26 template <
typename GridwiseGemm,
 
   27           bool HasMainKBlockLoop,
 
   32 #if CK_USE_LAUNCH_BOUNDS 
   38 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) 
   39     __shared__ 
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
 
   41     auto splitk_batch_offset = 
typename GridwiseGemm::SplitKBatchOffset(karg);
 
   43     GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
 
   44         karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
 
   45         karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
 
   46         karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
 
   54 template <
typename GridwiseGemm,
 
   55           bool HasMainKBlockLoop,
 
   60 #if CK_USE_LAUNCH_BOUNDS 
   66 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) 
   69     __shared__ 
char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
 
   70     __shared__ 
char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
 
   72     auto splitk_batch_offset = 
typename GridwiseGemm::SplitKBatchOffset(karg);
 
   74     GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
 
   75         karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
 
   76         karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
 
   77         karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
 
  191 template <
typename ALayout,
 
  196           typename AccDataType,
 
  197           typename CShuffleDataType,
 
  199           typename AElementwiseOperation,
 
  200           typename BElementwiseOperation,
 
  201           typename CElementwiseOperation,
 
  213           typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
  214           typename ABlockTransferThreadClusterArrangeOrder,
 
  215           typename ABlockTransferSrcAccessOrder,
 
  216           index_t ABlockTransferSrcVectorDim,
 
  217           index_t ABlockTransferSrcScalarPerVector,
 
  218           index_t ABlockTransferDstScalarPerVector_AK1,
 
  219           bool AThreadTransferSrcResetCoordinateAfterRun,
 
  221           typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
  222           typename BBlockTransferThreadClusterArrangeOrder,
 
  223           typename BBlockTransferSrcAccessOrder,
 
  224           index_t BBlockTransferSrcVectorDim,
 
  225           index_t BBlockTransferSrcScalarPerVector,
 
  226           index_t BBlockTransferDstScalarPerVector_BK1,
 
  227           bool BThreadTransferSrcResetCoordinateAfterRun,
 
  229           index_t CShuffleMXdlPerWavePerShuffle,
 
  230           index_t CShuffleNXdlPerWavePerShuffle,
 
  231           typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
  232           index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
 
  235           typename ComputeTypeA                       = CDataType,
 
  236           typename ComputeTypeB                       = ComputeTypeA,
 
  237           bool PermuteA                               = 
false,
 
  238           bool PermuteB                               = 
false,
 
  239           bool DoElementwiseBeforeCShuffle            = 
false>
 
  264           KPerBlock < 128 && MPerXdl == 16))
 
  315         auto K_t = K_Batch * KPerBlock;
 
  316         return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
 
  321         auto K_t = K_Batch * KPerBlock;
 
  322         return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
 
  327         auto K_t = K_Batch * KPerBlock;
 
  328         return (K + K_t - 1) / K_t * KPerBlock;
 
  334         auto K_t                = K_Batch * KReadVec;
 
  335         return (K + K_t - 1) / K_t * KReadVec;
 
  348     template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, 
typename TileDesc_K0_MN_K1>
 
  366         const auto a_grid_desc_mraw_kraw = [&]() {
 
  367             if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
 
  371             else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
 
  379         if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
 
  380                      GemmSpec == GemmSpecialization::MNKPadding)
 
  383             const auto a_grid_desc_m_k =
 
  397             return a_grid_desc_ak0_m_ak1;
 
  399         else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
 
  400                           GemmSpec == GemmSpecialization::MNPadding)
 
  404                 a_grid_desc_mraw_kraw,
 
  410             return a_grid_desc_ak0_m_ak1;
 
  412         else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
 
  413                           GemmSpec == GemmSpecialization::NKPadding)
 
  417                 a_grid_desc_mraw_kraw,
 
  429             return a_grid_desc_ak0_m_ak1;
 
  435                 a_grid_desc_mraw_kraw,
 
  441             return a_grid_desc_ak0_m_ak1;
 
  448         const auto b_grid_desc_nraw_kraw = [&]() {
 
  462                         GemmSpec != GemmSpecialization::Default),
 
  463                       "pk_i4_t does not support padding");
 
  465         if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
 
  466                      GemmSpec == GemmSpecialization::MNKPadding)
 
  469             const auto b_grid_desc_n_k =
 
  483             return b_grid_desc_bk0_n_bk1;
 
  485         else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
 
  486                           GemmSpec == GemmSpecialization::MNPadding)
 
  490                 b_grid_desc_nraw_kraw,
 
  496             return b_grid_desc_bk0_n_bk1;
 
  498         else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
 
  499                           GemmSpec == GemmSpecialization::MKPadding)
 
  503                 b_grid_desc_nraw_kraw,
 
  515             return b_grid_desc_bk0_n_bk1;
 
  519             if constexpr(!PermuteB)
 
  523                     b_grid_desc_nraw_kraw,
 
  529                 return b_grid_desc_bk0_n_bk1;
 
  535                 constexpr 
