36 template <
typename GridwiseGemm,
 
   37           bool HasMainKBlockLoop,
 
   42 #if CK_USE_LAUNCH_BOUNDS 
   49     __shared__ 
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
 
   51     auto splitk_batch_offset = 
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
 
   53     GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
 
   54         karg.p_sorted_token_ids,
 
   55         karg.p_sorted_expert_ids,
 
   57         karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
 
   58         karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
 
   71 template <
typename GridwiseGemm,
 
   72           bool HasMainKBlockLoop,
 
   77 #if CK_USE_LAUNCH_BOUNDS 
   84     __shared__ 
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
 
   85     __shared__ 
char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
 
   87     auto splitk_batch_offset = 
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
 
   89     GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
 
   90         karg.p_sorted_token_ids,
 
   91         karg.p_sorted_expert_ids,
 
   93         karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
 
   94         karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
 
  108 template <
typename ALayout,
 
  114           typename AccDataType,
 
  115           typename CShuffleDataType,
 
  118           typename AElementwiseOperation,
 
  119           typename BElementwiseOperation,
 
  120           typename CElementwiseOperation,
 
  132           typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
  133           typename ABlockTransferThreadClusterArrangeOrder,
 
  134           typename ABlockTransferSrcAccessOrder,
 
  135           index_t ABlockTransferSrcVectorDim,
 
  136           index_t ABlockTransferSrcScalarPerVector,
 
  137           index_t ABlockTransferDstScalarPerVector_AK1,
 
  138           bool AThreadTransferSrcResetCoordinateAfterRun,
 
  140           typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
  141           typename BBlockTransferThreadClusterArrangeOrder,
 
  142           typename BBlockTransferSrcAccessOrder,
 
  143           index_t BBlockTransferSrcVectorDim,
 
  144           index_t BBlockTransferSrcScalarPerVector,
 
  145           index_t BBlockTransferDstScalarPerVector_BK1,
 
  146           bool BThreadTransferSrcResetCoordinateAfterRun,
 
  148           index_t CShuffleMXdlPerWavePerShuffle,
 
  149           index_t CShuffleNXdlPerWavePerShuffle,
 
  150           typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
  151           typename CDEShuffleBlockTransferScalarPerVectors,
 
  154           index_t ActivationOperation                 = 0,
 
  155           bool NSwizzle                               = 
false,
 
  156           bool IsInputGemm                            = 
true,
 
  157           bool MulRoutedWeight                        = 
true,
 
  158           bool PerTokenQuant                          = 
false,
 
  160           typename ComputeTypeA                       = CDataType,
 
  161           typename ComputeTypeB                       = ComputeTypeA,
 
  162           typename LDSTypeA                           = ADataType,
 
  163           typename LDSTypeB                           = BDataType>
 
  176         CDEShuffleBlockTransferScalarPerVectors{}[
I0];
 
  217                 return static_cast<const DDataType*
>(
nullptr);
 
  244         const index_t gridx  = NSwizzle ? nblock * mblock : nblock;
 
  245         const index_t gridy  = NSwizzle ? 1 : mblock;
 
  276         auto K_t = K_Batch * KPerBlock;
 
  277         return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
 
  282         auto K_t = K_Batch * KPerBlock;
 
  283         return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
 
  288         auto K_t = K_Batch * KPerBlock;
 
  289         return (K + K_t - 1) / K_t * KPerBlock;
 
  295         auto K_t                = K_Batch * KReadVec;
 
  296         return (K + K_t - 1) / K_t * KReadVec;
 
  309     template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, 
typename TileDesc_K0_MN_K1>
 
  325         IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
 
  327         const auto a_grid_desc_mraw_kraw = [&]() {
 
  328             if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
 
  332             else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
 
  340         if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
 
  341                      GemmSpec == GemmSpecialization::MNKPadding)
 
  344             const auto a_grid_desc_m_k =
 
  358             return a_grid_desc_ak0_m_ak1;
 
  360         else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
 
  361                           GemmSpec == GemmSpecialization::MNPadding)
 
  365                 a_grid_desc_mraw_kraw,
 
  371             return a_grid_desc_ak0_m_ak1;
 
  373         else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
 
  374                           GemmSpec == GemmSpecialization::NKPadding)
 
  378                 a_grid_desc_mraw_kraw,
 
  390             return a_grid_desc_ak0_m_ak1;
 
  396                 a_grid_desc_mraw_kraw,
 
  402             return a_grid_desc_ak0_m_ak1;
 
  411             make_tuple(
NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber, 
I1));
 
  417         const auto b_grid_desc_nraw_kraw = [&]() {
 
  431                         GemmSpec != GemmSpecialization::Default),
 
  432                       "pk_i4_t does not support padding");
 
  434         if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
 
  435                      GemmSpec == GemmSpecialization::MNKPadding)
 
  438             const auto b_grid_desc_n_k =
 
  452             return b_grid_desc_bk0_n_bk1;
 
  454         else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
 
  455                           GemmSpec == GemmSpecialization::MNPadding)
 
  459                 b_grid_desc_nraw_kraw,
 
  465             return b_grid_desc_bk0_n_bk1;
 
  467         else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
 
  468                           GemmSpec == GemmSpecialization::MKPadding)
 
  472                 b_grid_desc_nraw_kraw,
 
  484             return b_grid_desc_bk0_n_bk1;
 
  490                 b_grid_desc_nraw_kraw,
 
  496             return b_grid_desc_bk0_n_bk1;
 
  500     template <
typename ABlockDesc_AK0_M_AK1>
 
  501     __host__ __device__ 
static constexpr 
auto 
  504         constexpr 
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
 
  506         return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
 
  509     template <
typename BBlockDesc_BK0_N_BK1>
 
  510     __host__ __device__ 
static constexpr 
auto 
  513         return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
 
  516     template <
typename ELayout>
 
  518         IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
 
  520         const auto c_grid_desc_mraw_nraw = [&]() {
 
  539     template <
typename DLayout>
 
  540     __host__ __device__ 
static auto 
  543         const auto c_grid_desc_mraw_nraw = [&]() {
 
  568                 return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
 
  573     template <
typename DsGr
idDesc>
 
  575         const DsGridDesc& ds_grid_desc_m_n, 
index_t MBlock, 
index_t NBlock)
 
  580                     ds_grid_desc_m_n[i], MBlock, NBlock);
 
  594                                     std::array<index_t, NumDTensor> StrideDs_,
 
  622             std::cout << 
