/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp Source File
gemm_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 
12 namespace ck_tile {
13 
15 {
18  index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
19  : M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_)
20  {
21  }
22 
29 };
30 
31 struct GemmHostArgs : public GemmProblem
32 {
34  CK_TILE_HOST GemmHostArgs(const void* a_ptr_,
35  const void* b_ptr_,
36  void* c_ptr_,
37  index_t k_batch_,
38  index_t M_,
39  index_t N_,
40  index_t K_,
41  index_t stride_A_,
42  index_t stride_B_,
43  index_t stride_C_)
44  : GemmProblem(M_, N_, K_, stride_A_, stride_B_, stride_C_),
45  a_ptr(a_ptr_),
46  b_ptr(b_ptr_),
47  c_ptr(c_ptr_),
48  k_batch(k_batch_)
49  {
50  }
51 
52  const void* a_ptr;
53  const void* b_ptr;
54  void* c_ptr;
56 };
57 
58 template <typename TilePartitioner_, typename GemmPipeline_, typename EpiloguePipeline_>
59 struct GemmKernel
60 {
67  static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
68 
71  // Below type is actually accumulation data type - the output of block GEMM.
73 
74  static constexpr auto I0 = number<0>();
75  static constexpr auto I1 = number<1>();
76  static constexpr auto I2 = number<2>();
77 
78  CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
79  {
80  return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
81  }
82 
83  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
84 
86  {
87  const void* a_ptr;
88  const void* b_ptr;
89  void* c_ptr;
97  };
98 
99  CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const GemmHostArgs& hostArgs)
100  {
101  return GemmKernelArgs{hostArgs.a_ptr,
102  hostArgs.b_ptr,
103  hostArgs.c_ptr,
104  hostArgs.M,
105  hostArgs.N,
106  hostArgs.K,
107  hostArgs.stride_A,
108  hostArgs.stride_B,
109  hostArgs.stride_C,
110  hostArgs.k_batch};
111  }
112 
114  {
115  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
116  }
117 
119  {
120  __device__ SplitKBatchOffset(const GemmKernelArgs& kargs,
121  const std::size_t k_id = blockIdx.z)
122  {
123  constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
124  const index_t K_t = kargs.k_batch * K1;
125  const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
126 
127  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
128  {
129  a_k_split_offset = k_id * KRead;
130  }
131  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
132  {
133  a_k_split_offset = k_id * KRead * kargs.stride_A;
134  }
135 
136  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
137  {
138  b_k_split_offset = k_id * KRead * kargs.stride_B;
139  }
140  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
141  {
142  b_k_split_offset = k_id * KRead;
143  }
144 
145  if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
146  {
147  splitted_k = KRead;
148  }
149  else
150  {
151  splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
152  }
153  }
154 
158  };
159 
161  {
162  if constexpr(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
164  {
165  if(kargs.k_batch != 1)
166  {
167  std::cerr << "Conditions not met for Kbatch >1 !" << std::endl;
168  return false;
169  }
170  }
171 
172  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
173  {
174  if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false)
175  {
176  std::cerr << "Can't support K that is not a multiple of KPerBlock"
177  " without padding!"
178  << std::endl;
179  return false;
180  }
181  if(kargs.K % GemmPipeline::GetVectorSizeA() != 0)
182  {
183  std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl;
184  return false;
185  }
186  }
187  else
188  {
189  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
190  {
191  std::cerr << "Can't support M that is not a multiple of MPerBlock"
192  " without padding!"
193  << std::endl;
194  return false;
195  }
196  if(kargs.M % GemmPipeline::GetVectorSizeA() != 0)
197  {
198  std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl;
199  return false;
200  }
201  }
202 
203  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
204  {
205  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
206  {
207  std::cerr << "Can't support N that is not a multiple of NPerBlock"
208  " without padding!"
209  << std::endl;
210  return false;
211  }
212  if(kargs.N % GemmPipeline::GetVectorSizeB() != 0)
213  {
214  std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl;
215  return false;
216  }
217  }
218  else
219  {
220  if(kargs.K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK == false)
221  {
222  std::cerr << "Can't support K that is not a multiple of KPerBlock"
223  " without padding!"
224  << std::endl;
225  return false;
226  }
227  if(kargs.K % GemmPipeline::GetVectorSizeB() != 0)
228  {
229  std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl;
230  return false;
231  }
232  }
233 
234  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
235  {
236  if(kargs.N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN == false)
237  {
238  std::cerr << "Can't support N that is not a multiple of NPerBlock"
239  " without padding!"
240  << std::endl;
241  return false;
242  }
243  if(kargs.N % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
244  {
245  std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
246  return false;
247  }
248  }
249  else
250  {
251  if(kargs.M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM == false)
252  {
253  std::cerr << "Can't support M that is not a multiple of MPerBlock"
254  " without padding!"
255  << std::endl;
256  return false;
257  }
258  if(kargs.M % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
259  {
260  std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
261  return false;
262  }
263  }
264  return true;
265  }
266 
267  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
268  CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
269  const BDataType* b_ptr,
270  CDataType* c_ptr,
271  const GemmKernelArgs& kargs,
272  const SplitKBatchOffset& splitk_batch_offset)
273  {
274  const auto& a_tensor_view = [&]() {
275  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
276  {
277  return make_naive_tensor_view<address_space_enum::global>(
278  a_ptr,
279  make_tuple(kargs.M, splitk_batch_offset.splitted_k),
280  make_tuple(kargs.stride_A, 1),
281  number<GemmPipeline::GetVectorSizeA()>{},
282  number<1>{});
283  }
284  else
285  {
286  return make_naive_tensor_view<address_space_enum::global>(
287  a_ptr,
288  make_tuple(splitk_batch_offset.splitted_k, kargs.M),
289  make_tuple(kargs.stride_A, 1),
290  number<GemmPipeline::GetVectorSizeA()>{},
291  number<1>{});
292  }
293  }();
294 
295  const auto& b_tensor_view = [&]() {
296  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
297  {
298  return make_naive_tensor_view<address_space_enum::global>(
299  b_ptr,
300  make_tuple(splitk_batch_offset.splitted_k, kargs.N),
301  make_tuple(kargs.stride_B, 1),
302  number<GemmPipeline::GetVectorSizeB()>{},
303  number<1>{});
304  }
305  else
306  {
307  return make_naive_tensor_view<address_space_enum::global>(
308  b_ptr,
309  make_tuple(kargs.N, splitk_batch_offset.splitted_k),
310  make_tuple(kargs.stride_B, 1),
311  number<GemmPipeline::GetVectorSizeB()>{},
312  number<1>{});
313  }
314  }();
315 
316  // TODO: enable vector write for C in ColMajor
317  const auto& c_tensor_view = [&]() {
318  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
319  {
320  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
321  c_ptr,
322  make_tuple(kargs.M, kargs.N),
323  make_tuple(kargs.stride_C, 1),
324  number<EpiloguePipeline::template GetVectorSizeC<CDataType>()>{},
325  number<1>{});
326  }
327  else
328  {
329  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
330  c_ptr,
331  make_tuple(kargs.M, kargs.N),
332  make_tuple(1, kargs.stride_C),
333  number<1>{},
334  number<1>{});
335  }
336  }();
337 
338  return make_tuple(a_tensor_view, b_tensor_view, c_tensor_view);
339  }
340 
341  template <typename TensorView>
342  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
343  {
344  const auto& a_pad_view = [&]() {
345  const auto& a_tensor_view = views.at(I0);
346  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
347  {
348  return pad_tensor_view(a_tensor_view,
352  }
353  else
354  {
355  return pad_tensor_view(a_tensor_view,
359  }
360  }();
361 
362  const auto& b_pad_view = [&]() {
363  const auto& b_tensor_view = views.at(I1);
364  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
365  {
366  return pad_tensor_view(b_tensor_view,
370  }
371  else
372  {
373  return pad_tensor_view(b_tensor_view,
377  }
378  }();
379 
380  // TODO vector write in for C in ColMajor
381  const auto& c_pad_view = [&]() {
382  const auto& c_tensor_view = views.at(I2);
383  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
384  {
385  return pad_tensor_view(c_tensor_view,
389  }
390  else
391  {
392  return pad_tensor_view(c_tensor_view,
396  }
397  }();
398 
399  return make_tuple(a_pad_view, b_pad_view, c_pad_view);
400  }
401 
402  template <typename PadView>
403  CK_TILE_DEVICE static auto
404  MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
405  {
406  const auto& a_pad_view = views.at(I0);
407  const auto& b_pad_view = views.at(I1);
408  const auto& c_pad_view = views.at(I2);
409 
410  const auto& a_block_window = [&]() {
411  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
412  {
413  return make_tile_window(a_pad_view,
416  {i_m, 0});
417  }
418  else
419  {
420  return make_tile_window(a_pad_view,
423  {0, i_m});
424  }
425  }();
426 
427  const auto& b_block_window = [&]() {
428  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
429  {
430  return make_tile_window(b_pad_view,
433  {i_n, 0});
434  }
435  else
436  {
437  return make_tile_window(b_pad_view,
440  {0, i_n});
441  }
442  }();
443 
444  auto c_block_window = make_tile_window(
445  c_pad_view,
447  {i_m, i_n});
448 
449  return make_tuple(a_block_window, b_block_window, c_block_window);
450  }
451 
464  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
465  CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
466  const BDataType* b_ptr,
467  CDataType* c_ptr,
468  void* smem_ptr,
469  const GemmKernelArgs& kargs,
470  const SplitKBatchOffset& splitk_batch_offset,
471  const index_t block_idx_m,
472  const index_t block_idx_n)
473  {
474  // Create Gemm tensor views, pad views and tile windows
475  const auto& gemm_tensor_views_tuple =
476  MakeGemmTensorViews<DstInMemOp>(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
477 
478  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
479  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
480 
481  const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
482 
483  // Run GEMM cooperatively by whole workgroup.
484  const auto& a_block_window = gemm_tile_windows.at(I0);
485  const auto& b_block_window = gemm_tile_windows.at(I1);
486  const auto& c_block_tile =
487  GemmPipeline{}.template operator()(a_block_window, b_block_window, num_loop, smem_ptr);
488 
489  // Run Epilogue Pipeline
490  auto& c_block_window = gemm_tile_windows.at(I2);
491 
493  .template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
494  c_block_window, c_block_tile, smem_ptr);
495  }
496 
498  {
499  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
500  const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
501  const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
502 
503  const SplitKBatchOffset splitk_batch_offset(kargs);
504  // options
505  const ADataType* a_ptr =
506  static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
507  const BDataType* b_ptr =
508  static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
509  CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
510 
511  // allocate LDS
512  __shared__ char smem_ptr[GetSmemSize()];
513 
514  if(kargs.k_batch == 1)
515  {
516  RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
517  }
518  else
519  {
520  // Do not compile in case where we have unsupported
521  // VectorSizeC & data type configuration.
522  if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
524  {
525  RunGemm<memory_operation_enum::atomic_add>(
526  a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
527  }
528  }
529  }
530 };
531 
532 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:480
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
Definition: gemm_kernel.hpp:32
CK_TILE_HOST GemmHostArgs()=default
void * c_ptr
Definition: gemm_kernel.hpp:54
CK_TILE_HOST GemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
Definition: gemm_kernel.hpp:34
const void * a_ptr
Definition: gemm_kernel.hpp:52
const void * b_ptr
Definition: gemm_kernel.hpp:53
index_t k_batch
Definition: gemm_kernel.hpp:55
Definition: gemm_kernel.hpp:86
index_t M
Definition: gemm_kernel.hpp:90
index_t N
Definition: gemm_kernel.hpp:91
const void * b_ptr
Definition: gemm_kernel.hpp:88
const void * a_ptr
Definition: gemm_kernel.hpp:87
index_t k_batch
Definition: gemm_kernel.hpp:96
index_t stride_A
Definition: gemm_kernel.hpp:93
void * c_ptr
Definition: gemm_kernel.hpp:89
index_t stride_B
Definition: gemm_kernel.hpp:94
index_t K
Definition: gemm_kernel.hpp:92
index_t stride_C
Definition: gemm_kernel.hpp:95
Definition: gemm_kernel.hpp:119
index_t b_k_split_offset
Definition: gemm_kernel.hpp:156
__device__ SplitKBatchOffset(const GemmKernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: gemm_kernel.hpp:120
index_t a_k_split_offset
Definition: gemm_kernel.hpp:155
index_t splitted_k
Definition: gemm_kernel.hpp:157
Definition: gemm_kernel.hpp:60
static CK_TILE_DEVICE void RunGemm(const ADataType *a_ptr, const BDataType *b_ptr, CDataType *c_ptr, void *smem_ptr, const GemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: gemm_kernel.hpp:465
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Definition: gemm_kernel.hpp:69
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: gemm_kernel.hpp:78
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition: gemm_kernel.hpp:64
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: gemm_kernel.hpp:72
static constexpr index_t KernelBlockSize
Definition: gemm_kernel.hpp:67
static CK_TILE_HOST bool IsSupportedArgument(const GemmKernelArgs &kargs)
Definition: gemm_kernel.hpp:160
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: gemm_kernel.hpp:70
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: gemm_kernel.hpp:404
static constexpr auto I0
Definition: gemm_kernel.hpp:74
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
Definition: gemm_kernel.hpp:497
static constexpr auto I1
Definition: gemm_kernel.hpp:75
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: gemm_kernel.hpp:113
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: gemm_kernel.hpp:62
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: gemm_kernel.hpp:66
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: gemm_kernel.hpp:342
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: gemm_kernel.hpp:65
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_ptr, CDataType *c_ptr, const GemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: gemm_kernel.hpp:268
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: gemm_kernel.hpp:61
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: gemm_kernel.hpp:63
static constexpr auto I2
Definition: gemm_kernel.hpp:76
static constexpr CK_TILE_HOST GemmKernelArgs MakeKernelArgs(const GemmHostArgs &hostArgs)
Definition: gemm_kernel.hpp:99
static constexpr CK_TILE_HOST auto BlockSize()
Definition: gemm_kernel.hpp:83
Definition: gemm_kernel.hpp:15
index_t stride_C
Definition: gemm_kernel.hpp:28
index_t stride_B
Definition: gemm_kernel.hpp:27
CK_TILE_HOST GemmProblem(index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
Definition: gemm_kernel.hpp:17
CK_TILE_HOST GemmProblem()=default
index_t K
Definition: gemm_kernel.hpp:25
index_t stride_A
Definition: gemm_kernel.hpp:26
index_t N
Definition: gemm_kernel.hpp:24
index_t M
Definition: gemm_kernel.hpp:23
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:114
Definition: sequence.hpp:52