/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp Source File
grouped_gemm_quant_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
14 #include "ck_tile/host.hpp"
15 
16 #include <hip/hip_runtime.h>
17 
18 namespace ck_tile {
19 
28 {
30  const void* b_ptr_,
31  void* e_ptr_,
32  const void* aq_ptr_,
33  const void* bq_ptr_,
34  index_t k_batch_,
35  index_t M_,
36  index_t N_,
37  index_t K_,
38  index_t QK_A_,
39  index_t QK_B_,
40  index_t stride_A_,
41  index_t stride_B_,
42  index_t stride_E_,
43  index_t stride_AQ_,
44  index_t stride_BQ_)
45  : a_ptr(a_ptr_),
46  b_ptr(b_ptr_),
47  aq_ptr(aq_ptr_),
48  bq_ptr(bq_ptr_),
49  e_ptr(e_ptr_),
50  M(M_),
51  N(N_),
52  K(K_),
53  QK_A(QK_A_),
54  QK_B(QK_B_),
55  stride_A(stride_A_),
56  stride_B(stride_B_),
57  stride_AQ(stride_AQ_),
58  stride_BQ(stride_BQ_),
59  stride_E(stride_E_),
60  k_batch(k_batch_)
61  {
62  }
63 
64  const void* a_ptr;
65  const void* b_ptr;
66  const void* aq_ptr;
67  const void* bq_ptr;
68  union
69  {
70  void* e_ptr;
71  void* c_ptr;
72  };
73 
83 
84  union
85  {
88  };
89 
91 };
92 
94 
96 {
100 
103  : group_karg{karg}, block_start{bl_start}, block_end{bl_end}
104  {
105  }
106 
108  : group_karg{karg}, block_start{0}, block_end{0}
109  {
110  }
111 };
112 
113 template <typename TilePartitioner_,
114  typename GemmPipeline_,
115  typename EpiloguePipeline_,
116  QuantType QuantType_>
118 {
122 
126 
131 
137 
138  using AQDataType =
140  using BQDataType =
142 
143  static constexpr auto kQuantType = QuantType_;
144 
146  static_assert(
148  "ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");
149 
151  static_assert(
153  "BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
154 
156  static_assert(!is_detected<is_tuple, CLayout>::value &&
158  "C/ELayout and C/EDataType must be scalars.");
159 
161  using Kernel =
163 
164  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
165  static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel;
166  static_assert(UsePersistentKernel == true, "UsePersistentKernel must be true");
167 
168  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
169  {
170  // clang-format off
171  using P_ = GemmPipeline;
172 
173  return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>(),
174  concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
175  concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
176  concat('x', P_::kPadM, P_::kPadN, P_::kPadK),
177  (UsePersistentKernel ? "Persistent" : "NonPersistent"));
178  // clang-format on
179  }
180 
181  CK_TILE_HOST static auto
182  GetWorkSpaceSize(const std::vector<QuantGroupedGemmHostArgs>& gemm_descs) -> std::size_t
183  {
184  return gemm_descs.size() * sizeof(QuantGemmTransKernelArg);
185  }
186 
187  CK_TILE_HOST static auto GetWorkSpaceSize(index_t group_count) -> std::size_t
188  {
189  return group_count * sizeof(QuantGemmTransKernelArg);
190  }
191 
192  CK_TILE_HOST static auto BlockSize() -> dim3
193  {
194  if(is_wave32())
195  {
196  return dim3(kBlockSize / 2);
197  }
198  else
199  {
200  return dim3(kBlockSize);
201  }
202  }
203 
210  CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
211  {
212  using ConstantPointer = const void CK_TILE_CONSTANT_ADDRESS_SPACE*;
213  const auto kernel_func = kentry<1, Kernel, ConstantPointer, index_t>;
214  int occupancy;
216  hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel_func, kBlockSize, 0));
217  const int grid_size = get_available_compute_units(s) * occupancy;
218  return dim3(grid_size, 1, 1);
219  }
220 
221  CK_TILE_HOST static auto GridSize(const std::vector<QuantGroupedGemmHostArgs>& gemm_descs)
222  {
223  index_t grid_size = 0;
224  for(const auto& it_desc : gemm_descs)
225  {
226  const auto local_grid_size = TilePartitioner::GridSize(it_desc.M, it_desc.N);
227  grid_size += local_grid_size * it_desc.k_batch;
228  }
229  return dim3(grid_size, 1, 1);
230  }
231 
232  CK_TILE_HOST static auto MakeKargs(const std::vector<QuantGroupedGemmHostArgs>& gemm_descs)
233  -> std::vector<QuantGemmTransKernelArg>
234  {
235  std::vector<QuantGemmTransKernelArg> gemm_kernel_args_;
236  index_t group_count = ck_tile::type_convert<ck_tile::index_t>(gemm_descs.size());
237  index_t grid_size = 0;
238  gemm_kernel_args_.reserve(group_count);
239 
240  for(std::size_t i = 0; i < gemm_descs.size(); ++i)
241  {
242  const index_t M = gemm_descs[i].M;
243  const index_t N = gemm_descs[i].N;
244  const index_t K = gemm_descs[i].K;
245 
246  if(M == 0 || N == 0 || K == 0)
247  {
248  continue;
249  }
250 
251  const index_t stride_a = gemm_descs[i].stride_A;
252  const index_t stride_b = gemm_descs[i].stride_B;
253  const index_t stride_e = gemm_descs[i].stride_C;
254 
255  const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch;
256 
257  const index_t block_start = grid_size;
258  const index_t block_end = grid_size + grid_size_grp;
259 
260  grid_size += grid_size_grp;
261 
262  auto karg =
263  QuantGroupedGemmKernelArgs{type_convert<const ADataType*>(gemm_descs[i].a_ptr),
264  type_convert<const BDataType*>(gemm_descs[i].b_ptr),
265  type_convert<CDataType*>(gemm_descs[i].e_ptr),
266  type_convert<const AQDataType*>(gemm_descs[i].aq_ptr),
267  type_convert<const BQDataType*>(gemm_descs[i].bq_ptr),
268  gemm_descs[i].k_batch,
269  M,
270  N,
271  K,
272  gemm_descs[i].QK_A,
273  gemm_descs[i].QK_B,
274  stride_a,
275  stride_b,
276  stride_e,
277  gemm_descs[i].stride_AQ,
278  gemm_descs[i].stride_BQ};
279 
280  gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
281  }
282 
283  return gemm_kernel_args_;
284  }
285 
286  CK_TILE_HOST static bool IsSupportedArgument(const std::vector<QuantGemmTransKernelArg>& kargs)
287  {
288  for(const auto& karg : kargs)
289  {
290  if(!Base::IsSupportedArgument(karg.group_karg))
291  {
292  return false;
293  }
294  }
295  return true;
296  }
297 
298  CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t
299  {
300  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
301  }
302 
304  const tuple<index_t, index_t>& block_idx_2d,
305  const index_t block_idx_z) const
306  {
307  const auto [iM, iN] = block_idx_2d;
308 
309  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
310  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
311 
312  const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z);
313 
314  // options
315  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.a_ptr);
316  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.b_ptr);
317  const AQDataType* aq_ptr = static_cast<const AQDataType*>(kargs.aq_ptr);
318  const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
319  CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
320 
321  // allocate LDS
322  __shared__ char smem_ptr_0[GetSmemSize()];
323 
324  // Only for BQuantGrouped DoubleSmemBuffer is supported
325  if constexpr(GemmPipeline::DoubleSmemBuffer == true &&
327  {
328 
329  __shared__ char smem_ptr_1[GetSmemSize()];
331  b_ptr,
332  aq_ptr,
333  bq_ptr,
334  c_ptr,
335  smem_ptr_0,
336  smem_ptr_1,
337  kargs,
338  splitk_batch_offset,
339  i_m,
340  i_n);
341  }
342  else
343  {
344 
346  b_ptr,
347  aq_ptr,
348  bq_ptr,
349  c_ptr,
350  smem_ptr_0,
351  kargs,
352  splitk_batch_offset,
353  i_m,
354  i_n);
355  }
356  }
357 
358  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
359  CK_TILE_DEVICE static void
361  const BDataType* b_ptr,
362  const AQDataType* aq_ptr,
363  const BQDataType* bq_ptr,
364  CDataType* c_ptr,
365  void* smem_ptr_0,
366  void* smem_ptr_1,
367  const QuantGroupedGemmKernelArgs& kargs,
368  const typename Base::SplitKBatchOffset& splitk_batch_offset,
369  const index_t block_idx_m,
370  const index_t block_idx_n)
371  {
372  static_assert(kQuantType == QuantType::BQuantGrouped, "kQuantType must be BQuantGrouped");
373  // Create Gemm tensor views, pad views and tile windows
374  const auto& gemm_tensor_views_tuple =
375  Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
376  a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
377 
378  const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
379  auto gemm_tile_windows =
380  Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
381 
382  const index_t num_loop = __builtin_amdgcn_readfirstlane(
383  TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
384  const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
385 
386  // Run GEMM cooperatively by whole workgroup.
387  const auto& a_block_window = gemm_tile_windows.at(Base::I0);
388  const auto& b_block_window = gemm_tile_windows.at(Base::I2);
389 
390  const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
391  const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
392  b_block_window,
393  bq_block_window,
394  num_loop,
395  tail_num,
396  smem_ptr_0,
397  smem_ptr_1);
398 
399  // Run Epilogue Pipeline
400  auto& c_block_window = gemm_tile_windows.at(Base::I4);
401 
402  EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
403  }
404 
425  CK_TILE_DEVICE static void
427  const BDataType* b_ptr,
428  const AQDataType* aq_ptr,
429  const BQDataType* bq_ptr,
430  CDataType* c_ptr,
431  void* smem_ptr_0,
432  const QuantGroupedGemmKernelArgs& kargs,
433  const typename Base::SplitKBatchOffset& splitk_batch_offset,
434  const index_t block_idx_m,
435  const index_t block_idx_n)
436  {
437  // Create Gemm tensor views, pad views and tile windows
438  const auto& gemm_tensor_views_tuple =
439  Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
440  a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
441 
442  const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
443  auto gemm_tile_windows =
444  Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
445  const auto& a_block_window = gemm_tile_windows.at(Base::I0);
446  const auto& b_block_window = gemm_tile_windows.at(Base::I2);
447 
448  // Get hot-loop and tail configuration
449  const index_t num_loop = __builtin_amdgcn_readfirstlane(
450  TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
451  const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
452  const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
453 
454  if constexpr(kQuantType == QuantType::BQuantGrouped)
455  {
456  const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
457  // Run GEMM pipeline
458  const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
459  b_block_window,
460  bq_block_window,
461  num_loop,
462  has_hot_loop,
463  tail_num,
464  smem_ptr_0);
465 
466  auto& c_block_window = gemm_tile_windows.at(Base::I4);
467 
468  // Run Epilogue Pipeline
469  EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
470  }
471  else
472  {
473  // Run GEMM pipeline
474  const auto& c_block_tile = GemmPipeline{}.template operator()(
475  a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
476  // Run Epilogue Pipeline
477  auto& c_block_window = gemm_tile_windows.at(Base::I4);
478  if constexpr(kQuantType == QuantType::RowColQuant)
479  {
480  const auto& aq_block_window = gemm_tile_windows.at(Base::I1);
481  const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
482  EpiloguePipeline{}(c_block_window,
483  c_block_tile,
484  c_block_window,
485  smem_ptr_0,
486  aq_block_window,
487  bq_block_window);
488  }
489  else if constexpr(kQuantType == QuantType::TensorQuant)
490  {
491  const AccDataType aq_scale = type_convert<AccDataType>(*aq_ptr);
492  const AccDataType bq_scale = type_convert<AccDataType>(*bq_ptr);
494  c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale);
495  }
496  }
497  }
498 
499  // For persistent kernels
500  template <bool U = UsePersistentKernel,
501  typename = std::enable_if_t<U>,
502  typename = void> // extra template parameter to avoid redefinition
503  CK_TILE_DEVICE void operator()(const void CK_TILE_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
504  const index_t group_count) const
505  {
506  const index_t grid_size = ck_tile::get_grid_size();
507  const auto gemm_desc_ptr = reinterpret_cast<const QuantGemmTransKernelArg*>(
508  cast_pointer_to_generic_address_space(gemm_descs_const));
509  index_t block_id = ck_tile::get_block_1d_id(); // initial block_id
510  index_t cum_grid_size = 0;
511  for(index_t group_id = 0; group_id < group_count; ++group_id)
512  {
513  const auto& kargs = gemm_desc_ptr[group_id].group_karg;
514  const auto& k_batch = kargs.k_batch;
515  const auto block_start = cum_grid_size;
516  cum_grid_size += TilePartitioner::GridSize(kargs.M, kargs.N) * k_batch;
517  while(block_id < cum_grid_size)
518  {
519  const auto grid_size_2d = TilePartitioner::GridSize(kargs.M, kargs.N);
520  const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
521  0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d);
522  Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d);
523  block_id = block_id + grid_size; // advance to next block
524  // NOTE: this check is redundant but helps the compiler avoid spilling some VGPR
525  if(block_id >= cum_grid_size)
526  {
527  break; // exit the loop if all blocks are processed
528  }
529  }
530  }
531  }
532 };
533 
534 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
#define HIP_CHECK_ERROR(retval_or_funcall)
Definition: hip_check_error.hpp:23
Definition: cluster_descriptor.hpp:13
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
TailNumber
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:21
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
QuantType
Definition: tile_gemm_quant_traits.hpp:12
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
__device__ index_t get_grid_size()
Definition: get_id.hpp:49
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition: amd_address_space.hpp:24
Struct used to calculate offseted tile indexes.
Definition: gemm_tile_partitioner.hpp:184
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from 1D raw-indexes.
Definition: gemm_tile_partitioner.hpp:192
Definition: gemm_quant_kernel.hpp:363
index_t splitted_k
Definition: gemm_quant_kernel.hpp:401
Definition: gemm_quant_kernel.hpp:171
index_t k_batch
Definition: gemm_quant_kernel.hpp:187
const void * b_ptr
Definition: gemm_quant_kernel.hpp:173
void * c_ptr
Definition: gemm_quant_kernel.hpp:176
const void * aq_ptr
Definition: gemm_quant_kernel.hpp:174
const void * a_ptr
Definition: gemm_quant_kernel.hpp:172
const void * bq_ptr
Definition: gemm_quant_kernel.hpp:175
Definition: gemm_quant_kernel.hpp:195
static constexpr auto I4
Definition: gemm_quant_kernel.hpp:227
static constexpr auto I3
Definition: gemm_quant_kernel.hpp:226
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: gemm_quant_kernel.hpp:838
static constexpr auto I0
Definition: gemm_quant_kernel.hpp:223
static constexpr auto I1
Definition: gemm_quant_kernel.hpp:224
static CK_TILE_HOST bool IsSupportedArgument(const QuantGemmKernelArgs &kargs)
Definition: gemm_quant_kernel.hpp:404
static constexpr auto I2
Definition: gemm_quant_kernel.hpp:225
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: gemm_quant_kernel.hpp:915
Definition: grouped_gemm_quant_kernel.hpp:96
QuantGemmTransKernelArg(QuantGroupedGemmKernelArgs &&karg, index_t bl_start, index_t bl_end)
Definition: grouped_gemm_quant_kernel.hpp:102
ck_tile::index_t block_end
Definition: grouped_gemm_quant_kernel.hpp:99
ck_tile::index_t block_start
Definition: grouped_gemm_quant_kernel.hpp:98
QuantGroupedGemmKernelArgs group_karg
Definition: grouped_gemm_quant_kernel.hpp:97
QuantGemmTransKernelArg(QuantGroupedGemmKernelArgs &&karg)
Definition: grouped_gemm_quant_kernel.hpp:107
The Grouped GEMM kernel host arguments.
Definition: grouped_gemm_quant_kernel.hpp:28
index_t stride_BQ
Definition: grouped_gemm_quant_kernel.hpp:82
const void * b_ptr
Definition: grouped_gemm_quant_kernel.hpp:65
void * c_ptr
Definition: grouped_gemm_quant_kernel.hpp:71
index_t QK_A
Definition: grouped_gemm_quant_kernel.hpp:77
index_t M
Definition: grouped_gemm_quant_kernel.hpp:74
const void * aq_ptr
Definition: grouped_gemm_quant_kernel.hpp:66
index_t stride_B
Definition: grouped_gemm_quant_kernel.hpp:80
index_t k_batch
Definition: grouped_gemm_quant_kernel.hpp:90
index_t N
Definition: grouped_gemm_quant_kernel.hpp:75
index_t stride_AQ
Definition: grouped_gemm_quant_kernel.hpp:81
CK_TILE_HOST QuantGroupedGemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *e_ptr_, const void *aq_ptr_, const void *bq_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t QK_A_, index_t QK_B_, index_t stride_A_, index_t stride_B_, index_t stride_E_, index_t stride_AQ_, index_t stride_BQ_)
Definition: grouped_gemm_quant_kernel.hpp:29
index_t K
Definition: grouped_gemm_quant_kernel.hpp:76
index_t QK_B
Definition: grouped_gemm_quant_kernel.hpp:78
void * e_ptr
Definition: grouped_gemm_quant_kernel.hpp:70
index_t stride_A
Definition: grouped_gemm_quant_kernel.hpp:79
const void * bq_ptr
Definition: grouped_gemm_quant_kernel.hpp:67
index_t stride_C
Definition: grouped_gemm_quant_kernel.hpp:87
index_t stride_E
Definition: grouped_gemm_quant_kernel.hpp:86
const void * a_ptr
Definition: grouped_gemm_quant_kernel.hpp:64
Definition: grouped_gemm_quant_kernel.hpp:118
static CK_TILE_DEVICE void RunGemmWithPipelineSelection(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, void *smem_ptr_0, const QuantGroupedGemmKernelArgs &kargs, const typename Base::SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_gemm_quant_kernel.hpp:426
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: grouped_gemm_quant_kernel.hpp:130
static CK_TILE_HOST auto GridSize(const std::vector< QuantGroupedGemmHostArgs > &gemm_descs)
Definition: grouped_gemm_quant_kernel.hpp:221
static constexpr index_t kBlockSize
Definition: grouped_gemm_quant_kernel.hpp:164
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Specify the data type configurations for A, B, C/E.
Definition: grouped_gemm_quant_kernel.hpp:133
remove_cvref_t< typename detail::get_bq_data_type_or< GemmPipeline, AccDataType >::type > BQDataType
Definition: grouped_gemm_quant_kernel.hpp:141
static CK_TILE_HOST auto BlockSize() -> dim3
Definition: grouped_gemm_quant_kernel.hpp:192
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: grouped_gemm_quant_kernel.hpp:210
CK_TILE_DEVICE void Run(const QuantGroupedGemmKernelArgs &kargs, const tuple< index_t, index_t > &block_idx_2d, const index_t block_idx_z) const
Definition: grouped_gemm_quant_kernel.hpp:303
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: grouped_gemm_quant_kernel.hpp:129
remove_cvref_t< typename detail::get_aq_data_type_or< GemmPipeline, AccDataType >::type > AQDataType
Definition: grouped_gemm_quant_kernel.hpp:139
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_gemm_quant_kernel.hpp:124
static constexpr CK_TILE_HOST_DEVICE auto GetSmemSize() -> index_t
Definition: grouped_gemm_quant_kernel.hpp:298
remove_cvref_t< typename EpiloguePipeline::AccDataType > AccDataType
Definition: grouped_gemm_quant_kernel.hpp:136
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition: grouped_gemm_quant_kernel.hpp:128
static CK_TILE_HOST const std::string GetName()
Definition: grouped_gemm_quant_kernel.hpp:168
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: grouped_gemm_quant_kernel.hpp:134
static CK_TILE_HOST auto GetWorkSpaceSize(const std::vector< QuantGroupedGemmHostArgs > &gemm_descs) -> std::size_t
Definition: grouped_gemm_quant_kernel.hpp:182
static CK_TILE_HOST auto GetWorkSpaceSize(index_t group_count) -> std::size_t
Definition: grouped_gemm_quant_kernel.hpp:187
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_gemm_quant_kernel.hpp:123
static CK_TILE_DEVICE void RunGemmWithPipelineSelection2LDS(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, void *smem_ptr_0, void *smem_ptr_1, const QuantGroupedGemmKernelArgs &kargs, const typename Base::SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition: grouped_gemm_quant_kernel.hpp:360
CK_TILE_DEVICE void operator()(const void CK_TILE_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count) const
Definition: grouped_gemm_quant_kernel.hpp:503
static CK_TILE_HOST auto MakeKargs(const std::vector< QuantGroupedGemmHostArgs > &gemm_descs) -> std::vector< QuantGemmTransKernelArg >
Definition: grouped_gemm_quant_kernel.hpp:232
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_gemm_quant_kernel.hpp:125
static constexpr bool UsePersistentKernel
Definition: grouped_gemm_quant_kernel.hpp:165
static constexpr auto kQuantType
Definition: grouped_gemm_quant_kernel.hpp:143
static CK_TILE_HOST bool IsSupportedArgument(const std::vector< QuantGemmTransKernelArg > &kargs)
Definition: grouped_gemm_quant_kernel.hpp:286
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: grouped_gemm_quant_kernel.hpp:135
Definition: stream_config.hpp:30
Definition: tuple.hpp:192