index_t BK01 = KPerBlock / BK1Value;
 
  536                 const index_t BK0_     = StrideB / BK1Value;
 
  537                 const index_t BK00     = BK0_ / BK01;
 
  539                 const auto b_grid_desc_bk00_n_bk01_bk1_permute =
 
  543                     b_grid_desc_bk00_n_bk01_bk1_permute,
 
  550                 return b_grid_desc_bk0_n_bk1_permute;
 
  555     template <
typename ABlockDesc_AK0_M_AK1>
 
  556     __host__ __device__ 
static constexpr 
auto 
  559         constexpr 
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
 
  561         return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
 
  564     template <
typename BBlockDesc_BK0_N_BK1>
 
  565     __host__ __device__ 
static constexpr 
auto 
  568         constexpr 
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
 
  570         return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
 
  573     __host__ __device__ 
static auto 
  576         const auto c_grid_desc_mraw_nraw = [&]() {
 
  596         if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
 
  597                      GemmSpec == GemmSpecialization::MNKPadding)
 
  606         else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
 
  607                           GemmSpec == GemmSpecialization::MKPadding)
 
  611                 c_grid_desc_mraw_nraw,
 
  616         else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
 
  617                           GemmSpec == GemmSpecialization::NKPadding)
 
  621                 c_grid_desc_mraw_nraw,
 
  629             return c_grid_desc_mraw_nraw;
 
  643                          AElementwiseOperation a_element_op,
 
  644                          BElementwiseOperation b_element_op,
 
  645                          CElementwiseOperation c_element_op)
 
  669             std::cout << 
"problem {" 
  678                       << 
"KRead:" << 
KRead << 
", " 
  680                       << 
"AK0:" << 
AK0 << 
", " 
  681                       << 
"BK0:" << 
BK0 << 
", " 
  682                       << 
"MBlock: " << 
MBlock << 
", " 
  683                       << 
"NBlock: " << 
NBlock << 
"}" << std::endl;
 
  710                           const BDataType* p_b_grid_,
 
  711                           CDataType* p_c_grid_,
 
  719                           bool is_reduce_                    = 
false,
 
  720                           AElementwiseOperation 
a_element_op = AElementwiseOperation{},
 
  721                           BElementwiseOperation 
b_element_op = BElementwiseOperation{},
 
  722                           CElementwiseOperation 
c_element_op = CElementwiseOperation{})
 
  761             if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
 
  765             else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
 
  770             if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
 
  774             else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
 
  776                 if constexpr(!PermuteB)
 
  782                     const int k0_offset = karg.
KRead * karg.
N;
 
  787             if(blockIdx.z < 
static_cast<uint32_t
>(karg.
KBatch - 1))
 
  827             constexpr 
auto MLdsLayer        = LdsSize < 1 ? 1 : LdsSize;
 
  842                 a_lds_block_desc_permuted,
 
  850                 a_lds_block_desc_ak0_mldslayer_m_ak1,
 
  858             return a_lds_block_desc_ak0_m_ak1;
 
  865             constexpr 
auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
 
  866             constexpr 
auto M1 = MPerBlock / M0;
 
  868             constexpr 
auto KThreadWrite     = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
 
  869             constexpr 
auto K0PerThreadWrite = 
AK0Number / KThreadWrite;
 
  870             constexpr 
auto KThreadRead      = 64 / MPerXdl;
 
  871             constexpr 
auto K0PerThreadRead  = 
AK0Number / KThreadRead;
 
  873             constexpr 
auto kfold = (
AK1Number * M0 * 
sizeof(ADataType) > 128)
 
  875                                        : 128 / (
AK1Number * M0 * 
sizeof(ADataType));
 
  876             constexpr 
auto KThreadReadPerm =
 
  877                 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
 
  878                     ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
 
  882             constexpr 
