20 template <
typename GridwiseGemm,
 
   21           bool HasMainKBlockLoop,
 
   26 #if CK_USE_LAUNCH_BOUNDS 
   31 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__)) 
   32 #if defined(__gfx11__) 
   36                    (std::is_same_v<c_data_type, ck::half_t> ||
 
   37                     std::is_same_v<c_data_type, ck::bhalf_t>)))
 
   40         __shared__ 
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
 
   42         auto splitk_batch_offset = 
typename GridwiseGemm::SplitKBatchOffset(karg);
 
   44         GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
 
   45             karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
 
   46             karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
 
   47             karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
 
   50 #if defined(__gfx11__) 
  161 template <
typename ALayout,
 
  166           typename AccDataType,
 
  167           typename CShuffleDataType,
 
  169           typename AElementwiseOperation,
 
  170           typename BElementwiseOperation,
 
  171           typename CElementwiseOperation,
 
  183           typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
  184           typename ABlockTransferThreadClusterArrangeOrder,
 
  185           typename ABlockTransferSrcAccessOrder,
 
  186           index_t ABlockTransferSrcVectorDim,
 
  187           index_t ABlockTransferSrcScalarPerVector,
 
  188           index_t ABlockTransferDstScalarPerVector_AK1,
 
  189           bool AThreadTransferSrcResetCoordinateAfterRun,
 
  191           typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
  192           typename BBlockTransferThreadClusterArrangeOrder,
 
  193           typename BBlockTransferSrcAccessOrder,
 
  194           index_t BBlockTransferSrcVectorDim,
 
  195           index_t BBlockTransferSrcScalarPerVector,
 
  196           index_t BBlockTransferDstScalarPerVector_BK1,
 
  197           bool BThreadTransferSrcResetCoordinateAfterRun,
 
  199           index_t CShuffleMRepeatPerShuffle,
 
  200           index_t CShuffleNRepeatPerShuffle,
 
  201           typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
  202           index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
 
  205           typename ComputeTypeA,
 
  206           typename ComputeTypeB,
 
  269         auto K_t = K_Batch * KPerBlock;
 
  270         return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
 
  275         auto K_t = K_Batch * KPerBlock;
 
  276         return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
 
  281         auto K_t = K_Batch * KPerBlock;
 
  282         return (K + K_t - 1) / K_t * KPerBlock;
 
  288         auto K_t                = K_Batch * KReadVec;
 
  289         return (K + K_t - 1) / K_t * KReadVec;
 
  302     template <index_t MNRepeat, index_t MNWaves, index_t MNPerWmma, 
typename BlockDesc>
 
  306         constexpr 
auto K0 = BlockDesc{}.GetLength(
I0);
 
  307         constexpr 
auto K1 = BlockDesc{}.GetLength(
I2);
 
  309         constexpr 
auto KRow = 
I2;
 
  311         constexpr 
auto KRow = 
I1;
 
  326         const auto a_grid_desc_mraw_kraw = [&]() {
 
  327             if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
 
  331             else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
 
  339         if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
 
  340                      GemmSpec == GemmSpecialization::MNKPadding)
 
  343             const auto a_grid_desc_m_k =
 
  357             return a_grid_desc_ak0_m_ak1;
 
  359         else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
 
  360                           GemmSpec == GemmSpecialization::MNPadding)
 
  364                 a_grid_desc_mraw_kraw,
 
  370             return a_grid_desc_ak0_m_ak1;
 
  372         else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
 
  373                           GemmSpec == GemmSpecialization::NKPadding)
 
  377                 a_grid_desc_mraw_kraw,
 
  389             return a_grid_desc_ak0_m_ak1;
 
  393             static_assert(!PermuteA, 
"PermuteA is not supported");
 
  397                 a_grid_desc_mraw_kraw,
 
  403             return a_grid_desc_ak0_m_ak1;
 
  410         const auto b_grid_desc_nraw_kraw = [&]() {
 
  424                         GemmSpec != GemmSpecialization::Default),
 
  425                       "pk_i4_t does not support padding");
 
  427         if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
 
  428                      GemmSpec == GemmSpecialization::MNKPadding)
 
  431             const auto b_grid_desc_n_k =
 
  445             return b_grid_desc_bk0_n_bk1;
 
  447         else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
 
  448                           GemmSpec == GemmSpecialization::MNPadding)
 
  452                 b_grid_desc_nraw_kraw,
 
  458             return b_grid_desc_bk0_n_bk1;
 
  460         else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
 
  461                           GemmSpec == GemmSpecialization::MKPadding)
 
  465                 b_grid_desc_nraw_kraw,
 
  477             return b_grid_desc_bk0_n_bk1;
 
  481             if constexpr(!PermuteB)
 
  485                     b_grid_desc_nraw_kraw,
 
  491                 return b_grid_desc_bk0_n_bk1;
 
  497                 constexpr 
