/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp Source File
gridwise_gemm_xdl_cshuffle_v2.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
17 
18 namespace ck {
19 
20 template <typename GridwiseGemm, bool HasMainKBlockLoop, index_t TailNum = 3>
21 __global__ void
22 #if CK_USE_LAUNCH_BOUNDS
23  __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1)
24 #endif
25  // __attribute__((amdgpu_waves_per_eu(1, 1)))
26  kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg)
27 {
28 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
29  defined(__gfx94__))
30  // Pass two lds pointer is the key to tell compiler that ds_read/write
31  // operate on different lds chunk at same time without order dependecy
32  __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
33  __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
34 
35  GridwiseGemm::template Run<HasMainKBlockLoop, TailNum>(
36  karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg);
37 #else
38  ignore = karg;
39 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
40 }
41 
42 template <typename GridwiseGemm,
43  typename FloatA,
44  typename FloatB,
45  typename FloatC,
46  bool HasMainKBlockLoop>
47 __global__ void
48 #if CK_USE_LAUNCH_BOUNDS
49  __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1)
50 #endif
51  kernel_gemm_xdl_cshuffle_v2(const FloatA* p_a_grid,
52  const FloatB* p_b_grid,
53  FloatC* p_c_grid,
54  typename GridwiseGemm::Problem problem)
55 {
56 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
57  defined(__gfx94__))
58  __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
59  __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
60 
61  GridwiseGemm::template Run<HasMainKBlockLoop>(
62  p_a_grid, p_b_grid, p_c_grid, p_shared_0, p_shared_1, problem);
63 #else
64  ignore = p_a_grid;
65  ignore = p_b_grid;
66  ignore = p_c_grid;
67  ignore = problem;
68 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
69 }
70 
71 template <typename ALayout,
72  typename BLayout,
73  typename CLayout,
74  typename FloatA,
75  typename FloatB,
76  typename FloatGemmAcc,
77  typename FloatCShuffle,
78  typename FloatC,
79  typename AElementwiseOperation,
80  typename BElementwiseOperation,
81  typename CElementwiseOperation,
83  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
84  index_t NumGemmKPrefetchStage,
85  index_t BlockSize,
86  index_t MPerBlock,
87  index_t NPerBlock,
88  index_t KPerBlock,
89  index_t AK1Value,
90  index_t BK1Value,
91  index_t MPerXdl,
92  index_t NPerXdl,
93  index_t MXdlPerWave,
94  index_t NXdlPerWave,
95  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
96  typename ABlockTransferThreadClusterArrangeOrder,
97  typename ABlockTransferSrcAccessOrder,
98  index_t ABlockTransferSrcVectorDim,
99  index_t ABlockTransferSrcScalarPerVector,
100  index_t ABlockTransferDstScalarPerVector_AK1,
101  bool AThreadTransferSrcResetCoordinateAfterRun,
102  index_t ABlockLdsExtraM,
103  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
104  typename BBlockTransferThreadClusterArrangeOrder,
105  typename BBlockTransferSrcAccessOrder,
106  index_t BBlockTransferSrcVectorDim,
107  index_t BBlockTransferSrcScalarPerVector,
108  index_t BBlockTransferDstScalarPerVector_BK1,
109  bool BThreadTransferSrcResetCoordinateAfterRun,
110  index_t BBlockLdsExtraN,
111  index_t CShuffleMXdlPerWavePerShuffle,
112  index_t CShuffleNXdlPerWavePerShuffle,
113  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
114  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
115  LoopScheduler LoopSched,
116  PipelineVersion PipelineVer = PipelineVersion::v1,
117  typename ComputeTypeA = FloatC,
118  typename ComputeTypeB = ComputeTypeA>
120 {
121  static constexpr auto I0 = Number<0>{};
122  static constexpr auto I1 = Number<1>{};
123  static constexpr auto I2 = Number<2>{};
124  static constexpr auto I3 = Number<3>{};
125  static constexpr auto I4 = Number<4>{};
126  static constexpr auto I5 = Number<5>{};
127  static constexpr auto I6 = Number<6>{};
128  static constexpr auto I7 = Number<7>{};
129 
130  // K1 should be Number<...>
131  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
132  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
133  static constexpr auto AK1Number = Number<AK1Value>{};
134  static constexpr auto BK1Number = Number<BK1Value>{};
135 
137 
138  __host__ static auto CalculateGridSize(index_t M, index_t N)
139  {
141  }
142 
143  __host__ static auto CalculateMPadded(index_t M)
144  {
145  return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
146  }
147 
148  __host__ static auto CalculateNPadded(index_t N)
149  {
150  return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
151  }
152 
153  __host__ static auto CalculateKPadded(index_t K)
154  {
155  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
156  }
157 
158  __host__ static auto CalculateAK0(index_t K)
159  {
161 
162  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
163  GemmSpec == GemmSpecialization::MNKPadding ||
164  GemmSpec == GemmSpecialization::KPadding ||
165  GemmSpec == GemmSpecialization::NKPadding)
166  {
167  return CalculateKPadded(K) / AK1Value;
168  }
169  else
170  {
171  return K / AK1Value;
172  }
173  }
174 
175  __host__ static auto CalculateBK0(index_t K)
176  {
178 
179  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
180  GemmSpec == GemmSpecialization::MNKPadding ||
181  GemmSpec == GemmSpecialization::KPadding ||
182  GemmSpec == GemmSpecialization::MKPadding)
183  {
184  return CalculateKPadded(K) / BK1Value;
185  }
186  else
187  {
188  return K / BK1Value;
189  }
190  }
191 
192  __host__ static auto CalculateMBlock(index_t M)
193  {
194  return math::integer_divide_floor(M, MPerBlock);
195  }
196 
197  __host__ static auto CalculateNBlock(index_t N)
198  {
199  return math::integer_divide_floor(N, NPerBlock);
200  }
201 
202  template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
203  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
204  {
205  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
206  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
207 
209  TileDesc_K0_MN_K1{},
215  }
216 
217  __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
218  index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
219  {
220  const auto a_grid_desc_mraw_kraw = [&]() {
221  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
222  {
223  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
224  }
225  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
226  {
227  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
228  }
229  }();
230 
232 
233  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
234  GemmSpec == GemmSpecialization::MNKPadding)
235  {
236  // pad both M and K
237  const auto a_grid_desc_m_k =
238  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
240  make_right_pad_transform(K, KPad - K)),
243 
244  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
245  a_grid_desc_m_k,
250 
251  return a_grid_desc_ak0_m_ak1;
252  }
253  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
254  GemmSpec == GemmSpecialization::MNPadding)
255  {
256  // pad M, but not K
257  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
258  a_grid_desc_mraw_kraw,
260  make_right_pad_transform(M, MPad - M)),
263 
264  return a_grid_desc_ak0_m_ak1;
265  }
266  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
267  GemmSpec == GemmSpecialization::NKPadding)
268  {
269  // pad K, but not M
270  const auto a_grid_desc_m_k = transform_tensor_descriptor(
271  a_grid_desc_mraw_kraw,
275 
276  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
277  a_grid_desc_m_k,
282 
283  return a_grid_desc_ak0_m_ak1;
284  }
285  else
286  {
287  // not pad M or K
288  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
289  a_grid_desc_mraw_kraw,
294 
295  return a_grid_desc_ak0_m_ak1;
296  }
297  }
298 
299  __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
300  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
301  {
302  const auto b_grid_desc_nraw_kraw = [&]() {
304  {
305  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
306  }
308  {
309  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
310  }
311  }();
312 
314 
315  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
316  GemmSpec == GemmSpecialization::MNKPadding)
317  {
318  // pad both N and K
319  const auto b_grid_desc_n_k =
320  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
322  make_right_pad_transform(K, KPad - K)),
325 
326  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
327  b_grid_desc_n_k,
332 
333  return b_grid_desc_bk0_n_bk1;
334  }
335  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
336  GemmSpec == GemmSpecialization::MNPadding)
337  {
338  // pad N, but not K
339  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
340  b_grid_desc_nraw_kraw,
342  make_right_pad_transform(N, NPad - N)),
345 
346  return b_grid_desc_bk0_n_bk1;
347  }
348  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
349  GemmSpec == GemmSpecialization::MKPadding)
350  {
351  // pad K, but not N
352  const auto b_grid_desc_n_k = transform_tensor_descriptor(
353  b_grid_desc_nraw_kraw,
357 
358  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
359  b_grid_desc_n_k,
364 
365  return b_grid_desc_bk0_n_bk1;
366  }
367  else
368  {
369  // not pad N or K
370  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
371  b_grid_desc_nraw_kraw,
376 
377  return b_grid_desc_bk0_n_bk1;
378  }
379  }
380 
381  template <typename ABlockDesc_AK0_M_AK1>
382  __host__ __device__ static constexpr auto
383  MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
384  {
385  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
386 
387  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
388  }
389 
390  template <typename BBlockDesc_BK0_N_BK1>
391  __host__ __device__ static constexpr auto
392  MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
393  {
394  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
395 
396  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
397  }
398 
399  __host__ __device__ static auto
401  {
402  const auto c_grid_desc_mraw_nraw = [&]() {
404  {
405  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
406  }
408  {
409  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
410  }
411  }();
412 
414 
415  if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
416  GemmSpec == GemmSpecialization::MNKPadding)
417  {
418  // pad M and N
419  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
421  make_right_pad_transform(N, NPad - N)),
424  }
425  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
426  GemmSpec == GemmSpecialization::MKPadding)
427  {
428  // pad M, but not N
430  c_grid_desc_mraw_nraw,
434  }
435  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
436  GemmSpec == GemmSpecialization::NKPadding)
437  {
438  // pad N, but not M
440  c_grid_desc_mraw_nraw,
444  }
445  else
446  {
447  // not pad M or N
448  return c_grid_desc_mraw_nraw;
449  }
450  }
451 
452  struct Problem
453  {
454  __host__ Problem(index_t M_,
455  index_t N_,
456  index_t K_,
457  index_t StrideA_,
458  index_t StrideB_,
459  index_t StrideC_)
460  : M{M_},
461  N{N_},
462  K{K_},
463  StrideA{StrideA_},
464  StrideB{StrideB_},
465  StrideC{StrideC_},
469  AK0{CalculateAK0(K_)},
470  BK0{CalculateBK0(K_)},
471  MBlock{CalculateMBlock(M_)},
473  {
474  }
475 
476  __host__ void Print() const
477  {
478  std::cout << "problem {"
479  << "M:" << M << ", "
480  << "N:" << N << ", "
481  << "K:" << K << ", "
482  << "SA:" << StrideA << ", "
483  << "SB:" << StrideB << ", "
484  << "SC:" << StrideC << ", "
485  << "MP:" << MPadded << ", "
486  << "NP:" << NPadded << ", "
487  << "KP:" << KPadded << ", "
488  << "AK0:" << AK0 << ", "
489  << "BK0:" << BK0 << ", "
490  << "MBlock: " << MBlock << ", "
491  << "NBlock: " << NBlock << "}" << std::endl;
492  }
493 
507  };
508 
509  // Argument
511  {
512  __host__ Argument(const FloatA* p_a_grid_,
513  const FloatB* p_b_grid_,
514  FloatC* p_c_grid_,
515  index_t M_,
516  index_t N_,
517  index_t K_,
518  index_t StrideA_,
519  index_t StrideB_,
520  index_t StrideC_)
521  : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
522  p_a_grid{p_a_grid_},
523  p_b_grid{p_b_grid_},
524  p_c_grid{p_c_grid_}
525  {
526  }
527 
528  const FloatA* p_a_grid;
529  const FloatB* p_b_grid;
530  FloatC* p_c_grid;
531  };
532 
533  // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
535  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
536 
537  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
538  {
539  // A matrix in LDS memory, dst of blockwise copy
543  }
544 
545  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
546  {
547  // B matrix in LDS memory, dst of blockwise copy
551  }
552 
554  {
555  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
556  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
557 
558  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
560  make_tuple(I1,
562  I1,
564 
565  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
566  }
567 
568  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
569  {
570  // LDS allocation for A and B: be careful of alignment
571  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
572  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
573 
574  // lds max alignment
575  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
576 
577  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
578  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
579 
580  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
581  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
582 
583  // LDS allocation for C shuffle in LDS
584  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
586 
587  constexpr auto c_block_size =
588  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
589 
590  return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) +
591  b_block_space_size_aligned * sizeof(ComputeTypeB)),
592  c_block_size * sizeof(FloatCShuffle));
593  }
594 
595  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
596  __host__ static constexpr bool CheckValidity(const Problem& problem)
597  {
598  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
599  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
600  "Invalid tuning param!");
601 
602  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
606  {
607  if(!(problem.M % MPerBlock == 0))
608  {
609  return false;
610  }
611  }
612 
613  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
617  {
618  if(!(problem.N % NPerBlock == 0))
619  {
620  return false;
621  }
622  }
623 
628  {
629  if(!(CalculateKPadded(problem.K) % AK1Value == 0) ||
630  !(CalculateKPadded(problem.K) % BK1Value == 0))
631  {
632  return false;
633  }
634  }
635  else
636  {
637  if(!(problem.K % AK1Value == 0) || !(problem.K % BK1Value == 0))
638  {
639  return false;
640  }
641  }
642 
644  {
645  if(problem.K % ABlockTransferSrcScalarPerVector != 0)
646  {
647  return false;
648  }
649  }
650  else
651  {
652  if(problem.M % ABlockTransferSrcScalarPerVector != 0)
653  {
654  return false;
655  }
656  }
657 
659  {
660  if(problem.N % BBlockTransferSrcScalarPerVector != 0)
661  {
662  return false;
663  }
664  }
665  else
666  {
667  if(problem.K % BBlockTransferSrcScalarPerVector != 0)
668  {
669  return false;
670  }
671  }
672 
674  {
675  if(problem.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
676  {
677  return false;
678  }
679  }
680  else
681  {
682  if(problem.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
683  {
684  return false;
685  }
686  }
687 
688  // check gridwise gemm pipeline
689  const auto num_k_loop = (CalculateAK0(problem.K) * AK1Value) / KPerBlock;
690 
691  if(num_k_loop < 4)
692  {
693  return false;
694  }
695 
696  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
697  return true;
698  }
699 
700  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
701  {
702  const index_t num_loop = K / KPerBlock;
703 
704  return num_loop > 3;
705  }
706 
707  __host__ static constexpr index_t CalculateKBlockLoopTailNum(index_t K)
708  {
709  const index_t num_loop = K / KPerBlock;
710 
711  if(num_loop % 2 == 1)
712  return 3;
713  else
714  return 2;
715  }
716 
717  template <typename CGridDesc>
719  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
720  {
721  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
722  c_grid_desc_m_n,
727 
728  return c_grid_desc_mblock_mperblock_nblock_nperblock;
729  }
730 
731  // return block_id to C matrix tile idx (m0, n0) mapping
732  // if arch = gfx942
734 
735  template <bool HasMainKBlockLoop, index_t TailNum = 3>
736  __device__ static void Run(const FloatA* p_a_grid,
737  const FloatB* p_b_grid,
738  FloatC* p_c_grid,
739  void* p_shared_0,
740  void* p_shared_1,
741  const Problem& problem)
742  {
743  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
744  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
745  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
746  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
747  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
748  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
749 
750  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
752  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
753 
754  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
755  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
756  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
757  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
758  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
759  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
760 
761  const AElementwiseOperation a_element_op{};
762  const BElementwiseOperation b_element_op{};
763  const CElementwiseOperation c_element_op{};
764 
765  // divide block work by [M, N]
766  const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
767 
768  const auto block_work_idx =
769  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
770 
771  if(!block_2_ctile_map.ValidCTileIndex(
772  block_work_idx,
773  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
774  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
775  {
776  return;
777  }
778 #if 0
779  if(threadIdx.x == 0){
780  printf("Hardware assigned No. %03d workgroup of logical C tile (%02d, %02d) on %d th XCC Die, %d th SE, %d th CU\n",
781  get_block_1d_id(),
782  block_work_idx[I0],
783  block_work_idx[I1],
784  __smid()>>6 & 0xf,
785  __smid()>>4 & 0x3,
786  __smid() & 0xf);
787  }
788 #endif
789  // HACK: this force m/n_block_data_idx_on_grid into SGPR
790  const index_t m_block_data_idx_on_grid =
791  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
792 
793  const index_t n_block_data_idx_on_grid =
794  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
795 
796  // lds max alignment
797  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
798 
799  // A matrix in LDS memory, dst of blockwise copy
800  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
801 
802  // B matrix in LDS memory, dst of blockwise copy
803  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
804 
805  // A matrix blockwise copy
806  auto a_blockwise_copy =
808  AElementwiseOperation,
812  ABlockTransferThreadClusterLengths_AK0_M_AK1,
813  ABlockTransferThreadClusterArrangeOrder,
814  FloatA,
815  ComputeTypeA,
816  decltype(a_grid_desc_ak0_m_ak1),
817  decltype(a_block_desc_ak0_m_ak1),
818  ABlockTransferSrcAccessOrder,
820  ABlockTransferSrcVectorDim,
821  2,
822  ABlockTransferSrcScalarPerVector,
823  ABlockTransferDstScalarPerVector_AK1,
824  1,
825  1,
826  AThreadTransferSrcResetCoordinateAfterRun,
827  true>(
828  a_grid_desc_ak0_m_ak1,
829  make_multi_index(0, m_block_data_idx_on_grid, 0),
830  a_element_op,
831  a_block_desc_ak0_m_ak1,
832  make_multi_index(0, 0, 0),
834 
835  // B matrix blockwise copy
836  auto b_blockwise_copy =
838  BElementwiseOperation,
842  BBlockTransferThreadClusterLengths_BK0_N_BK1,
843  BBlockTransferThreadClusterArrangeOrder,
844  FloatB,
845  ComputeTypeB,
846  decltype(b_grid_desc_bk0_n_bk1),
847  decltype(b_block_desc_bk0_n_bk1),
848  BBlockTransferSrcAccessOrder,
850  BBlockTransferSrcVectorDim,
851  2,
852  BBlockTransferSrcScalarPerVector,
853  BBlockTransferDstScalarPerVector_BK1,
854  1,
855  1,
856  BThreadTransferSrcResetCoordinateAfterRun,
857  true>(
858  b_grid_desc_bk0_n_bk1,
859  make_multi_index(0, n_block_data_idx_on_grid, 0),
860  b_element_op,
861  b_block_desc_bk0_n_bk1,
862  make_multi_index(0, 0, 0),
864 
865  // GEMM definition
866  // c_mtx += transpose(a_mtx) * b_mtx
867  // a_mtx[K0PerBlock, MPerBlock] is in LDS
868  // b_mtx[K0PerBlock, NPerBlock] is in LDS
869  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
870  // register
871  // sanity check
872  constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
873  constexpr bool is_single_rate_mfma =
875  lcm_AK1_BK1 <= 4) ||
876  (is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
878  lcm_AK1_BK1 < 32))
879  ? true
880  : false;
881  constexpr auto is_scale_mfma = false;
882  constexpr index_t KPack = math::max(lcm_AK1_BK1,
883  MfmaSelector<ComputeTypeA,
884  MPerXdl,
885  NPerXdl,
886  ComputeTypeA,
887  is_single_rate_mfma,
888  is_scale_mfma>::selected_mfma.k_per_blk);
889 
890  // auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
891  // BlockSize,
892  // ComputeType,
893  // FloatGemmAcc,
894  // decltype(a_block_desc_ak0_m_ak1),
895  // decltype(b_block_desc_bk0_n_bk1),
896  // MPerXdl,
897  // NPerXdl,
898  // MXdlPerWave,
899  // NXdlPerWave,
900  // KPack,
901  // LoopSched>();
902  auto blockwise_gemm_pipeline = BlockwiseGemmXdlops_pipeline_v4<
903  BlockSize,
904  ComputeTypeA,
905  FloatGemmAcc,
906  decltype(a_block_desc_ak0_m_ak1),
907  decltype(b_block_desc_bk0_n_bk1),
908  decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
909  decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
910  MPerBlock,
911  NPerBlock,
912  KPerBlock,
913  MPerXdl,
914  NPerXdl,
915  MXdlPerWave,
916  NXdlPerWave,
917  KPack>{}; // TransposeC
918 
919  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
920 
921  // LDS allocation for A and B: be careful of alignment
922  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
923  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
924 
925  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
926  static_cast<ComputeTypeA*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
927 
928  auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
929  static_cast<ComputeTypeB*>(p_shared_0) + a_block_space_size_aligned,
930  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
931 
932  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
933  static_cast<ComputeTypeA*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
934 
935  auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
936  static_cast<ComputeTypeB*>(p_shared_1) + a_block_space_size_aligned,
937  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
938 
939  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
940  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
941 
942  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
943  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
944 
945  // gridwise GEMM pipeline
946  static_assert(std::is_default_constructible_v<GridwiseGemmPipe>);
947  // const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
948 
949  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
950  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
951  KPerBlock);
952 
953  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
954  a_block_desc_ak0_m_ak1,
955  a_blockwise_copy,
956  a_grid_buf,
957  a_block_bufs,
958  a_block_slice_copy_step,
959  b_grid_desc_bk0_n_bk1,
960  b_block_desc_bk0_n_bk1,
961  b_blockwise_copy,
962  b_grid_buf,
963  b_block_bufs,
964  b_block_slice_copy_step,
965  c_thread_buf,
966  num_k_block_main_loop);
967 
968  // shuffle C and write out
969  {
970  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
971  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
972  "wrong!");
973 
974  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
975  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
976 
977  // TODO: hacky, fix it!
978  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
979  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
980 
981  // TODO: hacky, fix it!
982  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
983  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
984  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
985 
986  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
987  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
988  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
989  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
990  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
991  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
992  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
993  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
994 
995  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
997 
998  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
999  static_cast<FloatCShuffle*>(p_shared_0),
1000  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1001 
1002  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1003  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1004  make_tuple(
1007  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1008  M1, // M1 = MWave
1009  M2, // M2 * M3 * M4 = MPerXdl
1010  M3,
1011  M4)),
1014  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1015  N1, // N1 = NWave
1016  N2))), // N2 = NPerXdl
1018  make_tuple(
1020 
1021  // calculate origin of thread output tensor on global memory
1022  // blockwise GEMM c matrix starting index
1023  const auto c_thread_mtx_on_block =
1024  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1025 
1026  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1027  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1028 
1029  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1031  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1033  make_tuple(Sequence<0>{}));
1034 
1035  const auto m_thread_data_on_block_idx =
1036  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1037  make_multi_index(m_thread_data_on_block));
1038 
1039  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1043  make_tuple(Sequence<0>{}));
1044 
1045  const auto n_thread_data_on_block_idx =
1046  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1047  make_multi_index(n_thread_data_on_block));
1048 
1049  // shuffle: threadwise copy C from VGPR to LDS
1050  auto c_thread_copy_vgpr_to_lds =
1052  FloatCShuffle,
1053  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1054  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1056  Sequence<CShuffleMXdlPerWavePerShuffle,
1057  CShuffleNXdlPerWavePerShuffle,
1058  I1,
1059  I1,
1060  M2,
1061  I1,
1062  M4,
1063  I1>,
1065  7,
1066  1,
1068  1,
1069  true>{
1070  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1071  make_multi_index(0,
1072  0,
1073  m_thread_data_on_block_idx[I1],
1074  n_thread_data_on_block_idx[I1],
1075  m_thread_data_on_block_idx[I2],
1076  m_thread_data_on_block_idx[I3],
1077  m_thread_data_on_block_idx[I4],
1078  n_thread_data_on_block_idx[I2]),
1080 
1081  // shuffle: blockwise copy C from LDS to global
1082  auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1083  ThisThreadBlock, // ThreadGroup
1084  CElementwiseOperation, // ElementwiseOperation,
1085  CGlobalMemoryDataOperation, // DstInMemOp,
1086  Sequence<1,
1087  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1088  1,
1089  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1090  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1091  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1092  FloatCShuffle, // typename SrcData,
1093  FloatC, // typename DstData,
1094  decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1095  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1096  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1097  3, // index_t VectorDim,
1098  CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1099  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1100  false> // bool ThreadTransferDstResetCoordinateAfterRun>
1101  {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1102  make_multi_index(0, 0, 0, 0),
1103  c_grid_desc_mblock_mperblock_nblock_nperblock,
1104  make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
1105  c_element_op};
1106 
1107  // space filling curve for threadwise C in VGPR
1108  constexpr auto sfc_c_vgpr =
1111  Sequence<CShuffleMXdlPerWavePerShuffle,
1112  CShuffleNXdlPerWavePerShuffle,
1113  1,
1114  1,
1115  M2,
1116  1,
1117  M4,
1118  1>>{};
1119 
1120  // space filling curve for shuffled blockwise C in global mem
1121  constexpr auto sfc_c_global =
1124  Sequence<1,
1125  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1126  1,
1127  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1128 
1129  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1130 
1131  static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1132 
1133  static_for<0, num_access, 1>{}([&](auto access_id) {
1134  // make sure it's safe to write to LDS
1135  block_sync_lds();
1136 
1137  // each thread write its data from VGPR to LDS
1138  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1139  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1140  c_thread_buf,
1141  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1142  c_shuffle_block_buf);
1143 
1144  // make sure it's safe to read from LDS
1145  block_sync_lds();
1146 
1147  // each block copy its data from LDS to global
1148  c_shuffle_block_copy_lds_to_global.Run(
1149  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1150  c_shuffle_block_buf,
1151  c_grid_desc_mblock_mperblock_nblock_nperblock,
1152  c_grid_buf);
1153 
1154  if constexpr(access_id < num_access - 1)
1155  {
1156  constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1157 
1158  // move on C
1159  c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1160  c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1161  }
1162  });
1163  }
1164  }
1165 };
1166 
1167 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto integer_divide_floor(X x, Y y)
Definition: math.hpp:66
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:269
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:278
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__global__ void kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:26
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:300
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
Definition: block_to_ctile_map.hpp:270
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:282
Definition: blockwise_gemm_pipeline_xdlops.hpp:103
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:511
const FloatB * p_b_grid
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:529
__host__ Argument(const FloatA *p_a_grid_, const FloatB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:512
FloatC * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:530
const FloatA * p_a_grid
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:528
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:453
index_t StrideB
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:498
index_t StrideC
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:499
index_t KPadded
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:502
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:495
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:494
index_t MBlock
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:505
index_t NBlock
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:506
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:476
index_t BK0
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:504
index_t MPadded
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:500
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:496
index_t StrideA
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:497
index_t AK0
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:503
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:454
index_t NPadded
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:501
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:120
static constexpr auto I6
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:127
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:568
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:143
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:138
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:700
static constexpr __host__ index_t CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:707
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage, LoopSched >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:535
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:718
static __device__ void Run(const FloatA *p_a_grid, const FloatB *p_b_grid, FloatC *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:736
static constexpr auto AK1Number
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:133
static constexpr auto I1
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:122
static constexpr auto I2
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:123
static constexpr auto AK0Number
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:131
static constexpr auto I0
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:121
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:192
static constexpr auto BK0Number
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:132
static __host__ auto CalculateAK0(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:158
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:197
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:537
static constexpr auto I4
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:125
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:545
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:203
static constexpr auto I5
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:126
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:553
static __host__ auto CalculateBK0(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:175
static constexpr auto BK1Number
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:134
static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:217
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:383
static constexpr auto I3
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:124
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:153
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:400
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:148
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:392
static constexpr auto I7
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:128
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:596
static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:299
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:136
Definition: xdlops_gemm.hpp:942
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:308