36 template <
typename GridwiseGemm,
37 bool HasMainKBlockLoop,
42 #if CK_USE_LAUNCH_BOUNDS
49 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
51 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
53 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
54 karg.p_sorted_token_ids,
55 karg.p_sorted_expert_ids,
57 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
58 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
73 template <
typename GridwiseGemm,
74 bool HasMainKBlockLoop,
79 #if CK_USE_LAUNCH_BOUNDS
86 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
87 __shared__
char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
89 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
91 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
92 karg.p_sorted_token_ids,
93 karg.p_sorted_expert_ids,
95 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
96 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
112 template <
typename ALayout,
118 typename AccDataType,
119 typename CShuffleDataType,
122 typename AElementwiseOperation,
123 typename BElementwiseOperation,
124 typename CElementwiseOperation,
139 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
140 typename ABlockTransferThreadClusterArrangeOrder,
141 typename ABlockTransferSrcAccessOrder,
142 index_t ABlockTransferSrcVectorDim,
143 index_t ABlockTransferSrcScalarPerVector,
144 index_t ABlockTransferDstScalarPerVector_AK1,
145 bool AThreadTransferSrcResetCoordinateAfterRun,
147 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
148 typename BBlockTransferThreadClusterArrangeOrder,
149 typename BBlockTransferSrcAccessOrder,
150 index_t BBlockTransferSrcVectorDim,
151 index_t BBlockTransferSrcScalarPerVector,
152 index_t BBlockTransferDstScalarPerVector_BK1,
153 bool BThreadTransferSrcResetCoordinateAfterRun,
155 index_t CShuffleMXdlPerWavePerShuffle,
156 index_t CShuffleNXdlPerWavePerShuffle,
157 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
158 typename CDEShuffleBlockTransferScalarPerVectors,
161 index_t ActivationOperation = 0,
162 bool NSwizzle =
false,
163 bool IsInputGemm =
true,
164 bool MulRoutedWeight =
true,
166 typename ComputeTypeA = CDataType,
167 typename ComputeTypeB = ComputeTypeA,
168 typename LDSTypeA = ADataType,
169 typename LDSTypeB = BDataType>
185 CDEShuffleBlockTransferScalarPerVectors{}[
I0];
223 return static_cast<const DDataType*
>(
nullptr);
250 const index_t gridx = NSwizzle ? nblock * mblock : nblock;
251 const index_t gridy = NSwizzle ? 1 : mblock;
281 auto K_t = K_Batch * KPerBlock;
282 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
287 auto K_t = K_Batch * KPerBlock;
288 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
293 auto K_t = K_Batch * KPerBlock;
294 return (K + K_t - 1) / K_t * KPerBlock;
300 auto K_t = K_Batch * KReadVec;
301 return (K + K_t - 1) / K_t * KReadVec;
314 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl,
typename TileDesc_K0_MN_K1>
330 IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
332 const auto a_grid_desc_mraw_kraw = [&]() {
333 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
337 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
345 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
346 GemmSpec == GemmSpecialization::MNKPadding)
349 const auto a_grid_desc_m_k =
363 return a_grid_desc_ak0_m_ak1;
365 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
366 GemmSpec == GemmSpecialization::MNPadding)
370 a_grid_desc_mraw_kraw,
376 return a_grid_desc_ak0_m_ak1;
378 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
379 GemmSpec == GemmSpecialization::NKPadding)
383 a_grid_desc_mraw_kraw,
395 return a_grid_desc_ak0_m_ak1;
401 a_grid_desc_mraw_kraw,
407 return a_grid_desc_ak0_m_ak1;
416 make_tuple(
NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber,
I1));
422 const auto b_grid_desc_nraw_kraw = [&]() {
436 GemmSpec != GemmSpecialization::Default),
437 "pk_i4_t does not support padding");
439 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
440 GemmSpec == GemmSpecialization::MNKPadding)
443 const auto b_grid_desc_n_k =
457 return b_grid_desc_bk0_n_bk1;
459 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
460 GemmSpec == GemmSpecialization::MNPadding)
464 b_grid_desc_nraw_kraw,
470 return b_grid_desc_bk0_n_bk1;
472 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
473 GemmSpec == GemmSpecialization::MKPadding)
477 b_grid_desc_nraw_kraw,
489 return b_grid_desc_bk0_n_bk1;
495 b_grid_desc_nraw_kraw,
501 return b_grid_desc_bk0_n_bk1;
505 template <
typename ABlockDesc_AK0_M_AK1>
506 __host__ __device__
static constexpr
auto
509 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
511 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
514 template <
typename BBlockDesc_BK0_N_BK1>
515 __host__ __device__
static constexpr
auto
518 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
521 template <
typename ELayout>
523 IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
525 const auto c_grid_desc_mraw_nraw = [&]() {
544 template <
typename DLayout>
545 __host__ __device__
static auto
548 const auto c_grid_desc_mraw_nraw = [&]() {
573 return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
578 template <
typename DsGr
idDesc>
580 const DsGridDesc& ds_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
585 ds_grid_desc_m_n[i], MBlock, NBlock);
601 std::array<index_t, NumDTensor> StrideDs_,
629 std::cout <<
"problem {" <<
"NumTokens:" <<
NumTokens <<
", " <<
"TopK:" <<
TopK <<
", "
630 <<
"M:" <<
M <<
", " <<
"N:" <<
N <<
", " <<
"K:" <<
K <<
", "
633 <<
"KRead:" <<
KRead <<
", " <<
"KP:" <<
KPadded <<
", " <<
"AK0:" <<
AK0
634 <<
", " <<
"BK0:" <<
BK0 <<
", " <<
"MBlock: " <<
MBlock <<
", "
635 <<
"NBlock: " <<
NBlock <<
"}" << std::endl;
665 const index_t* p_sorted_expert_ids_,
666 const index_t* p_max_token_id_,
667 const ADataType* p_a_grid_,
668 const BDataType* p_b_grid_,
669 std::array<const void*, NumDTensor> p_ds_grid_,
670 CDataType* p_c_grid_,
678 std::array<index_t, NumDTensor> StrideDs_,
683 AElementwiseOperation a_element_op_,
684 BElementwiseOperation b_element_op_,
685 CElementwiseOperation c_element_op_)
715 p_ds_grid(i) =
static_cast<const DDataType_*
>(p_ds_grid_[i]);
739 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
743 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
748 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
752 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
758 if(k_id < karg.
KBatch - 1)
775 if constexpr(ABlockLdsExtraM)
785 constexpr
auto a_lds_block_desc =
797 return a_lds_block_desc_permuted;
804 constexpr
auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
805 constexpr
auto M1 = MPerBlock / M0;
807 constexpr
auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
808 constexpr
auto K0PerThreadWrite =
AK0Number / KThreadWrite;
809 constexpr
auto KThreadRead = 64 / MPerXdl;
810 constexpr
auto K0PerThreadRead =
AK0Number / KThreadRead;
812 constexpr
auto kfold = (
AK1Number * M0 *
sizeof(LDSTypeA) > 128)
814 : 128 / (
AK1Number * M0 *
sizeof(LDSTypeA));
815 constexpr
auto KThreadReadPerm =
816 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
817 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
821 constexpr
auto mpair = (
AK1Number * MPerXdl *
sizeof(LDSTypeA) > 128)
823 : ((128 / (
AK1Number * MPerXdl *
sizeof(LDSTypeA))) > M0
825 : 128 / (
AK1Number * MPerXdl *
sizeof(LDSTypeA)));
831 Number<kfold * M0 / mpair>{},
850 a_lds_block_desc_permuted,
872 a_lds_block_desc_unmerged,
875 Number<KThreadWrite / kfold / KThreadReadPerm>{},
884 return a_lds_block_desc_ak0_m_ak1;
897 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
899 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
906 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
924 ABlockTransferSrcScalarPerVector,
925 BBlockTransferSrcScalarPerVector,
947 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
950 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
953 constexpr
auto c_block_size =
954 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
957 c_block_size *
sizeof(CShuffleDataType));
963 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
964 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
965 "Invalid tuning param!");
973 if(!(karg.
M % MPerBlock == 0))
976 std::cout <<
"Arg M value is not a multiple of MPerBlock! M: " << karg.
M <<
" "
977 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
991 if(!(karg.
N % NPerBlock == 0))
994 std::cout <<
"Arg N value is not a multiple of NPerBlock! N: " << karg.
N <<
" "
995 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1009 auto K_t = karg.
KBatch * KPerBlock;
1010 if(!(karg.
K % K_t == 0))
1013 std::cout <<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1014 << karg.
K <<
" " << __FILE__ <<
":" << __LINE__
1015 <<
", in function: " << __func__ << std::endl;
1024 auto K_t = karg.
KBatch * KReadVec;
1026 if((KReadPadSplited * (karg.
KBatch - 1)) >= karg.
K)
1034 if(karg.
K % ABlockTransferSrcScalarPerVector != 0)
1037 std::cout <<
"Arg K (" << karg.
K
1038 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1039 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1040 << __LINE__ <<
", in function: " << __func__ << std::endl;
1048 if(karg.
M % ABlockTransferSrcScalarPerVector != 0)
1051 std::cout <<
"Arg M (" << karg.
M
1052 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1053 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1054 << __LINE__ <<
", in function: " << __func__ << std::endl;
1063 if(karg.
N % BBlockTransferSrcScalarPerVector != 0)
1066 std::cout <<
"Arg N (" << karg.
N
1067 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1068 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1069 << __LINE__ <<
", in function: " << __func__ << std::endl;
1077 if(karg.
K % BBlockTransferSrcScalarPerVector != 0)
1080 std::cout <<
"Arg K (" << karg.
K
1081 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1082 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1083 << __LINE__ <<
", in function: " << __func__ << std::endl;
1095 std::cout <<
"Arg N (" << karg.
N
1096 <<
") value is not a multiple of "
1097 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1099 <<
":" << __LINE__ <<
", in function: " << __func__ << std::endl;
1110 std::cout <<
"Arg M (" << karg.
M
1111 <<
") value is not a multiple of "
1112 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1114 <<
":" << __LINE__ <<
", in function: " << __func__ << std::endl;
1123 const auto num_k_loop = karg.
AK0 / (KPerBlock / AK1Value);
1125 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1136 const index_t num_loop = K / KPerBlock;
1138 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1143 const index_t num_loop = K / KPerBlock;
1145 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1148 template <
typename CGr
idDesc>
1150 const CGridDesc& c_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1159 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1167 template <
bool HasMainKBlockLoop,
1171 const index_t* p_sorted_expert_ids,
1172 const index_t* p_max_token_id,
1173 const ADataType* p_a_grid,
1174 const BDataType* p_b_grid,
1176 CDataType* p_c_grid,
1181 AElementwiseOperation a_element_op,
1182 BElementwiseOperation b_element_op,
1183 CElementwiseOperation c_element_op)
1193 const auto b_grid_desc_bpreshuffled =
1195 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1213 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1216 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1218 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
1219 if(expert_block_id * MPerBlock >= max_token_id)
1222 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1223 const auto block_mn = [&]() -> std::pair<int, int> {
1224 if constexpr(NSwizzle)
1226 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1228 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1229 const index_t expert_swizzle =
1230 ecnt > 0 ? ecnt : 1;
1231 const index_t bid_new = blockIdx.x - prefix_block;
1232 const index_t nid = __builtin_amdgcn_readfirstlane(
1233 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1235 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1240 return {blockIdx.x, blockIdx.y};
1243 const index_t block_n_id = block_mn.first;
1244 const index_t block_m_id = block_mn.second;
1246 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1249 constexpr
auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
1250 constexpr
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
1251 constexpr
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
1252 constexpr
auto AKThreads = AK0Threads * AK1Threads;
1253 constexpr
auto AMRepeats = MPerBlock / AMThreads;
1254 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1256 if(token_pos >= max_token_id || token0 >= problem.
NumTokens)
1260 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1261 index_t token_offset = fused_token & 0xffffff;
1262 if constexpr(!IsInputGemm)
1264 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1266 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
K;
1269 __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
1270 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1275 const index_t n_block_data_idx_on_grid =
1276 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1278 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1279 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1280 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1282 b_grid_desc_bpreshuffled.GetElementSpaceSize());
1284 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1285 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1286 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1287 p_b_scale_grid + expert_id * expert_scale_stride,
1288 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1299 AElementwiseOperation,
1303 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1304 ABlockTransferThreadClusterArrangeOrder,
1307 decltype(a_grid_desc_ak0_m_ak1),
1308 decltype(a_block_desc_ak0_m_ak1),
1309 ABlockTransferSrcAccessOrder,
1311 ABlockTransferSrcVectorDim,
1313 ABlockTransferSrcScalarPerVector,
1314 ABlockTransferDstScalarPerVector_AK1,
1317 AThreadTransferSrcResetCoordinateAfterRun,
1321 BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1324 a_block_desc_ak0_m_ak1,
1331 auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1332 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1337 decltype(b_grid_desc_bpreshuffled),
1338 decltype(b_block_desc_bk0_n_bk1),
1342 BBlockTransferSrcScalarPerVector,
1343 BThreadTransferSrcResetCoordinateAfterRun,
1344 true>(b_grid_desc_bpreshuffled,
1352 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1353 static_cast<LDSTypeA*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1359 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1361 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1362 decltype(c_thread_buf) c_thread_buf_up;
1364 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1365 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1368 constexpr
index_t ScaleSliceSizeM = MXdlPerWave;
1377 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
1378 constexpr
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
1379 auto a_thread_offset =
1390 const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
1392 if(token_scale_pos >= max_token_id || token0 >= problem.
NumTokens)
1397 p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
1398 index_t token_offset = fused_token & 0xffffff;
1399 if constexpr(!IsInputGemm)
1401 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1403 scale_gather_offsets(m0) =
1407 auto a_scale_thread_copy =
1410 decltype(a_scale_grid_desc_am_ak),
1411 decltype(a_scale_thread_desc),
1421 auto b_scale_thread_copy =
1424 decltype(b_scale_grid_desc_bn_ak),
1425 decltype(b_scale_thread_desc),
1432 b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
1435 constexpr
auto a_scale_thread_slice_copy_step =
1437 constexpr
auto b_scale_thread_slice_copy_step =
make_multi_index(0, ScaleSliceSizeK);
1440 if constexpr(IsInputGemm)
1442 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 /
BPackedSize;
1443 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1445 b_grid_desc_bpreshuffled.GetElementSpaceSize());
1449 decltype(b_grid_desc_bpreshuffled),
1450 decltype(b_block_desc_bk0_n_bk1),
1454 BBlockTransferSrcScalarPerVector,
1455 BThreadTransferSrcResetCoordinateAfterRun,
1456 true>(b_grid_desc_bpreshuffled,
1462 p_b_scale_grid + expert_scale_stride / 2 /
BPackedSize;
1463 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1464 p_b_scale_grid_up + expert_id * expert_scale_stride,
1465 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1466 auto b_scale_thread_copy_up =
1469 decltype(b_scale_grid_desc_bn_ak),
1470 decltype(b_scale_thread_desc),
1477 b_scale_grid_desc_bn_ak,
1480 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1481 a_grid_desc_ak0_m_ak1,
1482 a_block_desc_ak0_m_ak1,
1486 a_block_slice_copy_step,
1488 b_grid_desc_bpreshuffled,
1489 b_block_desc_bk0_n_bk1,
1491 b_blockwise_copy_up,
1495 b_block_slice_copy_step,
1497 c_scale_thread_desc,
1501 a_scale_grid_desc_am_ak,
1502 a_scale_thread_desc,
1503 a_scale_thread_copy,
1505 a_scale_thread_slice_copy_step,
1507 b_scale_grid_desc_bn_ak,
1508 b_scale_thread_desc,
1509 b_scale_thread_copy,
1510 b_scale_thread_copy_up,
1512 b_scale_grid_buf_up,
1513 b_scale_thread_slice_copy_step,
1515 num_k_block_main_loop);
1519 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1520 a_grid_desc_ak0_m_ak1,
1521 a_block_desc_ak0_m_ak1,
1525 a_block_slice_copy_step,
1527 b_grid_desc_bpreshuffled,
1528 b_block_desc_bk0_n_bk1,
1532 b_block_slice_copy_step,
1534 c_scale_thread_desc,
1537 a_scale_grid_desc_am_ak,
1538 a_scale_thread_desc,
1539 a_scale_thread_copy,
1541 a_scale_thread_slice_copy_step,
1543 b_scale_grid_desc_bn_ak,
1544 b_scale_thread_desc,
1545 b_scale_thread_copy,
1547 b_scale_thread_slice_copy_step,
1549 num_k_block_main_loop);
1554 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1555 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1558 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1562 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
1563 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1567 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
1568 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1570 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I0);
1571 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I1);
1572 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I2);
1573 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I3);
1574 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I4);
1575 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I5);
1576 constexpr
auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I6);
1577 constexpr
auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I7);
1579 static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
1580 static_assert(M0 * M1 * M2 == MPerBlock);
1581 static_assert(N4 == 4);
1588 if constexpr(MulRoutedWeight)
1590 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
1591 topk_weight = p_ds_grid[
I0][m_pos];
1596 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1599 if constexpr(IsInputGemm)
1603 float gate = c_thread_buf[cidx];
1604 float up = c_thread_buf_up[cidx];
1605 if constexpr(MulRoutedWeight)
1607 gate = gate * topk_weight;
1608 up = up * topk_weight;
1616 c_thread_buf(cidx) = gate * up;
1620 float gate = c_thread_buf[cidx];
1621 float up = c_thread_buf_up[cidx];
1622 if constexpr(MulRoutedWeight)
1624 gate = gate * topk_weight;
1625 up = up * topk_weight;
1633 c_thread_buf(cidx) = gate * up;
1638 if constexpr(MulRoutedWeight)
1640 c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
1648 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1651 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1652 static_cast<CShuffleDataType*
>(p_shared),
1653 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1656 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1676 const auto c_thread_mtx_on_block =
1677 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1679 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1680 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1682 const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
1688 const auto m_thread_data_on_block_idx =
1689 m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
1692 const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
1698 const auto n_thread_data_on_block_idx =
1699 n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
1703 auto c_thread_copy_vgpr_to_lds =
1706 decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1707 decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1709 Sequence<CShuffleMXdlPerWavePerShuffle,
1710 CShuffleNXdlPerWavePerShuffle,
1723 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1726 m_thread_data_on_block_idx[
I1],
1727 n_thread_data_on_block_idx[
I1],
1728 m_thread_data_on_block_idx[
I2],
1729 n_thread_data_on_block_idx[
I2],
1730 n_thread_data_on_block_idx[
I3],
1731 n_thread_data_on_block_idx[
I4]),
1734 using EDataType = CDataType;
1739 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1746 const DDataType* ptr_ = p_ds_grid[i];
1749 return make_dynamic_buffer<AddressSpaceEnum::Global>(
1750 ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
1756 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1758 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1763 tie(c_shuffle_block_buf),
1765 {
return ds_grid_buf[i]; },
1769 const auto idx_c_ds_block_begin =
1779 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1780 c_grid_desc_mblock_mperblock_nblock_nperblock;
1782 using CDEBlockTransferCluster =
1783 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1784 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1785 constexpr
index_t scatter_weight_idx = IsInputGemm ? 1 : 1;
1790 decltype(c_ds_desc_refs),
1791 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1792 CElementwiseOperation,
1796 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1798 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>,
1799 CDEBlockTransferCluster,
1805 CDEShuffleBlockTransferScalarPerVectors,
1817 idx_c_ds_block_begin,
1818 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1822 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1823 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1825 constexpr
auto sfc_c_vgpr =
1828 Sequence<CShuffleMXdlPerWavePerShuffle,
1829 CShuffleNXdlPerWavePerShuffle,
1837 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1840 constexpr
auto sfc_cde_block =
1844 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1846 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>>{};
1848 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
1849 constexpr
auto EMThreads =
1850 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
1851 constexpr
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1852 constexpr
auto ENThreads =
1853 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
1858 auto dstidx = sfc_cde_block.GetIndex(access_id);
1860 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
1862 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1863 index_t token_offset = fused_token & 0xffffff;
1864 if constexpr(IsInputGemm)
1866 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1868 scatter_offsets(m0) = token_offset * problem.
N;
1874 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1875 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1877 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1878 c_shuffle_block_buf);
1884 cde_block_copy_lds_and_global.Run(
1887 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1891 if constexpr(access_id < num_access - 1)
1893 constexpr
auto cde_lds_and_global_step =
1894 sfc_cde_block.GetForwardStep(access_id);
1898 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1899 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
1903 cde_block_copy_lds_and_global.MoveDstSliceWindow(
1904 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1906 cde_lds_and_global_step);
1912 template <
bool HasMainKBlockLoop,
1916 const index_t* p_sorted_expert_ids,
1917 const index_t* p_max_token_id,
1918 const ADataType* p_a_grid,
1919 const BDataType* p_b_grid,
1921 CDataType* p_c_grid,
1927 AElementwiseOperation a_element_op,
1928 BElementwiseOperation b_element_op,
1929 CElementwiseOperation c_element_op)
1939 const auto b_grid_desc_bpreshuffled =
1941 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1958 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1961 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1962 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
1963 if(expert_block_id * MPerBlock >= max_token_id)
1966 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1967 const auto block_mn = [&]() -> std::pair<int, int> {
1968 if constexpr(NSwizzle)
1970 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1972 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1973 const index_t expert_swizzle = ecnt > 0 ? ecnt : 1;
1974 const index_t bid_new = blockIdx.x - prefix_block;
1975 const index_t nid = __builtin_amdgcn_readfirstlane(
1976 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1978 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1983 return {blockIdx.x, blockIdx.y};
1986 const index_t block_n_id = block_mn.first;
1987 const index_t block_m_id = block_mn.second;
1990 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1993 constexpr
auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
1994 constexpr
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
1995 constexpr
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
1996 constexpr
auto AKThreads = AK0Threads * AK1Threads;
1997 constexpr
auto AMRepeats = MPerBlock / AMThreads;
1998 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
2000 if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id ||
2006 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
2007 index_t token_offset = fused_token & 0xffffff;
2008 if constexpr(!IsInputGemm)
2010 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2012 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
K;
2015 __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
2016 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2020 const index_t n_block_data_idx_on_grid =
2021 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
2023 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2024 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2025 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2027 b_grid_desc_bpreshuffled.GetElementSpaceSize());
2029 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2030 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2031 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2032 p_b_scale_grid + expert_id * expert_scale_stride,
2033 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2044 AElementwiseOperation,
2048 ABlockTransferThreadClusterLengths_AK0_M_AK1,
2049 ABlockTransferThreadClusterArrangeOrder,
2052 decltype(a_grid_desc_ak0_m_ak1),
2053 decltype(a_block_desc_ak0_m_ak1),
2054 ABlockTransferSrcAccessOrder,
2056 ABlockTransferSrcVectorDim,
2058 ABlockTransferSrcScalarPerVector,
2059 ABlockTransferDstScalarPerVector_AK1,
2062 AThreadTransferSrcResetCoordinateAfterRun,
2066 BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
2069 a_block_desc_ak0_m_ak1,
2076 auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2077 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2078 auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2079 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2080 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
2085 decltype(b_grid_desc_bpreshuffled),
2086 decltype(b_block_desc_bk0_n_bk1),
2090 BBlockTransferSrcScalarPerVector,
2091 BThreadTransferSrcResetCoordinateAfterRun,
2092 true>(b_grid_desc_bpreshuffled,
2100 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2101 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2102 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2103 static_cast<ADataType*
>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2104 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
2110 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2112 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2113 decltype(c_thread_buf) c_thread_buf_up;
2115 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2116 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
2120 constexpr
index_t ScaleSliceSizeM = MXdlPerWave;
2129 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
2130 constexpr
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
2131 auto a_thread_offset =
2142 const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
2144 if(token_scale_pos >= max_token_id || token0 >= problem.
NumTokens)
2149 p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
2150 index_t token_offset = fused_token & 0xffffff;
2151 if constexpr(!IsInputGemm)
2153 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2155 scale_gather_offsets(m0) =
static_cast<IndexType
>(token_offset) *
2159 auto a_scale_thread_copy =
2162 decltype(a_scale_grid_desc_am_ak),
2163 decltype(a_scale_thread_desc),
2173 auto b_scale_thread_copy =
2176 decltype(b_scale_grid_desc_bn_ak),
2177 decltype(b_scale_thread_desc),
2184 b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
2187 constexpr
auto a_scale_thread_slice_copy_step =
2189 constexpr
auto b_scale_thread_slice_copy_step =
make_multi_index(0, ScaleSliceSizeK);
2192 if constexpr(IsInputGemm)
2194 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 /
BPackedSize;
2195 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2197 b_grid_desc_bpreshuffled.GetElementSpaceSize());
2201 decltype(b_grid_desc_bpreshuffled),
2202 decltype(b_block_desc_bk0_n_bk1),
2206 BBlockTransferSrcScalarPerVector,
2207 BThreadTransferSrcResetCoordinateAfterRun,
2208 true>(b_grid_desc_bpreshuffled,
2214 p_b_scale_grid + expert_scale_stride / 2 /
BPackedSize;
2215 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2216 p_b_scale_grid_up + expert_id * expert_scale_stride /
BPackedSize,
2217 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2218 auto b_scale_thread_copy_up =
2221 decltype(b_scale_grid_desc_bn_ak),
2222 decltype(b_scale_thread_desc),
2229 b_scale_grid_desc_bn_ak,
2232 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
2233 a_grid_desc_ak0_m_ak1,
2234 a_block_desc_ak0_m_ak1,
2238 a_block_slice_copy_step,
2239 b_grid_desc_bpreshuffled,
2240 b_block_desc_bk0_n_bk1,
2242 b_blockwise_copy_up,
2246 b_block_slice_copy_step,
2247 c_scale_thread_desc,
2250 a_scale_grid_desc_am_ak,
2251 a_scale_thread_desc,
2252 a_scale_thread_copy,
2254 a_scale_thread_slice_copy_step,
2255 b_scale_grid_desc_bn_ak,
2256 b_scale_thread_desc,
2257 b_scale_thread_copy,
2258 b_scale_thread_copy_up,
2260 b_scale_grid_buf_up,
2261 b_scale_thread_slice_copy_step,
2262 num_k_block_main_loop);
2266 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
2267 a_grid_desc_ak0_m_ak1,
2268 a_block_desc_ak0_m_ak1,
2272 a_block_slice_copy_step,
2273 b_grid_desc_bpreshuffled,
2274 b_block_desc_bk0_n_bk1,
2278 b_block_slice_copy_step,
2279 c_scale_thread_desc,
2281 a_scale_grid_desc_am_ak,
2282 a_scale_thread_desc,
2283 a_scale_thread_copy,
2285 a_scale_thread_slice_copy_step,
2286 b_scale_grid_desc_bn_ak,
2287 b_scale_thread_desc,
2288 b_scale_thread_copy,
2290 b_scale_thread_slice_copy_step,
2291 num_k_block_main_loop);
2297 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2298 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2301 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2305 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
2306 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
2310 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
2311 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
2313 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I0);
2314 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I1);
2315 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I2);
2316 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I3);
2317 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I4);
2318 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I5);
2319 constexpr
auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I6);
2320 constexpr
auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I7);
2322 static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
2323 static_assert(M0 * M1 * M2 == MPerBlock);
2324 static_assert(N4 == 4);
2331 if constexpr(MulRoutedWeight)
2333 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
2334 topk_weight = p_ds_grid[
I0][m_pos];
2339 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2342 if constexpr(IsInputGemm)
2346 float gate = c_thread_buf[cidx];
2347 float up = c_thread_buf_up[cidx];
2348 if constexpr(MulRoutedWeight)
2350 gate = gate * topk_weight;
2351 up = up * topk_weight;
2359 c_thread_buf(cidx) = gate * up;
2363 float gate = c_thread_buf[cidx];
2364 float up = c_thread_buf_up[cidx];
2365 if constexpr(MulRoutedWeight)
2367 gate = gate * topk_weight;
2368 up = up * topk_weight;
2376 c_thread_buf(cidx) = gate * up;
2381 if constexpr(MulRoutedWeight)
2383 c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
2392 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2395 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2396 static_cast<CShuffleDataType*
>(p_shared),
2397 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2400 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2420 const auto c_thread_mtx_on_block =
2421 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
2423 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
2424 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
2426 const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
2432 const auto m_thread_data_on_block_idx =
2433 m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
2436 const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
2442 const auto n_thread_data_on_block_idx =
2443 n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
2447 auto c_thread_copy_vgpr_to_lds =
2450 decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
2451 decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
2453 Sequence<CShuffleMXdlPerWavePerShuffle,
2454 CShuffleNXdlPerWavePerShuffle,
2467 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2470 m_thread_data_on_block_idx[
I1],
2471 n_thread_data_on_block_idx[
I1],
2472 m_thread_data_on_block_idx[
I2],
2473 n_thread_data_on_block_idx[
I2],
2474 n_thread_data_on_block_idx[
I3],
2475 n_thread_data_on_block_idx[
I4]),
2478 using EDataType = CDataType;
2483 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2489 return make_dynamic_buffer<AddressSpaceEnum::Global>(
2490 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2496 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2498 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2503 tie(c_shuffle_block_buf),
2505 {
return ds_grid_buf[i]; },
2509 const auto idx_c_ds_block_begin =
2519 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2520 c_grid_desc_mblock_mperblock_nblock_nperblock;
2522 using CDEBlockTransferCluster =
2523 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2524 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2525 constexpr
index_t scatter_weight_idx = IsInputGemm ? 1 : 1;
2530 decltype(c_ds_desc_refs),
2531 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2532 CElementwiseOperation,
2536 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2538 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>,
2539 CDEBlockTransferCluster,
2545 CDEShuffleBlockTransferScalarPerVectors,
2557 idx_c_ds_block_begin,
2558 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2562 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2563 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2565 constexpr
auto sfc_c_vgpr =
2568 Sequence<CShuffleMXdlPerWavePerShuffle,
2569 CShuffleNXdlPerWavePerShuffle,
2577 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2580 constexpr
auto sfc_cde_block =
2584 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2586 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>>{};
2588 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
2589 constexpr
auto EMThreads =
2590 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
2591 constexpr
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2592 constexpr
auto ENThreads =
2593 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
2599 auto dstidx = sfc_cde_block.GetIndex(access_id);
2601 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
2603 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2604 index_t token_offset = fused_token & 0xffffff;
2605 if constexpr(IsInputGemm)
2607 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2609 scatter_offsets(m0) = token_offset * problem.
N;
2615 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2616 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2618 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2619 c_shuffle_block_buf);
2625 cde_block_copy_lds_and_global.Run(
2628 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2632 if constexpr(access_id < num_access - 1)
2634 constexpr
auto cde_lds_and_global_step =
2635 sfc_cde_block.GetForwardStep(access_id);
2639 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2640 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
2644 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2645 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2647 cde_lds_and_global_step);
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:23
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:275
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
int64_t long_index_t
Definition: ck.hpp:298
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:297
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
constexpr auto BlockGemmBlockMoeScaleBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:81
Definition: gridwise_moe_gemm_blockscale.hpp:663
const index_t * p_sorted_token_ids
Definition: gridwise_moe_gemm_blockscale.hpp:719
CDataType * p_c_grid
Definition: gridwise_moe_gemm_blockscale.hpp:725
const BScaleType * p_b_scale_grid
Definition: gridwise_moe_gemm_blockscale.hpp:728
const index_t * p_max_token_id
Definition: gridwise_moe_gemm_blockscale.hpp:721
DsGridPointer p_ds_grid
Definition: gridwise_moe_gemm_blockscale.hpp:724
const CElementwiseOperation c_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:732
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, const AScaleType *p_a_scale_grid_, const BScaleType *p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_gemm_blockscale.hpp:664
const ADataType * p_a_grid
Definition: gridwise_moe_gemm_blockscale.hpp:722
const AElementwiseOperation a_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:730
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_gemm_blockscale.hpp:720
const BDataType * p_b_grid
Definition: gridwise_moe_gemm_blockscale.hpp:723
const AScaleType * p_a_scale_grid
Definition: gridwise_moe_gemm_blockscale.hpp:727
const BElementwiseOperation b_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:731
Definition: gridwise_moe_gemm_blockscale.hpp:593
index_t K
Definition: gridwise_moe_gemm_blockscale.hpp:642
__host__ __device__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_gemm_blockscale.hpp:594
index_t TopK
Definition: gridwise_moe_gemm_blockscale.hpp:639
index_t BK0Shuffled
Definition: gridwise_moe_gemm_blockscale.hpp:658
index_t NPadded
Definition: gridwise_moe_gemm_blockscale.hpp:649
index_t StrideB
Definition: gridwise_moe_gemm_blockscale.hpp:644
__host__ void Print() const
Definition: gridwise_moe_gemm_blockscale.hpp:627
index_t BK0
Definition: gridwise_moe_gemm_blockscale.hpp:653
index_t BN0Shuffled
Definition: gridwise_moe_gemm_blockscale.hpp:657
index_t KRead
Definition: gridwise_moe_gemm_blockscale.hpp:650
index_t N
Definition: gridwise_moe_gemm_blockscale.hpp:641
index_t StrideC
Definition: gridwise_moe_gemm_blockscale.hpp:646
index_t KBatch
Definition: gridwise_moe_gemm_blockscale.hpp:647
index_t MBlock
Definition: gridwise_moe_gemm_blockscale.hpp:654
index_t KPadded
Definition: gridwise_moe_gemm_blockscale.hpp:651
index_t NumTokens
Definition: gridwise_moe_gemm_blockscale.hpp:638
index_t StrideA
Definition: gridwise_moe_gemm_blockscale.hpp:643
index_t AK0
Definition: gridwise_moe_gemm_blockscale.hpp:652
index_t M
Definition: gridwise_moe_gemm_blockscale.hpp:640
index_t MPadded
Definition: gridwise_moe_gemm_blockscale.hpp:648
index_t NBlock
Definition: gridwise_moe_gemm_blockscale.hpp:655
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_gemm_blockscale.hpp:645
Definition: gridwise_moe_gemm_blockscale.hpp:736
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_gemm_blockscale.hpp:737
index_t a_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:768
index_t b_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:769
Definition: gridwise_moe_gemm_blockscale.hpp:171
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_gemm_blockscale.hpp:522
static constexpr index_t KPack
Definition: gridwise_moe_gemm_blockscale.hpp:196
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_gemm_blockscale.hpp:888
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_gemm_blockscale.hpp:411
static constexpr auto AK1Number
Definition: gridwise_moe_gemm_blockscale.hpp:189
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:269
static constexpr auto BK1Number
Definition: gridwise_moe_gemm_blockscale.hpp:190
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_gemm_blockscale.hpp:230
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, void *p_shared1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm_blockscale.hpp:1915
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1134
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_gemm_blockscale.hpp:772
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:260
__host__ static __device__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:285
static constexpr auto I3
Definition: gridwise_moe_gemm_blockscale.hpp:178
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm_blockscale.hpp:961
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:315
static constexpr index_t KLane
Definition: gridwise_moe_gemm_blockscale.hpp:209
remove_cvref_t< decltype(BlockGemmBlockMoeScaleBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, ScaleBlockM, ScaleBlockN, ScaleBlockK, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_gemm_blockscale.hpp:937
float AScaleType
Definition: gridwise_moe_gemm_blockscale.hpp:172
static constexpr auto AK0Number
Definition: gridwise_moe_gemm_blockscale.hpp:187
static constexpr index_t KRepeat
Definition: gridwise_moe_gemm_blockscale.hpp:211
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_gemm_blockscale.hpp:329
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm_blockscale.hpp:1149
static constexpr auto I2
Definition: gridwise_moe_gemm_blockscale.hpp:177
static constexpr index_t APackedSize
Definition: gridwise_moe_gemm_blockscale.hpp:232
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:265
__host__ static __device__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:274
static constexpr auto I4
Definition: gridwise_moe_gemm_blockscale.hpp:179
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_gemm_blockscale.hpp:567
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_gemm_blockscale.hpp:255
static constexpr auto I6
Definition: gridwise_moe_gemm_blockscale.hpp:181
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_gemm_blockscale.hpp:184
__host__ static constexpr __device__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1141
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:516
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:246
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_gemm_blockscale.hpp:228
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_gemm_blockscale.hpp:895
float BScaleType
Definition: gridwise_moe_gemm_blockscale.hpp:173
static constexpr auto I7
Definition: gridwise_moe_gemm_blockscale.hpp:182
__host__ static __device__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_gemm_blockscale.hpp:304
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm_blockscale.hpp:579
static constexpr index_t NWave
Definition: gridwise_moe_gemm_blockscale.hpp:213
static constexpr index_t NLane
Definition: gridwise_moe_gemm_blockscale.hpp:212
static constexpr index_t NumDTensor
Definition: gridwise_moe_gemm_blockscale.hpp:193
__host__ static __device__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:297
__host__ static __device__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:309
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_gemm_blockscale.hpp:939
__host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:291
static constexpr auto I0
Definition: gridwise_moe_gemm_blockscale.hpp:175
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:507
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_gemm_blockscale.hpp:546
static constexpr auto BlockSizeNumber
Definition: gridwise_moe_gemm_blockscale.hpp:191
static constexpr index_t KGroup
Definition: gridwise_moe_gemm_blockscale.hpp:198
static constexpr index_t BPackedSize
Definition: gridwise_moe_gemm_blockscale.hpp:239
static constexpr auto I5
Definition: gridwise_moe_gemm_blockscale.hpp:180
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm_blockscale.hpp:1170
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_gemm_blockscale.hpp:217
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_gemm_blockscale.hpp:419
static constexpr auto I1
Definition: gridwise_moe_gemm_blockscale.hpp:176
static constexpr auto BK0Number
Definition: gridwise_moe_gemm_blockscale.hpp:188
static constexpr index_t SortedTileSize
Definition: gridwise_moe_gemm_blockscale.hpp:215
__host__ static __device__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:279
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))> DsGridDesc_M_N
Definition: gridwise_moe_gemm_blockscale.hpp:590
Definition: xdlops_gemm.hpp:942
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1388
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1343
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1382
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: threadwise_tensor_slice_transfer.hpp:440
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
Definition: data_type.hpp:197
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:981
Definition: unary_element_wise_operation.hpp:308
Definition: unary_element_wise_operation.hpp:1023