38 template <
typename GridwiseGemm,
39 bool HasMainKBlockLoop,
44 #if CK_USE_LAUNCH_BOUNDS
51 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
53 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
55 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
56 karg.p_sorted_token_ids,
57 karg.p_sorted_expert_ids,
59 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
60 karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
61 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
62 karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
75 template <
typename GridwiseGemm,
76 bool HasMainKBlockLoop,
81 #if CK_USE_LAUNCH_BOUNDS
88 if constexpr(GridwiseGemm::template IsValidCompilationParameter<CGlobalMemoryDataOperation>())
90 __shared__
char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
91 __shared__
char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
93 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
95 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
96 karg.p_sorted_token_ids,
97 karg.p_sorted_expert_ids,
99 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
100 karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
101 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
102 karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
117 template <
typename ALayout,
122 typename AScaleDataType,
124 typename BScaleDataType,
125 typename AccDataType,
126 typename CShuffleDataType,
129 typename AElementwiseOperation,
130 typename BElementwiseOperation,
131 typename CElementwiseOperation,
144 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
145 typename ABlockTransferThreadClusterArrangeOrder,
146 typename ABlockTransferSrcAccessOrder,
147 index_t ABlockTransferSrcVectorDim,
148 index_t ABlockTransferSrcScalarPerVector,
149 index_t ABlockTransferDstScalarPerVector_AK1,
150 bool AThreadTransferSrcResetCoordinateAfterRun,
152 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
153 typename BBlockTransferThreadClusterArrangeOrder,
154 typename BBlockTransferSrcAccessOrder,
155 index_t BBlockTransferSrcVectorDim,
156 index_t BBlockTransferSrcScalarPerVector,
157 index_t BBlockTransferDstScalarPerVector_BK1,
158 bool BThreadTransferSrcResetCoordinateAfterRun,
160 index_t CShuffleMXdlPerWavePerShuffle,
161 index_t CShuffleNXdlPerWavePerShuffle,
162 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
163 typename CDEShuffleBlockTransferScalarPerVectors,
166 index_t ActivationOperation = 0,
167 bool NSwizzle =
false,
168 bool IsInputGemm =
true,
169 bool MulRoutedWeight =
true,
171 typename ComputeTypeA = ADataType,
172 typename ComputeTypeB = BDataType>
190 CDEShuffleBlockTransferScalarPerVectors{}[
I0];
237 "A scale pack data type too large!");
239 "B scale pack data type too large!");
247 return static_cast<const DDataType*
>(
nullptr);
260 const index_t gridx = NSwizzle ? nblock * mblock : nblock;
261 const index_t gridy = NSwizzle ? 1 : mblock;
292 auto K_t = K_Batch * KPerBlock;
293 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
298 auto K_t = K_Batch * KPerBlock;
299 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
304 auto K_t = K_Batch * KPerBlock;
305 return (K + K_t - 1) / K_t * KPerBlock;
311 auto K_t = K_Batch * KReadVec;
312 return (K + K_t - 1) / K_t * KReadVec;
325 template <
index_t MNXdlPerWave,
330 typename TileDesc_K0_MN_K1>
373 IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
375 const auto a_grid_desc_mraw_kraw = [&]() {
376 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
380 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
388 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
389 GemmSpec == GemmSpecialization::MNKPadding)
392 const auto a_grid_desc_m_k =
406 return a_grid_desc_ak0_m_ak1;
408 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
409 GemmSpec == GemmSpecialization::MNPadding)
413 a_grid_desc_mraw_kraw,
419 return a_grid_desc_ak0_m_ak1;
421 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
422 GemmSpec == GemmSpecialization::NKPadding)
426 a_grid_desc_mraw_kraw,
438 return a_grid_desc_ak0_m_ak1;
444 a_grid_desc_mraw_kraw,
451 a_grid_desc_ak0_m_ak1,
459 a_grid_desc_permuted,
473 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
474 constexpr
index_t WaveSize = BlockSize / (MWave *
NWave);
483 const auto b_grid_desc_nraw_kraw = [&]() {
497 GemmSpec != GemmSpecialization::Default),
498 "pk_i4_t does not support padding");
500 GemmSpec != GemmSpecialization::Default),
501 "f4x2_pk_t does not support padding");
503 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
504 GemmSpec == GemmSpecialization::MNKPadding)
507 const auto b_grid_desc_n_k =
521 return b_grid_desc_bk0_n_bk1;
523 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
524 GemmSpec == GemmSpecialization::MNPadding)
528 b_grid_desc_nraw_kraw,
534 return b_grid_desc_bk0_n_bk1;
536 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
537 GemmSpec == GemmSpecialization::MKPadding)
541 b_grid_desc_nraw_kraw,
553 return b_grid_desc_bk0_n_bk1;
559 b_grid_desc_nraw_kraw,
566 b_grid_desc_bk0_n_bk1,
574 b_grid_desc_permuted,
586 template <
typename ABlockDesc_AK0_M_AK1>
587 __host__ __device__
static constexpr
auto
590 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
592 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MXdlPack, MPerXdl, true>(
593 ABlockDesc_AK0_M_AK1{});
596 template <
typename BBlockDesc_BK0_N_BK1>
597 __host__ __device__
static constexpr
auto
600 constexpr
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
602 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NXdlPack, NPerXdl, false>(
603 BBlockDesc_BK0_N_BK1{});
606 template <
typename ELayout>
608 IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
610 const auto c_grid_desc_mraw_nraw = [&]() {
629 template <
typename DLayout>
630 __host__ __device__
static auto
633 const auto c_grid_desc_mraw_nraw = [&]() {
658 return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
663 template <
typename DsGr
idDesc>
665 const DsGridDesc& ds_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
670 ds_grid_desc_m_n[i], MBlock, NBlock);
686 std::array<index_t, NumDTensor> StrideDs_,
714 std::cout <<
"problem {" <<
"NumTokens:" <<
NumTokens <<
", " <<
"TopK:" <<
TopK <<
", "
715 <<
"M:" <<
M <<
", " <<
"N:" <<
N <<
", " <<
"K:" <<
K <<
", "
719 <<
", " <<
"KRead:" <<
KRead <<
", " <<
"KP:" <<
KPadded <<
", "
720 <<
"AK0:" <<
AK0 <<
", " <<
"BK0:" <<
BK0 <<
", " <<
"MBlock: " <<
MBlock
721 <<
", " <<
"NBlock: " <<
NBlock <<
"}" << std::endl;
750 const index_t* p_sorted_expert_ids_,
751 const index_t* p_max_token_id_,
752 const ADataType* p_a_grid_,
753 const AScaleDataType* p_a_scale_grid_,
754 const BDataType* p_b_grid_,
755 const BScaleDataType* p_b_scale_grid_,
756 std::array<const void*, NumDTensor> p_ds_grid_,
757 CDataType* p_c_grid_,
767 std::array<index_t, NumDTensor> StrideDs_,
770 AElementwiseOperation a_element_op_,
771 BElementwiseOperation b_element_op_,
772 CElementwiseOperation c_element_op_)
804 p_ds_grid(i) =
static_cast<const DDataType_*
>(p_ds_grid_[i]);
827 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
831 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
836 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
840 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
854 if(k_id < karg.
KBatch - 1)
872 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
873 constexpr
index_t WaveSize = BlockSize / (MWave *
NWave);
887 constexpr
auto a_lds_block_desc =
899 return a_lds_block_desc_permuted;
906 constexpr
auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
907 constexpr
auto M1 = MPerBlock / M0;
909 constexpr
auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
910 constexpr
auto K0PerThreadWrite =
AK0Number / KThreadWrite;
911 constexpr
auto KThreadRead = WaveSize / MPerXdl;
912 constexpr
auto K0PerThreadRead =
AK0Number / KThreadRead;
914 constexpr
auto kfold = (
AK1Number * M0 *
sizeof(ADataType) > 128)
916 : 128 / (
AK1Number * M0 *
sizeof(ADataType));
917 constexpr
auto KThreadReadPerm =
918 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
919 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
923 constexpr
auto mpair = (
AK1Number * MPerXdl *
sizeof(ADataType) > 128)
925 : ((128 / (
AK1Number * MPerXdl *
sizeof(ADataType))) > M0
927 : 128 / (
AK1Number * MPerXdl *
sizeof(ADataType)));
933 Number<kfold * M0 / mpair>{},
952 a_lds_block_desc_permuted,
974 a_lds_block_desc_unmerged,
977 Number<KThreadWrite / kfold / KThreadReadPerm>{},
986 return a_lds_block_desc_ak0_m_ak1;
1002 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1004 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1011 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1032 ABlockTransferSrcScalarPerVector,
1033 BBlockTransferSrcScalarPerVector,
1052 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1055 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1058 constexpr
auto c_block_size =
1059 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1061 return math::max(a_block_space_size_aligned *
sizeof(ADataType),
1062 c_block_size *
sizeof(CShuffleDataType));
1070 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1071 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1072 "Invalid tuning param!");
1074 static_assert(KPerBlock % (ScaleBlockSize /
BPackedSize) == 0,
1075 "KPerBlock should be multiple of ScaleBlockSize");
1083 if(!(karg.M % MPerBlock == 0))
1087 std::cout <<
"Arg M value is not a multiple of MPerBlock! M: " << karg.M <<
" "
1088 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1101 if(!(karg.N % NPerBlock == 0))
1105 std::cout <<
"Arg N value is not a multiple of NPerBlock! N: " << karg.N <<
" "
1106 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1118 auto K_t = karg.KBatch * KPerBlock;
1119 if(!(karg.K % K_t == 0))
1123 std::cout <<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1124 << karg.K <<
" " << __FILE__ <<
":" << __LINE__
1125 <<
", in function: " << __func__ << std::endl;
1133 auto K_t = karg.KBatch * KReadVec;
1135 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1143 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1147 std::cout <<
"Arg K (" << karg.K
1148 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1149 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1150 << __LINE__ <<
", in function: " << __func__ << std::endl;
1157 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1161 std::cout <<
"Arg M (" << karg.M
1162 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1163 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1164 << __LINE__ <<
", in function: " << __func__ << std::endl;
1172 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1176 std::cout <<
"Arg N (" << karg.N
1177 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1178 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1179 << __LINE__ <<
", in function: " << __func__ << std::endl;
1186 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1190 std::cout <<
"Arg K (" << karg.K
1191 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1192 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1193 << __LINE__ <<
", in function: " << __func__ << std::endl;
1205 std::cout <<
"Arg N (" << karg.N
1206 <<
") value is not a multiple of "
1207 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1209 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1221 std::cout <<
"Arg M (" << karg.M
1222 <<
") value is not a multiple of "
1223 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1225 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1235 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1237 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1248 const index_t num_loop = K / KPerBlock;
1249 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1254 const index_t num_loop = K / KPerBlock;
1256 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1259 template <
typename CGr
idDesc>
1261 const CGridDesc& c_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1270 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1278 template <
bool HasMainKBlockLoop,
1282 const index_t* p_sorted_expert_ids,
1283 const index_t* p_max_token_id,
1284 const ADataType* p_a_grid,
1285 const AScaleDataType* p_a_scale_grid,
1286 const BDataType* p_b_grid,
1287 const BScaleDataType* p_b_scale_grid,
1289 CDataType* p_c_grid,
1292 AElementwiseOperation a_element_op,
1293 BElementwiseOperation b_element_op,
1294 CElementwiseOperation c_element_op)
1307 const auto b_grid_desc_bpreshuffled =
1309 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1317 const auto Padded_Scale_M =
1341 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1345 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1346 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
1347 if(expert_block_id * MPerBlock >= max_token_id)
1350 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1351 const auto block_mn = [&]() -> std::pair<int, int> {
1352 if constexpr(NSwizzle)
1354 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1356 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1357 const index_t expert_swizzle =
1358 ecnt > 0 ? ecnt : 1;
1359 const index_t bid_new = blockIdx.x - prefix_block;
1360 const index_t nid = __builtin_amdgcn_readfirstlane(
1361 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1363 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1368 return {blockIdx.x, blockIdx.y};
1372 const index_t block_n_id = block_mn.first;
1373 const index_t block_m_id = block_mn.second;
1375 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1378 constexpr
auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
1379 constexpr
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
1380 constexpr
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
1381 constexpr
auto AKThreads = AK0Threads * AK1Threads;
1382 constexpr
auto AMRepeats = MPerBlock / AMThreads;
1383 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads;
1385 if(token_pos >= max_token_id || token0 >= problem.
NumTokens)
1389 const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads];
1390 index_t token_offset = fused_token & 0xffffff;
1391 if constexpr(!IsInputGemm)
1393 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1395 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
K;
1399 __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
1400 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1401 problem.
N * (IsInputGemm ? 2 : 1) *
1405 const index_t n_block_data_idx_on_grid =
1406 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave /
NXdlPack);
1409 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1410 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1411 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1412 p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
1415 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1416 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1417 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1418 p_b_scale_grid + (expert_id * expert_scale_stride) /
sizeof(BScaleDataType),
1419 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1431 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1432 ABlockTransferThreadClusterArrangeOrder,
1435 decltype(a_grid_desc_ak0_m_ak1),
1436 decltype(a_block_desc_ak0_m_ak1),
1437 ABlockTransferSrcAccessOrder,
1438 ABlockTransferSrcVectorDim,
1440 ABlockTransferSrcScalarPerVector,
1442 1>(a_grid_desc_ak0_m_ak1,
1444 a_block_desc_ak0_m_ak1,
1450 auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1451 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1452 auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1453 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1454 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
1456 auto b_blockwise_copy =
1459 decltype(b_grid_desc_bpreshuffled),
1460 decltype(b_block_desc_bk0_n_bk1),
1468 BBlockTransferSrcScalarPerVector,
1469 BThreadTransferSrcResetCoordinateAfterRun,
1471 b_grid_desc_bpreshuffled,
1480 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1481 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1487 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1489 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1490 decltype(c_thread_buf) c_thread_buf_up;
1494 c_thread_buf.num_of_v_,
1495 c_thread_buf.s_per_v,
1499 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1500 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1504 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
1505 const auto waveId_m = wave_idx[
I0];
1506 const auto waveId_n = wave_idx[
I1];
1508 auto thread_offset_shuffled =
1511 auto a_thread_offset_m = waveId_m;
1514 const index_t token_scale_pos = block_m_id * MPerBlock;
1515 if(token_scale_pos >= max_token_id || token0 >= problem.
NumTokens)
1521 decltype(a_scale_grid_desc_am_ak),
1522 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
1528 true>(a_scale_grid_desc_am_ak,
1534 auto b_thread_offset_n = waveId_n;
1539 decltype(b_scale_grid_desc_bn_ak),
1540 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1546 true>(b_scale_grid_desc_bn_ak,
1551 if constexpr(IsInputGemm)
1553 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
1554 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1555 p_b_grid_up + expert_id * expert_stride,
1556 b_grid_desc_bpreshuffled.GetElementSpaceSize());
1557 auto b_blockwise_copy_up =
1560 decltype(b_grid_desc_bpreshuffled),
1561 decltype(b_block_desc_bk0_n_bk1),
1569 BBlockTransferSrcScalarPerVector,
1570 BThreadTransferSrcResetCoordinateAfterRun,
1572 b_grid_desc_bpreshuffled,
1578 const BScaleDataType* p_b_scale_grid_up =
1579 p_b_scale_grid + expert_scale_stride / 2 /
sizeof(BScaleDataType);
1580 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1581 p_b_scale_grid_up + expert_id * expert_scale_stride /
sizeof(BScaleDataType),
1582 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1587 decltype(b_scale_grid_desc_bn_ak),
1588 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
1595 b_scale_grid_desc_bn_ak,
1600 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1602 a_grid_desc_ak0_m_ak1,
1603 a_block_desc_ak0_m_ak1,
1607 a_block_slice_copy_step,
1609 b_grid_desc_bpreshuffled,
1610 b_block_desc_bk0_n_bk1,
1612 b_blockwise_copy_up,
1616 b_block_slice_copy_step,
1621 a_scale_grid_desc_am_ak,
1622 a_scale_thread_copy,
1625 b_scale_grid_desc_bn_ak,
1626 b_scale_thread_copy,
1627 b_scale_thread_copy_up,
1629 b_scale_grid_buf_up,
1630 num_k_block_main_loop);
1634 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
1635 a_grid_desc_ak0_m_ak1,
1636 a_block_desc_ak0_m_ak1,
1640 a_block_slice_copy_step,
1641 b_grid_desc_bpreshuffled,
1642 b_block_desc_bk0_n_bk1,
1646 b_block_slice_copy_step,
1648 a_scale_grid_desc_am_ak,
1649 a_scale_thread_copy,
1651 b_scale_grid_desc_bn_ak,
1652 b_scale_thread_copy,
1654 num_k_block_main_loop);
1659 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1660 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1662 static_assert(CShuffleMXdlPerWavePerShuffle %
MXdlPack == 0 &&
1663 CShuffleNXdlPerWavePerShuffle %
NXdlPack == 0,
1666 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1669 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1670 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1674 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1675 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
1677 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
1678 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
1679 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
1680 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
1681 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
1682 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
1683 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
1684 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
1685 constexpr
auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I8);
1686 constexpr
auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I9);
1690 static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
1691 static_assert(M5 == 4);
1701 const index_t m_pos = block_m_id * MPerBlock +
1702 m0 * M2 * M1 * M3 * M4 * M5 +
1703 m1 * M2 * M3 * M4 * M5 +
1704 imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
1705 if constexpr(MulRoutedWeight)
1708 *c_style_pointer_cast<const vector_type<float, M5>*>(
1709 p_ds_grid[
I2] + m_pos);
1713 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1714 make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
1717 if constexpr(IsInputGemm)
1719 if constexpr(ActivationOperation ==
1722 float gate = c_thread_buf[cidx];
1723 float up = c_thread_buf_up[cidx];
1724 if constexpr(MulRoutedWeight)
1726 gate = gate * topk_weights.AsType<
float>()[m5];
1727 up = up * topk_weights.AsType<
float>()[m5];
1730 c_thread_buf_fp32(cidx) = gate * up;
1734 float gate = c_thread_buf[cidx];
1735 float up = c_thread_buf_up[cidx];
1736 if constexpr(MulRoutedWeight)
1738 gate = gate * topk_weights.AsType<
float>()[m5];
1739 up = up * topk_weights.AsType<
float>()[m5];
1742 c_thread_buf_fp32(cidx) = gate * up;
1747 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
1748 if constexpr(MulRoutedWeight)
1750 c_thread_buf_fp32(cidx) =
1751 topk_weights.AsType<
float>()[m5] *
1752 c_thread_buf_fp32[cidx];
1762 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1765 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1766 static_cast<CShuffleDataType*
>(p_shared),
1767 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1770 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1796 const auto c_thread_mtx_on_block =
1797 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1799 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1800 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1802 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1808 const auto m_thread_data_on_block_idx =
1809 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1812 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1818 const auto n_thread_data_on_block_idx =
1819 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1826 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1827 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1830 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
1839 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
1844 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1847 m_thread_data_on_block_idx[
I1],
1848 n_thread_data_on_block_idx[
I1],
1849 m_thread_data_on_block_idx[
I2],
1850 n_thread_data_on_block_idx[
I2],
1851 m_thread_data_on_block_idx[
I3],
1852 m_thread_data_on_block_idx[
I4],
1853 m_thread_data_on_block_idx[
I5],
1854 n_thread_data_on_block_idx[
I3]),
1857 using EDataType = CDataType;
1862 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1868 return make_dynamic_buffer<AddressSpaceEnum::Global>(
1869 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
1875 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1877 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1882 tie(c_shuffle_block_buf),
1884 {
return ds_grid_buf[i]; },
1888 const auto idx_c_ds_block_begin =
1898 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1899 c_grid_desc_mblock_mperblock_nblock_nperblock;
1901 using CDEBlockTransferCluster =
1902 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1903 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1904 constexpr
index_t scatter_weight_idx = 3;
1909 decltype(c_ds_desc_refs),
1910 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1911 CElementwiseOperation,
1916 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1918 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>,
1919 CDEBlockTransferCluster,
1925 CDEShuffleBlockTransferScalarPerVectors,
1937 idx_c_ds_block_begin,
1938 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1942 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1943 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1945 constexpr
auto sfc_c_vgpr =
1956 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
1958 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
1968 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1971 constexpr
auto sfc_cde_block =
1975 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1977 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>>{};
1979 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
1980 constexpr
auto EMThreads =
1981 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
1982 constexpr
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1983 constexpr
auto ENThreads =
1984 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
1989 auto dstidx = sfc_cde_block.GetIndex(access_id);
1991 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
1993 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1994 IndexType token_offset = fused_token & 0xffffff;
1995 if constexpr(IsInputGemm)
1997 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1999 scatter_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
N;
2005 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2006 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2008 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2009 c_shuffle_block_buf);
2015 cde_block_copy_lds_and_global.Run(
2018 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2022 if constexpr(access_id < num_access - 1)
2024 constexpr
auto cde_lds_and_global_step =
2025 sfc_cde_block.GetForwardStep(access_id);
2029 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2030 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
2034 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2035 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2037 cde_lds_and_global_step);
2043 template <
bool HasMainKBlockLoop,
2047 const index_t* p_sorted_expert_ids,
2048 const index_t* p_max_token_id,
2049 const ADataType* p_a_grid,
2050 const AScaleDataType* p_a_scale_grid,
2051 const BDataType* p_b_grid,
2052 const BScaleDataType* p_b_scale_grid,
2054 CDataType* p_c_grid,
2058 AElementwiseOperation a_element_op,
2059 BElementwiseOperation b_element_op,
2060 CElementwiseOperation c_element_op)
2073 const auto b_grid_desc_bpreshuffled =
2075 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
2083 const auto Padded_Scale_M =
2107 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2111 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
2112 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
2113 if(expert_block_id * MPerBlock >= max_token_id)
2116 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
2117 const auto block_mn = [&]() -> std::pair<int, int> {
2118 if constexpr(NSwizzle)
2120 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
2122 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
2123 const index_t expert_swizzle =
2124 ecnt > 0 ? ecnt : 1;
2125 const index_t bid_new = blockIdx.x - prefix_block;
2126 const index_t nid = __builtin_amdgcn_readfirstlane(
2127 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
2129 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
2134 return {blockIdx.x, blockIdx.y};
2138 const index_t block_n_id = block_mn.first;
2139 const index_t block_m_id = block_mn.second;
2141 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2144 constexpr
auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
2145 constexpr
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
2146 constexpr
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
2147 constexpr
auto AKThreads = AK0Threads * AK1Threads;
2148 constexpr
auto AMRepeats = MPerBlock / AMThreads;
2149 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads;
2151 if(token_pos >= max_token_id || token0 >= problem.
NumTokens)
2155 const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads];
2156 index_t token_offset = fused_token & 0xffffff;
2157 if constexpr(!IsInputGemm)
2159 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2161 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
K;
2165 __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
2166 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2167 problem.
N * (IsInputGemm ? 2 : 1) *
2171 const index_t n_block_data_idx_on_grid =
2172 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave /
NXdlPack);
2175 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2176 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2177 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2178 p_b_grid + expert_id * expert_stride, b_grid_desc_bpreshuffled.GetElementSpaceSize());
2181 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2182 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2183 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2184 p_b_scale_grid + (expert_id * expert_scale_stride) /
sizeof(BScaleDataType),
2185 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2197 ABlockTransferThreadClusterLengths_AK0_M_AK1,
2198 ABlockTransferThreadClusterArrangeOrder,
2201 decltype(a_grid_desc_ak0_m_ak1),
2202 decltype(a_block_desc_ak0_m_ak1),
2203 ABlockTransferSrcAccessOrder,
2204 ABlockTransferSrcVectorDim,
2206 ABlockTransferSrcScalarPerVector,
2208 1>(a_grid_desc_ak0_m_ak1,
2210 a_block_desc_ak0_m_ak1,
2216 auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2217 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2218 auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2219 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2220 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
2222 auto b_blockwise_copy =
2225 decltype(b_grid_desc_bpreshuffled),
2226 decltype(b_block_desc_bk0_n_bk1),
2234 BBlockTransferSrcScalarPerVector,
2235 BThreadTransferSrcResetCoordinateAfterRun,
2237 b_grid_desc_bpreshuffled,
2246 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2247 static_cast<ADataType*
>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2248 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2249 static_cast<ADataType*
>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2250 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
2256 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2258 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2259 decltype(c_thread_buf) c_thread_buf_up;
2263 c_thread_buf.num_of_v_,
2264 c_thread_buf.s_per_v,
2268 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2269 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
2273 const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
2274 const auto waveId_m = wave_idx[
I0];
2275 const auto waveId_n = wave_idx[
I1];
2277 auto thread_offset_shuffled =
2280 auto a_thread_offset_m = waveId_m;
2283 const index_t token_scale_pos = block_m_id * MPerBlock;
2284 if(token_scale_pos >= max_token_id || token0 >= problem.
NumTokens)
2290 decltype(a_scale_grid_desc_am_ak),
2291 decltype(BlockwiseGemmPipe::a_scale_thread_desc),
2297 true>(a_scale_grid_desc_am_ak,
2303 auto b_thread_offset_n = waveId_n;
2308 decltype(b_scale_grid_desc_bn_ak),
2309 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2315 true>(b_scale_grid_desc_bn_ak,
2320 if constexpr(IsInputGemm)
2322 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
2323 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2324 p_b_grid_up + expert_id * expert_stride,
2325 b_grid_desc_bpreshuffled.GetElementSpaceSize());
2326 auto b_blockwise_copy_up =
2329 decltype(b_grid_desc_bpreshuffled),
2330 decltype(b_block_desc_bk0_n_bk1),
2338 BBlockTransferSrcScalarPerVector,
2339 BThreadTransferSrcResetCoordinateAfterRun,
2341 b_grid_desc_bpreshuffled,
2347 const BScaleDataType* p_b_scale_grid_up =
2348 p_b_scale_grid + expert_scale_stride / 2 /
sizeof(BScaleDataType);
2349 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2350 p_b_scale_grid_up + expert_id * expert_scale_stride /
sizeof(BScaleDataType),
2351 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2356 decltype(b_scale_grid_desc_bn_ak),
2357 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
2364 b_scale_grid_desc_bn_ak,
2369 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2371 a_grid_desc_ak0_m_ak1,
2372 a_block_desc_ak0_m_ak1,
2376 a_block_slice_copy_step,
2378 b_grid_desc_bpreshuffled,
2379 b_block_desc_bk0_n_bk1,
2381 b_blockwise_copy_up,
2385 b_block_slice_copy_step,
2390 a_scale_grid_desc_am_ak,
2391 a_scale_thread_copy,
2394 b_scale_grid_desc_bn_ak,
2395 b_scale_thread_copy,
2396 b_scale_thread_copy_up,
2398 b_scale_grid_buf_up,
2399 num_k_block_main_loop);
2403 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
2404 a_grid_desc_ak0_m_ak1,
2405 a_block_desc_ak0_m_ak1,
2409 a_block_slice_copy_step,
2410 b_grid_desc_bpreshuffled,
2411 b_block_desc_bk0_n_bk1,
2415 b_block_slice_copy_step,
2417 a_scale_grid_desc_am_ak,
2418 a_scale_thread_copy,
2420 b_scale_grid_desc_bn_ak,
2421 b_scale_thread_copy,
2423 num_k_block_main_loop);
2428 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2429 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2431 static_assert(CShuffleMXdlPerWavePerShuffle %
MXdlPack == 0 &&
2432 CShuffleNXdlPerWavePerShuffle %
NXdlPack == 0,
2435 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2438 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
2439 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2443 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
2444 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
2446 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
2447 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
2448 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
2449 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
2450 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
2451 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
2452 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
2453 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
2454 constexpr
auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I8);
2455 constexpr
auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I9);
2459 static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
2460 static_assert(M5 == 4);
2470 const index_t m_pos = block_m_id * MPerBlock +
2471 m0 * M2 * M1 * M3 * M4 * M5 +
2472 m1 * M2 * M3 * M4 * M5 +
2473 imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
2474 if constexpr(MulRoutedWeight)
2477 *c_style_pointer_cast<const vector_type<float, M5>*>(
2478 p_ds_grid[
I2] + m_pos);
2482 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2483 make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
2486 if constexpr(IsInputGemm)
2488 if constexpr(ActivationOperation ==
2491 float gate = c_thread_buf[cidx];
2492 float up = c_thread_buf_up[cidx];
2493 if constexpr(MulRoutedWeight)
2495 gate = gate * topk_weights.AsType<
float>()[m5];
2496 up = up * topk_weights.AsType<
float>()[m5];
2499 c_thread_buf_fp32(cidx) = gate * up;
2503 float gate = c_thread_buf[cidx];
2504 float up = c_thread_buf_up[cidx];
2505 if constexpr(MulRoutedWeight)
2507 gate = gate * topk_weights.AsType<
float>()[m5];
2508 up = up * topk_weights.AsType<
float>()[m5];
2511 c_thread_buf_fp32(cidx) = gate * up;
2516 c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
2517 if constexpr(MulRoutedWeight)
2519 c_thread_buf_fp32(cidx) =
2520 topk_weights.AsType<
float>()[m5] *
2521 c_thread_buf_fp32[cidx];
2531 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2534 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2535 static_cast<CShuffleDataType*
>(p_shared_0),
2536 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2539 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2565 const auto c_thread_mtx_on_block =
2566 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
2568 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
2569 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
2571 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2577 const auto m_thread_data_on_block_idx =
2578 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2581 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2587 const auto n_thread_data_on_block_idx =
2588 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2595 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2596 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2599 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
2608 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
2613 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2616 m_thread_data_on_block_idx[
I1],
2617 n_thread_data_on_block_idx[
I1],
2618 m_thread_data_on_block_idx[
I2],
2619 n_thread_data_on_block_idx[
I2],
2620 m_thread_data_on_block_idx[
I3],
2621 m_thread_data_on_block_idx[
I4],
2622 m_thread_data_on_block_idx[
I5],
2623 n_thread_data_on_block_idx[
I3]),
2626 using EDataType = CDataType;
2631 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2637 return make_dynamic_buffer<AddressSpaceEnum::Global>(
2638 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2644 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2646 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2651 tie(c_shuffle_block_buf),
2653 {
return ds_grid_buf[i]; },
2657 const auto idx_c_ds_block_begin =
2667 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2668 c_grid_desc_mblock_mperblock_nblock_nperblock;
2670 using CDEBlockTransferCluster =
2671 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2672 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2673 constexpr
index_t scatter_weight_idx = 3;
2678 decltype(c_ds_desc_refs),
2679 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2680 CElementwiseOperation,
2685 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2687 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>,
2688 CDEBlockTransferCluster,
2694 CDEShuffleBlockTransferScalarPerVectors,
2706 idx_c_ds_block_begin,
2707 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2711 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2712 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2714 constexpr
auto sfc_c_vgpr =
2725 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
2727 CShuffleNXdlPerWavePerShuffle /
NXdlPack,
2737 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2740 constexpr
auto sfc_cde_block =
2744 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2746 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>>{};
2748 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
2749 constexpr
auto EMThreads =
2750 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
2751 constexpr
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2752 constexpr
auto ENThreads =
2753 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
2758 auto dstidx = sfc_cde_block.GetIndex(access_id);
2760 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
2762 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2763 IndexType token_offset = fused_token & 0xffffff;
2764 if constexpr(IsInputGemm)
2766 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2768 scatter_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
N;
2774 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2775 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2777 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2778 c_shuffle_block_buf);
2784 cde_block_copy_lds_and_global.Run(
2787 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2791 if constexpr(access_id < num_access - 1)
2793 constexpr
auto cde_lds_and_global_step =
2794 sfc_cde_block.GetForwardStep(access_id);
2798 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2799 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
2803 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2804 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2806 cde_lds_and_global_step);
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:30
#define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_)
Definition: device_base.hpp:178
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:47
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:45
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:48
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:277
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
constexpr auto BlockGemmMXBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_mx_moe_selector.hpp:37
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm.hpp:90
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:185
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:748
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:809
const index_t * p_sorted_token_ids
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:808
const BScaleDataType * p_b_scale_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:814
const BElementwiseOperation b_element_op
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:819
const BDataType * p_b_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:813
const index_t * p_max_token_id
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:810
DsGridPointer p_ds_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:815
const AElementwiseOperation a_element_op
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:818
CDataType * p_c_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:816
const AScaleDataType * p_a_scale_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:812
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:749
const CElementwiseOperation c_element_op
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:820
const ADataType * p_a_grid
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:811
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:676
index_t AK0
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:740
index_t NPadded
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:737
index_t KBatch
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:735
index_t BK0
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:741
index_t TopK
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:725
__host__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:677
index_t KRead
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:738
index_t K
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:728
index_t MPadded
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:736
index_t StrideC
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:734
index_t StrideScaleB
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:732
index_t NumTokens
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:724
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:733
index_t MBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:742
index_t StrideScaleA
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:730
index_t N
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:727
__host__ void Print() const
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:712
index_t StrideA
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:729
index_t StrideB
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:731
index_t M
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:726
index_t KPadded
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:739
index_t NBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:743
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:824
index_t a_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:866
index_t b_scale_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:867
index_t a_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:864
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:825
index_t b_k_split_offset
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:865
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:174
static constexpr auto I6
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:184
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:302
static constexpr auto AK1Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:194
static constexpr auto AK0Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:192
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1044
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:252
static constexpr auto I1
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:179
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:285
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:256
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:652
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:2046
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1260
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:664
static constexpr index_t NLane
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:225
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:990
static constexpr index_t SortedTileSize
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:231
static constexpr index_t scale_pack_size_b
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:235
static constexpr auto BK1Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:195
remove_cvref_t< decltype(BlockGemmMXBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1042
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1252
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:271
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:241
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:372
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:280
static constexpr auto BK0Number
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:193
static constexpr auto I9
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:187
static constexpr auto I4
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:182
static constexpr auto I3
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:181
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:308
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:290
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:266
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:189
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:315
static constexpr index_t KRepeat
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:228
static constexpr auto NXdlPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:204
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1000
static constexpr auto KXdlPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:205
BDataType LDSTypeB
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:176
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:480
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1281
static constexpr auto I0
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:178
static constexpr auto I5
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:183
static constexpr index_t scale_pack_size_a
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:234
static constexpr auto lcm_AK1_BK1
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:197
static constexpr index_t KLane
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:226
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:870
static constexpr auto MXdlPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:203
static constexpr auto I8
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:186
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:276
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:598
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:607
static constexpr bool is_single_rate_mfma
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:198
static constexpr auto I2
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:180
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:588
static constexpr index_t APackedSize
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:213
static constexpr index_t NumDTensor
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:201
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:320
static constexpr index_t BPackedSize
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:214
static constexpr index_t KPack
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:222
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1246
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:1068
ADataType LDSTypeA
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:175
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:296
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:631
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:331
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:254
static constexpr auto I7
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:185
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:471
static constexpr index_t NWave
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:227
static constexpr auto is_scale_mfma
Definition: gridwise_moe_mx_gemm_bpreshuffle.hpp:199
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1208
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1757
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Definition: static_buffer.hpp:75
Definition: thread_group_tensor_slice_transfer_gather_direct_load.hpp:57
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:26
Definition: data_type.hpp:42
Definition: integral_constant.hpp:20
Definition: data_type.hpp:187
Definition: functional2.hpp:33
Definition: device_base.hpp:197
Definition: unary_element_wise_operation.hpp:1041
Definition: unary_element_wise_operation.hpp:340
Definition: unary_element_wise_operation.hpp:1087
Definition: dtype_vector.hpp:10
#define CK_ENV(name)
Definition: env.hpp:129