auto mpair = (
AK1Number * MPerXdl * 
sizeof(ADataType) > 128)
 
  884                                        : ((128 / (
AK1Number * MPerXdl * 
sizeof(ADataType))) > M0
 
  886                                               : 128 / (
AK1Number * MPerXdl * 
sizeof(ADataType)));
 
  892                            Number<kfold * M0 / mpair>{},
 
  911                 a_lds_block_desc_permuted,
 
  933                 a_lds_block_desc_unmerged,
 
  936                                           Number<KThreadWrite / kfold / KThreadReadPerm>{},
 
  945             return a_lds_block_desc_ak0_m_ak1;
 
  964             constexpr 
index_t NLdsLayer     = LdsSize < 1 ? 1 : LdsSize;
 
  979                 b_lds_block_desc_permuted,
 
  987                 b_lds_block_desc_bk0_nldslayer_n_bk1,
 
  995             return b_lds_block_desc_bk0_n_bk1;
 
  999             constexpr 
auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I1);
 
 1000             constexpr 
auto N1 = NPerBlock / N0;
 
 1002             constexpr 
auto KThreadWrite     = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I0);
 
 1003             constexpr 
auto K0PerThreadWrite = 
BK0Number / KThreadWrite;
 
 1004             constexpr 
auto KThreadRead      = 64 / NPerXdl;
 
 1005             constexpr 
auto K0PerThreadRead  = 
BK0Number / KThreadRead;
 
 1007             constexpr 
auto kfold = (
BK1Number * N0 * 
sizeof(BDataType) > 128)
 
 1009                                        : 128 / (
BK1Number * N0 * 
sizeof(BDataType));
 
 1010             constexpr 
auto KThreadReadPerm =
 
 1011                 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
 
 1012                     ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
 
 1016             constexpr 
auto npair = (
BK1Number * NPerXdl * 
sizeof(BDataType) > 128)
 
 1018                                        : ((128 / (
BK1Number * NPerXdl * 
sizeof(BDataType))) > N0
 
 1020                                               : 128 / (
BK1Number * NPerXdl * 
sizeof(BDataType)));
 
 1026                            Number<kfold * N0 / npair>{},
 
 1045                 b_lds_block_desc_permuted,
 
 1067                 b_lds_block_desc_unmerged,
 
 1070                                           Number<KThreadWrite / kfold / KThreadReadPerm>{},
 
 1079             return b_lds_block_desc_bk0_n_bk1;
 
 1085         constexpr 
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
 
 1086         constexpr 
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
 
 1088         constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
 1095         return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
 
 1113                                 ABlockTransferSrcScalarPerVector,
 
 1114                                 BBlockTransferSrcScalarPerVector,
 
 1134             a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
 
 1137             b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
 
 1140         constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
 1143         constexpr 
auto c_block_size =
 
 1144             c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
 
 1147                           b_block_space_size_aligned * 
sizeof(BDataType) / 
BPackedSize),
 
 1148                          c_block_size * 
sizeof(CShuffleDataType));
 
 1154         static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
 
 1155                           (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
 
 1156                       "Invalid tuning param!");
 
 1164             if(!(karg.
M % MPerBlock == 0))
 
 1168                     std::cout << 
"Arg M value is not a multiple of MPerBlock! M: " << karg.
M << 
" " 
 1169                               << __FILE__ << 
":" << __LINE__ << 
", in function: " << __func__
 
 1182             if(!(karg.
N % NPerBlock == 0))
 
 1186                     std::cout << 
"Arg N value is not a multiple of NPerBlock! N: " << karg.
N << 
" " 
 1187                               << __FILE__ << 
":" << __LINE__ << 
", in function: " << __func__
 
 1200             auto K_t = karg.
KBatch * KPerBlock;
 
 1201             if(!(karg.
K % K_t == 0))
 
 1205                     std::cout << 
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " 
 1206                               << karg.
K << 
" " << __FILE__ << 
":" << __LINE__
 
 1207                               << 
", in function: " << __func__ << std::endl;
 
 1215             auto K_t                = karg.
KBatch * KReadVec;
 
 1217             if((KReadPadSplited * (karg.
KBatch - 1)) >= karg.
K)
 
 1225             if(karg.
K % ABlockTransferSrcScalarPerVector != 0)
 
 1229                     std::cout << 
"Arg K (" << karg.
K 
 1230                               << 
") value is not a multiple of ABlockTransferSrcScalarPerVector (" 
 1231                               << ABlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1232                               << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1239             if(karg.
M % ABlockTransferSrcScalarPerVector != 0)
 
 1243                     std::cout << 
"Arg M (" << karg.
M 
 1244                               << 
