/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File
gridwise_gemm_wmma_cshuffle_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck/utility/env.hpp"
17 
18 namespace ck {
19 
20 template <typename GridwiseGemm,
21  bool HasMainKBlockLoop,
22  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
23  index_t MinimumOccupancy = 1,
24  TailNumber TailNum = TailNumber::Full>
25 __global__ void
26 #if CK_USE_LAUNCH_BOUNDS
27  __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
28 #endif
29  kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
30 {
31 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
32 #if defined(__gfx11__)
33  // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
34  using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_c_grid)>>;
35  if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
36  (std::is_same_v<c_data_type, ck::half_t> ||
37  std::is_same_v<c_data_type, ck::bhalf_t>)))
38  {
39 #endif
40  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
41 
42  auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
43 
44  GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
45  karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
46  karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
47  karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
48  p_shared,
49  karg);
50 #if defined(__gfx11__)
51  }
52 #endif
53 #else
54  ignore = karg;
55 #endif
56 }
57 
161 template <typename ALayout,
162  typename BLayout,
163  typename CLayout,
164  typename ADataType,
165  typename BDataType,
166  typename AccDataType,
167  typename CShuffleDataType,
168  typename CDataType,
169  typename AElementwiseOperation,
170  typename BElementwiseOperation,
171  typename CElementwiseOperation,
173  index_t BlockSize,
174  index_t MPerBlock,
175  index_t NPerBlock,
176  index_t KPerBlock,
177  index_t AK1Value,
178  index_t BK1Value,
179  index_t MPerWmma,
180  index_t NPerWmma,
181  index_t MRepeat,
182  index_t NRepeat,
183  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
184  typename ABlockTransferThreadClusterArrangeOrder,
185  typename ABlockTransferSrcAccessOrder,
186  index_t ABlockTransferSrcVectorDim,
187  index_t ABlockTransferSrcScalarPerVector,
188  index_t ABlockTransferDstScalarPerVector_AK1,
189  bool AThreadTransferSrcResetCoordinateAfterRun,
190  index_t ABlockLdsExtraM,
191  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
192  typename BBlockTransferThreadClusterArrangeOrder,
193  typename BBlockTransferSrcAccessOrder,
194  index_t BBlockTransferSrcVectorDim,
195  index_t BBlockTransferSrcScalarPerVector,
196  index_t BBlockTransferDstScalarPerVector_BK1,
197  bool BThreadTransferSrcResetCoordinateAfterRun,
198  index_t BBlockLdsExtraN,
199  index_t CShuffleMRepeatPerShuffle,
200  index_t CShuffleNRepeatPerShuffle,
201  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
202  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
203  BlockGemmPipelineScheduler BlkGemmPipeSched,
204  BlockGemmPipelineVersion BlkGemmPipelineVer,
205  typename ComputeTypeA,
206  typename ComputeTypeB,
207  bool PermuteA,
208  bool PermuteB>
210 {
211  static constexpr auto I0 = Number<0>{};
212  static constexpr auto I1 = Number<1>{};
213  static constexpr auto I2 = Number<2>{};
214  static constexpr auto I3 = Number<3>{};
215  static constexpr auto I4 = Number<4>{};
216  static constexpr auto I5 = Number<5>{};
217  static constexpr auto I6 = Number<6>{};
218  static constexpr auto I7 = Number<7>{};
219 
220  // K1 should be Number<...>
221  static constexpr auto AK0Number = Number<KPerBlock / AK1Value>{};
222  static constexpr auto BK0Number = Number<KPerBlock / BK1Value>{};
223  static constexpr auto AK1Number = Number<AK1Value>{};
224  static constexpr auto BK1Number = Number<BK1Value>{};
225 
226  static constexpr index_t KPack = math::max(
229  .k_per_wmma);
230 
232 
233  static constexpr index_t APackedSize = []() {
235  return 2;
236  else
237  return 1;
238  }();
239 
240  static constexpr index_t BPackedSize = []() {
242  return 2;
243  else
244  return 1;
245  }();
246 
247  __host__ static auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
248  {
249  return std::make_tuple(Block2CTileMap::CalculateGridSize(M, N), 1, KBatch);
250  }
251 
252  __host__ static auto CalculateMPadded(index_t M)
253  {
254  return math::integer_least_multiple(M, MPerBlock);
255  }
256 
257  __host__ static auto CalculateNPadded(index_t N)
258  {
259  return math::integer_least_multiple(N, NPerBlock);
260  }
261 
262  __host__ static auto CalculateKPadded(index_t K)
263  {
264  return math::integer_divide_ceil(K, KPerBlock) * KPerBlock;
265  }
266 
267  __host__ static auto CalculateAK0Padded(index_t K, index_t K_Batch = 1)
268  {
269  auto K_t = K_Batch * KPerBlock;
270  return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
271  }
272 
273  __host__ static auto CalculateBK0Padded(index_t K, index_t K_Batch = 1)
274  {
275  auto K_t = K_Batch * KPerBlock;
276  return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
277  }
278 
279  __host__ static auto CalculateKPadded(index_t K, index_t K_Batch = 1)
280  {
281  auto K_t = K_Batch * KPerBlock;
282  return (K + K_t - 1) / K_t * KPerBlock;
283  }
284 
285  __host__ static auto CalculateKRead(index_t K, index_t K_Batch = 1)
286  {
287  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
288  auto K_t = K_Batch * KReadVec;
289  return (K + K_t - 1) / K_t * KReadVec;
290  }
291 
292  __host__ static auto CalculateMBlock(index_t M)
293  {
294  return math::integer_divide_ceil(M, MPerBlock);
295  }
296 
297  __host__ static auto CalculateNBlock(index_t N)
298  {
299  return math::integer_divide_ceil(N, NPerBlock);
300  }
301 
302  template <index_t MNRepeat, index_t MNWaves, index_t MNPerWmma, typename BlockDesc>
303  __host__ __device__ static constexpr auto MakeWmmaTileDescriptor(const BlockDesc&)
304  {
305  // K0_MN_K1 -> K0_MNRepeat_MNWaves_KRow_MNPerWmma_K1
306  constexpr auto K0 = BlockDesc{}.GetLength(I0);
307  constexpr auto K1 = BlockDesc{}.GetLength(I2);
308 #ifdef __gfx12__
309  constexpr auto KRow = I2;
310 #else
311  constexpr auto KRow = I1;
312 #endif
314  BlockDesc{},
321  }
322 
323  __host__ __device__ static auto MakeAGridDescriptor_AK0_M_AK1(
324  index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
325  {
326  const auto a_grid_desc_mraw_kraw = [&]() {
327  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
328  {
329  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
330  }
331  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
332  {
333  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
334  }
335  }();
336 
338 
339  if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
340  GemmSpec == GemmSpecialization::MNKPadding)
341  {
342  // pad both M and K
343  const auto a_grid_desc_m_k =
344  transform_tensor_descriptor(a_grid_desc_mraw_kraw,
346  make_right_pad_transform(K, KPad - K)),
349 
350  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
351  a_grid_desc_m_k,
356 
357  return a_grid_desc_ak0_m_ak1;
358  }
359  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
360  GemmSpec == GemmSpecialization::MNPadding)
361  {
362  // pad M, but not K
363  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
364  a_grid_desc_mraw_kraw,
366  make_right_pad_transform(M, MPad - M)),
369 
370  return a_grid_desc_ak0_m_ak1;
371  }
372  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
373  GemmSpec == GemmSpecialization::NKPadding)
374  {
375  // pad K, but not M
376  const auto a_grid_desc_m_k = transform_tensor_descriptor(
377  a_grid_desc_mraw_kraw,
381 
382  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
383  a_grid_desc_m_k,
388 
389  return a_grid_desc_ak0_m_ak1;
390  }
391  else
392  {
393  static_assert(!PermuteA, "PermuteA is not supported");
394 
395  // not pad M or K
396  const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor(
397  a_grid_desc_mraw_kraw,
402 
403  return a_grid_desc_ak0_m_ak1;
404  }
405  }
406 
407  __host__ __device__ static auto MakeBGridDescriptor_BK0_N_BK1(
408  index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
409  {
410  const auto b_grid_desc_nraw_kraw = [&]() {
412  {
413  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(I1, StrideB));
414  }
416  {
417  return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1));
418  }
419  }();
420 
422 
423  static_assert(!(is_same_v<remove_cvref_t<BDataType>, pk_i4_t> &&
424  GemmSpec != GemmSpecialization::Default),
425  "pk_i4_t does not support padding");
426 
427  if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
428  GemmSpec == GemmSpecialization::MNKPadding)
429  {
430  // pad both N and K
431  const auto b_grid_desc_n_k =
432  transform_tensor_descriptor(b_grid_desc_nraw_kraw,
434  make_right_pad_transform(K, KPad - K)),
437 
438  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
439  b_grid_desc_n_k,
444 
445  return b_grid_desc_bk0_n_bk1;
446  }
447  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
448  GemmSpec == GemmSpecialization::MNPadding)
449  {
450  // pad N, but not K
451  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
452  b_grid_desc_nraw_kraw,
454  make_right_pad_transform(N, NPad - N)),
457 
458  return b_grid_desc_bk0_n_bk1;
459  }
460  else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
461  GemmSpec == GemmSpecialization::MKPadding)
462  {
463  // pad K, but not N
464  const auto b_grid_desc_n_k = transform_tensor_descriptor(
465  b_grid_desc_nraw_kraw,
469 
470  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
471  b_grid_desc_n_k,
476 
477  return b_grid_desc_bk0_n_bk1;
478  }
479  else
480  {
481  if constexpr(!PermuteB)
482  {
483  // not pad N or K
484  const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
485  b_grid_desc_nraw_kraw,
490 
491  return b_grid_desc_bk0_n_bk1;
492  }
493  else
494  {
495  // Pre-shuffled Weight
496  // BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1]
497  constexpr index_t BK01 = KPerBlock / BK1Value;
498  const index_t BK0_ = StrideB / BK1Value;
499  const index_t BK00 = BK0_ / BK01;
500 
501  const auto b_grid_desc_bk00_n_bk01_bk1_permute =
502  make_naive_tensor_descriptor_packed(make_tuple(BK00, N, BK01, BK1Value));
503 
504  const auto b_grid_desc_bk0_n_bk1_permute = transform_tensor_descriptor(
505  b_grid_desc_bk00_n_bk01_bk1_permute,
508  make_pass_through_transform(BK1Value)),
511 
512  return b_grid_desc_bk0_n_bk1_permute;
513  }
514  }
515  }
516 
517  template <typename ABlockDesc_AK0_M_AK1>
518  __host__ __device__ static constexpr auto MakeAWmmaTileDescriptor(const ABlockDesc_AK0_M_AK1&)
519  {
520  constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
521 
522  return MakeWmmaTileDescriptor<MRepeat, MWaves, MPerWmma>(ABlockDesc_AK0_M_AK1{});
523  }
524 
525  template <typename BBlockDesc_BK0_N_BK1>
526  __host__ __device__ static constexpr auto MakeBWmmaTileDescriptor(const BBlockDesc_BK0_N_BK1&)
527  {
528  constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
529 
530  return MakeWmmaTileDescriptor<NRepeat, NWaves, NPerWmma>(BBlockDesc_BK0_N_BK1{});
531  }
532 
533  __host__ __device__ static auto
535  {
536  const auto c_grid_desc_mraw_nraw = [&]() {
538  {
539  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
540  }
542  {
543  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
544  }
545  }();
546 
547  // pad M and N
548  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
550  make_right_pad_transform(N, NPad - N)),
553  // TODO: Investigate why this path is not used in the original
554  // gridwise_gemm_xdl_cshuffle_v3.hpp
555 #if 0
557 
558  if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
559  GemmSpec == GemmSpecialization::MNKPadding)
560  {
561  // pad M and N
562  return transform_tensor_descriptor(c_grid_desc_mraw_nraw,
564  make_right_pad_transform(N, NPad - N)),
567  }
568  else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
569  GemmSpec == GemmSpecialization::MKPadding)
570  {
571  // pad M, but not N
573  c_grid_desc_mraw_nraw,
577  }
578  else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
579  GemmSpec == GemmSpecialization::NKPadding)
580  {
581  // pad N, but not M
583  c_grid_desc_mraw_nraw,
587  }
588  else
589  {
590  // not pad M or N
591  return c_grid_desc_mraw_nraw;
592  }
593 #endif
594  }
595 
596  struct Problem
597  {
598  __host__ Problem(index_t M_,
599  index_t N_,
600  index_t K_,
601  index_t StrideA_,
602  index_t StrideB_,
603  index_t StrideC_,
604  index_t KBatch_)
605  : M{M_},
606  N{N_},
607  K{K_},
608  StrideA{StrideA_},
609  StrideB{StrideB_},
610  StrideC{StrideC_},
611  KBatch{KBatch_},
614  KRead{CalculateKRead(K_, KBatch_)},
615  KPadded{CalculateKPadded(K_, KBatch_)},
616  AK0{CalculateAK0Padded(K_, KBatch_)},
617  BK0{CalculateBK0Padded(K_, KBatch_)},
618  MBlock{CalculateMBlock(M_)},
620  {
621  }
622 
623  __host__ void Print() const
624  {
625  std::cout << "problem {"
626  << "M:" << M << ", "
627  << "N:" << N << ", "
628  << "K:" << K << ", "
629  << "SA:" << StrideA << ", "
630  << "SB:" << StrideB << ", "
631  << "SC:" << StrideC << ", "
632  << "MP:" << MPadded << ", "
633  << "NP:" << NPadded << ", "
634  << "KRead:" << KRead << ", "
635  << "KP:" << KPadded << ", "
636  << "AK0:" << AK0 << ", "
637  << "BK0:" << BK0 << ", "
638  << "MBlock: " << MBlock << ", "
639  << "NBlock: " << NBlock << "}" << std::endl;
640  }
641 
657  };
658 
659  // Argument
661  {
662  __host__ Argument(const ADataType* p_a_grid_,
663  const BDataType* p_b_grid_,
664  CDataType* p_c_grid_,
665  index_t M_,
666  index_t N_,
667  index_t K_,
668  index_t StrideA_,
669  index_t StrideB_,
670  index_t StrideC_,
671  index_t k_batch_,
672  bool is_reduce_ = false)
673  : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_},
674  p_a_grid{p_a_grid_},
675  p_b_grid{p_b_grid_},
676  p_c_grid{p_c_grid_},
677  is_reduce(is_reduce_)
678  {
679  }
680 
681  __host__ __device__ inline bool IsReduceAdd() const
682  {
683  return (Problem::KBatch > 1) && is_reduce;
684  }
685 
686  __host__ __device__ inline bool IsAtomicAdd() const
687  {
688  return (Problem::KBatch > 1) && (!is_reduce);
689  }
690 
691  const ADataType* p_a_grid;
692  const BDataType* p_b_grid;
693  CDataType* p_c_grid;
694  bool is_reduce;
695  };
696 
698  {
699 
700  __device__ SplitKBatchOffset(Argument& karg)
701  {
702  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
703  {
704  a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
705  }
706  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
707  {
708  a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
709  }
710 
711  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
712  {
713  b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
714  }
715  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
716  {
717  if constexpr(!PermuteB)
718  {
719  b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
720  }
721  else
722  {
723  const int k0_offset = karg.KRead * karg.N;
724  b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
725  }
726  }
727 
728  if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
729  {
730  karg.K = karg.KRead;
731  }
732  else
733  {
734  karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
735  }
736 
737  if(karg.IsReduceAdd())
738  {
739  c_reduce_offset = blockIdx.z * karg.M * karg.N;
740  }
741  else
742  {
743  c_reduce_offset = 0;
744  }
745  }
746 
750  };
751 
752  __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
753  {
754  // A matrix in LDS memory, dst of blockwise copy
755  if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
756  {
757  // bank conflict when writting the data into LDS, but don't worry, we have whole entire
758  // loop to hide it in v4. it may give you some benefit from less valu in compute address
762  }
763  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
764  // in some cases.
766  {
767  constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(ADataType) / APackedSize;
768  constexpr auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
769  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
770  make_tuple(
771  AK0Number * Number<MLdsLayer>{}, Number<MPerBlock / MLdsLayer>{}, AK1Number),
773 
774  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
775  a_lds_block_desc,
781 
782  constexpr auto a_lds_block_desc_ak0_mldslayer_m_ak1 = transform_tensor_descriptor(
783  a_lds_block_desc_permuted,
789 
790  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
791  a_lds_block_desc_ak0_mldslayer_m_ak1,
798 
799  return a_lds_block_desc_ak0_m_ak1;
800  }
801  else // ColumnMajor A
802  {
803  // kfold and mpair dimension is not always required.
804  // more dimension in merge_transform increase the difficulty of generating immarg offset
805  // for compiler.
806  constexpr auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I1);
807  constexpr auto M1 = MPerBlock / M0;
808 
809  constexpr auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(I0);
810  constexpr auto K0PerThreadWrite = AK0Number / KThreadWrite;
811  constexpr auto KThreadRead = 64 / MPerWmma;
812  constexpr auto K0PerThreadRead = AK0Number / KThreadRead;
813 
814  constexpr auto kfold = (AK1Number * M0 * sizeof(ADataType) > 128)
815  ? 1
816  : 128 / (AK1Number * M0 * sizeof(ADataType));
817  constexpr auto KThreadReadPerm =
818  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
819  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
820  : KThreadRead;
821 
822  // 1<=mpair<=n0
823  constexpr auto mpair = (AK1Number * MPerWmma * sizeof(ADataType) > 128)
824  ? 1
825  : ((128 / (AK1Number * MPerWmma * sizeof(ADataType))) > M0
826  ? M0
827  : 128 / (AK1Number * MPerWmma * sizeof(ADataType)));
828 
829  constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
833  Number<kfold * M0 / mpair>{},
834  Number<mpair>{},
835  AK1Number));
836 
837  constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
838  a_lds_block_desc,
839  make_tuple(
843  make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
846  make_tuple(
848  make_tuple(
850 
851  constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
852  a_lds_block_desc_permuted,
853  make_tuple(
861  Sequence<1>{},
862  Sequence<2>{},
863  Sequence<3>{},
864  Sequence<4>{},
865  Sequence<5>{}),
867  Sequence<2>{},
868  Sequence<0, 3>{},
869  Sequence<4, 5>{},
870  Sequence<6>{},
871  Sequence<7>{}));
872 
873  constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
874  a_lds_block_desc_unmerged,
877  Number<KThreadWrite / kfold / KThreadReadPerm>{},
878  Number<kfold>{},
885 
886  return a_lds_block_desc_ak0_m_ak1;
887  }
888  }
889 
890  __device__ static constexpr auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
891  {
892  // B matrix in LDS memory, dst of blockwise copy
893  if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4)
894  {
895  // bank conflict when writting the data into LDS, but don't worry, we have whole entire
896  // loop to hide it in v4. it may give you some benefit from less valu in compute address
900  }
902  {
903  // NLdsLayer * K0 as logical Bank
904  constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(BDataType) / BPackedSize;
905  constexpr index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
906  constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
907  make_tuple(
908  BK0Number * Number<NLdsLayer>{}, Number<NPerBlock / NLdsLayer>{}, BK1Number),
910 
911  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
912  b_lds_block_desc,
918 
919  constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
920  b_lds_block_desc_permuted,
926 
927  constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
928  b_lds_block_desc_bk0_nldslayer_n_bk1,
935 
936  return b_lds_block_desc_bk0_n_bk1;
937  }
938  else // RowMajor B
939  {
940  constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
941  constexpr auto N1 = NPerBlock / N0;
942 
943  constexpr auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
944  constexpr auto K0PerThreadWrite = BK0Number / KThreadWrite;
945  constexpr auto KThreadRead = 64 / NPerWmma;
946  constexpr auto K0PerThreadRead = BK0Number / KThreadRead;
947 
948  constexpr auto kfold = (BK1Number * N0 * sizeof(BDataType) > 128)
949  ? 1
950  : 128 / (BK1Number * N0 * sizeof(BDataType));
951  constexpr auto KThreadReadPerm =
952  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
953  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
954  : KThreadRead;
955 
956  // 1<=npair<=n0
957  constexpr auto npair = (BK1Number * NPerWmma * sizeof(BDataType) > 128)
958  ? 1
959  : ((128 / (BK1Number * NPerWmma * sizeof(BDataType))) > N0
960  ? N0
961  : 128 / (BK1Number * NPerWmma * sizeof(BDataType)));
962 
963  constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
967  Number<kfold * N0 / npair>{},
968  Number<npair>{},
969  BK1Number));
970 
971  constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
972  b_lds_block_desc,
973  make_tuple(
977  make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
980  make_tuple(
982  make_tuple(
984 
985  constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
986  b_lds_block_desc_permuted,
987  make_tuple(
995  Sequence<1>{},
996  Sequence<2>{},
997  Sequence<3>{},
998  Sequence<4>{},
999  Sequence<5>{}),
1001  Sequence<2>{},
1002  Sequence<0, 3>{},
1003  Sequence<4, 5>{},
1004  Sequence<6>{},
1005  Sequence<7>{}));
1006 
1007  constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
1008  b_lds_block_desc_unmerged,
1011  Number<KThreadWrite / kfold / KThreadReadPerm>{},
1012  Number<kfold>{},
1019 
1020  return b_lds_block_desc_bk0_n_bk1;
1021  }
1022  }
1023 
1024  __host__ __device__ static constexpr auto
1025  // *Caution Here repeat is shuffle repeat
1027  {
1028  constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
1029  constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
1030 
1031  constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
1033  make_tuple(I1,
1035  I1,
1037 
1038  return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
1039  }
1040 
1042  decltype(BlockGemmPipeline_Selector<
1043  BlkGemmPipelineVer,
1044  BlkGemmPipeSched,
1045  BlockSize,
1046  ADataType,
1047  BDataType,
1048  ComputeTypeA,
1049  ComputeTypeB,
1050  AccDataType,
1053  ABlockTransferSrcScalarPerVector,
1054  BBlockTransferSrcScalarPerVector,
1055  MPerBlock,
1056  NPerBlock,
1057  KPerBlock,
1058  MPerWmma,
1059  NPerWmma,
1060  MRepeat,
1061  NRepeat,
1062  KPack>())>;
1063 
1064  __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
1065  {
1066  // LDS allocation for A and B: be careful of alignment
1067  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1068  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1069 
1070  // lds max alignment
1071  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1072 
1073  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1074  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1075 
1076  constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
1077  b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1078 
1079  // LDS allocation for C shuffle in LDS
1080  constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
1082 
1083  constexpr auto c_block_size =
1084  c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
1085  .GetElementSpaceSize();
1086 
1087  return math::max((a_block_space_size_aligned * sizeof(ADataType) / APackedSize +
1088  b_block_space_size_aligned * sizeof(BDataType) / BPackedSize),
1089  c_block_size * sizeof(CShuffleDataType));
1090  }
1091 
1092  // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
1093  __host__ static constexpr bool CheckValidity(const Argument& karg)
1094  {
1095  static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) &&
1096  (NPerBlock % (NPerWmma * NRepeat)) == 0,
1097  "Invalid tuning param!");
1098 
1099  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding ||
1104  {
1105  if(!(karg.M % MPerBlock == 0))
1106  {
1107  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1108  {
1109  std::cout << "Arg M value is not a multiple of MPerBlock! M: " << karg.M << " "
1110  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1111  << std::endl;
1112  }
1113  return false;
1114  }
1115  }
1116 
1117  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::NPadding ||
1122  {
1123  if(!(karg.N % NPerBlock == 0))
1124  {
1125  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1126  {
1127  std::cout << "Arg N value is not a multiple of NPerBlock! N: " << karg.N << " "
1128  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1129  << std::endl;
1130  }
1131  return false;
1132  }
1133  }
1134 
1135  if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::KPadding ||
1139  {
1140 
1141  auto K_t = karg.KBatch * KPerBlock;
1142  if(!(karg.K % K_t == 0))
1143  {
1144  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1145  {
1146  std::cout << "Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1147  << karg.K << " " << __FILE__ << ":" << __LINE__
1148  << ", in function: " << __func__ << std::endl;
1149  }
1150  return false;
1151  }
1152  }
1153  else
1154  {
1155  constexpr auto KReadVec = math::lcm(AK1Number, BK1Number);
1156  auto K_t = karg.KBatch * KReadVec;
1157  auto KReadPadSplited = math::integer_divide_ceil(karg.K, K_t) * KReadVec;
1158  if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1159  {
1160  return false;
1161  }
1162  }
1163 
1165  {
1166  if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1167  {
1168  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1169  {
1170  std::cout << "Arg K (" << karg.K
1171  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1172  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1173  << __LINE__ << ", in function: " << __func__ << std::endl;
1174  }
1175  return false;
1176  }
1177  }
1178  else
1179  {
1180  if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1181  {
1182  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1183  {
1184  std::cout << "Arg M (" << karg.M
1185  << ") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1186  << ABlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1187  << __LINE__ << ", in function: " << __func__ << std::endl;
1188  }
1189  return false;
1190  }
1191  }
1192 
1194  {
1195  if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1196  {
1197  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1198  {
1199  std::cout << "Arg N (" << karg.N
1200  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1201  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1202  << __LINE__ << ", in function: " << __func__ << std::endl;
1203  }
1204  return false;
1205  }
1206  }
1207  else
1208  {
1209  if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1210  {
1211  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1212  {
1213  std::cout << "Arg K (" << karg.K
1214  << ") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1215  << BBlockTransferSrcScalarPerVector << " )! " << __FILE__ << ":"
1216  << __LINE__ << ", in function: " << __func__ << std::endl;
1217  }
1218  return false;
1219  }
1220  }
1221 
1223  {
1224  if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1225  {
1226  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1227  {
1228  std::cout << "Arg N (" << karg.N
1229  << ") value is not a multiple of "
1230  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1231  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1232  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1233  << std::endl;
1234  }
1235  return false;
1236  }
1237  }
1238  else
1239  {
1240  if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1241  {
1242  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1243  {
1244  std::cout << "Arg M (" << karg.M
1245  << ") value is not a multiple of "
1246  "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1247  << CShuffleBlockTransferScalarPerVector_NPerBlock << " )! "
1248  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1249  << std::endl;
1250  }
1251  return false;
1252  }
1253  }
1254 
1255  if constexpr(!(is_same<remove_cvref_t<CDataType>, half_t>::value ||
1256  is_same<remove_cvref_t<CDataType>, float>::value ||
1259  {
1260  if(!karg.IsReduceAdd())
1261  {
1262  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1263  {
1264  std::cout << " KBatch: " << karg.KBatch << " > 1 is not supported yet"
1265  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__
1266  << std::endl;
1267  }
1268  if(karg.KBatch > 1)
1269  {
1270  return false;
1271  }
1272  }
1273  }
1274 
1275  // check gridwise gemm pipeline
1276  const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1277 
1278  if constexpr(BlkGemmPipelineVer != BlockGemmPipelineVersion::v1)
1279  {
1280  if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1281  {
1282  return false;
1283  }
1284  }
1285 
1286  // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
1287  return true;
1288  }
1289 
1290  __host__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
1291  {
1292  const index_t num_loop = K / KPerBlock;
1293 
1294  return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1295  }
1296 
1297  __host__ static constexpr TailNumber CalculateKBlockLoopTailNum(index_t K)
1298  {
1299  const index_t num_loop = K / KPerBlock;
1300 
1301  return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1302  }
1303 
1304  template <typename CGridDesc>
1305  __host__ __device__ static constexpr auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
1306  const CGridDesc& c_grid_desc_m_n, index_t MBlock, index_t NBlock)
1307  {
1308  const auto c_grid_desc_mblock_mperblock_nblock_nperblock = transform_tensor_descriptor(
1309  c_grid_desc_m_n,
1314 
1315  return c_grid_desc_mblock_mperblock_nblock_nperblock;
1316  }
1317 
1318  // return block_id to C matrix tile idx (m0, n0) mapping
1319  // if arch = gfx942
1321  // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
1322 
1323  template <typename AGridDesc_AK0_M_K1,
1324  typename BGridDesc_BK0_N_K1,
1325  typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1326  bool HasMainKBlockLoop,
1327  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1328  TailNumber TailNum = TailNumber::Odd>
1329  __device__ static void Run(const ADataType* p_a_grid,
1330  const BDataType* p_b_grid,
1331  CDataType* p_c_grid,
1332  void* p_shared,
1333  const Problem& problem,
1334  const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1335  const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1336  const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1337  c_grid_desc_mblock_mperblock_nblock_nperblock)
1338  {
1339  const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1340  p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1341  const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1342  p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1343  auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1344  p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1345 
1346  const AElementwiseOperation a_element_op{};
1347  const BElementwiseOperation b_element_op{};
1348  const CElementwiseOperation c_element_op{};
1349 
1350  // divide block work by [M, N]
1351  const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
1352 
1353  const auto block_work_idx =
1354  block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
1355 
1356  if(!block_2_ctile_map.ValidCTileIndex(
1357  block_work_idx,
1358  make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
1359  c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
1360  {
1361  return;
1362  }
1363 
1364  const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
1365  const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
1366 
1367  // HACK: this force m/n_block_data_idx_on_grid into SGPR
1368  const index_t m_block_data_idx_on_grid =
1369  __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1370 
1371  const index_t n_block_data_idx_on_grid =
1372  __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1373 
1374  // lds max alignment
1375  constexpr auto max_lds_align = math::lcm(AK1Number, BK1Number);
1376 
1377  // A matrix in LDS memory, dst of blockwise copy
1378  constexpr auto a_block_desc_ak0_m_ak1 = GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1();
1379 
1380  // B matrix in LDS memory, dst of blockwise copy
1381  constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1();
1382 
1383  // A matrix blockwise copy
1384  auto a_blockwise_copy =
1386  AElementwiseOperation,
1390  ABlockTransferThreadClusterLengths_AK0_M_AK1,
1391  ABlockTransferThreadClusterArrangeOrder,
1392  ADataType,
1393  ADataType,
1394  decltype(a_grid_desc_ak0_m_ak1),
1395  decltype(a_block_desc_ak0_m_ak1),
1396  ABlockTransferSrcAccessOrder,
1398  ABlockTransferSrcVectorDim,
1399  2,
1400  ABlockTransferSrcScalarPerVector,
1401  ABlockTransferDstScalarPerVector_AK1,
1402  1,
1403  1,
1404  AThreadTransferSrcResetCoordinateAfterRun,
1405  true,
1406  BlockwiseGemmPipe::GlobalBufferNum>(
1407  a_grid_desc_ak0_m_ak1,
1408  make_multi_index(0, m_block_data_idx_on_grid, 0),
1409  a_element_op,
1410  a_block_desc_ak0_m_ak1,
1411  make_multi_index(0, 0, 0),
1413 
1414  // B matrix blockwise copy
1415  auto b_blockwise_copy =
1417  BElementwiseOperation,
1421  BBlockTransferThreadClusterLengths_BK0_N_BK1,
1422  BBlockTransferThreadClusterArrangeOrder,
1423  BDataType,
1424  BDataType,
1425  decltype(b_grid_desc_bk0_n_bk1),
1426  decltype(b_block_desc_bk0_n_bk1),
1427  BBlockTransferSrcAccessOrder,
1429  BBlockTransferSrcVectorDim,
1430  2,
1431  BBlockTransferSrcScalarPerVector,
1432  BBlockTransferDstScalarPerVector_BK1,
1433  1,
1434  1,
1435  BThreadTransferSrcResetCoordinateAfterRun,
1436  true,
1437  BlockwiseGemmPipe::GlobalBufferNum>(
1438  b_grid_desc_bk0_n_bk1,
1439  make_multi_index(0, n_block_data_idx_on_grid, 0),
1440  b_element_op,
1441  b_block_desc_bk0_n_bk1,
1442  make_multi_index(0, 0, 0),
1444 
1445  // LDS allocation for A and B: be careful of alignment
1446  constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
1447  a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1448 
1449  // Cast after lds
1450  auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1451  static_cast<ADataType*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1452 
1453  auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1454  reinterpret_cast<BDataType*>(static_cast<char*>(p_shared) + a_block_space_size_aligned *
1455  sizeof(ADataType) /
1456  APackedSize),
1457  b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1458 
1459  constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
1460  constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
1461 
1462  // Blockwise GEMM pipeline
1463  static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1464  auto blockwise_gemm_pipeline = BlockwiseGemmPipe{};
1465  auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1466 
1467  const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1468  (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) /
1469  KPerBlock);
1470 
1471  blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1472  a_block_desc_ak0_m_ak1,
1473  a_blockwise_copy,
1474  a_grid_buf,
1475  a_block_buf,
1476  a_block_slice_copy_step,
1477  b_grid_desc_bk0_n_bk1,
1478  b_block_desc_bk0_n_bk1,
1479  b_blockwise_copy,
1480  b_grid_buf,
1481  b_block_buf,
1482  b_block_slice_copy_step,
1483  c_thread_buf,
1484  num_k_block_main_loop);
1485 
1486  // shuffle C and write out
1487  {
1488  // C mapping in single thread.
1489  constexpr auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
1490  blockwise_gemm_pipeline
1491  .GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
1492 
1493  // C mapping in single block
1494  constexpr auto
1495  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
1496  blockwise_gemm_pipeline
1497  .GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
1498 
1499  constexpr auto MWave =
1500  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
1501  .GetLength(I1);
1502  constexpr auto MSubGroup =
1503  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
1504  .GetLength(I2);
1505  constexpr auto NWave =
1506  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
1507  .GetLength(I4);
1508  constexpr auto NThreadPerSubGroup =
1509  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
1510  .GetLength(I5);
1511  constexpr auto MAccVgprs =
1512  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
1513  .GetLength(I6);
1514 
1515  // LDS descriptor, shuffle and write out in MRepeat x NRepeat times
1516  constexpr auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
1518 
1519  auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1520  static_cast<CShuffleDataType*>(p_shared),
1521  c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
1522  .GetElementSpaceSize());
1523 
1524  constexpr auto
1525  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
1527  c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
1528  make_tuple(
1531  Number<CShuffleMRepeatPerShuffle>{}, // MRepeat per shuffle repeat
1532  MWave, // MWave
1533  MSubGroup, // MSubGroup * MAccVgprs = MPerWmma
1534  MAccVgprs)),
1537  Number<CShuffleNRepeatPerShuffle>{}, // NRepeat per shuffle repeat
1538  NWave, // NWave
1539  NThreadPerSubGroup))), // NThreadPerSubGroup = NPerWmma
1543  Sequence<>{},
1544  Sequence<3, 4, 5>{}));
1545 
1546  // calculate origin of thread output tensor on global memory
1547  // blockwise GEMM c matrix starting index
1548  const auto c_thread_mtx_on_block =
1549  blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(I0, I0);
1550 
1551  const index_t m_thread_data_on_block = c_thread_mtx_on_block[I0];
1552  const index_t n_thread_data_on_block = c_thread_mtx_on_block[I1];
1553 
1554  const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
1556  MRepeat, MWave, MSubGroup, MAccVgprs))),
1558  make_tuple(Sequence<0>{}));
1559 
1560  const auto m_thread_data_on_block_idx =
1561  m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor
1562  .CalculateBottomIndex(make_multi_index(m_thread_data_on_block));
1563 
1564  const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
1566  NRepeat, NWave, NThreadPerSubGroup))),
1568  make_tuple(Sequence<0>{}));
1569 
1570  const auto n_thread_data_on_block_idx =
1571  n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor
1572  .CalculateBottomIndex(make_multi_index(n_thread_data_on_block));
1573 
1574  // shuffle: threadwise copy C from VGPR to LDS
1575  auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
1576  AccDataType,
1577  CShuffleDataType,
1578  decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
1579  decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
1581  Sequence<CShuffleMRepeatPerShuffle,
1582  I1,
1583  I1,
1584  CShuffleNRepeatPerShuffle,
1585  I1,
1586  I1,
1587  MAccVgprs>,
1589  6,
1590  1, // vector write pixel
1592  1,
1593  true>{
1594  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
1595  make_multi_index(0,
1596  m_thread_data_on_block_idx[I1],
1597  m_thread_data_on_block_idx[I2],
1598  0,
1599  n_thread_data_on_block_idx[I1],
1600  n_thread_data_on_block_idx[I2],
1601  m_thread_data_on_block_idx[I3]),
1603 
1604  // shuffle: blockwise copy C from LDS to global
1605  auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1<
1606  ThisThreadBlock, // ThreadGroup
1607  CElementwiseOperation, // ElementwiseOperation,
1608  CGlobalMemoryDataOperation, // DstInMemOp,
1609  Sequence<1,
1610  CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1611  1,
1612  CShuffleNRepeatPerShuffle * NWave * NPerWmma>, // BlockSliceLengths,
1613  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1614  Sequence<0, 1, 2, 3>, // typename ThreadClusterArrangeOrder,
1615  CShuffleDataType, // typename SrcData,
1616  CDataType, // typename DstData,
1617  decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
1618  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1619  Sequence<0, 1, 2, 3>, // typename DimAccessOrder,
1620  3, // index_t VectorDim,
1621  CShuffleBlockTransferScalarPerVector_NPerBlock, // index_t ScalarPerVector,
1622  true, // bool ThreadTransferSrcResetCoordinateAfterRun,
1623  false> // bool ThreadTransferDstResetCoordinateAfterRun>
1624  {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
1625  make_multi_index(0, 0, 0, 0),
1626  c_grid_desc_mblock_mperblock_nblock_nperblock,
1627  make_multi_index(block_work_idx[I0], 0, block_work_idx[I1], 0),
1628  c_element_op};
1629 
1630  // space filling curve for local reg & global memory
1631  // space filling curve for threadwise C in VGPR
1632  constexpr auto sfc_c_vgpr =
1635  Sequence<CShuffleMRepeatPerShuffle,
1636  1,
1637  1,
1638  CShuffleNRepeatPerShuffle,
1639  1,
1640  1,
1641  MAccVgprs>>{};
1642 
1643  // space filling curve for shuffled blockwise C in global mem
1644  constexpr auto sfc_c_global =
1647  Sequence<1,
1648  CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1649  1,
1650  CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{};
1651 
1652  constexpr index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1653 
1654  static_assert(num_access == sfc_c_global.GetNumOfAccess(), "wrong!");
1655 
1656  static_for<0, num_access, 1>{}([&](auto access_id) {
1657  // make sure it's safe to write to LDS
1658  block_sync_lds();
1659 
1660  // each thread write its data from VGPR to LDS
1661  c_thread_copy_vgpr_to_lds.Run(
1662  c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
1663  sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1664  c_thread_buf,
1665  c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
1666  c_shuffle_block_buf);
1667 
1668  // make sure it's safe to read from LDS
1669  block_sync_lds();
1670 
1671  // each block copy its data from LDS to global
1672  c_shuffle_block_copy_lds_to_global.Run(
1673  c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
1674  c_shuffle_block_buf,
1675  c_grid_desc_mblock_mperblock_nblock_nperblock,
1676  c_grid_buf);
1677 
1678  if constexpr(access_id < num_access - 1)
1679  {
1680  constexpr auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1681 
1682  // move on C
1683  c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1684  c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1685  }
1686  });
1687  }
1688  }
1689 
1690  template <bool HasMainKBlockLoop,
1691  InMemoryDataOperationEnum CGlobalMemoryDataOperation,
1692  TailNumber TailNum = TailNumber::Odd>
1693  __device__ static void Run(const ADataType* p_a_grid,
1694  const BDataType* p_b_grid,
1695  CDataType* p_c_grid,
1696  void* p_shared,
1697  const Problem& problem)
1698  {
1699  const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
1700  problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1701  const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
1702  problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1703  const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
1704  problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1705  const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1707  c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1708 
1709  Run<decltype(a_grid_desc_ak0_m_ak1),
1710  decltype(b_grid_desc_bk0_n_bk1),
1711  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1712  HasMainKBlockLoop,
1713  CGlobalMemoryDataOperation,
1714  TailNum>(p_a_grid,
1715  p_b_grid,
1716  p_c_grid,
1717  p_shared,
1718  problem,
1719  a_grid_desc_ak0_m_ak1,
1720  b_grid_desc_bk0_n_bk1,
1721  c_grid_desc_mblock_mperblock_nblock_nperblock);
1722  }
1723 };
1724 
1725 } // namespace ck
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
int32_t int32_t
Definition: integer.hpp:10
Definition: ck.hpp:269
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:278
typename remove_pointer< T >::type remove_pointer_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr auto BlockGemmPipeline_Selector()
Definition: blockwise_gemm_pipeline_wmma_selector.hpp:31
_Float16 half_t
Definition: data_type.hpp:30
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
ushort bhalf_t
Definition: data_type.hpp:29
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
__global__ void kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:29
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:139
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
constexpr bool is_same_v
Definition: type.hpp:283
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:300
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: block_to_ctile_map.hpp:270
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:282
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:661
CDataType * p_c_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:693
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t k_batch_, bool is_reduce_=false)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:662
bool is_reduce
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:694
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:681
const ADataType * p_a_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:691
const BDataType * p_b_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:692
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:686
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:597
index_t M
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:642
index_t KPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:652
index_t NPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:650
index_t NBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:656
index_t K
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:644
__host__ void Print() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:623
index_t N
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:643
index_t AK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:653
index_t BK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:654
index_t KBatch
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:648
index_t MPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:649
index_t MBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:655
index_t StrideA
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:645
index_t StrideB
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:646
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:598
index_t StrideC
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:647
index_t KRead
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:651
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:698
index_t c_reduce_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:749
index_t b_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:748
index_t a_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:747
__device__ SplitKBatchOffset(Argument &karg)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:700
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:210
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:252
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:323
static constexpr auto BK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:224
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:534
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, decltype(MakeAWmmaTileDescriptor(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBWmmaTileDescriptor(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >())> BlockwiseGemmPipe
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1062
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:240
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:273
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1305
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:752
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1297
__host__ static constexpr __device__ auto MakeBWmmaTileDescriptor(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:526
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:279
static constexpr auto I6
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:217
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1064
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:285
static constexpr auto I5
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:216
__host__ static constexpr __device__ auto MakeAWmmaTileDescriptor(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:518
static constexpr auto I7
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:218
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:233
static constexpr auto I4
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:215
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1093
__host__ static constexpr __device__ auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1026
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:247
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:231
static constexpr auto I1
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:212
static constexpr auto I0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:211
static constexpr auto I3
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:214
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:257
static constexpr auto AK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:221
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1290
static constexpr auto AK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:223
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:292
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:267
static constexpr auto BK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:222
static constexpr index_t KPack
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:226
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1329
static constexpr auto I2
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:213
__host__ static constexpr __device__ auto MakeWmmaTileDescriptor(const BlockDesc &)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:303
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:297
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:262
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:890
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:407
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1693
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: wmma_gemm.hpp:553
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:186
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:308
#define CK_ENV(name)
Definition: env.hpp:128