/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp Source File
gridwise_gemm_multiple_abd_xdl_cshuffle.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
16 
19 
20 namespace ck {
21 
22 // GEMM:
23 // input : A0[M, K], A1[M, K]
24 // input : B0[N, K], B1[N, K]
25 // input : D0[M, N], D1[M, N], ...
26 // output : E[M, N]
27 // C = a_op(A) * b_op(B)
28 // E = cde_op(C, D0, D1, ...)
29 // Assume:
30 // D0, D1, ... and E have the same layout
31 template <typename AsDataType,
32  typename BsDataType,
33  typename AComputeDataType_,
34  typename AccDataType,
35  typename CShuffleDataType,
36  typename DsDataType,
37  typename EDataType,
38  typename AElementwiseOperation,
39  typename BElementwiseOperation,
40  typename CDEElementwiseOperation,
41  InMemoryDataOperationEnum EGlobalMemoryDataOperation,
42  index_t NumGemmKPrefetchStage,
43  index_t BlockSize,
44  index_t MPerBlock,
45  index_t NPerBlock,
46  index_t KPerBlock,
47  index_t AK1Value,
48  index_t BK1Value,
49  index_t MPerXdl,
50  index_t NPerXdl,
51  index_t MXdlPerWave,
52  index_t NXdlPerWave,
53  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
54  typename ABlockTransferThreadClusterArrangeOrder,
55  typename ABlockTransferSrcAccessOrder,
56  index_t ABlockTransferSrcVectorDim,
57  index_t ABlockTransferSrcScalarPerVector,
58  index_t ABlockTransferDstScalarPerVector_AK1,
59  bool AThreadTransferSrcResetCoordinateAfterRun,
60  index_t ABlockLdsExtraM,
61  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
62  typename BBlockTransferThreadClusterArrangeOrder,
63  typename BBlockTransferSrcAccessOrder,
64  index_t BBlockTransferSrcVectorDim,
65  index_t BBlockTransferSrcScalarPerVector,
66  index_t BBlockTransferDstScalarPerVector_BK1,
67  bool BThreadTransferSrcResetCoordinateAfterRun,
68  index_t BBlockLdsExtraN,
69  index_t CShuffleMXdlPerWavePerShuffle,
70  index_t CShuffleNXdlPerWavePerShuffle,
71  typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
72  index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
73  LoopScheduler LoopSched,
75  typename BComputeDataType_ = AComputeDataType_>
77 {
78  static constexpr index_t NumATensor = AsDataType::Size();
79  static constexpr index_t NumBTensor = BsDataType::Size();
80  static constexpr index_t NumDTensor = DsDataType::Size();
81 
83 
84  static constexpr auto I0 = Number<0>{};
85  static constexpr auto I1 = Number<1>{};
86  static constexpr auto I2 = Number<2>{};
87  static constexpr auto I3 = Number<3>{};
88  static constexpr auto I4 = Number<4>{};
89  static constexpr auto I5 = Number<5>{};
90  static constexpr auto I6 = Number<6>{};
91  static constexpr auto I7 = Number<7>{};
92 
93  // K1 should be Number<...>
94  static constexpr auto AK1 = Number<AK1Value>{};
95  static constexpr auto BK1 = Number<BK1Value>{};
96  static constexpr auto AK0PerBlock = Number<KPerBlock / AK1Value>{};
97  static constexpr auto BK0PerBlock = Number<KPerBlock / BK1Value>{};
98 
100 
102  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
103 
104 #if CK_GFX90A_DENORM_WORKAROUND
105  using AComputeDataType =
107  using BComputeDataType =
109 #else
110  using AComputeDataType = AComputeDataType_;
111  using BComputeDataType = BComputeDataType_;
112 #endif
113 
114  __host__ __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
115  {
116  // A matrix in LDS memory, dst of blockwise copy
120  }
121 
122  __host__ __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
123  {
124  // B matrix in LDS memory, dst of blockwise copy
128  }
129 
130  __host__ __device__ static constexpr auto
132  {
133  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
134  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
135 
136  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
138  make_tuple(I1,
140  I1,
142 
143  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
144  }
145 
146  static constexpr auto MakeAsGridPointer()
147  {
148  return generate_tuple(
149  [&](auto i) {
150  using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
151 
152  return static_cast<const ADataType*>(nullptr);
153  },
155  }
156 
157  static constexpr auto MakeBsGridPointer()
158  {
159  return generate_tuple(
160  [&](auto i) {
161  using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
162 
163  return static_cast<const BDataType*>(nullptr);
164  },
166  }
167 
168  // ck::Tuple<const D0DataType*, const D1DataType*, ...>
169  static constexpr auto MakeDsGridPointer()
170  {
171  return generate_tuple(
172  [&](auto i) {
173  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
174 
175  return static_cast<const DDataType*>(nullptr);
176  },
178  }
179 
180  __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
181  {
182  // LDS allocation for A and B: be careful of alignment
183  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
184  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
185 
186  // lds max alignment
187  constexpr auto max_lds_align = math::lcm(AK1, BK1);
188 
189  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
190  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
191 
192  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
193  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
194 
195  // LDS allocation for C shuffle in LDS
196  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
198 
199  constexpr auto c_block_size =
200  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
201 
202  return math::max(a_block_space_size_aligned * sizeof(AComputeDataType) +
203  b_block_space_size_aligned * sizeof(BComputeDataType),
204  c_block_size * sizeof(CShuffleDataType));
205  }
206 
207  // A desc for source in blockwise copy
208  template <typename AGridDesc_M_K>
209  __host__ __device__ static constexpr auto
210  MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K& a_grid_desc_m_k)
211  {
212  const auto M = a_grid_desc_m_k.GetLength(I0);
213  const auto K = a_grid_desc_m_k.GetLength(I1);
214 
215  const auto AK0 = K / AK1;
216 
217  return transform_tensor_descriptor(a_grid_desc_m_k,
222  }
223 
224  template <typename AsGridDesc_M_K>
225  __host__ __device__ static constexpr auto
226  MakeDefaultAsGridDescriptor_AK0_M_AK1(const AsGridDesc_M_K& as_grid_desc_m_k)
227  {
228  return generate_tuple(
229  [&](auto i) { return MakeDefaultAGridDescriptor_AK0_M_AK1(as_grid_desc_m_k[i]); },
231  }
232 
233  // B desc for source in blockwise copy
234  template <typename BGridDesc_N_K>
235  __host__ __device__ static constexpr auto
236  MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K& b_grid_desc_n_k)
237  {
238  const auto N = b_grid_desc_n_k.GetLength(I0);
239  const auto K = b_grid_desc_n_k.GetLength(I1);
240 
241  const auto BK0 = K / BK1;
242 
243  return transform_tensor_descriptor(b_grid_desc_n_k,
248  }
249 
250  template <typename BsGridDesc_N_K>
251  __host__ __device__ static constexpr auto
252  MakeDefaultBsGridDescriptor_BK0_N_BK1(const BsGridDesc_N_K& bs_grid_desc_n_k)
253  {
254  return generate_tuple(
255  [&](auto i) { return MakeDefaultBGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k[i]); },
257  }
258 
259  // E desc for destination in blockwise copy
260  template <typename EGridDesc_M_N>
261  __host__ __device__ static constexpr auto
262  MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N& e_grid_desc_m_n)
263  {
264  const auto M = e_grid_desc_m_n.GetLength(I0);
265  const auto N = e_grid_desc_m_n.GetLength(I1);
266 
267  const auto MBlock = M / MPerBlock;
268  const auto NBlock = N / NPerBlock;
269 
270  const auto e_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
271  e_grid_desc_m_n,
276 
277  return e_grid_desc_mblock_mperblock_nblock_nperblock;
278  }
279 
280  // Ds desc for source in blockwise copy
281  template <typename DsGridDesc_M_N>
282  __host__ __device__ static constexpr auto
283  MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N& ds_grid_desc_m_n)
284  {
285  return generate_tuple(
286  [&](auto i) {
287  return MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n[i]);
288  },
290  }
291 
292  // return block_id to E matrix tile idx (m0, n0) mapping
293  template <typename EGridDesc_M_N>
294  __host__ __device__ static constexpr auto
295  MakeDefaultBlock2ETileMap(const EGridDesc_M_N& e_grid_desc_m_n)
296  {
298  e_grid_desc_m_n);
299  }
300 
301  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
302  template <typename AsGridDesc_M_K,
303  typename BsGridDesc_N_K,
304  typename DsGridDesc_M_N,
305  typename EGridDesc_M_N,
306  typename Block2ETileMap>
307  __host__ __device__ static constexpr bool CheckValidity(const AsGridDesc_M_K& as_grid_desc_m_k,
308  const BsGridDesc_N_K& bs_grid_desc_n_k,
309  const DsGridDesc_M_N& ds_grid_desc_m_n,
310  const EGridDesc_M_N& e_grid_desc_m_n,
311  const Block2ETileMap& block_2_etile_map)
312  {
313  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
314  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
315  "Invalid tuning param!");
316  static_assert(KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0,
317  "KPerBlock must be divisible by AK1Value and BK1Value!");
318 
319  const auto M = as_grid_desc_m_k[I0].GetLength(I0);
320  const auto N = bs_grid_desc_n_k[I0].GetLength(I0);
321  const auto AK = as_grid_desc_m_k[I0].GetLength(I1);
322  const auto BK = bs_grid_desc_n_k[I0].GetLength(I1);
323 
324  // check consistency of desc
325  if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) && AK == BK))
326  {
327  return false;
328  }
329 
330  constexpr long_index_t TwoGB = (long_index_t{1} << 31);
331 
332  bool valid = true;
333  static_for<0, NumATensor, 1>{}([&](auto i) {
334  using ADataType = remove_cvref_t<tuple_element_t<i.value, AsDataType>>;
335  valid =
336  valid && (as_grid_desc_m_k[i].GetElementSpaceSize() * sizeof(ADataType) <= TwoGB);
337  valid = valid && (M == as_grid_desc_m_k[i].GetLength(I0) &&
338  AK == as_grid_desc_m_k[i].GetLength(I1));
339  });
340 
341  static_for<0, NumBTensor, 1>{}([&](auto i) {
342  using BDataType = remove_cvref_t<tuple_element_t<i.value, BsDataType>>;
343  valid =
344  valid && (bs_grid_desc_n_k[i].GetElementSpaceSize() * sizeof(BDataType) <= TwoGB);
345  valid = valid && (N == bs_grid_desc_n_k[i].GetLength(I0) &&
346  BK == bs_grid_desc_n_k[i].GetLength(I1));
347  });
348 
349  static_for<0, NumDTensor, 1>{}([&](auto i) {
350  valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
351  N == ds_grid_desc_m_n[i].GetLength(I1));
352  });
353 
354  if(!valid)
355  {
356  return false;
357  }
358 
359  // check tile size
360  if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && AK % KPerBlock == 0))
361  {
362  return false;
363  }
364 
365  // check gridwise gemm pipeline
366  const auto num_k_loop = AK / KPerBlock;
367 
368  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
369  {
370  return false;
371  }
372 
373  // check block-to-E-tile
374  if(!block_2_etile_map.CheckValidity(e_grid_desc_m_n))
375  {
376  return false;
377  }
378 
379  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
380  // check tensor size: cannot be larger than 2GB each
381 
382  if(!(e_grid_desc_m_n.GetElementSpaceSize() * sizeof(EDataType) <= TwoGB))
383  {
384  return false;
385  }
386 
387  return true;
388  }
389 
390  __host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
391  {
392  const index_t num_loop = K / KPerBlock;
393 
394  return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
395  }
396 
397  using AsGridPointer = decltype(MakeAsGridPointer());
398  using BsGridPointer = decltype(MakeBsGridPointer());
399  using DsGridPointer = decltype(MakeDsGridPointer());
400 
401  template <typename ALayout, GemmSpecialization GemmSpec>
402  __host__ __device__ static auto
404  {
405  constexpr auto matrix_padder =
407  MPerBlock, NPerBlock, KPerBlock};
408 
409  const auto a_grid_desc_mraw_kraw = [&]() {
410  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
411  {
412  return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
413  make_tuple(StrideA, I1));
414  }
415  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
416  {
417  return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
418  make_tuple(I1, StrideA));
419  }
420  }();
421 
422  return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
423  }
424 
425  template <typename AsLayout, GemmSpecialization GemmSpec>
426  __host__ __device__ static auto MakeAsGridDescriptor_M_K(
427 #ifdef CK_CODE_GEN_RTC
428  const ck::Array<index_t, NumATensor>& MRaws,
429  const ck::Array<index_t, NumATensor>& KRaws,
430  const ck::Array<index_t, NumATensor>& AsStride
431 #else
432  const std::array<index_t, NumATensor>& MRaws,
433  const std::array<index_t, NumATensor>& KRaws,
434  const std::array<index_t, NumATensor>& AsStride
435 #endif
436  )
437  {
438  return generate_tuple(
439  [&](auto i) {
440  using ALayout = remove_cvref_t<tuple_element_t<i.value, AsLayout>>;
441 
442  return MakeAGridDescriptor_M_K<ALayout, GemmSpec>(MRaws[i], KRaws[i], AsStride[i]);
443  },
445  }
446 
447  template <typename BLayout, GemmSpecialization GemmSpec>
448  __host__ __device__ static auto
449  MakeBGridDescriptor_N_K(const index_t NRaw, const index_t KRaw, const index_t StrideB)
450  {
451  constexpr auto matrix_padder =
453  MPerBlock, NPerBlock, KPerBlock};
454 
455  const auto b_grid_desc_nraw_kraw = [&]() {
457  {
458  return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
459  make_tuple(I1, StrideB));
460  }
462  {
463  return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
464  make_tuple(StrideB, I1));
465  }
466  }();
467 
468  return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
469  }
470 
471  template <typename BsLayout, GemmSpecialization GemmSpec>
472  __host__ __device__ static auto MakeBsGridDescriptor_N_K(
473 #ifdef CK_CODE_GEN_RTC
474  const ck::Array<index_t, NumBTensor>& NRaws,
475  const ck::Array<index_t, NumBTensor>& KRaws,
476  const ck::Array<index_t, NumBTensor>& BsStride
477 #else
478  const std::array<index_t, NumBTensor>& NRaws,
479  const std::array<index_t, NumBTensor>& KRaws,
480  const std::array<index_t, NumBTensor>& BsStride
481 #endif
482  )
483  {
484  return generate_tuple(
485  [&](auto i) {
486  using BLayout = remove_cvref_t<tuple_element_t<i.value, BsLayout>>;
487 
488  return MakeBGridDescriptor_N_K<BLayout, GemmSpec>(NRaws[i], KRaws[i], BsStride[i]);
489  },
491  }
492 
493  template <typename ELayout, GemmSpecialization GemmSpec>
494  __host__ __device__ static auto
496  {
497  constexpr auto matrix_padder =
499  MPerBlock, NPerBlock, KPerBlock};
500  const auto e_grid_desc_mraw_nraw = [&]() {
502  {
503  return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
504  make_tuple(StrideE, I1));
505  }
507  {
508  return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
509  make_tuple(I1, StrideE));
510  }
511  }();
512 
513  return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
514  }
515 
516  template <typename DsLayout, GemmSpecialization GemmSpec>
517  __host__ __device__ static auto MakeDsGridDescriptor_M_N(
518 #ifdef CK_CODE_GEN_RTC
519  const ck::Array<index_t, NumDTensor>& MRaws,
520  const ck::Array<index_t, NumDTensor>& NRaws,
521  const ck::Array<index_t, NumDTensor>& DsStride
522 #else
523  const std::array<index_t, NumDTensor>& MRaws,
524  const std::array<index_t, NumDTensor>& NRaws,
525  const std::array<index_t, NumDTensor>& DsStride
526 #endif
527  )
528  {
529  return generate_tuple(
530  [&](auto i) {
531  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
532 
533  return MakeEGridDescriptor_M_N<DLayout, GemmSpec>(MRaws[i], NRaws[i], DsStride[i]);
534  },
536  }
537 
538  __device__ __host__ static constexpr auto GetMPerBlock() { return MPerBlock; }
539 
540  template <bool HasMainKBlockLoop,
541  typename AsGridDesc_AK0_M_AK1,
542  typename BsGridDesc_BK0_N_BK1,
543  typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
544  typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
545  typename Block2ETileMap>
546  __device__ static void Run(AsGridPointer p_as_grid,
547  BsGridPointer p_bs_grid,
548  DsGridPointer p_ds_grid,
549  EDataType* __restrict__ p_e_grid,
550  void* __restrict__ p_shared,
551  const AElementwiseOperation& a_element_op,
552  const BElementwiseOperation& b_element_op,
553  const CDEElementwiseOperation& cde_element_op,
554  const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1,
555  const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1,
556  const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
557  ds_grid_desc_mblock_mperblock_nblock_nperblock,
558  const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
559  e_grid_desc_mblock_mperblock_nblock_nperblock,
560  const Block2ETileMap& block_2_etile_map)
561  {
562  const auto as_grid_buf = generate_tuple(
563  [&](auto i) {
564  return make_dynamic_buffer<AddressSpaceEnum::Global>(
565  p_as_grid[i], as_grid_desc_ak0_m_ak1[i].GetElementSpaceSize());
566  },
568 
569  const auto bs_grid_buf = generate_tuple(
570  [&](auto i) {
571  return make_dynamic_buffer<AddressSpaceEnum::Global>(
572  p_bs_grid[i], bs_grid_desc_bk0_n_bk1[i].GetElementSpaceSize());
573  },
575 
576  const auto ds_grid_buf = generate_tuple(
577  [&](auto i) {
578  return make_dynamic_buffer<AddressSpaceEnum::Global>(
579  p_ds_grid[i],
580  ds_grid_desc_mblock_mperblock_nblock_nperblock[i].GetElementSpaceSize());
581  },
583 
584  auto e_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
585  p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
586 
587  // divide block work by [M, N]
588  const auto block_work_idx =
589  block_2_etile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
590 
591  if(!block_2_etile_map.ValidCTileIndex(
592  block_work_idx,
593  make_tuple(e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
594  e_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
595  {
596  return;
597  }
598  // HACK: this force m/n_block_data_idx_on_grid into SGPR
599  const index_t m_block_data_idx_on_grid =
600  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
601 
602  const index_t n_block_data_idx_on_grid =
603  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
604 
605  // lds max alignment
606  constexpr auto max_lds_align = math::lcm(AK1, BK1);
607 
608  // A matrix in LDS memory, dst of blockwise copy
609  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
610 
611  // B matrix in LDS memory, dst of blockwise copy
612  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
613 
614  const auto idx_as_block_begin =
615  generate_tuple([&](auto) { return make_multi_index(0, m_block_data_idx_on_grid, 0); },
617 
618  auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
620  AsDataType,
622  decltype(as_grid_desc_ak0_m_ak1),
623  decltype(tie(a_block_desc_ak0_m_ak1)),
624  AElementwiseOperation,
627  ABlockTransferThreadClusterLengths_AK0_M_AK1,
628  ABlockTransferThreadClusterArrangeOrder,
629  ABlockTransferSrcAccessOrder,
631  ABlockTransferSrcVectorDim,
632  2,
633  ABlockTransferSrcScalarPerVector,
634  ABlockTransferDstScalarPerVector_AK1,
636  Sequence<true>>{as_grid_desc_ak0_m_ak1,
637  idx_as_block_begin,
638  tie(a_block_desc_ak0_m_ak1),
639  make_tuple(make_multi_index(0, 0, 0)),
640  a_element_op};
641 
642  const auto idx_bs_block_begin =
643  generate_tuple([&](auto) { return make_multi_index(0, n_block_data_idx_on_grid, 0); },
645 
646  auto b_blockwise_copy = ThreadGroupTensorSliceTransfer_v7r2<
648  BsDataType,
650  decltype(bs_grid_desc_bk0_n_bk1),
651  decltype(tie(b_block_desc_bk0_n_bk1)),
652  BElementwiseOperation,
655  BBlockTransferThreadClusterLengths_BK0_N_BK1,
656  BBlockTransferThreadClusterArrangeOrder,
657  BBlockTransferSrcAccessOrder,
659  BBlockTransferSrcVectorDim,
660  2,
661  BBlockTransferSrcScalarPerVector,
662  BBlockTransferDstScalarPerVector_BK1,
664  Sequence<true>>{bs_grid_desc_bk0_n_bk1,
665  idx_bs_block_begin,
666  tie(b_block_desc_bk0_n_bk1),
667  make_tuple(make_multi_index(0, 0, 0)),
668  b_element_op};
669 
670  // GEMM definition
671  // c_mtx += transpose(a_mtx) * b_mtx
672  // a_mtx[K0PerBlock, MPerBlock] is in LDS
673  // b_mtx[K0PerBlock, NPerBlock] is in LDS
674  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
675  // register
676  // sanity check
677  constexpr index_t KPack = math::max(
678  math::lcm(AK1, BK1),
680  .k_per_blk);
681 
683  BlockSize,
686  AccDataType,
687  decltype(a_block_desc_ak0_m_ak1),
688  decltype(b_block_desc_bk0_n_bk1),
689  MPerXdl,
690  NPerXdl,
691  MXdlPerWave,
692  NXdlPerWave,
693  KPack,
694  LoopSched>();
695 
696  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
697 
698  // LDS allocation for A and B: be careful of alignment
699  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
700  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
701 
702  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
703  static_cast<AComputeDataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
704 
705  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
706  static_cast<BComputeDataType*>(p_shared) + a_block_space_size_aligned,
707  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
708 
709  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1, 0, 0);
710  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1, 0, 0);
711 
712  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
713  (as_grid_desc_ak0_m_ak1[I0].GetLength(I0) * as_grid_desc_ak0_m_ak1[I0].GetLength(I2)) /
714  KPerBlock);
715 
716  // gridwise GEMM pipeline
717  const auto gridwise_gemm_pipeline =
718  GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>();
719 
720  gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(as_grid_desc_ak0_m_ak1,
721  a_block_desc_ak0_m_ak1,
722  a_blockwise_copy,
723  as_grid_buf,
724  a_block_buf,
725  a_block_slice_copy_step,
726  bs_grid_desc_bk0_n_bk1,
727  b_block_desc_bk0_n_bk1,
728  b_blockwise_copy,
729  bs_grid_buf,
730  b_block_buf,
731  b_block_slice_copy_step,
732  blockwise_gemm,
733  c_thread_buf,
734  num_k_block_main_loop);
735 
736  // shuffle C and write out
737  {
738  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
739  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
740  "wrong!");
741 
742  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
743  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
744 
745  // TODO: hacky, fix it!
746  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
747  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
748 
749  // TODO: hacky, fix it!
750  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
751  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
752  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
753 
754  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
755  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
756  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
757  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
758  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
759  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
760  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
761  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
762 
763  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
765 
766  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
767  static_cast<CShuffleDataType*>(p_shared),
768  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
769 
770  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
771  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
772  make_tuple(
775  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
776  M1, // M1 = MWave
777  M2, // M2 * M3 * M4 = MPerXdl
778  M3,
779  M4)),
782  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
783  N1, // N1 = NWave
784  N2))), // N2 = NPerXdl
786  make_tuple(
788 
789  // calculate origin of thread output tensor on global memory
790  // blockwise GEMM c matrix starting index
791  const auto c_thread_mtx_on_block =
792  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
793 
794  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
795  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
796 
797  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
799  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
802 
803  const auto m_thread_data_on_block_idx =
804  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
805  make_multi_index(m_thread_data_on_block));
806 
807  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
812 
813  const auto n_thread_data_on_block_idx =
814  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
815  make_multi_index(n_thread_data_on_block));
816 
817  // shuffle: threadwise copy C from VGPR to LDS
818  auto c_thread_copy_vgpr_to_lds =
820  CShuffleDataType,
821  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
822  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
824  Sequence<CShuffleMXdlPerWavePerShuffle,
825  CShuffleNXdlPerWavePerShuffle,
826  I1,
827  I1,
828  M2,
829  I1,
830  M4,
831  I1>,
833  7,
834  1,
836  1,
837  true>{
838  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
840  0,
841  m_thread_data_on_block_idx[I1],
842  n_thread_data_on_block_idx[I1],
843  m_thread_data_on_block_idx[I2],
844  m_thread_data_on_block_idx[I3],
845  m_thread_data_on_block_idx[I4],
846  n_thread_data_on_block_idx[I2]),
848 
849  // tuple of reference to C/Ds tensor descriptors
850  const auto c_ds_desc_refs = concat_tuple_of_reference(
851  tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
852  generate_tie(
853  [&](auto i) -> const auto& // return type should be reference
854  { return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
855  Number<NumDTensor>{}));
856 
857  // tuple of reference to C/Ds tensor descriptors
858  const auto c_ds_buf_refs = concat_tuple_of_reference(
859  tie(c_shuffle_block_buf),
860  generate_tie(
861  [&](auto i) -> const auto& // return type should be reference
862  { return ds_grid_buf[i]; },
863  Number<NumDTensor>{}));
864 
865  // tuple of starting index of C/Ds blockwise copy
866  const auto idx_c_ds_block_begin = container_concat(
867  make_tuple(make_multi_index(0, 0, 0, 0)),
869  [&](auto) {
870  return make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0);
871  },
872  Number<NumDTensor>{}));
873 
874  // blockwise copy C/D/E between LDS and global
875  auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r2<
877  decltype(container_concat(make_tuple(CShuffleDataType{}), DsDataType{})),
879  decltype(c_ds_desc_refs),
880  decltype(tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
881  CDEElementwiseOperation,
882  Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, // FIXME: make Sequence
883  // support arbitray type
884  Sequence<1,
885  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
886  1,
887  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
888  CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
889  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
890  Sequence<0, 1, 2, 3>, // typename SrcDimAccessOrder,
891  Sequence<0, 1, 2, 3>, // typename DstDimAccessOrder,
892  3, // index_t SrcVectorDim,
893  3, // index_t DstVectorDim,
894  CDEShuffleBlockTransferScalarPerVector_NPerBlock,
895  CDEShuffleBlockTransferScalarPerVector_NPerBlock,
899  false>>, // ThreadTransferSrcResetCoordinateAfterRunFlags
900  Sequence<false>> // ThreadTransferDstResetCoordinateAfterRunFlags
901  {c_ds_desc_refs,
902  idx_c_ds_block_begin,
903  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
904  make_tuple(make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0)),
905  cde_element_op};
906 
907  // space filling curve for threadwise C in VGPR before shuffle
908  constexpr auto sfc_c_vgpr =
911  Sequence<CShuffleMXdlPerWavePerShuffle,
912  CShuffleNXdlPerWavePerShuffle,
913  1,
914  1,
915  M2,
916  1,
917  M4,
918  1>>{};
919 
920  // space filling curve for shuffled blockwise C/D/E
921  constexpr auto sfc_cde_block =
924  Sequence<1,
925  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
926  1,
927  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
928 
929  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
930 
931  static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
932 
933  static_for<0, num_access, 1>{}([&](auto access_id) {
934  // make sure it's safe to write to LDS
935  block_sync_lds();
936 
937  // each thread write its data from VGPR to LDS
938  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
939  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
940  c_thread_buf,
941  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
942  c_shuffle_block_buf);
943 
944  // make sure it's safe to read from LDS
945  block_sync_lds();
946 
947  // each block copy its data from LDS to global
948  cde_block_copy_lds_and_global.Run(
949  c_ds_desc_refs,
950  c_ds_buf_refs,
951  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
952  tie(e_grid_buf));
953 
954  if constexpr(access_id < num_access - 1)
955  {
956  constexpr auto cde_lds_and_global_step =
957  sfc_cde_block.GetForwardStep(access_id);
958 
959  // move on Ds
960  static_for<0, NumDTensor, 1>{}([&](auto i) {
961  cde_block_copy_lds_and_global.MoveSrcSliceWindow(
962  c_ds_desc_refs, i + I1, cde_lds_and_global_step);
963  });
964 
965  // move on E
966  cde_block_copy_lds_and_global.MoveDstSliceWindow(
967  tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
968  I0,
969  cde_lds_and_global_step);
970  }
971  });
972  }
973  }
974 
975  template <bool HasMainKBlockLoop,
976  GemmSpecialization GemmSpec,
977  typename AsLayout,
978  typename BsLayout,
979  typename DsLayout,
980  typename ELayout,
981  typename Block2ETileMap>
982  __device__ static void Run(AsGridPointer p_as_grid,
983  BsGridPointer p_bs_grid,
984  DsGridPointer p_ds_grid,
985  void* __restrict__ p_e_grid_,
986  void* __restrict__ p_shared,
987  const AElementwiseOperation& a_element_op,
988  const BElementwiseOperation& b_element_op,
989  const CDEElementwiseOperation& cde_element_op,
990  const index_t M,
991  const index_t N,
992  const index_t K,
993 #ifdef CK_CODE_GEN_RTC
994  const ck::Array<index_t, NumATensor> StrideAs,
995  const ck::Array<index_t, NumBTensor> StrideBs,
996  const ck::Array<index_t, NumDTensor> StrideDs,
997 #else
998  const std::array<index_t, NumATensor> StrideAs,
999  const std::array<index_t, NumBTensor> StrideBs,
1000  const std::array<index_t, NumDTensor> StrideDs,
1001 #endif
1002  const index_t StrideE,
1003  const Block2ETileMap& block_2_etile_map)
1004  {
1005  using AsGridDesc_M_K =
1007  using BsGridDesc_N_K =
1009  using DsGridDesc_M_N =
1011 
1012  const auto p_e_grid = reinterpret_cast<EDataType*>(p_e_grid_);
1013 
1014  AsGridDesc_M_K as_grid_desc_m_k;
1015  BsGridDesc_N_K bs_grid_desc_n_k;
1016  DsGridDesc_M_N ds_grid_desc_m_n;
1017 
1018  static_for<0, NumATensor, 1>{}([&](auto j) {
1019  using ALayout = remove_cvref_t<tuple_element_t<j.value, AsLayout>>;
1020 
1021  as_grid_desc_m_k(j) = MakeAGridDescriptor_M_K<ALayout, GemmSpec>(M, K, StrideAs[j]);
1022  });
1023 
1024  static_for<0, NumBTensor, 1>{}([&](auto j) {
1025  using BLayout = remove_cvref_t<tuple_element_t<j.value, BsLayout>>;
1026 
1027  bs_grid_desc_n_k(j) = MakeBGridDescriptor_N_K<BLayout, GemmSpec>(N, K, StrideBs[j]);
1028  });
1029 
1030  static_for<0, NumDTensor, 1>{}([&](auto j) {
1031  using DLayout = remove_cvref_t<tuple_element_t<j.value, DsLayout>>;
1032 
1033  ds_grid_desc_m_n(j) = MakeEGridDescriptor_M_N<DLayout, GemmSpec>(M, N, StrideDs[j]);
1034  });
1035 
1036  const auto e_grid_desc_m_n = MakeEGridDescriptor_M_N<ELayout, GemmSpec>(M, N, StrideE);
1037 
1038  // tensor descriptors for block/thread-wise copy
1039  const auto as_grid_desc_ak0_m_ak1 = MakeDefaultAsGridDescriptor_AK0_M_AK1(as_grid_desc_m_k);
1040 
1041  const auto bs_grid_desc_bk0_n_bk1 = MakeDefaultBsGridDescriptor_BK0_N_BK1(bs_grid_desc_n_k);
1042 
1043  const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1045 
1046  const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1048 
1049  Run<HasMainKBlockLoop>(p_as_grid,
1050  p_bs_grid,
1051  p_ds_grid,
1052  p_e_grid,
1053  p_shared,
1054  a_element_op,
1055  b_element_op,
1056  cde_element_op,
1057  as_grid_desc_ak0_m_ak1,
1058  bs_grid_desc_bk0_n_bk1,
1059  ds_grid_desc_mblock_mperblock_nblock_nperblock,
1060  e_grid_desc_mblock_mperblock_nblock_nperblock,
1061  block_2_etile_map);
1062  }
1063 };
1064 
1065 } // namespace ck
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition: blockwise_gemm_xdlops.hpp:606
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:22
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:901
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
InMemoryDataOperationEnum
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
int64_t long_index_t
Definition: ck.hpp:290
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
ushort bhalf_t
Definition: data_type.hpp:24
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:30
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:898
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:17
Definition: array.hpp:14
Definition: block_to_ctile_map.hpp:260
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:77
__host__ static constexpr __device__ auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:262
static constexpr auto I7
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:91
static constexpr auto BK1
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:95
static constexpr auto MakeBsGridPointer()
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:157
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:180
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:99
AComputeDataType_ AComputeDataType
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:110
static constexpr auto I4
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:88
static constexpr auto I2
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:86
__host__ static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N &ds_grid_desc_m_n)
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:283
__host__ static __device__ auto MakeAsGridDescriptor_M_K(const std::array< index_t, NumATensor > &MRaws, const std::array< index_t, NumATensor > &KRaws, const std::array< index_t, NumATensor > &AsStride)
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:426
BComputeDataType_ BComputeDataType
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:111
static constexpr auto I1
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:85
__host__ static constexpr __device__ auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k)
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:236
__host__ static __device__ auto MakeDsGridDescriptor_M_N(const std::array< index_t, NumDTensor > &MRaws, const std::array< index_t, NumDTensor > &NRaws, const std::array< index_t, NumDTensor > &DsStride)
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:517
static constexpr auto I6
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:90
static __device__ void Run(AsGridPointer p_as_grid, BsGridPointer p_bs_grid, DsGridPointer p_ds_grid, EDataType *__restrict__ p_e_grid, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op, const AsGridDesc_AK0_M_AK1 as_grid_desc_ak0_m_ak1, const BsGridDesc_BK0_N_BK1 bs_grid_desc_bk0_n_bk1, const DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap &block_2_etile_map)
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:546
__host__ static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:131
__host__ static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:122
__host__ static constexpr __device__ auto MakeDefaultAsGridDescriptor_AK0_M_AK1(const AsGridDesc_M_K &as_grid_desc_m_k)
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:226
static constexpr auto AK0PerBlock
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:96
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage, LoopSched >())> GridwiseGemmPipe
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:102
decltype(MakeAsGridPointer()) AsGridPointer
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:397
__host__ static __device__ auto MakeBGridDescriptor_N_K(const index_t NRaw, const index_t KRaw, const index_t StrideB)
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:449
static constexpr auto I3
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:87
__host__ static constexpr __device__ auto MakeDefaultBlock2ETileMap(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:295
static constexpr auto I0
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:84
__host__ static constexpr __device__ auto MakeDefaultBsGridDescriptor_BK0_N_BK1(const BsGridDesc_N_K &bs_grid_desc_n_k)
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:252
static constexpr auto MakeAsGridPointer()
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:146
static constexpr auto I5
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:89
decltype(MakeBsGridPointer()) BsGridPointer
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:398
static __device__ void Run(AsGridPointer p_as_grid, BsGridPointer p_bs_grid, DsGridPointer p_ds_grid, void *__restrict__ p_e_grid_, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op, const index_t M, const index_t N, const index_t K, const std::array< index_t, NumATensor > StrideAs, const std::array< index_t, NumBTensor > StrideBs, const std::array< index_t, NumDTensor > StrideDs, const index_t StrideE, const Block2ETileMap &block_2_etile_map)
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:982
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:390
__host__ static constexpr __device__ auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k)
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:210
static constexpr auto BK0PerBlock
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:97
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:399
__host__ static constexpr __device__ bool CheckValidity(const AsGridDesc_M_K &as_grid_desc_m_k, const BsGridDesc_N_K &bs_grid_desc_n_k, const DsGridDesc_M_N &ds_grid_desc_m_n, const EGridDesc_M_N &e_grid_desc_m_n, const Block2ETileMap &block_2_etile_map)
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:307
__host__ static __device__ auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:495
__host__ static __device__ auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:403
__host__ static __device__ auto MakeBsGridDescriptor_N_K(const std::array< index_t, NumBTensor > &NRaws, const std::array< index_t, NumBTensor > &KRaws, const std::array< index_t, NumBTensor > &BsStride)
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:472
static constexpr auto AK1
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:94
static constexpr index_t NumBTensor
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:79
__host__ static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:114
static constexpr index_t NumDTensor
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:80
__device__ static constexpr __host__ auto GetMPerBlock()
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:538
ck::tensor_operation::device::GemmSpecialization GemmSpecialization
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:82
static constexpr auto MakeDsGridPointer()
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:169
static constexpr index_t NumATensor
Definition: gridwise_gemm_multiple_abd_xdl_cshuffle.hpp:78
Definition: xdlops_gemm.hpp:886
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: thread_group_tensor_slice_transfer_v7r2.hpp:47
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: tuple.hpp:117
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: functional2.hpp:31
Definition: matrix_padder.hpp:180
Definition: unary_element_wise_operation.hpp:241