index_t BK01 = KPerBlock / BK1Value;
 
  498                 const index_t BK0_     = StrideB / BK1Value;
 
  499                 const index_t BK00     = BK0_ / BK01;
 
  501                 const auto b_grid_desc_bk00_n_bk01_bk1_permute =
 
  505                     b_grid_desc_bk00_n_bk01_bk1_permute,
 
  512                 return b_grid_desc_bk0_n_bk1_permute;
 
  517     template <
typename ABlockDesc_AK0_M_AK1>
 
  520         constexpr 
index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
 
  522         return MakeWmmaTileDescriptor<MRepeat, MWaves, MPerWmma>(ABlockDesc_AK0_M_AK1{});
 
  525     template <
typename BBlockDesc_BK0_N_BK1>
 
  528         constexpr 
index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
 
  530         return MakeWmmaTileDescriptor<NRepeat, NWaves, NPerWmma>(BBlockDesc_BK0_N_BK1{});
 
  533     __host__ __device__ 
static auto 
  536         const auto c_grid_desc_mraw_nraw = [&]() {
 
  558         if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
 
  559                      GemmSpec == GemmSpecialization::MNKPadding)
 
  568         else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
 
  569                           GemmSpec == GemmSpecialization::MKPadding)
 
  573                 c_grid_desc_mraw_nraw,
 
  578         else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
 
  579                           GemmSpec == GemmSpecialization::NKPadding)
 
  583                 c_grid_desc_mraw_nraw,
 
  591             return c_grid_desc_mraw_nraw;
 
  625             std::cout << 
"problem {" 
  634                       << 
"KRead:" << 
KRead << 
", " 
  636                       << 
"AK0:" << 
AK0 << 
", " 
  637                       << 
"BK0:" << 
BK0 << 
", " 
  638                       << 
"MBlock: " << 
MBlock << 
", " 
  639                       << 
"NBlock: " << 
NBlock << 
"}" << std::endl;
 
  663                           const BDataType* p_b_grid_,
 
  664                           CDataType* p_c_grid_,
 
  672                           bool is_reduce_ = 
false)
 
  673             : 
Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_},
 
  702             if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
 
  706             else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
 
  711             if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
 
  715             else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
 
  717                 if constexpr(!PermuteB)
 
  723                     const int k0_offset = karg.
KRead * karg.
N;
 
  728             if(blockIdx.z < 
static_cast<uint32_t
>(karg.
KBatch - 1))
 
  768             constexpr 
auto MLdsLayer        = LdsSize < 1 ? 1 : LdsSize;
 
  783                 a_lds_block_desc_permuted,
 
  791                 a_lds_block_desc_ak0_mldslayer_m_ak1,
 
  799             return a_lds_block_desc_ak0_m_ak1;
 
  806             constexpr 
auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
 
  807             constexpr 
auto M1 = MPerBlock / M0;
 
  809             constexpr 
auto KThreadWrite     = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
 
  810             constexpr 
auto K0PerThreadWrite = 
AK0Number / KThreadWrite;
 
  811             constexpr 
auto KThreadRead      = 64 / MPerWmma;
 
  812             constexpr 
