/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r1.hpp Source File
gridwise_gemm_xdlops_v3r1.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
18 
19 namespace ck {
20 
21 template <typename GridwiseGemm,
22  typename FloatAB,
23  typename FloatC,
24  typename AGridDesc_AK0_M_AK1,
25  typename BGridDesc_BK0_N_BK1,
26  typename CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
27  typename AElementwiseOperation,
28  typename BElementwiseOperation,
29  typename CElementwiseOperation,
30  typename Block2CTileMap,
31  bool HasMainK0BlockLoop>
32 __global__ void
33 #if CK_USE_LAUNCH_BOUNDS
35 #endif
37  const FloatAB* __restrict__ p_a_grid,
38  const FloatAB* __restrict__ p_b_grid,
39  FloatC* __restrict__ p_c_grid,
40  const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
41  const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
42  const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
43  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
44  const AElementwiseOperation a_element_op,
45  const BElementwiseOperation b_element_op,
46  const CElementwiseOperation c_element_op,
47  const Block2CTileMap block_2_ctile_map)
48 {
49 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
50  defined(__gfx94__))
51  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
52 
53  GridwiseGemm::template Run<HasMainK0BlockLoop>(
54  p_a_grid,
55  p_b_grid,
56  p_c_grid,
57  p_shared,
58  a_grid_desc_ak0_m_ak1,
59  b_grid_desc_bk0_n_bk1,
60  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
61  a_element_op,
62  b_element_op,
63  c_element_op,
64  block_2_ctile_map);
65 #else
66  ignore = p_a_grid;
67  ignore = p_b_grid;
68  ignore = p_c_grid;
69  ignore = a_grid_desc_ak0_m_ak1;
70  ignore = b_grid_desc_bk0_n_bk1;
71  ignore = c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
72  ignore = a_element_op;
73  ignore = b_element_op;
74  ignore = c_element_op;
75  ignore = block_2_ctile_map;
76 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
77 }
78 
79 template <
80  index_t BlockSize,
81  typename FloatAB,
82  typename FloatAcc,
83  typename FloatCShuffle,
84  typename FloatC,
85  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
86  typename AGridDesc_AK0_M_AK1,
87  typename BGridDesc_BK0_N_BK1,
88  typename CGridDesc_M_N,
89  typename AElementwiseOperation,
90  typename BElementwiseOperation,
91  typename CElementwiseOperation,
92  index_t MPerBlock,
93  index_t NPerBlock,
94  index_t KPerBlock,
95  index_t AK1Value,
96  index_t BK1Value,
97  index_t MPerXdl,
98  index_t NPerXdl,
99  index_t MXdlPerWave,
100  index_t NXdlPerWave,
101  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
102  typename ABlockTransferThreadClusterArrangeOrder,
103  typename ABlockTransferSrcAccessOrder,
104  index_t ABlockTransferSrcVectorDim,
105  index_t ABlockTransferSrcScalarPerVector,
106  index_t ABlockTransferDstScalarPerVector_K1,
107  bool AThreadTransferSrcResetCoordinateAfterRun,
108  bool ABlockLdsExtraM,
109  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
110  typename BBlockTransferThreadClusterArrangeOrder,
111  typename BBlockTransferSrcAccessOrder,
112  index_t BBlockTransferSrcVectorDim,
113  index_t BBlockTransferSrcScalarPerVector,
114  index_t BBlockTransferDstScalarPerVector_K1,
115  bool BThreadTransferSrcResetCoordinateAfterRun,
116  bool BBlockLdsExtraN,
117  index_t CShuffleMXdlPerWavePerShuffle,
118  index_t CShuffleNXdlPerWavePerShuffle,
119  typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
120  index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
121  index_t NumGemmKPrefetchStage = 1,
122  PipelineVersion PipelineVer = PipelineVersion::v1>
124 {
125  static constexpr auto I0 = Number<0>{};
126  static constexpr auto I1 = Number<1>{};
127  static constexpr auto I2 = Number<2>{};
128  static constexpr auto I3 = Number<3>{};
129  static constexpr auto I4 = Number<4>{};
130  static constexpr auto I5 = Number<5>{};
131  static constexpr auto I6 = Number<6>{};
132  static constexpr auto I7 = Number<7>{};
133 
134  // K1 should be Number<...>
135  static constexpr auto AK0 = Number<KPerBlock / AK1Value>{};
136  static constexpr auto BK0 = Number<KPerBlock / BK1Value>{};
137  static constexpr auto AK1 = Number<AK1Value>{};
138  static constexpr auto BK1 = Number<BK1Value>{};
139 
141 
143  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
144 
145  __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
146  {
147  constexpr auto max_lds_align = AK1;
148 
149  // A matrix in LDS memory, dst of blockwise copy
150  constexpr auto a_block_desc_ak0_m_ak1 = [&]() {
151  if constexpr(ABlockLdsExtraM)
152  {
156  }
157  else
158  {
160  make_tuple(AK0, Number<MPerBlock>{}, AK1), max_lds_align);
161  }
162  }();
163 
164  return a_block_desc_ak0_m_ak1;
165  }
166 
167  __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
168  {
169  constexpr auto max_lds_align = BK1;
170 
171  // B matrix in LDS memory, dst of blockwise copy
172  constexpr auto b_block_desc_bk0_n_bk1 = [&]() {
173  if constexpr(BBlockLdsExtraN)
174  {
178  }
179  else
180  {
182  make_tuple(BK0, Number<NPerBlock>{}, BK1), max_lds_align);
183  }
184  }();
185 
186  return b_block_desc_bk0_n_bk1;
187  }
188 
189  __host__ __device__ static constexpr auto
191  {
192  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
193  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
194 
195  constexpr auto
196  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
198  make_tuple(I1,
201  I1,
204 
205  return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
206  }
207 
208  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
209  {
210  // LDS allocation for A and B: be careful of alignment
211  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
212 
213  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
214 
215  constexpr auto a_block_space_size_aligned =
216  math::integer_least_multiple(a_block_desc_ak0_m_ak1.GetElementSpaceSize(), AK1);
217 
218  constexpr auto b_block_space_size_aligned =
219  math::integer_least_multiple(b_block_desc_bk0_n_bk1.GetElementSpaceSize(), BK1);
220 
221  // LDS allocation for C shuffle in LDS
222  constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
224 
225  constexpr auto c_block_size =
226  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
227  .GetElementSpaceSize();
228 
229  return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
230  sizeof(FloatAB),
231  c_block_size * sizeof(FloatCShuffle));
232  }
233 
234  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
235  template <typename Block2CTileMap>
236  __host__ __device__ static constexpr bool
237  CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
238  const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
239  const CGridDesc_M_N& c_grid_desc_m_n,
240  const Block2CTileMap& block_2_ctile_map)
241  {
242  // static_assert(is_known_at_compile_time<remove_cv_t<decltype(AK1)>>::value &&
243  // is_known_at_compile_time<remove_cv_t<decltype(BK1)>>::value,
244  // "wrong! K1 need to be known at compile-time");
245 
246  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
247  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
248  "Invalid tuning param!");
249 
250  const auto M = a_grid_desc_ak0_m_ak1.GetLength(I1);
251  const auto N = b_grid_desc_bk0_n_bk1.GetLength(I1);
252  const auto K = a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2);
253 
254  if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1)))
255  return false;
256 
257  if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
258  return false;
259 
260  // check gridwise gemm pipeline
261  const auto num_k_loop = K / KPerBlock;
262 
263  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
264  {
265  return false;
266  }
267 
268  if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
269  {
270  return false;
271  }
272 
273  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
274  return true;
275  }
276 
277  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
278  {
279  const index_t num_loop = K / KPerBlock;
280 
281  return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
282  }
283 
284  __host__ __device__ static constexpr auto
286  const CGridDesc_M_N& c_grid_desc_m_n)
287  {
288  const auto M = c_grid_desc_m_n.GetLength(I0);
289  const auto N = c_grid_desc_m_n.GetLength(I1);
290 
291  const auto MBlock = M / MPerBlock;
292  const auto NBlock = N / NPerBlock;
293 
294  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
295  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
296 
297  const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
299  c_grid_desc_m_n,
306 
307  return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
308  }
309 
310  // return block_id to C matrix tile idx (m0, n0) mapping
311  __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
312  const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
313  {
315  c_grid_desc_m_n);
316  }
320  CGridDesc_M_N{}))>;
321 
323  remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
324 
325  template <bool HasMainK0BlockLoop, typename Block2CTileMap>
326  __device__ static void
327  Run(const FloatAB* __restrict__ p_a_grid,
328  const FloatAB* __restrict__ p_b_grid,
329  FloatC* __restrict__ p_c_grid,
330  void* __restrict__ p_shared,
331  const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
332  const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
334  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
335  const AElementwiseOperation& a_element_op,
336  const BElementwiseOperation& b_element_op,
337  const CElementwiseOperation& c_element_op,
338  const Block2CTileMap& block_2_ctile_map)
339  {
340  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
341  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
342  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
343  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
344  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
345  p_c_grid,
346  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
347  .GetElementSpaceSize());
348 
349  // divide block work by [M, N]
350  const auto block_work_idx =
351  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
352 
353  if(!block_2_ctile_map.ValidCTileIndex(
354  block_work_idx,
355  make_tuple(
356  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
357  .GetLength(I0),
358  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
359  .GetLength(I3))))
360  {
361  return;
362  }
363 
364  // HACK: this force m/n_block_data_idx_on_grid into SGPR
365  const index_t m_block_data_idx_on_grid =
366  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
367 
368  const index_t n_block_data_idx_on_grid =
369  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
370 
371  // lds max alignment
372  constexpr auto max_lds_align = math::lcm(AK1, BK1);
373 
374  // A matrix in LDS memory, dst of blockwise copy
375  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
376 
377  // B matrix in LDS memory, dst of blockwise copy
378  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
379 
380  // A matrix blockwise copy
381  auto a_blockwise_copy =
383  AElementwiseOperation,
387  ABlockTransferThreadClusterLengths_AK0_M_AK1,
388  ABlockTransferThreadClusterArrangeOrder,
389  FloatAB,
390  FloatAB,
391  decltype(a_grid_desc_ak0_m_ak1),
392  decltype(a_block_desc_ak0_m_ak1),
393  ABlockTransferSrcAccessOrder,
395  ABlockTransferSrcVectorDim,
396  2,
397  ABlockTransferSrcScalarPerVector,
398  ABlockTransferDstScalarPerVector_K1,
399  1,
400  1,
401  AThreadTransferSrcResetCoordinateAfterRun,
402  true,
403  NumGemmKPrefetchStage>(
404  a_grid_desc_ak0_m_ak1,
405  make_multi_index(0, m_block_data_idx_on_grid, 0),
406  a_element_op,
407  a_block_desc_ak0_m_ak1,
408  make_multi_index(0, 0, 0),
410 
411  // B matrix blockwise copy
412  auto b_blockwise_copy =
414  BElementwiseOperation,
418  BBlockTransferThreadClusterLengths_BK0_N_BK1,
419  BBlockTransferThreadClusterArrangeOrder,
420  FloatAB,
421  FloatAB,
422  decltype(b_grid_desc_bk0_n_bk1),
423  decltype(b_block_desc_bk0_n_bk1),
424  BBlockTransferSrcAccessOrder,
426  BBlockTransferSrcVectorDim,
427  2,
428  BBlockTransferSrcScalarPerVector,
429  BBlockTransferDstScalarPerVector_K1,
430  1,
431  1,
432  BThreadTransferSrcResetCoordinateAfterRun,
433  true,
434  NumGemmKPrefetchStage>(
435  b_grid_desc_bk0_n_bk1,
436  make_multi_index(0, n_block_data_idx_on_grid, 0),
437  b_element_op,
438  b_block_desc_bk0_n_bk1,
439  make_multi_index(0, 0, 0),
441 
442  // GEMM definition
443  // c_mtx += transpose(a_mtx) * b_mtx
444  // a_mtx[K0PerBlock, MPerBlock] is in LDS
445  // b_mtx[K0PerBlock, NPerBlock] is in LDS
446  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
447  // register
448  // sanity check
449  constexpr auto lcm_AK1_BK1 = math::lcm(AK1, BK1);
450  constexpr bool is_single_rate_mfma =
452  lcm_AK1_BK1 <= 4) ||
453  (is_same<FloatAB, int8_t>::value && lcm_AK1_BK1 <= 8) ||
455  lcm_AK1_BK1 < 32))
456  ? true
457  : false;
458  constexpr auto is_scale_mfma = false;
459  constexpr index_t k_pack = math::max(
460  lcm_AK1_BK1,
462  selected_mfma.k_per_blk);
463 
464  auto blockwise_gemm =
466  FloatAB,
467  FloatAB,
468  FloatAcc,
469  decltype(a_block_desc_ak0_m_ak1),
470  decltype(b_block_desc_bk0_n_bk1),
471  MPerXdl,
472  NPerXdl,
473  MXdlPerWave,
474  NXdlPerWave,
475  k_pack>{};
476 
477  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
478 
479  // LDS allocation for A and B: be careful of alignment
480  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
481  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
482 
483  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
484  static_cast<FloatAB*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
485 
486  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
487  static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
488  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
489 
490  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
491  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
492 
493  // gridwise GEMM pipeline
494  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
495  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
496  KPerBlock);
497 
498  GridwiseGemmPipe::template Run<HasMainK0BlockLoop>(a_grid_desc_ak0_m_ak1,
499  a_block_desc_ak0_m_ak1,
500  a_blockwise_copy,
501  a_grid_buf,
502  a_block_buf,
503  a_block_slice_copy_step,
504  b_grid_desc_bk0_n_bk1,
505  b_block_desc_bk0_n_bk1,
506  b_blockwise_copy,
507  b_grid_buf,
508  b_block_buf,
509  b_block_slice_copy_step,
510  blockwise_gemm,
511  c_thread_buf,
512  num_k_block_main_loop);
513 
514  // shuffle C and write out
515  {
516  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
517  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
518  "wrong!");
519 
520  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
521  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
522 
523  // TODO: hacky, fix it!
524  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
525  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
526 
527  // TODO: hacky, fix it!
528  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
529  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
530  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
531 
532  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
533  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
534  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
535  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
536  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
537  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
538  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
539  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
540 
541  constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
543 
544  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
545  static_cast<FloatCShuffle*>(p_shared),
546  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
547  .GetElementSpaceSize());
548 
549  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
550  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
551  make_tuple(
552  make_freeze_transform(I0), // freeze mblock
554  Number<CShuffleMXdlPerWavePerShuffle>{}), // M0 (MXdlPerWave) per shuffle
556  make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
557  make_freeze_transform(I0), // freeze nblock
559  Number<CShuffleNXdlPerWavePerShuffle>{}), // N0 (NXdlPerWave) per shuffle
561  make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
563  Sequence<1>{},
564  Sequence<2>{},
565  Sequence<3>{},
566  Sequence<4>{},
567  Sequence<5>{}),
569  Sequence<0>{},
571  Sequence<>{},
572  Sequence<1>{},
573  Sequence<3, 7>{}));
574 
575  // calculate origin of thread output tensor on global memory
576  // blockwise GEMM c matrix starting index
577  const auto c_thread_mtx_on_block =
578  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
579 
580  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
581  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
582 
583  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
585  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
588 
589  const auto m_thread_data_on_block_idx =
590  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
591  make_multi_index(m_thread_data_on_block));
592 
593  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
598 
599  const auto n_thread_data_on_block_idx =
600  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
601  make_multi_index(n_thread_data_on_block));
602 
603  // VGPR to LDS
604  auto c_thread_copy_vgpr_to_lds =
606  FloatCShuffle,
607  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
608  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
610  Sequence<CShuffleMXdlPerWavePerShuffle,
611  CShuffleNXdlPerWavePerShuffle,
612  I1,
613  I1,
614  M2,
615  I1,
616  M4,
617  I1>,
619  7,
620  1,
622  1,
623  true>{
624  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
626  0,
627  m_thread_data_on_block_idx[I1],
628  n_thread_data_on_block_idx[I1],
629  m_thread_data_on_block_idx[I2],
630  m_thread_data_on_block_idx[I3],
631  m_thread_data_on_block_idx[I4],
632  n_thread_data_on_block_idx[I2]),
634 
635  // LDS to global
636  auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
637  ThisThreadBlock, // ThreadGroup
638  CElementwiseOperation, // ElementwiseOperation,
639  CGlobalMemoryDataOperation, // DstInMemOp,
640  Sequence<1,
641  CShuffleMXdlPerWavePerShuffle,
642  MWave * MPerXdl,
643  1,
644  CShuffleNXdlPerWavePerShuffle,
645  NWave * NPerXdl>, // BlockSliceLengths,
646  CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
647  Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder,
648  FloatCShuffle, // typename SrcData,
649  FloatC, // typename DstData,
650  decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
651  decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
652  Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
653  5, // index_t VectorDim,
654  CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
655  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
656  false> // bool ThreadTransferDstResetCoordinateAfterRun>
657  {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
658  make_multi_index(0, 0, 0, 0, 0, 0),
659  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
660  make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
661  c_element_op};
662 
663  constexpr auto mxdlperwave_forward_step =
664  make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0);
665  constexpr auto nxdlperwave_forward_step =
666  make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0);
667  constexpr auto nxdlperwave_backward_step =
668  make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0);
669 
670  static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) {
671  constexpr auto mxdlperwave = mxdlperwave_iter;
672 
673  static_for<0,
674  NXdlPerWave,
675  CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) {
676  constexpr bool nxdlperwave_forward_sweep =
677  (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0);
678 
679  constexpr index_t nxdlperwave_value =
680  nxdlperwave_forward_sweep
681  ? nxdlperwave_iter
682  : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle);
683 
684  constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
685 
686  // make sure it's safe to do ds_write
687  block_sync_lds();
688 
689  // VGPR to LDS
690  c_thread_copy_vgpr_to_lds.Run(
691  c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
692  make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
693  c_thread_buf,
694  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
695  c_shuffle_block_buf);
696 
697  // make sure it's safe to do ds_read
698  block_sync_lds();
699 
700  // LDS to global
701  c_block_copy_lds_to_global.Run(
702  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
703  c_shuffle_block_buf,
704  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
705  c_grid_buf);
706 
707  // move on nxdlperwave dimension
708  if constexpr(nxdlperwave_forward_sweep &&
709  (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle))
710  {
711  c_block_copy_lds_to_global.MoveDstSliceWindow(
712  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
713  nxdlperwave_forward_step);
714  }
715  else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
716  {
717  c_block_copy_lds_to_global.MoveDstSliceWindow(
718  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
719  nxdlperwave_backward_step);
720  }
721  });
722 
723  // move on mxdlperwave dimension
724  if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle)
725  {
726  c_block_copy_lds_to_global.MoveDstSliceWindow(
727  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
728  mxdlperwave_forward_step);
729  }
730  });
731  }
732  }
733 };
734 
735 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:269
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:278
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:300
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
__global__ void kernel_gemm_xdlops_v3r1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r1.hpp:36
Definition: block_to_ctile_map.hpp:260
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
Definition: gridwise_gemm_xdlops_v3r1.hpp:124
static __device__ void Run(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, void *__restrict__ p_shared, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl &c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r1.hpp:327
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdlops_v3r1.hpp:277
__host__ static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdlops_v3r1.hpp:167
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_v3r1.hpp:140
static constexpr auto I6
Definition: gridwise_gemm_xdlops_v3r1.hpp:131
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition: gridwise_gemm_xdlops_v3r1.hpp:311
static constexpr auto I5
Definition: gridwise_gemm_xdlops_v3r1.hpp:130
static constexpr auto I7
Definition: gridwise_gemm_xdlops_v3r1.hpp:132
static constexpr auto I4
Definition: gridwise_gemm_xdlops_v3r1.hpp:129
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r1.hpp:237
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdlops_v3r1.hpp:143
static constexpr auto I0
Definition: gridwise_gemm_xdlops_v3r1.hpp:125
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
Definition: gridwise_gemm_xdlops_v3r1.hpp:320
static constexpr auto AK0
Definition: gridwise_gemm_xdlops_v3r1.hpp:135
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v3r1.hpp:285
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))> DefaultBlock2CTileMap
Definition: gridwise_gemm_xdlops_v3r1.hpp:323
__host__ static constexpr __device__ auto GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl()
Definition: gridwise_gemm_xdlops_v3r1.hpp:190
static constexpr auto BK0
Definition: gridwise_gemm_xdlops_v3r1.hpp:136
__host__ static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdlops_v3r1.hpp:145
static constexpr auto BK1
Definition: gridwise_gemm_xdlops_v3r1.hpp:138
static constexpr auto I1
Definition: gridwise_gemm_xdlops_v3r1.hpp:126
static constexpr auto I3
Definition: gridwise_gemm_xdlops_v3r1.hpp:128
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_v3r1.hpp:208
static constexpr auto AK1
Definition: gridwise_gemm_xdlops_v3r1.hpp:137
static constexpr auto I2
Definition: gridwise_gemm_xdlops_v3r1.hpp:127
Definition: xdlops_gemm.hpp:942
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:308