") value is not a multiple of ABlockTransferSrcScalarPerVector (" 
 1245                               << ABlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1246                               << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1254             if(karg.
N % BBlockTransferSrcScalarPerVector != 0)
 
 1258                     std::cout << 
"Arg N (" << karg.
N 
 1259                               << 
") value is not a multiple of BBlockTransferSrcScalarPerVector (" 
 1260                               << BBlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1261                               << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1268             if(karg.
K % BBlockTransferSrcScalarPerVector != 0)
 
 1272                     std::cout << 
"Arg K (" << karg.
K 
 1273                               << 
") value is not a multiple of BBlockTransferSrcScalarPerVector (" 
 1274                               << BBlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1275                               << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1283             if(karg.
N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
 
 1287                     std::cout << 
"Arg N (" << karg.
N 
 1288                               << 
") value is not a multiple of " 
 1289                                  "CShuffleBlockTransferScalarPerVector_NPerBlock (" 
 1290                               << CShuffleBlockTransferScalarPerVector_NPerBlock << 
" )! " 
 1291                               << __FILE__ << 
":" << __LINE__ << 
", in function: " << __func__
 
 1299             if(karg.
M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
 
 1303                     std::cout << 
"Arg M (" << karg.
M 
 1304                               << 
") value is not a multiple of " 
 1305                                  "CShuffleBlockTransferScalarPerVector_NPerBlock (" 
 1306                               << CShuffleBlockTransferScalarPerVector_NPerBlock << 
" )! " 
 1307                               << __FILE__ << 
":" << __LINE__ << 
", in function: " << __func__
 
 1323                     std::cout << 
" KBatch: " << karg.
KBatch << 
" > 1 is not support yet" << __FILE__
 
 1324                               << 
":" << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1334         const auto num_k_loop = karg.
AK0 / (KPerBlock / AK1Value);
 
 1338             if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
 
 1350         const index_t num_loop = K / KPerBlock;
 
 1352         return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
 
 1357         const index_t num_loop = K / KPerBlock;
 
 1359         return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
 
 1362     template <
typename CGr
idDesc>
 
 1364         const CGridDesc& c_grid_desc_m_n, 
index_t MBlock, 
index_t NBlock)
 
 1373         return c_grid_desc_mblock_mperblock_nblock_nperblock;
 
 1381     template <
typename AGridDesc_AK0_M_K1,
 
 1382               typename BGridDesc_BK0_N_K1,
 
 1383               typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
 
 1384               bool HasMainKBlockLoop,
 
 1387     __device__ 
static void Run(
const ADataType* p_a_grid,
 
 1388                                const BDataType* p_b_grid,
 
 1389                                CDataType* p_c_grid,
 
 1392                                const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
 
 1393                                const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
 
 1394                                const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
 
 1395                                    c_grid_desc_mblock_mperblock_nblock_nperblock)
 
 1397         const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1398             p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
 
 1399         const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1400             p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
 
 1401         auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1402             p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
 1407         const auto block_work_idx =
 
 1410         if(!block_2_ctile_map.ValidCTileIndex(
 
 1412                make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
 
 1413                           c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
 
 1418         const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
 
 1419         const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
 
 1422         const index_t m_block_data_idx_on_grid =
 
 1423             __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
 
 1425         const index_t n_block_data_idx_on_grid =
 
 1426             __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
 
 1438         auto a_blockwise_copy =
 
 1440                                                 AElementwiseOperation,
 
 1444                                                 ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
 1445                                                 ABlockTransferThreadClusterArrangeOrder,
 
 1448                                                 decltype(a_grid_desc_ak0_m_ak1),
 
 1449                                                 decltype(a_block_desc_ak0_m_ak1),
 
 1450                                                 ABlockTransferSrcAccessOrder,
 
 1452                                                 ABlockTransferSrcVectorDim,
 
 1454                                                 ABlockTransferSrcScalarPerVector,
 
 1455                                                 ABlockTransferDstScalarPerVector_AK1,
 
 1458                                                 AThreadTransferSrcResetCoordinateAfterRun,
 
 1460                                                 BlockwiseGemmPipe::GlobalBufferNum>(
 
 1461                 a_grid_desc_ak0_m_ak1,
 
 1464                 a_block_desc_ak0_m_ak1,
 
 1469         auto b_blockwise_copy =
 
 1471                                                 BElementwiseOperation,
 
 1475                                                 BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
 1476                                                 BBlockTransferThreadClusterArrangeOrder,
 
 1479                                                 decltype(b_grid_desc_bk0_n_bk1),
 
 1480                                                 decltype(b_block_desc_bk0_n_bk1),
 
 1481                                                 BBlockTransferSrcAccessOrder,
 
 1483                                                 BBlockTransferSrcVectorDim,
 
 1485                                                 BBlockTransferSrcScalarPerVector,
 
 1486                                                 BBlockTransferDstScalarPerVector_BK1,
 
 1489                                                 BThreadTransferSrcResetCoordinateAfterRun,
 
 1491                                                 BlockwiseGemmPipe::GlobalBufferNum>(
 
 1492                 b_grid_desc_bk0_n_bk1,
 
 1495                 b_block_desc_bk0_n_bk1,
 
 1501             a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
 
 1504         auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1505             static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
 
 1507         auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1508             reinterpret_cast<BDataType*
>(
static_cast<char*
>(p_shared) + a_block_space_size_aligned *
 
 1511             b_block_desc_bk0_n_bk1.GetElementSpaceSize());
 
 1517         static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
 
 1519         auto c_thread_buf            = blockwise_gemm_pipeline.GetCThreadBuffer();
 
 1521         const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
 
 1522             (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
 
 1525         blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
 
 1526                                                                          a_block_desc_ak0_m_ak1,
 
 1530                                                                          a_block_slice_copy_step,
 
 1531                                                                          b_grid_desc_bk0_n_bk1,
 
 1532                                                                          b_block_desc_bk0_n_bk1,
 
 1536                                                                          b_block_slice_copy_step,
 
 1538                                                                          num_k_block_main_loop);
 
 1542             static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
 
 1543                               NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
 
 1546             constexpr 
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
 
 1547             constexpr 
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
 
 1550             constexpr 
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
 
 1551                 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
 1555             constexpr 
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
 
 1556                 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
 1558             constexpr 
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
 
 1559             constexpr 
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
 
 1560             constexpr 
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
 
 1561             constexpr 
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
 
 1562             constexpr 
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
 
 1563             constexpr 
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
 
 1564             constexpr 
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
 
 1565             constexpr 
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
 
 1567             constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
 1570             auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1571                 static_cast<CShuffleDataType*
>(p_shared),
 
 1572                 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
 1575                 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
 1595             const auto c_thread_mtx_on_block =
 
 1596                 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0, 
I0, 
I0, 
I0);
 
 1598             const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
 
 1599             const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
 
 1601             const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
 
 1607             const auto m_thread_data_on_block_idx =
 
 1608                 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
 
 1611             const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
 
 1617             const auto n_thread_data_on_block_idx =
 
 1618                 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
 
 1622             const auto& vpgr_to_lds_element_op = [&] {
 
 1623                 if constexpr(DoElementwiseBeforeCShuffle)
 
 1629                     return pass_through;
 
 1632             const auto& lds_to_global_element_op = [&] {
 
 1633                 if constexpr(!DoElementwiseBeforeCShuffle)
 
 1639                     return pass_through;
 
 1647                 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
 1648                 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
 1650                               CElementwiseOperation,
 
 1652                 Sequence<CShuffleMXdlPerWavePerShuffle,
 
 1653                          CShuffleNXdlPerWavePerShuffle,
 
 1665                 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 1668                                        m_thread_data_on_block_idx[
I1],
 
 1669                                        n_thread_data_on_block_idx[
I1],
 
 1670                                        m_thread_data_on_block_idx[
I2],
 
 1671                                        m_thread_data_on_block_idx[
I3],
 
 1672                                        m_thread_data_on_block_idx[
I4],
 
 1673                                        n_thread_data_on_block_idx[
I2]),
 
 1674                       vpgr_to_lds_element_op()};
 
 1680                               CElementwiseOperation,
 
 1682                 CGlobalMemoryDataOperation, 
 
 1684                          CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
 1686                          CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, 
 
 1687                 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
 1691                 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
 
 1692                 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
 
 1695                 CShuffleBlockTransferScalarPerVector_NPerBlock, 
 
 1698                 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
 1700                  c_grid_desc_mblock_mperblock_nblock_nperblock,
 
 1702                  lds_to_global_element_op()};
 
 1705             constexpr 
auto sfc_c_vgpr =
 
 1708                                   Sequence<CShuffleMXdlPerWavePerShuffle,
 
 1709                                            CShuffleNXdlPerWavePerShuffle,
 
 1718             constexpr 
auto sfc_c_global =
 
 1722                                            CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
 1724                                            CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
 
 1726             constexpr 
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
 
 1728             static_assert(num_access == sfc_c_global.GetNumOfAccess(), 
"wrong!");
 
 1735                 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 1736                                               sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
 
 1738                                               c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 1739                                               c_shuffle_block_buf);
 
 1745                 c_shuffle_block_copy_lds_to_global.Run(
 
 1746                     c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
 1747                     c_shuffle_block_buf,
 
 1748                     c_grid_desc_mblock_mperblock_nblock_nperblock,
 
 1751                 if constexpr(access_id < num_access - 1)
 
 1753                     constexpr 
auto c_global_step = sfc_c_global.GetForwardStep(access_id);
 
 1756                     c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
 
 1757                         c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
 
 1763     template <
bool HasMainKBlockLoop,
 
 1766     __device__ 
static void Run(
const ADataType* p_a_grid,
 
 1767                                const BDataType* p_b_grid,
 
 1768                                CDataType* p_c_grid,
 
 1778         const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
 
 1782         Run<decltype(a_grid_desc_ak0_m_ak1),
 
 1783             decltype(b_grid_desc_bk0_n_bk1),
 
 1784             decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
 
 1786             CGlobalMemoryDataOperation,
 
 1792                      a_grid_desc_ak0_m_ak1,
 
 1793                      b_grid_desc_bk0_n_bk1,
 
 1794                      c_grid_desc_mblock_mperblock_nblock_nperblock);
 
 1797     template <
typename AGridDesc_AK0_M_K1,
 
 1798               typename BGridDesc_BK0_N_K1,
 
 1799               typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
 
 1800               bool HasMainKBlockLoop,
 
 1803     __device__ 
static void Run_2Lds(
const ADataType* p_a_grid,
 
 1804                                     const BDataType* p_b_grid,
 
 1805                                     CDataType* p_c_grid,
 
 1809                                     const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
 
 1810                                     const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
 
 1811                                     const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
 
 1812                                         c_grid_desc_mblock_mperblock_nblock_nperblock)
 
 1814         const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1815             p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
 
 1816         const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1817             p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
 
 1818         auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1819             p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
 1824         const auto block_work_idx =
 
 1827         if(!block_2_ctile_map.ValidCTileIndex(
 
 1829                make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
 
 1830                           c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
 
 1835         const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
 
 1836         const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
 
 1839         const index_t m_block_data_idx_on_grid =
 
 1840             __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
 
 1842         const index_t n_block_data_idx_on_grid =
 
 1843             __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
 
 1855         auto a_blockwise_copy =
 
 1857                                                 AElementwiseOperation,
 
 1861                                                 ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
 1862                                                 ABlockTransferThreadClusterArrangeOrder,
 
 1865                                                 decltype(a_grid_desc_ak0_m_ak1),
 
 1866                                                 decltype(a_block_desc_ak0_m_ak1),
 
 1867                                                 ABlockTransferSrcAccessOrder,
 
 1869                                                 ABlockTransferSrcVectorDim,
 
 1871                                                 ABlockTransferSrcScalarPerVector,
 
 1872                                                 ABlockTransferDstScalarPerVector_AK1,
 
 1875                                                 AThreadTransferSrcResetCoordinateAfterRun,
 
 1877                                                 BlockwiseGemmPipe::GlobalBufferNum>(
 
 1878                 a_grid_desc_ak0_m_ak1,
 
 1881                 a_block_desc_ak0_m_ak1,
 
 1886         auto b_blockwise_copy =
 
 1888                                                 BElementwiseOperation,
 
 1892                                                 BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
 1893                                                 BBlockTransferThreadClusterArrangeOrder,
 
 1896                                                 decltype(b_grid_desc_bk0_n_bk1),
 
 1897                                                 decltype(b_block_desc_bk0_n_bk1),
 
 1898                                                 BBlockTransferSrcAccessOrder,
 
 1900                                                 BBlockTransferSrcVectorDim,
 
 1902                                                 BBlockTransferSrcScalarPerVector,
 
 1903                                                 BBlockTransferDstScalarPerVector_BK1,
 
 1906                                                 BThreadTransferSrcResetCoordinateAfterRun,
 
 1908                                                 BlockwiseGemmPipe::GlobalBufferNum>(
 
 1909                 b_grid_desc_bk0_n_bk1,
 
 1912                 b_block_desc_bk0_n_bk1,
 
 1918             a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
 
 1920         auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1921             static_cast<ADataType*
>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
 
 1923         auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1924             bit_cast<BDataType*>(
static_cast<char*
>(p_shared_0) +
 
 1925                                  a_block_space_size_aligned * 
sizeof(ADataType)),
 
 1926             b_block_desc_bk0_n_bk1.GetElementSpaceSize());
 
 1928         auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1929             static_cast<ADataType*
>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
 
 1931         auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1932             bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
 
 1933                                  a_block_space_size_aligned * 
sizeof(ADataType)),
 
 1934             b_block_desc_bk0_n_bk1.GetElementSpaceSize());
 
 1936         auto a_block_bufs = 
make_tuple(a_block_buf_ping, a_block_buf_pong);
 
 1937         auto b_block_bufs = 
make_tuple(b_block_buf_ping, b_block_buf_pong);
 
 1943         static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
 
 1945         auto c_thread_buf            = blockwise_gemm_pipeline.GetCThreadBuffer();
 
 1947         const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
 
 1948             (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
 
 1951         blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
 
 1952                                                                          a_block_desc_ak0_m_ak1,
 
 1956                                                                          a_block_slice_copy_step,
 
 1957                                                                          b_grid_desc_bk0_n_bk1,
 
 1958                                                                          b_block_desc_bk0_n_bk1,
 
 1962                                                                          b_block_slice_copy_step,
 
 1964                                                                          num_k_block_main_loop);
 
 1968             static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
 
 1969                               NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
 
 1972             constexpr 
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
 
 1973             constexpr 
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
 
 1976             constexpr 
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
 
 1977                 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
 1981             constexpr 
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
 
 1982                 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
 1984             constexpr 
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
 
 1985             constexpr 
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
 
 1986             constexpr 
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
 
 1987             constexpr 
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
 
 1988             constexpr 
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
 
 1989             constexpr 
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
 
 1990             constexpr 
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
 
 1991             constexpr 
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
 
 1993             constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
 1996             auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1997                 static_cast<CShuffleDataType*
>(p_shared_0),
 
 1998                 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
 2001                 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
 2021             const auto c_thread_mtx_on_block =
 
 2022                 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0, 
I0, 
I0, 
I0);
 
 2024             const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
 
 2025             const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
 
 2027             const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
 
 2033             const auto m_thread_data_on_block_idx =
 
 2034                 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
 
 2037             const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
 
 2043             const auto n_thread_data_on_block_idx =
 
 2044                 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
 
 2048             auto c_thread_copy_vgpr_to_lds =
 
 2051                                                    decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
 2052                                                    decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
 2054                                                    Sequence<CShuffleMXdlPerWavePerShuffle,
 
 2055                                                             CShuffleNXdlPerWavePerShuffle,
 
 2068                     c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 2071                                      m_thread_data_on_block_idx[
I1],
 
 2072                                      n_thread_data_on_block_idx[
I1],
 
 2073                                      m_thread_data_on_block_idx[
I2],
 
 2074                                      m_thread_data_on_block_idx[
I3],
 
 2075                                      m_thread_data_on_block_idx[
I4],
 
 2076                                      n_thread_data_on_block_idx[
I2]),
 
 2082                 CElementwiseOperation,      
 
 2083                 CGlobalMemoryDataOperation, 
 
 2085                          CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
 2087                          CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, 
 
 2088                 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
 2092                 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
 
 2093                 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
 
 2096                 CShuffleBlockTransferScalarPerVector_NPerBlock, 
 
 2099                 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
 2101                  c_grid_desc_mblock_mperblock_nblock_nperblock,
 
 2106             constexpr 
auto sfc_c_vgpr =
 
 2109                                   Sequence<CShuffleMXdlPerWavePerShuffle,
 
 2110                                            CShuffleNXdlPerWavePerShuffle,
 
 2119             constexpr 
auto sfc_c_global =
 
 2123                                            CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
 2125                                            CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
 
 2127             constexpr 
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
 
 2129             static_assert(num_access == sfc_c_global.GetNumOfAccess(), 
"wrong!");
 
 2136                 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 2137                                               sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
 
 2139                                               c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 2140                                               c_shuffle_block_buf);
 
 2146                 c_shuffle_block_copy_lds_to_global.Run(
 
 2147                     c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
 2148                     c_shuffle_block_buf,
 
 2149                     c_grid_desc_mblock_mperblock_nblock_nperblock,
 
 2152                 if constexpr(access_id < num_access - 1)
 
 2154                     constexpr 
auto c_global_step = sfc_c_global.GetForwardStep(access_id);
 
 2157                     c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
 
 2158                         c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
 
 2164     template <
bool HasMainKBlockLoop,
 
 2167     __device__ 
static void Run_2Lds(
const ADataType* p_a_grid,
 
 2168                                     const BDataType* p_b_grid,
 
 2169                                     CDataType* p_c_grid,
 
 2181         const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
 
 2185         Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
 
 2186                  decltype(b_grid_desc_bk0_n_bk1),
 
 2187                  decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
 
 2189                  CGlobalMemoryDataOperation,
 
 2196                           a_grid_desc_ak0_m_ak1,
 
 2197                           b_grid_desc_bk0_n_bk1,
 
 2198                           c_grid_desc_mblock_mperblock_nblock_nperblock);
 
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
 
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
 
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
 
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
GemmSpecialization
Definition: gemm_specialization.hpp:11
 
int32_t int32_t
Definition: integer.hpp:10
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
 
InMemoryDataOperationEnum
Definition: ck.hpp:278
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
 
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
 
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
 
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
 
constexpr auto BlockGemmPipeline_Selector()
Definition: blockwise_gemm_pipeline_wmma_selector.hpp:31
 
_Float16 half_t
Definition: data_type.hpp:30
 
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
 
ushort bhalf_t
Definition: data_type.hpp:29
 
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:59
 
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
 
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
 
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
 
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:139
 
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
 
constexpr bool is_same_v
Definition: type.hpp:283
 
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
 
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
 
int32_t index_t
Definition: ck.hpp:300
 
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
 
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
 
Definition: block_to_ctile_map.hpp:270
 
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:282
 
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:708
 
const BElementwiseOperation b_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:649
 
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t k_batch_, bool is_reduce_=false, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CElementwiseOperation c_element_op=CElementwiseOperation{})
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:709
 
const BDataType * p_b_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:751
 
CDataType * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:752
 
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:740
 
const AElementwiseOperation a_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:648
 
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:745
 
const ADataType * p_a_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:750
 
const CElementwiseOperation c_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:650
 
bool is_reduce
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:753
 
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:635
 
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:687
 
index_t NPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:694
 
index_t KBatch
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:692
 
index_t StrideA
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:689
 
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:636
 
CElementwiseOperation c_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:703
 
index_t BK0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:698
 
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:686
 
index_t NBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:700
 
index_t MPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:693
 
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:688
 
index_t StrideB
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:690
 
index_t KPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:696
 
index_t StrideC
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:691
 
index_t MBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:699
 
BElementwiseOperation b_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:702
 
index_t AK0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:697
 
index_t KRead
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:695
 
AElementwiseOperation a_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:701
 
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:667
 
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:757
 
index_t a_k_split_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:806
 
index_t b_k_split_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:807
 
__device__ SplitKBatchOffset(Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:759
 
index_t c_reduce_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:808
 
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:241
 
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:566
 
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:445
 
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:331
 
static constexpr auto is_scale_mfma
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:267
 
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:308
 
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:298
 
static constexpr auto BK1Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:255
 
static constexpr index_t APackedSize
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:279
 
static constexpr bool is_single_rate_mfma
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:258
 
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1355
 
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:277
 
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1363
 
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:313
 
static constexpr auto I2
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:244
 
static constexpr index_t KPack
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:268
 
static constexpr auto lcm_AK1_BK1
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:257
 
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:349
 
static constexpr auto I7
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:249
 
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1122
 
static constexpr auto I5
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:247
 
static constexpr auto AK1Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:254
 
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1152
 
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1803
 
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:293
 
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:338
 
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:2167
 
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:303
 
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:557
 
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:319
 
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1124
 
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:949
 
static constexpr index_t BPackedSize
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:286
 
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1387
 
static constexpr auto I6
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:248
 
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:811
 
static constexpr auto I1
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:243
 
static constexpr auto I0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:242
 
static constexpr auto I3
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:245
 
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1766
 
static constexpr auto I4
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:246
 
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:325
 
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:363
 
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:343
 
static constexpr auto BK0Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:253
 
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1083
 
static constexpr auto AK0Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:252
 
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1348
 
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:574
 
Definition: xdlops_gemm.hpp:942
 
Definition: sequence.hpp:43
 
Definition: tensor_space_filling_curve.hpp:20
 
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
 
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
 
Definition: threadwise_tensor_slice_transfer.hpp:39
 
Definition: integral_constant.hpp:20
 
Definition: data_type.hpp:186
 
Definition: functional2.hpp:33
 
Definition: device_base.hpp:51
 
Definition: unary_element_wise_operation.hpp:308
 
#define CK_ENV(name)
Definition: env.hpp:128