auto K0PerThreadRead  = 
AK0Number / KThreadRead;
 
  814             constexpr 
auto kfold = (
AK1Number * M0 * 
sizeof(ADataType) > 128)
 
  816                                        : 128 / (
AK1Number * M0 * 
sizeof(ADataType));
 
  817             constexpr 
auto KThreadReadPerm =
 
  818                 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
 
  819                     ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
 
  823             constexpr 
auto mpair = (
AK1Number * MPerWmma * 
sizeof(ADataType) > 128)
 
  825                                        : ((128 / (
AK1Number * MPerWmma * 
sizeof(ADataType))) > M0
 
  827                                               : 128 / (
AK1Number * MPerWmma * 
sizeof(ADataType)));
 
  833                            Number<kfold * M0 / mpair>{},
 
  852                 a_lds_block_desc_permuted,
 
  874                 a_lds_block_desc_unmerged,
 
  877                                           Number<KThreadWrite / kfold / KThreadReadPerm>{},
 
  886             return a_lds_block_desc_ak0_m_ak1;
 
  905             constexpr 
index_t NLdsLayer     = LdsSize < 1 ? 1 : LdsSize;
 
  920                 b_lds_block_desc_permuted,
 
  928                 b_lds_block_desc_bk0_nldslayer_n_bk1,
 
  936             return b_lds_block_desc_bk0_n_bk1;
 
  940             constexpr 
auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I1);
 
  941             constexpr 
auto N1 = NPerBlock / N0;
 
  943             constexpr 
auto KThreadWrite     = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I0);
 
  944             constexpr 
auto K0PerThreadWrite = 
BK0Number / KThreadWrite;
 
  945             constexpr 
auto KThreadRead      = 64 / NPerWmma;
 
  946             constexpr 
auto K0PerThreadRead  = 
BK0Number / KThreadRead;
 
  948             constexpr 
auto kfold = (
BK1Number * N0 * 
sizeof(BDataType) > 128)
 
  950                                        : 128 / (
BK1Number * N0 * 
sizeof(BDataType));
 
  951             constexpr 
auto KThreadReadPerm =
 
  952                 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
 
  953                     ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
 
  957             constexpr 
auto npair = (
BK1Number * NPerWmma * 
sizeof(BDataType) > 128)
 
  959                                        : ((128 / (
BK1Number * NPerWmma * 
sizeof(BDataType))) > N0
 
  961                                               : 128 / (
BK1Number * NPerWmma * 
sizeof(BDataType)));
 
  967                            Number<kfold * N0 / npair>{},
 
  986                 b_lds_block_desc_permuted,
 
 1008                 b_lds_block_desc_unmerged,
 
 1011                                           Number<KThreadWrite / kfold / KThreadReadPerm>{},
 
 1020             return b_lds_block_desc_bk0_n_bk1;
 
 1024     __host__ __device__ 
static constexpr 
auto 
 1028         constexpr 
index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
 
 1029         constexpr 
index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
 
 1031         constexpr 
auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
 
 1038         return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
 
 1053                  ABlockTransferSrcScalarPerVector,
 
 1054                  BBlockTransferSrcScalarPerVector,
 
 1074             a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
 
 1077             b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
 
 1080         constexpr 
auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
 
 1083         constexpr 
auto c_block_size =
 
 1084             c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
 
 1085                 .GetElementSpaceSize();
 
 1088                           b_block_space_size_aligned * 
sizeof(BDataType) / 
BPackedSize),
 
 1089                          c_block_size * 
sizeof(CShuffleDataType));
 
 1095         static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) &&
 
 1096                           (NPerBlock % (NPerWmma * NRepeat)) == 0,
 
 1097                       "Invalid tuning param!");
 
 1105             if(!(karg.
M % MPerBlock == 0))
 
 1109                     std::cout << 
"Arg M value is not a multiple of MPerBlock! M: " << karg.
M << 
" " 
 1110                               << __FILE__ << 
