/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v3r3.hpp Source File
gridwise_gemm_xdlops_v3r3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
17 
18 namespace ck {
19 
20 template <typename GridwiseGemm,
21  typename FloatAB,
22  typename FloatC,
23  typename AGridDesc_K0_M_K1,
24  typename BGridDesc_K0_N_K1,
25  typename CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
26  typename C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
27  typename C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
28  typename AElementwiseOperation,
29  typename BElementwiseOperation,
30  typename CElementwiseOperation,
31  typename Block2CTileMap,
32  bool HasMainKBlockLoop>
33 __global__ void
34 #if CK_USE_LAUNCH_BOUNDS
36 #endif
38  const FloatAB* __restrict__ p_a_grid,
39  const FloatAB* __restrict__ p_b_grid,
40  FloatC* __restrict__ p_c_grid,
41  const FloatC* __restrict__ p_c0_grid,
42  const FloatC* __restrict__ p_c1_grid,
43  const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
44  const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
45  const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
46  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
47  const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
48  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
49  const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
50  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
51  const AElementwiseOperation a_element_op,
52  const BElementwiseOperation b_element_op,
53  const CElementwiseOperation c_element_op,
54  const Block2CTileMap block_2_ctile_map)
55 {
56 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
57  defined(__gfx94__))
58  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
59 
60  GridwiseGemm::template Run<HasMainKBlockLoop>(
61  p_a_grid,
62  p_b_grid,
63  p_c_grid,
64  p_c0_grid,
65  p_c1_grid,
66  p_shared,
67  a_grid_desc_k0_m_k1,
68  b_grid_desc_k0_n_k1,
69  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
70  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
71  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
72  a_element_op,
73  b_element_op,
74  c_element_op,
75  block_2_ctile_map);
76 #else
77  ignore = p_a_grid;
78  ignore = p_b_grid;
79  ignore = p_c_grid;
80  ignore = p_c0_grid;
81  ignore = p_c1_grid;
82  ignore = a_grid_desc_k0_m_k1;
83  ignore = b_grid_desc_k0_n_k1;
84  ignore = c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
85  ignore = c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
86  ignore = c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
87  ignore = a_element_op;
88  ignore = b_element_op;
89  ignore = c_element_op;
90  ignore = block_2_ctile_map;
91 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
92 }
93 
94 template <
95  index_t BlockSize,
96  typename FloatAB,
97  typename FloatAcc,
98  typename FloatC,
99  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
100  typename AGridDesc_K0_M_K1,
101  typename BGridDesc_K0_N_K1,
102  typename CGridDesc_M_N,
103  typename C0GridDesc_M_N,
104  typename C1GridDesc_M_N,
105  typename AElementwiseOperation,
106  typename BElementwiseOperation,
107  typename CElementwiseOperation,
108  index_t MPerBlock,
109  index_t NPerBlock,
110  index_t K0PerBlock,
111  index_t MPerXdl,
112  index_t NPerXdl,
113  index_t K1Value,
114  index_t MXdlPerWave,
115  index_t NXdlPerWave,
116  typename ABlockTransferThreadClusterLengths_K0_M_K1,
117  typename ABlockTransferThreadClusterArrangeOrder,
118  typename ABlockTransferSrcAccessOrder,
119  index_t ABlockTransferSrcVectorDim,
120  index_t ABlockTransferSrcScalarPerVector,
121  index_t ABlockTransferDstScalarPerVector_K1,
122  bool AThreadTransferSrcResetCoordinateAfterRun,
123  bool ABlockLdsExtraM,
124  typename BBlockTransferThreadClusterLengths_K0_N_K1,
125  typename BBlockTransferThreadClusterArrangeOrder,
126  typename BBlockTransferSrcAccessOrder,
127  index_t BBlockTransferSrcVectorDim,
128  index_t BBlockTransferSrcScalarPerVector,
129  index_t BBlockTransferDstScalarPerVector_K1,
130  bool BThreadTransferSrcResetCoordinateAfterRun,
131  bool BBlockLdsExtraN,
132  index_t CShuffleMXdlPerWavePerShuffle,
133  index_t CShuffleNXdlPerWavePerShuffle,
134  typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
135  index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
136  index_t NumGemmKPrefetchStage = 1,
137  PipelineVersion PipelineVer = PipelineVersion::v1>
139 {
140  static constexpr auto I0 = Number<0>{};
141  static constexpr auto I1 = Number<1>{};
142  static constexpr auto I2 = Number<2>{};
143  static constexpr auto I3 = Number<3>{};
144  static constexpr auto I4 = Number<4>{};
145  static constexpr auto I5 = Number<5>{};
146  static constexpr auto I6 = Number<6>{};
147  static constexpr auto I7 = Number<7>{};
148 
149  // K1 should be Number<...>
150  static constexpr auto K1 = Number<K1Value>{};
151 
153 
155  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
156 
157  __host__ __device__ static constexpr auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
158  {
159  constexpr auto max_lds_align = K1;
160 
161  // A matrix in LDS memory, dst of blockwise copy
162  constexpr auto a_block_desc_k0_m_k1 = [&]() {
163  if constexpr(ABlockLdsExtraM)
164  {
168  }
169  else
170  {
172  make_tuple(Number<K0PerBlock>{}, Number<MPerBlock>{}, K1), max_lds_align);
173  }
174  }();
175 
176  return a_block_desc_k0_m_k1;
177  }
178 
179  __host__ __device__ static constexpr auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
180  {
181  constexpr auto max_lds_align = K1;
182 
183  // B matrix in LDS memory, dst of blockwise copy
184  constexpr auto b_block_desc_k0_n_k1 = [&]() {
185  if constexpr(BBlockLdsExtraN)
186  {
190  }
191  else
192  {
194  make_tuple(Number<K0PerBlock>{}, Number<NPerBlock>{}, K1), max_lds_align);
195  }
196  }();
197 
198  return b_block_desc_k0_n_k1;
199  }
200 
201  __host__ __device__ static constexpr auto
203  {
204  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
205  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
206 
207  constexpr auto
208  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
210  make_tuple(I1,
213  I1,
216 
217  return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
218  }
219 
220  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
221  {
222  // LDS allocation for A and B: be careful of alignment
223  constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
224 
225  constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
226 
227  constexpr auto max_lds_align = K1;
228 
229  constexpr auto a_block_space_size_aligned =
230  math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
231 
232  constexpr auto b_block_space_size_aligned =
233  math::integer_least_multiple(b_block_desc_k0_n_k1.GetElementSpaceSize(), max_lds_align);
234 
235  // LDS allocation for C shuffle in LDS
236  constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
238 
239  constexpr auto c_block_size =
240  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
241  .GetElementSpaceSize();
242 
243  return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
244  sizeof(FloatAB),
245  c_block_size * sizeof(FloatC));
246  }
247 
248  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
249  template <typename Block2CTileMap>
250  __host__ __device__ static constexpr bool
251  CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
252  const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
253  const CGridDesc_M_N& c_grid_desc_m_n,
254  const Block2CTileMap& block_2_ctile_map)
255  {
256  static_assert(is_known_at_compile_time<remove_cv_t<decltype(K1)>>::value,
257  "wrong! K1 need to be known at compile-time");
258 
259  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
260  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
261  "Invalid tuning param!");
262 
263  const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
264  const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
265  const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
266 
267  if(!(M == c_grid_desc_m_n.GetLength(I0) && N == c_grid_desc_m_n.GetLength(I1) &&
268  K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
269  K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
270  return false;
271 
272  if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
273  return false;
274 
275  // check gridwise gemm pipeline
276  const auto num_k_loop = K0 / K0PerBlock;
277 
278  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
279  {
280  return false;
281  }
282 
283  if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
284  {
285  return false;
286  }
287 
288  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
289  return true;
290  }
291 
292  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
293  {
294  const index_t num_loop = K / (K0PerBlock * K1);
295 
296  return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
297  }
298 
299  template <typename CGridDesc_M_N_>
300  __host__ __device__ static constexpr auto
302  const CGridDesc_M_N_& c_grid_desc_m_n)
303  {
304  const auto M = c_grid_desc_m_n.GetLength(I0);
305  const auto N = c_grid_desc_m_n.GetLength(I1);
306 
307  const auto MBlock = M / MPerBlock;
308  const auto NBlock = N / NPerBlock;
309 
310  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
311  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
312 
313  const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
315  c_grid_desc_m_n,
322 
323  return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
324  }
325 
326  // return block_id to C matrix tile idx (m0, n0) mapping
327  __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(
328  const CGridDesc_M_N& c_grid_desc_m_n, index_t /* M01 */, index_t /* N01 */)
329  {
331  c_grid_desc_m_n);
332  }
336  CGridDesc_M_N{}))>;
337 
341  C0GridDesc_M_N{}))>;
342 
346  C1GridDesc_M_N{}))>;
347 
349  remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))>;
350 
351  template <bool HasMainKBlockLoop, typename Block2CTileMap>
352  __device__ static void
353  Run(const FloatAB* __restrict__ p_a_grid,
354  const FloatAB* __restrict__ p_b_grid,
355  FloatC* __restrict__ p_c_grid,
356  const FloatC* __restrict__ p_c0_grid,
357  const FloatC* __restrict__ p_c1_grid,
358  void* __restrict__ p_shared,
359  const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
360  const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
362  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
364  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
366  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
367  const AElementwiseOperation& a_element_op,
368  const BElementwiseOperation& b_element_op,
369  const CElementwiseOperation& c_element_op,
370  const Block2CTileMap& block_2_ctile_map)
371  {
372  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
373  p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
374  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
375  p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
376  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
377  p_c_grid,
378  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
379  .GetElementSpaceSize());
380  auto c0_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
381  p_c0_grid,
382  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
383  .GetElementSpaceSize());
384  auto c1_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
385  p_c1_grid,
386  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
387  .GetElementSpaceSize());
388 
389  const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
390 
391  // divide block work by [M, N]
392  const auto block_work_idx =
393  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
394 
395  if(!block_2_ctile_map.ValidCTileIndex(
396  block_work_idx,
397  make_tuple(
398  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
399  .GetLength(I0),
400  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
401  .GetLength(I3))))
402  {
403  return;
404  }
405 
406  // HACK: this force m/n_block_data_idx_on_grid into SGPR
407  const index_t m_block_data_idx_on_grid =
408  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
409 
410  const index_t n_block_data_idx_on_grid =
411  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
412 
413  // lds max alignment
414  constexpr auto max_lds_align = K1;
415 
416  // A matrix in LDS memory, dst of blockwise copy
417  constexpr auto a_block_desc_k0_m_k1 = GetABlockDescriptor_K0PerBlock_MPerBlock_K1();
418 
419  // B matrix in LDS memory, dst of blockwise copy
420  constexpr auto b_block_desc_k0_n_k1 = GetBBlockDescriptor_K0PerBlock_NPerBlock_K1();
421 
422  // A matrix blockwise copy
423  auto a_blockwise_copy =
425  AElementwiseOperation,
429  ABlockTransferThreadClusterLengths_K0_M_K1,
430  ABlockTransferThreadClusterArrangeOrder,
431  FloatAB,
432  FloatAB,
433  decltype(a_grid_desc_k0_m_k1),
434  decltype(a_block_desc_k0_m_k1),
435  ABlockTransferSrcAccessOrder,
437  ABlockTransferSrcVectorDim,
438  2,
439  ABlockTransferSrcScalarPerVector,
440  ABlockTransferDstScalarPerVector_K1,
441  1,
442  1,
443  AThreadTransferSrcResetCoordinateAfterRun,
444  true>(
445  a_grid_desc_k0_m_k1,
446  make_multi_index(0, m_block_data_idx_on_grid, 0),
447  a_element_op,
448  a_block_desc_k0_m_k1,
449  make_multi_index(0, 0, 0),
451 
452  // B matrix blockwise copy
453  auto b_blockwise_copy =
455  BElementwiseOperation,
459  BBlockTransferThreadClusterLengths_K0_N_K1,
460  BBlockTransferThreadClusterArrangeOrder,
461  FloatAB,
462  FloatAB,
463  decltype(b_grid_desc_k0_n_k1),
464  decltype(b_block_desc_k0_n_k1),
465  BBlockTransferSrcAccessOrder,
467  BBlockTransferSrcVectorDim,
468  2,
469  BBlockTransferSrcScalarPerVector,
470  BBlockTransferDstScalarPerVector_K1,
471  1,
472  1,
473  BThreadTransferSrcResetCoordinateAfterRun,
474  true>(
475  b_grid_desc_k0_n_k1,
476  make_multi_index(0, n_block_data_idx_on_grid, 0),
477  b_element_op,
478  b_block_desc_k0_n_k1,
479  make_multi_index(0, 0, 0),
481 
482  // GEMM definition
483  // c_mtx += transpose(a_mtx) * b_mtx
484  // a_mtx[K0PerBlock, MPerBlock] is in LDS
485  // b_mtx[K0PerBlock, NPerBlock] is in LDS
486  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
487  // register
488  // sanity check
489 
490  auto blockwise_gemm =
492  FloatAB,
493  FloatAB,
494  FloatAcc,
495  decltype(a_block_desc_k0_m_k1),
496  decltype(b_block_desc_k0_n_k1),
497  MPerXdl,
498  NPerXdl,
499  MXdlPerWave,
500  NXdlPerWave,
501  K1>{};
502 
503  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
504 
505  // LDS allocation for A and B: be careful of alignment
506  constexpr auto a_block_space_size_aligned =
507  math::integer_least_multiple(a_block_desc_k0_m_k1.GetElementSpaceSize(), max_lds_align);
508 
509  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
510  static_cast<FloatAB*>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
511 
512  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
513  static_cast<FloatAB*>(p_shared) + a_block_space_size_aligned,
514  b_block_desc_k0_n_k1.GetElementSpaceSize());
515 
516  constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
517  constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0);
518 
519  // gridwise GEMM pipeline
520  const index_t K0BlockMainLoop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
521 
522  GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
523  a_block_desc_k0_m_k1,
524  a_blockwise_copy,
525  a_grid_buf,
526  a_block_buf,
527  a_block_slice_copy_step,
528  b_grid_desc_k0_n_k1,
529  b_block_desc_k0_n_k1,
530  b_blockwise_copy,
531  b_grid_buf,
532  b_block_buf,
533  b_block_slice_copy_step,
534  blockwise_gemm,
535  c_thread_buf,
536  K0BlockMainLoop);
537 
538  // shuffle C and write out
539  {
540  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
541  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
542  "wrong!");
543 
544  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
545  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
546 
547  // TODO: hacky, fix it!
548  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
549  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
550 
551  // TODO: hacky, fix it!
552  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
553  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
554  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
555 
556  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
557  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
558  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
559  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
560  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
561  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
562  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
563  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
564 
565  constexpr auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
567 
568  auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
569  static_cast<FloatC*>(p_shared),
570  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
571  .GetElementSpaceSize());
572 
573  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
574  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
575  make_tuple(make_freeze_transform(I0), // freeze mblock
577  Number<CShuffleMXdlPerWavePerShuffle>{}), // M0 (MXdlPerWave) per
578  // shuffle
580  make_tuple(M1, M2, M3, M4)), // M1 = MWave, M2 * M3 * M4 = MPerXdl
581  make_freeze_transform(I0), // freeze nblock
583  Number<CShuffleNXdlPerWavePerShuffle>{}), // N0 (NXdlPerWave) per
584  // shuffle
586  make_tuple(N1, N2))), // M1 = MWave, M2 * M3 * M4 = MPerXdl
588  Sequence<1>{},
589  Sequence<2>{},
590  Sequence<3>{},
591  Sequence<4>{},
592  Sequence<5>{}),
594  Sequence<0>{},
596  Sequence<>{},
597  Sequence<1>{},
598  Sequence<3, 7>{})
599 
600  );
601 
602  // calculate origin of thread output tensor on global memory
603  // blockwise GEMM c matrix starting index
604  const auto c_thread_mtx_on_block =
605  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
606 
607  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
608  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
609 
610  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
612  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
615 
616  const auto m_thread_data_on_block_idx =
617  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
618  make_multi_index(m_thread_data_on_block));
619 
620  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
625 
626  const auto n_thread_data_on_block_idx =
627  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
628  make_multi_index(n_thread_data_on_block));
629 
630  // VGPR to LDS
631  auto c_thread_copy_vgpr_to_lds =
633  FloatC,
634  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
635  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
637  Sequence<CShuffleMXdlPerWavePerShuffle,
638  CShuffleNXdlPerWavePerShuffle,
639  I1,
640  I1,
641  M2,
642  I1,
643  M4,
644  I1>,
646  7,
647  1,
649  1,
650  true>{
651  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
653  0,
654  m_thread_data_on_block_idx[I1],
655  n_thread_data_on_block_idx[I1],
656  m_thread_data_on_block_idx[I2],
657  m_thread_data_on_block_idx[I3],
658  m_thread_data_on_block_idx[I4],
659  n_thread_data_on_block_idx[I2]),
661 
662  auto c_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r3<
663  ThisThreadBlock, // ThreadGroup
664  CElementwiseOperation, // ElementwiseOperation,
665  CGlobalMemoryDataOperation, // DstInMemOp,
666  Sequence<1,
667  CShuffleMXdlPerWavePerShuffle,
668  MWave * MPerXdl,
669  1,
670  CShuffleNXdlPerWavePerShuffle,
671  NWave * NPerXdl>, // BlockSliceLengths,
672  CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
673  Sequence<0, 1, 2, 3, 4, 5>, // typename ThreadClusterArrangeOrder,
674  FloatC, // typename Src0Data,
675  FloatC, // typename Src1Data,
676  FloatC, // typename Src2Data,
677  FloatC, // typename DstData,
678  decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
679  decltype(c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
680  decltype(c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
681  decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
682  Sequence<0, 1, 2, 3, 4, 5>, // typename DimAccessOrder,
683  5, // index_t VectorDim,
684  CBlockTransferScalarPerVector_NWaveNPerXdl, // index_t ScalarPerVector,
685  true, // bool ThreadTransferSrc0ResetCoordinateAfterRun,
686  false, // bool ThreadTransferSrc1ResetCoordinateAfterRun,
687  false, // bool ThreadTransferSrc2ResetCoordinateAfterRun,
688  false> // bool ThreadTransferDstResetCoordinateAfterRun>
689  {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
690  make_multi_index(0, 0, 0, 0, 0, 0),
691  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
692  make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
693  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
694  make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
695  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
696  make_multi_index(block_work_idx[I0], 0, 0, block_work_idx[I1], 0, 0),
697  c_element_op};
698 
699  constexpr auto mxdlperwave_forward_step =
700  make_multi_index(0, CShuffleMXdlPerWavePerShuffle, 0, 0, 0, 0);
701  constexpr auto nxdlperwave_forward_step =
702  make_multi_index(0, 0, 0, 0, CShuffleNXdlPerWavePerShuffle, 0);
703  constexpr auto nxdlperwave_backward_step =
704  make_multi_index(0, 0, 0, 0, -CShuffleNXdlPerWavePerShuffle, 0);
705 
706  static_for<0, MXdlPerWave, CShuffleMXdlPerWavePerShuffle>{}([&](auto mxdlperwave_iter) {
707  constexpr auto mxdlperwave = mxdlperwave_iter;
708 
709  static_for<0,
710  NXdlPerWave,
711  CShuffleNXdlPerWavePerShuffle>{}([&](auto nxdlperwave_iter) {
712  constexpr bool nxdlperwave_forward_sweep =
713  (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0);
714 
715  constexpr index_t nxdlperwave_value =
716  nxdlperwave_forward_sweep
717  ? nxdlperwave_iter
718  : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle);
719 
720  constexpr auto nxdlperwave = Number<nxdlperwave_value>{};
721 
722  // make sure it's safe to do ds_write
723  block_sync_lds();
724 
725  // VGPR to LDS
726  c_thread_copy_vgpr_to_lds.Run(
727  c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
728  make_tuple(mxdlperwave, nxdlperwave, I0, I0, I0, I0, I0, I0),
729  c_thread_buf,
730  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
731  c_block_buf);
732 
733  // make sure it's safe to do ds_read
734  block_sync_lds();
735 
736  // LDS to global
737  c_block_copy_lds_to_global.Run(
738  c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
739  c_block_buf,
740  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
741  c0_grid_buf,
742  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
743  c1_grid_buf,
744  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
745  c_grid_buf);
746 
747  // move on nxdlperwave dimension
748  if constexpr(nxdlperwave_forward_sweep &&
749  (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle))
750  {
751  c_block_copy_lds_to_global.MoveSrc1SliceWindow(
752  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
753  nxdlperwave_forward_step);
754 
755  c_block_copy_lds_to_global.MoveSrc2SliceWindow(
756  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
757  nxdlperwave_forward_step);
758 
759  c_block_copy_lds_to_global.MoveDstSliceWindow(
760  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
761  nxdlperwave_forward_step);
762  }
763  else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
764  {
765  c_block_copy_lds_to_global.MoveSrc1SliceWindow(
766  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
767  nxdlperwave_backward_step);
768 
769  c_block_copy_lds_to_global.MoveSrc2SliceWindow(
770  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
771  nxdlperwave_backward_step);
772 
773  c_block_copy_lds_to_global.MoveDstSliceWindow(
774  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
775  nxdlperwave_backward_step);
776  }
777  });
778 
779  // move on mxdlperwave dimension
780  if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle)
781  {
782  c_block_copy_lds_to_global.MoveSrc1SliceWindow(
783  c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
784  mxdlperwave_forward_step);
785 
786  c_block_copy_lds_to_global.MoveSrc2SliceWindow(
787  c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
788  mxdlperwave_forward_step);
789 
790  c_block_copy_lds_to_global.MoveDstSliceWindow(
791  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
792  mxdlperwave_forward_step);
793  }
794  });
795  }
796  }
797 };
798 
799 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:34
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
__global__ void kernel_gemm_xdlops_v3r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const FloatC *__restrict__ p_c0_grid, const FloatC *__restrict__ p_c1_grid, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r3.hpp:37
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:17
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:298
Definition: block_to_ctile_map.hpp:260
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
Definition: gridwise_gemm_xdlops_v3r3.hpp:139
static constexpr auto I3
Definition: gridwise_gemm_xdlops_v3r3.hpp:143
static __device__ void Run(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const FloatC *__restrict__ p_c0_grid, const FloatC *__restrict__ p_c1_grid, void *__restrict__ p_shared, const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl &c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl &c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl &c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r3.hpp:353
static constexpr auto I0
Definition: gridwise_gemm_xdlops_v3r3.hpp:140
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition: gridwise_gemm_xdlops_v3r3.hpp:327
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(const CGridDesc_M_N_ &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v3r3.hpp:301
static constexpr auto I2
Definition: gridwise_gemm_xdlops_v3r3.hpp:142
static constexpr auto I5
Definition: gridwise_gemm_xdlops_v3r3.hpp:145
__host__ static constexpr __device__ auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
Definition: gridwise_gemm_xdlops_v3r3.hpp:179
static constexpr auto I7
Definition: gridwise_gemm_xdlops_v3r3.hpp:147
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))> DefaultBlock2CTileMap
Definition: gridwise_gemm_xdlops_v3r3.hpp:349
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdlops_v3r3.hpp:292
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_v3r3.hpp:220
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
Definition: gridwise_gemm_xdlops_v3r3.hpp:336
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(C1GridDesc_M_N{}))> C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
Definition: gridwise_gemm_xdlops_v3r3.hpp:346
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(C0GridDesc_M_N{}))> C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
Definition: gridwise_gemm_xdlops_v3r3.hpp:341
__host__ static constexpr __device__ auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
Definition: gridwise_gemm_xdlops_v3r3.hpp:157
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdlops_v3r3.hpp:155
__host__ static constexpr __device__ auto GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl()
Definition: gridwise_gemm_xdlops_v3r3.hpp:202
static constexpr auto I4
Definition: gridwise_gemm_xdlops_v3r3.hpp:144
static constexpr auto I6
Definition: gridwise_gemm_xdlops_v3r3.hpp:146
static constexpr auto K1
Definition: gridwise_gemm_xdlops_v3r3.hpp:150
static constexpr auto I1
Definition: gridwise_gemm_xdlops_v3r3.hpp:141
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r3.hpp:251
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_v3r3.hpp:152
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r3.hpp:40
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:10
Definition: is_known_at_compile_time.hpp:14
Definition: functional2.hpp:31
Definition: unary_element_wise_operation.hpp:241