/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/flatmm_kernel.hpp Source File
flatmm_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
12 
13 namespace ck_tile {
15 {
18  index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
19  : M(M_), N(N_), K(K_), stride_A(stride_A_), stride_B(stride_B_), stride_C(stride_C_)
20  {
21  }
22 
29 };
30 
31 template <int SharedGranularityMN, int SharedGranularityK = 0>
33 {
34  static constexpr int GranularityMN = SharedGranularityMN;
35  static constexpr int GranularityK = SharedGranularityK;
36 
37  const float* ptr;
38 
40  CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_) {}
41  CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, [[maybe_unused]] index_t length_)
42  : ptr(ptr_)
43  {
44  }
45 
47  {
49  if constexpr(GranularityMN == 0)
50  {
51  ret.ptr = ptr + offset / GranularityK;
52  }
53  else
54  {
56  }
57  return ret;
58  }
59 
60  CK_TILE_HOST_DEVICE float operator[](index_t i) const = delete;
61 };
62 
63 template <int SharedGranularityMN>
64 struct FlatmmScalePointer<SharedGranularityMN, 0>
65 {
66  static constexpr int GranularityMN = SharedGranularityMN;
67  static constexpr int GranularityK = 0;
68 
69  static_assert(GranularityMN != 0);
70 
71  const float* ptr;
73 
75  CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_), length(1) {}
76  CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, index_t length_)
77  : ptr(ptr_), length(length_)
78  {
79  }
80 
82  {
84  if constexpr(GranularityMN == 1)
85  {
86  ret.ptr = ptr + offset;
87  ret.length = length - offset;
88  }
89  else
90  {
91  ret.ptr = ptr + offset / GranularityMN;
92  ret.length = length - offset / GranularityMN;
93  }
94  return ret;
95  }
96 
98  {
99  // with additional oob check
100  if constexpr(GranularityMN == 1)
101  return i < length ? ptr[i] : 0;
102  else
103  return i / GranularityMN < length ? ptr[i / GranularityMN] : 0;
104  }
105 };
106 
107 // shared granularityMN = -1 means no scale
108 template <>
109 struct FlatmmScalePointer<-1, 0>
110 {
111  static constexpr int GranularityMN = -1;
112  static constexpr int GranularityK = 0;
113 
114  const float* ptr = nullptr;
115 
117  CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*) {}
118  CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*, index_t) {}
119 
121  {
122  return FlatmmScalePointer{};
123  }
124  CK_TILE_HOST_DEVICE constexpr float operator[](index_t) const
125  {
126  return 1; // alway return 1, it doesn't change the result
127  }
128 };
129 
130 template <index_t NumDTensor = 0>
132 {
134  CK_TILE_HOST BaseFlatmmHostArgs(const void* a_ptr_,
135  const void* b_ptr_,
136  const std::array<const void*, NumDTensor>& ds_ptr_,
137  void* e_ptr_,
138  index_t k_batch_,
139  index_t M_,
140  index_t N_,
141  index_t K_,
142  index_t stride_A_,
143  index_t stride_B_,
144  const std::array<index_t, NumDTensor>& stride_Ds_,
145  index_t stride_E_)
146  : a_ptr(a_ptr_),
147  b_ptr(b_ptr_),
148  ds_ptr(ds_ptr_),
149  e_ptr(e_ptr_),
150  M(M_),
151  N(N_),
152  K(K_),
153  stride_A(stride_A_),
154  stride_B(stride_B_),
155  stride_Ds(stride_Ds_),
156  stride_E(stride_E_),
157  k_batch(k_batch_)
158  {
159  }
160 
161  const void* a_ptr;
162  const void* b_ptr;
163  const std::array<const void*, NumDTensor> ds_ptr;
164  union
165  {
166  void* e_ptr;
167  void* c_ptr;
168  };
174  const std::array<index_t, NumDTensor> stride_Ds;
175  union
176  {
179  };
180 
182 };
183 template <class ScaleM = FlatmmScalePointer<-1>,
184  class ScaleN = FlatmmScalePointer<-1>,
185  index_t NumDTensor = 0>
187 {
189  CK_TILE_HOST ScaleFlatmmHostArgs(const void* a_ptr_,
190  const void* b_shuffle_ptr_,
191  const std::array<const void*, NumDTensor>& ds_ptr_,
192  void* c_ptr_,
193  index_t k_batch_,
194  index_t M_,
195  index_t N_,
196  index_t K_,
197  index_t stride_A_,
198  index_t stride_B_,
199  const std::array<index_t, NumDTensor>& stride_Ds_,
200  index_t stride_C_,
201  ScaleM scale_m_ = nullptr,
202  ScaleN scale_n_ = nullptr)
203  : BaseFlatmmHostArgs(a_ptr_,
204  b_shuffle_ptr_,
205  ds_ptr_,
206  c_ptr_,
207  k_batch_,
208  M_,
209  N_,
210  K_,
211  stride_A_,
212  stride_B_,
213  stride_Ds_,
214  stride_C_),
215  scale_m(scale_m_),
216  scale_n(scale_n_)
217  {
218  }
219  ScaleM scale_m = nullptr;
220  ScaleN scale_n = nullptr;
221 };
222 
223 template <int NumberTensor = 0>
226 
227 template <class ScaleM, class ScaleN, index_t NumDTensor = 0>
229 {
230  const void* a_ptr;
231  // const void* b_shuffle_ptr;
232  const void* b_ptr;
233  const std::array<const void*, NumDTensor> ds_ptr;
234  void* e_ptr;
240  std::array<index_t, NumDTensor> stride_Ds;
243  ScaleM scale_m_ptr = nullptr;
244  ScaleN scale_n_ptr = nullptr;
245 };
246 
247 template <typename TilePartitioner_, typename FlatmmPipeline_, typename EpiloguePipeline_>
249 {
260  static constexpr index_t kBlockSize = FlatmmPipeline::BlockSize;
261  static constexpr bool UsePersistentKernel = FlatmmPipeline::UsePersistentKernel;
262 
265  // Below type is actually accumulation data type - the output of block GEMM.
267 
268  static constexpr index_t NumDTensor = DsDataType::size();
269 
270  static constexpr auto I0 = number<0>();
271  static constexpr auto I1 = number<1>();
272  static constexpr auto I2 = number<2>();
273  static constexpr auto I3 = number<3>();
274 
275  static_assert(DsLayout::size() == DsDataType::size(),
276  "The size of DsLayout and DsDataType should be the same");
277  // using KernelArgs = FlatmmKernelArgs<DsLayout::size()>;
278 
279  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
280  {
281  // clang-format off
282  return concat('_', "gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
283  // clang-format on
284  }
285 
286  CK_TILE_HOST static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
287  {
288  assert(!UsePersistentKernel);
289  return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
290  }
291 
292  template <class ScaleM, class ScaleN>
293  CK_TILE_HOST static constexpr auto
294  GridSize(const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs)
295  {
296  if constexpr(UsePersistentKernel)
297  {
298  hipDeviceProp_t prop;
299  int deviceId = 0; // default device
300 
301  constexpr int block_size = FlatmmKernel::BlockSize().x;
302  int dync_smem_size = 0;
303  int maxActiveBlocksPerCU = 0;
304 
305  [[maybe_unused]] auto e = hipGetDeviceProperties(&prop, deviceId);
306 
307  e = hipOccupancyMaxActiveBlocksPerMultiprocessor(
308  &maxActiveBlocksPerCU,
309  reinterpret_cast<void*>(
310  kentry<1, FlatmmKernel, FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>>),
311  block_size,
312  dync_smem_size);
313 
314  const int persistent_block_size = prop.multiProcessorCount * maxActiveBlocksPerCU;
315  const int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
316 
317  // std::cout << "maxActiveBlocksPerCU: " << maxActiveBlocksPerCU
318  // << ", persistent_block_size: " << persistent_block_size
319  // << ", total_work_tile_cnt: " << total_work_tile_cnt << std::endl;
320 
321  assert(kargs.k_batch == 1);
322  return dim3(min(persistent_block_size, total_work_tile_cnt), 1, kargs.k_batch);
323  }
324  else
325  {
326  return dim3(TilePartitioner::GridSize(kargs.M, kargs.N), 1, kargs.k_batch);
327  }
328  }
329 
330  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
331 
332  template <class ScaleM, class ScaleN>
333  CK_TILE_HOST static constexpr FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>
334  MakeKernelArgs(const ScaleFlatmmHostArgs<ScaleM, ScaleN, DsDataType::size()>& hostArgs)
335  {
336  return {hostArgs.a_ptr,
337  hostArgs.b_ptr,
338  hostArgs.ds_ptr,
339  hostArgs.e_ptr,
340  hostArgs.M,
341  hostArgs.N,
342  hostArgs.K,
343  hostArgs.stride_A,
344  hostArgs.stride_B,
345  hostArgs.stride_Ds,
346  hostArgs.stride_E,
347  hostArgs.k_batch,
348  hostArgs.scale_m,
349  hostArgs.scale_n};
350  }
351 
353  {
354  return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
355  }
357  {
358  return FlatmmPipeline::GetSmemSize();
359  }
360 
362  {
363  template <class KernelArgs>
364  __device__ SplitKBatchOffset(const KernelArgs& kargs, const std::size_t k_id = blockIdx.z)
365  {
366  constexpr auto N1 = BlockGemmShape::WarpTile::at(number<1>{});
367  constexpr auto K1 = BlockGemmShape::WarpTile::at(number<2>{});
368  const index_t K_t = kargs.k_batch * K1;
369  const index_t KRead = (kargs.K + K_t - 1) / K_t * K1;
370 
371  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
372  {
373  a_k_split_offset = k_id * KRead;
374  }
375  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
376  {
377  a_k_split_offset = k_id * KRead * kargs.stride_A;
378  }
379 
380  if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
381  {
382  b_k_split_offset = k_id * KRead * kargs.stride_B * N1;
383  }
384  else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
385  {
386  b_k_split_offset = k_id * KRead * N1;
387  }
388 
389  if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
390  {
391  splitted_k = KRead;
392  }
393  else
394  {
395  splitted_k = kargs.K - KRead * (kargs.k_batch - 1);
396  }
397  }
398 
402  };
403 
404  template <class KernelArgs>
405  CK_TILE_HOST static bool IsSupportedArgument(const KernelArgs& kargs)
406  {
407  if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
409  {
410  if(kargs.k_batch != 1)
411  {
412  std::cerr << "Conditions not met for Kbatch >1 !" << std::endl;
413  return false;
414  }
415  }
416  if constexpr(UsePersistentKernel)
417  {
418  if(kargs.k_batch != 1)
419  {
420  std::cerr << "Persistent mode doesn't support Kbatch >1 !" << std::endl;
421  return false;
422  }
423  }
424 
425  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
426  {
427  if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false)
428  {
429  std::cerr << "Can't support K that is not a multiple of KPerBlock"
430  " without padding!"
431  << std::endl;
432  return false;
433  }
434  if(kargs.K % FlatmmPipeline::GetVectorSizeA() != 0)
435  {
436  std::cerr << "K is not a multiple of vector load size for A tensor!" << std::endl;
437  return false;
438  }
439  }
440  else
441  {
442  if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
443  {
444  std::cerr << "Can't support M that is not a multiple of MPerBlock"
445  " without padding!"
446  << std::endl;
447  return false;
448  }
449  if(kargs.M % FlatmmPipeline::GetVectorSizeA() != 0)
450  {
451  std::cerr << "M is not a multiple of vector load size for A tensor!" << std::endl;
452  return false;
453  }
454  }
455 
456  if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
457  {
458  if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
459  {
460  std::cerr << "Can't support N that is not a multiple of NPerBlock"
461  " without padding!"
462  << std::endl;
463  return false;
464  }
465  if(kargs.N % FlatmmPipeline::GetVectorSizeB() != 0)
466  {
467  std::cerr << "N is not a multiple of vector load size for B tensor!" << std::endl;
468  return false;
469  }
470  }
471  else
472  {
473  if(kargs.K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == false)
474  {
475  std::cerr << "Can't support K that is not a multiple of KPerBlock"
476  " without padding!"
477  << std::endl;
478  return false;
479  }
480  if(kargs.K % FlatmmPipeline::GetVectorSizeB() != 0)
481  {
482  std::cerr << "K is not a multiple of vector load size for B tensor!" << std::endl;
483  return false;
484  }
485  }
486 
487  bool DTesnorIsValid = {true};
488  static_for<0, NumDTensor, 1>{}([&](auto index) {
489  using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
490  if(std::is_same_v<DiLayout, ELayout> == false)
491  {
492  DTesnorIsValid = false;
493  }
494  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
495  {
496  if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
497  {
498  CK_TILE_ERROR("Can't support N for tensor D that is not a multiple of "
499  "NPerBlock without padding!");
500  DTesnorIsValid = false;
501  }
502  if(kargs.N % EpiloguePipeline::GetVectorSizeD(index) != 0)
503  {
504  CK_TILE_ERROR("N is not a multiple of vector load size for D tensor!");
505  DTesnorIsValid = false;
506  }
507  }
508  else
509  {
510  if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
511  {
512  CK_TILE_ERROR("Can't support M for tensor D that is not a multiple of "
513  "MPerBlock without padding!");
514 
515  DTesnorIsValid = false;
516  }
517  if(kargs.M % EpiloguePipeline::GetVectorSizeD(index) != 0)
518  {
519  CK_TILE_ERROR("M is not a multiple of vector load size for D tensor!");
520  DTesnorIsValid = false;
521  }
522  }
523  });
524 
525  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
526  {
527  if(kargs.N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == false)
528  {
529  std::cerr << "Can't support N that is not a multiple of NPerBlock"
530  " without padding!"
531  << std::endl;
532  return false;
533  }
534  if(kargs.N % EpiloguePipeline::GetVectorSizeC() != 0)
535  {
536  std::cerr << "N is not a multiple of vector load size for C tensor!" << std::endl;
537  return false;
538  }
539  }
540  else
541  {
542  if(kargs.M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == false)
543  {
544  std::cerr << "Can't support M that is not a multiple of MPerBlock"
545  " without padding!"
546  << std::endl;
547  return false;
548  }
549  if(kargs.M % EpiloguePipeline::GetVectorSizeC() != 0)
550  {
551  std::cerr << "M is not a multiple of vector load size for C tensor!" << std::endl;
552  return false;
553  }
554  }
555  return DTesnorIsValid;
556  }
557 
558  template <memory_operation_enum DstInMemOp = memory_operation_enum::set, class KernelArgs>
559  CK_TILE_DEVICE static auto
561  const BDataType* b_flat_ptr,
562  const std::array<const void*, NumDTensor>& ds_ptr,
563  EDataType* e_ptr,
564  const KernelArgs& kargs,
565  const SplitKBatchOffset& splitk_batch_offset)
566  {
567  const auto& a_tensor_view = [&]() {
568  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
569  {
570  return make_naive_tensor_view<address_space_enum::global>(
571  a_ptr,
572  make_tuple(kargs.M, splitk_batch_offset.splitted_k),
573  make_tuple(kargs.stride_A, 1),
574  number<FlatmmPipeline::GetVectorSizeA()>{},
575  number<1>{});
576  }
577  else
578  {
579  return make_naive_tensor_view<address_space_enum::global>(
580  a_ptr,
581  make_tuple(splitk_batch_offset.splitted_k, kargs.M),
582  make_tuple(kargs.stride_A, 1),
583  number<FlatmmPipeline::GetVectorSizeA()>{},
584  number<1>{});
585  }
586  }();
587 
588  index_t kFlatK =
589  FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2));
590  index_t kFlatN = kargs.N * kargs.K / kFlatK;
591  const auto& b_flat_tensor_view = [&]() {
592  return make_naive_tensor_view<address_space_enum::global>(
593  b_flat_ptr,
594  make_tuple(kFlatN, kFlatK),
595  make_tuple(kFlatK, 1),
596  number<FlatmmPipeline::GetVectorSizeB()>{},
597  number<1>{});
598  }();
599 
600  const auto& ds_tensor_view = generate_tuple(
601  [&](auto i) {
602  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
603  using DDataType_ = remove_cvref_t<std::tuple_element_t<i.value, DsDataType>>;
604  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
605  {
606  return make_naive_tensor_view<address_space_enum::global>(
607  static_cast<const DDataType_*>(ds_ptr[i]),
608  make_tuple(kargs.M, kargs.N),
609  make_tuple(kargs.stride_Ds[i], 1),
610  number<EpiloguePipeline::GetVectorSizeD(i)>{},
611  number<1>{});
612  }
613  else
614  {
615  return make_naive_tensor_view<address_space_enum::global>(
616  static_cast<const DDataType_*>(ds_ptr[i]),
617  make_tuple(kargs.N, kargs.M),
618  make_tuple(kargs.stride_Ds[i], 1),
619  number<EpiloguePipeline::GetVectorSizeD(i)>{},
620  number<1>{});
621  }
622  },
624 
625  // TODO: enable vector write for C in ColMajor
626  const auto& e_tensor_view = [&]() {
627  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
628  {
629  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
630  e_ptr,
631  make_tuple(kargs.M, kargs.N),
632  make_tuple(kargs.stride_E, 1),
633  number<EpiloguePipeline::GetVectorSizeC()>{},
634  number<1>{});
635  }
636  else
637  {
638  return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
639  e_ptr,
640  make_tuple(kargs.N, kargs.M),
641  make_tuple(kargs.stride_E, 1),
642  number<1>{},
643  number<1>{});
644  }
645  }();
646 
647  constexpr int ScaleGranularityM = decltype(kargs.scale_m_ptr)::GranularityMN;
648  constexpr int ScaleGranularityN = decltype(kargs.scale_n_ptr)::GranularityMN;
649 
650  constexpr int ScaleGranularityKA = decltype(kargs.scale_m_ptr)::GranularityK;
651  constexpr int ScaleGranularityKB = decltype(kargs.scale_n_ptr)::GranularityK;
652 
653  auto scale_stride_m = ScaleGranularityM == 0 ? 0 // per-tensor scale
654  : 1; // per-token scale
655  auto scale_stride_n = ScaleGranularityN == 0 ? 0 // per-tensor scale
656  : 1; // per-channel scale
657 
658  static_assert(ScaleGranularityM == 0 || ScaleGranularityM == 1 || ScaleGranularityM == -1,
659  "only support per-tensor or per-row scaling");
660  static_assert(ScaleGranularityN == 0 || ScaleGranularityN == 1 || ScaleGranularityN == -1,
661  "only support per-tensor or per-column scaling");
662 
663  const auto scale_m_view = make_naive_tensor_view<address_space_enum::global>(
664  kargs.scale_m_ptr.ptr,
665  make_tuple(kargs.M / ScaleGranularityM,
666  ScaleGranularityKA == 0
667  ? 1
668  : splitk_batch_offset.splitted_k /
669  (ScaleGranularityKA != 0 ? ScaleGranularityKA : 1)),
670  make_tuple(scale_stride_m, 0),
671  number < ScaleGranularityM == 1 ? FlatmmPipeline::GetVectorSizeA() : 1 > {},
672  number<1>{});
673  const auto scale_n_view = make_naive_tensor_view<address_space_enum::global>(
674  kargs.scale_n_ptr.ptr,
675  make_tuple(ScaleGranularityKB == 0
676  ? 1
677  : (splitk_batch_offset.splitted_k /
678  (ScaleGranularityKB != 0 ? ScaleGranularityKB : 1)),
679  kargs.N / ScaleGranularityN),
680  make_tuple(0, scale_stride_n),
681  number < ScaleGranularityN == 1 ? FlatmmPipeline::GetVectorSizeB() : 1 > {},
682  number<1>{});
683 
684  return make_tuple(a_tensor_view,
685  b_flat_tensor_view,
686  ds_tensor_view,
687  e_tensor_view,
688  scale_m_view,
689  scale_n_view);
690  }
691 
692  template <typename TensorView>
693  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
694  {
695  const auto& a_pad_view = [&]() {
696  const auto& a_tensor_view = views.at(I0);
697  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
698  {
699  return pad_tensor_view(a_tensor_view,
703  }
704  else
705  {
706  return pad_tensor_view(a_tensor_view,
710  }
711  }();
712 
713  const auto& b_flat_tensor_view = views.at(I1);
714 
715  const auto& ds_pad_view = generate_tuple(
716  [&](auto i) {
717  const auto& d_tensor_view = views.at(I2);
718  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
719  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
720  {
721  return pad_tensor_view(d_tensor_view[i],
725  }
726  else
727  {
728  return pad_tensor_view(d_tensor_view[i],
732  }
733  },
735 
736  // TODO vector write in for C in ColMajor
737  const auto& e_pad_view = [&]() {
738  const auto& e_tensor_view = views.at(I3);
739  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
740  {
741  return pad_tensor_view(e_tensor_view,
745  }
746  else
747  {
748  return pad_tensor_view(e_tensor_view,
752  }
753  }();
754 
755  return make_tuple(a_pad_view,
756  b_flat_tensor_view,
757  ds_pad_view,
758  e_pad_view,
759  views.at(number<4>{}),
760  views.at(number<5>{}));
761  }
762 
763  template <typename PadView>
764  CK_TILE_DEVICE static auto
765  MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
766  {
767  const auto& a_pad_view = views.at(I0);
768  const auto& b_flat_pad_view = views.at(I1);
769  const auto& ds_pad_view = views.at(I2);
770  const auto& e_pad_view = views.at(I3);
771 
772  const auto& a_block_window = [&]() {
773  if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
774  {
775  return make_tile_window(a_pad_view,
778  {i_m, 0});
779  }
780  else
781  {
782  return make_tile_window(a_pad_view,
785  {0, i_m});
786  }
787  }();
788 
789  const auto& b_flat_block_window =
790  make_tile_window(b_flat_pad_view,
793  {static_cast<int>(i_n / BlockGemmShape::WarpTile::at(I1)), 0});
794 
795  const auto ds_block_window = generate_tuple(
796  [&](auto i) {
797  using DiLayout = remove_cvref_t<std::tuple_element_t<i.value, DsLayout>>;
798  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
799  {
800  return make_tile_window(ds_pad_view[i],
803  {i_m, i_n});
804  }
805  else
806  {
807  return make_tile_window(ds_pad_view[i],
810  {i_n, i_m});
811  }
812  },
814 
815  auto e_block_window = make_tile_window(
816  e_pad_view,
818  {i_m, i_n});
819 
820  constexpr int ScaleGranularityKA = 0; // decltype(kargs.scale_m_ptr)::GranularityK;
821  constexpr int ScaleGranularityKB = 0; // decltype(kargs.scale_n_ptr)::GranularityK;
822 
823  auto scale_m_window = make_tile_window(views.at(number<4>{}),
825  number < ScaleGranularityKA == 0
826  ? TilePartitioner::NPerBlock
827  : TilePartitioner::KPerBlock > {}),
828  {i_m, 0});
829  auto scale_n_window = make_tile_window(views.at(number<5>{}),
830  make_tuple(number < ScaleGranularityKB == 0
831  ? TilePartitioner::MPerBlock
832  : TilePartitioner::KPerBlock > {},
834  {0, i_n});
835 
836  return make_tuple(a_block_window,
837  b_flat_block_window,
838  ds_block_window,
839  e_block_window,
840  scale_m_window,
841  scale_n_window);
842  }
843 
844  template <class ScaleM, class ScaleN, bool UseDefaultScheduler = true>
845  CK_TILE_DEVICE static void
846  RunFlatmm(const ADataType* a_ptr,
847  const BDataType* b_flat_ptr,
848  const std::array<const void*, NumDTensor>& ds_ptr,
849  EDataType* e_ptr,
850  void* smem_ptr_ping,
851  void* smem_ptr_pong,
852  const FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()>& kargs,
853  const SplitKBatchOffset& splitk_batch_offset,
854  const index_t block_idx_m,
855  const index_t block_idx_n)
856  {
857  // Create Gemm tensor views, pad views and tile windows
858  const auto& gemm_tensor_views_tuple =
859  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
860  a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
861  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
862  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
863 
864  const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k);
865 
866  // Run GEMM cooperatively by whole workgroup.
867  const auto& a_block_window = gemm_tile_windows.at(I0);
868  const auto& b_flat_block_window = gemm_tile_windows.at(I1);
869  const auto& d_block_window = gemm_tile_windows.at(I2);
870  const auto& c_block_tile = FlatmmPipeline{}.template operator()(
871  a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
872 
873  auto scale_m_window = gemm_tile_windows.at(number<4>{});
874  auto scale_n_window = gemm_tile_windows.at(number<5>{});
875 
876  // Run Epilogue Pipeline
877  if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1)
878  {
879  auto& c_block_window = gemm_tile_windows.at(I3);
880  EpiloguePipeline{}.template
881  operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
882  c_block_window,
883  c_block_tile,
884  d_block_window,
885  smem_ptr_ping,
886  scale_m_window,
887  scale_n_window);
888  }
889  else if(UseDefaultScheduler || (get_warp_id() == 0))
890  {
891  // Run Epilogue Pipeline
892  auto& c_block_window = gemm_tile_windows.at(I3);
893  EpiloguePipeline{}.template
894  operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
895  c_block_window, c_block_tile, d_block_window, smem_ptr_ping);
896  }
897  }
898 
899  template <class ScaleM, class ScaleN>
900  CK_TILE_DEVICE void operator()(FlatmmKernelArgs<ScaleM, ScaleN, DsDataType::size()> kargs,
901  int partition_idx = blockIdx.x) const
902  {
903  int total_work_tile_cnt = TilePartitioner::GridSize(kargs.M, kargs.N);
904 
905  do
906  {
907  const auto [iM, iN] =
908  TilePartitioner{kargs.M, kargs.N}.GetOutputTileIndex(partition_idx);
909  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
910  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
911 
912  const SplitKBatchOffset splitk_batch_offset(kargs);
913  // options
914  const ADataType* a_ptr =
915  static_cast<const ADataType*>(kargs.a_ptr) + splitk_batch_offset.a_k_split_offset;
916  const BDataType* b_flat_ptr =
917  static_cast<const BDataType*>(kargs.b_ptr) + splitk_batch_offset.b_k_split_offset;
918  EDataType* e_ptr = static_cast<EDataType*>(kargs.e_ptr);
919 
920  // allocate LDS
921  __shared__ char smem_ptr_ping[GetSmemPingSize()];
922  __shared__ char smem_ptr_pong[GetSmemPongSize()];
923 
924  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
925  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
927  {
928  constexpr auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
929  RunFlatmm<ScaleM, ScaleN, scheduler_type>(a_ptr,
930  b_flat_ptr,
931  kargs.ds_ptr,
932  e_ptr,
933  smem_ptr_ping,
934  smem_ptr_pong,
935  kargs,
936  splitk_batch_offset,
937  i_m,
938  i_n);
939  }
940  partition_idx += gridDim.x;
941  } while(UsePersistentKernel && partition_idx < total_work_tile_cnt);
942  }
943 };
944 
945 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
__global__ void kentry(Args... args)
Definition: kernel_launch.hpp:22
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__device__ X atomic_add(X *p_dst, const X &x)
unsigned int uint32_t
Definition: stdint.h:126
Definition: flatmm_kernel.hpp:132
index_t N
Definition: flatmm_kernel.hpp:170
const void * a_ptr
Definition: flatmm_kernel.hpp:161
index_t stride_B
Definition: flatmm_kernel.hpp:173
const std::array< index_t, NumDTensor > stride_Ds
Definition: flatmm_kernel.hpp:174
index_t stride_C
Definition: flatmm_kernel.hpp:178
CK_TILE_HOST BaseFlatmmHostArgs()=default
index_t K
Definition: flatmm_kernel.hpp:171
const void * b_ptr
Definition: flatmm_kernel.hpp:162
index_t k_batch
Definition: flatmm_kernel.hpp:181
index_t stride_E
Definition: flatmm_kernel.hpp:177
CK_TILE_HOST BaseFlatmmHostArgs(const void *a_ptr_, const void *b_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition: flatmm_kernel.hpp:134
index_t stride_A
Definition: flatmm_kernel.hpp:172
const std::array< const void *, NumDTensor > ds_ptr
Definition: flatmm_kernel.hpp:163
void * c_ptr
Definition: flatmm_kernel.hpp:167
void * e_ptr
Definition: flatmm_kernel.hpp:166
index_t M
Definition: flatmm_kernel.hpp:169
Definition: flatmm_kernel.hpp:362
index_t b_k_split_offset
Definition: flatmm_kernel.hpp:400
index_t a_k_split_offset
Definition: flatmm_kernel.hpp:399
__device__ SplitKBatchOffset(const KernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: flatmm_kernel.hpp:364
index_t splitted_k
Definition: flatmm_kernel.hpp:401
Definition: flatmm_kernel.hpp:229
ScaleN scale_n_ptr
Definition: flatmm_kernel.hpp:244
void * e_ptr
Definition: flatmm_kernel.hpp:234
std::array< index_t, NumDTensor > stride_Ds
Definition: flatmm_kernel.hpp:240
index_t K
Definition: flatmm_kernel.hpp:237
ScaleM scale_m_ptr
Definition: flatmm_kernel.hpp:243
const void * b_ptr
Definition: flatmm_kernel.hpp:232
index_t k_batch
Definition: flatmm_kernel.hpp:242
index_t N
Definition: flatmm_kernel.hpp:236
index_t stride_B
Definition: flatmm_kernel.hpp:239
const std::array< const void *, NumDTensor > ds_ptr
Definition: flatmm_kernel.hpp:233
const void * a_ptr
Definition: flatmm_kernel.hpp:230
index_t stride_E
Definition: flatmm_kernel.hpp:241
index_t M
Definition: flatmm_kernel.hpp:235
index_t stride_A
Definition: flatmm_kernel.hpp:238
Definition: flatmm_kernel.hpp:249
static constexpr CK_TILE_HOST auto BlockSize()
Definition: flatmm_kernel.hpp:330
remove_cvref_t< typename FlatmmPipeline::BlockGemmShape > BlockGemmShape
Definition: flatmm_kernel.hpp:253
static constexpr auto I0
Definition: flatmm_kernel.hpp:270
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: flatmm_kernel.hpp:250
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: flatmm_kernel.hpp:258
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: flatmm_kernel.hpp:560
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: flatmm_kernel.hpp:259
static constexpr bool UsePersistentKernel
Definition: flatmm_kernel.hpp:261
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: flatmm_kernel.hpp:266
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: flatmm_kernel.hpp:254
static constexpr auto I2
Definition: flatmm_kernel.hpp:272
static constexpr CK_TILE_HOST auto GridSize(const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs)
Definition: flatmm_kernel.hpp:294
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: flatmm_kernel.hpp:693
CK_TILE_DEVICE void operator()(FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> kargs, int partition_idx=blockIdx.x) const
Definition: flatmm_kernel.hpp:900
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPingSize()
Definition: flatmm_kernel.hpp:352
remove_cvref_t< FlatmmPipeline_ > FlatmmPipeline
Definition: flatmm_kernel.hpp:251
remove_cvref_t< typename FlatmmPipeline::CLayout > ELayout
Definition: flatmm_kernel.hpp:257
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition: flatmm_kernel.hpp:263
remove_cvref_t< typename FlatmmPipeline::BLayout > BLayout
Definition: flatmm_kernel.hpp:256
static constexpr index_t NumDTensor
Definition: flatmm_kernel.hpp:268
static CK_TILE_HOST const std::string GetName()
Definition: flatmm_kernel.hpp:279
static constexpr CK_TILE_HOST FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> MakeKernelArgs(const ScaleFlatmmHostArgs< ScaleM, ScaleN, DsDataType::size()> &hostArgs)
Definition: flatmm_kernel.hpp:334
static constexpr index_t kBlockSize
Definition: flatmm_kernel.hpp:260
remove_cvref_t< typename FlatmmPipeline::ALayout > ALayout
Definition: flatmm_kernel.hpp:255
static constexpr auto I3
Definition: flatmm_kernel.hpp:273
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: flatmm_kernel.hpp:286
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: flatmm_kernel.hpp:765
static constexpr auto I1
Definition: flatmm_kernel.hpp:271
static CK_TILE_DEVICE void RunFlatmm(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_ping, void *smem_ptr_pong, const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition: flatmm_kernel.hpp:846
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition: flatmm_kernel.hpp:264
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: flatmm_kernel.hpp:405
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPongSize()
Definition: flatmm_kernel.hpp:356
Definition: flatmm_kernel.hpp:15
index_t stride_C
Definition: flatmm_kernel.hpp:28
CK_TILE_HOST FlatmmProblem()=default
index_t M
Definition: flatmm_kernel.hpp:23
index_t stride_B
Definition: flatmm_kernel.hpp:27
CK_TILE_HOST FlatmmProblem(index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
Definition: flatmm_kernel.hpp:17
index_t stride_A
Definition: flatmm_kernel.hpp:26
index_t N
Definition: flatmm_kernel.hpp:24
index_t K
Definition: flatmm_kernel.hpp:25
constexpr CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t) const
Definition: flatmm_kernel.hpp:120
constexpr CK_TILE_HOST_DEVICE FlatmmScalePointer(const float *, index_t)
Definition: flatmm_kernel.hpp:118
constexpr CK_TILE_HOST_DEVICE FlatmmScalePointer()=default
constexpr CK_TILE_HOST_DEVICE FlatmmScalePointer(const float *)
Definition: flatmm_kernel.hpp:117
constexpr CK_TILE_HOST_DEVICE float operator[](index_t) const
Definition: flatmm_kernel.hpp:124
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
Definition: flatmm_kernel.hpp:81
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float *ptr_, index_t length_)
Definition: flatmm_kernel.hpp:76
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float *ptr_)
Definition: flatmm_kernel.hpp:75
CK_TILE_HOST_DEVICE float operator[](index_t i) const
Definition: flatmm_kernel.hpp:97
const float * ptr
Definition: flatmm_kernel.hpp:69
index_t length
Definition: flatmm_kernel.hpp:72
Definition: flatmm_kernel.hpp:33
static constexpr int GranularityMN
Definition: flatmm_kernel.hpp:34
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
Definition: flatmm_kernel.hpp:46
const float * ptr
Definition: flatmm_kernel.hpp:37
CK_TILE_HOST_DEVICE FlatmmScalePointer()=default
static constexpr int GranularityK
Definition: flatmm_kernel.hpp:35
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float *ptr_)
Definition: flatmm_kernel.hpp:40
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float *ptr_, [[maybe_unused]] index_t length_)
Definition: flatmm_kernel.hpp:41
CK_TILE_HOST_DEVICE float operator[](index_t i) const =delete
Definition: flatmm_kernel.hpp:187
CK_TILE_HOST ScaleFlatmmHostArgs()=default
ScaleM scale_m
Definition: flatmm_kernel.hpp:219
ScaleN scale_n
Definition: flatmm_kernel.hpp:220
CK_TILE_HOST ScaleFlatmmHostArgs(const void *a_ptr_, const void *b_shuffle_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *c_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_C_, ScaleM scale_m_=nullptr, ScaleN scale_n_=nullptr)
Definition: flatmm_kernel.hpp:189
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:115
Definition: coordinate_transform.hpp:1392
Definition: sequence.hpp:49
Definition: functional.hpp:43