":" << __LINE__ << 
", in function: " << __func__
 
 1123             if(!(karg.
N % NPerBlock == 0))
 
 1127                     std::cout << 
"Arg N value is not a multiple of NPerBlock! N: " << karg.
N << 
" " 
 1128                               << __FILE__ << 
":" << __LINE__ << 
", in function: " << __func__
 
 1141             auto K_t = karg.
KBatch * KPerBlock;
 
 1142             if(!(karg.
K % K_t == 0))
 
 1146                     std::cout << 
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " 
 1147                               << karg.
K << 
" " << __FILE__ << 
":" << __LINE__
 
 1148                               << 
", in function: " << __func__ << std::endl;
 
 1156             auto K_t                = karg.
KBatch * KReadVec;
 
 1158             if((KReadPadSplited * (karg.
KBatch - 1)) >= karg.
K)
 
 1166             if(karg.
K % ABlockTransferSrcScalarPerVector != 0)
 
 1170                     std::cout << 
"Arg K (" << karg.
K 
 1171                               << 
") value is not a multiple of ABlockTransferSrcScalarPerVector (" 
 1172                               << ABlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1173                               << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1180             if(karg.
M % ABlockTransferSrcScalarPerVector != 0)
 
 1184                     std::cout << 
"Arg M (" << karg.
M 
 1185                               << 
") value is not a multiple of ABlockTransferSrcScalarPerVector (" 
 1186                               << ABlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1187                               << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1195             if(karg.
N % BBlockTransferSrcScalarPerVector != 0)
 
 1199                     std::cout << 
"Arg N (" << karg.
N 
 1200                               << 
") value is not a multiple of BBlockTransferSrcScalarPerVector (" 
 1201                               << BBlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1202                               << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1209             if(karg.
K % BBlockTransferSrcScalarPerVector != 0)
 
 1213                     std::cout << 
"Arg K (" << karg.
K 
 1214                               << 
