/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp Source File
gridwise_gemm_xdl_cshuffle_v1.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
17 
18 namespace ck {
19 
20 template <typename GridwiseGemm, bool HasMainKBlockLoop>
21 __global__ void
22 #if CK_USE_LAUNCH_BOUNDS
24 #endif
25  kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
26 {
27 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
28  defined(__gfx94__))
29  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
30 
31  GridwiseGemm::template Run<HasMainKBlockLoop>(
32  karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg);
33 #else
34  ignore = karg;
35 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
36 }
37 
38 template <typename GridwiseGemm,
39  typename FloatA,
40  typename FloatB,
41  typename FloatC,
42  bool HasMainKBlockLoop>
43 __global__ void
44 #if CK_USE_LAUNCH_BOUNDS
46 #endif
47  kernel_gemm_xdl_cshuffle_v1(const FloatA* __restrict__ p_a_grid,
48  const FloatB* __restrict__ p_b_grid,
49  FloatC* __restrict__ p_c_grid,
50  typename GridwiseGemm::Problem problem)
51 {
52 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
53  defined(__gfx94__))
54  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
55 
56  GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, p_b_grid, p_c_grid, p_shared, problem);
57 #else
58  ignore = p_a_grid;
59  ignore = p_b_grid;
60  ignore = p_c_grid;
61  ignore = problem;
62 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
63 }
64 
65 template <typename ALayout,
66  typename BLayout,
67  typename CLayout,
68  typename FloatA,
69  typename FloatB,
70  typename FloatGemmAcc,
71  typename FloatCShuffle,
72  typename FloatC,
73  typename AElementwiseOperation,
74  typename BElementwiseOperation,
75  typename CElementwiseOperation,
77  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
78  index_t NumGemmKPrefetchStage,
79  index_t BlockSize,
80  index_t MPerBlock,
81  index_t NPerBlock,
82  index_t KPerBlock,
83  index_t AK1Value,
84  index_t BK1Value,
85  index_t MPerXdl,
86  index_t NPerXdl,
87  index_t MXdlPerWave,
88  index_t NXdlPerWave,
89  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
90  typename ABlockTransferThreadClusterArrangeOrder,
91  typename ABlockTransferSrcAccessOrder,
92  index_t ABlockTransferSrcVectorDim,
93  index_t ABlockTransferSrcScalarPerVector,
94  index_t ABlockTransferDstScalarPerVector_AK1,
95  bool AThreadTransferSrcResetCoordinateAfterRun,
96  index_t ABlockLdsExtraM,
97  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
98  typename BBlockTransferThreadClusterArrangeOrder,
99  typename BBlockTransferSrcAccessOrder,
100  index_t BBlockTransferSrcVectorDim,
101  index_t BBlockTransferSrcScalarPerVector,
102  index_t BBlockTransferDstScalarPerVector_BK1,
103  bool BThreadTransferSrcResetCoordinateAfterRun,
104  index_t BBlockLdsExtraN,
105  index_t CShuffleMXdlPerWavePerShuffle,
106  index_t CShuffleNXdlPerWavePerShuffle,
107  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
108  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
109  LoopScheduler LoopSched,
110  PipelineVersion PipelineVer = PipelineVersion::v1,
111  typename ComputeTypeA = FloatC,
112  typename ComputeTypeB = ComputeTypeA>
114 {
115  static constexpr auto I0 = Number<0>{};
116  static constexpr auto I1 = Number<1>{};
117  static constexpr auto I2 = Number<2>{};
118  static constexpr auto I3 = Number<3>{};
119  static constexpr auto I4 = Number<4>{};
120  static constexpr auto I5 = Number<5>{};
121  static constexpr auto I6 = Number<6>{};
122  static constexpr auto I7 = Number<7>{};
123 
124  // K1 should be Number<...>
125  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
126  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
127  static constexpr auto AK1Number = Number<AK1Value>{};
128  static constexpr auto BK1Number = Number<BK1Value>{};
129 
131 
132  __host__ static auto CalculateGridSize(index_t M, index_t N)
133  {
134  return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
135  }
136 
137  __host__ static auto CalculateMPadded(index_t M)
138  {
139  return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
140  }
141 
142  __host__ static auto CalculateNPadded(index_t N)
143  {
144  return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
145  }
146 
147  __host__ static auto CalculateKPadded(index_t K)
148  {
149  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
150  }
151 
152  __host__ static auto CalculateAK0(index_t K)
153  {
155 
156  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
157  GemmSpec == GemmSpecialization::MNKPadding ||
158  GemmSpec == GemmSpecialization::KPadding ||
159  GemmSpec == GemmSpecialization::NKPadding)
160  {
161  return CalculateKPadded(K) / AK1Value;
162  }
163  else
164  {
165  return K / AK1Value;
166  }
167  }
168 
169  __host__ static auto CalculateBK0(index_t K)
170  {
172 
173  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
174  GemmSpec == GemmSpecialization::MNKPadding ||
175  GemmSpec == GemmSpecialization::KPadding ||
176  GemmSpec == GemmSpecialization::MKPadding)
177  {
178  return CalculateKPadded(K) / BK1Value;
179  }
180  else
181  {
182  return K / BK1Value;
183  }
184  }
185 
186  __host__ static auto CalculateMBlock(index_t M)
187  {
188  return math::integer_divide_floor(M, MPerBlock);
189  }
190 
191  __host__ static auto CalculateNBlock(index_t N)
192  {
193  return math::integer_divide_floor(N, NPerBlock);
194  }
195 
196  __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
197  index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
198  {
199  const auto a_grid_desc_mraw_kraw = [&]() {
200  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
201  {
202  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
203  }
204  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
205  {
206  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
207  }
208  }();
209 
211 
212  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
213  GemmSpec == GemmSpecialization::MNKPadding)
214  {
215  // pad both M and K
216  const auto a_grid_desc_m_k =
217  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
219  make_right_pad_transform(K, KPad - K)),
222 
223  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
224  a_grid_desc_m_k,
229 
230  return a_grid_desc_ak0_m_ak1;
231  }
232  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
233  GemmSpec == GemmSpecialization::MNPadding)
234  {
235  // pad M, but not K
236  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
237  a_grid_desc_mraw_kraw,
239  make_right_pad_transform(M, MPad - M)),
242 
243  return a_grid_desc_ak0_m_ak1;
244  }
245  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
246  GemmSpec == GemmSpecialization::NKPadding)
247  {
248  // pad K, but not M
249  const auto a_grid_desc_m_k = transform_tensor_descriptor(
250  a_grid_desc_mraw_kraw,
254 
255  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
256  a_grid_desc_m_k,
261 
262  return a_grid_desc_ak0_m_ak1;
263  }
264  else
265  {
266  // not pad M or K
267  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
268  a_grid_desc_mraw_kraw,
273 
274  return a_grid_desc_ak0_m_ak1;
275  }
276  }
277 
278  __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
279  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
280  {
281  const auto b_grid_desc_nraw_kraw = [&]() {
283  {
284  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
285  }
287  {
288  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
289  }
290  }();
291 
293 
294  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
295  GemmSpec == GemmSpecialization::MNKPadding)
296  {
297  // pad both N and K
298  const auto b_grid_desc_n_k =
299  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
301  make_right_pad_transform(K, KPad - K)),
304 
305  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
306  b_grid_desc_n_k,
311 
312  return b_grid_desc_bk0_n_bk1;
313  }
314  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
315  GemmSpec == GemmSpecialization::MNPadding)
316  {
317  // pad N, but not K
318  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
319  b_grid_desc_nraw_kraw,
321  make_right_pad_transform(N, NPad - N)),
324 
325  return b_grid_desc_bk0_n_bk1;
326  }
327  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
328  GemmSpec == GemmSpecialization::MKPadding)
329  {
330  // pad K, but not N
331  const auto b_grid_desc_n_k = transform_tensor_descriptor(
332  b_grid_desc_nraw_kraw,
336 
337  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
338  b_grid_desc_n_k,
343 
344  return b_grid_desc_bk0_n_bk1;
345  }
346  else
347  {
348  // not pad N or K
349  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
350  b_grid_desc_nraw_kraw,
355 
356  return b_grid_desc_bk0_n_bk1;
357  }
358  }
359 
360  __host__ __device__ static auto
362  {
363  const auto c_grid_desc_mraw_nraw = [&]() {
365  {
366  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
367  }
369  {
370  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
371  }
372  }();
373 
375 
376  if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
377  GemmSpec == GemmSpecialization::MNKPadding)
378  {
379  // pad M and N
380  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
382  make_right_pad_transform(N, NPad - N)),
385  }
386  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
387  GemmSpec == GemmSpecialization::MKPadding)
388  {
389  // pad M, but not N
391  c_grid_desc_mraw_nraw,
395  }
396  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
397  GemmSpec == GemmSpecialization::NKPadding)
398  {
399  // pad N, but not M
401  c_grid_desc_mraw_nraw,
405  }
406  else
407  {
408  // not pad M or N
409  return c_grid_desc_mraw_nraw;
410  }
411  }
412 
413  struct Problem
414  {
415  __host__ Problem(index_t M_,
416  index_t N_,
417  index_t K_,
418  index_t StrideA_,
419  index_t StrideB_,
420  index_t StrideC_)
421  : M{M_},
422  N{N_},
423  K{K_},
424  StrideA{StrideA_},
425  StrideB{StrideB_},
426  StrideC{StrideC_},
430  AK0{CalculateAK0(K_)},
431  BK0{CalculateBK0(K_)},
432  MBlock{CalculateMBlock(M_)},
434  {
435  }
436 
437  __host__ void Print() const
438  {
439  std::cout << "problem {"
440  << "M:" << M << ", "
441  << "N:" << N << ", "
442  << "K:" << K << ", "
443  << "SA:" << StrideA << ", "
444  << "SB:" << StrideB << ", "
445  << "SC:" << StrideC << ", "
446  << "MP:" << MPadded << ", "
447  << "NP:" << NPadded << ", "
448  << "KP:" << KPadded << ", "
449  << "AK0:" << AK0 << ", "
450  << "BK0:" << BK0 << ", "
451  << "MBlock: " << MBlock << ", "
452  << "NBlock: " << NBlock << "}" << std::endl;
453  }
454 
468  };
469 
470  // Argument
472  {
473  __host__ Argument(const FloatA* p_a_grid_,
474  const FloatB* p_b_grid_,
475  FloatC* p_c_grid_,
476  index_t M_,
477  index_t N_,
478  index_t K_,
479  index_t StrideA_,
480  index_t StrideB_,
481  index_t StrideC_)
482  : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
483  p_a_grid{p_a_grid_},
484  p_b_grid{p_b_grid_},
485  p_c_grid{p_c_grid_}
486  {
487  }
488 
489  const FloatA* p_a_grid;
490  const FloatB* p_b_grid;
491  FloatC* p_c_grid;
492  };
493 
494  // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
496  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
497 
498  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
499  {
500  // A matrix in LDS memory, dst of blockwise copy
504  }
505 
506  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
507  {
508  // B matrix in LDS memory, dst of blockwise copy
512  }
513 
515  {
516  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
517  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
518 
519  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
521  make_tuple(I1,
523  I1,
525 
526  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
527  }
528 
529  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
530  {
531  // LDS allocation for A and B: be careful of alignment
532  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
533  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
534 
535  // lds max alignment
536  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
537 
538  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
539  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
540 
541  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
542  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
543 
544  // LDS allocation for C shuffle in LDS
545  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
547 
548  constexpr auto c_block_size =
549  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
550 
551  return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) +
552  b_block_space_size_aligned * sizeof(ComputeTypeB)),
553  c_block_size * sizeof(FloatCShuffle));
554  }
555 
556  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
557  __host__ static constexpr bool CheckValidity(const Problem& problem)
558  {
559  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
560  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
561  "Invalid tuning param!");
562 
563  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
567  {
568  if(!(problem.M % MPerBlock == 0))
569  {
570  return false;
571  }
572  }
573 
574  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
578  {
579  if(!(problem.N % NPerBlock == 0))
580  {
581  return false;
582  }
583  }
584 
589  {
590  if(!(CalculateKPadded(problem.K) % AK1Value == 0) ||
591  !(CalculateKPadded(problem.K) % BK1Value == 0))
592  {
593  return false;
594  }
595  }
596  else
597  {
598  if(!(problem.K % AK1Value == 0) || !(problem.K % BK1Value == 0))
599  {
600  return false;
601  }
602  }
603 
605  {
606  if(problem.K % ABlockTransferSrcScalarPerVector != 0)
607  {
608  return false;
609  }
610  }
611  else
612  {
613  if(problem.M % ABlockTransferSrcScalarPerVector != 0)
614  {
615  return false;
616  }
617  }
618 
620  {
621  if(problem.N % BBlockTransferSrcScalarPerVector != 0)
622  {
623  return false;
624  }
625  }
626  else
627  {
628  if(problem.K % BBlockTransferSrcScalarPerVector != 0)
629  {
630  return false;
631  }
632  }
633 
635  {
636  if(problem.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
637  {
638  return false;
639  }
640  }
641  else
642  {
643  if(problem.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
644  {
645  return false;
646  }
647  }
648 
649  // check gridwise gemm pipeline
650  const auto num_k_loop = (CalculateAK0(problem.K) * AK1Value) / KPerBlock;
651 
652  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
653  {
654  return false;
655  }
656 
657  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
658  return true;
659  }
660 
661  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
662  {
663  const index_t num_loop = K / KPerBlock;
664 
665  return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
666  }
667 
668  template <typename CGridDesc>
670  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
671  {
672  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
673  c_grid_desc_m_n,
678 
679  return c_grid_desc_mblock_mperblock_nblock_nperblock;
680  }
681 
682  // return block_id to C matrix tile idx (m0, n0) mapping
684 
685  template <bool HasMainKBlockLoop>
686  __device__ static void Run(const FloatA* __restrict__ p_a_grid,
687  const FloatB* __restrict__ p_b_grid,
688  FloatC* __restrict__ p_c_grid,
689  void* __restrict__ p_shared,
690  const Problem& problem)
691  {
692  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
693  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
694  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
695  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
696  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
697  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
698 
699  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
701  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
702 
703  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
704  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
705  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
706  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
707  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
708  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
709 
710  const AElementwiseOperation a_element_op{};
711  const BElementwiseOperation b_element_op{};
712  const CElementwiseOperation c_element_op{};
713 
714  // divide block work by [M, N]
715  const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N};
716 
717  const auto block_work_idx =
718  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
719 
720  if(!block_2_ctile_map.ValidCTileIndex(
721  block_work_idx,
722  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
723  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
724  {
725  return;
726  }
727 
728  // HACK: this force m/n_block_data_idx_on_grid into SGPR
729  const index_t m_block_data_idx_on_grid =
730  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
731 
732  const index_t n_block_data_idx_on_grid =
733  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
734 
735  // lds max alignment
736  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
737 
738  // A matrix in LDS memory, dst of blockwise copy
739  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
740 
741  // B matrix in LDS memory, dst of blockwise copy
742  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
743 
744  // A matrix blockwise copy
745  auto a_blockwise_copy =
747  AElementwiseOperation,
751  ABlockTransferThreadClusterLengths_AK0_M_AK1,
752  ABlockTransferThreadClusterArrangeOrder,
753  FloatA,
754  ComputeTypeA,
755  decltype(a_grid_desc_ak0_m_ak1),
756  decltype(a_block_desc_ak0_m_ak1),
757  ABlockTransferSrcAccessOrder,
759  ABlockTransferSrcVectorDim,
760  2,
761  ABlockTransferSrcScalarPerVector,
762  ABlockTransferDstScalarPerVector_AK1,
763  1,
764  1,
765  AThreadTransferSrcResetCoordinateAfterRun,
766  true,
767  NumGemmKPrefetchStage>(
768  a_grid_desc_ak0_m_ak1,
769  make_multi_index(0, m_block_data_idx_on_grid, 0),
770  a_element_op,
771  a_block_desc_ak0_m_ak1,
772  make_multi_index(0, 0, 0),
774 
775  // B matrix blockwise copy
776  auto b_blockwise_copy =
778  BElementwiseOperation,
782  BBlockTransferThreadClusterLengths_BK0_N_BK1,
783  BBlockTransferThreadClusterArrangeOrder,
784  FloatB,
785  ComputeTypeB,
786  decltype(b_grid_desc_bk0_n_bk1),
787  decltype(b_block_desc_bk0_n_bk1),
788  BBlockTransferSrcAccessOrder,
790  BBlockTransferSrcVectorDim,
791  2,
792  BBlockTransferSrcScalarPerVector,
793  BBlockTransferDstScalarPerVector_BK1,
794  1,
795  1,
796  BThreadTransferSrcResetCoordinateAfterRun,
797  true,
798  NumGemmKPrefetchStage>(
799  b_grid_desc_bk0_n_bk1,
800  make_multi_index(0, n_block_data_idx_on_grid, 0),
801  b_element_op,
802  b_block_desc_bk0_n_bk1,
803  make_multi_index(0, 0, 0),
805 
806  // GEMM definition
807  // c_mtx += transpose(a_mtx) * b_mtx
808  // a_mtx[K0PerBlock, MPerBlock] is in LDS
809  // b_mtx[K0PerBlock, NPerBlock] is in LDS
810  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
811  // register
812  // sanity check
813  constexpr index_t KPack = math::max(
816 
818  BlockSize,
819  ComputeTypeA,
820  ComputeTypeB,
821  FloatGemmAcc,
822  decltype(a_block_desc_ak0_m_ak1),
823  decltype(b_block_desc_bk0_n_bk1),
824  MPerXdl,
825  NPerXdl,
826  MXdlPerWave,
827  NXdlPerWave,
828  KPack,
829  LoopSched>();
830 
831  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
832 
833  // LDS allocation for A and B: be careful of alignment
834  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
835  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
836 
837  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
838  static_cast<ComputeTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
839 
840  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
841  static_cast<ComputeTypeB*>(p_shared) + a_block_space_size_aligned,
842  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
843 
844  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
845  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
846 
847  // gridwise GEMM pipeline
848  static_assert(std::is_default_constructible_v<GridwiseGemmPipe>);
849  const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
850 
851  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
852  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
853  KPerBlock);
854 
855  gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
856  a_block_desc_ak0_m_ak1,
857  a_blockwise_copy,
858  a_grid_buf,
859  a_block_buf,
860  a_block_slice_copy_step,
861  b_grid_desc_bk0_n_bk1,
862  b_block_desc_bk0_n_bk1,
863  b_blockwise_copy,
864  b_grid_buf,
865  b_block_buf,
866  b_block_slice_copy_step,
867  blockwise_gemm,
868  c_thread_buf,
869  num_k_block_main_loop);
870 
871  // shuffle C and write out
872  {
873  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
874  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
875  "wrong!");
876 
877  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
878  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
879 
880  // TODO: hacky, fix it!
881  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
882  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
883 
884  // TODO: hacky, fix it!
885  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
886  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
887  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
888 
889  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
890  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
891  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
892  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
893  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
894  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
895  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
896  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
897 
898  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
900 
901  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
902  static_cast<FloatCShuffle*>(p_shared),
903  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
904 
905  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
906  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
907  make_tuple(
910  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
911  M1, // M1 = MWave
912  M2, // M2 * M3 * M4 = MPerXdl
913  M3,
914  M4)),
917  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
918  N1, // N1 = NWave
919  N2))), // N2 = NPerXdl
921  make_tuple(
923 
924  // calculate origin of thread output tensor on global memory
925  // blockwise GEMM c matrix starting index
926  const auto c_thread_mtx_on_block =
927  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
928 
929  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
930  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
931 
932  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
934  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
937 
938  const auto m_thread_data_on_block_idx =
939  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
940  make_multi_index(m_thread_data_on_block));
941 
942  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
947 
948  const auto n_thread_data_on_block_idx =
949  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
950  make_multi_index(n_thread_data_on_block));
951 
952  // shuffle: threadwise copy C from VGPR to LDS
953  auto c_thread_copy_vgpr_to_lds =
955  FloatCShuffle,
956  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
957  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
959  Sequence<CShuffleMXdlPerWavePerShuffle,
960  CShuffleNXdlPerWavePerShuffle,
961  I1,
962  I1,
963  M2,
964  I1,
965  M4,
966  I1>,
968  7,
969  1,
971  1,
972  true>{
973  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
975  0,
976  m_thread_data_on_block_idx[I1],
977  n_thread_data_on_block_idx[I1],
978  m_thread_data_on_block_idx[I2],
979  m_thread_data_on_block_idx[I3],
980  m_thread_data_on_block_idx[I4],
981  n_thread_data_on_block_idx[I2]),
983 
984  // shuffle: blockwise copy C from LDS to global
985  auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
986  ThisThreadBlock, // ThreadGroup
987  CElementwiseOperation, // ElementwiseOperation,
988  CGlobalMemoryDataOperation, // DstInMemOp,
989  Sequence<1,
990  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
991  1,
992  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
993  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
994  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
995  FloatCShuffle, // typename SrcData,
996  FloatC, // typename DstData,
997  decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
998  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
999  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1000  3, // index_t VectorDim,
1001  CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1002  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1003  false> // bool ThreadTransferDstResetCoordinateAfterRun>
1004  {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1005  make_multi_index(0, 0, 0, 0),
1006  c_grid_desc_mblock_mperblock_nblock_nperblock,
1007  make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
1008  c_element_op};
1009 
1010  // space filling curve for threadwise C in VGPR
1011  constexpr auto sfc_c_vgpr =
1014  Sequence<CShuffleMXdlPerWavePerShuffle,
1015  CShuffleNXdlPerWavePerShuffle,
1016  1,
1017  1,
1018  M2,
1019  1,
1020  M4,
1021  1>>{};
1022 
1023  // space filling curve for shuffled blockwise C in global mem
1024  constexpr auto sfc_c_global =
1027  Sequence<1,
1028  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1029  1,
1030  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1031 
1032  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1033 
1034  static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1035 
1036  static_for<0, num_access, 1>{}([&](auto access_id) {
1037  // make sure it's safe to write to LDS
1038  block_sync_lds();
1039 
1040  // each thread write its data from VGPR to LDS
1041  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1042  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1043  c_thread_buf,
1044  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1045  c_shuffle_block_buf);
1046 
1047  // make sure it's safe to read from LDS
1048  block_sync_lds();
1049 
1050  // each block copy its data from LDS to global
1051  c_shuffle_block_copy_lds_to_global.Run(
1052  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1053  c_shuffle_block_buf,
1054  c_grid_desc_mblock_mperblock_nblock_nperblock,
1055  c_grid_buf);
1056 
1057  if constexpr(access_id < num_access - 1)
1058  {
1059  constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1060 
1061  // move on C
1062  c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1063  c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1064  }
1065  });
1066  }
1067  }
1068 };
1069 
1070 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:34
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto integer_divide_floor(X x, Y y)
Definition: math.hpp:66
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition: blockwise_gemm_xdlops.hpp:606
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:267
__global__ void kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:25
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:17
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:472
__host__ Argument(const FloatA *p_a_grid_, const FloatB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:473
const FloatB * p_b_grid
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:490
const FloatA * p_a_grid
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:489
FloatC * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:491
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:414
index_t NPadded
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:462
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:456
index_t StrideA
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:458
index_t StrideB
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:459
index_t NBlock
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:467
index_t BK0
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:465
index_t KPadded
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:463
index_t MPadded
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:461
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:457
index_t MBlock
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:466
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:437
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:415
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:455
index_t AK0
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:464
index_t StrideC
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:460
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:114
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:506
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:557
static constexpr auto BK1Number
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:128
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:142
static __host__ auto CalculateAK0(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:152
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:137
static constexpr auto I6
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:121
static constexpr auto I1
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:116
static constexpr auto I2
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:117
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:661
static constexpr auto I5
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:120
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:361
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:130
static __host__ auto CalculateBK0(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:169
static __device__ void Run(const FloatA *__restrict__ p_a_grid, const FloatB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, void *__restrict__ p_shared, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:686
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:191
static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:278
static constexpr auto BK0Number
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:126
static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:196
static constexpr auto I7
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:122
static constexpr auto I3
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:118
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:147
static constexpr auto AK1Number
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:127
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:186
static constexpr auto AK0Number
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:125
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:529
static constexpr auto I0
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:115
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:514
static constexpr auto I4
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:119
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:498
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage, LoopSched >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:496
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:669
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:132
Definition: xdlops_gemm.hpp:886
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: functional2.hpp:31
Definition: device_base.hpp:50
Definition: unary_element_wise_operation.hpp:241