/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Source File
gridwise_gemm_xdl_cshuffle_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
16 
17 namespace ck {
18 
19 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
20 // kernel function Blockers:
21 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
22 // two lds chunks.
23 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
24 // buffer when we declare __shared__ inside blkgemmpipe
25 template <typename GridwiseGemm,
26  bool HasMainKBlockLoop,
27  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
28  index_t MinimumOccupancy = 1,
29  TailNumber TailNum = TailNumber::Full>
30 __global__ void
31 #if CK_USE_LAUNCH_BOUNDS
32  __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
33 #endif
34  // __attribute__((amdgpu_waves_per_eu(1, 1)))
35  kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
36 {
37 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
38  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
39 
40  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
41 
42  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
43  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
44  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
45  karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
46  p_shared,
47  karg);
48 #else
49  ignore = karg;
50 #endif // end of if (defined(__gfx9__))
51 }
52 
53 template <typename GridwiseGemm,
54  bool HasMainKBlockLoop,
55  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
56  index_t MinimumOccupancy = 1,
57  TailNumber TailNum = TailNumber::Full>
58 __global__ void
59 #if CK_USE_LAUNCH_BOUNDS
60  __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
61 #endif
62  // __attribute__((amdgpu_waves_per_eu(1, 1)))
63  kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
64 {
65 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
66  // Pass two lds pointer is the key to tell compiler that ds_read/write
67  // operate on different lds chunk at same time without order dependecy
68  __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
69  __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
70 
71  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
72 
73  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
74  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
75  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
76  karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
77  p_shared_0,
78  p_shared_1,
79  karg);
80 #else
81  ignore = karg;
82 #endif // end of if (defined(__gfx9__))
83 }
84 
85 template <typename ALayout,
86  typename BLayout,
87  typename CLayout,
88  typename ADataType,
89  typename BDataType,
90  typename AccDataType,
91  typename CShuffleDataType,
92  typename CDataType,
93  typename AElementwiseOperation,
94  typename BElementwiseOperation,
95  typename CElementwiseOperation,
97  index_t BlockSize,
98  index_t MPerBlock,
99  index_t NPerBlock,
100  index_t KPerBlock,
101  index_t AK1Value,
102  index_t BK1Value,
103  index_t MPerXdl,
104  index_t NPerXdl,
105  index_t MXdlPerWave,
106  index_t NXdlPerWave,
107  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
108  typename ABlockTransferThreadClusterArrangeOrder,
109  typename ABlockTransferSrcAccessOrder,
110  index_t ABlockTransferSrcVectorDim,
111  index_t ABlockTransferSrcScalarPerVector,
112  index_t ABlockTransferDstScalarPerVector_AK1,
113  bool AThreadTransferSrcResetCoordinateAfterRun,
114  index_t ABlockLdsExtraM,
115  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
116  typename BBlockTransferThreadClusterArrangeOrder,
117  typename BBlockTransferSrcAccessOrder,
118  index_t BBlockTransferSrcVectorDim,
119  index_t BBlockTransferSrcScalarPerVector,
120  index_t BBlockTransferDstScalarPerVector_BK1,
121  bool BThreadTransferSrcResetCoordinateAfterRun,
122  index_t BBlockLdsExtraN,
123  index_t CShuffleMXdlPerWavePerShuffle,
124  index_t CShuffleNXdlPerWavePerShuffle,
125  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
126  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
129  typename ComputeTypeA = CDataType,
130  typename ComputeTypeB = ComputeTypeA,
131  bool PermuteA = false,
132  bool PermuteB = false>
133 struct GridwiseGemm_xdl_cshuffle_v3
134 {
135  static constexpr auto I0 = Number<0>{};
136  static constexpr auto I1 = Number<1>{};
137  static constexpr auto I2 = Number<2>{};
138  static constexpr auto I3 = Number<3>{};
139  static constexpr auto I4 = Number<4>{};
140  static constexpr auto I5 = Number<5>{};
141  static constexpr auto I6 = Number<6>{};
142  static constexpr auto I7 = Number<7>{};
143 
144  // K1 should be Number<...>
145  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
146  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
147  static constexpr auto AK1Number = Number<AK1Value>{};
148  static constexpr auto BK1Number = Number<BK1Value>{};
149 
150  static constexpr index_t KPack =
153 
155 
156  static constexpr index_t APackedSize = []() {
158  return 2;
159  else
160  return 1;
161  }();
162 
163  static constexpr index_t BPackedSize = []() {
165  return 2;
166  else
167  return 1;
168  }();
169 
170  __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
171  {
172  return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
173  }
174 
175  __host__ static auto CalculateMPadded(index_t M)
176  {
177  return math::integer_least_multiple(M, MPerBlock);
178  }
179 
180  __host__ static auto CalculateNPadded(index_t N)
181  {
182  return math::integer_least_multiple(N, NPerBlock);
183  }
184 
185  __host__ static auto CalculateKPadded(index_t K)
186  {
187  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
188  }
189 
190  __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
191  {
192  auto K_t = K_Batch * KPerBlock;
193  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
194  }
195 
196  __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
197  {
198  auto K_t = K_Batch * KPerBlock;
199  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
200  }
201 
202  __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
203  {
204  auto K_t = K_Batch * KPerBlock;
205  return (K + K_t - 1) / K_t * KPerBlock;
206  }
207 
208  __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
209  {
210  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
211  auto K_t = K_Batch * KReadVec;
212  return (K + K_t - 1) / K_t * KReadVec;
213  }
214 
215  __host__ static auto CalculateMBlock(index_t M)
216  {
217  return math::integer_divide_ceil(M, MPerBlock);
218  }
219 
220  __host__ static auto CalculateNBlock(index_t N)
221  {
222  return math::integer_divide_ceil(N, NPerBlock);
223  }
224 
225  template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
226  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
227  {
228  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
229  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
230 
232  TileDesc_K0_MN_K1{},
238  }
239 
240  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
241  index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
242  {
243  const auto a_grid_desc_mraw_kraw = [&]() {
244  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
245  {
246  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
247  }
248  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
249  {
250  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
251  }
252  }();
253 
255 
256  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
257  GemmSpec == GemmSpecialization::MNKPadding)
258  {
259  // pad both M and K
260  const auto a_grid_desc_m_k =
261  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
263  make_right_pad_transform(K, KPad - K)),
266 
267  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
268  a_grid_desc_m_k,
273 
274  return a_grid_desc_ak0_m_ak1;
275  }
276  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
277  GemmSpec == GemmSpecialization::MNPadding)
278  {
279  // pad M, but not K
280  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
281  a_grid_desc_mraw_kraw,
283  make_right_pad_transform(M, MPad - M)),
286 
287  return a_grid_desc_ak0_m_ak1;
288  }
289  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
290  GemmSpec == GemmSpecialization::NKPadding)
291  {
292  // pad K, but not M
293  const auto a_grid_desc_m_k = transform_tensor_descriptor(
294  a_grid_desc_mraw_kraw,
298 
299  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
300  a_grid_desc_m_k,
305 
306  return a_grid_desc_ak0_m_ak1;
307  }
308  else
309  {
310  // not pad M or K
311  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
312  a_grid_desc_mraw_kraw,
317 
318  return a_grid_desc_ak0_m_ak1;
319  }
320  }
321 
322  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
323  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
324  {
325  const auto b_grid_desc_nraw_kraw = [&]() {
327  {
328  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
329  }
331  {
332  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
333  }
334  }();
335 
337 
338  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
339  GemmSpec != GemmSpecialization::Default),
340  "pk_i4_t does not support padding");
341 
342  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
343  GemmSpec == GemmSpecialization::MNKPadding)
344  {
345  // pad both N and K
346  const auto b_grid_desc_n_k =
347  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
349  make_right_pad_transform(K, KPad - K)),
352 
353  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
354  b_grid_desc_n_k,
359 
360  return b_grid_desc_bk0_n_bk1;
361  }
362  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
363  GemmSpec == GemmSpecialization::MNPadding)
364  {
365  // pad N, but not K
366  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
367  b_grid_desc_nraw_kraw,
369  make_right_pad_transform(N, NPad - N)),
372 
373  return b_grid_desc_bk0_n_bk1;
374  }
375  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
376  GemmSpec == GemmSpecialization::MKPadding)
377  {
378  // pad K, but not N
379  const auto b_grid_desc_n_k = transform_tensor_descriptor(
380  b_grid_desc_nraw_kraw,
384 
385  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
386  b_grid_desc_n_k,
391 
392  return b_grid_desc_bk0_n_bk1;
393  }
394  else
395  {
396  if constexpr(!PermuteB)
397  {
398  // not pad N or K
399  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
400  b_grid_desc_nraw_kraw,
405 
406  return b_grid_desc_bk0_n_bk1;
407  }
408  else
409  {
410  // Pre-shuffled Weight
411  // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1]
412  constexpr index_t BK01 = KPerBlock / BK1Value;
413  const index_t BK0_ = StrideB / BK1Value;
414  const index_t BK00 = BK0_ / BK01;
415 
416  const auto b_grid_desc_bk00_n_bk01_bk1_permute =
417  make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value));
418 
419  const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor(
420  b_grid_desc_bk00_n_bk01_bk1_permute,
423  make_pass_through_transform(BK1Value)),
426 
427  return b_grid_desc_bk0_n_bk1_permute;
428  }
429  }
430  }
431 
432  template <typename ABlockDesc_AK0_M_AK1>
433  __host__ __device__ static constexpr auto
434  MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
435  {
436  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
437 
438  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
439  }
440 
441  template <typename BBlockDesc_BK0_N_BK1>
442  __host__ __device__ static constexpr auto
443  MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
444  {
445  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
446 
447  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
448  }
449 
450  __host__ __device__ static auto
452  {
453  const auto c_grid_desc_mraw_nraw = [&]() {
455  {
456  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
457  }
459  {
460  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
461  }
462  }();
463 
464  // pad M and N
465  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
467  make_right_pad_transform(N, NPad - N)),
470 #if 0
472 
473  if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
474  GemmSpec == GemmSpecialization::MNKPadding)
475  {
476  // pad M and N
477  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
479  make_right_pad_transform(N, NPad - N)),
482  }
483  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
484  GemmSpec == GemmSpecialization::MKPadding)
485  {
486  // pad M, but not N
488  c_grid_desc_mraw_nraw,
492  }
493  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
494  GemmSpec == GemmSpecialization::NKPadding)
495  {
496  // pad N, but not M
498  c_grid_desc_mraw_nraw,
502  }
503  else
504  {
505  // not pad M or N
506  return c_grid_desc_mraw_nraw;
507  }
508 #endif
509  }
510 
511  struct Problem
512  {
513  __host__ Problem(index_t M_,
514  index_t N_,
515  index_t K_,
516  index_t StrideA_,
517  index_t StrideB_,
518  index_t StrideC_,
519  index_t KBatch_)
520  : M{M_},
521  N{N_},
522  K{K_},
523  StrideA{StrideA_},
524  StrideB{StrideB_},
525  StrideC{StrideC_},
526  KBatch{KBatch_},
529  KRead{CalculateKRead(K_, KBatch_)},
530  KPadded{CalculateKPadded(K_, KBatch_)},
531  AK0{CalculateAK0Padded(K_, KBatch_)},
532  BK0{CalculateBK0Padded(K_, KBatch_)},
533  MBlock{CalculateMBlock(M_)},
535  {
536  }
537 
538  __host__ void Print() const
539  {
540  std::cout << "problem {"
541  << "M:" << M << ", "
542  << "N:" << N << ", "
543  << "K:" << K << ", "
544  << "SA:" << StrideA << ", "
545  << "SB:" << StrideB << ", "
546  << "SC:" << StrideC << ", "
547  << "MP:" << MPadded << ", "
548  << "NP:" << NPadded << ", "
549  << "KRead:" << KRead << ", "
550  << "KP:" << KPadded << ", "
551  << "AK0:" << AK0 << ", "
552  << "BK0:" << BK0 << ", "
553  << "MBlock: " << MBlock << ", "
554  << "NBlock: " << NBlock << "}" << std::endl;
555  }
556 
557  index_t M;
558  index_t N;
559  index_t K;
563  index_t KBatch;
566  index_t KRead;
568  index_t AK0;
569  index_t BK0;
570  index_t MBlock;
571  index_t NBlock;
572  };
573 
574  // Argument
575  struct Argument : public tensor_operation::device::BaseArgument, public Problem
576  {
577  __host__ Argument(const ADataType* p_a_grid_,
578  const BDataType* p_b_grid_,
579  CDataType* p_c_grid_,
580  index_t M_,
581  index_t N_,
582  index_t K_,
583  index_t StrideA_,
584  index_t StrideB_,
585  index_t StrideC_,
586  index_t k_batch_,
587  bool is_reduce_ = false)
588  : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_},
589  p_a_grid{p_a_grid_},
590  p_b_grid{p_b_grid_},
591  p_c_grid{p_c_grid_},
592  is_reduce(is_reduce_)
593  {
594  }
595 
596  __host__ __device__ inline bool IsReduceAdd() const
597  {
598  return (Problem::KBatch > 1) && is_reduce;
599  }
600 
601  __host__ __device__ inline bool IsAtomicAdd() const
602  {
603  return (Problem::KBatch > 1) && (!is_reduce);
604  }
605 
606  const ADataType* p_a_grid;
607  const BDataType* p_b_grid;
608  CDataType* p_c_grid;
609  bool is_reduce;
610  };
611 
613  {
614 
615  __device__ SplitKBatchOffset(Argument& karg)
616  {
617  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
618  {
619  a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
620  }
621  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
622  {
623  a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
624  }
625 
626  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
627  {
628  b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
629  }
630  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
631  {
632  if constexpr(!PermuteB)
633  {
634  b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
635  }
636  else
637  {
638  const int k0_offset = karg.KRead * karg.N;
639  b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
640  }
641  }
642 
643  if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
644  {
645  karg.K = karg.KRead;
646  }
647  else
648  {
649  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
650  }
651 
652  if(karg.IsReduceAdd())
653  {
654  c_reduce_offset = blockIdx.z * karg.M * karg.N;
655  }
656  else
657  {
658  c_reduce_offset = 0;
659  }
660  }
661 
665  };
666 
667  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
668  {
669  // A matrix in LDS memory, dst of blockwise copy
670  if constexpr(ABlockLdsExtraM)
671  {
675  }
676  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
677  // in some cases.
679  {
680  constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize;
681  constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
682  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
683  make_tuple(
684  AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
686 
687  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
688  a_lds_block_desc,
694 
695  constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor(
696  a_lds_block_desc_permuted,
702 
703  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
704  a_lds_block_desc_ak0_mldslayer_m_ak1,
711 
712  return a_lds_block_desc_ak0_m_ak1;
713  }
714  else // ColumnMajor A
715  {
716  // kfold and mpair dimension is not always required.
717  // more dimension in merge_transform increase the difficulty of generating immarg offset
718  // for compiler.
719  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
720  constexpr auto M1 = MPerBlock / M0;
721 
722  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
723  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
724  constexpr auto KThreadRead = 64 / MPerXdl;
725  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
726 
727  constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
728  ? 1
729  : 128 / (AK1Number * M0 * sizeof(ADataType));
730  constexpr auto KThreadReadPerm =
731  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
732  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
733  : KThreadRead;
734 
735  // 1<=mpair<=n0
736  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
737  ? 1
738  : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
739  ? M0
740  : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
741 
742  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
746  Number<kfold * M0 / mpair>{},
747  Number<mpair>{},
748  AK1Number));
749 
750  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
751  a_lds_block_desc,
752  make_tuple(
756  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
759  make_tuple(
761  make_tuple(
763 
764  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
765  a_lds_block_desc_permuted,
766  make_tuple(
774  Sequence<1>{},
775  Sequence<2>{},
776  Sequence<3>{},
777  Sequence<4>{},
778  Sequence<5>{}),
780  Sequence<2>{},
781  Sequence<0, 3>{},
782  Sequence<4, 5>{},
783  Sequence<6>{},
784  Sequence<7>{}));
785 
786  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
787  a_lds_block_desc_unmerged,
790  Number<KThreadWrite / kfold / KThreadReadPerm>{},
791  Number<kfold>{},
798 
799  return a_lds_block_desc_ak0_m_ak1;
800  }
801  }
802 
803  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
804  {
805  // B matrix in LDS memory, dst of blockwise copy
806  if constexpr(BBlockLdsExtraN)
807  {
811  }
813  {
814  // NLdsLayer * K0 as logical Bank
815  constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize;
816  constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
817  constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
818  make_tuple(
819  BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number),
821 
822  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
823  b_lds_block_desc,
829 
830  constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
831  b_lds_block_desc_permuted,
837 
838  constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
839  b_lds_block_desc_bk0_nldslayer_n_bk1,
846 
847  return b_lds_block_desc_bk0_n_bk1;
848  }
849  else // RowMajor B
850  {
851  constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
852  constexpr auto N1 = NPerBlock / N0;
853 
854  constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
855  constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
856  constexpr auto KThreadRead = 64 / NPerXdl;
857  constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
858 
859  constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
860  ? 1
861  : 128 / (BK1Number * N0 * sizeof(BDataType));
862  constexpr auto KThreadReadPerm =
863  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
864  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
865  : KThreadRead;
866 
867  // 1<=npair<=n0
868  constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128)
869  ? 1
870  : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0
871  ? N0
872  : 128 / (BK1Number * NPerXdl * sizeof(BDataType)));
873 
874  constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
878  Number<kfold * N0 / npair>{},
879  Number<npair>{},
880  BK1Number));
881 
882  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
883  b_lds_block_desc,
884  make_tuple(
888  make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
891  make_tuple(
893  make_tuple(
895 
896  constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
897  b_lds_block_desc_permuted,
898  make_tuple(
906  Sequence<1>{},
907  Sequence<2>{},
908  Sequence<3>{},
909  Sequence<4>{},
910  Sequence<5>{}),
912  Sequence<2>{},
913  Sequence<0, 3>{},
914  Sequence<4, 5>{},
915  Sequence<6>{},
916  Sequence<7>{}));
917 
918  constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
919  b_lds_block_desc_unmerged,
922  Number<KThreadWrite / kfold / KThreadReadPerm>{},
923  Number<kfold>{},
930 
931  return b_lds_block_desc_bk0_n_bk1;
932  }
933  }
934 
936  {
937  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
938  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
939 
940  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
942  make_tuple(I1,
944  I1,
946 
947  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
948  }
949 
952  BlkGemmPipelineVer,
953  BlkGemmPipeSched,
954  BlockSize,
955  ADataType,
956  BDataType,
957  ComputeTypeA,
958  AccDataType,
965  ABlockTransferSrcScalarPerVector,
966  BBlockTransferSrcScalarPerVector,
967  MPerBlock,
968  NPerBlock,
969  KPerBlock,
970  MPerXdl,
971  NPerXdl,
972  MXdlPerWave,
973  NXdlPerWave,
974  KPack>())>;
975 
976  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
977  {
978  // LDS allocation for A and B: be careful of alignment
979  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
980  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
981 
982  // lds max alignment
983  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
984 
985  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
986  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
987 
988  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
989  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
990 
991  // LDS allocation for C shuffle in LDS
992  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
994 
995  constexpr auto c_block_size =
996  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
997 
998  return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize +
999  b_block_space_size_aligned * sizeof(BDataType) / BPackedSize),
1000  c_block_size * sizeof(CShuffleDataType));
1001  }
1002 
1003  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1004  __host__ static constexpr bool CheckValidity(const Argument& karg)
1005  {
1006  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1007  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1008  "Invalid tuning param!");
1009 
1010  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1015  {
1016  if(!(karg.M % MPerBlock == 0))
1017  {
1018  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1019  {
1020  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1021  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1022  << std::endl;
1023  }
1024  return false;
1025  }
1026  }
1027 
1028  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1033  {
1034  if(!(karg.N % NPerBlock == 0))
1035  {
1036  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1037  {
1038  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1039  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1040  << std::endl;
1041  }
1042  return false;
1043  }
1044  }
1045 
1046  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1050  {
1051 
1052  auto K_t = karg.KBatch * KPerBlock;
1053  if(!(karg.K % K_t == 0))
1054  {
1055  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1056  {
1057  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1058  << karg.K << " " << __FILE__ << ":" << __LINE__
1059  << ", in function: " << __func__ << std::endl;
1060  }
1061  return false;
1062  }
1063  }
1064  else
1065  {
1066  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1067  auto K_t = karg.KBatch * KReadVec;
1068  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1069  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1070  {
1071  return false;
1072  }
1073  }
1074 
1076  {
1077  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1078  {
1079  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1080  {
1081  std::cout << "Arg K (" << karg.K
1082  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1083  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1084  << __LINE__ << ", in function: " << __func__ << std::endl;
1085  }
1086  return false;
1087  }
1088  }
1089  else
1090  {
1091  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1092  {
1093  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1094  {
1095  std::cout << "Arg M (" << karg.M
1096  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1097  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1098  << __LINE__ << ", in function: " << __func__ << std::endl;
1099  }
1100  return false;
1101  }
1102  }
1103 
1105  {
1106  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1107  {
1108  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1109  {
1110  std::cout << "Arg N (" << karg.N
1111  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1112  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1113  << __LINE__ << ", in function: " << __func__ << std::endl;
1114  }
1115  return false;
1116  }
1117  }
1118  else
1119  {
1120  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1121  {
1122  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1123  {
1124  std::cout << "Arg K (" << karg.K
1125  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1126  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1127  << __LINE__ << ", in function: " << __func__ << std::endl;
1128  }
1129  return false;
1130  }
1131  }
1132 
1134  {
1135  if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1136  {
1137  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1138  {
1139  std::cout << "Arg N (" << karg.N
1140  << ") value is not a multiple of "
1141  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1142  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1143  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1144  << std::endl;
1145  }
1146  return false;
1147  }
1148  }
1149  else
1150  {
1151  if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1152  {
1153  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1154  {
1155  std::cout << "Arg M (" << karg.M
1156  << ") value is not a multiple of "
1157  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1158  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1159  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1160  << std::endl;
1161  }
1162  return false;
1163  }
1164  }
1165 
1166  if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
1167  is_same<remove_cvref_t<CDataType>, float>::value ||
1169  is_same<remove_cvref_t<CDataType>, int32_t>::value))
1170  {
1171  if(!karg.IsReduceAdd())
1172  {
1173  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1174  {
1175  std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__
1176  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1177  }
1178  if(karg.KBatch > 1)
1179  {
1180  return false;
1181  }
1182  }
1183  }
1184 
1185  // check gridwise gemm pipeline
1186  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1187 
1188  if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
1189  {
1190  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1191  {
1192  return false;
1193  }
1194  }
1195 
1196  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1197  return true;
1198  }
1199 
1200  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1201  {
1202  const index_t num_loop = K / KPerBlock;
1203 
1204  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1205  }
1206 
1207  __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1208  {
1209  const index_t num_loop = K / KPerBlock;
1210 
1211  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1212  }
1213 
1214  template <typename CGridDesc>
1215  __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1216  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1217  {
1218  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1219  c_grid_desc_m_n,
1224 
1225  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1226  }
1227 
1228  // return block_id to C matrix tile idx (m0, n0) mapping
1229  // if arch = gfx942
1231  // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
1232 
1233  template <typename AGridDesc_AK0_M_K1,
1234  typename BGridDesc_BK0_N_K1,
1235  typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1236  bool HasMainKBlockLoop,
1237  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1238  TailNumber TailNum = TailNumber::Odd>
1239  __device__ static void Run(const ADataType* p_a_grid,
1240  const BDataType* p_b_grid,
1241  CDataType* p_c_grid,
1242  void* p_shared,
1243  const Problem& problem,
1244  const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1245  const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1246  const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1247  c_grid_desc_mblock_mperblock_nblock_nperblock)
1248  {
1249  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1250  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1251  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1252  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1253  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1254  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1255 
1256  const AElementwiseOperation a_element_op{};
1257  const BElementwiseOperation b_element_op{};
1258  const CElementwiseOperation c_element_op{};
1259 
1260  // divide block work by [M, N]
1261  const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1262 
1263  const auto block_work_idx =
1264  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1265 
1266  if(!block_2_ctile_map.ValidCTileIndex(
1267  block_work_idx,
1268  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1269  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1270  {
1271  return;
1272  }
1273 
1274  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1275  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1276 
1277  // HACK: this force m/n_block_data_idx_on_grid into SGPR
1278  const index_t m_block_data_idx_on_grid =
1279  __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1280 
1281  const index_t n_block_data_idx_on_grid =
1282  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1283 
1284  // lds max alignment
1285  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1286 
1287  // A matrix in LDS memory, dst of blockwise copy
1288  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1289 
1290  // B matrix in LDS memory, dst of blockwise copy
1291  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1292 
1293  // A matrix blockwise copy
1294  auto a_blockwise_copy =
1296  AElementwiseOperation,
1300  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1301  ABlockTransferThreadClusterArrangeOrder,
1302  ADataType,
1303  ADataType,
1304  decltype(a_grid_desc_ak0_m_ak1),
1305  decltype(a_block_desc_ak0_m_ak1),
1306  ABlockTransferSrcAccessOrder,
1308  ABlockTransferSrcVectorDim,
1309  2,
1310  ABlockTransferSrcScalarPerVector,
1311  ABlockTransferDstScalarPerVector_AK1,
1312  1,
1313  1,
1314  AThreadTransferSrcResetCoordinateAfterRun,
1315  true,
1316  BlockwiseGemmPipe::GlobalBufferNum>(
1317  a_grid_desc_ak0_m_ak1,
1318  make_multi_index(0, m_block_data_idx_on_grid, 0),
1319  a_element_op,
1320  a_block_desc_ak0_m_ak1,
1321  make_multi_index(0, 0, 0),
1323 
1324  // B matrix blockwise copy
1325  auto b_blockwise_copy =
1327  BElementwiseOperation,
1331  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1332  BBlockTransferThreadClusterArrangeOrder,
1333  BDataType,
1334  BDataType,
1335  decltype(b_grid_desc_bk0_n_bk1),
1336  decltype(b_block_desc_bk0_n_bk1),
1337  BBlockTransferSrcAccessOrder,
1339  BBlockTransferSrcVectorDim,
1340  2,
1341  BBlockTransferSrcScalarPerVector,
1342  BBlockTransferDstScalarPerVector_BK1,
1343  1,
1344  1,
1345  BThreadTransferSrcResetCoordinateAfterRun,
1346  true,
1347  BlockwiseGemmPipe::GlobalBufferNum>(
1348  b_grid_desc_bk0_n_bk1,
1349  make_multi_index(0, n_block_data_idx_on_grid, 0),
1350  b_element_op,
1351  b_block_desc_bk0_n_bk1,
1352  make_multi_index(0, 0, 0),
1354 
1355  // LDS allocation for A and B: be careful of alignment
1356  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1357  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1358 
1359  // Cast after lds
1360  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1361  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1362 
1363  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1364  reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
1365  sizeof(ADataType) /
1366  APackedSize),
1367  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1368 
1369  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1370  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1371 
1372  // Blockwise GEMM pipeline
1373  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1374  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1375  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1376 
1377  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1378  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1379  KPerBlock);
1380 
1381  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1382  a_block_desc_ak0_m_ak1,
1383  a_blockwise_copy,
1384  a_grid_buf,
1385  a_block_buf,
1386  a_block_slice_copy_step,
1387  b_grid_desc_bk0_n_bk1,
1388  b_block_desc_bk0_n_bk1,
1389  b_blockwise_copy,
1390  b_grid_buf,
1391  b_block_buf,
1392  b_block_slice_copy_step,
1393  c_thread_buf,
1394  num_k_block_main_loop);
1395 
1396  // shuffle C and write out
1397  {
1398  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1399  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1400  "wrong!");
1401 
1402  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1403  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1404 
1405  // TODO: hacky, fix it!
1406  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1407  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1408 
1409  // TODO: hacky, fix it!
1410  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1411  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1412  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1413 
1414  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1415  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1416  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1417  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1418  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1419  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1420  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1421  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1422 
1423  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1425 
1426  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1427  static_cast<CShuffleDataType*>(p_shared),
1428  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1429 
1430  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1431  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1432  make_tuple(
1435  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1436  M1, // M1 = MWave
1437  M2, // M2 * M3 * M4 = MPerXdl
1438  M3,
1439  M4)),
1442  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1443  N1, // N1 = NWave
1444  N2))), // N2 = NPerXdl
1446  make_tuple(
1448 
1449  // calculate origin of thread output tensor on global memory
1450  // blockwise GEMM c matrix starting index
1451  const auto c_thread_mtx_on_block =
1452  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1453 
1454  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1455  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1456 
1457  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1459  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1461  make_tuple(Sequence<0>{}));
1462 
1463  const auto m_thread_data_on_block_idx =
1464  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1465  make_multi_index(m_thread_data_on_block));
1466 
1467  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1471  make_tuple(Sequence<0>{}));
1472 
1473  const auto n_thread_data_on_block_idx =
1474  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1475  make_multi_index(n_thread_data_on_block));
1476 
1477  // shuffle: threadwise copy C from VGPR to LDS
1478  auto c_thread_copy_vgpr_to_lds =
1480  CShuffleDataType,
1481  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1482  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1484  Sequence<CShuffleMXdlPerWavePerShuffle,
1485  CShuffleNXdlPerWavePerShuffle,
1486  I1,
1487  I1,
1488  M2,
1489  I1,
1490  M4,
1491  I1>,
1493  7,
1494  1,
1496  1,
1497  true>{
1498  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1499  make_multi_index(0,
1500  0,
1501  m_thread_data_on_block_idx[I1],
1502  n_thread_data_on_block_idx[I1],
1503  m_thread_data_on_block_idx[I2],
1504  m_thread_data_on_block_idx[I3],
1505  m_thread_data_on_block_idx[I4],
1506  n_thread_data_on_block_idx[I2]),
1508 
1509  // shuffle: blockwise copy C from LDS to global
1510  auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1511  ThisThreadBlock, // ThreadGroup
1512  CElementwiseOperation, // ElementwiseOperation,
1513  CGlobalMemoryDataOperation, // DstInMemOp,
1514  Sequence<1,
1515  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1516  1,
1517  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1518  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1519  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1520  CShuffleDataType, // typename SrcData,
1521  CDataType, // typename DstData,
1522  decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1523  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1524  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1525  3, // index_t VectorDim,
1526  CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1527  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1528  false> // bool ThreadTransferDstResetCoordinateAfterRun>
1529  {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1530  make_multi_index(0, 0, 0, 0),
1531  c_grid_desc_mblock_mperblock_nblock_nperblock,
1532  make_multi_index(block_m_id, 0, block_n_id, 0),
1533  c_element_op};
1534 
1535  // space filling curve for threadwise C in VGPR
1536  constexpr auto sfc_c_vgpr =
1539  Sequence<CShuffleMXdlPerWavePerShuffle,
1540  CShuffleNXdlPerWavePerShuffle,
1541  1,
1542  1,
1543  M2,
1544  1,
1545  M4,
1546  1>>{};
1547 
1548  // space filling curve for shuffled blockwise C in global mem
1549  constexpr auto sfc_c_global =
1552  Sequence<1,
1553  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1554  1,
1555  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1556 
1557  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1558 
1559  static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1560 
1561  static_for<0, num_access, 1>{}([&](auto access_id) {
1562  // make sure it's safe to write to LDS
1563  block_sync_lds();
1564 
1565  // each thread write its data from VGPR to LDS
1566  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1567  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1568  c_thread_buf,
1569  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1570  c_shuffle_block_buf);
1571 
1572  // make sure it's safe to read from LDS
1573  block_sync_lds();
1574 
1575  // each block copy its data from LDS to global
1576  c_shuffle_block_copy_lds_to_global.Run(
1577  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1578  c_shuffle_block_buf,
1579  c_grid_desc_mblock_mperblock_nblock_nperblock,
1580  c_grid_buf);
1581 
1582  if constexpr(access_id < num_access - 1)
1583  {
1584  constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1585 
1586  // move on C
1587  c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1588  c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1589  }
1590  });
1591  }
1592  }
1593 
1594  template <bool HasMainKBlockLoop,
1595  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1596  TailNumber TailNum = TailNumber::Odd>
1597  __device__ static void Run(const ADataType* p_a_grid,
1598  const BDataType* p_b_grid,
1599  CDataType* p_c_grid,
1600  void* p_shared,
1601  const Problem& problem)
1602  {
1603  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1604  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1605  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
1606  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1607  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
1608  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1609  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1611  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1612 
1613  Run<decltype(a_grid_desc_ak0_m_ak1),
1614  decltype(b_grid_desc_bk0_n_bk1),
1615  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1616  HasMainKBlockLoop,
1617  CGlobalMemoryDataOperation,
1618  TailNum>(p_a_grid,
1619  p_b_grid,
1620  p_c_grid,
1621  p_shared,
1622  problem,
1623  a_grid_desc_ak0_m_ak1,
1624  b_grid_desc_bk0_n_bk1,
1625  c_grid_desc_mblock_mperblock_nblock_nperblock);
1626  }
1627 
1628  template <typename AGridDesc_AK0_M_K1,
1629  typename BGridDesc_BK0_N_K1,
1630  typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1631  bool HasMainKBlockLoop,
1632  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1633  TailNumber TailNum = TailNumber::Odd>
1634  __device__ static void Run_2Lds(const ADataType* p_a_grid,
1635  const BDataType* p_b_grid,
1636  CDataType* p_c_grid,
1637  void* p_shared_0,
1638  void* p_shared_1,
1639  const Problem& problem,
1640  const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1641  const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1642  const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1643  c_grid_desc_mblock_mperblock_nblock_nperblock)
1644  {
1645  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1646  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1647  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1648  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1649  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1650  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1651 
1652  const AElementwiseOperation a_element_op{};
1653  const BElementwiseOperation b_element_op{};
1654  const CElementwiseOperation c_element_op{};
1655 
1656  // divide block work by [M, N]
1657  const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1658 
1659  const auto block_work_idx =
1661 
1662  if(!block_2_ctile_map.ValidCTileIndex(
1663  block_work_idx,
1664  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1665  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1666  {
1667  return;
1668  }
1669 
1670  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1671  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1672 
1673  // HACK: this force m/n_block_data_idx_on_grid into SGPR
1674  const index_t m_block_data_idx_on_grid =
1675  __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1676 
1677  const index_t n_block_data_idx_on_grid =
1678  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1679 
1680  // lds max alignment
1681  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1682 
1683  // A matrix in LDS memory, dst of blockwise copy
1684  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1685 
1686  // B matrix in LDS memory, dst of blockwise copy
1687  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1688 
1689  // A matrix blockwise copy
1690  auto a_blockwise_copy =
1692  AElementwiseOperation,
1696  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1697  ABlockTransferThreadClusterArrangeOrder,
1698  ADataType,
1699  ADataType,
1700  decltype(a_grid_desc_ak0_m_ak1),
1701  decltype(a_block_desc_ak0_m_ak1),
1702  ABlockTransferSrcAccessOrder,
1704  ABlockTransferSrcVectorDim,
1705  2,
1706  ABlockTransferSrcScalarPerVector,
1707  ABlockTransferDstScalarPerVector_AK1,
1708  1,
1709  1,
1710  AThreadTransferSrcResetCoordinateAfterRun,
1711  true,
1712  BlockwiseGemmPipe::GlobalBufferNum>(
1713  a_grid_desc_ak0_m_ak1,
1714  make_multi_index(0, m_block_data_idx_on_grid, 0),
1715  a_element_op,
1716  a_block_desc_ak0_m_ak1,
1717  make_multi_index(0, 0, 0),
1719 
1720  // B matrix blockwise copy
1721  auto b_blockwise_copy =
1723  BElementwiseOperation,
1727  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1728  BBlockTransferThreadClusterArrangeOrder,
1729  BDataType,
1730  BDataType,
1731  decltype(b_grid_desc_bk0_n_bk1),
1732  decltype(b_block_desc_bk0_n_bk1),
1733  BBlockTransferSrcAccessOrder,
1735  BBlockTransferSrcVectorDim,
1736  2,
1737  BBlockTransferSrcScalarPerVector,
1738  BBlockTransferDstScalarPerVector_BK1,
1739  1,
1740  1,
1741  BThreadTransferSrcResetCoordinateAfterRun,
1742  true,
1743  BlockwiseGemmPipe::GlobalBufferNum>(
1744  b_grid_desc_bk0_n_bk1,
1745  make_multi_index(0, n_block_data_idx_on_grid, 0),
1746  b_element_op,
1747  b_block_desc_bk0_n_bk1,
1748  make_multi_index(0, 0, 0),
1750 
1751  // LDS allocation for A and B: be careful of alignment
1752  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1753  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1754 
1755  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1756  static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1757 
1758  auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1759  bit_cast<BDataType*>(static_cast<char*>(p_shared_0) +
1760  a_block_space_size_aligned * sizeof(ADataType)),
1761  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1762 
1763  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1764  static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1765 
1766  auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1767  bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
1768  a_block_space_size_aligned * sizeof(ADataType)),
1769  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1770 
1771  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
1772  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
1773 
1774  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1775  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1776 
1777  // Blockwise GEMM pipeline
1778  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1779  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1780  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1781 
1782  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1783  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1784  KPerBlock);
1785 
1786  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1787  a_block_desc_ak0_m_ak1,
1788  a_blockwise_copy,
1789  a_grid_buf,
1790  a_block_bufs,
1791  a_block_slice_copy_step,
1792  b_grid_desc_bk0_n_bk1,
1793  b_block_desc_bk0_n_bk1,
1794  b_blockwise_copy,
1795  b_grid_buf,
1796  b_block_bufs,
1797  b_block_slice_copy_step,
1798  c_thread_buf,
1799  num_k_block_main_loop);
1800 
1801  // shuffle C and write out
1802  {
1803  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1804  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1805  "wrong!");
1806 
1807  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1808  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1809 
1810  // TODO: hacky, fix it!
1811  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1812  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1813 
1814  // TODO: hacky, fix it!
1815  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1816  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1817  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1818 
1819  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1820  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1821  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1822  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1823  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1824  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1825  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1826  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1827 
1828  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1830 
1831  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1832  static_cast<CShuffleDataType*>(p_shared_0),
1833  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1834 
1835  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1836  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1837  make_tuple(
1840  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1841  M1, // M1 = MWave
1842  M2, // M2 * M3 * M4 = MPerXdl
1843  M3,
1844  M4)),
1847  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1848  N1, // N1 = NWave
1849  N2))), // N2 = NPerXdl
1851  make_tuple(
1853 
1854  // calculate origin of thread output tensor on global memory
1855  // blockwise GEMM c matrix starting index
1856  const auto c_thread_mtx_on_block =
1857  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1858 
1859  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1860  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1861 
1862  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1864  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1866  make_tuple(Sequence<0>{}));
1867 
1868  const auto m_thread_data_on_block_idx =
1869  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1870  make_multi_index(m_thread_data_on_block));
1871 
1872  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1876  make_tuple(Sequence<0>{}));
1877 
1878  const auto n_thread_data_on_block_idx =
1879  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1880  make_multi_index(n_thread_data_on_block));
1881 
1882  // shuffle: threadwise copy C from VGPR to LDS
1883  auto c_thread_copy_vgpr_to_lds =
1885  CShuffleDataType,
1886  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1887  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1889  Sequence<CShuffleMXdlPerWavePerShuffle,
1890  CShuffleNXdlPerWavePerShuffle,
1891  I1,
1892  I1,
1893  M2,
1894  I1,
1895  M4,
1896  I1>,
1898  7,
1899  1,
1901  1,
1902  true>{
1903  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1904  make_multi_index(0,
1905  0,
1906  m_thread_data_on_block_idx[I1],
1907  n_thread_data_on_block_idx[I1],
1908  m_thread_data_on_block_idx[I2],
1909  m_thread_data_on_block_idx[I3],
1910  m_thread_data_on_block_idx[I4],
1911  n_thread_data_on_block_idx[I2]),
1913 
1914  // shuffle: blockwise copy C from LDS to global
1915  auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1916  ThisThreadBlock, // ThreadGroup
1917  CElementwiseOperation, // ElementwiseOperation,
1918  CGlobalMemoryDataOperation, // DstInMemOp,
1919  Sequence<1,
1920  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1921  1,
1922  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1923  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1924  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1925  CShuffleDataType, // typename SrcData,
1926  CDataType, // typename DstData,
1927  decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1928  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1929  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1930  3, // index_t VectorDim,
1931  CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1932  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1933  false> // bool ThreadTransferDstResetCoordinateAfterRun>
1934  {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1935  make_multi_index(0, 0, 0, 0),
1936  c_grid_desc_mblock_mperblock_nblock_nperblock,
1937  make_multi_index(block_m_id, 0, block_n_id, 0),
1938  c_element_op};
1939 
1940  // space filling curve for threadwise C in VGPR
1941  constexpr auto sfc_c_vgpr =
1944  Sequence<CShuffleMXdlPerWavePerShuffle,
1945  CShuffleNXdlPerWavePerShuffle,
1946  1,
1947  1,
1948  M2,
1949  1,
1950  M4,
1951  1>>{};
1952 
1953  // space filling curve for shuffled blockwise C in global mem
1954  constexpr auto sfc_c_global =
1957  Sequence<1,
1958  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1959  1,
1960  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1961 
1962  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1963 
1964  static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1965 
1966  static_for<0, num_access, 1>{}([&](auto access_id) {
1967  // make sure it's safe to write to LDS
1968  block_sync_lds();
1969 
1970  // each thread write its data from VGPR to LDS
1971  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1972  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1973  c_thread_buf,
1974  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1975  c_shuffle_block_buf);
1976 
1977  // make sure it's safe to read from LDS
1978  block_sync_lds();
1979 
1980  // each block copy its data from LDS to global
1981  c_shuffle_block_copy_lds_to_global.Run(
1982  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1983  c_shuffle_block_buf,
1984  c_grid_desc_mblock_mperblock_nblock_nperblock,
1985  c_grid_buf);
1986 
1987  if constexpr(access_id < num_access - 1)
1988  {
1989  constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1990 
1991  // move on C
1992  c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1993  c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1994  }
1995  });
1996  }
1997  }
1998 
1999  template <bool HasMainKBlockLoop,
2000  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
2001  TailNumber TailNum = TailNumber::Odd>
2002  __device__ static void Run_2Lds(const ADataType* p_a_grid,
2003  const BDataType* p_b_grid,
2004  CDataType* p_c_grid,
2005  void* p_shared_0,
2006  void* p_shared_1,
2007  const Problem& problem)
2008  {
2009  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2010  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
2011  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
2012  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2013  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
2014  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
2015 
2016  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2018  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2019 
2020  Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
2021  decltype(b_grid_desc_bk0_n_bk1),
2022  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2023  HasMainKBlockLoop,
2024  CGlobalMemoryDataOperation,
2025  TailNum>(p_a_grid,
2026  p_b_grid,
2027  p_c_grid,
2028  p_shared_0,
2029  p_shared_1,
2030  problem,
2031  a_grid_desc_ak0_m_ak1,
2032  b_grid_desc_bk0_n_bk1,
2033  c_grid_desc_mblock_mperblock_nblock_nperblock);
2034  }
2035 };
2036 
2037 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
#define CK_ENV(name)
Definition: env.hpp:128
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blockwise_gemm_pipeline_xdlops_ab_scale_selector.hpp:13
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
TailNumber
Definition: blkgemmpipe_scheduler.hpp:18
_Float16 half_t
Definition: data_type.hpp:25
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
ushort bhalf_t
Definition: data_type.hpp:24
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:58
constexpr auto BlockGemmPipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp:44
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:139
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
constexpr bool is_same_v
Definition: type.hpp:283
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:37
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: block_to_ctile_map.hpp:270
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:296
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:281
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:241
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t k_batch_, bool is_reduce_=false)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:577
const BDataType * p_b_grid
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:260
bool is_reduce
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:609
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:596
const ADataType * p_a_grid
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:259
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:601
CDataType * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:261
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:177
index_t StrideB
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:226
index_t M
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:222
index_t KRead
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:231
index_t BK0
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:234
index_t NPadded
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:230
index_t MBlock
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:235
index_t KPadded
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:232
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:538
index_t NBlock
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:236
index_t MPadded
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:229
index_t AK0
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:233
index_t N
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:223
index_t KBatch
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:228
index_t StrideA
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:225
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:513
index_t K
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:224
index_t StrideC
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:227
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:613
index_t a_k_split_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:662
__device__ SplitKBatchOffset(Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:615
index_t c_reduce_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:664
index_t b_k_split_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:663
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:196
static constexpr auto I5
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:72
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const index_t k_id=0)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:1002
static constexpr auto AK1Number
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:79
static constexpr auto BK1Number
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:80
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1207
static constexpr index_t APackedSize
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:156
static constexpr index_t BPackedSize
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:163
static constexpr auto I3
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:70
static constexpr auto I1
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:68
static constexpr auto I4
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:71
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:185
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:322
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:202
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:190
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:976
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:803
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:2002
static constexpr auto AK0Number
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:77
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1004
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:667
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1215
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1200
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const index_t k_id=0)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:641
static constexpr auto I7
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:74
static constexpr index_t KPack
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:82
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1597
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:180
static constexpr auto I0
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:67
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:434
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:170
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:935
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:443
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:86
static constexpr auto BK0Number
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:78
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1634
static constexpr auto I2
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:69
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:240
static constexpr auto I6
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:73
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:451
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:208
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:226
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1239
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:215
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:574
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:220
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:175
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1130
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: data_type.hpp:320
Definition: functional2.hpp:31
Definition: device_base.hpp:50
Definition: unary_element_wise_operation.hpp:241