/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp Source File
grouped_gemm_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
13 #include "ck_tile/host.hpp"
14 
15 #include <hip/hip_runtime.h>
16 
17 namespace ck_tile {
18 
27 {
28  CK_TILE_HOST GroupedGemmHostArgs(const void* a_ptr_,
29  const void* b_ptr_,
30  void* e_ptr_,
31  index_t k_batch_,
32  index_t M_,
33  index_t N_,
34  index_t K_,
35  index_t stride_A_,
36  index_t stride_B_,
37  index_t stride_E_)
38  : a_ptr(a_ptr_),
39  b_ptr(b_ptr_),
40  e_ptr(e_ptr_),
41  M(M_),
42  N(N_),
43  K(K_),
44  stride_A(stride_A_),
45  stride_B(stride_B_),
46  stride_E(stride_E_),
47  k_batch(k_batch_)
48  {
49  }
50 
51  const void* a_ptr;
52  const void* b_ptr;
53  union
54  {
55  void* e_ptr;
56  void* c_ptr;
57  };
58 
64 
65  union
66  {
69  };
70 
72 };
73 
75 {
79 
80  GemmTransKernelArg() = delete;
82  : group_karg{karg}, block_start{bl_start}, block_end{bl_end}
83  {
84  }
85 
87  : group_karg{karg}, block_start{0}, block_end{0}
88  {
89  }
90 };
91 
92 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
94 {
98 
102 
107 
112 
114  static_assert(
116  "ALayout and ADataType must be scalars. Multiple parameters are not currently supported.");
117 
119  static_assert(
121  "BLayout and BDataType must be scalars. Multiple parameters are not currently supported.");
122 
124  static_assert(!is_detected<is_tuple, CLayout>::value &&
126  "C/ELayout and C/EDataType must be scalars.");
127 
130 
131  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
132  static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel;
133 
134  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
135  {
136  // clang-format off
137  using P_ = GemmPipeline;
138 
139  return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>(),
140  concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
141  concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
142  concat('x', P_::kPadM, P_::kPadN, P_::kPadK),
143  (UsePersistentKernel ? "Persistent" : "NonPersistent"));
144  // clang-format on
145  }
146 
147  CK_TILE_HOST static auto
148  GetWorkSpaceSize(const std::vector<GroupedGemmHostArgs>& gemm_descs) -> std::size_t
149  {
150  return gemm_descs.size() * sizeof(GemmTransKernelArg);
151  }
152 
153  CK_TILE_HOST static auto GetWorkSpaceSize(index_t group_count) -> std::size_t
154  {
155  return group_count * sizeof(GemmTransKernelArg);
156  }
157 
158  CK_TILE_HOST static auto BlockSize() -> dim3
159  {
160  if(is_wave32())
161  {
162  return dim3(kBlockSize / 2);
163  }
164  else
165  {
166  return dim3(kBlockSize);
167  }
168  }
169 
176  CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
177  {
178  using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*;
179  const auto kernel = kentry<1, Kernel, ConstantPointer, index_t>;
180  int occupancy;
182  hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, kBlockSize, 0));
183  const int grid_size = get_available_compute_units(s) * occupancy;
184  return dim3(grid_size, 1, 1);
185  }
186 
187  CK_TILE_HOST static auto GridSize(const std::vector<GroupedGemmHostArgs>& gemm_descs)
188  {
189  index_t grid_size = 0;
190  for(const auto& it_desc : gemm_descs)
191  {
192  const auto local_grid_size = TilePartitioner::GridSize(it_desc.M, it_desc.N);
193  grid_size += local_grid_size * it_desc.k_batch;
194  }
195  return dim3(grid_size, 1, 1);
196  }
197 
198  CK_TILE_HOST static auto
199  MakeKargs(const std::vector<GroupedGemmHostArgs>& gemm_descs) -> std::vector<GemmTransKernelArg>
200  {
201  std::vector<GemmTransKernelArg> gemm_kernel_args_;
202  index_t group_count = ck_tile::type_convert<ck_tile::index_t>(gemm_descs.size());
203  index_t grid_size = 0;
204  gemm_kernel_args_.reserve(group_count);
205 
206  for(std::size_t i = 0; i < gemm_descs.size(); ++i)
207  {
208  const index_t M = gemm_descs[i].M;
209  const index_t N = gemm_descs[i].N;
210  const index_t K = gemm_descs[i].K;
211 
212  if(M == 0 || N == 0 || K == 0)
213  {
214  continue;
215  }
216 
217  const index_t stride_a = gemm_descs[i].stride_A;
218  const index_t stride_b = gemm_descs[i].stride_B;
219  const index_t stride_e = gemm_descs[i].stride_E;
220 
221  const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch;
222 
223  const index_t block_start = grid_size;
224  const index_t block_end = grid_size + grid_size_grp;
225 
226  grid_size += grid_size_grp;
227 
228  auto karg =
229  UniversalGemmKernelArgs<>{{type_convert<const ADataType*>(gemm_descs[i].a_ptr)},
230  {type_convert<const BDataType*>(gemm_descs[i].b_ptr)},
231  {/*ds_ptr*/},
232  type_convert<CDataType*>(gemm_descs[i].e_ptr),
233  M,
234  N,
235  K,
236  {stride_a},
237  {stride_b},
238  {/*stride_ds*/},
239  stride_e,
240  gemm_descs[i].k_batch};
241 
242  gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
243  }
244 
245  return gemm_kernel_args_;
246  }
247 
248  CK_TILE_HOST static bool IsSupportedArgument(const std::vector<GemmTransKernelArg>& kargs)
249  {
250  for(const auto& karg : kargs)
251  {
252  if(!Base::IsSupportedArgument(karg.group_karg))
253  {
254  return false;
255  }
256  }
257  return true;
258  }
259 
260  CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t
261  {
262  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
263  }
264 
266  const tuple<index_t, index_t>& block_idx_2d,
267  const index_t block_idx_z) const
268  {
269  const auto [iM, iN] = block_idx_2d;
270 
271  const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
272  const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
273 
274  const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z);
275 
276  const ADataType* a_ptr = static_cast<const ADataType*>(kargs.as_ptr[0]) +
277  splitk_batch_offset.as_k_split_offset[0];
278  const BDataType* b_ptr = static_cast<const BDataType*>(kargs.bs_ptr[0]) +
279  splitk_batch_offset.bs_k_split_offset[0];
280  CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
281 
282  // allocate LDS
283  __shared__ char smem_ptr_0[GetSmemSize()];
284 
285  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
286  {
287  __shared__ char smem_ptr_1[GetSmemSize()];
288  if constexpr(UsePersistentKernel)
289  {
291  b_ptr,
292  c_ptr,
293  smem_ptr_0,
294  smem_ptr_1,
295  kargs,
296  splitk_batch_offset,
297  i_m,
298  i_n);
299  }
300  else
301  {
302  Base::RunGemm2LDS({a_ptr},
303  {b_ptr},
304  {/*ds_ptr*/},
305  c_ptr,
306  smem_ptr_0,
307  smem_ptr_1,
308  kargs,
309  splitk_batch_offset,
310  i_m,
311  i_n);
312  }
313  }
314  else
315  {
316  if constexpr(UsePersistentKernel)
317  {
319  a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
320  }
321  else
322  {
323  Base::RunGemm({a_ptr},
324  {b_ptr},
325  {/*ds_ptr*/},
326  c_ptr,
327  smem_ptr_0,
328  kargs,
329  splitk_batch_offset,
330  i_m,
331  i_n);
332  }
333  }
334  }
335 
354  CK_TILE_DEVICE static void
356  const BDataType* b_ptr,
357  CDataType* c_ptr,
358  void* smem_ptr_0,
359  const UniversalGemmKernelArgs<>& kargs,
360  const typename Base::SplitKBatchOffset& splitk_batch_offset,
361  const index_t block_idx_m,
362  const index_t block_idx_n)
363  {
364  // Create Gemm tensor views, pad views and tile windows
365  const auto& gemm_tensor_views_tuple =
366  Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
367  {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset);
368 
369  const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
370  auto gemm_tile_windows =
371  Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
372  const auto& a_block_window = gemm_tile_windows.at(Base::I0);
373  const auto& b_block_window = gemm_tile_windows.at(Base::I1);
374  const auto& d_block_window = gemm_tile_windows.at(Base::I2);
375 
376  // Get hot-loop and tail configuration
377  const index_t num_loop = __builtin_amdgcn_readfirstlane(
378  TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
379  const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
380  const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
381 
382  // Run GEMM pipeline
383  const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window[Base::I0],
384  b_block_window[Base::I0],
385  num_loop,
386  has_hot_loop,
387  tail_num,
388  smem_ptr_0);
389  // Run Epilogue Pipeline
390  auto& c_block_window = gemm_tile_windows.at(Base::I3);
391  EpiloguePipeline{}.template
392  operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
393  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
394  }
395 
415  CK_TILE_DEVICE static void
417  const BDataType* b_ptr,
418  CDataType* c_ptr,
419  void* __restrict__ smem_ptr_0,
420  void* __restrict__ smem_ptr_1,
421  const UniversalGemmKernelArgs<>& kargs,
422  const typename Base::SplitKBatchOffset& splitk_batch_offset,
423  const index_t block_idx_m,
424  const index_t block_idx_n)
425  {
426  // Create Gemm tensor views, pad views and tile windows
427  const auto& gemm_tensor_views_tuple =
428  Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
429  {a_ptr}, {b_ptr}, {/*ds_ptr*/}, c_ptr, kargs, splitk_batch_offset);
430 
431  const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
432  auto gemm_tile_windows =
433  Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
434  const auto& a_block_window = gemm_tile_windows.at(Base::I0);
435  const auto& b_block_window = gemm_tile_windows.at(Base::I1);
436  const auto& d_block_window = gemm_tile_windows.at(Base::I2);
437 
438  // Get hot-loop and tail configuration
439  const index_t num_loop = __builtin_amdgcn_readfirstlane(
440  TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
441  const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
442  const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
443 
444  // Run GEMM pipeline
445  const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window[Base::I0],
446  b_block_window[Base::I0],
447  num_loop,
448  has_hot_loop,
449  tail_num,
450  smem_ptr_0,
451  smem_ptr_1);
452  // Run Epilogue Pipeline
453  auto& c_block_window = gemm_tile_windows.at(Base::I3);
454  EpiloguePipeline{}.template
455  operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
456  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
457  }
458 
460  index_t block_id,
461  index_t group_count) const
462  {
463  index_t left = 0;
464  index_t right = group_count;
465  index_t group_id = index_t((left + right) >> 1);
466 
467  while((!(block_id >= gemm_desc_ptr[group_id].block_start &&
468  block_id < gemm_desc_ptr[group_id].block_end)) &&
469  left <= right)
470  {
471  if(block_id < gemm_desc_ptr[group_id].block_start)
472  {
473  right = group_id;
474  }
475  else
476  {
477  left = group_id;
478  }
479  group_id = index_t((left + right) >> 1);
480  }
481 
482  return group_id;
483  }
484 
485  // For non-persistent kernels
486  template <bool U = UsePersistentKernel, typename = std::enable_if_t<!U>>
487  CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
488  index_t group_count) const
489  {
490  const index_t block_id = ck_tile::get_block_1d_id();
491  const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
492  cast_pointer_to_generic_address_space(gemm_descs_const));
493 
494  const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
495  const auto& kargs = gemm_desc_ptr[group_id];
496  const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N);
497  const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
498  0,
499  kargs.group_karg.M,
500  kargs.group_karg.N,
501  (block_id - kargs.block_start) % grid_size_2d);
502  Run(kargs.group_karg, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d);
503  }
504 
505  // For persistent kernels
506  template <bool U = UsePersistentKernel,
507  typename = std::enable_if_t<U>,
508  typename = void> // extra template parameter to avoid redefinition
509  CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
510  const index_t group_count) const
511  {
512  const index_t grid_size = ck_tile::get_grid_size();
513  const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
514  cast_pointer_to_generic_address_space(gemm_descs_const));
515  index_t block_id = ck_tile::get_block_1d_id(); // initial block_id
516  index_t cum_grid_size = 0;
517  for(index_t group_id = 0; group_id < group_count; ++group_id)
518  {
519  const auto& kargs = gemm_desc_ptr[group_id].group_karg;
520  const auto& k_batch = kargs.k_batch;
521  const auto block_start = cum_grid_size;
522  cum_grid_size += TilePartitioner::GridSize(kargs.M, kargs.N) * k_batch;
523  while(block_id < cum_grid_size)
524  {
525  const auto grid_size_2d = TilePartitioner::GridSize(kargs.M, kargs.N);
526  const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
527  0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d);
528  Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d);
529  block_id = block_id + grid_size; // advance to next block
530  // NOTE: this check is redundant but helps the compiler avoid spilling some VGPR
531  if(block_id >= cum_grid_size)
532  {
533  break; // exit the loop if all blocks are processed
534  }
535  }
536  }
537  }
538 };
539 
540 } // namespace ck_tile
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition: arch.hpp:215
#define CK_CONSTANT_ADDRESS_SPACE
Definition: ck.hpp:22
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
#define HIP_CHECK_ERROR(retval_or_funcall)
Definition: hip_check_error.hpp:21
Definition: cluster_descriptor.hpp:13
TailNumber
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:21
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
__device__ index_t get_grid_size()
Definition: get_id.hpp:60
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:58
Definition: grouped_gemm_kernel.hpp:75
ck_tile::index_t block_end
Definition: grouped_gemm_kernel.hpp:78
UniversalGemmKernelArgs group_karg
Definition: grouped_gemm_kernel.hpp:76
GemmTransKernelArg(UniversalGemmKernelArgs<> &&karg, index_t bl_start, index_t bl_end)
Definition: grouped_gemm_kernel.hpp:81
GemmTransKernelArg(UniversalGemmKernelArgs<> &&karg)
Definition: grouped_gemm_kernel.hpp:86
ck_tile::index_t block_start
Definition: grouped_gemm_kernel.hpp:77
The Grouped GEMM kernel host arguments.
Definition: grouped_gemm_kernel.hpp:27
index_t stride_C
Definition: grouped_gemm_kernel.hpp:68
index_t N
Definition: grouped_gemm_kernel.hpp:60
index_t stride_E
Definition: grouped_gemm_kernel.hpp:67
index_t M
Definition: grouped_gemm_kernel.hpp:59
const void * a_ptr
Definition: grouped_gemm_kernel.hpp:51
index_t stride_B
Definition: grouped_gemm_kernel.hpp:63
index_t k_batch
Definition: grouped_gemm_kernel.hpp:71
void * e_ptr
Definition: grouped_gemm_kernel.hpp:55
void * c_ptr
Definition: grouped_gemm_kernel.hpp:56
const void * b_ptr
Definition: grouped_gemm_kernel.hpp:52
CK_TILE_HOST GroupedGemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_E_)
Definition: grouped_gemm_kernel.hpp:28
index_t stride_A
Definition: grouped_gemm_kernel.hpp:62
index_t K
Definition: grouped_gemm_kernel.hpp:61
Definition: grouped_gemm_kernel.hpp:94
static CK_TILE_HOST auto GridSize(const std::vector< GroupedGemmHostArgs > &gemm_descs)
Definition: grouped_gemm_kernel.hpp:187
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_gemm_kernel.hpp:99
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count) const
Definition: grouped_gemm_kernel.hpp:509
static CK_TILE_HOST auto GetWorkSpaceSize(index_t group_count) -> std::size_t
Definition: grouped_gemm_kernel.hpp:153
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition: grouped_gemm_kernel.hpp:104
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: grouped_gemm_kernel.hpp:105
static CK_TILE_HOST auto MakeKargs(const std::vector< GroupedGemmHostArgs > &gemm_descs) -> std::vector< GemmTransKernelArg >
Definition: grouped_gemm_kernel.hpp:199
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: grouped_gemm_kernel.hpp:110
static CK_TILE_DEVICE void RunGemmWithPipelineSelection2LDS(const ADataType *a_ptr, const BDataType *b_ptr, CDataType *c_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const UniversalGemmKernelArgs<> &kargs, const typename Base::SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_gemm_kernel.hpp:416
static constexpr index_t kBlockSize
Definition: grouped_gemm_kernel.hpp:131
static constexpr CK_TILE_HOST_DEVICE auto GetSmemSize() -> index_t
Definition: grouped_gemm_kernel.hpp:260
static CK_TILE_HOST bool IsSupportedArgument(const std::vector< GemmTransKernelArg > &kargs)
Definition: grouped_gemm_kernel.hpp:248
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: grouped_gemm_kernel.hpp:106
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_gemm_kernel.hpp:101
static CK_TILE_DEVICE void RunGemmWithPipelineSelection(const ADataType *a_ptr, const BDataType *b_ptr, CDataType *c_ptr, void *smem_ptr_0, const UniversalGemmKernelArgs<> &kargs, const typename Base::SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_gemm_kernel.hpp:355
static CK_TILE_HOST auto BlockSize() -> dim3
Definition: grouped_gemm_kernel.hpp:158
CK_TILE_DEVICE void Run(const UniversalGemmKernelArgs<> &kargs, const tuple< index_t, index_t > &block_idx_2d, const index_t block_idx_z) const
Definition: grouped_gemm_kernel.hpp:265
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: grouped_gemm_kernel.hpp:111
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: grouped_gemm_kernel.hpp:176
static CK_TILE_HOST const std::string GetName()
Definition: grouped_gemm_kernel.hpp:134
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_gemm_kernel.hpp:100
CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg *gemm_desc_ptr, index_t block_id, index_t group_count) const
Definition: grouped_gemm_kernel.hpp:459
static CK_TILE_HOST auto GetWorkSpaceSize(const std::vector< GroupedGemmHostArgs > &gemm_descs) -> std::size_t
Definition: grouped_gemm_kernel.hpp:148
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Specify the data type configurations for A, B, C/E.
Definition: grouped_gemm_kernel.hpp:109
static constexpr bool UsePersistentKernel
Definition: grouped_gemm_kernel.hpp:132
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, index_t group_count) const
Definition: grouped_gemm_kernel.hpp:487
Struct used to calculate offseted tile indexes.
Definition: gemm_tile_partitioner.hpp:184
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from 1D raw-indexes.
Definition: gemm_tile_partitioner.hpp:192
Definition: universal_gemm_kernel.hpp:322
std::array< index_t, NumATensor > as_k_split_offset
Definition: universal_gemm_kernel.hpp:365
index_t splitted_k
Definition: universal_gemm_kernel.hpp:367
std::array< index_t, NumBTensor > bs_k_split_offset
Definition: universal_gemm_kernel.hpp:366
The GEMM kernel device arguments.
Definition: universal_gemm_kernel.hpp:86
void * e_ptr
The E output tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:94
const std::array< const void *, NumATensor > as_ptr
The As input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:88
index_t k_batch
Definition: universal_gemm_kernel.hpp:113
const std::array< const void *, NumBTensor > bs_ptr
The Bs input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:90
The Universal GEMM kernel template.
Definition: universal_gemm_kernel.hpp:154
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: universal_gemm_kernel.hpp:952
static constexpr auto I2
Definition: universal_gemm_kernel.hpp:235
static CK_TILE_DEVICE void RunGemm2LDS(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: universal_gemm_kernel.hpp:1007
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: universal_gemm_kernel.hpp:850
static constexpr auto I3
Definition: universal_gemm_kernel.hpp:236
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: universal_gemm_kernel.hpp:751
static constexpr auto I1
Definition: universal_gemm_kernel.hpp:234
static constexpr auto I0
Definition: universal_gemm_kernel.hpp:233
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:370
Definition: stream_config.hpp:30
Definition: tuple.hpp:192