/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp Source File
gridwise_gemm_xdl_cshuffle_v1.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
17 
18 namespace ck {
19 
20 template <typename GridwiseGemm, bool HasMainKBlockLoop>
21 __global__ void
22 #if CK_USE_LAUNCH_BOUNDS
24 #endif
25  kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
26 {
27 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
28  defined(__gfx94__))
29  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
30 
31  GridwiseGemm::template Run<HasMainKBlockLoop>(
32  karg.p_a_grid, karg.p_b_grid, karg.p_c_grid, p_shared, karg);
33 #else
34  ignore = karg;
35 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
36 }
37 
38 template <typename GridwiseGemm,
39  typename FloatA,
40  typename FloatB,
41  typename FloatC,
42  bool HasMainKBlockLoop>
43 __global__ void
44 #if CK_USE_LAUNCH_BOUNDS
46 #endif
47  kernel_gemm_xdl_cshuffle_v1(const FloatA* __restrict__ p_a_grid,
48  const FloatB* __restrict__ p_b_grid,
49  FloatC* __restrict__ p_c_grid,
50  typename GridwiseGemm::Problem problem)
51 {
52 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
53  defined(__gfx94__))
54  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
55 
56  GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid, p_b_grid, p_c_grid, p_shared, problem);
57 #else
58  ignore = p_a_grid;
59  ignore = p_b_grid;
60  ignore = p_c_grid;
61  ignore = problem;
62 #endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
63 }
64 
65 template <typename ALayout,
66  typename BLayout,
67  typename CLayout,
68  typename FloatA,
69  typename FloatB,
70  typename FloatGemmAcc,
71  typename FloatCShuffle,
72  typename FloatC,
73  typename AElementwiseOperation,
74  typename BElementwiseOperation,
75  typename CElementwiseOperation,
77  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
78  index_t NumGemmKPrefetchStage,
79  index_t BlockSize,
80  index_t MPerBlock,
81  index_t NPerBlock,
82  index_t KPerBlock,
83  index_t AK1Value,
84  index_t BK1Value,
85  index_t MPerXdl,
86  index_t NPerXdl,
87  index_t MXdlPerWave,
88  index_t NXdlPerWave,
89  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
90  typename ABlockTransferThreadClusterArrangeOrder,
91  typename ABlockTransferSrcAccessOrder,
92  index_t ABlockTransferSrcVectorDim,
93  index_t ABlockTransferSrcScalarPerVector,
94  index_t ABlockTransferDstScalarPerVector_AK1,
95  bool AThreadTransferSrcResetCoordinateAfterRun,
96  index_t ABlockLdsExtraM,
97  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
98  typename BBlockTransferThreadClusterArrangeOrder,
99  typename BBlockTransferSrcAccessOrder,
100  index_t BBlockTransferSrcVectorDim,
101  index_t BBlockTransferSrcScalarPerVector,
102  index_t BBlockTransferDstScalarPerVector_BK1,
103  bool BThreadTransferSrcResetCoordinateAfterRun,
104  index_t BBlockLdsExtraN,
105  index_t CShuffleMXdlPerWavePerShuffle,
106  index_t CShuffleNXdlPerWavePerShuffle,
107  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
108  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
109  LoopScheduler LoopSched,
110  PipelineVersion PipelineVer = PipelineVersion::v1,
111  typename ComputeTypeA = FloatC,
112  typename ComputeTypeB = ComputeTypeA>
114 {
115  static constexpr auto I0 = Number<0>{};
116  static constexpr auto I1 = Number<1>{};
117  static constexpr auto I2 = Number<2>{};
118  static constexpr auto I3 = Number<3>{};
119  static constexpr auto I4 = Number<4>{};
120  static constexpr auto I5 = Number<5>{};
121  static constexpr auto I6 = Number<6>{};
122  static constexpr auto I7 = Number<7>{};
123 
124  // K1 should be Number<...>
125  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
126  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
127  static constexpr auto AK1Number = Number<AK1Value>{};
128  static constexpr auto BK1Number = Number<BK1Value>{};
129 
131 
132  __host__ static auto CalculateGridSize(index_t M, index_t N)
133  {
134  return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, 1);
135  }
136 
137  __host__ static auto CalculateMPadded(index_t M)
138  {
139  return math::integer_divide_ceil(M, MPerBlock) * MPerBlock;
140  }
141 
142  __host__ static auto CalculateNPadded(index_t N)
143  {
144  return math::integer_divide_ceil(N, NPerBlock) * NPerBlock;
145  }
146 
147  __host__ static auto CalculateKPadded(index_t K)
148  {
149  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
150  }
151 
152  __host__ static auto CalculateAK0(index_t K)
153  {
155 
156  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
157  GemmSpec == GemmSpecialization::MNKPadding ||
158  GemmSpec == GemmSpecialization::KPadding ||
159  GemmSpec == GemmSpecialization::NKPadding)
160  {
161  return CalculateKPadded(K) / AK1Value;
162  }
163  else
164  {
165  return K / AK1Value;
166  }
167  }
168 
169  __host__ static auto CalculateBK0(index_t K)
170  {
172 
173  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
174  GemmSpec == GemmSpecialization::MNKPadding ||
175  GemmSpec == GemmSpecialization::KPadding ||
176  GemmSpec == GemmSpecialization::MKPadding)
177  {
178  return CalculateKPadded(K) / BK1Value;
179  }
180  else
181  {
182  return K / BK1Value;
183  }
184  }
185 
186  __host__ static auto CalculateMBlock(index_t M)
187  {
188  return math::integer_divide_floor(M, MPerBlock);
189  }
190 
191  __host__ static auto CalculateNBlock(index_t N)
192  {
193  return math::integer_divide_floor(N, NPerBlock);
194  }
195 
196  __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
197  index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
198  {
199  const auto a_grid_desc_mraw_kraw = [&]() {
200  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
201  {
202  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
203  }
204  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
205  {
206  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
207  }
208  }();
209 
211 
212  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
213  GemmSpec == GemmSpecialization::MNKPadding)
214  {
215  // pad both M and K
216  const auto a_grid_desc_m_k =
217  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
219  make_right_pad_transform(K, KPad - K)),
222 
223  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
224  a_grid_desc_m_k,
229 
230  return a_grid_desc_ak0_m_ak1;
231  }
232  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
233  GemmSpec == GemmSpecialization::MNPadding)
234  {
235  // pad M, but not K
236  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
237  a_grid_desc_mraw_kraw,
239  make_right_pad_transform(M, MPad - M)),
242 
243  return a_grid_desc_ak0_m_ak1;
244  }
245  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
246  GemmSpec == GemmSpecialization::NKPadding)
247  {
248  // pad K, but not M
249  const auto a_grid_desc_m_k = transform_tensor_descriptor(
250  a_grid_desc_mraw_kraw,
254 
255  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
256  a_grid_desc_m_k,
261 
262  return a_grid_desc_ak0_m_ak1;
263  }
264  else
265  {
266  // not pad M or K
267  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
268  a_grid_desc_mraw_kraw,
273 
274  return a_grid_desc_ak0_m_ak1;
275  }
276  }
277 
278  __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
279  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
280  {
281  const auto b_grid_desc_nraw_kraw = [&]() {
283  {
284  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
285  }
287  {
288  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
289  }
290  }();
291 
293 
294  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
295  GemmSpec == GemmSpecialization::MNKPadding)
296  {
297  // pad both N and K
298  const auto b_grid_desc_n_k =
299  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
301  make_right_pad_transform(K, KPad - K)),
304 
305  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
306  b_grid_desc_n_k,
311 
312  return b_grid_desc_bk0_n_bk1;
313  }
314  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
315  GemmSpec == GemmSpecialization::MNPadding)
316  {
317  // pad N, but not K
318  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
319  b_grid_desc_nraw_kraw,
321  make_right_pad_transform(N, NPad - N)),
324 
325  return b_grid_desc_bk0_n_bk1;
326  }
327  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
328  GemmSpec == GemmSpecialization::MKPadding)
329  {
330  // pad K, but not N
331  const auto b_grid_desc_n_k = transform_tensor_descriptor(
332  b_grid_desc_nraw_kraw,
336 
337  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
338  b_grid_desc_n_k,
343 
344  return b_grid_desc_bk0_n_bk1;
345  }
346  else
347  {
348  // not pad N or K
349  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
350  b_grid_desc_nraw_kraw,
355 
356  return b_grid_desc_bk0_n_bk1;
357  }
358  }
359 
360  __host__ __device__ static auto
362  {
363  const auto c_grid_desc_mraw_nraw = [&]() {
365  {
366  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
367  }
369  {
370  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
371  }
372  }();
373 
375 
376  if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
377  GemmSpec == GemmSpecialization::MNKPadding)
378  {
379  // pad M and N
380  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
382  make_right_pad_transform(N, NPad - N)),
385  }
386  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
387  GemmSpec == GemmSpecialization::MKPadding)
388  {
389  // pad M, but not N
391  c_grid_desc_mraw_nraw,
395  }
396  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
397  GemmSpec == GemmSpecialization::NKPadding)
398  {
399  // pad N, but not M
401  c_grid_desc_mraw_nraw,
405  }
406  else
407  {
408  // not pad M or N
409  return c_grid_desc_mraw_nraw;
410  }
411  }
412 
413  struct Problem
414  {
415  __host__ Problem(index_t M_,
416  index_t N_,
417  index_t K_,
418  index_t StrideA_,
419  index_t StrideB_,
420  index_t StrideC_)
421  : M{M_},
422  N{N_},
423  K{K_},
424  StrideA{StrideA_},
425  StrideB{StrideB_},
426  StrideC{StrideC_},
430  AK0{CalculateAK0(K_)},
431  BK0{CalculateBK0(K_)},
432  MBlock{CalculateMBlock(M_)},
434  {
435  }
436 
437  __host__ void Print() const
438  {
439  std::cout << "problem {"
440  << "M:" << M << ", "
441  << "N:" << N << ", "
442  << "K:" << K << ", "
443  << "SA:" << StrideA << ", "
444  << "SB:" << StrideB << ", "
445  << "SC:" << StrideC << ", "
446  << "MP:" << MPadded << ", "
447  << "NP:" << NPadded << ", "
448  << "KP:" << KPadded << ", "
449  << "AK0:" << AK0 << ", "
450  << "BK0:" << BK0 << ", "
451  << "MBlock: " << MBlock << ", "
452  << "NBlock: " << NBlock << "}" << std::endl;
453  }
454 
468  };
469 
470  // Argument
472  {
473  __host__ Argument(const FloatA* p_a_grid_,
474  const FloatB* p_b_grid_,
475  FloatC* p_c_grid_,
476  index_t M_,
477  index_t N_,
478  index_t K_,
479  index_t StrideA_,
480  index_t StrideB_,
481  index_t StrideC_)
482  : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
483  p_a_grid{p_a_grid_},
484  p_b_grid{p_b_grid_},
485  p_c_grid{p_c_grid_}
486  {
487  }
488 
489  const FloatA* p_a_grid;
490  const FloatB* p_b_grid;
491  FloatC* p_c_grid;
492  };
493 
494  // FIXME: pass GridwiseGemmPipe as a template arguement into GridwiseGemm
496  decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
497 
498  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
499  {
500  // A matrix in LDS memory, dst of blockwise copy
504  }
505 
506  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
507  {
508  // B matrix in LDS memory, dst of blockwise copy
512  }
513 
515  {
516  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
517  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
518 
519  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
521  make_tuple(I1,
523  I1,
525 
526  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
527  }
528 
529  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
530  {
531  // LDS allocation for A and B: be careful of alignment
532  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
533  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
534 
535  // lds max alignment
536  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
537 
538  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
539  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
540 
541  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
542  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
543 
544  // LDS allocation for C shuffle in LDS
545  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
547 
548  constexpr auto c_block_size =
549  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
550 
551  return math::max((a_block_space_size_aligned * sizeof(ComputeTypeA) +
552  b_block_space_size_aligned * sizeof(ComputeTypeB)),
553  c_block_size * sizeof(FloatCShuffle));
554  }
555 
556  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
557  __host__ static constexpr bool CheckValidity(const Problem& problem)
558  {
559  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
560  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
561  "Invalid tuning param!");
562 
563  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
567  {
568  if(!(problem.M % MPerBlock == 0))
569  {
570  return false;
571  }
572  }
573 
574  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
578  {
579  if(!(problem.N % NPerBlock == 0))
580  {
581  return false;
582  }
583  }
584 
589  {
590  if(!(CalculateKPadded(problem.K) % AK1Value == 0) ||
591  !(CalculateKPadded(problem.K) % BK1Value == 0))
592  {
593  return false;
594  }
595  }
596  else
597  {
598  if(!(problem.K % AK1Value == 0) || !(problem.K % BK1Value == 0))
599  {
600  return false;
601  }
602  }
603 
605  {
606  if(problem.K % ABlockTransferSrcScalarPerVector != 0)
607  {
608  return false;
609  }
610  }
611  else
612  {
613  if(problem.M % ABlockTransferSrcScalarPerVector != 0)
614  {
615  return false;
616  }
617  }
618 
620  {
621  if(problem.N % BBlockTransferSrcScalarPerVector != 0)
622  {
623  return false;
624  }
625  }
626  else
627  {
628  if(problem.K % BBlockTransferSrcScalarPerVector != 0)
629  {
630  return false;
631  }
632  }
633 
635  {
636  if(problem.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
637  {
638  return false;
639  }
640  }
641  else
642  {
643  if(problem.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
644  {
645  return false;
646  }
647  }
648 
649  // check gridwise gemm pipeline
650  const auto num_k_loop = (CalculateAK0(problem.K) * AK1Value) / KPerBlock;
651 
652  if(!GridwiseGemmPipe::IsSupported(num_k_loop))
653  {
654  return false;
655  }
656 
657  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
658  return true;
659  }
660 
661  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
662  {
663  const index_t num_loop = K / KPerBlock;
664 
665  return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
666  }
667 
668  template <typename CGridDesc>
670  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
671  {
672  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
673  c_grid_desc_m_n,
678 
679  return c_grid_desc_mblock_mperblock_nblock_nperblock;
680  }
681 
682  // return block_id to C matrix tile idx (m0, n0) mapping
684 
685  template <bool HasMainKBlockLoop>
686  __device__ static void Run(const FloatA* __restrict__ p_a_grid,
687  const FloatB* __restrict__ p_b_grid,
688  FloatC* __restrict__ p_c_grid,
689  void* __restrict__ p_shared,
690  const Problem& problem)
691  {
692  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
693  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
694  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
695  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
696  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
697  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
698 
699  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
701  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
702 
703  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
704  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
705  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
706  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
707  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
708  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
709 
710  const AElementwiseOperation a_element_op{};
711  const BElementwiseOperation b_element_op{};
712  const CElementwiseOperation c_element_op{};
713 
714  // divide block work by [M, N]
715  const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N};
716 
717  const auto block_work_idx =
718  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
719 
720  if(!block_2_ctile_map.ValidCTileIndex(
721  block_work_idx,
722  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
723  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
724  {
725  return;
726  }
727 
728  // HACK: this force m/n_block_data_idx_on_grid into SGPR
729  const index_t m_block_data_idx_on_grid =
730  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
731 
732  const index_t n_block_data_idx_on_grid =
733  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
734 
735  // lds max alignment
736  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
737 
738  // A matrix in LDS memory, dst of blockwise copy
739  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
740 
741  // B matrix in LDS memory, dst of blockwise copy
742  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
743 
744  // A matrix blockwise copy
745  auto a_blockwise_copy =
747  AElementwiseOperation,
751  ABlockTransferThreadClusterLengths_AK0_M_AK1,
752  ABlockTransferThreadClusterArrangeOrder,
753  FloatA,
754  ComputeTypeA,
755  decltype(a_grid_desc_ak0_m_ak1),
756  decltype(a_block_desc_ak0_m_ak1),
757  ABlockTransferSrcAccessOrder,
759  ABlockTransferSrcVectorDim,
760  2,
761  ABlockTransferSrcScalarPerVector,
762  ABlockTransferDstScalarPerVector_AK1,
763  1,
764  1,
765  AThreadTransferSrcResetCoordinateAfterRun,
766  true,
767  NumGemmKPrefetchStage>(
768  a_grid_desc_ak0_m_ak1,
769  make_multi_index(0, m_block_data_idx_on_grid, 0),
770  a_element_op,
771  a_block_desc_ak0_m_ak1,
772  make_multi_index(0, 0, 0),
774 
775  // B matrix blockwise copy
776  auto b_blockwise_copy =
778  BElementwiseOperation,
782  BBlockTransferThreadClusterLengths_BK0_N_BK1,
783  BBlockTransferThreadClusterArrangeOrder,
784  FloatB,
785  ComputeTypeB,
786  decltype(b_grid_desc_bk0_n_bk1),
787  decltype(b_block_desc_bk0_n_bk1),
788  BBlockTransferSrcAccessOrder,
790  BBlockTransferSrcVectorDim,
791  2,
792  BBlockTransferSrcScalarPerVector,
793  BBlockTransferDstScalarPerVector_BK1,
794  1,
795  1,
796  BThreadTransferSrcResetCoordinateAfterRun,
797  true,
798  NumGemmKPrefetchStage>(
799  b_grid_desc_bk0_n_bk1,
800  make_multi_index(0, n_block_data_idx_on_grid, 0),
801  b_element_op,
802  b_block_desc_bk0_n_bk1,
803  make_multi_index(0, 0, 0),
805 
806  // GEMM definition
807  // c_mtx += transpose(a_mtx) * b_mtx
808  // a_mtx[K0PerBlock, MPerBlock] is in LDS
809  // b_mtx[K0PerBlock, NPerBlock] is in LDS
810  // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
811  // register
812  // sanity check
813  constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
814  constexpr bool is_single_rate_mfma =
816  lcm_AK1_BK1 <= 4) ||
817  (is_same<ComputeTypeA, int8_t>::value && lcm_AK1_BK1 <= 8) ||
819  lcm_AK1_BK1 < 32))
820  ? true
821  : false;
822  constexpr auto is_scale_mfma = false;
823  constexpr index_t KPack = math::max(lcm_AK1_BK1,
824  MfmaSelector<ComputeTypeA,
825  MPerXdl,
826  NPerXdl,
827  ComputeTypeB,
828  is_single_rate_mfma,
829  is_scale_mfma>::selected_mfma.k_per_blk);
830 
832  BlockSize,
833  ComputeTypeA,
834  ComputeTypeB,
835  FloatGemmAcc,
836  decltype(a_block_desc_ak0_m_ak1),
837  decltype(b_block_desc_bk0_n_bk1),
838  MPerXdl,
839  NPerXdl,
840  MXdlPerWave,
841  NXdlPerWave,
842  KPack,
843  LoopSched>();
844 
845  auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
846 
847  // LDS allocation for A and B: be careful of alignment
848  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
849  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
850 
851  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
852  static_cast<ComputeTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
853 
854  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
855  static_cast<ComputeTypeB*>(p_shared) + a_block_space_size_aligned,
856  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
857 
858  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
859  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
860 
861  // gridwise GEMM pipeline
862  static_assert(std::is_default_constructible_v<GridwiseGemmPipe>);
863  const auto gridwise_gemm_pipeline = GridwiseGemmPipe{};
864 
865  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
866  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
867  KPerBlock);
868 
869  gridwise_gemm_pipeline.template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
870  a_block_desc_ak0_m_ak1,
871  a_blockwise_copy,
872  a_grid_buf,
873  a_block_buf,
874  a_block_slice_copy_step,
875  b_grid_desc_bk0_n_bk1,
876  b_block_desc_bk0_n_bk1,
877  b_blockwise_copy,
878  b_grid_buf,
879  b_block_buf,
880  b_block_slice_copy_step,
881  blockwise_gemm,
882  c_thread_buf,
883  num_k_block_main_loop);
884 
885  // shuffle C and write out
886  {
887  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
888  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
889  "wrong!");
890 
891  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
892  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
893 
894  // TODO: hacky, fix it!
895  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
896  blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
897 
898  // TODO: hacky, fix it!
899  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
900  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
901  blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
902 
903  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
904  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
905  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
906  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
907  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
908  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
909  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
910  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
911 
912  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
914 
915  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
916  static_cast<FloatCShuffle*>(p_shared),
917  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
918 
919  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
920  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
921  make_tuple(
924  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
925  M1, // M1 = MWave
926  M2, // M2 * M3 * M4 = MPerXdl
927  M3,
928  M4)),
931  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
932  N1, // N1 = NWave
933  N2))), // N2 = NPerXdl
935  make_tuple(
937 
938  // calculate origin of thread output tensor on global memory
939  // blockwise GEMM c matrix starting index
940  const auto c_thread_mtx_on_block =
941  blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
942 
943  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
944  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
945 
946  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
948  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
951 
952  const auto m_thread_data_on_block_idx =
953  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
954  make_multi_index(m_thread_data_on_block));
955 
956  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
961 
962  const auto n_thread_data_on_block_idx =
963  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
964  make_multi_index(n_thread_data_on_block));
965 
966  // shuffle: threadwise copy C from VGPR to LDS
967  auto c_thread_copy_vgpr_to_lds =
969  FloatCShuffle,
970  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
971  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
973  Sequence<CShuffleMXdlPerWavePerShuffle,
974  CShuffleNXdlPerWavePerShuffle,
975  I1,
976  I1,
977  M2,
978  I1,
979  M4,
980  I1>,
982  7,
983  1,
985  1,
986  true>{
987  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
989  0,
990  m_thread_data_on_block_idx[I1],
991  n_thread_data_on_block_idx[I1],
992  m_thread_data_on_block_idx[I2],
993  m_thread_data_on_block_idx[I3],
994  m_thread_data_on_block_idx[I4],
995  n_thread_data_on_block_idx[I2]),
997 
998  // shuffle: blockwise copy C from LDS to global
999  auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1000  ThisThreadBlock, // ThreadGroup
1001  CElementwiseOperation, // ElementwiseOperation,
1002  CGlobalMemoryDataOperation, // DstInMemOp,
1003  Sequence<1,
1004  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1005  1,
1006  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1007  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1008  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1009  FloatCShuffle, // typename SrcData,
1010  FloatC, // typename DstData,
1011  decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1012  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1013  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1014  3, // index_t VectorDim,
1015  CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1016  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1017  false> // bool ThreadTransferDstResetCoordinateAfterRun>
1018  {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1019  make_multi_index(0, 0, 0, 0),
1020  c_grid_desc_mblock_mperblock_nblock_nperblock,
1021  make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
1022  c_element_op};
1023 
1024  // space filling curve for threadwise C in VGPR
1025  constexpr auto sfc_c_vgpr =
1028  Sequence<CShuffleMXdlPerWavePerShuffle,
1029  CShuffleNXdlPerWavePerShuffle,
1030  1,
1031  1,
1032  M2,
1033  1,
1034  M4,
1035  1>>{};
1036 
1037  // space filling curve for shuffled blockwise C in global mem
1038  constexpr auto sfc_c_global =
1041  Sequence<1,
1042  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1043  1,
1044  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1045 
1046  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1047 
1048  static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1049 
1050  static_for<0, num_access, 1>{}([&](auto access_id) {
1051  // make sure it's safe to write to LDS
1052  block_sync_lds();
1053 
1054  // each thread write its data from VGPR to LDS
1055  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1056  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1057  c_thread_buf,
1058  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1059  c_shuffle_block_buf);
1060 
1061  // make sure it's safe to read from LDS
1062  block_sync_lds();
1063 
1064  // each block copy its data from LDS to global
1065  c_shuffle_block_copy_lds_to_global.Run(
1066  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1067  c_shuffle_block_buf,
1068  c_grid_desc_mblock_mperblock_nblock_nperblock,
1069  c_grid_buf);
1070 
1071  if constexpr(access_id < num_access - 1)
1072  {
1073  constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1074 
1075  // move on C
1076  c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1077  c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1078  }
1079  });
1080  }
1081  }
1082 };
1083 
1084 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto integer_divide_floor(X x, Y y)
Definition: math.hpp:66
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:269
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition: blockwise_gemm_xdlops.hpp:606
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:278
__global__ void kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:25
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:300
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:472
__host__ Argument(const FloatA *p_a_grid_, const FloatB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:473
const FloatB * p_b_grid
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:490
const FloatA * p_a_grid
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:489
FloatC * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:491
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:414
index_t NPadded
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:462
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:456
index_t StrideA
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:458
index_t StrideB
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:459
index_t NBlock
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:467
index_t BK0
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:465
index_t KPadded
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:463
index_t MPadded
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:461
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:457
index_t MBlock
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:466
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:437
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:415
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:455
index_t AK0
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:464
index_t StrideC
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:460
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:114
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:506
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:557
static constexpr auto BK1Number
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:128
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:142
static __host__ auto CalculateAK0(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:152
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:137
static constexpr auto I6
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:121
static constexpr auto I1
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:116
static constexpr auto I2
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:117
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:661
static constexpr auto I5
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:120
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:361
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:130
static __host__ auto CalculateBK0(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:169
static __device__ void Run(const FloatA *__restrict__ p_a_grid, const FloatB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, void *__restrict__ p_shared, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:686
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:191
static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:278
static constexpr auto BK0Number
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:126
static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:196
static constexpr auto I7
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:122
static constexpr auto I3
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:118
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:147
static constexpr auto AK1Number
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:127
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:186
static constexpr auto AK0Number
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:125
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:529
static constexpr auto I0
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:115
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:514
static constexpr auto I4
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:119
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:498
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage, LoopSched >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:496
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:669
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:132
Definition: xdlops_gemm.hpp:942
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:308