/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp Source File
grouped_gemm_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
13 #include "ck_tile/host.hpp"
14 
15 #include <hip/hip_runtime.h>
16 
17 namespace ck_tile {
18 
27 {
28  CK_TILE_HOST GroupedGemmHostArgs(const void* a_ptr_,
29  const void* b_ptr_,
30  void* e_ptr_,
31  index_t k_batch_,
32  index_t M_,
33  index_t N_,
34  index_t K_,
35  index_t stride_A_,
36  index_t stride_B_,
37  index_t stride_E_)
38  : a_ptr(a_ptr_),
39  b_ptr(b_ptr_),
40  e_ptr(e_ptr_),
41  M(M_),
42  N(N_),
43  K(K_),
44  stride_A(stride_A_),
45  stride_B(stride_B_),
46  stride_E(stride_E_),
47  k_batch(k_batch_)
48  {
49  }
50 
51  const void* a_ptr;
52  const void* b_ptr;
53  union
54  {
55  void* e_ptr;
56  void* c_ptr;
57  };
58 
64 
65  union
66  {
69  };
70 
72 };
73 
75 {
79 
80  GemmTransKernelArg() = delete;
82  : group_karg{karg}, block_start{bl_start}, block_end{bl_end}
83  {
84  }
85 
87  : group_karg{karg}, block_start{0}, block_end{0}
88  {
89  }
90 };
91 
92 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
94 {
98 
102 
107 
112 
114  static_assert(
116  "ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");
117 
119  static_assert(
121  "BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
122 
124  static_assert(!is_detected<is_tuple, CLayout>::value &&
126  "C/ELayout and C/EDataType must be scalars.");
127 
130 
131  static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
132  static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel;
133 
134  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
135  {
136  // clang-format off
137  using P_ = GemmPipeline;
138 
139  return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>(),
140  concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
141  concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
142  concat('x', P_::kPadM, P_::kPadN, P_::kPadK),
143  (UsePersistentKernel ? "Persistent" : "NonPersistent"));
144  // clang-format on
145  }
146 
147  CK_TILE_HOST static auto
148  GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs) -> std::size_t
149  {
150  return gemm_descs.size() * sizeof(GemmTransKernelArg);
151  }
152 
153  CK_TILE_HOST static auto GetWorkSpaceSize(index_t group_count) -> std::size_t
154  {
155  return group_count * sizeof(GemmTransKernelArg);
156  }
157 
158  CK_TILE_HOST static constexpr auto BlockSize() -> dim3 { return dim3(KernelBlockSize); }
159 
166  CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
167  {
168  using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*;
169  const auto kernel = kentry<KernelBlockSize, 1, Kernel, ConstantPointer, index_t>;
170  int occupancy;
172  hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0));
173  const int grid_size = get_available_compute_units(s) * occupancy;
174  return dim3(grid_size, 1, 1);
175  }
176 
177  CK_TILE_HOST static auto GridSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
178  {
179  index_t grid_size = 0;
180  for(const auto& it_desc : gemm_descs)
181  {
182  const auto local_grid_size = TilePartitioner::GridSize(it_desc.M, it_desc.N);
183  grid_size += local_grid_size * it_desc.k_batch;
184  }
185  return dim3(grid_size, 1, 1);
186  }
187 
188  CK_TILE_HOST static auto
189  MakeKargs(const std::vector<GroupedGemmHostArgs>& gemm_descs) -> std::vector<GemmTransKernelArg>
190  {
191  std::vector<GemmTransKernelArg> gemm_kernel_args_;
192  index_t group_count = ck_tile::type_convert<ck_tile::index_t>(gemm_descs.size());
193  index_t grid_size = 0;
194  gemm_kernel_args_.reserve(group_count);
195 
196  for(std::size_t i = 0; i < gemm_descs.size(); ++i)
197  {
198  const index_t M = gemm_descs[i].M;
199  const index_t N = gemm_descs[i].N;
200  const index_t K = gemm_descs[i].K;
201 
202  if(M == 0 || N == 0 || K == 0)
203  {
204  continue;
205  }
206 
207  const index_t stride_a = gemm_descs[i].stride_A;
208  const index_t stride_b = gemm_descs[i].stride_B;
209  const index_t stride_e = gemm_descs[i].stride_E;
210 
211  const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch;
212 
213  const index_t block_start = grid_size;
214  const index_t block_end = grid_size + grid_size_grp;
215 
216  grid_size += grid_size_grp;
217 
218  auto karg =
219  UniversalGemmKernelArgs<>{{type_convert<const ADataType*>(gemm_descs[i].a_ptr)},
220  {type_convert<const BDataType*>(gemm_descs[i].b_ptr)},
221  {/*ds_ptr*/},
222  type_convert<CDataType*>(gemm_descs[i].e_ptr),
223  M,
224  N,
225  K,
226  {stride_a},
227  {stride_b},
228  {/*stride_ds*/},
229  stride_e,
230  gemm_descs[i].k_batch};
231 
232  gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
233  }
234 
235  return gemm_kernel_args_;
236  }
237 
238  CK_TILE_HOST static bool IsSupportedArgument(const std::vector<GemmTransKernelArg>& kargs)
239  {
240  for(const auto& karg : kargs)
241  {
242  if(!Base::IsSupportedArgument(karg.group_karg))
243  {
244  return false;
245  }
246  }
247  return true;
248  }
249 
250  CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t
251  {
252  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
253  }
254 
256  const tuple<index_t, index_t>& block_idx_2d,
257  const index_t block_idx_z) const
258  {
259  const auto [iM, iN] = block_idx_2d;
260 
261  const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
262  const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
263 
264  const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z);
265 
266  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) +
267  splitk_batch_offset.as_k_split_offset[0];
268  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) +
269  splitk_batch_offset.bs_k_split_offset[0];
270  CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
271 
272  // allocate LDS
273  __shared__ char smem_ptr_0[GetSmemSize()];
274 
275  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
276  {
277  __shared__ char smem_ptr_1[GetSmemSize()];
278  if constexpr(UsePersistentKernel)
279  {
281  b_ptr,
282  c_ptr,
283  smem_ptr_0,
284  smem_ptr_1,
285  kargs,
286  splitk_batch_offset,
287  i_m,
288  i_n);
289  }
290  else
291  {
292  Base::RunGemm2LDS({a_ptr},
293  {b_ptr},
294  {/*ds_ptr*/},
295  c_ptr,
296  smem_ptr_0,
297  smem_ptr_1,
298  kargs,
299  splitk_batch_offset,
300  i_m,
301  i_n);
302  }
303  }
304  else
305  {
306  if constexpr(UsePersistentKernel)
307  {
309  a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
310  }
311  else
312  {
313  Base::RunGemm({a_ptr},
314  {b_ptr},
315  {/*ds_ptr*/},
316  c_ptr,
317  smem_ptr_0,
318  kargs,
319  splitk_batch_offset,
320  i_m,
321  i_n);
322  }
323  }
324  }
325 
344  CK_TILE_DEVICE static void
346  const BDataType* b_ptr,
347  CDataType* c_ptr,
348  void* smem_ptr_0,
349  const UniversalGemmKernelArgs<>& kargs,
350  const typename Base::SplitKBatchOffset& splitk_batch_offset,
351  const index_t block_idx_m,
352  const index_t block_idx_n)
353  {
354  // Create Gemm tensor views, pad views and tile windows
355  const auto& gemm_tensor_views_tuple =
356  Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
357  {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset);
358 
359  const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
360  auto gemm_tile_windows =
361  Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
362  const auto& a_block_window = gemm_tile_windows.at(Base::I0);
363  const auto& b_block_window = gemm_tile_windows.at(Base::I1);
364  const auto& d_block_window = gemm_tile_windows.at(Base::I2);
365 
366  // Get hot-loop and tail configuration
367  const index_t num_loop = __builtin_amdgcn_readfirstlane(
368  TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
369  const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
370  const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
371 
372  // Run GEMM pipeline
373  const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window[Base::I0],
374  b_block_window[Base::I0],
375  num_loop,
376  has_hot_loop,
377  tail_num,
378  smem_ptr_0);
379  // Run Epilogue Pipeline
380  auto& c_block_window = gemm_tile_windows.at(Base::I3);
381  EpiloguePipeline{}.template
382  operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
383  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
384  }
385 
405  CK_TILE_DEVICE static void
407  const BDataType* b_ptr,
408  CDataType* c_ptr,
409  void* __restrict__ smem_ptr_0,
410  void* __restrict__ smem_ptr_1,
411  const UniversalGemmKernelArgs<>& kargs,
412  const typename Base::SplitKBatchOffset& splitk_batch_offset,
413  const index_t block_idx_m,
414  const index_t block_idx_n)
415  {
416  // Create Gemm tensor views, pad views and tile windows
417  const auto& gemm_tensor_views_tuple =
418  Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
419  {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset);
420 
421  const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
422  auto gemm_tile_windows =
423  Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
424  const auto& a_block_window = gemm_tile_windows.at(Base::I0);
425  const auto& b_block_window = gemm_tile_windows.at(Base::I1);
426  const auto& d_block_window = gemm_tile_windows.at(Base::I2);
427 
428  // Get hot-loop and tail configuration
429  const index_t num_loop = __builtin_amdgcn_readfirstlane(
430  TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
431  const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
432  const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
433 
434  // Run GEMM pipeline
435  const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window[Base::I0],
436  b_block_window[Base::I0],
437  num_loop,
438  has_hot_loop,
439  tail_num,
440  smem_ptr_0,
441  smem_ptr_1);
442  // Run Epilogue Pipeline
443  auto& c_block_window = gemm_tile_windows.at(Base::I3);
444  EpiloguePipeline{}.template
445  operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
446  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
447  }
448 
450  index_t block_id,
451  index_t group_count) const
452  {
453  index_t left = 0;
454  index_t right = group_count;
455  index_t group_id = index_t((left + right) >> 1);
456 
457  while((!(block_id >= gemm_desc_ptr[group_id].block_start &&
458  block_id < gemm_desc_ptr[group_id].block_end)) &&
459  left <= right)
460  {
461  if(block_id < gemm_desc_ptr[group_id].block_start)
462  {
463  right = group_id;
464  }
465  else
466  {
467  left = group_id;
468  }
469  group_id = index_t((left + right) >> 1);
470  }
471 
472  return group_id;
473  }
474 
475  // For non-persistent kernels
476  template <bool U = UsePersistentKernel, typename = std::enable_if_t<!U>>
477  CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
478  index_t group_count) const
479  {
480  const index_t block_id = ck_tile::get_block_1d_id();
481  const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
482  cast_pointer_to_generic_address_space(gemm_descs_const));
483 
484  const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
485  const auto& kargs = gemm_desc_ptr[group_id];
486  const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N);
487  const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
488  0,
489  kargs.group_karg.M,
490  kargs.group_karg.N,
491  (block_id - kargs.block_start) % grid_size_2d);
492  Run(kargs.group_karg, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d);
493  }
494 
495  // For persistent kernels
496  template <bool U = UsePersistentKernel,
497  typename = std::enable_if_t<U>,
498  typename = void> // extra template parameter to avoid redefinition
499  CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
500  const index_t group_count) const
501  {
502  const index_t grid_size = ck_tile::get_grid_size();
503  const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
504  cast_pointer_to_generic_address_space(gemm_descs_const));
505  index_t block_id = ck_tile::get_block_1d_id(); // initial block_id
506  index_t cum_grid_size = 0;
507  for(index_t group_id = 0; group_id < group_count; ++group_id)
508  {
509  const auto& kargs = gemm_desc_ptr[group_id].group_karg;
510  const auto& k_batch = kargs.k_batch;
511  const auto block_start = cum_grid_size;
512  cum_grid_size += TilePartitioner::GridSize(kargs.M, kargs.N) * k_batch;
513  while(block_id < cum_grid_size)
514  {
515  const auto grid_size_2d = TilePartitioner::GridSize(kargs.M, kargs.N);
516  const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
517  0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d);
518  Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d);
519  block_id = block_id + grid_size; // advance to next block
520  // NOTE: this check is redundant but helps the compiler avoid spilling some VGPR
521  if(block_id >= cum_grid_size)
522  {
523  break; // exit the loop if all blocks are processed
524  }
525  }
526  }
527  }
528 };
529 
530 } // namespace ck_tile
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition: arch.hpp:207
#define CK_CONSTANT_ADDRESS_SPACE
Definition: ck.hpp:22
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
#define HIP_CHECK_ERROR(retval_or_funcall)
Definition: hip_check_error.hpp:21
Definition: cluster_descriptor.hpp:13
TailNumber
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:21
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
__device__ index_t get_grid_size()
Definition: get_id.hpp:27
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
Definition: grouped_gemm_kernel.hpp:75
ck_tile::index_t block_end
Definition: grouped_gemm_kernel.hpp:78
UniversalGemmKernelArgs group_karg
Definition: grouped_gemm_kernel.hpp:76
GemmTransKernelArg(UniversalGemmKernelArgs<> &&karg, index_t bl_start, index_t bl_end)
Definition: grouped_gemm_kernel.hpp:81
GemmTransKernelArg(UniversalGemmKernelArgs<> &&karg)
Definition: grouped_gemm_kernel.hpp:86
ck_tile::index_t block_start
Definition: grouped_gemm_kernel.hpp:77
The Grouped GEMM kernel host arguments.
Definition: grouped_gemm_kernel.hpp:27
index_t stride_C
Definition: grouped_gemm_kernel.hpp:68
index_t N
Definition: grouped_gemm_kernel.hpp:60
index_t stride_E
Definition: grouped_gemm_kernel.hpp:67
index_t M
Definition: grouped_gemm_kernel.hpp:59
const void * a_ptr
Definition: grouped_gemm_kernel.hpp:51
index_t stride_B
Definition: grouped_gemm_kernel.hpp:63
index_t k_batch
Definition: grouped_gemm_kernel.hpp:71
void * e_ptr
Definition: grouped_gemm_kernel.hpp:55
void * c_ptr
Definition: grouped_gemm_kernel.hpp:56
const void * b_ptr
Definition: grouped_gemm_kernel.hpp:52
CK_TILE_HOST GroupedGemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_E_)
Definition: grouped_gemm_kernel.hpp:28
index_t stride_A
Definition: grouped_gemm_kernel.hpp:62
index_t K
Definition: grouped_gemm_kernel.hpp:61
Definition: grouped_gemm_kernel.hpp:94
static CK_TILE_HOST auto GridSize(const std::vector< GroupedGemmHostArgs > &gemm_descs)
Definition: grouped_gemm_kernel.hpp:177
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_gemm_kernel.hpp:99
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count) const
Definition: grouped_gemm_kernel.hpp:499
static CK_TILE_HOST auto GetWorkSpaceSize(index_t group_count) -> std::size_t
Definition: grouped_gemm_kernel.hpp:153
static constexpr index_t KernelBlockSize
Definition: grouped_gemm_kernel.hpp:131
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition: grouped_gemm_kernel.hpp:104
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: grouped_gemm_kernel.hpp:105
static CK_TILE_HOST auto MakeKargs(const std::vector< GroupedGemmHostArgs > &gemm_descs) -> std::vector< GemmTransKernelArg >
Definition: grouped_gemm_kernel.hpp:189
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: grouped_gemm_kernel.hpp:110
static CK_TILE_DEVICE void RunGemmWithPipelineSelection2LDS(const ADataType *a_ptr, const BDataType *b_ptr, CDataType *c_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const UniversalGemmKernelArgs<> &kargs, const typename Base::SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_gemm_kernel.hpp:406
static constexpr CK_TILE_HOST auto BlockSize() -> dim3
Definition: grouped_gemm_kernel.hpp:158
static constexpr CK_TILE_HOST_DEVICE auto GetSmemSize() -> index_t
Definition: grouped_gemm_kernel.hpp:250
static CK_TILE_HOST bool IsSupportedArgument(const std::vector< GemmTransKernelArg > &kargs)
Definition: grouped_gemm_kernel.hpp:238
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: grouped_gemm_kernel.hpp:106
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_gemm_kernel.hpp:101
static CK_TILE_DEVICE void RunGemmWithPipelineSelection(const ADataType *a_ptr, const BDataType *b_ptr, CDataType *c_ptr, void *smem_ptr_0, const UniversalGemmKernelArgs<> &kargs, const typename Base::SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_gemm_kernel.hpp:345
CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs<> &kargs, const tuple< index_t, index_t > &block_idx_2d, const index_t block_idx_z) const
Definition: grouped_gemm_kernel.hpp:255
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: grouped_gemm_kernel.hpp:111
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: grouped_gemm_kernel.hpp:166
static CK_TILE_HOST const std::string GetName()
Definition: grouped_gemm_kernel.hpp:134
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_gemm_kernel.hpp:100
CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg *gemm_desc_ptr, index_t block_id, index_t group_count) const
Definition: grouped_gemm_kernel.hpp:449
static CK_TILE_HOST auto GetWorkSpaceSize(const std::vector< GroupedGemmHostArgs > &gemm_descs) -> std::size_t
Definition: grouped_gemm_kernel.hpp:148
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Specify the data type configurations for A, B, C/E.
Definition: grouped_gemm_kernel.hpp:109
static constexpr bool UsePersistentKernel
Definition: grouped_gemm_kernel.hpp:132
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, index_t group_count) const
Definition: grouped_gemm_kernel.hpp:477
Struct used to calculate offseted tile indexes.
Definition: gemm_tile_partitioner.hpp:183
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from 1D raw-indexes.
Definition: gemm_tile_partitioner.hpp:191
Definition: universal_gemm_kernel.hpp:294
std::array< index_t, NumATensor > as_k_split_offset
Definition: universal_gemm_kernel.hpp:337
index_t splitted_k
Definition: universal_gemm_kernel.hpp:339
std::array< index_t, NumBTensor > bs_k_split_offset
Definition: universal_gemm_kernel.hpp:338
The GEMM kernel device arguments.
Definition: universal_gemm_kernel.hpp:86
void * e_ptr
The E output tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:94
const std::array< const void *, NumATensor > as_ptr
The As input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:88
index_t k_batch
Definition: universal_gemm_kernel.hpp:113
const std::array< const void *, NumBTensor > bs_ptr
The Bs input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:90
The Universal GEMM kernel template.
Definition: universal_gemm_kernel.hpp:154
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: universal_gemm_kernel.hpp:920
static constexpr auto I2
Definition: universal_gemm_kernel.hpp:218
static CK_TILE_DEVICE void RunGemm2LDS(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: universal_gemm_kernel.hpp:977
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: universal_gemm_kernel.hpp:818
static constexpr auto I3
Definition: universal_gemm_kernel.hpp:219
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: universal_gemm_kernel.hpp:719
static constexpr auto I1
Definition: universal_gemm_kernel.hpp:217
static constexpr auto I0
Definition: universal_gemm_kernel.hpp:216
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:342
Definition: stream_config.hpp:30
Definition: tuple.hpp:192