/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp Source File
grouped_gemm_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
13 #include "ck_tile/host.hpp"
14 
15 #include <hip/hip_runtime.h>
16 
17 namespace ck_tile {
18 
20 {
24 
25  GemmTransKernelArg() = delete;
27  : group_karg{karg}, block_start{bl_start}, block_end{bl_end}
28  {
29  }
30 
32 };
33 
34 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
35 struct GroupedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>
36 {
43 
47 
51 
52  static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
53  static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel;
54 
55  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
56  {
57  // clang-format off
58  using P_ = GemmPipeline;
59 
60  return concat('_', "gemm_grouped", gemm_prec_str<ADataType, BDataType>,
61  concat('x', P_::MPerBlock, P_::NPerBlock, P_::KPerBlock),
62  concat('x', P_::GetVectorSizeA(), P_::GetVectorSizeB(), P_::GetVectorSizeC()),
63  concat('x', P_::kPadM, P_::kPadN, P_::kPadK),
64  (UsePersistentKernel ? "Persistent" : "NonPersistent"));
65  // clang-format on
66  }
67 
68  CK_TILE_HOST static auto
69  GetWorkSpaceSize(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs) -> std::size_t
70  {
71  return gemm_descs.size() * sizeof(GemmTransKernelArg);
72  }
73 
74  CK_TILE_HOST static auto GetWorkSpaceSize(index_t group_count) -> std::size_t
75  {
76  return group_count * sizeof(GemmTransKernelArg);
77  }
78 
79  CK_TILE_HOST static constexpr auto BlockSize() -> dim3 { return dim3(KernelBlockSize); }
80 
87  CK_TILE_HOST static auto MaxOccupancyGridSize(const stream_config& s) -> dim3
88  {
89  using ConstantPointer = const void CK_CONSTANT_ADDRESS_SPACE*;
90  const auto kernel = kentry<KernelBlockSize, 1, Kernel, ConstantPointer, index_t>;
91  int occupancy;
93  hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel, KernelBlockSize, 0));
94  const int grid_size = get_available_compute_units(s) * occupancy;
95  return dim3(grid_size, 1, 1);
96  }
97 
98  CK_TILE_HOST static constexpr auto
99  GridSize(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs)
100  {
101  index_t grid_size = 0;
102  for(const auto& it_desc : gemm_descs)
103  {
104  const auto local_grid_size = TilePartitioner::GridSize(it_desc.M, it_desc.N);
105  grid_size += local_grid_size * it_desc.k_batch;
106  }
107  return dim3(grid_size, 1, 1);
108  }
109 
110  CK_TILE_HOST static auto
111  MakeKargs(const std::vector<GemmHostArgs</*NumDTensor = 0*/>>& gemm_descs)
112  -> std::vector<GemmTransKernelArg>
113  {
114  std::vector<GemmTransKernelArg> gemm_kernel_args_;
115  index_t group_count = ck_tile::type_convert<ck_tile::index_t>(gemm_descs.size());
116  index_t grid_size = 0;
117  gemm_kernel_args_.reserve(group_count);
118 
119  for(std::size_t i = 0; i < gemm_descs.size(); ++i)
120  {
121  const index_t M = gemm_descs[i].M;
122  const index_t N = gemm_descs[i].N;
123  const index_t K = gemm_descs[i].K;
124 
125  if(M == 0 || N == 0 || K == 0)
126  {
127  continue;
128  }
129 
130  const index_t stride_a = gemm_descs[i].stride_A;
131  const index_t stride_b = gemm_descs[i].stride_B;
132  const index_t stride_e = gemm_descs[i].stride_E;
133 
134  const index_t grid_size_grp = TilePartitioner::GridSize(M, N) * gemm_descs[i].k_batch;
135 
136  const index_t block_start = grid_size;
137  const index_t block_end = grid_size + grid_size_grp;
138 
139  grid_size += grid_size_grp;
140 
141  auto karg = GemmKernelArgs<>{type_convert<const ADataType*>(gemm_descs[i].a_ptr),
142  type_convert<const BDataType*>(gemm_descs[i].b_ptr),
143  {},
144  type_convert<CDataType*>(gemm_descs[i].e_ptr),
145  M,
146  N,
147  K,
148  stride_a,
149  stride_b,
150  {},
151  stride_e,
152  gemm_descs[i].k_batch};
153 
154  gemm_kernel_args_.emplace_back(std::move(karg), block_start, block_end);
155  }
156 
157  return gemm_kernel_args_;
158  }
159 
160  CK_TILE_HOST static bool IsSupportedArgument(const std::vector<GemmTransKernelArg>& kargs)
161  {
162  for(const auto& karg : kargs)
163  {
164  if(!Base::IsSupportedArgument(karg.group_karg))
165  {
166  return false;
167  }
168  }
169  return true;
170  }
171 
172  CK_TILE_HOST_DEVICE static constexpr auto GetSmemSize() -> index_t
173  {
174  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
175  }
176 
178  const tuple<index_t, index_t>& block_idx_2d,
179  const index_t block_idx_z) const
180  {
181  Run(kargs.group_karg, block_idx_2d, block_idx_z);
182  }
183 
185  const tuple<index_t, index_t>& block_idx_2d,
186  const index_t block_idx_z) const
187  {
188  const auto [iM, iN] = block_idx_2d;
189 
190  const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
191  const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
192 
193  const typename Base::SplitKBatchOffset splitk_batch_offset(kargs, block_idx_z);
194 
195  const ADataType* a_ptr =
196  static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
197  const BDataType* b_ptr =
198  static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
199  CDataType* c_ptr = static_cast<CDataType*>(kargs.e_ptr);
200 
201  // allocate LDS
202  __shared__ char smem_ptr[GetSmemSize()];
203 
204  if constexpr(UsePersistentKernel)
205  {
207  a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
208  }
209  else
210  {
211  this->RunGemm(a_ptr, b_ptr, {}, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
212  }
213  }
214 
232  CK_TILE_DEVICE static void
234  const BDataType* b_ptr,
235  CDataType* c_ptr,
236  void* smem_ptr_0,
237  const GemmKernelArgs<>& kargs,
238  const typename Base::SplitKBatchOffset& splitk_batch_offset,
239  const index_t block_idx_m,
240  const index_t block_idx_n)
241  {
242  // Create Gemm tensor views, pad views and tile windows
243  const auto& gemm_tensor_views_tuple =
244  Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
245  a_ptr, b_ptr, {}, c_ptr, kargs, splitk_batch_offset);
246 
247  const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
248  auto gemm_tile_windows =
249  Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
250  const auto& a_block_window = gemm_tile_windows.at(Base::I0);
251  const auto& b_block_window = gemm_tile_windows.at(Base::I1);
252  const auto& d_block_window = gemm_tile_windows.at(Base::I2);
253 
254  // Get hot-loop and tail configuration
255  const index_t num_loop = __builtin_amdgcn_readfirstlane(
256  TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
257  const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
258  const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
259 
260  // Run GEMM pipeline
261  const auto& c_block_tile = GemmPipeline{}.template operator()(
262  a_block_window, b_block_window, num_loop, has_hot_loop, tail_num, smem_ptr_0);
263  // Run Epilogue Pipeline
264  auto& c_block_window = gemm_tile_windows.at(Base::I3);
265  EpiloguePipeline{}.template
266  operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
267  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
268  }
269 
271  index_t block_id,
272  index_t group_count) const
273  {
274  index_t left = 0;
275  index_t right = group_count;
276  index_t group_id = index_t((left + right) >> 1);
277 
278  while((!(block_id >= gemm_desc_ptr[group_id].block_start &&
279  block_id < gemm_desc_ptr[group_id].block_end)) &&
280  left <= right)
281  {
282  if(block_id < gemm_desc_ptr[group_id].block_start)
283  {
284  right = group_id;
285  }
286  else
287  {
288  left = group_id;
289  }
290  group_id = index_t((left + right) >> 1);
291  }
292 
293  return group_id;
294  }
295 
296  // For non-persistent kernels
297  template <bool U = UsePersistentKernel, typename = std::enable_if_t<!U>>
298  CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
299  index_t group_count) const
300  {
301  const index_t block_id = ck_tile::get_block_1d_id();
302  const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
303  cast_pointer_to_generic_address_space(gemm_descs_const));
304 
305  const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
306  const auto& kargs = gemm_desc_ptr[group_id];
307  const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N);
308  const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
309  0,
310  kargs.group_karg.M,
311  kargs.group_karg.N,
312  (block_id - kargs.block_start) % grid_size_2d);
313  Run(kargs, block_idx_2d, (block_id - kargs.block_start) / grid_size_2d);
314  }
315 
316  // For persistent kernels
317  template <bool U = UsePersistentKernel,
318  typename = std::enable_if_t<U>,
319  typename = void> // extra template parameter to avoid redefinition
320  CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE* gemm_descs_const,
321  const index_t group_count) const
322  {
323  const index_t grid_size = ck_tile::get_grid_size();
324  const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
325  cast_pointer_to_generic_address_space(gemm_descs_const));
326  index_t block_id = ck_tile::get_block_1d_id(); // initial block_id
327  index_t cum_grid_size = 0;
328  for(index_t group_id = 0; group_id < group_count; ++group_id)
329  {
330  const auto& kargs = gemm_desc_ptr[group_id].group_karg;
331  const auto& k_batch = kargs.k_batch;
332  const auto block_start = cum_grid_size;
333  cum_grid_size += TilePartitioner::GridSize(kargs.M, kargs.N) * k_batch;
334  while(block_id < cum_grid_size)
335  {
336  const auto grid_size_2d = TilePartitioner::GridSize(kargs.M, kargs.N);
337  const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
338  0, kargs.M, kargs.N, (block_id - block_start) % grid_size_2d);
339  Run(kargs, block_idx_2d, (block_id - block_start) / grid_size_2d);
340  block_id = block_id + grid_size; // advance to next block
341  // NOTE: this check is redundant but helps the compiler avoid spilling some VGPR
342  if(block_id >= cum_grid_size)
343  {
344  break; // exit the loop if all blocks are processed
345  }
346  }
347  }
348  }
349 };
350 
351 } // namespace ck_tile
#define CK_CONSTANT_ADDRESS_SPACE
Definition: ck.hpp:22
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
#define HIP_CHECK_ERROR(retval_or_funcall)
Definition: hip_check_error.hpp:22
Definition: cluster_descriptor.hpp:13
TailNumber
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:21
CK_TILE_DEVICE index_t get_block_1d_id()
Definition: arch.hpp:69
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:41
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition: arch.hpp:139
CK_TILE_DEVICE index_t get_grid_size()
Definition: arch.hpp:60
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
The GEMM kernel host arguments.
Definition: gemm_kernel.hpp:30
Definition: gemm_kernel.hpp:251
index_t b_k_split_offset
Definition: gemm_kernel.hpp:287
index_t a_k_split_offset
Definition: gemm_kernel.hpp:286
index_t splitted_k
Definition: gemm_kernel.hpp:288
The GEMM kernel device arguments.
Definition: gemm_kernel.hpp:85
const void * a_ptr
The A input tensor's pointer to device memory.
Definition: gemm_kernel.hpp:87
const void * b_ptr
The B input tensor's pointer to device memory.
Definition: gemm_kernel.hpp:89
void * e_ptr
The E output tensor's pointer to device memory.
Definition: gemm_kernel.hpp:93
index_t k_batch
Definition: gemm_kernel.hpp:112
The GEMM kernel template.
Definition: gemm_kernel.hpp:153
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: gemm_kernel.hpp:291
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Definition: gemm_kernel.hpp:180
static constexpr auto I3
Definition: gemm_kernel.hpp:190
remove_cvref_t< typename GemmPipeline::CLayout > ELayout
Definition: gemm_kernel.hpp:160
static CK_TILE_DEVICE void RunGemm(const ADataType *a_ptr, const BDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: gemm_kernel.hpp:794
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition: gemm_kernel.hpp:157
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: gemm_kernel.hpp:181
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: gemm_kernel.hpp:710
static constexpr auto I0
Definition: gemm_kernel.hpp:187
static constexpr auto I1
Definition: gemm_kernel.hpp:188
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: gemm_kernel.hpp:155
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: gemm_kernel.hpp:627
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: gemm_kernel.hpp:158
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: gemm_kernel.hpp:154
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: gemm_kernel.hpp:156
static constexpr auto I2
Definition: gemm_kernel.hpp:189
Definition: grouped_gemm_kernel.hpp:20
ck_tile::index_t block_end
Definition: grouped_gemm_kernel.hpp:23
GemmTransKernelArg(GemmKernelArgs<> &&karg)
Definition: grouped_gemm_kernel.hpp:31
GemmKernelArgs group_karg
Definition: grouped_gemm_kernel.hpp:21
ck_tile::index_t block_start
Definition: grouped_gemm_kernel.hpp:22
GemmTransKernelArg(GemmKernelArgs<> &&karg, index_t bl_start, index_t bl_end)
Definition: grouped_gemm_kernel.hpp:26
Definition: grouped_gemm_kernel.hpp:36
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count) const
Definition: grouped_gemm_kernel.hpp:320
static CK_TILE_DEVICE void RunGemmWithPipelineSelection(const ADataType *a_ptr, const BDataType *b_ptr, CDataType *c_ptr, void *smem_ptr_0, const GemmKernelArgs<> &kargs, const typename Base::SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_gemm_kernel.hpp:233
static CK_TILE_HOST auto GetWorkSpaceSize(index_t group_count) -> std::size_t
Definition: grouped_gemm_kernel.hpp:74
static constexpr index_t KernelBlockSize
Definition: grouped_gemm_kernel.hpp:52
CK_TILE_DEVICE void Run(const GemmKernelArgs<> &kargs, const tuple< index_t, index_t > &block_idx_2d, const index_t block_idx_z) const
Definition: grouped_gemm_kernel.hpp:184
static constexpr CK_TILE_HOST auto GridSize(const std::vector< GemmHostArgs<>> &gemm_descs)
Definition: grouped_gemm_kernel.hpp:99
static constexpr CK_TILE_HOST auto BlockSize() -> dim3
Definition: grouped_gemm_kernel.hpp:79
static constexpr CK_TILE_HOST_DEVICE auto GetSmemSize() -> index_t
Definition: grouped_gemm_kernel.hpp:172
static CK_TILE_HOST bool IsSupportedArgument(const std::vector< GemmTransKernelArg > &kargs)
Definition: grouped_gemm_kernel.hpp:160
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: grouped_gemm_kernel.hpp:46
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: grouped_gemm_kernel.hpp:87
static CK_TILE_HOST const std::string GetName()
Definition: grouped_gemm_kernel.hpp:55
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_gemm_kernel.hpp:38
CK_TILE_DEVICE void Run(const GemmTransKernelArg &kargs, const tuple< index_t, index_t > &block_idx_2d, const index_t block_idx_z) const
Definition: grouped_gemm_kernel.hpp:177
CK_TILE_DEVICE index_t FindGroupId(const GemmTransKernelArg *gemm_desc_ptr, index_t block_id, index_t group_count) const
Definition: grouped_gemm_kernel.hpp:270
static CK_TILE_HOST auto GetWorkSpaceSize(const std::vector< GemmHostArgs<>> &gemm_descs) -> std::size_t
Definition: grouped_gemm_kernel.hpp:69
static CK_TILE_HOST auto MakeKargs(const std::vector< GemmHostArgs<>> &gemm_descs) -> std::vector< GemmTransKernelArg >
Definition: grouped_gemm_kernel.hpp:111
static constexpr bool UsePersistentKernel
Definition: grouped_gemm_kernel.hpp:53
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, index_t group_count) const
Definition: grouped_gemm_kernel.hpp:298
Struct used to calculate offseted tile indexes.
Definition: gemm_tile_partitioner.hpp:183
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from 1D raw-indexes.
Definition: gemm_tile_partitioner.hpp:192
Definition: stream_config.hpp:26
Definition: tuple.hpp:192