"problem {" << 
"NumTokens:" << 
NumTokens << 
", " << 
"TopK:" << 
TopK << 
", " 
  623                       << 
"M:" << 
M << 
", " << 
"N:" << 
N << 
", " << 
"K:" << 
K << 
", " 
  626                       << 
"KRead:" << 
KRead << 
", " << 
"KP:" << 
KPadded << 
", " << 
"AK0:" << 
AK0 
  627                       << 
", " << 
"BK0:" << 
BK0 << 
", " << 
"MBlock: " << 
MBlock << 
", " 
  628                       << 
"NBlock: " << 
NBlock << 
"}" << std::endl;
 
  658                           const index_t* p_sorted_expert_ids_,
 
  659                           const index_t* p_max_token_id_,
 
  660                           const ADataType* p_a_grid_,
 
  661                           const BDataType* p_b_grid_,
 
  662                           std::array<const void*, NumDTensor> p_ds_grid_,
 
  663                           CDataType* p_c_grid_,
 
  671                           std::array<index_t, NumDTensor> StrideDs_,
 
  674                           AElementwiseOperation a_element_op_,
 
  675                           BElementwiseOperation b_element_op_,
 
  676                           CElementwiseOperation c_element_op_)
 
  704                 p_ds_grid(i) = 
static_cast<const DDataType_*
>(p_ds_grid_[i]);
 
  725             if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
 
  729             else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
 
  734             if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
 
  738             else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
 
  744             if(k_id < karg.
KBatch - 1)
 
  761         if constexpr(ABlockLdsExtraM)
 
  771             constexpr 
auto a_lds_block_desc =
 
  783             return a_lds_block_desc_permuted;
 
  790             constexpr 
auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
 
  791             constexpr 
auto M1 = MPerBlock / M0;
 
  793             constexpr 
auto KThreadWrite     = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
 
  794             constexpr 
auto K0PerThreadWrite = 
AK0Number / KThreadWrite;
 
  795             constexpr 
auto KThreadRead      = 64 / MPerXdl;
 
  796             constexpr 
auto K0PerThreadRead  = 
AK0Number / KThreadRead;
 
  798             constexpr 
auto kfold = (
AK1Number * M0 * 
sizeof(LDSTypeA) > 128)
 
  800                                        : 128 / (
AK1Number * M0 * 
sizeof(LDSTypeA));
 
  801             constexpr 
auto KThreadReadPerm =
 
  802                 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
 
  803                     ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
 
  807             constexpr 
auto mpair = (
AK1Number * MPerXdl * 
sizeof(LDSTypeA) > 128)
 
  809                                        : ((128 / (
AK1Number * MPerXdl * 
sizeof(LDSTypeA))) > M0
 
  811                                               : 128 / (
AK1Number * MPerXdl * 
sizeof(LDSTypeA)));
 
  817                            Number<kfold * M0 / mpair>{},
 
  836                 a_lds_block_desc_permuted,
 
  858                 a_lds_block_desc_unmerged,
 
  861                                           Number<KThreadWrite / kfold / KThreadReadPerm>{},
 
  870             return a_lds_block_desc_ak0_m_ak1;
 
  883         constexpr 
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
 
  885         constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
  892         return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
 
  910                                 ABlockTransferSrcScalarPerVector,
 
  911                                 BBlockTransferSrcScalarPerVector,
 
  930             a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
 
  933         constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
  936         constexpr 
auto c_block_size =
 
  937             c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
 
  940                          c_block_size * 
sizeof(CShuffleDataType));
 
  946         static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
 
  947                           (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
 
  948                       "Invalid tuning param!");
 
  956             if(!(karg.
M % MPerBlock == 0))
 
  959                 std::cout << 
"Arg M value is not a multiple of MPerBlock! M: " << karg.
M << 
" " 
  960                           << __FILE__ << 
":" << __LINE__ << 
", in function: " << __func__
 
  974             if(!(karg.
N % NPerBlock == 0))
 
  977                 std::cout << 
"Arg N value is not a multiple of NPerBlock! N: " << karg.
N << 
" " 
  978                           << __FILE__ << 
":" << __LINE__ << 
", in function: " << __func__
 
  992             auto K_t = karg.
KBatch * KPerBlock;
 
  993             if(!(karg.
K % K_t == 0))
 
  996                 std::cout << 
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " 
  997                           << karg.
K << 
" " << __FILE__ << 
":" << __LINE__
 
  998                           << 
", in function: " << __func__ << std::endl;
 
 1007             auto K_t                = karg.
KBatch * KReadVec;
 
 1009             if((KReadPadSplited * (karg.
KBatch - 1)) >= karg.
K)
 
 1017             if(karg.
K % ABlockTransferSrcScalarPerVector != 0)
 
 1020                 std::cout << 
"Arg K (" << karg.
K 
 1021                           << 
") value is not a multiple of ABlockTransferSrcScalarPerVector (" 
 1022                           << ABlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1023                           << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1031             if(karg.
M % ABlockTransferSrcScalarPerVector != 0)
 
 1034                 std::cout << 
"Arg M (" << karg.
M 
 1035                           << 
") value is not a multiple of ABlockTransferSrcScalarPerVector (" 
 1036                           << ABlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1037                           << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1046             if(karg.
N % BBlockTransferSrcScalarPerVector != 0)
 
 1049                 std::cout << 
"Arg N (" << karg.
N 
 1050                           << 
") value is not a multiple of BBlockTransferSrcScalarPerVector (" 
 1051                           << BBlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1052                           << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1060             if(karg.
K % BBlockTransferSrcScalarPerVector != 0)
 
 1063                 std::cout << 
"Arg K (" << karg.
K 
 1064                           << 
") value is not a multiple of BBlockTransferSrcScalarPerVector (" 
 1065                           << BBlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1066                           << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1078                 std::cout << 
"Arg N (" << karg.
N 
 1079                           << 
") value is not a multiple of " 
 1080                              "CShuffleBlockTransferScalarPerVector_NPerBlock (" 
 1082                           << 
":" << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1093                 std::cout << 
"Arg M (" << karg.
M 
 1094                           << 
") value is not a multiple of " 
 1095                              "CShuffleBlockTransferScalarPerVector_NPerBlock (" 
 1097                           << 
":" << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1106         const auto num_k_loop = karg.
AK0 / (KPerBlock / AK1Value);
 
 1108         if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
 
 1119         const index_t num_loop = K / KPerBlock;
 
 1121         return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
 
 1126         const index_t num_loop = K / KPerBlock;
 
 1128         return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
 
 1131     template <