") value is not a multiple of BBlockTransferSrcScalarPerVector (" 
 1215                               << BBlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1216                               << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1224             if(karg.
N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
 
 1228                     std::cout << 
"Arg N (" << karg.
N 
 1229                               << 
") value is not a multiple of " 
 1230                                  "CShuffleBlockTransferScalarPerVector_NPerBlock (" 
 1231                               << CShuffleBlockTransferScalarPerVector_NPerBlock << 
" )! " 
 1232                               << __FILE__ << 
":" << __LINE__ << 
", in function: " << __func__
 
 1240             if(karg.
M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
 
 1244                     std::cout << 
"Arg M (" << karg.
M 
 1245                               << 
") value is not a multiple of " 
 1246                                  "CShuffleBlockTransferScalarPerVector_NPerBlock (" 
 1247                               << CShuffleBlockTransferScalarPerVector_NPerBlock << 
" )! " 
 1248                               << __FILE__ << 
":" << __LINE__ << 
", in function: " << __func__
 
 1264                     std::cout << 
" KBatch: " << karg.
KBatch << 
" > 1 is not supported yet" 
 1265                               << __FILE__ << 
":" << __LINE__ << 
", in function: " << __func__
 
 1276         const auto num_k_loop = karg.
AK0 / (KPerBlock / AK1Value);
 
 1280             if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
 
 1292         const index_t num_loop = K / KPerBlock;
 
 1294         return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
 
 1299         const index_t num_loop = K / KPerBlock;
 
 1301         return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
 
 1304     template <
typename CGr
idDesc>
 
 1306         const CGridDesc& c_grid_desc_m_n, 
index_t MBlock, 
index_t NBlock)
 
 1315         return c_grid_desc_mblock_mperblock_nblock_nperblock;
 
 1323     template <
typename AGridDesc_AK0_M_K1,
 
 1324               typename BGridDesc_BK0_N_K1,
 
 1325               typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
 
 1326               bool HasMainKBlockLoop,
 
 1329     __device__ 
static void Run(
const ADataType* p_a_grid,
 
 1330                                const BDataType* p_b_grid,
 
 1331                                CDataType* p_c_grid,
 
 1334                                const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
 
 1335                                const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
 
 1336                                const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
 
 1337                                    c_grid_desc_mblock_mperblock_nblock_nperblock)
 
 1339         const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1340             p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
 
 1341         const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1342             p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
 
 1343         auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1344             p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
 1346         const AElementwiseOperation a_element_op{};
 
 1347         const BElementwiseOperation b_element_op{};
 
 1348         const CElementwiseOperation c_element_op{};
 
 1353         const auto block_work_idx =
 
 1356         if(!block_2_ctile_map.ValidCTileIndex(
 
 1358                make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
 
 1359                           c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
 
 1364         const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
 
 1365         const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
 
 1368         const index_t m_block_data_idx_on_grid =
 
 1369             __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
 
 1371         const index_t n_block_data_idx_on_grid =
 
 1372             __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
 
 1384         auto a_blockwise_copy =
 
 1386                                                 AElementwiseOperation,
 
 1390                                                 ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
 1391                                                 ABlockTransferThreadClusterArrangeOrder,
 
 1394                                                 decltype(a_grid_desc_ak0_m_ak1),
 
 1395                                                 decltype(a_block_desc_ak0_m_ak1),
 
 1396                                                 ABlockTransferSrcAccessOrder,
 
 1398                                                 ABlockTransferSrcVectorDim,
 
 1400                                                 ABlockTransferSrcScalarPerVector,
 
 1401                                                 ABlockTransferDstScalarPerVector_AK1,
 
 1404                                                 AThreadTransferSrcResetCoordinateAfterRun,
 
 1406                                                 BlockwiseGemmPipe::GlobalBufferNum>(
 
 1407                 a_grid_desc_ak0_m_ak1,
 
 1410                 a_block_desc_ak0_m_ak1,
 
 1415         auto b_blockwise_copy =
 
 1417                                                 BElementwiseOperation,
 
 1421                                                 BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
 1422                                                 BBlockTransferThreadClusterArrangeOrder,
 
 1425                                                 decltype(b_grid_desc_bk0_n_bk1),
 
 1426                                                 decltype(b_block_desc_bk0_n_bk1),
 
 1427                                                 BBlockTransferSrcAccessOrder,
 
 1429                                                 BBlockTransferSrcVectorDim,
 
 1431                                                 BBlockTransferSrcScalarPerVector,
 
 1432                                                 BBlockTransferDstScalarPerVector_BK1,
 
 1435                                                 BThreadTransferSrcResetCoordinateAfterRun,
 
 1437                                                 BlockwiseGemmPipe::GlobalBufferNum>(
 
 1438                 b_grid_desc_bk0_n_bk1,
 
 1441                 b_block_desc_bk0_n_bk1,
 
 1447             a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
 
 1450         auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1451             static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
 
 1453         auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1454             reinterpret_cast<BDataType*
>(
static_cast<char*
>(p_shared) + a_block_space_size_aligned *
 
 1457             b_block_desc_bk0_n_bk1.GetElementSpaceSize());
 
 1463         static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
 
 1465         auto c_thread_buf            = blockwise_gemm_pipeline.GetCThreadBuffer();
 
 1467         const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
 
 1468             (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
 
 1471         blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
 
 1472                                                                          a_block_desc_ak0_m_ak1,
 
 1476                                                                          a_block_slice_copy_step,
 
 1477                                                                          b_grid_desc_bk0_n_bk1,
 
 1478                                                                          b_block_desc_bk0_n_bk1,
 
 1482                                                                          b_block_slice_copy_step,
 
 1484                                                                          num_k_block_main_loop);
 
 1489             constexpr 
auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
 
 1490                 blockwise_gemm_pipeline
 
 1491                     .GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
 
 1495                 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
 
 1496                     blockwise_gemm_pipeline
 
 1497                         .GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
 
 1499             constexpr 
auto MWave =
 
 1500                 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
 
 1502             constexpr 
auto MSubGroup =
 
 1503                 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
 
 1505             constexpr 
auto NWave =
 
 1506                 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
 
 1508             constexpr 
auto NThreadPerSubGroup =
 
 1509                 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
 
 1511             constexpr 
auto MAccVgprs =
 
 1512                 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
 
 1516             constexpr 
auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
 
 1519             auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1520                 static_cast<CShuffleDataType*
>(p_shared),
 
 1521                 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
 
 1522                     .GetElementSpaceSize());
 
 1525                 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
 
 1527                         c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
 
 1539                                 NThreadPerSubGroup))), 
 
 1548             const auto c_thread_mtx_on_block =
 
 1549                 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0, 
I0);
 
 1551             const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
 
 1552             const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
 
 1554             const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
 
 1556                                                      MRepeat, MWave, MSubGroup, MAccVgprs))),
 
 1560             const auto m_thread_data_on_block_idx =
 
 1561                 m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor
 
 1564             const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
 
 1566                                                      NRepeat, NWave, NThreadPerSubGroup))),
 
 1570             const auto n_thread_data_on_block_idx =
 
 1571                 n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor
 
 1578                 decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
 
 1579                 decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
 
 1581                 Sequence<CShuffleMRepeatPerShuffle,
 
 1584                          CShuffleNRepeatPerShuffle,
 
 1594                 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
 
 1596                                  m_thread_data_on_block_idx[
I1],
 
 1597                                  m_thread_data_on_block_idx[
I2],
 
 1599                                  n_thread_data_on_block_idx[
I1],
 
 1600                                  n_thread_data_on_block_idx[
I2],
 
 1601                                  m_thread_data_on_block_idx[
I3]),
 
 1607                 CElementwiseOperation,      
 
 1608                 CGlobalMemoryDataOperation, 
 
 1610                          CShuffleMRepeatPerShuffle * MWave * MPerWmma,
 
 1612                          CShuffleNRepeatPerShuffle * NWave * NPerWmma>, 
 
 1613                 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
 1617                 decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
 
 1618                 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
 
 1621                 CShuffleBlockTransferScalarPerVector_NPerBlock, 
 
 1624                 {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
 
 1626                  c_grid_desc_mblock_mperblock_nblock_nperblock,
 
 1632             constexpr 
auto sfc_c_vgpr =
 
 1635                                   Sequence<CShuffleMRepeatPerShuffle,
 
 1638                                            CShuffleNRepeatPerShuffle,
 
 1644             constexpr 
auto sfc_c_global =
 
 1648                                            CShuffleMRepeatPerShuffle * MWave * MPerWmma,
 
 1650                                            CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{};
 
 1652             constexpr 
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
 
 1654             static_assert(num_access == sfc_c_global.GetNumOfAccess(), 
"wrong!");
 
 1661                 c_thread_copy_vgpr_to_lds.Run(
 
 1662                     c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
 
 1663                     sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
 
 1665                     c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
 
 1666                     c_shuffle_block_buf);
 
 1672                 c_shuffle_block_copy_lds_to_global.Run(
 
 1673                     c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
 
 1674                     c_shuffle_block_buf,
 
 1675                     c_grid_desc_mblock_mperblock_nblock_nperblock,
 
 1678                 if constexpr(access_id < num_access - 1)
 
 1680                     constexpr 
auto c_global_step = sfc_c_global.GetForwardStep(access_id);
 
 1683                     c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
 
 1684                         c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
 
 1690     template <
bool HasMainKBlockLoop,
 
 1693     __device__ 
static void Run(
const ADataType* p_a_grid,
 
 1694                                const BDataType* p_b_grid,
 
 1695                                CDataType* p_c_grid,
 
 1705         const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
 
 1709         Run<decltype(a_grid_desc_ak0_m_ak1),
 
 1710             decltype(b_grid_desc_bk0_n_bk1),
 
 1711             decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
 
 1713             CGlobalMemoryDataOperation,
 
 1719                      a_grid_desc_ak0_m_ak1,
 
 1720                      b_grid_desc_bk0_n_bk1,
 
 1721                      c_grid_desc_mblock_mperblock_nblock_nperblock);
 
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
 
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
 
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
 
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
GemmSpecialization
Definition: gemm_specialization.hpp:11
 
int32_t int32_t
Definition: integer.hpp:10
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
 
InMemoryDataOperationEnum
Definition: ck.hpp:278
 
typename remove_pointer< T >::type remove_pointer_t
Definition: type.hpp:300
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
 
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
 
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
 
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
 
constexpr auto BlockGemmPipeline_Selector()
Definition: blockwise_gemm_pipeline_wmma_selector.hpp:31
 
_Float16 half_t
Definition: data_type.hpp:30
 
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
 
ushort bhalf_t
Definition: data_type.hpp:29
 
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
 
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
 
__global__ void kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:29
 
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:139
 
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
 
constexpr bool is_same_v
Definition: type.hpp:283
 
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
 
int32_t index_t
Definition: ck.hpp:300
 
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
 
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
 
Definition: block_to_ctile_map.hpp:270
 
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:282
 
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:661
 
CDataType * p_c_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:693
 
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t k_batch_, bool is_reduce_=false)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:662
 
bool is_reduce
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:694
 
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:681
 
const ADataType * p_a_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:691
 
const BDataType * p_b_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:692
 
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:686
 
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:597
 
index_t M
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:642
 
index_t KPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:652
 
index_t NPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:650
 
index_t NBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:656
 
index_t K
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:644
 
__host__ void Print() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:623
 
index_t N
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:643
 
index_t AK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:653
 
index_t BK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:654
 
index_t KBatch
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:648
 
index_t MPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:649
 
index_t MBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:655
 
index_t StrideA
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:645
 
index_t StrideB
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:646
 
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:598
 
index_t StrideC
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:647
 
index_t KRead
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:651
 
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:698
 
index_t c_reduce_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:749
 
index_t b_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:748
 
index_t a_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:747
 
__device__ SplitKBatchOffset(Argument &karg)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:700
 
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:210
 
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:252
 
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:323
 
static constexpr auto BK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:224
 
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:534
 
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, decltype(MakeAWmmaTileDescriptor(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBWmmaTileDescriptor(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >())> BlockwiseGemmPipe
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1062
 
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:240
 
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:273
 
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1305
 
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:752
 
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1297
 
__host__ static constexpr __device__ auto MakeBWmmaTileDescriptor(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:526
 
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:279
 
static constexpr auto I6
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:217
 
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1064
 
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:285
 
static constexpr auto I5
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:216
 
__host__ static constexpr __device__ auto MakeAWmmaTileDescriptor(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:518
 
static constexpr auto I7
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:218
 
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:233
 
static constexpr auto I4
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:215
 
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1093
 
__host__ static constexpr __device__ auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1026
 
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:247
 
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:231
 
static constexpr auto I1
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:212
 
static constexpr auto I0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:211
 
static constexpr auto I3
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:214
 
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:257
 
static constexpr auto AK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:221
 
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1290
 
static constexpr auto AK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:223
 
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:292
 
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:267
 
static constexpr auto BK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:222
 
static constexpr index_t KPack
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:226
 
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1329
 
static constexpr auto I2
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:213
 
__host__ static constexpr __device__ auto MakeWmmaTileDescriptor(const BlockDesc &)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:303
 
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:297
 
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:262
 
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:890
 
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:407
 
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1693
 
Definition: sequence.hpp:43
 
Definition: tensor_space_filling_curve.hpp:20
 
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
 
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
 
Definition: threadwise_tensor_slice_transfer.hpp:39
 
Definition: wmma_gemm.hpp:553
 
Definition: integral_constant.hpp:20
 
Definition: data_type.hpp:186
 
Definition: functional2.hpp:33
 
Definition: device_base.hpp:51
 
Definition: unary_element_wise_operation.hpp:308
 
#define CK_ENV(name)
Definition: env.hpp:128