include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Source File

include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Source File#

Composable Kernel: include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp Source File
gridwise_gemm_xdl_cshuffle_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/env.hpp"
17 
18 namespace ck {
19 
20 // Currently we do not have a elegant way to put single lds buffer & double lds buffer pipe in same
21 // kernel function Blockers:
22 // 1. Two separted declaration of __shared__ pointer is the key to make sure data access operate on
23 // two lds chunks.
24 // 2. Occupied __shared__ won't release until whole shader end, a.k.a AB and C may not use same lds
25 // buffer when we declare __shared__ inside blkgemmpipe
26 template <typename GridwiseGemm,
27  bool HasMainKBlockLoop,
28  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
29  index_t MinimumOccupancy = 1,
30  TailNumber TailNum = TailNumber::Full>
31 __global__ void
32 #if CK_USE_LAUNCH_BOUNDS
33 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
34 #endif
35  // __attribute__((amdgpu_waves_per_eu(1, 1)))
36  kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
37 {
38 #if defined(__gfx9__)
39  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
40 
41  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
42 
43  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
44  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
45  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
46  karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
47  p_shared,
48  karg);
49 #else
50  ignore = karg;
51 #endif // end of if (defined(__gfx9__))
52 }
53 
54 template <typename GridwiseGemm,
55  bool HasMainKBlockLoop,
56  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
57  index_t MinimumOccupancy = 1,
58  TailNumber TailNum = TailNumber::Full>
59 __global__ void
60 #if CK_USE_LAUNCH_BOUNDS
61 __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
62 #endif
63  // __attribute__((amdgpu_waves_per_eu(1, 1)))
64  kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
65 {
66 #if defined(__gfx9__)
67  // Pass two lds pointer is the key to tell compiler that ds_read/write
68  // operate on different lds chunk at same time without order dependecy
69  __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
70  __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
71 
72  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
73 
74  GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
75  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
76  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
77  karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
78  p_shared_0,
79  p_shared_1,
80  karg);
81 #else
82  ignore = karg;
83 #endif // end of if (defined(__gfx9__))
84 }
85 
191 template <typename ALayout,
192  typename BLayout,
193  typename CLayout,
194  typename ADataType,
195  typename BDataType,
196  typename AccDataType,
197  typename CShuffleDataType,
198  typename CDataType,
199  typename AElementwiseOperation,
200  typename BElementwiseOperation,
201  typename CElementwiseOperation,
203  index_t BlockSize,
204  index_t MPerBlock,
205  index_t NPerBlock,
206  index_t KPerBlock,
207  index_t AK1Value,
208  index_t BK1Value,
209  index_t MPerXdl,
210  index_t NPerXdl,
211  index_t MXdlPerWave,
212  index_t NXdlPerWave,
213  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
214  typename ABlockTransferThreadClusterArrangeOrder,
215  typename ABlockTransferSrcAccessOrder,
216  index_t ABlockTransferSrcVectorDim,
217  index_t ABlockTransferSrcScalarPerVector,
218  index_t ABlockTransferDstScalarPerVector_AK1,
219  bool AThreadTransferSrcResetCoordinateAfterRun,
220  index_t ABlockLdsExtraM,
221  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
222  typename BBlockTransferThreadClusterArrangeOrder,
223  typename BBlockTransferSrcAccessOrder,
224  index_t BBlockTransferSrcVectorDim,
225  index_t BBlockTransferSrcScalarPerVector,
226  index_t BBlockTransferDstScalarPerVector_BK1,
227  bool BThreadTransferSrcResetCoordinateAfterRun,
228  index_t BBlockLdsExtraN,
229  index_t CShuffleMXdlPerWavePerShuffle,
230  index_t CShuffleNXdlPerWavePerShuffle,
231  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
232  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
235  typename ComputeTypeA = CDataType,
236  typename ComputeTypeB = ComputeTypeA,
237  bool PermuteA = false,
238  bool PermuteB = false,
239  bool DoElementwiseBeforeCShuffle = false>
241 {
242  static constexpr auto I0 = Number<0>{};
243  static constexpr auto I1 = Number<1>{};
244  static constexpr auto I2 = Number<2>{};
245  static constexpr auto I3 = Number<3>{};
246  static constexpr auto I4 = Number<4>{};
247  static constexpr auto I5 = Number<5>{};
248  static constexpr auto I6 = Number<6>{};
249  static constexpr auto I7 = Number<7>{};
250 
251  // K1 should be Number<...>
252  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
253  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
254  static constexpr auto AK1Number = Number<AK1Value>{};
255  static constexpr auto BK1Number = Number<BK1Value>{};
256 
257  static constexpr auto lcm_AK1_BK1 = math::lcm(AK1Number, BK1Number);
258  static constexpr bool is_single_rate_mfma =
260  lcm_AK1_BK1 <= 4) ||
262  // gfx950 double rate mfma16x16 require at least 128 KPerBlock to consume
264  KPerBlock < 128 && MPerXdl == 16))
265  ? true
266  : false;
267  static constexpr auto is_scale_mfma = false;
268  static constexpr index_t KPack =
270  MfmaSelector<ComputeTypeA,
271  MPerXdl,
272  NPerXdl,
273  ComputeTypeA,
275  is_scale_mfma>::selected_mfma.k_per_blk);
276 
278 
279  static constexpr index_t APackedSize = []() {
281  return 2;
282  else
283  return 1;
284  }();
285 
286  static constexpr index_t BPackedSize = []() {
288  return 2;
289  else
290  return 1;
291  }();
292 
293  __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
294  {
295  return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
296  }
297 
298  __host__ static auto CalculateMPadded(index_t M)
299  {
300  return math::integer_least_multiple(M, MPerBlock);
301  }
302 
303  __host__ static auto CalculateNPadded(index_t N)
304  {
305  return math::integer_least_multiple(N, NPerBlock);
306  }
307 
308  __host__ static auto CalculateKPadded(index_t K)
309  {
310  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
311  }
312 
313  __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
314  {
315  auto K_t = K_Batch * KPerBlock;
316  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
317  }
318 
319  __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
320  {
321  auto K_t = K_Batch * KPerBlock;
322  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
323  }
324 
325  __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
326  {
327  auto K_t = K_Batch * KPerBlock;
328  return (K + K_t - 1) / K_t * KPerBlock;
329  }
330 
331  __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
332  {
333  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
334  auto K_t = K_Batch * KReadVec;
335  return (K + K_t - 1) / K_t * KReadVec;
336  }
337 
338  __host__ static auto CalculateMBlock(index_t M)
339  {
340  return math::integer_divide_ceil(M, MPerBlock);
341  }
342 
343  __host__ static auto CalculateNBlock(index_t N)
344  {
345  return math::integer_divide_ceil(N, NPerBlock);
346  }
347 
348  template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
349  __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&)
350  {
351  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
352  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
353 
355  TileDesc_K0_MN_K1{},
361  }
362 
363  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
364  index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
365  {
366  const auto a_grid_desc_mraw_kraw = [&]() {
367  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
368  {
369  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
370  }
371  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
372  {
373  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
374  }
375  }();
376 
378 
379  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
380  GemmSpec == GemmSpecialization::MNKPadding)
381  {
382  // pad both M and K
383  const auto a_grid_desc_m_k =
384  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
386  make_right_pad_transform(K, KPad - K)),
389 
390  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
391  a_grid_desc_m_k,
396 
397  return a_grid_desc_ak0_m_ak1;
398  }
399  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
400  GemmSpec == GemmSpecialization::MNPadding)
401  {
402  // pad M, but not K
403  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
404  a_grid_desc_mraw_kraw,
406  make_right_pad_transform(M, MPad - M)),
409 
410  return a_grid_desc_ak0_m_ak1;
411  }
412  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
413  GemmSpec == GemmSpecialization::NKPadding)
414  {
415  // pad K, but not M
416  const auto a_grid_desc_m_k = transform_tensor_descriptor(
417  a_grid_desc_mraw_kraw,
421 
422  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
423  a_grid_desc_m_k,
428 
429  return a_grid_desc_ak0_m_ak1;
430  }
431  else
432  {
433  // not pad M or K
434  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
435  a_grid_desc_mraw_kraw,
440 
441  return a_grid_desc_ak0_m_ak1;
442  }
443  }
444 
445  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
446  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
447  {
448  const auto b_grid_desc_nraw_kraw = [&]() {
450  {
451  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
452  }
454  {
455  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
456  }
457  }();
458 
460 
461  static_assert(!(is_same_v<remove_cvref_t<ADataType>, pk_i4_t> &&
462  GemmSpec != GemmSpecialization::Default),
463  "pk_i4_t does not support padding");
464 
465  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
466  GemmSpec == GemmSpecialization::MNKPadding)
467  {
468  // pad both N and K
469  const auto b_grid_desc_n_k =
470  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
472  make_right_pad_transform(K, KPad - K)),
475 
476  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
477  b_grid_desc_n_k,
482 
483  return b_grid_desc_bk0_n_bk1;
484  }
485  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
486  GemmSpec == GemmSpecialization::MNPadding)
487  {
488  // pad N, but not K
489  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
490  b_grid_desc_nraw_kraw,
492  make_right_pad_transform(N, NPad - N)),
495 
496  return b_grid_desc_bk0_n_bk1;
497  }
498  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
499  GemmSpec == GemmSpecialization::MKPadding)
500  {
501  // pad K, but not N
502  const auto b_grid_desc_n_k = transform_tensor_descriptor(
503  b_grid_desc_nraw_kraw,
507 
508  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
509  b_grid_desc_n_k,
514 
515  return b_grid_desc_bk0_n_bk1;
516  }
517  else
518  {
519  if constexpr(!PermuteB)
520  {
521  // not pad N or K
522  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
523  b_grid_desc_nraw_kraw,
528 
529  return b_grid_desc_bk0_n_bk1;
530  }
531  else
532  {
533  // Pre-shuffled Weight
534  // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1]
535  constexpr index_t BK01 = KPerBlock / BK1Value;
536  const index_t BK0_ = StrideB / BK1Value;
537  const index_t BK00 = BK0_ / BK01;
538 
539  const auto b_grid_desc_bk00_n_bk01_bk1_permute =
540  make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value));
541 
542  const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor(
543  b_grid_desc_bk00_n_bk01_bk1_permute,
546  make_pass_through_transform(BK1Value)),
549 
550  return b_grid_desc_bk0_n_bk1_permute;
551  }
552  }
553  }
554 
555  template <typename ABlockDesc_AK0_M_AK1>
556  __host__ __device__ static constexpr auto
557  MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1&)
558  {
559  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
560 
561  return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
562  }
563 
564  template <typename BBlockDesc_BK0_N_BK1>
565  __host__ __device__ static constexpr auto
566  MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1&)
567  {
568  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
569 
570  return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
571  }
572 
573  __host__ __device__ static auto
575  {
576  const auto c_grid_desc_mraw_nraw = [&]() {
578  {
579  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
580  }
582  {
583  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
584  }
585  }();
586 
587  // pad M and N
588  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
590  make_right_pad_transform(N, NPad - N)),
593 #if 0
595 
596  if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
597  GemmSpec == GemmSpecialization::MNKPadding)
598  {
599  // pad M and N
600  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
602  make_right_pad_transform(N, NPad - N)),
605  }
606  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
607  GemmSpec == GemmSpecialization::MKPadding)
608  {
609  // pad M, but not N
611  c_grid_desc_mraw_nraw,
615  }
616  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
617  GemmSpec == GemmSpecialization::NKPadding)
618  {
619  // pad N, but not M
621  c_grid_desc_mraw_nraw,
625  }
626  else
627  {
628  // not pad M or N
629  return c_grid_desc_mraw_nraw;
630  }
631 #endif
632  }
633 
634  struct Problem
635  {
636  __host__ Problem(index_t M_,
637  index_t N_,
638  index_t K_,
639  index_t StrideA_,
640  index_t StrideB_,
641  index_t StrideC_,
642  index_t KBatch_,
643  AElementwiseOperation a_element_op,
644  BElementwiseOperation b_element_op,
645  CElementwiseOperation c_element_op)
646  : M{M_},
647  N{N_},
648  K{K_},
649  StrideA{StrideA_},
650  StrideB{StrideB_},
651  StrideC{StrideC_},
652  KBatch{KBatch_},
655  KRead{CalculateKRead(K_, KBatch_)},
656  KPadded{CalculateKPadded(K_, KBatch_)},
657  AK0{CalculateAK0Padded(K_, KBatch_)},
658  BK0{CalculateBK0Padded(K_, KBatch_)},
659  MBlock{CalculateMBlock(M_)},
660  NBlock{CalculateNBlock(N_)},
661  a_element_op_{a_element_op},
662  b_element_op_{b_element_op},
663  c_element_op_{c_element_op}
664  {
665  }
666 
667  __host__ void Print() const
668  {
669  std::cout << "problem {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
670  << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SC:" << StrideC
671  << ", " << "MP:" << MPadded << ", " << "NP:" << NPadded << ", "
672  << "KRead:" << KRead << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0
673  << ", " << "BK0:" << BK0 << ", " << "MBlock: " << MBlock << ", "
674  << "NBlock: " << NBlock << "}" << std::endl;
675  }
676 
692  AElementwiseOperation a_element_op_;
693  BElementwiseOperation b_element_op_;
694  CElementwiseOperation c_element_op_;
695  };
696 
697  // Argument
699  {
700  __host__ Argument(const ADataType* p_a_grid_,
701  const BDataType* p_b_grid_,
702  CDataType* p_c_grid_,
703  index_t M_,
704  index_t N_,
705  index_t K_,
706  index_t StrideA_,
707  index_t StrideB_,
708  index_t StrideC_,
709  index_t k_batch_,
710  bool is_reduce_ = false,
711  AElementwiseOperation a_element_op = AElementwiseOperation{},
712  BElementwiseOperation b_element_op = BElementwiseOperation{},
713  CElementwiseOperation c_element_op = CElementwiseOperation{})
714  : Problem{M_,
715  N_,
716  K_,
717  StrideA_,
718  StrideB_,
719  StrideC_,
720  k_batch_,
721  a_element_op,
722  b_element_op,
723  c_element_op},
724  p_a_grid{p_a_grid_},
725  p_b_grid{p_b_grid_},
726  p_c_grid{p_c_grid_},
727  is_reduce(is_reduce_)
728  {
729  }
730 
731  __host__ __device__ inline bool IsReduceAdd() const
732  {
733  return (Problem::KBatch > 1) && is_reduce;
734  }
735 
736  __host__ __device__ inline bool IsAtomicAdd() const
737  {
738  return (Problem::KBatch > 1) && (!is_reduce);
739  }
740 
741  const ADataType* p_a_grid;
742  const BDataType* p_b_grid;
743  CDataType* p_c_grid;
744  bool is_reduce;
745  };
746 
748  {
749 
750  __device__ SplitKBatchOffset(Argument& karg)
751  {
752  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
753  {
754  a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
755  }
756  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
757  {
758  a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
759  }
760 
761  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
762  {
763  b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
764  }
765  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
766  {
767  if constexpr(!PermuteB)
768  {
769  b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
770  }
771  else
772  {
773  const int k0_offset = karg.KRead * karg.N;
774  b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
775  }
776  }
777 
778  if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
779  {
780  karg.K = karg.KRead;
781  }
782  else
783  {
784  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
785  }
786 
787  if(karg.IsReduceAdd())
788  {
789  c_reduce_offset = blockIdx.z * karg.M * karg.N;
790  }
791  else
792  {
793  c_reduce_offset = 0;
794  }
795  }
796 
800  };
801 
802  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
803  {
804  // A matrix in LDS memory, dst of blockwise copy
805  if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
806  {
807  // bank conflict when writting the data into LDS, but don't worry, we have whole entire
808  // loop to hide it in v4. it may give you some benefit from less valu in compute address
812  }
813  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
814  // in some cases.
816  {
817  constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize;
818  constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
819  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
820  make_tuple(
821  AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
823 
824  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
825  a_lds_block_desc,
831 
832  constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor(
833  a_lds_block_desc_permuted,
839 
840  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
841  a_lds_block_desc_ak0_mldslayer_m_ak1,
848 
849  return a_lds_block_desc_ak0_m_ak1;
850  }
851  else // ColumnMajor A
852  {
853  // kfold and mpair dimension is not always required.
854  // more dimension in merge_transform increase the difficulty of generating immarg offset
855  // for compiler.
856  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
857  constexpr auto M1 = MPerBlock / M0;
858 
859  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
860  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
861  constexpr auto KThreadRead = 64 / MPerXdl;
862  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
863 
864  constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
865  ? 1
866  : 128 / (AK1Number * M0 * sizeof(ADataType));
867  constexpr auto KThreadReadPerm =
868  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
869  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
870  : KThreadRead;
871 
872  // 1<=mpair<=n0
873  constexpr auto mpair = (AK1Number * MPerXdl * sizeof(ADataType) > 128)
874  ? 1
875  : ((128 / (AK1Number * MPerXdl * sizeof(ADataType))) > M0
876  ? M0
877  : 128 / (AK1Number * MPerXdl * sizeof(ADataType)));
878 
879  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
883  Number<kfold * M0 / mpair>{},
884  Number<mpair>{},
885  AK1Number));
886 
887  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
888  a_lds_block_desc,
889  make_tuple(
893  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
896  make_tuple(
898  make_tuple(
900 
901  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
902  a_lds_block_desc_permuted,
903  make_tuple(
911  Sequence<1>{},
912  Sequence<2>{},
913  Sequence<3>{},
914  Sequence<4>{},
915  Sequence<5>{}),
917  Sequence<2>{},
918  Sequence<0, 3>{},
919  Sequence<4, 5>{},
920  Sequence<6>{},
921  Sequence<7>{}));
922 
923  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
924  a_lds_block_desc_unmerged,
927  Number<KThreadWrite / kfold / KThreadReadPerm>{},
928  Number<kfold>{},
935 
936  return a_lds_block_desc_ak0_m_ak1;
937  }
938  }
939 
940  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
941  {
942  // B matrix in LDS memory, dst of blockwise copy
943  if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
944  {
945  // bank conflict when writting the data into LDS, but don't worry, we have whole entire
946  // loop to hide it in v4. it may give you some benefit from less valu in compute address
950  }
952  {
953  // NLdsLayer * K0 as logical Bank
954  constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize;
955  constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
956  constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
957  make_tuple(
958  BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number),
960 
961  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
962  b_lds_block_desc,
968 
969  constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
970  b_lds_block_desc_permuted,
976 
977  constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
978  b_lds_block_desc_bk0_nldslayer_n_bk1,
985 
986  return b_lds_block_desc_bk0_n_bk1;
987  }
988  else // RowMajor B
989  {
990  constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
991  constexpr auto N1 = NPerBlock / N0;
992 
993  constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
994  constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
995  constexpr auto KThreadRead = 64 / NPerXdl;
996  constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
997 
998  constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
999  ? 1
1000  : 128 / (BK1Number * N0 * sizeof(BDataType));
1001  constexpr auto KThreadReadPerm =
1002  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
1003  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
1004  : KThreadRead;
1005 
1006  // 1<=npair<=n0
1007  constexpr auto npair = (BK1Number * NPerXdl * sizeof(BDataType) > 128)
1008  ? 1
1009  : ((128 / (BK1Number * NPerXdl * sizeof(BDataType))) > N0
1010  ? N0
1011  : 128 / (BK1Number * NPerXdl * sizeof(BDataType)));
1012 
1013  constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
1017  Number<kfold * N0 / npair>{},
1018  Number<npair>{},
1019  BK1Number));
1020 
1021  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
1022  b_lds_block_desc,
1023  make_tuple(
1027  make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
1030  make_tuple(
1032  make_tuple(
1034 
1035  constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
1036  b_lds_block_desc_permuted,
1037  make_tuple(
1045  Sequence<1>{},
1046  Sequence<2>{},
1047  Sequence<3>{},
1048  Sequence<4>{},
1049  Sequence<5>{}),
1051  Sequence<2>{},
1052  Sequence<0, 3>{},
1053  Sequence<4, 5>{},
1054  Sequence<6>{},
1055  Sequence<7>{}));
1056 
1057  constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
1058  b_lds_block_desc_unmerged,
1061  Number<KThreadWrite / kfold / KThreadReadPerm>{},
1062  Number<kfold>{},
1069 
1070  return b_lds_block_desc_bk0_n_bk1;
1071  }
1072  }
1073 
1075  {
1076  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1077  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1078 
1079  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1081  make_tuple(I1,
1083  I1,
1085 
1086  return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1087  }
1088 
1091  BlkGemmPipelineVer,
1092  BlkGemmPipeSched,
1093  BlockSize,
1094  ADataType,
1095  BDataType,
1096  ComputeTypeA,
1097  AccDataType,
1104  ABlockTransferSrcScalarPerVector,
1105  BBlockTransferSrcScalarPerVector,
1106  MPerBlock,
1107  NPerBlock,
1108  KPerBlock,
1109  MPerXdl,
1110  NPerXdl,
1111  MXdlPerWave,
1112  NXdlPerWave,
1113  KPack>())>;
1114 
1115  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1116  {
1117  // LDS allocation for A and B: be careful of alignment
1118  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1119  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1120 
1121  // lds max alignment
1122  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1123 
1124  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1125  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1126 
1127  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1128  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1129 
1130  // LDS allocation for C shuffle in LDS
1131  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1133 
1134  constexpr auto c_block_size =
1135  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1136 
1137  return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize +
1138  b_block_space_size_aligned * sizeof(BDataType) / BPackedSize),
1139  c_block_size * sizeof(CShuffleDataType));
1140  }
1141 
1142  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1143  __host__ static constexpr bool CheckValidity(const Argument& karg)
1144  {
1145  static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1146  (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1147  "Invalid tuning param!");
1148 
1149  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1154  {
1155  if(!(karg.M % MPerBlock == 0))
1156  {
1157  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1158  {
1159  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1160  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1161  << std::endl;
1162  }
1163  return false;
1164  }
1165  }
1166 
1167  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1172  {
1173  if(!(karg.N % NPerBlock == 0))
1174  {
1175  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1176  {
1177  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1178  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1179  << std::endl;
1180  }
1181  return false;
1182  }
1183  }
1184 
1185  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1189  {
1190 
1191  auto K_t = karg.KBatch * KPerBlock;
1192  if(!(karg.K % K_t == 0))
1193  {
1194  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1195  {
1196  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1197  << karg.K << " " << __FILE__ << ":" << __LINE__
1198  << ", in function: " << __func__ << std::endl;
1199  }
1200  return false;
1201  }
1202  }
1203  else
1204  {
1205  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1206  auto K_t = karg.KBatch * KReadVec;
1207  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1208  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1209  {
1210  return false;
1211  }
1212  }
1213 
1215  {
1216  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1217  {
1218  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1219  {
1220  std::cout << "Arg K (" << karg.K
1221  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1222  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1223  << __LINE__ << ", in function: " << __func__ << std::endl;
1224  }
1225  return false;
1226  }
1227  }
1228  else
1229  {
1230  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1231  {
1232  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1233  {
1234  std::cout << "Arg M (" << karg.M
1235  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1236  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1237  << __LINE__ << ", in function: " << __func__ << std::endl;
1238  }
1239  return false;
1240  }
1241  }
1242 
1244  {
1245  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1246  {
1247  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1248  {
1249  std::cout << "Arg N (" << karg.N
1250  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1251  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1252  << __LINE__ << ", in function: " << __func__ << std::endl;
1253  }
1254  return false;
1255  }
1256  }
1257  else
1258  {
1259  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1260  {
1261  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1262  {
1263  std::cout << "Arg K (" << karg.K
1264  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1265  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1266  << __LINE__ << ", in function: " << __func__ << std::endl;
1267  }
1268  return false;
1269  }
1270  }
1271 
1273  {
1274  if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1275  {
1276  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1277  {
1278  std::cout << "Arg N (" << karg.N
1279  << ") value is not a multiple of "
1280  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1281  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1282  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1283  << std::endl;
1284  }
1285  return false;
1286  }
1287  }
1288  else
1289  {
1290  if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1291  {
1292  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1293  {
1294  std::cout << "Arg M (" << karg.M
1295  << ") value is not a multiple of "
1296  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1297  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1298  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1299  << std::endl;
1300  }
1301  return false;
1302  }
1303  }
1304 
1305  if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
1306  is_same<remove_cvref_t<CDataType>, float>::value ||
1309  {
1310  if(!karg.IsReduceAdd())
1311  {
1312  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1313  {
1314  std::cout << " KBatch: " << karg.KBatch << " > 1 is not support yet" << __FILE__
1315  << ":" << __LINE__ << ", in function: " << __func__ << std::endl;
1316  }
1317  if(karg.KBatch > 1)
1318  {
1319  return false;
1320  }
1321  }
1322  }
1323 
1324  // check gridwise gemm pipeline
1325  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1326 
1327  if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
1328  {
1329  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1330  {
1331  return false;
1332  }
1333  }
1334 
1335  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1336  return true;
1337  }
1338 
1339  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1340  {
1341  const index_t num_loop = K / KPerBlock;
1342 
1343  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1344  }
1345 
1346  __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1347  {
1348  const index_t num_loop = K / KPerBlock;
1349 
1350  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1351  }
1352 
1353  template <typename CGridDesc>
1354  __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1355  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1356  {
1357  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1358  c_grid_desc_m_n,
1363 
1364  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1365  }
1366 
1367  // return block_id to C matrix tile idx (m0, n0) mapping
1368  // if arch = gfx942
1370  // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
1371 
1372  template <typename AGridDesc_AK0_M_K1,
1373  typename BGridDesc_BK0_N_K1,
1374  typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1375  bool HasMainKBlockLoop,
1376  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1377  TailNumber TailNum = TailNumber::Odd>
1378  __device__ static void Run(const ADataType* p_a_grid,
1379  const BDataType* p_b_grid,
1380  CDataType* p_c_grid,
1381  void* p_shared,
1382  const Problem& problem,
1383  const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1384  const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1385  const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1386  c_grid_desc_mblock_mperblock_nblock_nperblock)
1387  {
1388  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1389  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1390  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1391  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1392  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1393  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1394 
1395  // divide block work by [M, N]
1396  const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1397 
1398  const auto block_work_idx =
1399  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1400 
1401  if(!block_2_ctile_map.ValidCTileIndex(
1402  block_work_idx,
1403  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1404  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1405  {
1406  return;
1407  }
1408 
1409  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1410  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1411 
1412  // HACK: this force m/n_block_data_idx_on_grid into SGPR
1413  const index_t m_block_data_idx_on_grid =
1414  __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1415 
1416  const index_t n_block_data_idx_on_grid =
1417  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1418 
1419  // lds max alignment
1420  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1421 
1422  // A matrix in LDS memory, dst of blockwise copy
1423  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1424 
1425  // B matrix in LDS memory, dst of blockwise copy
1426  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1427 
1428  // A matrix blockwise copy
1429  auto a_blockwise_copy =
1431  AElementwiseOperation,
1435  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1436  ABlockTransferThreadClusterArrangeOrder,
1437  ADataType,
1438  ADataType,
1439  decltype(a_grid_desc_ak0_m_ak1),
1440  decltype(a_block_desc_ak0_m_ak1),
1441  ABlockTransferSrcAccessOrder,
1443  ABlockTransferSrcVectorDim,
1444  2,
1445  ABlockTransferSrcScalarPerVector,
1446  ABlockTransferDstScalarPerVector_AK1,
1447  1,
1448  1,
1449  AThreadTransferSrcResetCoordinateAfterRun,
1450  true,
1451  BlockwiseGemmPipe::GlobalBufferNum>(
1452  a_grid_desc_ak0_m_ak1,
1453  make_multi_index(0, m_block_data_idx_on_grid, 0),
1454  problem.a_element_op_,
1455  a_block_desc_ak0_m_ak1,
1456  make_multi_index(0, 0, 0),
1458 
1459  // B matrix blockwise copy
1460  auto b_blockwise_copy =
1462  BElementwiseOperation,
1466  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1467  BBlockTransferThreadClusterArrangeOrder,
1468  BDataType,
1469  BDataType,
1470  decltype(b_grid_desc_bk0_n_bk1),
1471  decltype(b_block_desc_bk0_n_bk1),
1472  BBlockTransferSrcAccessOrder,
1474  BBlockTransferSrcVectorDim,
1475  2,
1476  BBlockTransferSrcScalarPerVector,
1477  BBlockTransferDstScalarPerVector_BK1,
1478  1,
1479  1,
1480  BThreadTransferSrcResetCoordinateAfterRun,
1481  true,
1482  BlockwiseGemmPipe::GlobalBufferNum>(
1483  b_grid_desc_bk0_n_bk1,
1484  make_multi_index(0, n_block_data_idx_on_grid, 0),
1485  problem.b_element_op_,
1486  b_block_desc_bk0_n_bk1,
1487  make_multi_index(0, 0, 0),
1489 
1490  // LDS allocation for A and B: be careful of alignment
1491  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1492  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1493 
1494  // Cast after lds
1495  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1496  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1497 
1498  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1499  reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
1500  sizeof(ADataType) /
1501  APackedSize),
1502  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1503 
1504  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1505  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1506 
1507  // Blockwise GEMM pipeline
1508  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1509  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1510  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1511 
1512  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1513  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1514  KPerBlock);
1515 
1516  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1517  a_block_desc_ak0_m_ak1,
1518  a_blockwise_copy,
1519  a_grid_buf,
1520  a_block_buf,
1521  a_block_slice_copy_step,
1522  b_grid_desc_bk0_n_bk1,
1523  b_block_desc_bk0_n_bk1,
1524  b_blockwise_copy,
1525  b_grid_buf,
1526  b_block_buf,
1527  b_block_slice_copy_step,
1528  c_thread_buf,
1529  num_k_block_main_loop);
1530 
1531  // shuffle C and write out
1532  {
1533  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1534  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1535  "wrong!");
1536 
1537  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1538  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1539 
1540  // TODO: hacky, fix it!
1541  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1542  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1543 
1544  // TODO: hacky, fix it!
1545  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1546  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1547  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1548 
1549  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1550  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1551  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1552  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1553  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1554  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1555  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1556  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1557 
1558  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1560 
1561  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1562  static_cast<CShuffleDataType*>(p_shared),
1563  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1564 
1565  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1566  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1567  make_tuple(
1570  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1571  M1, // M1 = MWave
1572  M2, // M2 * M3 * M4 = MPerXdl
1573  M3,
1574  M4)),
1577  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
1578  N1, // N1 = NWave
1579  N2))), // N2 = NPerXdl
1581  make_tuple(
1583 
1584  // calculate origin of thread output tensor on global memory
1585  // blockwise GEMM c matrix starting index
1586  const auto c_thread_mtx_on_block =
1587  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
1588 
1589  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1590  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1591 
1592  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1594  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
1596  make_tuple(Sequence<0>{}));
1597 
1598  const auto m_thread_data_on_block_idx =
1599  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1600  make_multi_index(m_thread_data_on_block));
1601 
1602  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1606  make_tuple(Sequence<0>{}));
1607 
1608  const auto n_thread_data_on_block_idx =
1609  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1610  make_multi_index(n_thread_data_on_block));
1611 
1613  const auto& vpgr_to_lds_element_op = [&] {
1614  if constexpr(DoElementwiseBeforeCShuffle)
1615  {
1616  return problem.c_element_op_;
1617  }
1618  else
1619  {
1620  return pass_through;
1621  }
1622  };
1623  const auto& lds_to_global_element_op = [&] {
1624  if constexpr(!DoElementwiseBeforeCShuffle)
1625  {
1626  return problem.c_element_op_;
1627  }
1628  else
1629  {
1630  return pass_through;
1631  }
1632  };
1633 
1634  // shuffle: threadwise copy C from VGPR to LDS
1635  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
1636  AccDataType,
1637  CShuffleDataType,
1638  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1639  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1640  conditional_t<DoElementwiseBeforeCShuffle,
1641  CElementwiseOperation,
1643  Sequence<CShuffleMXdlPerWavePerShuffle,
1644  CShuffleNXdlPerWavePerShuffle,
1645  I1,
1646  I1,
1647  M2,
1648  I1,
1649  M4,
1650  I1>,
1652  7,
1653  1,
1655  1,
1656  true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1657  make_multi_index(0,
1658  0,
1659  m_thread_data_on_block_idx[I1],
1660  n_thread_data_on_block_idx[I1],
1661  m_thread_data_on_block_idx[I2],
1662  m_thread_data_on_block_idx[I3],
1663  m_thread_data_on_block_idx[I4],
1664  n_thread_data_on_block_idx[I2]),
1665  vpgr_to_lds_element_op()};
1666 
1667  // shuffle: blockwise copy C from LDS to global
1668  auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1669  ThisThreadBlock, // ThreadGroup
1670  conditional_t<!DoElementwiseBeforeCShuffle,
1671  CElementwiseOperation,
1673  CGlobalMemoryDataOperation, // DstInMemOp,
1674  Sequence<1,
1675  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1676  1,
1677  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
1678  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1679  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1680  CShuffleDataType, // typename SrcData,
1681  CDataType, // typename DstData,
1682  decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1683  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1684  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1685  3, // index_t VectorDim,
1686  CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1687  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1688  false> // bool ThreadTransferDstResetCoordinateAfterRun>
1689  {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1690  make_multi_index(0, 0, 0, 0),
1691  c_grid_desc_mblock_mperblock_nblock_nperblock,
1692  make_multi_index(block_m_id, 0, block_n_id, 0),
1693  lds_to_global_element_op()};
1694 
1695  // space filling curve for threadwise C in VGPR
1696  constexpr auto sfc_c_vgpr =
1699  Sequence<CShuffleMXdlPerWavePerShuffle,
1700  CShuffleNXdlPerWavePerShuffle,
1701  1,
1702  1,
1703  M2,
1704  1,
1705  M4,
1706  1>>{};
1707 
1708  // space filling curve for shuffled blockwise C in global mem
1709  constexpr auto sfc_c_global =
1712  Sequence<1,
1713  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1714  1,
1715  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1716 
1717  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1718 
1719  static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1720 
1721  static_for<0, num_access, 1>{}([&](auto access_id) {
1722  // make sure it's safe to write to LDS
1723  block_sync_lds();
1724 
1725  // each thread write its data from VGPR to LDS
1726  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1727  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1728  c_thread_buf,
1729  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1730  c_shuffle_block_buf);
1731 
1732  // make sure it's safe to read from LDS
1733  block_sync_lds();
1734 
1735  // each block copy its data from LDS to global
1736  c_shuffle_block_copy_lds_to_global.Run(
1737  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1738  c_shuffle_block_buf,
1739  c_grid_desc_mblock_mperblock_nblock_nperblock,
1740  c_grid_buf);
1741 
1742  if constexpr(access_id < num_access - 1)
1743  {
1744  constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1745 
1746  // move on C
1747  c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1748  c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1749  }
1750  });
1751  }
1752  }
1753 
1754  template <bool HasMainKBlockLoop,
1755  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1756  TailNumber TailNum = TailNumber::Odd>
1757  __device__ static void Run(const ADataType* p_a_grid,
1758  const BDataType* p_b_grid,
1759  CDataType* p_c_grid,
1760  void* p_shared,
1761  const Problem& problem)
1762  {
1763  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1764  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1765  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
1766  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1767  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
1768  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1769  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1771  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1772 
1773  Run<decltype(a_grid_desc_ak0_m_ak1),
1774  decltype(b_grid_desc_bk0_n_bk1),
1775  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1776  HasMainKBlockLoop,
1777  CGlobalMemoryDataOperation,
1778  TailNum>(p_a_grid,
1779  p_b_grid,
1780  p_c_grid,
1781  p_shared,
1782  problem,
1783  a_grid_desc_ak0_m_ak1,
1784  b_grid_desc_bk0_n_bk1,
1785  c_grid_desc_mblock_mperblock_nblock_nperblock);
1786  }
1787 
1788  template <typename AGridDesc_AK0_M_K1,
1789  typename BGridDesc_BK0_N_K1,
1790  typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1791  bool HasMainKBlockLoop,
1792  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1793  TailNumber TailNum = TailNumber::Odd>
1794  __device__ static void Run_2Lds(const ADataType* p_a_grid,
1795  const BDataType* p_b_grid,
1796  CDataType* p_c_grid,
1797  void* p_shared_0,
1798  void* p_shared_1,
1799  const Problem& problem,
1800  const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1801  const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1802  const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1803  c_grid_desc_mblock_mperblock_nblock_nperblock)
1804  {
1805  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1806  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1807  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1808  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1809  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1810  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1811 
1812  // divide block work by [M, N]
1813  const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1814 
1815  const auto block_work_idx =
1816  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1817 
1818  if(!block_2_ctile_map.ValidCTileIndex(
1819  block_work_idx,
1820  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1821  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1822  {
1823  return;
1824  }
1825 
1826  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1827  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1828 
1829  // HACK: this force m/n_block_data_idx_on_grid into SGPR
1830  const index_t m_block_data_idx_on_grid =
1831  __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1832 
1833  const index_t n_block_data_idx_on_grid =
1834  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1835 
1836  // lds max alignment
1837  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1838 
1839  // A matrix in LDS memory, dst of blockwise copy
1840  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1841 
1842  // B matrix in LDS memory, dst of blockwise copy
1843  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1844 
1845  // A matrix blockwise copy
1846  auto a_blockwise_copy =
1848  AElementwiseOperation,
1852  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1853  ABlockTransferThreadClusterArrangeOrder,
1854  ADataType,
1855  ADataType,
1856  decltype(a_grid_desc_ak0_m_ak1),
1857  decltype(a_block_desc_ak0_m_ak1),
1858  ABlockTransferSrcAccessOrder,
1860  ABlockTransferSrcVectorDim,
1861  2,
1862  ABlockTransferSrcScalarPerVector,
1863  ABlockTransferDstScalarPerVector_AK1,
1864  1,
1865  1,
1866  AThreadTransferSrcResetCoordinateAfterRun,
1867  true,
1868  BlockwiseGemmPipe::GlobalBufferNum>(
1869  a_grid_desc_ak0_m_ak1,
1870  make_multi_index(0, m_block_data_idx_on_grid, 0),
1871  problem.a_element_op_,
1872  a_block_desc_ak0_m_ak1,
1873  make_multi_index(0, 0, 0),
1875 
1876  // B matrix blockwise copy
1877  auto b_blockwise_copy =
1879  BElementwiseOperation,
1883  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1884  BBlockTransferThreadClusterArrangeOrder,
1885  BDataType,
1886  BDataType,
1887  decltype(b_grid_desc_bk0_n_bk1),
1888  decltype(b_block_desc_bk0_n_bk1),
1889  BBlockTransferSrcAccessOrder,
1891  BBlockTransferSrcVectorDim,
1892  2,
1893  BBlockTransferSrcScalarPerVector,
1894  BBlockTransferDstScalarPerVector_BK1,
1895  1,
1896  1,
1897  BThreadTransferSrcResetCoordinateAfterRun,
1898  true,
1899  BlockwiseGemmPipe::GlobalBufferNum>(
1900  b_grid_desc_bk0_n_bk1,
1901  make_multi_index(0, n_block_data_idx_on_grid, 0),
1902  problem.b_element_op_,
1903  b_block_desc_bk0_n_bk1,
1904  make_multi_index(0, 0, 0),
1906 
1907  // LDS allocation for A and B: be careful of alignment
1908  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1909  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1910 
1911  auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1912  static_cast<ADataType*>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1913 
1914  auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1915  bit_cast<BDataType*>(static_cast<char*>(p_shared_0) +
1916  a_block_space_size_aligned * sizeof(ADataType)),
1917  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1918 
1919  auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1920  static_cast<ADataType*>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1921 
1922  auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1923  bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
1924  a_block_space_size_aligned * sizeof(ADataType)),
1925  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1926 
1927  auto a_block_bufs = make_tuple(a_block_buf_ping, a_block_buf_pong);
1928  auto b_block_bufs = make_tuple(b_block_buf_ping, b_block_buf_pong);
1929 
1930  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1931  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1932 
1933  // Blockwise GEMM pipeline
1934  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1935  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1936  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1937 
1938  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1939  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1940  KPerBlock);
1941 
1942  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1943  a_block_desc_ak0_m_ak1,
1944  a_blockwise_copy,
1945  a_grid_buf,
1946  a_block_bufs,
1947  a_block_slice_copy_step,
1948  b_grid_desc_bk0_n_bk1,
1949  b_block_desc_bk0_n_bk1,
1950  b_blockwise_copy,
1951  b_grid_buf,
1952  b_block_bufs,
1953  b_block_slice_copy_step,
1954  c_thread_buf,
1955  num_k_block_main_loop);
1956 
1957  // shuffle C and write out
1958  {
1959  static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1960  NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1961  "wrong!");
1962 
1963  constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1964  constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1965 
1966  // TODO: hacky, fix it!
1967  constexpr auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1968  blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1969 
1970  // TODO: hacky, fix it!
1971  // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths
1972  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1973  blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1974 
1975  constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0);
1976  constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1);
1977  constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2);
1978  constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3);
1979  constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4);
1980  constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5);
1981  constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6);
1982  constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7);
1983 
1984  constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1986 
1987  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1988  static_cast<CShuffleDataType*>(p_shared_0),
1989  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1990 
1991  constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 = transform_tensor_descriptor(
1992  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1993  make_tuple(
1996  Number<CShuffleMXdlPerWavePerShuffle>{}, // M0 (MXdlPerWave) per shuffle
1997  M1, // M1 = MWave
1998  M2, // M2 * M3 * M4 = MPerXdl
1999  M3,
2000  M4)),
2003  Number<CShuffleNXdlPerWavePerShuffle>{}, // N0 (NXdlPerWave) per shuffle
2004  N1, // N1 = NWave
2005  N2))), // N2 = NPerXdl
2007  make_tuple(
2009 
2010  // calculate origin of thread output tensor on global memory
2011  // blockwise GEMM c matrix starting index
2012  const auto c_thread_mtx_on_block =
2013  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0, I0, I0);
2014 
2015  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
2016  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
2017 
2018  const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2020  make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))),
2022  make_tuple(Sequence<0>{}));
2023 
2024  const auto m_thread_data_on_block_idx =
2025  m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2026  make_multi_index(m_thread_data_on_block));
2027 
2028  const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2032  make_tuple(Sequence<0>{}));
2033 
2034  const auto n_thread_data_on_block_idx =
2035  n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2036  make_multi_index(n_thread_data_on_block));
2037 
2038  // shuffle: threadwise copy C from VGPR to LDS
2039  auto c_thread_copy_vgpr_to_lds =
2041  CShuffleDataType,
2042  decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2043  decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2045  Sequence<CShuffleMXdlPerWavePerShuffle,
2046  CShuffleNXdlPerWavePerShuffle,
2047  I1,
2048  I1,
2049  M2,
2050  I1,
2051  M4,
2052  I1>,
2054  7,
2055  1,
2057  1,
2058  true>{
2059  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2060  make_multi_index(0,
2061  0,
2062  m_thread_data_on_block_idx[I1],
2063  n_thread_data_on_block_idx[I1],
2064  m_thread_data_on_block_idx[I2],
2065  m_thread_data_on_block_idx[I3],
2066  m_thread_data_on_block_idx[I4],
2067  n_thread_data_on_block_idx[I2]),
2069 
2070  // shuffle: blockwise copy C from LDS to global
2071  auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
2072  ThisThreadBlock, // ThreadGroup
2073  CElementwiseOperation, // ElementwiseOperation,
2074  CGlobalMemoryDataOperation, // DstInMemOp,
2075  Sequence<1,
2076  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2077  1,
2078  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, // BlockSliceLengths,
2079  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
2080  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
2081  CShuffleDataType, // typename SrcData,
2082  CDataType, // typename DstData,
2083  decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2084  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2085  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
2086  3, // index_t VectorDim,
2087  CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
2088  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
2089  false> // bool ThreadTransferDstResetCoordinateAfterRun>
2090  {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2091  make_multi_index(0, 0, 0, 0),
2092  c_grid_desc_mblock_mperblock_nblock_nperblock,
2093  make_multi_index(block_m_id, 0, block_n_id, 0),
2094  problem.c_element_op_};
2095 
2096  // space filling curve for threadwise C in VGPR
2097  constexpr auto sfc_c_vgpr =
2100  Sequence<CShuffleMXdlPerWavePerShuffle,
2101  CShuffleNXdlPerWavePerShuffle,
2102  1,
2103  1,
2104  M2,
2105  1,
2106  M4,
2107  1>>{};
2108 
2109  // space filling curve for shuffled blockwise C in global mem
2110  constexpr auto sfc_c_global =
2113  Sequence<1,
2114  CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2115  1,
2116  CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2117 
2118  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2119 
2120  static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
2121 
2122  static_for<0, num_access, 1>{}([&](auto access_id) {
2123  // make sure it's safe to write to LDS
2124  block_sync_lds();
2125 
2126  // each thread write its data from VGPR to LDS
2127  c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2128  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2129  c_thread_buf,
2130  c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2131  c_shuffle_block_buf);
2132 
2133  // make sure it's safe to read from LDS
2134  block_sync_lds();
2135 
2136  // each block copy its data from LDS to global
2137  c_shuffle_block_copy_lds_to_global.Run(
2138  c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2139  c_shuffle_block_buf,
2140  c_grid_desc_mblock_mperblock_nblock_nperblock,
2141  c_grid_buf);
2142 
2143  if constexpr(access_id < num_access - 1)
2144  {
2145  constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
2146 
2147  // move on C
2148  c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
2149  c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
2150  }
2151  });
2152  }
2153  }
2154 
2155  template <bool HasMainKBlockLoop,
2156  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
2157  TailNumber TailNum = TailNumber::Odd>
2158  __device__ static void Run_2Lds(const ADataType* p_a_grid,
2159  const BDataType* p_b_grid,
2160  CDataType* p_c_grid,
2161  void* p_shared_0,
2162  void* p_shared_1,
2163  const Problem& problem)
2164  {
2165  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
2166  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
2167  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
2168  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2169  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
2170  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
2171 
2172  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2174  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2175 
2176  Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
2177  decltype(b_grid_desc_bk0_n_bk1),
2178  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2179  HasMainKBlockLoop,
2180  CGlobalMemoryDataOperation,
2181  TailNum>(p_a_grid,
2182  p_b_grid,
2183  p_c_grid,
2184  p_shared_0,
2185  p_shared_1,
2186  problem,
2187  a_grid_desc_ak0_m_ak1,
2188  b_grid_desc_bk0_n_bk1,
2189  c_grid_desc_mblock_mperblock_nblock_nperblock);
2190  }
2191 };
2192 
2193 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
int32_t int32_t
Definition: integer.hpp:10
Definition: ck.hpp:266
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:275
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr auto BlockGemmPipeline_Selector()
Definition: blockwise_gemm_pipeline_wmma_selector.hpp:31
_Float16 half_t
Definition: data_type.hpp:30
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
ushort bhalf_t
Definition: data_type.hpp:29
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:59
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
constexpr bool is_same_v
Definition: type.hpp:283
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:297
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: block_to_ctile_map.hpp:270
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:282
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:699
const BElementwiseOperation b_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:639
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t k_batch_, bool is_reduce_=false, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CElementwiseOperation c_element_op=CElementwiseOperation{})
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:700
const BDataType * p_b_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:742
CDataType * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:743
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:731
const AElementwiseOperation a_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:638
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:736
const ADataType * p_a_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:741
const CElementwiseOperation c_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:640
bool is_reduce
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:744
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:635
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:678
index_t NPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:685
index_t KBatch
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:683
index_t StrideA
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:680
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:636
CElementwiseOperation c_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:694
index_t BK0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:689
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:677
index_t NBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:691
index_t MPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:684
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:679
index_t StrideB
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:681
index_t KPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:687
index_t StrideC
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:682
index_t MBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:690
BElementwiseOperation b_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:693
index_t AK0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:688
index_t KRead
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:686
AElementwiseOperation a_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:692
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:667
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:748
index_t a_k_split_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:797
index_t b_k_split_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:798
__device__ SplitKBatchOffset(Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:750
index_t c_reduce_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:799
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:241
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:566
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:445
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:331
static constexpr auto is_scale_mfma
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:267
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:308
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:298
static constexpr auto BK1Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:255
static constexpr index_t APackedSize
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:279
static constexpr bool is_single_rate_mfma
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:258
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1346
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:277
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1354
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:313
static constexpr auto I2
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:244
static constexpr index_t KPack
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:268
static constexpr auto lcm_AK1_BK1
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:257
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:349
static constexpr auto I7
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:249
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1113
static constexpr auto I5
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:247
static constexpr auto AK1Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:254
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1143
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1794
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:293
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:338
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:2158
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:303
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:557
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:319
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1115
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:940
static constexpr index_t BPackedSize
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:286
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1378
static constexpr auto I6
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:248
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:802
static constexpr auto I1
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:243
static constexpr auto I0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:242
static constexpr auto I3
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:245
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1757
static constexpr auto I4
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:246
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:325
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:363
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:343
static constexpr auto BK0Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:253
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1074
static constexpr auto AK0Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:252
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1339
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:574
Definition: xdlops_gemm.hpp:942
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:197
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:308
#define CK_ENV(name)
Definition: env.hpp:129