/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp Source File
gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
16 
17 namespace ck {
18 
19 template <typename ALayout,
20  typename BLayout,
21  typename CLayout,
22  typename ADataType,
23  typename BDataType,
24  typename AccDataType,
25  typename CShuffleDataType,
26  typename CDataType,
27  typename AElementwiseOperation,
28  typename BElementwiseOperation,
29  typename CElementwiseOperation,
31  index_t BlockSize,
32  index_t MPerBlock,
33  index_t NPerBlock,
34  index_t KPerBlock,
35  index_t AK1Value,
36  index_t BK1Value,
37  index_t MPerXdl,
38  index_t NPerXdl,
39  index_t MXdlPerWave,
40  index_t NXdlPerWave,
41  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
42  typename ABlockTransferThreadClusterArrangeOrder,
43  typename ABlockTransferSrcAccessOrder,
44  index_t ABlockTransferSrcVectorDim,
45  index_t ABlockTransferSrcScalarPerVector,
46  index_t ABlockTransferDstScalarPerVector_AK1,
47  bool AThreadTransferSrcResetCoordinateAfterRun,
48  index_t ABlockLdsExtraM,
49  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
50  typename BBlockTransferThreadClusterArrangeOrder,
51  typename BBlockTransferSrcAccessOrder,
52  index_t BBlockTransferSrcVectorDim,
53  index_t BBlockTransferSrcScalarPerVector,
54  index_t BBlockTransferDstScalarPerVector_BK1,
55  bool BThreadTransferSrcResetCoordinateAfterRun,
56  index_t BBlockLdsExtraN,
57  index_t CShuffleMXdlPerWavePerShuffle,
58  index_t CShuffleNXdlPerWavePerShuffle,
59  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
60  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
63  typename ComputeTypeA = CDataType,
64  typename ComputeTypeB = ComputeTypeA>
66 {
67  static constexpr auto I0 = Number<0>{};
68  static constexpr auto I1 = Number<1>{};
69  static constexpr auto I2 = Number<2>{};
70  static constexpr auto I3 = Number<3>{};
71  static constexpr auto I4 = Number<4>{};
72  static constexpr auto I5 = Number<5>{};
73  static constexpr auto I6 = Number<6>{};
74  static constexpr auto I7 = Number<7>{};
75 
76  // K1 should be Number<...>
77  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
78  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
79  static constexpr auto AK1Number = Number<AK1Value>{};
80  static constexpr auto BK1Number = Number<BK1Value>{};
81 
82  static constexpr index_t KPack =
85 
87 
88  __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch)
89  {
90  return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), KBatch, Batch);
91  }
92 
93  __host__ static auto CalculateMPadded(index_t M)
94  {
95  return math::integer_least_multiple(M, MPerBlock);
96  }
97 
98  __host__ static auto CalculateNPadded(index_t N)
99  {
100  return math::integer_least_multiple(N, NPerBlock);
101  }
102 
103  __host__ static auto CalculateKPadded(index_t K)
104  {
105  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
106  }
107 
108  __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
109  {
110  auto K_t = K_Batch * KPerBlock;
111  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
112  }
113 
114  __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
115  {
116  auto K_t = K_Batch * KPerBlock;
117  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
118  }
119 
120  __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
121  {
122  auto K_t = K_Batch * KPerBlock;
123  return (K + K_t - 1) / K_t * KPerBlock;
124  }
125 
126  __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
127  {
128  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
129  auto K_t = K_Batch * KReadVec;
130  return (K + K_t - 1) / K_t * KReadVec;
131  }
132 
133  __host__ static auto CalculateMBlock(index_t M)
134  {
135  return math::integer_divide_ceil(M, MPerBlock);
136  }
137 
138  __host__ static auto CalculateNBlock(index_t N)
139  {
140  return math::integer_divide_ceil(N, NPerBlock);
141  }
142 
143  template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
144  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
145  {
146  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
147  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
148 
150  TileDesc_K0_MN_K1{},
156  }
157 
158  template <typename ABlockDesc_AK0_M_AK1>
159  __host__ __device__ static constexpr auto
160  MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
161  {
162  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
163 
164  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
165  }
166 
167  template <typename BBlockDesc_BK0_N_BK1>
168  __host__ __device__ static constexpr auto
169  MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
170  {
171  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
172 
173  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
174  }
175 
176  struct Problem
177  {
178  __host__ Problem(index_t M_,
179  index_t N_,
180  index_t K_,
181  index_t StrideA_,
182  index_t StrideB_,
183  index_t StrideC_,
184  index_t KBatch_)
185  : M{M_},
186  N{N_},
187  K{K_},
188  StrideA{StrideA_},
189  StrideB{StrideB_},
190  StrideC{StrideC_},
191  KBatch{KBatch_},
194  KRead{CalculateKRead(K_, KBatch_)},
195  KPadded{CalculateKPadded(K_, KBatch_)},
196  AK0{CalculateAK0Padded(K_, KBatch_)},
197  BK0{CalculateBK0Padded(K_, KBatch_)},
198  MBlock{CalculateMBlock(M_)},
200  {
201  }
202 
203  __host__ void Print() const
204  {
205  std::cout << "problem {"
206  << "M:" << M << ", "
207  << "N:" << N << ", "
208  << "K:" << K << ", "
209  << "SA:" << StrideA << ", "
210  << "SB:" << StrideB << ", "
211  << "SC:" << StrideC << ", "
212  << "MP:" << MPadded << ", "
213  << "NP:" << NPadded << ", "
214  << "KRead:" << KRead << ", "
215  << "KP:" << KPadded << ", "
216  << "AK0:" << AK0 << ", "
217  << "BK0:" << BK0 << ", "
218  << "MBlock: " << MBlock << ", "
219  << "NBlock: " << NBlock << "}" << std::endl;
220  }
221 
237  };
238 
239  // Argument
241  {
242  __host__ Argument(const ADataType* p_a_grid_,
243  const BDataType* p_b_grid_,
244  CDataType* p_c_grid_,
245  index_t M_,
246  index_t N_,
247  index_t K_,
248  index_t StrideA_,
249  index_t StrideB_,
250  index_t StrideC_,
251  index_t k_batch_)
252  : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_},
253  p_a_grid{p_a_grid_},
254  p_b_grid{p_b_grid_},
255  p_c_grid{p_c_grid_}
256  {
257  }
258 
259  const ADataType* p_a_grid;
260  const BDataType* p_b_grid;
261  CDataType* p_c_grid;
262  };
263 
264  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
265  {
266  // A matrix in LDS memory, dst of blockwise copy
267  if constexpr(ABlockLdsExtraM)
268  {
272  }
273  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
274  // in some cases.
276  {
277  constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1
278  ? 1
279  : 32 * 4 / KPerBlock / sizeof(ADataType);
280  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
281  make_tuple(
282  AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
284 
285  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
286  a_lds_block_desc,
292 
293  constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor(
294  a_lds_block_desc_permuted,
300 
301  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
302  a_lds_block_desc_ak0_mldslayer_m_ak1,
309 
310  return a_lds_block_desc_ak0_m_ak1;
311  }
312  else // ColumnMajor A
313  {
314  // kfold and mpair dimension is not always required.
315  // more dimension in merge_transform increase the difficulty of generating immarg offset
316  // for compiler.
317  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
318  constexpr auto M1 = MPerBlock / M0;
319 
320  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
321  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
322  constexpr auto KThreadRead = 64 / MPerXdl;
323  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
324 
325  constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
326  ? 1
327  : 128 / (AK1Number * M0 * sizeof(ADataType));
328  constexpr auto KThreadReadPerm =
329  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
330  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
331  : KThreadRead;
332 
333  // 1<=mpair<=n0
334  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
335  ? 1
336  : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
337  ? M0
338  : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
339 
340  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
344  Number<kfold * M0 / mpair>{},
345  Number<mpair>{},
346  AK1Number));
347 
348  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
349  a_lds_block_desc,
350  make_tuple(
354  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
357  make_tuple(
359  make_tuple(
361 
362  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
363  a_lds_block_desc_permuted,
364  make_tuple(
372  Sequence<1>{},
373  Sequence<2>{},
374  Sequence<3>{},
375  Sequence<4>{},
376  Sequence<5>{}),
378  Sequence<2>{},
379  Sequence<0, 3>{},
380  Sequence<4, 5>{},
381  Sequence<6>{},
382  Sequence<7>{}));
383 
384  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
385  a_lds_block_desc_unmerged,
388  Number<KThreadWrite / kfold / KThreadReadPerm>{},
389  Number<kfold>{},
396 
397  return a_lds_block_desc_ak0_m_ak1;
398  }
399  }
400 
401  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
402  {
403  // B matrix in LDS memory, dst of blockwise copy
404  if constexpr(BBlockLdsExtraN)
405  {
409  }
411  {
412  // NLdsLayer * K0 as logical Bank
413  constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1
414  ? 1
415  : 32 * 4 / KPerBlock / sizeof(BDataType);
416  ;
417  constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
418  make_tuple(
419  BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number),
421 
422  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
423  b_lds_block_desc,
429 
430  constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
431  b_lds_block_desc_permuted,
437 
438  constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
439  b_lds_block_desc_bk0_nldslayer_n_bk1,
446 
447  return b_lds_block_desc_bk0_n_bk1;
448  }
449  else // RowMajor B
450  {
451  constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
452  constexpr auto N1 = NPerBlock / N0;
453 
454  constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
455  constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
456  constexpr auto KThreadRead = 64 / NPerXdl;
457  constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
458 
459  constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
460  ? 1
461  : 128 / (BK1Number * N0 * sizeof(BDataType));
462  constexpr auto KThreadReadPerm =
463  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
464  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
465  : KThreadRead;
466 
467  // 1<=npair<=n0
468  constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128)
469  ? 1
470  : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0
471  ? N0
472  : 128 / (BK1Number * NPerXdl * sizeof(BDataType)));
473 
474  constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
478  Number<kfold * N0 / npair>{},
479  Number<npair>{},
480  BK1Number));
481 
482  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
483  b_lds_block_desc,
484  make_tuple(
488  make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
491  make_tuple(
493  make_tuple(
495 
496  constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
497  b_lds_block_desc_permuted,
498  make_tuple(
506  Sequence<1>{},
507  Sequence<2>{},
508  Sequence<3>{},
509  Sequence<4>{},
510  Sequence<5>{}),
512  Sequence<2>{},
513  Sequence<0, 3>{},
514  Sequence<4, 5>{},
515  Sequence<6>{},
516  Sequence<7>{}));
517 
518  constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
519  b_lds_block_desc_unmerged,
522  Number<KThreadWrite / kfold / KThreadReadPerm>{},
523  Number<kfold>{},
530 
531  return b_lds_block_desc_bk0_n_bk1;
532  }
533  }
534 
536  {
537  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
538  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
539 
540  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
542  make_tuple(I1,
544  I1,
546 
547  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
548  }
549 
552  BlkGemmPipelineVer,
553  BlkGemmPipeSched,
554  BlockSize,
555  ADataType,
556  BDataType,
557  ComputeTypeA,
558  AccDataType,
565  ABlockTransferSrcScalarPerVector,
566  BBlockTransferSrcScalarPerVector,
567  MPerBlock,
568  NPerBlock,
569  KPerBlock,
570  MPerXdl,
571  NPerXdl,
572  MXdlPerWave,
573  NXdlPerWave,
574  KPack>())>;
575 
576  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
577  {
578  // LDS allocation for A and B: be careful of alignment
579  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
580  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
581 
582  // lds max alignment
583  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
584 
585  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
586  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
587 
588  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
589  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
590 
591  // LDS allocation for C shuffle in LDS
592  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
594 
595  constexpr auto c_block_size =
596  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
597 
598  return math::max((a_block_space_size_aligned * sizeof(ADataType) +
599  b_block_space_size_aligned * sizeof(BDataType)),
600  c_block_size * sizeof(CShuffleDataType));
601  }
602 
603  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
604  {
605  const index_t num_loop = K / KPerBlock;
606 
607  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
608  }
609 
610  __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
611  {
612  const index_t num_loop = K / KPerBlock;
613 
614  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
615  }
616 
617  template <typename CGridDesc>
618  __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
619  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
620  {
621  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
622  c_grid_desc_m_n,
627 
628  return c_grid_desc_mblock_mperblock_nblock_nperblock;
629  }
630 
631  // return block_id to C matrix tile idx (m0, n0) mapping
632  // if arch = gfx942
634 
635  template <typename AGridDesc_AK0_M_K1,
636  typename BGridDesc_BK0_N_K1,
637  typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
638  bool HasMainKBlockLoop,
639  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
640  TailNumber TailNum = TailNumber::Odd>
641  __device__ static void Run(const ADataType* p_a_grid,
642  const BDataType* p_b_grid,
643  CDataType* p_c_grid,
644  void* p_shared,
645  const Problem& problem,
646  const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
647  const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
648  const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
649  c_grid_desc_mblock_mperblock_nblock_nperblock,
650  const index_t k_id = 0)
651  {
652  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
653  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
654  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
655  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
656  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
657  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
658 
659  const AElementwiseOperation a_element_op{};
660  const BElementwiseOperation b_element_op{};
661  const CElementwiseOperation c_element_op{};
662 
663  // divide block work by [M, N]
664  const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
665 
666  const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(
667  make_multi_index(static_cast<index_t>(blockIdx.x)));
668 
669  if(!block_2_ctile_map.ValidCTileIndex(
670  block_work_idx,
671  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
672  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
673  {
674  return;
675  }
676 
677  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
678  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
679 
680  // HACK: this force m/n_block_data_idx_on_grid into SGPR
681  const index_t m_block_data_idx_on_grid =
682  __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
683 
684  const index_t n_block_data_idx_on_grid =
685  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
686 
687  // lds max alignment
688  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
689 
690  // A matrix in LDS memory, dst of blockwise copy
691  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
692 
693  // B matrix in LDS memory, dst of blockwise copy
694  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
695 
696  // A matrix blockwise copy
697  auto a_blockwise_copy =
699  AElementwiseOperation,
703  ABlockTransferThreadClusterLengths_AK0_M_AK1,
704  ABlockTransferThreadClusterArrangeOrder,
705  ADataType,
706  ADataType,
707  decltype(a_grid_desc_ak0_m_ak1),
708  decltype(a_block_desc_ak0_m_ak1),
709  ABlockTransferSrcAccessOrder,
711  ABlockTransferSrcVectorDim,
712  2,
713  ABlockTransferSrcScalarPerVector,
714  ABlockTransferDstScalarPerVector_AK1,
715  1,
716  1,
717  AThreadTransferSrcResetCoordinateAfterRun,
718  true,
719  BlockwiseGemmPipe::GlobalBufferNum>(
720  a_grid_desc_ak0_m_ak1,
721  make_multi_index(k_id, m_block_data_idx_on_grid, 0),
722  a_element_op,
723  a_block_desc_ak0_m_ak1,
724  make_multi_index(0, 0, 0),
726 
727  // B matrix blockwise copy
728  auto b_blockwise_copy =
730  BElementwiseOperation,
734  BBlockTransferThreadClusterLengths_BK0_N_BK1,
735  BBlockTransferThreadClusterArrangeOrder,
736  BDataType,
737  BDataType,
738  decltype(b_grid_desc_bk0_n_bk1),
739  decltype(b_block_desc_bk0_n_bk1),
740  BBlockTransferSrcAccessOrder,
742  BBlockTransferSrcVectorDim,
743  2,
744  BBlockTransferSrcScalarPerVector,
745  BBlockTransferDstScalarPerVector_BK1,
746  1,
747  1,
748  BThreadTransferSrcResetCoordinateAfterRun,
749  true,
750  BlockwiseGemmPipe::GlobalBufferNum>(
751  b_grid_desc_bk0_n_bk1,
752  make_multi_index(k_id, n_block_data_idx_on_grid, 0),
753  b_element_op,
754  b_block_desc_bk0_n_bk1,
755  make_multi_index(0, 0, 0),
757 
758  // LDS allocation for A and B: be careful of alignment
759  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
760  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
761 
762  // Cast after lds
763  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
764  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
765 
766  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
767  static_cast<BDataType*>(p_shared) +
768  a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType),
769  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
770 
771  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
772  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
773 
774  // Blockwise GEMM pipeline
775  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
776  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
777  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
778 
779  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
780  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
781  (KPerBlock * problem.KBatch));
782 
783  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
784  a_block_desc_ak0_m_ak1,
785  a_blockwise_copy,
786  a_grid_buf,
787  a_block_buf,
788  a_block_slice_copy_step,
789  b_grid_desc_bk0_n_bk1,
790  b_block_desc_bk0_n_bk1,
791  b_blockwise_copy,
792  b_grid_buf,
793  b_block_buf,
794  b_block_slice_copy_step,
795  c_thread_buf,
796  num_k_block_main_loop);
797 
798  // shuffle C and write out
799  {
800  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
801  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
802  "wrong!");
803 
804  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
805  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
806 
807  // TODO: hacky, fix it!
808  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
809  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
810 
811  // TODO: hacky, fix it!
812  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
813  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
814  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
815 
816  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
817  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
818  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
819  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
820  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
821  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
822  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
823  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
824 
825  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
827 
828  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
829  static_cast<CShuffleDataType*>(p_shared),
830  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
831 
832  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
833  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
834  make_tuple(
837  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
838  M1, // M1 = MWave
839  M2, // M2 * M3 * M4 = MPerXdl
840  M3,
841  M4)),
844  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
845  N1, // N1 = NWave
846  N2))), // N2 = NPerXdl
848  make_tuple(
850 
851  // calculate origin of thread output tensor on global memory
852  // blockwise GEMM c matrix starting index
853  const auto c_thread_mtx_on_block =
854  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
855 
856  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
857  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
858 
859  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
861  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
864 
865  const auto m_thread_data_on_block_idx =
866  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
867  make_multi_index(m_thread_data_on_block));
868 
869  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
874 
875  const auto n_thread_data_on_block_idx =
876  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
877  make_multi_index(n_thread_data_on_block));
878 
879  // shuffle: threadwise copy C from VGPR to LDS
880  auto c_thread_copy_vgpr_to_lds =
882  CShuffleDataType,
883  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
884  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
886  Sequence<CShuffleMXdlPerWavePerShuffle,
887  CShuffleNXdlPerWavePerShuffle,
888  I1,
889  I1,
890  M2,
891  I1,
892  M4,
893  I1>,
895  7,
896  1,
898  1,
899  true>{
900  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
902  0,
903  m_thread_data_on_block_idx[I1],
904  n_thread_data_on_block_idx[I1],
905  m_thread_data_on_block_idx[I2],
906  m_thread_data_on_block_idx[I3],
907  m_thread_data_on_block_idx[I4],
908  n_thread_data_on_block_idx[I2]),
910 
911  // shuffle: blockwise copy C from LDS to global
912  auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
913  ThisThreadBlock, // ThreadGroup
914  CElementwiseOperation, // ElementwiseOperation,
915  CGlobalMemoryDataOperation, // DstInMemOp,
916  Sequence<1,
917  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
918  1,
919  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
920  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
921  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
922  CShuffleDataType, // typename SrcData,
923  CDataType, // typename DstData,
924  decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
925  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
926  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
927  3, // index_t VectorDim,
928  CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
929  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
930  false> // bool ThreadTransferDstResetCoordinateAfterRun>
931  {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
932  make_multi_index(0, 0, 0, 0),
933  c_grid_desc_mblock_mperblock_nblock_nperblock,
934  make_multi_index(block_m_id, 0, block_n_id, 0),
935  c_element_op};
936 
937  // space filling curve for threadwise C in VGPR
938  constexpr auto sfc_c_vgpr =
941  Sequence<CShuffleMXdlPerWavePerShuffle,
942  CShuffleNXdlPerWavePerShuffle,
943  1,
944  1,
945  M2,
946  1,
947  M4,
948  1>>{};
949 
950  // space filling curve for shuffled blockwise C in global mem
951  constexpr auto sfc_c_global =
954  Sequence<1,
955  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
956  1,
957  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
958 
959  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
960 
961  static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
962 
963  static_for<0, num_access, 1>{}([&](auto access_id) {
964  // make sure it's safe to write to LDS
965  block_sync_lds();
966 
967  // each thread write its data from VGPR to LDS
968  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
969  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
970  c_thread_buf,
971  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
972  c_shuffle_block_buf);
973 
974  // make sure it's safe to read from LDS
975  block_sync_lds();
976 
977  // each block copy its data from LDS to global
978  c_shuffle_block_copy_lds_to_global.Run(
979  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
980  c_shuffle_block_buf,
981  c_grid_desc_mblock_mperblock_nblock_nperblock,
982  c_grid_buf);
983 
984  if constexpr(access_id < num_access - 1)
985  {
986  constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
987 
988  // move on C
989  c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
990  c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
991  }
992  });
993  }
994  }
995 
996  template <typename AGridDesc_AK0_M_K1,
997  typename BGridDesc_BK0_N_K1,
998  typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
999  bool HasMainKBlockLoop,
1000  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1001  TailNumber TailNum = TailNumber::Odd>
1002  __device__ static void Run_2Lds(const ADataType* p_a_grid,
1003  const BDataType* p_b_grid,
1004  CDataType* p_c_grid,
1005  void* p_shared_0,
1006  void* p_shared_1,
1007  const Problem& problem,
1008  const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1009  const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1010  const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1011  c_grid_desc_mblock_mperblock_nblock_nperblock,
1012  const index_t k_id = 0)
1013  {
1014  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1015  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1016  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1017  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1018  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1019  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1020 
1021  const AElementwiseOperation a_element_op{};
1022  const BElementwiseOperation b_element_op{};
1023  const CElementwiseOperation c_element_op{};
1024 
1025  // divide block work by [M, N]
1026  const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1027 
1028  const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(
1029  make_multi_index(static_cast<index_t>(blockIdx.x)));
1030 
1031  if(!block_2_ctile_map.ValidCTileIndex(
1032  block_work_idx,
1033  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1034  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1035  {
1036  return;
1037  }
1038 
1039  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1040  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1041 
1042  // HACK: this force m/n_block_data_idx_on_grid into SGPR
1043  const index_t m_block_data_idx_on_grid =
1044  __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1045 
1046  const index_t n_block_data_idx_on_grid =
1047  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1048 
1049  // lds max alignment
1050  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1051 
1052  // A matrix in LDS memory, dst of blockwise copy
1053  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1054 
1055  // B matrix in LDS memory, dst of blockwise copy
1056  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1057 
1058  // A matrix blockwise copy
1059  auto a_blockwise_copy =
1061  AElementwiseOperation,
1065  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1066  ABlockTransferThreadClusterArrangeOrder,
1067  ADataType,
1068  ADataType,
1069  decltype(a_grid_desc_ak0_m_ak1),
1070  decltype(a_block_desc_ak0_m_ak1),
1071  ABlockTransferSrcAccessOrder,
1073  ABlockTransferSrcVectorDim,
1074  2,
1075  ABlockTransferSrcScalarPerVector,
1076  ABlockTransferDstScalarPerVector_AK1,
1077  1,
1078  1,
1079  AThreadTransferSrcResetCoordinateAfterRun,
1080  true,
1081  BlockwiseGemmPipe::GlobalBufferNum>(
1082  a_grid_desc_ak0_m_ak1,
1083  make_multi_index(k_id, m_block_data_idx_on_grid, 0),
1084  a_element_op,
1085  a_block_desc_ak0_m_ak1,
1086  make_multi_index(0, 0, 0),
1088 
1089  // B matrix blockwise copy
1090  auto b_blockwise_copy =
1092  BElementwiseOperation,
1096  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1097  BBlockTransferThreadClusterArrangeOrder,
1098  BDataType,
1099  BDataType,
1100  decltype(b_grid_desc_bk0_n_bk1),
1101  decltype(b_block_desc_bk0_n_bk1),
1102  BBlockTransferSrcAccessOrder,
1104  BBlockTransferSrcVectorDim,
1105  2,
1106  BBlockTransferSrcScalarPerVector,
1107  BBlockTransferDstScalarPerVector_BK1,
1108  1,
1109  1,
1110  BThreadTransferSrcResetCoordinateAfterRun,
1111  true,
1112  BlockwiseGemmPipe::GlobalBufferNum>(
1113  b_grid_desc_bk0_n_bk1,
1114  make_multi_index(k_id, n_block_data_idx_on_grid, 0),
1115  b_element_op,
1116  b_block_desc_bk0_n_bk1,
1117  make_multi_index(0, 0, 0),
1119 
1120  // LDS allocation for A and B: be careful of alignment
1121  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1122  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1123 
1124  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1125  static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1126 
1127  auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1128  static_cast<BDataType*>(p_shared_0) +
1129  a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType),
1130  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1131 
1132  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1133  static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1134 
1135  auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1136  static_cast<BDataType*>(p_shared_1) +
1137  a_block_space_size_aligned * sizeof(ADataType) / sizeof(BDataType),
1138  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1139 
1140  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
1141  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
1142 
1143  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1144  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1145 
1146  // Blockwise GEMM pipeline
1147  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1148  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1149  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1150 
1151  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1152  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1153  (KPerBlock * problem.KBatch));
1154 
1155  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1156  a_block_desc_ak0_m_ak1,
1157  a_blockwise_copy,
1158  a_grid_buf,
1159  a_block_bufs,
1160  a_block_slice_copy_step,
1161  b_grid_desc_bk0_n_bk1,
1162  b_block_desc_bk0_n_bk1,
1163  b_blockwise_copy,
1164  b_grid_buf,
1165  b_block_bufs,
1166  b_block_slice_copy_step,
1167  c_thread_buf,
1168  num_k_block_main_loop);
1169 
1170  // shuffle C and write out
1171  {
1172  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1173  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1174  "wrong!");
1175 
1176  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1177  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1178 
1179  // TODO: hacky, fix it!
1180  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1181  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1182 
1183  // TODO: hacky, fix it!
1184  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1185  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1186  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1187 
1188  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1189  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1190  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1191  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1192  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1193  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1194  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1195  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1196 
1197  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1199 
1200  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1201  static_cast<CShuffleDataType*>(p_shared_0),
1202  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1203 
1204  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1205  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1206  make_tuple(
1209  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1210  M1, // M1 = MWave
1211  M2, // M2 * M3 * M4 = MPerXdl
1212  M3,
1213  M4)),
1216  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1217  N1, // N1 = NWave
1218  N2))), // N2 = NPerXdl
1220  make_tuple(
1222 
1223  // calculate origin of thread output tensor on global memory
1224  // blockwise GEMM c matrix starting index
1225  const auto c_thread_mtx_on_block =
1226  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1227 
1228  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1229  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1230 
1231  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1233  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1235  make_tuple(Sequence<0>{}));
1236 
1237  const auto m_thread_data_on_block_idx =
1238  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1239  make_multi_index(m_thread_data_on_block));
1240 
1241  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1245  make_tuple(Sequence<0>{}));
1246 
1247  const auto n_thread_data_on_block_idx =
1248  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1249  make_multi_index(n_thread_data_on_block));
1250 
1251  // shuffle: threadwise copy C from VGPR to LDS
1252  auto c_thread_copy_vgpr_to_lds =
1254  CShuffleDataType,
1255  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1256  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1258  Sequence<CShuffleMXdlPerWavePerShuffle,
1259  CShuffleNXdlPerWavePerShuffle,
1260  I1,
1261  I1,
1262  M2,
1263  I1,
1264  M4,
1265  I1>,
1267  7,
1268  1,
1270  1,
1271  true>{
1272  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1273  make_multi_index(0,
1274  0,
1275  m_thread_data_on_block_idx[I1],
1276  n_thread_data_on_block_idx[I1],
1277  m_thread_data_on_block_idx[I2],
1278  m_thread_data_on_block_idx[I3],
1279  m_thread_data_on_block_idx[I4],
1280  n_thread_data_on_block_idx[I2]),
1282 
1283  // shuffle: blockwise copy C from LDS to global
1284  auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1285  ThisThreadBlock, // ThreadGroup
1286  CElementwiseOperation, // ElementwiseOperation,
1287  CGlobalMemoryDataOperation, // DstInMemOp,
1288  Sequence<1,
1289  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1290  1,
1291  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1292  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1293  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1294  CShuffleDataType, // typename SrcData,
1295  CDataType, // typename DstData,
1296  decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1297  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1298  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1299  3, // index_t VectorDim,
1300  CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1301  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1302  false> // bool ThreadTransferDstResetCoordinateAfterRun>
1303  {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1304  make_multi_index(0, 0, 0, 0),
1305  c_grid_desc_mblock_mperblock_nblock_nperblock,
1306  make_multi_index(block_m_id, 0, block_n_id, 0),
1307  c_element_op};
1308 
1309  // space filling curve for threadwise C in VGPR
1310  constexpr auto sfc_c_vgpr =
1313  Sequence<CShuffleMXdlPerWavePerShuffle,
1314  CShuffleNXdlPerWavePerShuffle,
1315  1,
1316  1,
1317  M2,
1318  1,
1319  M4,
1320  1>>{};
1321 
1322  // space filling curve for shuffled blockwise C in global mem
1323  constexpr auto sfc_c_global =
1326  Sequence<1,
1327  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1328  1,
1329  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1330 
1331  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1332 
1333  static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1334 
1335  static_for<0, num_access, 1>{}([&](auto access_id) {
1336  // make sure it's safe to write to LDS
1337  block_sync_lds();
1338 
1339  // each thread write its data from VGPR to LDS
1340  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1341  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1342  c_thread_buf,
1343  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1344  c_shuffle_block_buf);
1345 
1346  // make sure it's safe to read from LDS
1347  block_sync_lds();
1348 
1349  // each block copy its data from LDS to global
1350  c_shuffle_block_copy_lds_to_global.Run(
1351  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1352  c_shuffle_block_buf,
1353  c_grid_desc_mblock_mperblock_nblock_nperblock,
1354  c_grid_buf);
1355 
1356  if constexpr(access_id < num_access - 1)
1357  {
1358  constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1359 
1360  // move on C
1361  c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1362  c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1363  }
1364  });
1365  }
1366  }
1367 };
1368 
1369 } // namespace ck
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blockwise_gemm_pipeline_xdlops_ab_scale_selector.hpp:13
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
TailNumber
Definition: blkgemmpipe_scheduler.hpp:18
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
constexpr auto BlockGemmPipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp:44
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: block_to_ctile_map.hpp:270
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:281
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:241
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t k_batch_)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:242
const BDataType * p_b_grid
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:260
const ADataType * p_a_grid
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:259
CDataType * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:261
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:177
index_t StrideB
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:226
index_t M
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:222
index_t KRead
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:231
index_t BK0
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:234
index_t NPadded
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:230
index_t MBlock
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:235
index_t KPadded
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:232
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:203
index_t NBlock
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:236
index_t MPadded
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:229
index_t AK0
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:233
index_t N
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:223
index_t KBatch
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:228
index_t StrideA
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:225
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:178
index_t K
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:224
index_t StrideC
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:227
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:66
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:114
static constexpr auto I5
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:72
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const index_t k_id=0)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:1002
static constexpr auto AK1Number
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:79
static constexpr auto BK1Number
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:80
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:610
static constexpr auto I3
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:70
static constexpr auto I1
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:68
static constexpr auto I4
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:71
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:103
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:120
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:108
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:576
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:401
static constexpr auto AK0Number
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:77
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:264
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:618
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:603
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const index_t k_id=0)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:641
static constexpr auto I7
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:74
static constexpr index_t KPack
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:82
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:98
static constexpr auto I0
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:67
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:160
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:535
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:169
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:86
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:88
static constexpr auto BK0Number
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:78
static constexpr auto I2
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:69
static constexpr auto I6
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:73
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:126
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:144
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:133
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:574
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:138
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:93
Definition: xdlops_gemm.hpp:886
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: functional2.hpp:31
Definition: device_base.hpp:50
Definition: unary_element_wise_operation.hpp:241