/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Source File
gridwise_gemm_xdl_cshuffle_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/env.hpp"
17 
18 namespace ck {
19 
20 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
21 // kernel function Blockers:
22 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
23 // two lds chunks.
24 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
25 // buffer when we declare __shared__ inside blkgemmpipe
26 template <typename GridwiseGemm,
27  bool HasMainKBlockLoop,
28  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
29  index_t MinimumOccupancy = 1,
30  TailNumber TailNum = TailNumber::Full>
31 __global__ void
32 #if CK_USE_LAUNCH_BOUNDS
33  __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
34 #endif
35  // __attribute__((amdgpu_waves_per_eu(1, 1)))
36  kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
37 {
38 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
39  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
40 
41  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
42 
43  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
44  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
45  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
46  karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
47  p_shared,
48  karg);
49 #else
50  ignore = karg;
51 #endif // end of if (defined(__gfx9__))
52 }
53 
54 template <typename GridwiseGemm,
55  bool HasMainKBlockLoop,
56  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
57  index_t MinimumOccupancy = 1,
58  TailNumber TailNum = TailNumber::Full>
59 __global__ void
60 #if CK_USE_LAUNCH_BOUNDS
61  __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
62 #endif
63  // __attribute__((amdgpu_waves_per_eu(1, 1)))
64  kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
65 {
66 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
67  // Pass two lds pointer is the key to tell compiler that ds_read/write
68  // operate on different lds chunk at same time without order dependecy
69  __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
70  __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
71 
72  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
73 
74  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
75  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
76  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
77  karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
78  p_shared_0,
79  p_shared_1,
80  karg);
81 #else
82  ignore = karg;
83 #endif // end of if (defined(__gfx9__))
84 }
85 
191 template <typename ALayout,
192  typename BLayout,
193  typename CLayout,
194  typename ADataType,
195  typename BDataType,
196  typename AccDataType,
197  typename CShuffleDataType,
198  typename CDataType,
199  typename AElementwiseOperation,
200  typename BElementwiseOperation,
201  typename CElementwiseOperation,
203  index_t BlockSize,
204  index_t MPerBlock,
205  index_t NPerBlock,
206  index_t KPerBlock,
207  index_t AK1Value,
208  index_t BK1Value,
209  index_t MPerXdl,
210  index_t NPerXdl,
211  index_t MXdlPerWave,
212  index_t NXdlPerWave,
213  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
214  typename ABlockTransferThreadClusterArrangeOrder,
215  typename ABlockTransferSrcAccessOrder,
216  index_t ABlockTransferSrcVectorDim,
217  index_t ABlockTransferSrcScalarPerVector,
218  index_t ABlockTransferDstScalarPerVector_AK1,
219  bool AThreadTransferSrcResetCoordinateAfterRun,
220  index_t ABlockLdsExtraM,
221  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
222  typename BBlockTransferThreadClusterArrangeOrder,
223  typename BBlockTransferSrcAccessOrder,
224  index_t BBlockTransferSrcVectorDim,
225  index_t BBlockTransferSrcScalarPerVector,
226  index_t BBlockTransferDstScalarPerVector_BK1,
227  bool BThreadTransferSrcResetCoordinateAfterRun,
228  index_t BBlockLdsExtraN,
229  index_t CShuffleMXdlPerWavePerShuffle,
230  index_t CShuffleNXdlPerWavePerShuffle,
231  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
232  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
235  typename ComputeTypeA = CDataType,
236  typename ComputeTypeB = ComputeTypeA,
237  bool PermuteA = false,
238  bool PermuteB = false,
239  bool DoElementwiseBeforeCShuffle = false>
241 {
242  static constexpr auto I0 = Number<0>{};
243  static constexpr auto I1 = Number<1>{};
244  static constexpr auto I2 = Number<2>{};
245  static constexpr auto I3 = Number<3>{};
246  static constexpr auto I4 = Number<4>{};
247  static constexpr auto I5 = Number<5>{};
248  static constexpr auto I6 = Number<6>{};
249  static constexpr auto I7 = Number<7>{};
250 
251  // K1 should be Number<...>
252  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
253  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
254  static constexpr auto AK1Number = Number<AK1Value>{};
255  static constexpr auto BK1Number = Number<BK1Value>{};
256 
257  static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
258  static constexpr bool is_single_rate_mfma =
260  lcm_AK1_BK1 <= 4) ||
262  // gfx950 double rate mfma16x16 require at least 128 KPerBlock to consume
264  KPerBlock < 128 && MPerXdl == 16))
265  ? true
266  : false;
267  static constexpr auto is_scale_mfma = false;
268  static constexpr index_t KPack =
270  MfmaSelector<ComputeTypeA,
271  MPerXdl,
272  NPerXdl,
273  ComputeTypeA,
275  is_scale_mfma>::selected_mfma.k_per_blk);
276 
278 
279  static constexpr index_t APackedSize = []() {
281  return 2;
282  else
283  return 1;
284  }();
285 
286  static constexpr index_t BPackedSize = []() {
288  return 2;
289  else
290  return 1;
291  }();
292 
293  __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
294  {
295  return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
296  }
297 
298  __host__ static auto CalculateMPadded(index_t M)
299  {
300  return math::integer_least_multiple(M, MPerBlock);
301  }
302 
303  __host__ static auto CalculateNPadded(index_t N)
304  {
305  return math::integer_least_multiple(N, NPerBlock);
306  }
307 
308  __host__ static auto CalculateKPadded(index_t K)
309  {
310  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
311  }
312 
313  __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
314  {
315  auto K_t = K_Batch * KPerBlock;
316  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
317  }
318 
319  __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
320  {
321  auto K_t = K_Batch * KPerBlock;
322  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
323  }
324 
325  __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
326  {
327  auto K_t = K_Batch * KPerBlock;
328  return (K + K_t - 1) / K_t * KPerBlock;
329  }
330 
331  __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
332  {
333  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
334  auto K_t = K_Batch * KReadVec;
335  return (K + K_t - 1) / K_t * KReadVec;
336  }
337 
338  __host__ static auto CalculateMBlock(index_t M)
339  {
340  return math::integer_divide_ceil(M, MPerBlock);
341  }
342 
343  __host__ static auto CalculateNBlock(index_t N)
344  {
345  return math::integer_divide_ceil(N, NPerBlock);
346  }
347 
348  template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
349  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
350  {
351  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
352  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
353 
355  TileDesc_K0_MN_K1{},
361  }
362 
363  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
364  index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
365  {
366  const auto a_grid_desc_mraw_kraw = [&]() {
367  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
368  {
369  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
370  }
371  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
372  {
373  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
374  }
375  }();
376 
378 
379  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
380  GemmSpec == GemmSpecialization::MNKPadding)
381  {
382  // pad both M and K
383  const auto a_grid_desc_m_k =
384  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
386  make_right_pad_transform(K, KPad - K)),
389 
390  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
391  a_grid_desc_m_k,
396 
397  return a_grid_desc_ak0_m_ak1;
398  }
399  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
400  GemmSpec == GemmSpecialization::MNPadding)
401  {
402  // pad M, but not K
403  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
404  a_grid_desc_mraw_kraw,
406  make_right_pad_transform(M, MPad - M)),
409 
410  return a_grid_desc_ak0_m_ak1;
411  }
412  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
413  GemmSpec == GemmSpecialization::NKPadding)
414  {
415  // pad K, but not M
416  const auto a_grid_desc_m_k = transform_tensor_descriptor(
417  a_grid_desc_mraw_kraw,
421 
422  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
423  a_grid_desc_m_k,
428 
429  return a_grid_desc_ak0_m_ak1;
430  }
431  else
432  {
433  // not pad M or K
434  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
435  a_grid_desc_mraw_kraw,
440 
441  return a_grid_desc_ak0_m_ak1;
442  }
443  }
444 
445  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
446  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
447  {
448  const auto b_grid_desc_nraw_kraw = [&]() {
450  {
451  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
452  }
454  {
455  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
456  }
457  }();
458 
460 
461  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
462  GemmSpec != GemmSpecialization::Default),
463  "pk_i4_t does not support padding");
464 
465  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
466  GemmSpec == GemmSpecialization::MNKPadding)
467  {
468  // pad both N and K
469  const auto b_grid_desc_n_k =
470  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
472  make_right_pad_transform(K, KPad - K)),
475 
476  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
477  b_grid_desc_n_k,
482 
483  return b_grid_desc_bk0_n_bk1;
484  }
485  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
486  GemmSpec == GemmSpecialization::MNPadding)
487  {
488  // pad N, but not K
489  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
490  b_grid_desc_nraw_kraw,
492  make_right_pad_transform(N, NPad - N)),
495 
496  return b_grid_desc_bk0_n_bk1;
497  }
498  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
499  GemmSpec == GemmSpecialization::MKPadding)
500  {
501  // pad K, but not N
502  const auto b_grid_desc_n_k = transform_tensor_descriptor(
503  b_grid_desc_nraw_kraw,
507 
508  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
509  b_grid_desc_n_k,
514 
515  return b_grid_desc_bk0_n_bk1;
516  }
517  else
518  {
519  if constexpr(!PermuteB)
520  {
521  // not pad N or K
522  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
523  b_grid_desc_nraw_kraw,
528 
529  return b_grid_desc_bk0_n_bk1;
530  }
531  else
532  {
533  // Pre-shuffled Weight
534  // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1]
535  constexpr index_t BK01 = KPerBlock / BK1Value;
536  const index_t BK0_ = StrideB / BK1Value;
537  const index_t BK00 = BK0_ / BK01;
538 
539  const auto b_grid_desc_bk00_n_bk01_bk1_permute =
540  make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value));
541 
542  const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor(
543  b_grid_desc_bk00_n_bk01_bk1_permute,
546  make_pass_through_transform(BK1Value)),
549 
550  return b_grid_desc_bk0_n_bk1_permute;
551  }
552  }
553  }
554 
555  template <typename ABlockDesc_AK0_M_AK1>
556  __host__ __device__ static constexpr auto
557  MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
558  {
559  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
560 
561  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
562  }
563 
564  template <typename BBlockDesc_BK0_N_BK1>
565  __host__ __device__ static constexpr auto
566  MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
567  {
568  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
569 
570  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
571  }
572 
573  __host__ __device__ static auto
575  {
576  const auto c_grid_desc_mraw_nraw = [&]() {
578  {
579  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
580  }
582  {
583  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
584  }
585  }();
586 
587  // pad M and N
588  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
590  make_right_pad_transform(N, NPad - N)),
593 #if 0
595 
596  if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
597  GemmSpec == GemmSpecialization::MNKPadding)
598  {
599  // pad M and N
600  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
602  make_right_pad_transform(N, NPad - N)),
605  }
606  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
607  GemmSpec == GemmSpecialization::MKPadding)
608  {
609  // pad M, but not N
611  c_grid_desc_mraw_nraw,
615  }
616  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
617  GemmSpec == GemmSpecialization::NKPadding)
618  {
619  // pad N, but not M
621  c_grid_desc_mraw_nraw,
625  }
626  else
627  {
628  // not pad M or N
629  return c_grid_desc_mraw_nraw;
630  }
631 #endif
632  }
633 
634  struct Problem
635  {
636  __host__ Problem(index_t M_,
637  index_t N_,
638  index_t K_,
639  index_t StrideA_,
640  index_t StrideB_,
641  index_t StrideC_,
642  index_t KBatch_,
643  AElementwiseOperation a_element_op,
644  BElementwiseOperation b_element_op,
645  CElementwiseOperation c_element_op)
646  : M{M_},
647  N{N_},
648  K{K_},
649  StrideA{StrideA_},
650  StrideB{StrideB_},
651  StrideC{StrideC_},
652  KBatch{KBatch_},
655  KRead{CalculateKRead(K_, KBatch_)},
656  KPadded{CalculateKPadded(K_, KBatch_)},
657  AK0{CalculateAK0Padded(K_, KBatch_)},
658  BK0{CalculateBK0Padded(K_, KBatch_)},
659  MBlock{CalculateMBlock(M_)},
660  NBlock{CalculateNBlock(N_)},
661  a_element_op_{a_element_op},
662  b_element_op_{b_element_op},
663  c_element_op_{c_element_op}
664  {
665  }
666 
667  __host__ void Print() const
668  {
669  std::cout << "problem {"
670  << "M:" << M << ", "
671  << "N:" << N << ", "
672  << "K:" << K << ", "
673  << "SA:" << StrideA << ", "
674  << "SB:" << StrideB << ", "
675  << "SC:" << StrideC << ", "
676  << "MP:" << MPadded << ", "
677  << "NP:" << NPadded << ", "
678  << "KRead:" << KRead << ", "
679  << "KP:" << KPadded << ", "
680  << "AK0:" << AK0 << ", "
681  << "BK0:" << BK0 << ", "
682  << "MBlock: " << MBlock << ", "
683  << "NBlock: " << NBlock << "}" << std::endl;
684  }
685 
701  AElementwiseOperation a_element_op_;
702  BElementwiseOperation b_element_op_;
703  CElementwiseOperation c_element_op_;
704  };
705 
706  // Argument
708  {
709  __host__ Argument(const ADataType* p_a_grid_,
710  const BDataType* p_b_grid_,
711  CDataType* p_c_grid_,
712  index_t M_,
713  index_t N_,
714  index_t K_,
715  index_t StrideA_,
716  index_t StrideB_,
717  index_t StrideC_,
718  index_t k_batch_,
719  bool is_reduce_ = false,
720  AElementwiseOperation a_element_op = AElementwiseOperation{},
721  BElementwiseOperation b_element_op = BElementwiseOperation{},
722  CElementwiseOperation c_element_op = CElementwiseOperation{})
723  : Problem{M_,
724  N_,
725  K_,
726  StrideA_,
727  StrideB_,
728  StrideC_,
729  k_batch_,
730  a_element_op,
731  b_element_op,
732  c_element_op},
733  p_a_grid{p_a_grid_},
734  p_b_grid{p_b_grid_},
735  p_c_grid{p_c_grid_},
736  is_reduce(is_reduce_)
737  {
738  }
739 
740  __host__ __device__ inline bool IsReduceAdd() const
741  {
742  return (Problem::KBatch > 1) && is_reduce;
743  }
744 
745  __host__ __device__ inline bool IsAtomicAdd() const
746  {
747  return (Problem::KBatch > 1) && (!is_reduce);
748  }
749 
750  const ADataType* p_a_grid;
751  const BDataType* p_b_grid;
752  CDataType* p_c_grid;
753  bool is_reduce;
754  };
755 
757  {
758 
759  __device__ SplitKBatchOffset(Argument& karg)
760  {
761  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
762  {
763  a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
764  }
765  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
766  {
767  a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
768  }
769 
770  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
771  {
772  b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
773  }
774  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
775  {
776  if constexpr(!PermuteB)
777  {
778  b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
779  }
780  else
781  {
782  const int k0_offset = karg.KRead * karg.N;
783  b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
784  }
785  }
786 
787  if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
788  {
789  karg.K = karg.KRead;
790  }
791  else
792  {
793  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
794  }
795 
796  if(karg.IsReduceAdd())
797  {
798  c_reduce_offset = blockIdx.z * karg.M * karg.N;
799  }
800  else
801  {
802  c_reduce_offset = 0;
803  }
804  }
805 
809  };
810 
811  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
812  {
813  // A matrix in LDS memory, dst of blockwise copy
814  if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
815  {
816  // bank conflict when writting the data into LDS, but don't worry, we have whole entire
817  // loop to hide it in v4. it may give you some benefit from less valu in compute address
821  }
822  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
823  // in some cases.
825  {
826  constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize;
827  constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
828  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
829  make_tuple(
830  AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
832 
833  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
834  a_lds_block_desc,
840 
841  constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor(
842  a_lds_block_desc_permuted,
848 
849  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
850  a_lds_block_desc_ak0_mldslayer_m_ak1,
857 
858  return a_lds_block_desc_ak0_m_ak1;
859  }
860  else // ColumnMajor A
861  {
862  // kfold and mpair dimension is not always required.
863  // more dimension in merge_transform increase the difficulty of generating immarg offset
864  // for compiler.
865  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
866  constexpr auto M1 = MPerBlock / M0;
867 
868  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
869  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
870  constexpr auto KThreadRead = 64 / MPerXdl;
871  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
872 
873  constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
874  ? 1
875  : 128 / (AK1Number * M0 * sizeof(ADataType));
876  constexpr auto KThreadReadPerm =
877  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
878  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
879  : KThreadRead;
880 
881  // 1<=mpair<=n0
882  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
883  ? 1
884  : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
885  ? M0
886  : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
887 
888  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
892  Number<kfold * M0 / mpair>{},
893  Number<mpair>{},
894  AK1Number));
895 
896  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
897  a_lds_block_desc,
898  make_tuple(
902  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
905  make_tuple(
907  make_tuple(
909 
910  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
911  a_lds_block_desc_permuted,
912  make_tuple(
920  Sequence<1>{},
921  Sequence<2>{},
922  Sequence<3>{},
923  Sequence<4>{},
924  Sequence<5>{}),
926  Sequence<2>{},
927  Sequence<0, 3>{},
928  Sequence<4, 5>{},
929  Sequence<6>{},
930  Sequence<7>{}));
931 
932  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
933  a_lds_block_desc_unmerged,
936  Number<KThreadWrite / kfold / KThreadReadPerm>{},
937  Number<kfold>{},
944 
945  return a_lds_block_desc_ak0_m_ak1;
946  }
947  }
948 
949  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
950  {
951  // B matrix in LDS memory, dst of blockwise copy
952  if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
953  {
954  // bank conflict when writting the data into LDS, but don't worry, we have whole entire
955  // loop to hide it in v4. it may give you some benefit from less valu in compute address
959  }
961  {
962  // NLdsLayer * K0 as logical Bank
963  constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize;
964  constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
965  constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
966  make_tuple(
967  BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number),
969 
970  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
971  b_lds_block_desc,
977 
978  constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
979  b_lds_block_desc_permuted,
985 
986  constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
987  b_lds_block_desc_bk0_nldslayer_n_bk1,
994 
995  return b_lds_block_desc_bk0_n_bk1;
996  }
997  else // RowMajor B
998  {
999  constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
1000  constexpr auto N1 = NPerBlock / N0;
1001 
1002  constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
1003  constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
1004  constexpr auto KThreadRead = 64 / NPerXdl;
1005  constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
1006 
1007  constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
1008  ? 1
1009  : 128 / (BK1Number * N0 * sizeof(BDataType));
1010  constexpr auto KThreadReadPerm =
1011  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
1012  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
1013  : KThreadRead;
1014 
1015  // 1<=npair<=n0
1016  constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128)
1017  ? 1
1018  : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0
1019  ? N0
1020  : 128 / (BK1Number * NPerXdl * sizeof(BDataType)));
1021 
1022  constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
1026  Number<kfold * N0 / npair>{},
1027  Number<npair>{},
1028  BK1Number));
1029 
1030  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
1031  b_lds_block_desc,
1032  make_tuple(
1036  make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
1039  make_tuple(
1041  make_tuple(
1043 
1044  constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
1045  b_lds_block_desc_permuted,
1046  make_tuple(
1054  Sequence<1>{},
1055  Sequence<2>{},
1056  Sequence<3>{},
1057  Sequence<4>{},
1058  Sequence<5>{}),
1060  Sequence<2>{},
1061  Sequence<0, 3>{},
1062  Sequence<4, 5>{},
1063  Sequence<6>{},
1064  Sequence<7>{}));
1065 
1066  constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
1067  b_lds_block_desc_unmerged,
1070  Number<KThreadWrite / kfold / KThreadReadPerm>{},
1071  Number<kfold>{},
1078 
1079  return b_lds_block_desc_bk0_n_bk1;
1080  }
1081  }
1082 
1084  {
1085  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1086  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1087 
1088  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1090  make_tuple(I1,
1092  I1,
1094 
1095  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1096  }
1097 
1100  BlkGemmPipelineVer,
1101  BlkGemmPipeSched,
1102  BlockSize,
1103  ADataType,
1104  BDataType,
1105  ComputeTypeA,
1106  AccDataType,
1113  ABlockTransferSrcScalarPerVector,
1114  BBlockTransferSrcScalarPerVector,
1115  MPerBlock,
1116  NPerBlock,
1117  KPerBlock,
1118  MPerXdl,
1119  NPerXdl,
1120  MXdlPerWave,
1121  NXdlPerWave,
1122  KPack>())>;
1123 
1124  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1125  {
1126  // LDS allocation for A and B: be careful of alignment
1127  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1128  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1129 
1130  // lds max alignment
1131  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1132 
1133  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1134  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1135 
1136  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1137  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1138 
1139  // LDS allocation for C shuffle in LDS
1140  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1142 
1143  constexpr auto c_block_size =
1144  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1145 
1146  return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize +
1147  b_block_space_size_aligned * sizeof(BDataType) / BPackedSize),
1148  c_block_size * sizeof(CShuffleDataType));
1149  }
1150 
1151  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1152  __host__ static constexpr bool CheckValidity(const Argument& karg)
1153  {
1154  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1155  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1156  "Invalid tuning param!");
1157 
1158  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1163  {
1164  if(!(karg.M % MPerBlock == 0))
1165  {
1166  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1167  {
1168  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1169  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1170  << std::endl;
1171  }
1172  return false;
1173  }
1174  }
1175 
1176  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1181  {
1182  if(!(karg.N % NPerBlock == 0))
1183  {
1184  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1185  {
1186  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1187  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1188  << std::endl;
1189  }
1190  return false;
1191  }
1192  }
1193 
1194  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1198  {
1199 
1200  auto K_t = karg.KBatch * KPerBlock;
1201  if(!(karg.K % K_t == 0))
1202  {
1203  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1204  {
1205  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1206  << karg.K << " " << __FILE__ << ":" << __LINE__
1207  << ", in function: " << __func__ << std::endl;
1208  }
1209  return false;
1210  }
1211  }
1212  else
1213  {
1214  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1215  auto K_t = karg.KBatch * KReadVec;
1216  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1217  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1218  {
1219  return false;
1220  }
1221  }
1222 
1224  {
1225  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1226  {
1227  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1228  {
1229  std::cout << "Arg K (" << karg.K
1230  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1231  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1232  << __LINE__ << ", in function: " << __func__ << std::endl;
1233  }
1234  return false;
1235  }
1236  }
1237  else
1238  {
1239  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1240  {
1241  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1242  {
1243  std::cout << "Arg M (" << karg.M
1244  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1245  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1246  << __LINE__ << ", in function: " << __func__ << std::endl;
1247  }
1248  return false;
1249  }
1250  }
1251 
1253  {
1254  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1255  {
1256  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1257  {
1258  std::cout << "Arg N (" << karg.N
1259  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1260  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1261  << __LINE__ << ", in function: " << __func__ << std::endl;
1262  }
1263  return false;
1264  }
1265  }
1266  else
1267  {
1268  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1269  {
1270  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1271  {
1272  std::cout << "Arg K (" << karg.K
1273  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1274  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1275  << __LINE__ << ", in function: " << __func__ << std::endl;
1276  }
1277  return false;
1278  }
1279  }
1280 
1282  {
1283  if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1284  {
1285  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1286  {
1287  std::cout << "Arg N (" << karg.N
1288  << ") value is not a multiple of "
1289  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1290  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1291  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1292  << std::endl;
1293  }
1294  return false;
1295  }
1296  }
1297  else
1298  {
1299  if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1300  {
1301  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1302  {
1303  std::cout << "Arg M (" << karg.M
1304  << ") value is not a multiple of "
1305  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1306  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1307  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1308  << std::endl;
1309  }
1310  return false;
1311  }
1312  }
1313 
1314  if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
1315  is_same<remove_cvref_t<CDataType>, float>::value ||
1318  {
1319  if(!karg.IsReduceAdd())
1320  {
1321  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1322  {
1323  std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__
1324  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1325  }
1326  if(karg.KBatch > 1)
1327  {
1328  return false;
1329  }
1330  }
1331  }
1332 
1333  // check gridwise gemm pipeline
1334  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1335 
1336  if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
1337  {
1338  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1339  {
1340  return false;
1341  }
1342  }
1343 
1344  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1345  return true;
1346  }
1347 
1348  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1349  {
1350  const index_t num_loop = K / KPerBlock;
1351 
1352  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1353  }
1354 
1355  __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1356  {
1357  const index_t num_loop = K / KPerBlock;
1358 
1359  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1360  }
1361 
1362  template <typename CGridDesc>
1363  __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1364  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1365  {
1366  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1367  c_grid_desc_m_n,
1372 
1373  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1374  }
1375 
1376  // return block_id to C matrix tile idx (m0, n0) mapping
1377  // if arch = gfx942
1379  // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
1380 
1381  template <typename AGridDesc_AK0_M_K1,
1382  typename BGridDesc_BK0_N_K1,
1383  typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1384  bool HasMainKBlockLoop,
1385  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1386  TailNumber TailNum = TailNumber::Odd>
1387  __device__ static void Run(const ADataType* p_a_grid,
1388  const BDataType* p_b_grid,
1389  CDataType* p_c_grid,
1390  void* p_shared,
1391  const Problem& problem,
1392  const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1393  const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1394  const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1395  c_grid_desc_mblock_mperblock_nblock_nperblock)
1396  {
1397  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1398  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1399  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1400  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1401  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1402  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1403 
1404  // divide block work by [M, N]
1405  const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1406 
1407  const auto block_work_idx =
1408  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1409 
1410  if(!block_2_ctile_map.ValidCTileIndex(
1411  block_work_idx,
1412  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1413  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1414  {
1415  return;
1416  }
1417 
1418  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1419  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1420 
1421  // HACK: this force m/n_block_data_idx_on_grid into SGPR
1422  const index_t m_block_data_idx_on_grid =
1423  __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1424 
1425  const index_t n_block_data_idx_on_grid =
1426  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1427 
1428  // lds max alignment
1429  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1430 
1431  // A matrix in LDS memory, dst of blockwise copy
1432  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1433 
1434  // B matrix in LDS memory, dst of blockwise copy
1435  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1436 
1437  // A matrix blockwise copy
1438  auto a_blockwise_copy =
1440  AElementwiseOperation,
1444  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1445  ABlockTransferThreadClusterArrangeOrder,
1446  ADataType,
1447  ADataType,
1448  decltype(a_grid_desc_ak0_m_ak1),
1449  decltype(a_block_desc_ak0_m_ak1),
1450  ABlockTransferSrcAccessOrder,
1452  ABlockTransferSrcVectorDim,
1453  2,
1454  ABlockTransferSrcScalarPerVector,
1455  ABlockTransferDstScalarPerVector_AK1,
1456  1,
1457  1,
1458  AThreadTransferSrcResetCoordinateAfterRun,
1459  true,
1460  BlockwiseGemmPipe::GlobalBufferNum>(
1461  a_grid_desc_ak0_m_ak1,
1462  make_multi_index(0, m_block_data_idx_on_grid, 0),
1463  problem.a_element_op_,
1464  a_block_desc_ak0_m_ak1,
1465  make_multi_index(0, 0, 0),
1467 
1468  // B matrix blockwise copy
1469  auto b_blockwise_copy =
1471  BElementwiseOperation,
1475  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1476  BBlockTransferThreadClusterArrangeOrder,
1477  BDataType,
1478  BDataType,
1479  decltype(b_grid_desc_bk0_n_bk1),
1480  decltype(b_block_desc_bk0_n_bk1),
1481  BBlockTransferSrcAccessOrder,
1483  BBlockTransferSrcVectorDim,
1484  2,
1485  BBlockTransferSrcScalarPerVector,
1486  BBlockTransferDstScalarPerVector_BK1,
1487  1,
1488  1,
1489  BThreadTransferSrcResetCoordinateAfterRun,
1490  true,
1491  BlockwiseGemmPipe::GlobalBufferNum>(
1492  b_grid_desc_bk0_n_bk1,
1493  make_multi_index(0, n_block_data_idx_on_grid, 0),
1494  problem.b_element_op_,
1495  b_block_desc_bk0_n_bk1,
1496  make_multi_index(0, 0, 0),
1498 
1499  // LDS allocation for A and B: be careful of alignment
1500  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1501  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1502 
1503  // Cast after lds
1504  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1505  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1506 
1507  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1508  reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
1509  sizeof(ADataType) /
1510  APackedSize),
1511  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1512 
1513  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1514  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1515 
1516  // Blockwise GEMM pipeline
1517  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1518  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1519  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1520 
1521  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1522  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1523  KPerBlock);
1524 
1525  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1526  a_block_desc_ak0_m_ak1,
1527  a_blockwise_copy,
1528  a_grid_buf,
1529  a_block_buf,
1530  a_block_slice_copy_step,
1531  b_grid_desc_bk0_n_bk1,
1532  b_block_desc_bk0_n_bk1,
1533  b_blockwise_copy,
1534  b_grid_buf,
1535  b_block_buf,
1536  b_block_slice_copy_step,
1537  c_thread_buf,
1538  num_k_block_main_loop);
1539 
1540  // shuffle C and write out
1541  {
1542  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1543  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1544  "wrong!");
1545 
1546  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1547  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1548 
1549  // TODO: hacky, fix it!
1550  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1551  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1552 
1553  // TODO: hacky, fix it!
1554  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1555  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1556  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1557 
1558  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1559  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1560  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1561  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1562  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1563  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1564  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1565  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1566 
1567  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1569 
1570  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1571  static_cast<CShuffleDataType*>(p_shared),
1572  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1573 
1574  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1575  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1576  make_tuple(
1579  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1580  M1, // M1 = MWave
1581  M2, // M2 * M3 * M4 = MPerXdl
1582  M3,
1583  M4)),
1586  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1587  N1, // N1 = NWave
1588  N2))), // N2 = NPerXdl
1590  make_tuple(
1592 
1593  // calculate origin of thread output tensor on global memory
1594  // blockwise GEMM c matrix starting index
1595  const auto c_thread_mtx_on_block =
1596  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1597 
1598  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1599  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1600 
1601  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1603  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1605  make_tuple(Sequence<0>{}));
1606 
1607  const auto m_thread_data_on_block_idx =
1608  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1609  make_multi_index(m_thread_data_on_block));
1610 
1611  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1615  make_tuple(Sequence<0>{}));
1616 
1617  const auto n_thread_data_on_block_idx =
1618  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1619  make_multi_index(n_thread_data_on_block));
1620 
1622  const auto& vpgr_to_lds_element_op = [&] {
1623  if constexpr(DoElementwiseBeforeCShuffle)
1624  {
1625  return problem.c_element_op_;
1626  }
1627  else
1628  {
1629  return pass_through;
1630  }
1631  };
1632  const auto& lds_to_global_element_op = [&] {
1633  if constexpr(!DoElementwiseBeforeCShuffle)
1634  {
1635  return problem.c_element_op_;
1636  }
1637  else
1638  {
1639  return pass_through;
1640  }
1641  };
1642 
1643  // shuffle: threadwise copy C from VGPR to LDS
1644  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
1645  AccDataType,
1646  CShuffleDataType,
1647  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1648  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1649  conditional_t<DoElementwiseBeforeCShuffle,
1650  CElementwiseOperation,
1652  Sequence<CShuffleMXdlPerWavePerShuffle,
1653  CShuffleNXdlPerWavePerShuffle,
1654  I1,
1655  I1,
1656  M2,
1657  I1,
1658  M4,
1659  I1>,
1661  7,
1662  1,
1664  1,
1665  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1666  make_multi_index(0,
1667  0,
1668  m_thread_data_on_block_idx[I1],
1669  n_thread_data_on_block_idx[I1],
1670  m_thread_data_on_block_idx[I2],
1671  m_thread_data_on_block_idx[I3],
1672  m_thread_data_on_block_idx[I4],
1673  n_thread_data_on_block_idx[I2]),
1674  vpgr_to_lds_element_op()};
1675 
1676  // shuffle: blockwise copy C from LDS to global
1677  auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1678  ThisThreadBlock, // ThreadGroup
1679  conditional_t<!DoElementwiseBeforeCShuffle,
1680  CElementwiseOperation,
1682  CGlobalMemoryDataOperation, // DstInMemOp,
1683  Sequence<1,
1684  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1685  1,
1686  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1687  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1688  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1689  CShuffleDataType, // typename SrcData,
1690  CDataType, // typename DstData,
1691  decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1692  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1693  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1694  3, // index_t VectorDim,
1695  CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1696  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1697  false> // bool ThreadTransferDstResetCoordinateAfterRun>
1698  {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1699  make_multi_index(0, 0, 0, 0),
1700  c_grid_desc_mblock_mperblock_nblock_nperblock,
1701  make_multi_index(block_m_id, 0, block_n_id, 0),
1702  lds_to_global_element_op()};
1703 
1704  // space filling curve for threadwise C in VGPR
1705  constexpr auto sfc_c_vgpr =
1708  Sequence<CShuffleMXdlPerWavePerShuffle,
1709  CShuffleNXdlPerWavePerShuffle,
1710  1,
1711  1,
1712  M2,
1713  1,
1714  M4,
1715  1>>{};
1716 
1717  // space filling curve for shuffled blockwise C in global mem
1718  constexpr auto sfc_c_global =
1721  Sequence<1,
1722  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1723  1,
1724  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1725 
1726  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1727 
1728  static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1729 
1730  static_for<0, num_access, 1>{}([&](auto access_id) {
1731  // make sure it's safe to write to LDS
1732  block_sync_lds();
1733 
1734  // each thread write its data from VGPR to LDS
1735  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1736  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1737  c_thread_buf,
1738  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1739  c_shuffle_block_buf);
1740 
1741  // make sure it's safe to read from LDS
1742  block_sync_lds();
1743 
1744  // each block copy its data from LDS to global
1745  c_shuffle_block_copy_lds_to_global.Run(
1746  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1747  c_shuffle_block_buf,
1748  c_grid_desc_mblock_mperblock_nblock_nperblock,
1749  c_grid_buf);
1750 
1751  if constexpr(access_id < num_access - 1)
1752  {
1753  constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1754 
1755  // move on C
1756  c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1757  c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1758  }
1759  });
1760  }
1761  }
1762 
1763  template <bool HasMainKBlockLoop,
1764  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1765  TailNumber TailNum = TailNumber::Odd>
1766  __device__ static void Run(const ADataType* p_a_grid,
1767  const BDataType* p_b_grid,
1768  CDataType* p_c_grid,
1769  void* p_shared,
1770  const Problem& problem)
1771  {
1772  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1773  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1774  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
1775  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1776  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
1777  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1778  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1780  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1781 
1782  Run<decltype(a_grid_desc_ak0_m_ak1),
1783  decltype(b_grid_desc_bk0_n_bk1),
1784  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1785  HasMainKBlockLoop,
1786  CGlobalMemoryDataOperation,
1787  TailNum>(p_a_grid,
1788  p_b_grid,
1789  p_c_grid,
1790  p_shared,
1791  problem,
1792  a_grid_desc_ak0_m_ak1,
1793  b_grid_desc_bk0_n_bk1,
1794  c_grid_desc_mblock_mperblock_nblock_nperblock);
1795  }
1796 
1797  template <typename AGridDesc_AK0_M_K1,
1798  typename BGridDesc_BK0_N_K1,
1799  typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1800  bool HasMainKBlockLoop,
1801  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1802  TailNumber TailNum = TailNumber::Odd>
1803  __device__ static void Run_2Lds(const ADataType* p_a_grid,
1804  const BDataType* p_b_grid,
1805  CDataType* p_c_grid,
1806  void* p_shared_0,
1807  void* p_shared_1,
1808  const Problem& problem,
1809  const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1810  const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1811  const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1812  c_grid_desc_mblock_mperblock_nblock_nperblock)
1813  {
1814  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1815  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1816  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1817  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1818  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1819  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1820 
1821  // divide block work by [M, N]
1822  const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1823 
1824  const auto block_work_idx =
1825  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1826 
1827  if(!block_2_ctile_map.ValidCTileIndex(
1828  block_work_idx,
1829  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1830  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1831  {
1832  return;
1833  }
1834 
1835  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1836  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1837 
1838  // HACK: this force m/n_block_data_idx_on_grid into SGPR
1839  const index_t m_block_data_idx_on_grid =
1840  __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1841 
1842  const index_t n_block_data_idx_on_grid =
1843  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1844 
1845  // lds max alignment
1846  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1847 
1848  // A matrix in LDS memory, dst of blockwise copy
1849  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1850 
1851  // B matrix in LDS memory, dst of blockwise copy
1852  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1853 
1854  // A matrix blockwise copy
1855  auto a_blockwise_copy =
1857  AElementwiseOperation,
1861  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1862  ABlockTransferThreadClusterArrangeOrder,
1863  ADataType,
1864  ADataType,
1865  decltype(a_grid_desc_ak0_m_ak1),
1866  decltype(a_block_desc_ak0_m_ak1),
1867  ABlockTransferSrcAccessOrder,
1869  ABlockTransferSrcVectorDim,
1870  2,
1871  ABlockTransferSrcScalarPerVector,
1872  ABlockTransferDstScalarPerVector_AK1,
1873  1,
1874  1,
1875  AThreadTransferSrcResetCoordinateAfterRun,
1876  true,
1877  BlockwiseGemmPipe::GlobalBufferNum>(
1878  a_grid_desc_ak0_m_ak1,
1879  make_multi_index(0, m_block_data_idx_on_grid, 0),
1880  problem.a_element_op_,
1881  a_block_desc_ak0_m_ak1,
1882  make_multi_index(0, 0, 0),
1884 
1885  // B matrix blockwise copy
1886  auto b_blockwise_copy =
1888  BElementwiseOperation,
1892  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1893  BBlockTransferThreadClusterArrangeOrder,
1894  BDataType,
1895  BDataType,
1896  decltype(b_grid_desc_bk0_n_bk1),
1897  decltype(b_block_desc_bk0_n_bk1),
1898  BBlockTransferSrcAccessOrder,
1900  BBlockTransferSrcVectorDim,
1901  2,
1902  BBlockTransferSrcScalarPerVector,
1903  BBlockTransferDstScalarPerVector_BK1,
1904  1,
1905  1,
1906  BThreadTransferSrcResetCoordinateAfterRun,
1907  true,
1908  BlockwiseGemmPipe::GlobalBufferNum>(
1909  b_grid_desc_bk0_n_bk1,
1910  make_multi_index(0, n_block_data_idx_on_grid, 0),
1911  problem.b_element_op_,
1912  b_block_desc_bk0_n_bk1,
1913  make_multi_index(0, 0, 0),
1915 
1916  // LDS allocation for A and B: be careful of alignment
1917  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1918  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1919 
1920  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1921  static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1922 
1923  auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1924  bit_cast<BDataType*>(static_cast<char*>(p_shared_0) +
1925  a_block_space_size_aligned * sizeof(ADataType)),
1926  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1927 
1928  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1929  static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1930 
1931  auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1932  bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
1933  a_block_space_size_aligned * sizeof(ADataType)),
1934  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1935 
1936  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
1937  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
1938 
1939  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1940  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1941 
1942  // Blockwise GEMM pipeline
1943  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1944  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1945  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1946 
1947  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1948  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1949  KPerBlock);
1950 
1951  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1952  a_block_desc_ak0_m_ak1,
1953  a_blockwise_copy,
1954  a_grid_buf,
1955  a_block_bufs,
1956  a_block_slice_copy_step,
1957  b_grid_desc_bk0_n_bk1,
1958  b_block_desc_bk0_n_bk1,
1959  b_blockwise_copy,
1960  b_grid_buf,
1961  b_block_bufs,
1962  b_block_slice_copy_step,
1963  c_thread_buf,
1964  num_k_block_main_loop);
1965 
1966  // shuffle C and write out
1967  {
1968  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1969  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1970  "wrong!");
1971 
1972  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1973  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1974 
1975  // TODO: hacky, fix it!
1976  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1977  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1978 
1979  // TODO: hacky, fix it!
1980  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1981  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1982  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1983 
1984  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1985  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1986  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1987  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1988  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1989  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1990  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1991  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1992 
1993  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1995 
1996  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1997  static_cast<CShuffleDataType*>(p_shared_0),
1998  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1999 
2000  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
2001  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2002  make_tuple(
2005  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
2006  M1, // M1 = MWave
2007  M2, // M2 * M3 * M4 = MPerXdl
2008  M3,
2009  M4)),
2012  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
2013  N1, // N1 = NWave
2014  N2))), // N2 = NPerXdl
2016  make_tuple(
2018 
2019  // calculate origin of thread output tensor on global memory
2020  // blockwise GEMM c matrix starting index
2021  const auto c_thread_mtx_on_block =
2022  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2023 
2024  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2025  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2026 
2027  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2029  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
2031  make_tuple(Sequence<0>{}));
2032 
2033  const auto m_thread_data_on_block_idx =
2034  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2035  make_multi_index(m_thread_data_on_block));
2036 
2037  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2041  make_tuple(Sequence<0>{}));
2042 
2043  const auto n_thread_data_on_block_idx =
2044  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2045  make_multi_index(n_thread_data_on_block));
2046 
2047  // shuffle: threadwise copy C from VGPR to LDS
2048  auto c_thread_copy_vgpr_to_lds =
2050  CShuffleDataType,
2051  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2052  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2054  Sequence<CShuffleMXdlPerWavePerShuffle,
2055  CShuffleNXdlPerWavePerShuffle,
2056  I1,
2057  I1,
2058  M2,
2059  I1,
2060  M4,
2061  I1>,
2063  7,
2064  1,
2066  1,
2067  true>{
2068  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2069  make_multi_index(0,
2070  0,
2071  m_thread_data_on_block_idx[I1],
2072  n_thread_data_on_block_idx[I1],
2073  m_thread_data_on_block_idx[I2],
2074  m_thread_data_on_block_idx[I3],
2075  m_thread_data_on_block_idx[I4],
2076  n_thread_data_on_block_idx[I2]),
2078 
2079  // shuffle: blockwise copy C from LDS to global
2080  auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
2081  ThisThreadBlock, // ThreadGroup
2082  CElementwiseOperation, // ElementwiseOperation,
2083  CGlobalMemoryDataOperation, // DstInMemOp,
2084  Sequence<1,
2085  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2086  1,
2087  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2088  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
2089  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2090  CShuffleDataType, // typename SrcData,
2091  CDataType, // typename DstData,
2092  decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2093  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2094  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
2095  3, // index_t VectorDim,
2096  CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
2097  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
2098  false> // bool ThreadTransferDstResetCoordinateAfterRun>
2099  {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2100  make_multi_index(0, 0, 0, 0),
2101  c_grid_desc_mblock_mperblock_nblock_nperblock,
2102  make_multi_index(block_m_id, 0, block_n_id, 0),
2103  problem.c_element_op_};
2104 
2105  // space filling curve for threadwise C in VGPR
2106  constexpr auto sfc_c_vgpr =
2109  Sequence<CShuffleMXdlPerWavePerShuffle,
2110  CShuffleNXdlPerWavePerShuffle,
2111  1,
2112  1,
2113  M2,
2114  1,
2115  M4,
2116  1>>{};
2117 
2118  // space filling curve for shuffled blockwise C in global mem
2119  constexpr auto sfc_c_global =
2122  Sequence<1,
2123  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2124  1,
2125  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2126 
2127  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2128 
2129  static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
2130 
2131  static_for<0, num_access, 1>{}([&](auto access_id) {
2132  // make sure it's safe to write to LDS
2133  block_sync_lds();
2134 
2135  // each thread write its data from VGPR to LDS
2136  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2137  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2138  c_thread_buf,
2139  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2140  c_shuffle_block_buf);
2141 
2142  // make sure it's safe to read from LDS
2143  block_sync_lds();
2144 
2145  // each block copy its data from LDS to global
2146  c_shuffle_block_copy_lds_to_global.Run(
2147  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2148  c_shuffle_block_buf,
2149  c_grid_desc_mblock_mperblock_nblock_nperblock,
2150  c_grid_buf);
2151 
2152  if constexpr(access_id < num_access - 1)
2153  {
2154  constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
2155 
2156  // move on C
2157  c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
2158  c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
2159  }
2160  });
2161  }
2162  }
2163 
2164  template <bool HasMainKBlockLoop,
2165  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
2166  TailNumber TailNum = TailNumber::Odd>
2167  __device__ static void Run_2Lds(const ADataType* p_a_grid,
2168  const BDataType* p_b_grid,
2169  CDataType* p_c_grid,
2170  void* p_shared_0,
2171  void* p_shared_1,
2172  const Problem& problem)
2173  {
2174  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2175  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
2176  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
2177  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2178  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
2179  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
2180 
2181  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2183  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2184 
2185  Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
2186  decltype(b_grid_desc_bk0_n_bk1),
2187  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2188  HasMainKBlockLoop,
2189  CGlobalMemoryDataOperation,
2190  TailNum>(p_a_grid,
2191  p_b_grid,
2192  p_c_grid,
2193  p_shared_0,
2194  p_shared_1,
2195  problem,
2196  a_grid_desc_ak0_m_ak1,
2197  b_grid_desc_bk0_n_bk1,
2198  c_grid_desc_mblock_mperblock_nblock_nperblock);
2199  }
2200 };
2201 
2202 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
int32_t int32_t
Definition: integer.hpp:10
Definition: ck.hpp:269
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:278
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr auto BlockGemmPipeline_Selector()
Definition: blockwise_gemm_pipeline_wmma_selector.hpp:31
_Float16 half_t
Definition: data_type.hpp:30
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
ushort bhalf_t
Definition: data_type.hpp:29
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:59
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:139
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
constexpr bool is_same_v
Definition: type.hpp:283
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:300
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: block_to_ctile_map.hpp:270
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:282
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:708
const BElementwiseOperation b_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:649
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t k_batch_, bool is_reduce_=false, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CElementwiseOperation c_element_op=CElementwiseOperation{})
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:709
const BDataType * p_b_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:751
CDataType * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:752
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:740
const AElementwiseOperation a_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:648
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:745
const ADataType * p_a_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:750
const CElementwiseOperation c_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:650
bool is_reduce
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:753
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:635
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:687
index_t NPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:694
index_t KBatch
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:692
index_t StrideA
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:689
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:636
CElementwiseOperation c_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:703
index_t BK0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:698
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:686
index_t NBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:700
index_t MPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:693
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:688
index_t StrideB
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:690
index_t KPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:696
index_t StrideC
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:691
index_t MBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:699
BElementwiseOperation b_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:702
index_t AK0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:697
index_t KRead
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:695
AElementwiseOperation a_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:701
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:667
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:757
index_t a_k_split_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:806
index_t b_k_split_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:807
__device__ SplitKBatchOffset(Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:759
index_t c_reduce_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:808
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:241
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:566
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:445
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:331
static constexpr auto is_scale_mfma
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:267
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:308
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:298
static constexpr auto BK1Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:255
static constexpr index_t APackedSize
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:279
static constexpr bool is_single_rate_mfma
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:258
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1355
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:277
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1363
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:313
static constexpr auto I2
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:244
static constexpr index_t KPack
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:268
static constexpr auto lcm_AK1_BK1
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:257
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:349
static constexpr auto I7
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:249
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1122
static constexpr auto I5
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:247
static constexpr auto AK1Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:254
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1152
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1803
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:293
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:338
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:2167
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:303
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:557
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:319
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1124
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:949
static constexpr index_t BPackedSize
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:286
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1387
static constexpr auto I6
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:248
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:811
static constexpr auto I1
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:243
static constexpr auto I0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:242
static constexpr auto I3
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:245
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1766
static constexpr auto I4
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:246
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:325
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:363
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:343
static constexpr auto BK0Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:253
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1083
static constexpr auto AK0Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:252
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1348
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:574
Definition: xdlops_gemm.hpp:942
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:186
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:308
#define CK_ENV(name)
Definition: env.hpp:128