/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp Source File
flatmm_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
12 
13 namespace ck_tile {
14 
16 {
19  index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
20  : M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_)
21  {
22  }
23 
30 };
31 
33 {
35  CK_TILE_HOST FlatmmHostArgs(const void* a_ptr_,
36  const void* b_shuffle_ptr_,
37  void* c_ptr_,
38  index_t k_batch_,
39  index_t M_,
40  index_t N_,
41  index_t K_,
42  index_t stride_A_,
43  index_t stride_B_,
44  index_t stride_C_)
45  : FlatmmProblem(M_, N_, K_, stride_A_, stride_B_, stride_C_),
46  a_ptr(a_ptr_),
47  b_shuffle_ptr(b_shuffle_ptr_),
48  c_ptr(c_ptr_),
49  k_batch(k_batch_)
50  {
51  }
52 
53  const void* a_ptr;
54  const void* b_shuffle_ptr;
55  void* c_ptr;
57 };
58 
59 template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
61 {
70  static constexpr index_t KernelBlockSize = FlatmmPipeline::BlockSize;
71 
74  // Below type is actually accumulation data type - the output of block GEMM.
76 
77  static constexpr auto I0 = number<0>();
78  static constexpr auto I1 = number<1>();
79  static constexpr auto I2 = number<2>();
80  static constexpr auto idxM = I0;
81  static constexpr auto idxN = I1;
82  static constexpr auto idxK = I2;
83 
84  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
85  {
86  // clang-format off
87  return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
88  // clang-format on
89  }
90 
91  CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
92  {
93  return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
94  }
95 
96  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
97 
99  {
100  const void* a_ptr;
101  const void* b_shuffle_ptr;
102  void* c_ptr;
110  };
111 
112  CK_TILE_HOST static constexpr FlatmmKernelArgs MakeKernelArgs(const FlatmmHostArgs& hostArgs)
113  {
114  return FlatmmKernelArgs{hostArgs.a_ptr,
115  hostArgs.b_shuffle_ptr,
116  hostArgs.c_ptr,
117  hostArgs.M,
118  hostArgs.N,
119  hostArgs.K,
120  hostArgs.stride_A,
121  hostArgs.stride_B,
122  hostArgs.stride_C,
123  hostArgs.k_batch};
124  }
125 
127  {
128  return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
129  }
130 
132  {
133  __device__ SplitKBatchOffset(const FlatmmKernelArgs& kargs,
134  const std::size_t k_id = blockIdx.z)
135  {
136  constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
137  const index_t K_t = kargs.k_batch * K1;
138  const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
139 
140  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
141  {
142  a_k_split_offset = k_id * KRead;
143  }
144  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
145  {
146  a_k_split_offset = k_id * KRead * kargs.stride_A;
147  }
148 
149  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
150  {
151  b_k_split_offset = k_id * KRead * kargs.stride_B;
152  }
153  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
154  {
155  b_k_split_offset = k_id * KRead;
156  }
157 
158  if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
159  {
160  splitted_k = KRead;
161  }
162  else
163  {
164  splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
165  }
166  }
167 
171  };
172 
174  {
175  if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
177  {
178  if(kargs.k_batch != 1)
179  {
180  std::cerr << "Conditions not met for Kbatch >1 !" << std::endl;
181  return false;
182  }
183  }
184 
185  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
186  {
187  if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false)
188  {
189  std::cerr << "Can't support K that is not a multiple of KPerBlock"
190  " without padding!"
191  << std::endl;
192  return false;
193  }
194  if(kargs.K % FlatmmPipeline::GetVectorSizeA() != 0)
195  {
196  std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl;
197  return false;
198  }
199  }
200  else
201  {
202  if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
203  {
204  std::cerr << "Can't support M that is not a multiple of MPerBlock"
205  " without padding!"
206  << std::endl;
207  return false;
208  }
209  if(kargs.M % FlatmmPipeline::GetVectorSizeA() != 0)
210  {
211  std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl;
212  return false;
213  }
214  }
215 
216  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
217  {
218  if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
219  {
220  std::cerr << "Can't support N that is not a multiple of NPerBlock"
221  " without padding!"
222  << std::endl;
223  return false;
224  }
225  if(kargs.N % FlatmmPipeline::GetVectorSizeB() != 0)
226  {
227  std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl;
228  return false;
229  }
230  }
231  else
232  {
233  if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false)
234  {
235  std::cerr << "Can't support K that is not a multiple of KPerBlock"
236  " without padding!"
237  << std::endl;
238  return false;
239  }
240  if(kargs.K % FlatmmPipeline::GetVectorSizeB() != 0)
241  {
242  std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl;
243  return false;
244  }
245  }
246 
247  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
248  {
249  if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
250  {
251  std::cerr << "Can't support N that is not a multiple of NPerBlock"
252  " without padding!"
253  << std::endl;
254  return false;
255  }
256  if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
257  {
258  std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
259  return false;
260  }
261  }
262  else
263  {
264  if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
265  {
266  std::cerr << "Can't support M that is not a multiple of MPerBlock"
267  " without padding!"
268  << std::endl;
269  return false;
270  }
271  if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
272  {
273  std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
274  return false;
275  }
276  }
277  return true;
278  }
279 
280  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
281  CK_TILE_DEVICE static auto MakeGemmTensorViews(const ADataType* a_ptr,
282  const BDataType* b_flat_ptr,
283  CDataType* c_ptr,
284  const FlatmmKernelArgs& kargs,
285  const SplitKBatchOffset& splitk_batch_offset)
286  {
287  const auto& a_tensor_view = [&]() {
288  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
289  {
290  return make_naive_tensor_view<address_space_enum::global>(
291  a_ptr,
292  make_tuple(kargs.M, splitk_batch_offset.splitted_k),
293  make_tuple(kargs.stride_A, 1),
294  number<FlatmmPipeline::GetVectorSizeA()>{},
295  number<1>{});
296  }
297  else
298  {
299  return make_naive_tensor_view<address_space_enum::global>(
300  a_ptr,
301  make_tuple(splitk_batch_offset.splitted_k, kargs.M),
302  make_tuple(kargs.stride_A, 1),
303  number<FlatmmPipeline::GetVectorSizeA()>{},
304  number<1>{});
305  }
306  }();
307 
308  index_t kFlatK = FlatmmPipeline::flatKPerWarp * (splitk_batch_offset.splitted_k /
309  BlockGemmShape::WarpTile::at(number<2>{}));
310  index_t kFlatN = kargs.N * kargs.K / kFlatK;
311  const auto& b_flat_tensor_view = [&]() {
312  return make_naive_tensor_view<address_space_enum::global>(
313  b_flat_ptr,
314  make_tuple(kFlatN, kFlatK),
315  make_tuple(kFlatK, 1),
316  number<FlatmmPipeline::GetVectorSizeB()>{},
317  number<1>{});
318  }();
319 
320  // TODO: enable vector write for C in ColMajor
321  const auto& c_tensor_view = [&]() {
322  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
323  {
324  return make_naive_tensor_view<address_space_enum::global>(
325  c_ptr,
326  make_tuple(kargs.M, kargs.N),
327  make_tuple(kargs.stride_C, 1),
328  number<EpiloguePipeline::GetVectorSizeC()>{},
329  number<1>{});
330  }
331  else
332  {
333  return make_naive_tensor_view<address_space_enum::global>(
334  c_ptr,
335  make_tuple(kargs.M, kargs.N),
336  make_tuple(1, kargs.stride_C),
337  number<1>{},
338  number<1>{});
339  }
340  }();
341 
342  return make_tuple(a_tensor_view, b_flat_tensor_view, c_tensor_view);
343  }
344 
345  template <typename TensorView>
346  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
347  {
348  const auto& a_pad_view = [&]() {
349  const auto& a_tensor_view = views.at(I0);
350  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
351  {
352  return pad_tensor_view(a_tensor_view,
356  }
357  else
358  {
359  return pad_tensor_view(a_tensor_view,
363  }
364  }();
365 
366  const auto& b_flat_tensor_view = views.at(I1);
367 
368  // TODO vector write in for C in ColMajor
369  const auto& c_pad_view = [&]() {
370  const auto& c_tensor_view = views.at(I2);
371  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
372  {
373  return pad_tensor_view(c_tensor_view,
377  }
378  else
379  {
380  return pad_tensor_view(c_tensor_view,
384  }
385  }();
386 
387  return make_tuple(a_pad_view, b_flat_tensor_view, c_pad_view);
388  }
389 
390  template <typename PadView>
391  CK_TILE_DEVICE static auto
392  MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
393  {
394  const auto& a_pad_view = views.at(I0);
395  const auto& b_flat_pad_view = views.at(I1);
396  const auto& c_pad_view = views.at(I2);
397 
398  const auto& a_block_window = [&]() {
399  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
400  {
401  return make_tile_window(a_pad_view,
404  {i_m, 0});
405  }
406  else
407  {
408  return make_tile_window(a_pad_view,
411  {0, i_m});
412  }
413  }();
414 
415  const auto& b_flat_block_window =
416  make_tile_window(b_flat_pad_view,
419  {static_cast<int>(i_n / BlockGemmShape::WarpTile::at(idxN)), 0});
420 
421  auto c_block_window = make_tile_window(
422  c_pad_view,
424  {i_m, i_n});
425 
426  return make_tuple(a_block_window, b_flat_block_window, c_block_window);
427  }
428 
429  CK_TILE_DEVICE static void RunFlatmm(const ADataType* a_ptr,
430  const BDataType* b_flat_ptr,
431  CDataType* c_ptr,
432  void* smem_ptr,
433  const FlatmmKernelArgs& kargs,
434  const SplitKBatchOffset& splitk_batch_offset,
435  const index_t block_idx_m,
436  const index_t block_idx_n)
437  {
438  // Create Gemm tensor views, pad views and tile windows
439  const auto& gemm_tensor_views_tuple =
440  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
441  a_ptr, b_flat_ptr, c_ptr, kargs, splitk_batch_offset);
442  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
443  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
444 
445  const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
446 
447  // Run GEMM cooperatively by whole workgroup.
448  const auto& a_block_window = gemm_tile_windows.at(I0);
449  const auto& b_flat_block_window = gemm_tile_windows.at(I1);
450  const auto& d_block_window = gemm_tile_windows.at(I2);
451  const auto& c_block_tile = FlatmmPipeline{}.template operator()(
452  a_block_window, b_flat_block_window, num_loop, smem_ptr);
453 
454  // Run Epilogue Pipeline
455  auto& c_block_window = gemm_tile_windows.at(I2);
456 
457  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
458  c_block_window, c_block_tile, d_block_window, smem_ptr);
459  }
460 
462  {
463  const auto [iM, iN] = TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(blockIdx.x);
464  const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
465  const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
466 
467  const SplitKBatchOffset splitk_batch_offset(kargs);
468  // options
469  const ADataType* a_ptr =
470  static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
471  const BDataType* b_flat_ptr = static_cast<const BDataType*>(kargs.b_shuffle_ptr) +
472  splitk_batch_offset.b_k_split_offset;
473  CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
474 
475  // allocate LDS
476  __shared__ char smem_ptr[GetSmemSize()];
477 
478  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
479  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
481  {
482  RunFlatmm(a_ptr, b_flat_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
483  }
484  }
485 };
486 
487 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:529
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:41
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
Definition: flatmm_kernel.hpp:33
const void * a_ptr
Definition: flatmm_kernel.hpp:53
void * c_ptr
Definition: flatmm_kernel.hpp:55
index_t k_batch
Definition: flatmm_kernel.hpp:56
CK_TILE_HOST FlatmmHostArgs(const void *a_ptr_, const void *b_shuffle_ptr_, void *c_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
Definition: flatmm_kernel.hpp:35
CK_TILE_HOST FlatmmHostArgs()=default
const void * b_shuffle_ptr
Definition: flatmm_kernel.hpp:54
Definition: flatmm_kernel.hpp:99
index_t K
Definition: flatmm_kernel.hpp:105
index_t stride_C
Definition: flatmm_kernel.hpp:108
index_t N
Definition: flatmm_kernel.hpp:104
index_t stride_A
Definition: flatmm_kernel.hpp:106
index_t k_batch
Definition: flatmm_kernel.hpp:109
index_t stride_B
Definition: flatmm_kernel.hpp:107
const void * a_ptr
Definition: flatmm_kernel.hpp:100
index_t M
Definition: flatmm_kernel.hpp:103
const void * b_shuffle_ptr
Definition: flatmm_kernel.hpp:101
void * c_ptr
Definition: flatmm_kernel.hpp:102
Definition: flatmm_kernel.hpp:132
index_t b_k_split_offset
Definition: flatmm_kernel.hpp:169
index_t a_k_split_offset
Definition: flatmm_kernel.hpp:168
index_t splitted_k
Definition: flatmm_kernel.hpp:170
__device__ SplitKBatchOffset(const FlatmmKernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: flatmm_kernel.hpp:133
Definition: flatmm_kernel.hpp:61
static constexpr auto idxK
Definition: flatmm_kernel.hpp:82
static constexpr CK_TILE_HOST auto BlockSize()
Definition: flatmm_kernel.hpp:96
CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs) const
Definition: flatmm_kernel.hpp:461
remove_cvref_t< typename FlatmmPipeline::BlockGemmShape > BlockGemmShape
Definition: flatmm_kernel.hpp:65
static constexpr auto I0
Definition: flatmm_kernel.hpp:77
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: flatmm_kernel.hpp:62
static CK_TILE_DEVICE void RunFlatmm(const ADataType *a_ptr, const BDataType *b_flat_ptr, CDataType *c_ptr, void *smem_ptr, const FlatmmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition: flatmm_kernel.hpp:429
static constexpr auto idxN
Definition: flatmm_kernel.hpp:81
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_flat_ptr, CDataType *c_ptr, const FlatmmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: flatmm_kernel.hpp:281
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: flatmm_kernel.hpp:66
static constexpr auto I2
Definition: flatmm_kernel.hpp:79
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: flatmm_kernel.hpp:346
remove_cvref_t< FlatmmPipeline_ > FlatmmPipeline
Definition: flatmm_kernel.hpp:63
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition: flatmm_kernel.hpp:72
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: flatmm_kernel.hpp:75
remove_cvref_t< typename FlatmmPipeline::BLayout > BLayout
Definition: flatmm_kernel.hpp:68
static CK_TILE_HOST const std::string GetName()
Definition: flatmm_kernel.hpp:84
remove_cvref_t< typename FlatmmPipeline::ALayout > ALayout
Definition: flatmm_kernel.hpp:67
static constexpr auto idxM
Definition: flatmm_kernel.hpp:80
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: flatmm_kernel.hpp:126
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: flatmm_kernel.hpp:91
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: flatmm_kernel.hpp:392
static constexpr auto I1
Definition: flatmm_kernel.hpp:78
static constexpr index_t KernelBlockSize
Definition: flatmm_kernel.hpp:70
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition: flatmm_kernel.hpp:73
static CK_TILE_HOST bool IsSupportedArgument(const FlatmmKernelArgs &kargs)
Definition: flatmm_kernel.hpp:173
remove_cvref_t< typename FlatmmPipeline::CLayout > CLayout
Definition: flatmm_kernel.hpp:69
static constexpr CK_TILE_HOST FlatmmKernelArgs MakeKernelArgs(const FlatmmHostArgs &hostArgs)
Definition: flatmm_kernel.hpp:112
Definition: flatmm_kernel.hpp:16
index_t stride_C
Definition: flatmm_kernel.hpp:29
CK_TILE_HOST FlatmmProblem()=default
index_t M
Definition: flatmm_kernel.hpp:24
index_t stride_B
Definition: flatmm_kernel.hpp:28
CK_TILE_HOST FlatmmProblem(index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
Definition: flatmm_kernel.hpp:18
index_t stride_A
Definition: flatmm_kernel.hpp:27
index_t N
Definition: flatmm_kernel.hpp:25
index_t K
Definition: flatmm_kernel.hpp:26
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:115
Definition: sequence.hpp:52