/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v2.hpp Source File
gridwise_gemm_xdl_cshuffle_v2.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
17 
18 namespace ck {
19 
20 template <typename GridwiseGemm, bool HasMainKBlockLoop, index_t TailNum = 3>
21 __global__ void
22 #if CK_USE_LAUNCH_BOUNDS
23  __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1)
24 #endif
25  // __attribute__((amdgpu_waves_per_eu(1, 1)))
26  kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg)
27 {
28 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
29  defined(__gfx94__))
30  // Pass two lds pointer is the key to tell compiler that ds_read/write
31  // operate on different lds chunk at same time without order dependecy
32  __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
33  __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
34 
35  GridwiseGemm::template Run<HasMainKBlockLoop, TailNum>(
36  karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared_0, p_shared_1, karg);
37 #else
38  ignore = karg;
39 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
40 }
41 
42 template <typename GridwiseGemm,
43  typename FloatA,
44  typename FloatB,
45  typename FloatC,
46  bool HasMainKBlockLoop>
47 __global__ void
48 #if CK_USE_LAUNCH_BOUNDS
49  __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, 1)
50 #endif
51  kernel_gemm_xdl_cshuffle_v2(const FloatA* p_a_grid,
52  const FloatB* p_b_grid,
53  FloatC* p_c_grid,
54  typename GridwiseGemm::Problem problem)
55 {
56 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
57  defined(__gfx94__))
58  __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
59  __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
60 
61  GridwiseGemm::template Run<HasMainKBlockLoop>(
62  p_a_grid, p_b_grid, p_c_grid, p_shared_0, p_shared_1, problem);
63 #else
64  ignore = p_a_grid;
65  ignore = p_b_grid;
66  ignore = p_c_grid;
67  ignore = problem;
68 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
69 }
70 
71 template <typename ALayout,
72  typename BLayout,
73  typename CLayout,
74  typename FloatA,
75  typename FloatB,
76  typename FloatGemmAcc,
77  typename FloatCShuffle,
78  typename FloatC,
79  typename AElementwiseOperation,
80  typename BElementwiseOperation,
81  typename CElementwiseOperation,
83  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
84  index_t NumGemmKPrefetchStage,
85  index_t BlockSize,
86  index_t MPerBlock,
87  index_t NPerBlock,
88  index_t KPerBlock,
89  index_t AK1Value,
90  index_t BK1Value,
91  index_t MPerXdl,
92  index_t NPerXdl,
93  index_t MXdlPerWave,
94  index_t NXdlPerWave,
95  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
96  typename ABlockTransferThreadClusterArrangeOrder,
97  typename ABlockTransferSrcAccessOrder,
98  index_t ABlockTransferSrcVectorDim,
99  index_t ABlockTransferSrcScalarPerVector,
100  index_t ABlockTransferDstScalarPerVector_AK1,
101  bool AThreadTransferSrcResetCoordinateAfterRun,
102  index_t ABlockLdsExtraM,
103  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
104  typename BBlockTransferThreadClusterArrangeOrder,
105  typename BBlockTransferSrcAccessOrder,
106  index_t BBlockTransferSrcVectorDim,
107  index_t BBlockTransferSrcScalarPerVector,
108  index_t BBlockTransferDstScalarPerVector_BK1,
109  bool BThreadTransferSrcResetCoordinateAfterRun,
110  index_t BBlockLdsExtraN,
111  index_t CShuffleMXdlPerWavePerShuffle,
112  index_t CShuffleNXdlPerWavePerShuffle,
113  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
114  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
115  LoopScheduler LoopSched,
116  PipelineVersion PipelineVer = PipelineVersion::v1,
117  typename ComputeTypeA = FloatC,
118  typename ComputeTypeB = ComputeTypeA>
120 {
121  static constexpr auto I0 = Number<0>{};
122  static constexpr auto I1 = Number<1>{};
123  static constexpr auto I2 = Number<2>{};
124  static constexpr auto I3 = Number<3>{};
125  static constexpr auto I4 = Number<4>{};
126  static constexpr auto I5 = Number<5>{};
127  static constexpr auto I6 = Number<6>{};
128  static constexpr auto I7 = Number<7>{};
129 
130  // K1 should be Number<...>
131  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
132  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
133  static constexpr auto AK1Number = Number<AK1Value>{};
134  static constexpr auto BK1Number = Number<BK1Value>{};
135 
137 
138  __host__ static auto CalculateGridSize(index_t M, index_t N)
139  {
141  }
142 
143  __host__ static auto CalculateMPadded(index_t M)
144  {
145  return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
146  }
147 
148  __host__ static auto CalculateNPadded(index_t N)
149  {
150  return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
151  }
152 
153  __host__ static auto CalculateKPadded(index_t K)
154  {
155  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
156  }
157 
158  __host__ static auto CalculateAK0(index_t K)
159  {
161 
162  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
163  GemmSpec == GemmSpecialization::MNKPadding ||
164  GemmSpec == GemmSpecialization::KPadding ||
165  GemmSpec == GemmSpecialization::NKPadding)
166  {
167  return CalculateKPadded(K) / AK1Value;
168  }
169  else
170  {
171  return K / AK1Value;
172  }
173  }
174 
175  __host__ static auto CalculateBK0(index_t K)
176  {
178 
179  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
180  GemmSpec == GemmSpecialization::MNKPadding ||
181  GemmSpec == GemmSpecialization::KPadding ||
182  GemmSpec == GemmSpecialization::MKPadding)
183  {
184  return CalculateKPadded(K) / BK1Value;
185  }
186  else
187  {
188  return K / BK1Value;
189  }
190  }
191 
192  __host__ static auto CalculateMBlock(index_t M)
193  {
194  return math::integer_divide_floor(M, MPerBlock);
195  }
196 
197  __host__ static auto CalculateNBlock(index_t N)
198  {
199  return math::integer_divide_floor(N, NPerBlock);
200  }
201 
202  template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
203  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
204  {
205  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
206  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
207 
209  TileDesc_K0_MN_K1{},
215  }
216 
217  __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
218  index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
219  {
220  const auto a_grid_desc_mraw_kraw = [&]() {
221  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
222  {
223  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
224  }
225  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
226  {
227  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
228  }
229  }();
230 
232 
233  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
234  GemmSpec == GemmSpecialization::MNKPadding)
235  {
236  // pad both M and K
237  const auto a_grid_desc_m_k =
238  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
240  make_right_pad_transform(K, KPad - K)),
243 
244  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
245  a_grid_desc_m_k,
250 
251  return a_grid_desc_ak0_m_ak1;
252  }
253  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
254  GemmSpec == GemmSpecialization::MNPadding)
255  {
256  // pad M, but not K
257  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
258  a_grid_desc_mraw_kraw,
260  make_right_pad_transform(M, MPad - M)),
263 
264  return a_grid_desc_ak0_m_ak1;
265  }
266  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
267  GemmSpec == GemmSpecialization::NKPadding)
268  {
269  // pad K, but not M
270  const auto a_grid_desc_m_k = transform_tensor_descriptor(
271  a_grid_desc_mraw_kraw,
275 
276  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
277  a_grid_desc_m_k,
282 
283  return a_grid_desc_ak0_m_ak1;
284  }
285  else
286  {
287  // not pad M or K
288  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
289  a_grid_desc_mraw_kraw,
294 
295  return a_grid_desc_ak0_m_ak1;
296  }
297  }
298 
299  __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
300  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
301  {
302  const auto b_grid_desc_nraw_kraw = [&]() {
304  {
305  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
306  }
308  {
309  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
310  }
311  }();
312 
314 
315  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
316  GemmSpec == GemmSpecialization::MNKPadding)
317  {
318  // pad both N and K
319  const auto b_grid_desc_n_k =
320  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
322  make_right_pad_transform(K, KPad - K)),
325 
326  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
327  b_grid_desc_n_k,
332 
333  return b_grid_desc_bk0_n_bk1;
334  }
335  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
336  GemmSpec == GemmSpecialization::MNPadding)
337  {
338  // pad N, but not K
339  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
340  b_grid_desc_nraw_kraw,
342  make_right_pad_transform(N, NPad - N)),
345 
346  return b_grid_desc_bk0_n_bk1;
347  }
348  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
349  GemmSpec == GemmSpecialization::MKPadding)
350  {
351  // pad K, but not N
352  const auto b_grid_desc_n_k = transform_tensor_descriptor(
353  b_grid_desc_nraw_kraw,
357 
358  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
359  b_grid_desc_n_k,
364 
365  return b_grid_desc_bk0_n_bk1;
366  }
367  else
368  {
369  // not pad N or K
370  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
371  b_grid_desc_nraw_kraw,
376 
377  return b_grid_desc_bk0_n_bk1;
378  }
379  }
380 
381  template <typename ABlockDesc_AK0_M_AK1>
382  __host__ __device__ static constexpr auto
383  MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
384  {
385  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
386 
387  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
388  }
389 
390  template <typename BBlockDesc_BK0_N_BK1>
391  __host__ __device__ static constexpr auto
392  MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
393  {
394  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
395 
396  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
397  }
398 
399  __host__ __device__ static auto
401  {
402  const auto c_grid_desc_mraw_nraw = [&]() {
404  {
405  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
406  }
408  {
409  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
410  }
411  }();
412 
414 
415  if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
416  GemmSpec == GemmSpecialization::MNKPadding)
417  {
418  // pad M and N
419  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
421  make_right_pad_transform(N, NPad - N)),
424  }
425  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
426  GemmSpec == GemmSpecialization::MKPadding)
427  {
428  // pad M, but not N
430  c_grid_desc_mraw_nraw,
434  }
435  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
436  GemmSpec == GemmSpecialization::NKPadding)
437  {
438  // pad N, but not M
440  c_grid_desc_mraw_nraw,
444  }
445  else
446  {
447  // not pad M or N
448  return c_grid_desc_mraw_nraw;
449  }
450  }
451 
452  struct Problem
453  {
454  __host__ Problem(index_t M_,
455  index_t N_,
456  index_t K_,
457  index_t StrideA_,
458  index_t StrideB_,
459  index_t StrideC_)
460  : M{M_},
461  N{N_},
462  K{K_},
463  StrideA{StrideA_},
464  StrideB{StrideB_},
465  StrideC{StrideC_},
469  AK0{CalculateAK0(K_)},
470  BK0{CalculateBK0(K_)},
471  MBlock{CalculateMBlock(M_)},
473  {
474  }
475 
476  __host__ void Print() const
477  {
478  std::cout << "problem {"
479  << "M:" << M << ", "
480  << "N:" << N << ", "
481  << "K:" << K << ", "
482  << "SA:" << StrideA << ", "
483  << "SB:" << StrideB << ", "
484  << "SC:" << StrideC << ", "
485  << "MP:" << MPadded << ", "
486  << "NP:" << NPadded << ", "
487  << "KP:" << KPadded << ", "
488  << "AK0:" << AK0 << ", "
489  << "BK0:" << BK0 << ", "
490  << "MBlock: " << MBlock << ", "
491  << "NBlock: " << NBlock << "}" << std::endl;
492  }
493 
507  };
508 
509  // Argument
511  {
512  __host__ Argument(const FloatA* p_a_grid_,
513  const FloatB* p_b_grid_,
514  FloatC* p_c_grid_,
515  index_t M_,
516  index_t N_,
517  index_t K_,
518  index_t StrideA_,
519  index_t StrideB_,
520  index_t StrideC_)
521  : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
522  p_a_grid{p_a_grid_},
523  p_b_grid{p_b_grid_},
524  p_c_grid{p_c_grid_}
525  {
526  }
527 
528  const FloatA* p_a_grid;
529  const FloatB* p_b_grid;
530  FloatC* p_c_grid;
531  };
532 
533  // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
535  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
536 
537  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
538  {
539  // A matrix in LDS memory, dst of blockwise copy
543  }
544 
545  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
546  {
547  // B matrix in LDS memory, dst of blockwise copy
551  }
552 
554  {
555  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
556  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
557 
558  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
560  make_tuple(I1,
562  I1,
564 
565  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
566  }
567 
568  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
569  {
570  // LDS allocation for A and B: be careful of alignment
571  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
572  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
573 
574  // lds max alignment
575  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
576 
577  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
578  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
579 
580  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
581  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
582 
583  // LDS allocation for C shuffle in LDS
584  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
586 
587  constexpr auto c_block_size =
588  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
589 
590  return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) +
591  b_block_space_size_aligned * sizeof(ComputeTypeB)),
592  c_block_size * sizeof(FloatCShuffle));
593  }
594 
595  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
596  __host__ static constexpr bool CheckValidity(const Problem& problem)
597  {
598  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
599  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
600  "Invalid tuning param!");
601 
602  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
606  {
607  if(!(problem.M % MPerBlock == 0))
608  {
609  return false;
610  }
611  }
612 
613  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
617  {
618  if(!(problem.N % NPerBlock == 0))
619  {
620  return false;
621  }
622  }
623 
628  {
629  if(!(CalculateKPadded(problem.K) % AK1Value == 0) ||
630  !(CalculateKPadded(problem.K) % BK1Value == 0))
631  {
632  return false;
633  }
634  }
635  else
636  {
637  if(!(problem.K % AK1Value == 0) || !(problem.K % BK1Value == 0))
638  {
639  return false;
640  }
641  }
642 
644  {
645  if(problem.K % ABlockTransferSrcScalarPerVector != 0)
646  {
647  return false;
648  }
649  }
650  else
651  {
652  if(problem.M % ABlockTransferSrcScalarPerVector != 0)
653  {
654  return false;
655  }
656  }
657 
659  {
660  if(problem.N % BBlockTransferSrcScalarPerVector != 0)
661  {
662  return false;
663  }
664  }
665  else
666  {
667  if(problem.K % BBlockTransferSrcScalarPerVector != 0)
668  {
669  return false;
670  }
671  }
672 
674  {
675  if(problem.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
676  {
677  return false;
678  }
679  }
680  else
681  {
682  if(problem.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
683  {
684  return false;
685  }
686  }
687 
688  // check gridwise gemm pipeline
689  const auto num_k_loop = (CalculateAK0(problem.K) * AK1Value) / KPerBlock;
690 
691  if(num_k_loop < 4)
692  {
693  return false;
694  }
695 
696  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
697  return true;
698  }
699 
700  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
701  {
702  const index_t num_loop = K / KPerBlock;
703 
704  return num_loop > 3;
705  }
706 
707  __host__ static constexpr index_t CalculateKBlockLoopTailNum(index_t K)
708  {
709  const index_t num_loop = K / KPerBlock;
710 
711  if(num_loop % 2 == 1)
712  return 3;
713  else
714  return 2;
715  }
716 
717  template <typename CGridDesc>
719  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
720  {
721  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
722  c_grid_desc_m_n,
727 
728  return c_grid_desc_mblock_mperblock_nblock_nperblock;
729  }
730 
731  // return block_id to C matrix tile idx (m0, n0) mapping
732  // if arch = gfx942
734 
735  template <bool HasMainKBlockLoop, index_t TailNum = 3>
736  __device__ static void Run(const FloatA* p_a_grid,
737  const FloatB* p_b_grid,
738  FloatC* p_c_grid,
739  void* p_shared_0,
740  void* p_shared_1,
741  const Problem& problem)
742  {
743  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
744  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
745  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
746  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
747  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
748  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
749 
750  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
752  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
753 
754  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
755  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
756  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
757  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
758  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
759  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
760 
761  const AElementwiseOperation a_element_op{};
762  const BElementwiseOperation b_element_op{};
763  const CElementwiseOperation c_element_op{};
764 
765  // divide block work by [M, N]
766  const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
767 
768  const auto block_work_idx =
769  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
770 
771  if(!block_2_ctile_map.ValidCTileIndex(
772  block_work_idx,
773  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
774  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
775  {
776  return;
777  }
778 #if 0
779  if(threadIdx.x == 0){
780  printf("Hardware assigned No. %03d workgroup of logical C tile (%02d, %02d) on %d th XCC Die, %d th SE, %d th CU\n",
781  get_block_1d_id(),
782  block_work_idx[I0],
783  block_work_idx[I1],
784  __smid()>>6 & 0xf,
785  __smid()>>4 & 0x3,
786  __smid() & 0xf);
787  }
788 #endif
789  // HACK: this force m/n_block_data_idx_on_grid into SGPR
790  const index_t m_block_data_idx_on_grid =
791  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
792 
793  const index_t n_block_data_idx_on_grid =
794  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
795 
796  // lds max alignment
797  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
798 
799  // A matrix in LDS memory, dst of blockwise copy
800  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
801 
802  // B matrix in LDS memory, dst of blockwise copy
803  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
804 
805  // A matrix blockwise copy
806  auto a_blockwise_copy =
808  AElementwiseOperation,
812  ABlockTransferThreadClusterLengths_AK0_M_AK1,
813  ABlockTransferThreadClusterArrangeOrder,
814  FloatA,
815  ComputeTypeA,
816  decltype(a_grid_desc_ak0_m_ak1),
817  decltype(a_block_desc_ak0_m_ak1),
818  ABlockTransferSrcAccessOrder,
820  ABlockTransferSrcVectorDim,
821  2,
822  ABlockTransferSrcScalarPerVector,
823  ABlockTransferDstScalarPerVector_AK1,
824  1,
825  1,
826  AThreadTransferSrcResetCoordinateAfterRun,
827  true>(
828  a_grid_desc_ak0_m_ak1,
829  make_multi_index(0, m_block_data_idx_on_grid, 0),
830  a_element_op,
831  a_block_desc_ak0_m_ak1,
832  make_multi_index(0, 0, 0),
834 
835  // B matrix blockwise copy
836  auto b_blockwise_copy =
838  BElementwiseOperation,
842  BBlockTransferThreadClusterLengths_BK0_N_BK1,
843  BBlockTransferThreadClusterArrangeOrder,
844  FloatB,
845  ComputeTypeB,
846  decltype(b_grid_desc_bk0_n_bk1),
847  decltype(b_block_desc_bk0_n_bk1),
848  BBlockTransferSrcAccessOrder,
850  BBlockTransferSrcVectorDim,
851  2,
852  BBlockTransferSrcScalarPerVector,
853  BBlockTransferDstScalarPerVector_BK1,
854  1,
855  1,
856  BThreadTransferSrcResetCoordinateAfterRun,
857  true>(
858  b_grid_desc_bk0_n_bk1,
859  make_multi_index(0, n_block_data_idx_on_grid, 0),
860  b_element_op,
861  b_block_desc_bk0_n_bk1,
862  make_multi_index(0, 0, 0),
864 
865  // GEMM definition
866  // c_mtx += transpose(a_mtx) * b_mtx
867  // a_mtx[K0PerBlock, MPerBlock] is in LDS
868  // b_mtx[K0PerBlock, NPerBlock] is in LDS
869  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
870  // register
871  // sanity check
872  constexpr index_t KPack =
875 
876  // auto blockwise_gemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector<
877  // BlockSize,
878  // ComputeType,
879  // FloatGemmAcc,
880  // decltype(a_block_desc_ak0_m_ak1),
881  // decltype(b_block_desc_bk0_n_bk1),
882  // MPerXdl,
883  // NPerXdl,
884  // MXdlPerWave,
885  // NXdlPerWave,
886  // KPack,
887  // LoopSched>();
888  auto blockwise_gemm_pipeline = BlockwiseGemmXdlops_pipeline_v4<
889  BlockSize,
890  ComputeTypeA,
891  FloatGemmAcc,
892  decltype(a_block_desc_ak0_m_ak1),
893  decltype(b_block_desc_bk0_n_bk1),
894  decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(a_block_desc_ak0_m_ak1)),
895  decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(b_block_desc_bk0_n_bk1)),
896  MPerBlock,
897  NPerBlock,
898  KPerBlock,
899  MPerXdl,
900  NPerXdl,
901  MXdlPerWave,
902  NXdlPerWave,
903  KPack>{}; // TransposeC
904 
905  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
906 
907  // LDS allocation for A and B: be careful of alignment
908  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
909  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
910 
911  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
912  static_cast<ComputeTypeA*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
913 
914  auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
915  static_cast<ComputeTypeB*>(p_shared_0) + a_block_space_size_aligned,
916  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
917 
918  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
919  static_cast<ComputeTypeA*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
920 
921  auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
922  static_cast<ComputeTypeB*>(p_shared_1) + a_block_space_size_aligned,
923  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
924 
925  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
926  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
927 
928  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
929  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
930 
931  // gridwise GEMM pipeline
932  static_assert(std::is_default_constructible_v<GridwiseGemmPipe>);
933  // const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
934 
935  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
936  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
937  KPerBlock);
938 
939  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
940  a_block_desc_ak0_m_ak1,
941  a_blockwise_copy,
942  a_grid_buf,
943  a_block_bufs,
944  a_block_slice_copy_step,
945  b_grid_desc_bk0_n_bk1,
946  b_block_desc_bk0_n_bk1,
947  b_blockwise_copy,
948  b_grid_buf,
949  b_block_bufs,
950  b_block_slice_copy_step,
951  c_thread_buf,
952  num_k_block_main_loop);
953 
954  // shuffle C and write out
955  {
956  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
957  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
958  "wrong!");
959 
960  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
961  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
962 
963  // TODO: hacky, fix it!
964  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
965  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
966 
967  // TODO: hacky, fix it!
968  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
969  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
970  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
971 
972  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
973  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
974  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
975  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
976  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
977  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
978  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
979  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
980 
981  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
983 
984  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
985  static_cast<FloatCShuffle*>(p_shared_0),
986  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
987 
988  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
989  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
990  make_tuple(
993  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
994  M1, // M1 = MWave
995  M2, // M2 * M3 * M4 = MPerXdl
996  M3,
997  M4)),
1000  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1001  N1, // N1 = NWave
1002  N2))), // N2 = NPerXdl
1004  make_tuple(
1006 
1007  // calculate origin of thread output tensor on global memory
1008  // blockwise GEMM c matrix starting index
1009  const auto c_thread_mtx_on_block =
1010  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1011 
1012  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1013  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1014 
1015  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1017  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1019  make_tuple(Sequence<0>{}));
1020 
1021  const auto m_thread_data_on_block_idx =
1022  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1023  make_multi_index(m_thread_data_on_block));
1024 
1025  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1029  make_tuple(Sequence<0>{}));
1030 
1031  const auto n_thread_data_on_block_idx =
1032  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1033  make_multi_index(n_thread_data_on_block));
1034 
1035  // shuffle: threadwise copy C from VGPR to LDS
1036  auto c_thread_copy_vgpr_to_lds =
1038  FloatCShuffle,
1039  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1040  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1042  Sequence<CShuffleMXdlPerWavePerShuffle,
1043  CShuffleNXdlPerWavePerShuffle,
1044  I1,
1045  I1,
1046  M2,
1047  I1,
1048  M4,
1049  I1>,
1051  7,
1052  1,
1054  1,
1055  true>{
1056  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1057  make_multi_index(0,
1058  0,
1059  m_thread_data_on_block_idx[I1],
1060  n_thread_data_on_block_idx[I1],
1061  m_thread_data_on_block_idx[I2],
1062  m_thread_data_on_block_idx[I3],
1063  m_thread_data_on_block_idx[I4],
1064  n_thread_data_on_block_idx[I2]),
1066 
1067  // shuffle: blockwise copy C from LDS to global
1068  auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1069  ThisThreadBlock, // ThreadGroup
1070  CElementwiseOperation, // ElementwiseOperation,
1071  CGlobalMemoryDataOperation, // DstInMemOp,
1072  Sequence<1,
1073  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1074  1,
1075  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1076  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1077  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1078  FloatCShuffle, // typename SrcData,
1079  FloatC, // typename DstData,
1080  decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1081  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1082  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1083  3, // index_t VectorDim,
1084  CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1085  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1086  false> // bool ThreadTransferDstResetCoordinateAfterRun>
1087  {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1088  make_multi_index(0, 0, 0, 0),
1089  c_grid_desc_mblock_mperblock_nblock_nperblock,
1090  make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
1091  c_element_op};
1092 
1093  // space filling curve for threadwise C in VGPR
1094  constexpr auto sfc_c_vgpr =
1097  Sequence<CShuffleMXdlPerWavePerShuffle,
1098  CShuffleNXdlPerWavePerShuffle,
1099  1,
1100  1,
1101  M2,
1102  1,
1103  M4,
1104  1>>{};
1105 
1106  // space filling curve for shuffled blockwise C in global mem
1107  constexpr auto sfc_c_global =
1110  Sequence<1,
1111  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1112  1,
1113  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1114 
1115  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1116 
1117  static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1118 
1119  static_for<0, num_access, 1>{}([&](auto access_id) {
1120  // make sure it's safe to write to LDS
1121  block_sync_lds();
1122 
1123  // each thread write its data from VGPR to LDS
1124  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1125  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1126  c_thread_buf,
1127  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1128  c_shuffle_block_buf);
1129 
1130  // make sure it's safe to read from LDS
1131  block_sync_lds();
1132 
1133  // each block copy its data from LDS to global
1134  c_shuffle_block_copy_lds_to_global.Run(
1135  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1136  c_shuffle_block_buf,
1137  c_grid_desc_mblock_mperblock_nblock_nperblock,
1138  c_grid_buf);
1139 
1140  if constexpr(access_id < num_access - 1)
1141  {
1142  constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1143 
1144  // move on C
1145  c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1146  c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1147  }
1148  });
1149  }
1150  }
1151 };
1152 
1153 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto integer_divide_floor(X x, Y y)
Definition: math.hpp:66
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__global__ void kernel_gemm_xdl_cshuffle_v2(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:26
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:17
Definition: block_to_ctile_map.hpp:270
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:281
Definition: blockwise_gemm_pipeline_xdlops.hpp:103
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:511
const FloatB * p_b_grid
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:529
__host__ Argument(const FloatA *p_a_grid_, const FloatB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:512
FloatC * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:530
const FloatA * p_a_grid
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:528
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:453
index_t StrideB
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:498
index_t StrideC
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:499
index_t KPadded
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:502
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:495
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:494
index_t MBlock
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:505
index_t NBlock
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:506
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:476
index_t BK0
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:504
index_t MPadded
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:500
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:496
index_t StrideA
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:497
index_t AK0
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:503
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:454
index_t NPadded
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:501
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:120
static constexpr auto I6
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:127
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:568
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:143
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:138
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:700
static constexpr __host__ index_t CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:707
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage, LoopSched >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:535
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:718
static __device__ void Run(const FloatA *p_a_grid, const FloatB *p_b_grid, FloatC *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:736
static constexpr auto AK1Number
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:133
static constexpr auto I1
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:122
static constexpr auto I2
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:123
static constexpr auto AK0Number
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:131
static constexpr auto I0
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:121
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:192
static constexpr auto BK0Number
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:132
static __host__ auto CalculateAK0(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:158
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:197
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:537
static constexpr auto I4
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:125
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:545
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:203
static constexpr auto I5
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:126
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:553
static __host__ auto CalculateBK0(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:175
static constexpr auto BK1Number
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:134
static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:217
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:383
static constexpr auto I3
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:124
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:153
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:400
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:148
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:392
static constexpr auto I7
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:128
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:596
static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:299
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdl_cshuffle_v2.hpp:136
Definition: xdlops_gemm.hpp:886
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: functional2.hpp:31
Definition: device_base.hpp:50
Definition: unary_element_wise_operation.hpp:241