typename CGr
idDesc>
 
 1133         const CGridDesc& c_grid_desc_m_n, 
index_t MBlock, 
index_t NBlock)
 
 1142         return c_grid_desc_mblock_mperblock_nblock_nperblock;
 
 1150     template <
bool HasMainKBlockLoop,
 
 1154                                const index_t* p_sorted_expert_ids,
 
 1155                                const index_t* p_max_token_id,
 
 1156                                const ADataType* p_a_grid,
 
 1157                                const BDataType* p_b_grid,
 
 1159                                CDataType* p_c_grid,
 
 1162                                AElementwiseOperation a_element_op,
 
 1163                                BElementwiseOperation b_element_op,
 
 1164                                CElementwiseOperation c_element_op)
 
 1174         const auto b_grid_desc_bpreshuffled =
 
 1176         const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
 
 1182         const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
 
 1185         const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
 
 1187         const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
 
 1188         if(expert_block_id * MPerBlock >= max_token_id)
 
 1191             __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
 
 1192         const auto block_mn = [&]() -> std::pair<int, int> {
 
 1193             if constexpr(NSwizzle)
 
 1195                 const index_t ecnt_prefix  = p_max_token_id[1 + expert_id];
 
 1197                 const index_t ecnt         = p_max_token_id[2 + expert_id] - ecnt_prefix;
 
 1198                 const index_t expert_swizzle =
 
 1199                     ecnt > 0 ? ecnt : 1; 
 
 1200                 const index_t bid_new = blockIdx.x - prefix_block;
 
 1201                 const index_t nid     = __builtin_amdgcn_readfirstlane(
 
 1202                     bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
 
 1204                     __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
 
 1209                 return {blockIdx.x, blockIdx.y};
 
 1213         const index_t block_n_id = block_mn.first;
 
 1214         const index_t block_m_id = block_mn.second;
 
 1216             __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
 
 1219         constexpr 
auto AMThreads  = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
 
 1220         constexpr 
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
 
 1221         constexpr 
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
 
 1222         constexpr 
auto AKThreads  = AK0Threads * AK1Threads;
 
 1223         constexpr 
auto AMRepeats  = MPerBlock / AMThreads;
 
 1224         const index_t token_pos   = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
 
 1226         if(token_pos >= max_token_id || token0 >= problem.
NumTokens)
 
 1230             const index_t fused_token = p_sorted_token_ids[token_pos + m0];
 
 1231             index_t token_offset      = fused_token & 0xffffff;
 
 1232             if constexpr(!IsInputGemm)
 
 1234                 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
 
 1236             gather_offsets(m0) = 
static_cast<IndexType
>(token_offset) * problem.
K;
 
 1238         const IndexType expert_stride =
 
 1239             __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
 
 1240         const IndexType expert_offset = expert_id * expert_stride / 
BPackedSize;
 
 1242         const index_t n_block_data_idx_on_grid =
 
 1243             __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
 
 1245         const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1246             p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
 
 1247         const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1248             p_b_grid + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
 
 1258             AElementwiseOperation,
 
 1262             ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
 1263             ABlockTransferThreadClusterArrangeOrder,
 
 1266             decltype(a_grid_desc_ak0_m_ak1),
 
 1267             decltype(a_block_desc_ak0_m_ak1),
 
 1268             ABlockTransferSrcAccessOrder,
 
 1270             ABlockTransferSrcVectorDim,
 
 1272             ABlockTransferSrcScalarPerVector,
 
 1273             ABlockTransferDstScalarPerVector_AK1,
 
 1276             AThreadTransferSrcResetCoordinateAfterRun,
 
 1280             BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
 
 1283                                                 a_block_desc_ak0_m_ak1,
 
 1290         auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
 
 1291             b_block_desc_bk0_n_bk1.GetElementSpaceSize());
 
 1296             decltype(b_grid_desc_bpreshuffled),
 
 1297             decltype(b_block_desc_bk0_n_bk1),
 
 1301             BBlockTransferSrcScalarPerVector,
 
 1302             BThreadTransferSrcResetCoordinateAfterRun,
 
 1303             true>(b_grid_desc_bpreshuffled,
 
 1311         auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1312             static_cast<LDSTypeA*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
 
 1318         static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
 
 1320         auto c_thread_buf            = blockwise_gemm_pipeline.GetCThreadBuffer();
 
 1321         decltype(c_thread_buf) c_thread_buf_up;
 
 1325                                   c_thread_buf.num_of_v_,
 
 1326                                   c_thread_buf.s_per_v,
 
 1330         const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
 
 1331             (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
 
 1333         if constexpr(IsInputGemm)
 
 1335             const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / 
BPackedSize;
 
 1336             const auto b_grid_buf_up     = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1337                 p_b_grid_up + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
 
 1341                 decltype(b_grid_desc_bpreshuffled),
 
 1342                 decltype(b_block_desc_bk0_n_bk1),
 
 1346                 BBlockTransferSrcScalarPerVector,
 
 1347                 BThreadTransferSrcResetCoordinateAfterRun,
 
 1348                 true>(b_grid_desc_bpreshuffled,
 
 1354             blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
 
 1355                 a_grid_desc_ak0_m_ak1,
 
 1356                 a_block_desc_ak0_m_ak1,
 
 1360                 a_block_slice_copy_step,
 
 1361                 b_grid_desc_bpreshuffled,
 
 1363                 b_blockwise_copy_up,
 
 1367                 b_block_slice_copy_step,
 
 1370                 num_k_block_main_loop);
 
 1374             blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
 
 1375                 a_grid_desc_ak0_m_ak1,
 
 1376                 a_block_desc_ak0_m_ak1,
 
 1380                 a_block_slice_copy_step,
 
 1381                 b_grid_desc_bpreshuffled,
 
 1385                 b_block_slice_copy_step,
 
 1387                 num_k_block_main_loop);
 
 1392             static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
 
 1393                               NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
 
 1396             constexpr 
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
 
 1399             constexpr 
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
 
 1400                 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
 1404             constexpr 
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
 
 1405                 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
 1407             constexpr 
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
 
 1408             constexpr 
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
 
 1409             constexpr 
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
 
 1410             constexpr 
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
 
 1411             constexpr 
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
 
 1412             constexpr 
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
 
 1413             constexpr 
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
 
 1414             constexpr 
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
 
 1417             const float* p_sorted_weights_0 = p_ds_grid[
I0];
 
 1418             const float* p_scale_b          = p_ds_grid[
I1];
 
 1420             static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
 
 1421             static_assert(M4 == 4);
 
 1425             if(p_sorted_weights_0 != 
nullptr && p_scale_b != 
nullptr)
 
 1427                 if constexpr(PerTokenQuant)
 
 1429                     constexpr 
index_t scale_stride = (IsInputGemm ? 2 : 1);
 
 1430                     p_scale_b += expert_id * problem.
N * scale_stride + block_n_id * NPerBlock +
 
 1435                     p_scale_b += expert_id;
 
 1441                     const float scale_b = p_scale_b[n0 * 
NWave * NPerXdl * PerTokenQuant];
 
 1444                             const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
 
 1445                                                   m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
 
 1446                             if constexpr(PerTokenQuant)
 
 1449                                     *c_style_pointer_cast<const vector_type<int32_t, M4>*>(
 
 1450                                         p_sorted_token_ids + m_pos);
 
 1452                             if constexpr(MulRoutedWeight)
 
 1454                                 topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
 
 1455                                     p_ds_grid[
I2] + m_pos);
 
 1458                                 float scale_a = [&]() {
 
 1459                                     if constexpr(PerTokenQuant)
 
 1462                                         const index_t token_offset = fused_token & 0xffffff;
 
 1464                                                    ? p_sorted_weights_0[IsInputGemm
 
 1474                                         return p_sorted_weights_0[0];
 
 1478                                     blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
 
 1481                                 if constexpr(IsInputGemm) 
 
 1485                                         const float scale_up =
 
 1486                                             p_scale_b[(n0 * 
NWave * NPerXdl + problem.
N) *
 
 1488                                         float gate = scale_a * scale_b * c_thread_buf[cidx];
 
 1489                                         float up   = scale_a * scale_up * c_thread_buf_up[cidx];
 
 1490                                         if constexpr(MulRoutedWeight)
 
 1492                                             gate = gate * topk_weights.AsType<
float>()[m4];
 
 1493                                             up   = up * topk_weights.AsType<
float>()[m4];
 
 1501                                         c_thread_buf_fp32(cidx) = gate * up;
 
 1505                                         const float scale_up =
 
 1506                                             p_scale_b[(n0 * 
NWave * NPerXdl + problem.
N) *
 
 1508                                         float gate = scale_a * scale_b * c_thread_buf[cidx];
 
 1509                                         float up   = scale_a * scale_up * c_thread_buf_up[cidx];
 
 1510                                         if constexpr(MulRoutedWeight)
 
 1512                                             gate = gate * topk_weights.AsType<
float>()[m4];
 
 1513                                             up   = up * topk_weights.AsType<
float>()[m4];
 
 1521                                         c_thread_buf_fp32(cidx) = gate * up;
 
 1526                                     c_thread_buf_fp32(cidx) =
 
 1527                                         scale_a * scale_b * c_thread_buf[cidx];
 
 1528                                     if constexpr(MulRoutedWeight)
 
 1530                                         c_thread_buf_fp32(cidx) = c_thread_buf_fp32(cidx) *
 
 1531                                                                   topk_weights.AsType<
float>()[m4];
 
 1545                             const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
 
 1546                                                   m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
 
 1547                             if constexpr(MulRoutedWeight)
 
 1549                                 topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
 
 1550                                     p_ds_grid[
I2] + m_pos);
 
 1554                                     blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
 
 1558                                 if constexpr(IsInputGemm) 
 
 1562                                         float gate = c_thread_buf[cidx];
 
 1563                                         float up   = c_thread_buf_up[cidx];
 
 1564                                         if constexpr(MulRoutedWeight)
 
 1566                                             gate = gate * topk_weights.AsType<
float>()[m4];
 
 1567                                             up   = up * topk_weights.AsType<
float>()[m4];
 
 1570                                         c_thread_buf_fp32(cidx) = gate * up;
 
 1574                                         float gate = c_thread_buf[cidx];
 
 1575                                         float up   = c_thread_buf_up[cidx];
 
 1576                                         if constexpr(MulRoutedWeight)
 
 1578                                             gate = gate * topk_weights.AsType<
float>()[m4];
 
 1579                                             up   = up * topk_weights.AsType<
float>()[m4];
 
 1582                                         c_thread_buf_fp32(cidx) = gate * up;
 
 1587                                     c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
 
 1588                                     if constexpr(MulRoutedWeight)
 
 1590                                         c_thread_buf_fp32(cidx) = topk_weights.AsType<
float>()[m4] *
 
 1591                                                                   c_thread_buf_fp32[cidx];
 
 1600             constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
 1603             auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1604                 static_cast<CShuffleDataType*
>(p_shared),
 
 1605                 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
 1608                 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
 1628             const auto c_thread_mtx_on_block =
 
 1629                 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0, 
I0, 
I0, 
I0);
 
 1631             const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
 
 1632             const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
 
 1634             const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
 
 1640             const auto m_thread_data_on_block_idx =
 
 1641                 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
 
 1644             const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
 
 1650             const auto n_thread_data_on_block_idx =
 
 1651                 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
 
 1655             auto c_thread_copy_vgpr_to_lds =
 
 1658                                                    decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
 1659                                                    decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
 1661                                                    Sequence<CShuffleMXdlPerWavePerShuffle,
 
 1662                                                             CShuffleNXdlPerWavePerShuffle,
 
 1675                     c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 1678                                      m_thread_data_on_block_idx[
I1],
 
 1679                                      n_thread_data_on_block_idx[
I1],
 
 1680                                      m_thread_data_on_block_idx[
I2],
 
 1681                                      m_thread_data_on_block_idx[
I3],
 
 1682                                      m_thread_data_on_block_idx[
I4],
 
 1683                                      n_thread_data_on_block_idx[
I2]),
 
 1686             using EDataType = CDataType;
 
 1691             const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
 
 1697                     return make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1698                         p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
 
 1704                 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
 
 1706                              { 
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
 
 1711                 tie(c_shuffle_block_buf),
 
 1713                              { 
return ds_grid_buf[i]; },
 
 1717             const auto idx_c_ds_block_begin =
 
 1727             const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
 
 1728                 c_grid_desc_mblock_mperblock_nblock_nperblock;
 
 1730             using CDEBlockTransferCluster =
 
 1731                 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
 
 1732             const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
 
 1733             constexpr 
index_t scatter_weight_idx  = 3; 
 
 1738                    decltype(c_ds_desc_refs),
 
 1739                    decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
 
 1740                    CElementwiseOperation,
 
 1744                             CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
 1746                             CShuffleNXdlPerWavePerShuffle * 
NWave * NPerXdl>, 
 
 1747                    CDEBlockTransferCluster,
 
 1753                    CDEShuffleBlockTransferScalarPerVectors,
 
 1765                      idx_c_ds_block_begin,
 
 1766                      tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
 1770             auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1771                 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
 1772             constexpr 
auto sfc_c_vgpr =
 
 1775                                   Sequence<CShuffleMXdlPerWavePerShuffle,
 
 1776                                            CShuffleNXdlPerWavePerShuffle,
 
 1784             constexpr 
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
 
 1787             constexpr 
auto sfc_cde_block =
 
 1791                                            CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
 1793                                            CShuffleNXdlPerWavePerShuffle * 
NWave * NPerXdl>>{};
 
 1795             static_assert(num_access == sfc_cde_block.GetNumOfAccess(), 
"wrong!");
 
 1796             constexpr 
auto EMThreads =
 
 1797                 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
 
 1798             constexpr 
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
 
 1799             constexpr 
auto ENThreads =
 
 1800                 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
 
 1805                 auto dstidx = sfc_cde_block.GetIndex(access_id);
 
 1807                     block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
 
 1809                     const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
 
 1810                     IndexType token_offset    = fused_token & 0xffffff;
 
 1811                     if constexpr(IsInputGemm)
 
 1813                         token_offset = token_offset * problem.
TopK + (fused_token >> 24);
 
 1815                     scatter_offsets(m0) = 
static_cast<IndexType
>(token_offset) * problem.
N;
 
 1821                 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 1822                                               sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
 
 1824                                               c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 1825                                               c_shuffle_block_buf);
 
 1831                 cde_block_copy_lds_and_global.Run(
 
 1834                     tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
 1838                 if constexpr(access_id < num_access - 1)
 
 1840                     constexpr 
auto cde_lds_and_global_step =
 
 1841                         sfc_cde_block.GetForwardStep(access_id);
 
 1845                         cde_block_copy_lds_and_global.MoveSrcSliceWindow(
 
 1846                             c_ds_desc_refs, i + 
I1, cde_lds_and_global_step);
 
 1850                     cde_block_copy_lds_and_global.MoveDstSliceWindow(
 
 1851                         tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
 1853                         cde_lds_and_global_step);
 
 1859     template <
bool HasMainKBlockLoop,
 
 1863                                     const index_t* p_sorted_expert_ids,
 
 1864                                     const index_t* p_max_token_id,
 
 1865                                     const ADataType* p_a_grid,
 
 1866                                     const BDataType* p_b_grid,
 
 1868                                     CDataType* p_c_grid,
 
 1872                                     AElementwiseOperation a_element_op,
 
 1873                                     BElementwiseOperation b_element_op,
 
 1874                                     CElementwiseOperation c_element_op)
 
 1884         const auto b_grid_desc_bpreshuffled =
 
 1886         const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
 
 1892         const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
 
 1895         const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
 
 1897         const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
 
 1898         if(expert_block_id * MPerBlock >= max_token_id)
 
 1901             __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
 
 1902         const auto block_mn = [&]() -> std::pair<int, int> {
 
 1903             if constexpr(NSwizzle)
 
 1905                 const index_t ecnt_prefix  = p_max_token_id[1 + expert_id];
 
 1907                 const index_t ecnt         = p_max_token_id[2 + expert_id] - ecnt_prefix;
 
 1908                 const index_t expert_swizzle =
 
 1909                     ecnt > 0 ? ecnt : 1; 
 
 1910                 const index_t bid_new = blockIdx.x - prefix_block;
 
 1911                 const index_t nid     = __builtin_amdgcn_readfirstlane(
 
 1912                     bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
 
 1914                     __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
 
 1919                 return {blockIdx.x, blockIdx.y};
 
 1923         const index_t block_n_id = block_mn.first;
 
 1924         const index_t block_m_id = block_mn.second;
 
 1926             __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
 
 1929         constexpr 
auto AMThreads  = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
 
 1930         constexpr 
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
 
 1931         constexpr 
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
 
 1932         constexpr 
auto AKThreads  = AK0Threads * AK1Threads;
 
 1933         constexpr 
auto AMRepeats  = MPerBlock / AMThreads;
 
 1934         const index_t token_pos   = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
 
 1936         if(token_pos >= max_token_id || token0 >= problem.
NumTokens)
 
 1940             const index_t fused_token = p_sorted_token_ids[token_pos + m0];
 
 1941             index_t token_offset      = fused_token & 0xffffff;
 
 1942             if constexpr(!IsInputGemm)
 
 1944                 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
 
 1946             gather_offsets(m0) = 
static_cast<IndexType
>(token_offset) * problem.
K;
 
 1948         const IndexType expert_stride =
 
 1949             __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
 
 1950         const IndexType expert_offset = expert_id * expert_stride / 
BPackedSize;
 
 1952         const index_t n_block_data_idx_on_grid =
 
 1953             __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
 
 1955         const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1956             p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
 
 1957         const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1958             p_b_grid + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
 
 1969             AElementwiseOperation,
 
 1973             ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
 1974             ABlockTransferThreadClusterArrangeOrder,
 
 1977             decltype(a_grid_desc_ak0_m_ak1),
 
 1978             decltype(a_block_desc_ak0_m_ak1),
 
 1979             ABlockTransferSrcAccessOrder,
 
 1981             ABlockTransferSrcVectorDim,
 
 1983             ABlockTransferSrcScalarPerVector,
 
 1984             ABlockTransferDstScalarPerVector_AK1,
 
 1987             AThreadTransferSrcResetCoordinateAfterRun,
 
 1991             2>(a_grid_desc_ak0_m_ak1,
 
 1994                a_block_desc_ak0_m_ak1,
 
 2001         auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
 
 2002             b_block_desc_bk0_n_bk1.GetElementSpaceSize());
 
 2003         auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
 
 2004             b_block_desc_bk0_n_bk1.GetElementSpaceSize());
 
 2005         auto b_block_bufs = 
make_tuple(b_block_buf_ping, b_block_buf_pong);
 
 2010             decltype(b_grid_desc_bpreshuffled),
 
 2011             decltype(b_block_desc_bk0_n_bk1),
 
 2015             BBlockTransferSrcScalarPerVector,
 
 2016             BThreadTransferSrcResetCoordinateAfterRun,
 
 2017             true>(b_grid_desc_bpreshuffled,
 
 2025         auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 2026             static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
 
 2027         auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 2028             static_cast<ADataType*
>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
 
 2029         auto a_block_bufs = 
make_tuple(a_block_buf_ping, a_block_buf_pong);
 
 2035         static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
 
 2037         auto c_thread_buf            = blockwise_gemm_pipeline.GetCThreadBuffer();
 
 2038         decltype(c_thread_buf) c_thread_buf_up;
 
 2042                                   c_thread_buf.num_of_v_,
 
 2043                                   c_thread_buf.s_per_v,
 
 2047         const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
 
 2048             (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
 
 2051         if constexpr(IsInputGemm)
 
 2053             const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 / 
BPackedSize;
 
 2054             const auto b_grid_buf_up     = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2055                 p_b_grid_up + expert_offset, b_grid_desc_bpreshuffled.GetElementSpaceSize());
 
 2059                 decltype(b_grid_desc_bpreshuffled),
 
 2060                 decltype(b_block_desc_bk0_n_bk1),
 
 2064                 BBlockTransferSrcScalarPerVector,
 
 2065                 BThreadTransferSrcResetCoordinateAfterRun,
 
 2066                 true>(b_grid_desc_bpreshuffled,
 
 2071             blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
 
 2072                 a_grid_desc_ak0_m_ak1,
 
 2073                 a_block_desc_ak0_m_ak1,
 
 2077                 a_block_slice_copy_step,
 
 2078                 b_grid_desc_bpreshuffled,
 
 2080                 b_blockwise_copy_up,
 
 2084                 b_block_slice_copy_step,
 
 2087                 num_k_block_main_loop);
 
 2092             blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
 
 2093                 a_grid_desc_ak0_m_ak1,
 
 2094                 a_block_desc_ak0_m_ak1,
 
 2098                 a_block_slice_copy_step,
 
 2099                 b_grid_desc_bpreshuffled,
 
 2103                 b_block_slice_copy_step,
 
 2105                 num_k_block_main_loop);
 
 2110             static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
 
 2111                               NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
 
 2114             constexpr 
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
 
 2117             constexpr 
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
 
 2118                 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
 2122             constexpr 
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
 
 2123                 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
 2125             constexpr 
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
 
 2126             constexpr 
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
 
 2127             constexpr 
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
 
 2128             constexpr 
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
 
 2129             constexpr 
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
 
 2130             constexpr 
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
 
 2131             constexpr 
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
 
 2132             constexpr 
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
 
 2135             const float* p_sorted_weights_0 = p_ds_grid[
I0];
 
 2136             const float* p_scale_b          = p_ds_grid[
I1];
 
 2138             static_assert(M0 * M1 * M2 * M3 * M4 == MPerBlock);
 
 2139             static_assert(M4 == 4);
 
 2143             if(p_sorted_weights_0 != 
nullptr && p_scale_b != 
nullptr)
 
 2145                 if constexpr(PerTokenQuant)
 
 2147                     constexpr 
index_t scale_stride = (IsInputGemm ? 2 : 1);
 
 2148                     p_scale_b += expert_id * problem.
N * scale_stride + block_n_id * NPerBlock +
 
 2153                     p_scale_b += expert_id;
 
 2159                     const float scale_b = p_scale_b[n0 * 
NWave * NPerXdl * PerTokenQuant];
 
 2162                             const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
 
 2163                                                   m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
 
 2164                             if constexpr(PerTokenQuant)
 
 2167                                     *c_style_pointer_cast<const vector_type<int32_t, M4>*>(
 
 2168                                         p_sorted_token_ids + m_pos);
 
 2170                             if constexpr(MulRoutedWeight)
 
 2172                                 topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
 
 2173                                     p_ds_grid[
I2] + m_pos);
 
 2176                                 float scale_a = [&]() {
 
 2177                                     if constexpr(PerTokenQuant)
 
 2180                                         const index_t token_offset = fused_token & 0xffffff;
 
 2182                                                    ? p_sorted_weights_0[IsInputGemm
 
 2192                                         return p_sorted_weights_0[0];
 
 2196                                     blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
 
 2199                                 if constexpr(IsInputGemm) 
 
 2203                                         const float scale_up =
 
 2204                                             p_scale_b[(n0 * 
NWave * NPerXdl + problem.
N) *
 
 2206                                         float gate = scale_a * scale_b * c_thread_buf[cidx];
 
 2207                                         float up   = scale_a * scale_up * c_thread_buf_up[cidx];
 
 2208                                         if constexpr(MulRoutedWeight)
 
 2210                                             gate = gate * topk_weights.AsType<
float>()[m4];
 
 2211                                             up   = up * topk_weights.AsType<
float>()[m4];
 
 2219                                         c_thread_buf_fp32(cidx) = gate * up;
 
 2223                                         const float scale_up =
 
 2224                                             p_scale_b[(n0 * 
NWave * NPerXdl + problem.
N) *
 
 2226                                         float gate = scale_a * scale_b * c_thread_buf[cidx];
 
 2227                                         float up   = scale_a * scale_up * c_thread_buf_up[cidx];
 
 2228                                         if constexpr(MulRoutedWeight)
 
 2230                                             gate = gate * topk_weights.AsType<
float>()[m4];
 
 2231                                             up   = up * topk_weights.AsType<
float>()[m4];
 
 2239                                         c_thread_buf_fp32(cidx) = gate * up;
 
 2244                                     c_thread_buf_fp32(cidx) =
 
 2245                                         scale_a * scale_b * c_thread_buf[cidx];
 
 2246                                     if constexpr(MulRoutedWeight)
 
 2248                                         c_thread_buf_fp32(cidx) = c_thread_buf_fp32(cidx) *
 
 2249                                                                   topk_weights.AsType<
float>()[m4];
 
 2263                             const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 * M3 * M4 +
 
 2264                                                   m1 * M2 * M3 * M4 + m2 * M3 * M4 + m3 * M4;
 
 2265                             if constexpr(MulRoutedWeight)
 
 2267                                 topk_weights = *c_style_pointer_cast<const vector_type<float, M4>*>(
 
 2268                                     p_ds_grid[
I2] + m_pos);
 
 2272                                     blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
 
 2276                                 if constexpr(IsInputGemm) 
 
 2280                                         float gate = c_thread_buf[cidx];
 
 2281                                         float up   = c_thread_buf_up[cidx];
 
 2282                                         if constexpr(MulRoutedWeight)
 
 2284                                             gate = gate * topk_weights.AsType<
float>()[m4];
 
 2285                                             up   = up * topk_weights.AsType<
float>()[m4];
 
 2288                                         c_thread_buf_fp32(cidx) = gate * up;
 
 2292                                         float gate = c_thread_buf[cidx];
 
 2293                                         float up   = c_thread_buf_up[cidx];
 
 2294                                         if constexpr(MulRoutedWeight)
 
 2296                                             gate = gate * topk_weights.AsType<
float>()[m4];
 
 2297                                             up   = up * topk_weights.AsType<
float>()[m4];
 
 2300                                         c_thread_buf_fp32(cidx) = gate * up;
 
 2305                                     c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
 
 2306                                     if constexpr(MulRoutedWeight)
 
 2308                                         c_thread_buf_fp32(cidx) = topk_weights.AsType<
float>()[m4] *
 
 2309                                                                   c_thread_buf_fp32[cidx];
 
 2318             constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
 2321             auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 2322                 static_cast<CShuffleDataType*
>(p_shared),
 
 2323                 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
 2326                 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
 2346             const auto c_thread_mtx_on_block =
 
 2347                 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0, 
I0, 
I0, 
I0);
 
 2349             const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
 
 2350             const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
 
 2352             const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
 
 2358             const auto m_thread_data_on_block_idx =
 
 2359                 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
 
 2362             const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
 
 2368             const auto n_thread_data_on_block_idx =
 
 2369                 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
 
 2373             auto c_thread_copy_vgpr_to_lds =
 
 2376                                                    decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
 2377                                                    decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
 2379                                                    Sequence<CShuffleMXdlPerWavePerShuffle,
 
 2380                                                             CShuffleNXdlPerWavePerShuffle,
 
 2393                     c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 2396                                      m_thread_data_on_block_idx[
I1],
 
 2397                                      n_thread_data_on_block_idx[
I1],
 
 2398                                      m_thread_data_on_block_idx[
I2],
 
 2399                                      m_thread_data_on_block_idx[
I3],
 
 2400                                      m_thread_data_on_block_idx[
I4],
 
 2401                                      n_thread_data_on_block_idx[
I2]),
 
 2404             using EDataType = CDataType;
 
 2409             const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
 
 2415                     return make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2416                         p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
 
 2422                 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
 
 2424                              { 
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
 
 2429                 tie(c_shuffle_block_buf),
 
 2431                              { 
return ds_grid_buf[i]; },
 
 2435             const auto idx_c_ds_block_begin =
 
 2445             const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
 
 2446                 c_grid_desc_mblock_mperblock_nblock_nperblock;
 
 2448             using CDEBlockTransferCluster =
 
 2449                 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
 
 2450             const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
 
 2451             constexpr 
index_t scatter_weight_idx  = 3; 
 
 2456                    decltype(c_ds_desc_refs),
 
 2457                    decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
 
 2458                    CElementwiseOperation,
 
 2462                             CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
 2464                             CShuffleNXdlPerWavePerShuffle * 
NWave * NPerXdl>, 
 
 2465                    CDEBlockTransferCluster,
 
 2471                    CDEShuffleBlockTransferScalarPerVectors,
 
 2483                      idx_c_ds_block_begin,
 
 2484                      tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
 2488             auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2489                 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
 2490             constexpr 
auto sfc_c_vgpr =
 
 2493                                   Sequence<CShuffleMXdlPerWavePerShuffle,
 
 2494                                            CShuffleNXdlPerWavePerShuffle,
 
 2502             constexpr 
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
 
 2505             constexpr 
auto sfc_cde_block =
 
 2509                                            CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
 2511                                            CShuffleNXdlPerWavePerShuffle * 
NWave * NPerXdl>>{};
 
 2513             static_assert(num_access == sfc_cde_block.GetNumOfAccess(), 
"wrong!");
 
 2514             constexpr 
auto EMThreads =
 
 2515                 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
 
 2516             constexpr 
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
 
 2517             constexpr 
auto ENThreads =
 
 2518                 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
 
 2523                 auto dstidx = sfc_cde_block.GetIndex(access_id);
 
 2525                     block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
 
 2527                     const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
 
 2528                     IndexType token_offset    = fused_token & 0xffffff;
 
 2529                     if constexpr(IsInputGemm)
 
 2531                         token_offset = token_offset * problem.
TopK + (fused_token >> 24);
 
 2533                     scatter_offsets(m0) = 
static_cast<IndexType
>(token_offset) * problem.
N;
 
 2539                 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 2540                                               sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
 
 2542                                               c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 2543                                               c_shuffle_block_buf);
 
 2549                 cde_block_copy_lds_and_global.Run(
 
 2552                     tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
 2556                 if constexpr(access_id < num_access - 1)
 
 2558                     constexpr 
auto cde_lds_and_global_step =
 
 2559                         sfc_cde_block.GetForwardStep(access_id);
 
 2563                         cde_block_copy_lds_and_global.MoveSrcSliceWindow(
 
 2564                             c_ds_desc_refs, i + 
I1, cde_lds_and_global_step);
 
 2568                     cde_block_copy_lds_and_global.MoveDstSliceWindow(
 
 2569                         tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
 2571                         cde_lds_and_global_step);
 
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
 
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
 
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
 
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
GemmSpecialization
Definition: gemm_specialization.hpp:11
 
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:23
 
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
 
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
 
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
 
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
 
InMemoryDataOperationEnum
Definition: ck.hpp:275
 
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
 
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
 
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
 
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
 
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
 
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
 
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
 
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
 
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
 
Activation
Definition: gridwise_moe_gemm.hpp:31
 
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
 
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
 
constexpr auto BlockGemmBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_selector.hpp:41
 
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
 
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
 
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
 
constexpr bool is_same_v
Definition: type.hpp:283
 
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
 
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
 
int32_t index_t
Definition: ck.hpp:297
 
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
 
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
 
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
 
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:81
 
Definition: gridwise_moe_gemm.hpp:656
 
const BDataType * p_b_grid
Definition: gridwise_moe_gemm.hpp:712
 
const index_t * p_sorted_token_ids
Definition: gridwise_moe_gemm.hpp:708
 
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_gemm.hpp:709
 
const AElementwiseOperation a_element_op
Definition: gridwise_moe_gemm.hpp:716
 
const ADataType * p_a_grid
Definition: gridwise_moe_gemm.hpp:711
 
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_gemm.hpp:657
 
const index_t * p_max_token_id
Definition: gridwise_moe_gemm.hpp:710
 
const BElementwiseOperation b_element_op
Definition: gridwise_moe_gemm.hpp:717
 
CDataType * p_c_grid
Definition: gridwise_moe_gemm.hpp:714
 
DsGridPointer p_ds_grid
Definition: gridwise_moe_gemm.hpp:713
 
const CElementwiseOperation c_element_op
Definition: gridwise_moe_gemm.hpp:718
 
Definition: gridwise_moe_gemm.hpp:586
 
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_gemm.hpp:638
 
index_t NumTokens
Definition: gridwise_moe_gemm.hpp:631
 
index_t MBlock
Definition: gridwise_moe_gemm.hpp:647
 
index_t BK0Shuffled
Definition: gridwise_moe_gemm.hpp:651
 
index_t TopK
Definition: gridwise_moe_gemm.hpp:632
 
index_t K
Definition: gridwise_moe_gemm.hpp:635
 
__host__ void Print() const
Definition: gridwise_moe_gemm.hpp:620
 
index_t NPadded
Definition: gridwise_moe_gemm.hpp:642
 
index_t BK0
Definition: gridwise_moe_gemm.hpp:646
 
index_t KRead
Definition: gridwise_moe_gemm.hpp:643
 
index_t MPadded
Definition: gridwise_moe_gemm.hpp:641
 
index_t AK0
Definition: gridwise_moe_gemm.hpp:645
 
index_t StrideA
Definition: gridwise_moe_gemm.hpp:636
 
index_t StrideC
Definition: gridwise_moe_gemm.hpp:639
 
index_t M
Definition: gridwise_moe_gemm.hpp:633
 
index_t KBatch
Definition: gridwise_moe_gemm.hpp:640
 
index_t BN0Shuffled
Definition: gridwise_moe_gemm.hpp:650
 
__host__ __device__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_gemm.hpp:587
 
index_t KPadded
Definition: gridwise_moe_gemm.hpp:644
 
index_t StrideB
Definition: gridwise_moe_gemm.hpp:637
 
index_t N
Definition: gridwise_moe_gemm.hpp:634
 
index_t NBlock
Definition: gridwise_moe_gemm.hpp:648
 
Definition: gridwise_moe_gemm.hpp:722
 
index_t a_k_split_offset
Definition: gridwise_moe_gemm.hpp:754
 
index_t b_k_split_offset
Definition: gridwise_moe_gemm.hpp:755
 
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_gemm.hpp:723
 
Definition: gridwise_moe_gemm.hpp:165
 
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_gemm.hpp:240
 
__host__ static __device__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:292
 
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_gemm.hpp:211
 
__host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:286
 
static constexpr index_t KRepeat
Definition: gridwise_moe_gemm.hpp:204
 
remove_cvref_t< decltype(BlockGemmBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_gemm.hpp:920
 
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_gemm.hpp:255
 
static constexpr index_t NLane
Definition: gridwise_moe_gemm.hpp:206
 
static constexpr auto I5
Definition: gridwise_moe_gemm.hpp:171
 
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_gemm.hpp:541
 
static constexpr auto BK0Number
Definition: gridwise_moe_gemm.hpp:179
 
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_gemm.hpp:324
 
static constexpr index_t NumDTensor
Definition: gridwise_moe_gemm.hpp:184
 
__host__ static constexpr __device__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_gemm.hpp:1124
 
static constexpr auto I2
Definition: gridwise_moe_gemm.hpp:168
 
static constexpr index_t APackedSize
Definition: gridwise_moe_gemm.hpp:226
 
__host__ static __device__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_gemm.hpp:299
 
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_gemm.hpp:224
 
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_gemm.hpp:406
 
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_gemm.hpp:414
 
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_gemm.hpp:511
 
static constexpr auto I6
Definition: gridwise_moe_gemm.hpp:172
 
static constexpr auto I0
Definition: gridwise_moe_gemm.hpp:166
 
static constexpr index_t SortedTileSize
Definition: gridwise_moe_gemm.hpp:209
 
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_gemm.hpp:1117
 
static constexpr auto I1
Definition: gridwise_moe_gemm.hpp:167
 
static constexpr auto I4
Definition: gridwise_moe_gemm.hpp:170
 
static constexpr auto AK1Number
Definition: gridwise_moe_gemm.hpp:180
 
__host__ static __device__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:274
 
__host__ static __device__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_gemm.hpp:304
 
static constexpr auto BK1Number
Definition: gridwise_moe_gemm.hpp:181
 
static constexpr auto BlockSizeNumber
Definition: gridwise_moe_gemm.hpp:182
 
static constexpr index_t BPackedSize
Definition: gridwise_moe_gemm.hpp:233
 
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_gemm.hpp:517
 
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_gemm.hpp:264
 
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_gemm.hpp:222
 
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, void *p_shared1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm.hpp:1862
 
__host__ static __device__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm.hpp:280
 
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_gemm.hpp:562
 
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm.hpp:944
 
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_gemm.hpp:881
 
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_gemm.hpp:502
 
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_gemm.hpp:175
 
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_gemm.hpp:250
 
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm.hpp:1132
 
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_gemm.hpp:922
 
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_gemm.hpp:874
 
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_gemm.hpp:260
 
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm.hpp:1153
 
static constexpr index_t KPack
Definition: gridwise_moe_gemm.hpp:187
 
static constexpr index_t NWave
Definition: gridwise_moe_gemm.hpp:207
 
static constexpr auto I3
Definition: gridwise_moe_gemm.hpp:169
 
__host__ static __device__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_gemm.hpp:269
 
static constexpr auto AK0Number
Definition: gridwise_moe_gemm.hpp:178
 
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm.hpp:574
 
static constexpr index_t KGroup
Definition: gridwise_moe_gemm.hpp:192
 
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_gemm.hpp:310
 
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_gemm.hpp:758
 
static constexpr index_t KLane
Definition: gridwise_moe_gemm.hpp:189
 
static constexpr auto I7
Definition: gridwise_moe_gemm.hpp:173
 
Definition: xdlops_gemm.hpp:942
 
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1388
 
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1343
 
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1382
 
Definition: sequence.hpp:43
 
Definition: tensor_space_filling_curve.hpp:20
 
Definition: static_buffer.hpp:75
 
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
 
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
 
Definition: threadwise_tensor_slice_transfer.hpp:39
 
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
 
Definition: tuple.hpp:117
 
Definition: integral_constant.hpp:20
 
Definition: data_type.hpp:197
 
Definition: functional2.hpp:33
 
Definition: device_base.hpp:51
 
Definition: unary_element_wise_operation.hpp:981
 
Definition: unary_element_wise_operation.hpp:308
 
Definition: unary_element_wise_operation.hpp:1023