36 template <
typename GridwiseGemm,
37 bool HasMainKBlockLoop,
42 #if CK_USE_LAUNCH_BOUNDS
48 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
49 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
51 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
53 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
54 karg.p_sorted_token_ids,
55 karg.p_sorted_expert_ids,
57 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
58 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
73 template <
typename GridwiseGemm,
74 bool HasMainKBlockLoop,
79 #if CK_USE_LAUNCH_BOUNDS
85 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
86 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
87 __shared__
char p_shared1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
89 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
91 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
92 karg.p_sorted_token_ids,
93 karg.p_sorted_expert_ids,
95 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
96 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
112 template <
typename ALayout,
118 typename AccDataType,
119 typename CShuffleDataType,
122 typename AElementwiseOperation,
123 typename BElementwiseOperation,
124 typename CElementwiseOperation,
139 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
140 typename ABlockTransferThreadClusterArrangeOrder,
141 typename ABlockTransferSrcAccessOrder,
142 index_t ABlockTransferSrcVectorDim,
143 index_t ABlockTransferSrcScalarPerVector,
144 index_t ABlockTransferDstScalarPerVector_AK1,
145 bool AThreadTransferSrcResetCoordinateAfterRun,
147 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
148 typename BBlockTransferThreadClusterArrangeOrder,
149 typename BBlockTransferSrcAccessOrder,
150 index_t BBlockTransferSrcVectorDim,
151 index_t BBlockTransferSrcScalarPerVector,
152 index_t BBlockTransferDstScalarPerVector_BK1,
153 bool BThreadTransferSrcResetCoordinateAfterRun,
155 index_t CShuffleMXdlPerWavePerShuffle,
156 index_t CShuffleNXdlPerWavePerShuffle,
157 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
158 typename CDEShuffleBlockTransferScalarPerVectors,
161 index_t ActivationOperation = 0,
162 bool NSwizzle =
false,
163 bool IsInputGemm =
true,
164 bool MulRoutedWeight =
true,
166 typename ComputeTypeA = CDataType,
167 typename ComputeTypeB = ComputeTypeA,
168 typename LDSTypeA = ADataType,
169 typename LDSTypeB = BDataType>
185 CDEShuffleBlockTransferScalarPerVectors{}[
I0];
223 return static_cast<const DDataType*
>(
nullptr);
250 const index_t gridx = NSwizzle ? nblock * mblock : nblock;
251 const index_t gridy = NSwizzle ? 1 : mblock;
281 auto K_t = K_Batch * KPerBlock;
282 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
287 auto K_t = K_Batch * KPerBlock;
288 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
293 auto K_t = K_Batch * KPerBlock;
294 return (K + K_t - 1) / K_t * KPerBlock;
300 auto K_t = K_Batch * KReadVec;
301 return (K + K_t - 1) / K_t * KReadVec;
314 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl,
typename TileDesc_K0_MN_K1>
330 IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
332 const auto a_grid_desc_mraw_kraw = [&]() {
333 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
337 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
345 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
346 GemmSpec == GemmSpecialization::MNKPadding)
349 const auto a_grid_desc_m_k =
363 return a_grid_desc_ak0_m_ak1;
365 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
366 GemmSpec == GemmSpecialization::MNPadding)
370 a_grid_desc_mraw_kraw,
376 return a_grid_desc_ak0_m_ak1;
378 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
379 GemmSpec == GemmSpecialization::NKPadding)
383 a_grid_desc_mraw_kraw,
395 return a_grid_desc_ak0_m_ak1;
401 a_grid_desc_mraw_kraw,
407 return a_grid_desc_ak0_m_ak1;
416 make_tuple(
NWave * K0 * NkSwizzleNumber, K0 * NkSwizzleNumber, NkSwizzleNumber,
I1));
422 const auto b_grid_desc_nraw_kraw = [&]() {
436 GemmSpec != GemmSpecialization::Default),
437 "pk_i4_t does not support padding");
439 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
440 GemmSpec == GemmSpecialization::MNKPadding)
443 const auto b_grid_desc_n_k =
457 return b_grid_desc_bk0_n_bk1;
459 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
460 GemmSpec == GemmSpecialization::MNPadding)
464 b_grid_desc_nraw_kraw,
470 return b_grid_desc_bk0_n_bk1;
472 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
473 GemmSpec == GemmSpecialization::MKPadding)
477 b_grid_desc_nraw_kraw,
489 return b_grid_desc_bk0_n_bk1;
495 b_grid_desc_nraw_kraw,
501 return b_grid_desc_bk0_n_bk1;
505 template <
typename ABlockDesc_AK0_M_AK1>
506 __host__ __device__
static constexpr
auto
509 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
511 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
514 template <
typename BBlockDesc_BK0_N_BK1>
515 __host__ __device__
static constexpr
auto
518 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWave, NPerXdl>(BBlockDesc_BK0_N_BK1{});
521 template <
typename ELayout>
523 IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
525 const auto c_grid_desc_mraw_nraw = [&]() {
544 template <
typename DLayout>
545 __host__ __device__
static auto
548 const auto c_grid_desc_mraw_nraw = [&]() {
573 return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
578 template <
typename DsGr
idDesc>
580 const DsGridDesc& ds_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
585 ds_grid_desc_m_n[i], MBlock, NBlock);
601 std::array<index_t, NumDTensor> StrideDs_,
629 std::cout <<
"problem {"
631 <<
"TopK:" <<
TopK <<
", "
640 <<
"KRead:" <<
KRead <<
", "
642 <<
"AK0:" <<
AK0 <<
", "
643 <<
"BK0:" <<
BK0 <<
", "
644 <<
"MBlock: " <<
MBlock <<
", "
645 <<
"NBlock: " <<
NBlock <<
"}" << std::endl;
675 const index_t* p_sorted_expert_ids_,
676 const index_t* p_max_token_id_,
677 const ADataType* p_a_grid_,
678 const BDataType* p_b_grid_,
679 std::array<const void*, NumDTensor> p_ds_grid_,
680 CDataType* p_c_grid_,
688 std::array<index_t, NumDTensor> StrideDs_,
693 AElementwiseOperation a_element_op_,
694 BElementwiseOperation b_element_op_,
695 CElementwiseOperation c_element_op_)
725 p_ds_grid(i) =
static_cast<const DDataType_*
>(p_ds_grid_[i]);
749 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
753 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
758 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
762 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
768 if(k_id < karg.
KBatch - 1)
785 if constexpr(ABlockLdsExtraM)
795 constexpr
auto a_lds_block_desc =
807 return a_lds_block_desc_permuted;
814 constexpr
auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
815 constexpr
auto M1 = MPerBlock / M0;
817 constexpr
auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
818 constexpr
auto K0PerThreadWrite =
AK0Number / KThreadWrite;
819 constexpr
auto KThreadRead = 64 / MPerXdl;
820 constexpr
auto K0PerThreadRead =
AK0Number / KThreadRead;
822 constexpr
auto kfold = (
AK1Number * M0 *
sizeof(LDSTypeA) > 128)
824 : 128 / (
AK1Number * M0 *
sizeof(LDSTypeA));
825 constexpr
auto KThreadReadPerm =
826 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
827 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
831 constexpr
auto mpair = (
AK1Number * MPerXdl *
sizeof(LDSTypeA) > 128)
833 : ((128 / (
AK1Number * MPerXdl *
sizeof(LDSTypeA))) > M0
835 : 128 / (
AK1Number * MPerXdl *
sizeof(LDSTypeA)));
841 Number<kfold * M0 / mpair>{},
860 a_lds_block_desc_permuted,
882 a_lds_block_desc_unmerged,
885 Number<KThreadWrite / kfold / KThreadReadPerm>{},
894 return a_lds_block_desc_ak0_m_ak1;
907 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
909 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
916 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
934 ABlockTransferSrcScalarPerVector,
935 BBlockTransferSrcScalarPerVector,
957 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
960 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
963 constexpr
auto c_block_size =
964 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
967 c_block_size *
sizeof(CShuffleDataType));
973 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
974 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
975 "Invalid tuning param!");
983 if(!(karg.
M % MPerBlock == 0))
986 std::cout <<
"Arg M value is not a multiple of MPerBlock! M: " << karg.
M <<
" "
987 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1001 if(!(karg.
N % NPerBlock == 0))
1004 std::cout <<
"Arg N value is not a multiple of NPerBlock! N: " << karg.
N <<
" "
1005 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1019 auto K_t = karg.
KBatch * KPerBlock;
1020 if(!(karg.
K % K_t == 0))
1023 std::cout <<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1024 << karg.
K <<
" " << __FILE__ <<
":" << __LINE__
1025 <<
", in function: " << __func__ << std::endl;
1034 auto K_t = karg.
KBatch * KReadVec;
1036 if((KReadPadSplited * (karg.
KBatch - 1)) >= karg.
K)
1044 if(karg.
K % ABlockTransferSrcScalarPerVector != 0)
1047 std::cout <<
"Arg K (" << karg.
K
1048 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1049 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1050 << __LINE__ <<
", in function: " << __func__ << std::endl;
1058 if(karg.
M % ABlockTransferSrcScalarPerVector != 0)
1061 std::cout <<
"Arg M (" << karg.
M
1062 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1063 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1064 << __LINE__ <<
", in function: " << __func__ << std::endl;
1073 if(karg.
N % BBlockTransferSrcScalarPerVector != 0)
1076 std::cout <<
"Arg N (" << karg.
N
1077 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1078 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1079 << __LINE__ <<
", in function: " << __func__ << std::endl;
1087 if(karg.
K % BBlockTransferSrcScalarPerVector != 0)
1090 std::cout <<
"Arg K (" << karg.
K
1091 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1092 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1093 << __LINE__ <<
", in function: " << __func__ << std::endl;
1105 std::cout <<
"Arg N (" << karg.
N
1106 <<
") value is not a multiple of "
1107 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1109 <<
":" << __LINE__ <<
", in function: " << __func__ << std::endl;
1120 std::cout <<
"Arg M (" << karg.
M
1121 <<
") value is not a multiple of "
1122 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1124 <<
":" << __LINE__ <<
", in function: " << __func__ << std::endl;
1133 const auto num_k_loop = karg.
AK0 / (KPerBlock / AK1Value);
1135 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1146 const index_t num_loop = K / KPerBlock;
1148 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1153 const index_t num_loop = K / KPerBlock;
1155 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1158 template <
typename CGr
idDesc>
1160 const CGridDesc& c_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1169 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1177 template <
bool HasMainKBlockLoop,
1181 const index_t* p_sorted_expert_ids,
1182 const index_t* p_max_token_id,
1183 const ADataType* p_a_grid,
1184 const BDataType* p_b_grid,
1186 CDataType* p_c_grid,
1191 AElementwiseOperation a_element_op,
1192 BElementwiseOperation b_element_op,
1193 CElementwiseOperation c_element_op)
1203 const auto b_grid_desc_bpreshuffled =
1205 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1223 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1226 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1228 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
1229 if(expert_block_id * MPerBlock >= max_token_id)
1232 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1233 const auto block_mn = [&]() -> std::pair<int, int> {
1234 if constexpr(NSwizzle)
1236 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1238 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1239 const index_t expert_swizzle =
1240 ecnt > 0 ? ecnt : 1;
1241 const index_t bid_new = blockIdx.x - prefix_block;
1242 const index_t nid = __builtin_amdgcn_readfirstlane(
1243 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1245 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1250 return {blockIdx.x, blockIdx.y};
1253 const index_t block_n_id = block_mn.first;
1254 const index_t block_m_id = block_mn.second;
1256 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
1259 constexpr
auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
1260 constexpr
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
1261 constexpr
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
1262 constexpr
auto AKThreads = AK0Threads * AK1Threads;
1263 constexpr
auto AMRepeats = MPerBlock / AMThreads;
1264 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
1266 if(token_pos >= max_token_id || token0 >= problem.
NumTokens)
1270 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
1271 index_t token_offset = fused_token & 0xffffff;
1272 if constexpr(!IsInputGemm)
1274 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1276 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
K;
1279 __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
1280 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
1285 const index_t n_block_data_idx_on_grid =
1286 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
1288 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1289 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1290 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1292 b_grid_desc_bpreshuffled.GetElementSpaceSize());
1294 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1295 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
1296 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1297 p_b_scale_grid + expert_id * expert_scale_stride,
1298 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1309 AElementwiseOperation,
1313 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1314 ABlockTransferThreadClusterArrangeOrder,
1317 decltype(a_grid_desc_ak0_m_ak1),
1318 decltype(a_block_desc_ak0_m_ak1),
1319 ABlockTransferSrcAccessOrder,
1321 ABlockTransferSrcVectorDim,
1323 ABlockTransferSrcScalarPerVector,
1324 ABlockTransferDstScalarPerVector_AK1,
1327 AThreadTransferSrcResetCoordinateAfterRun,
1331 BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
1334 a_block_desc_ak0_m_ak1,
1341 auto b_block_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
1342 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1347 decltype(b_grid_desc_bpreshuffled),
1348 decltype(b_block_desc_bk0_n_bk1),
1352 BBlockTransferSrcScalarPerVector,
1353 BThreadTransferSrcResetCoordinateAfterRun,
1354 true>(b_grid_desc_bpreshuffled,
1362 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1363 static_cast<LDSTypeA*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1369 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1371 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1372 decltype(c_thread_buf) c_thread_buf_up;
1374 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1375 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1378 constexpr
index_t ScaleSliceSizeM = MXdlPerWave;
1387 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
1388 constexpr
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
1389 auto a_thread_offset =
1400 const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
1402 if(token_scale_pos >= max_token_id || token0 >= problem.
NumTokens)
1407 p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
1408 index_t token_offset = fused_token & 0xffffff;
1409 if constexpr(!IsInputGemm)
1411 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1413 scale_gather_offsets(m0) =
1417 auto a_scale_thread_copy =
1420 decltype(a_scale_grid_desc_am_ak),
1421 decltype(a_scale_thread_desc),
1431 auto b_scale_thread_copy =
1434 decltype(b_scale_grid_desc_bn_ak),
1435 decltype(b_scale_thread_desc),
1442 b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
1445 constexpr
auto a_scale_thread_slice_copy_step =
1447 constexpr
auto b_scale_thread_slice_copy_step =
make_multi_index(0, ScaleSliceSizeK);
1450 if constexpr(IsInputGemm)
1452 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 /
BPackedSize;
1453 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1455 b_grid_desc_bpreshuffled.GetElementSpaceSize());
1459 decltype(b_grid_desc_bpreshuffled),
1460 decltype(b_block_desc_bk0_n_bk1),
1464 BBlockTransferSrcScalarPerVector,
1465 BThreadTransferSrcResetCoordinateAfterRun,
1466 true>(b_grid_desc_bpreshuffled,
1472 p_b_scale_grid + expert_scale_stride / 2 /
BPackedSize;
1473 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
1474 p_b_scale_grid_up + expert_id * expert_scale_stride,
1475 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
1476 auto b_scale_thread_copy_up =
1479 decltype(b_scale_grid_desc_bn_ak),
1480 decltype(b_scale_thread_desc),
1487 b_scale_grid_desc_bn_ak,
1490 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1491 a_grid_desc_ak0_m_ak1,
1492 a_block_desc_ak0_m_ak1,
1496 a_block_slice_copy_step,
1498 b_grid_desc_bpreshuffled,
1499 b_block_desc_bk0_n_bk1,
1501 b_blockwise_copy_up,
1505 b_block_slice_copy_step,
1507 c_scale_thread_desc,
1511 a_scale_grid_desc_am_ak,
1512 a_scale_thread_desc,
1513 a_scale_thread_copy,
1515 a_scale_thread_slice_copy_step,
1517 b_scale_grid_desc_bn_ak,
1518 b_scale_thread_desc,
1519 b_scale_thread_copy,
1520 b_scale_thread_copy_up,
1522 b_scale_grid_buf_up,
1523 b_scale_thread_slice_copy_step,
1525 num_k_block_main_loop);
1529 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
1530 a_grid_desc_ak0_m_ak1,
1531 a_block_desc_ak0_m_ak1,
1535 a_block_slice_copy_step,
1537 b_grid_desc_bpreshuffled,
1538 b_block_desc_bk0_n_bk1,
1542 b_block_slice_copy_step,
1544 c_scale_thread_desc,
1547 a_scale_grid_desc_am_ak,
1548 a_scale_thread_desc,
1549 a_scale_thread_copy,
1551 a_scale_thread_slice_copy_step,
1553 b_scale_grid_desc_bn_ak,
1554 b_scale_thread_desc,
1555 b_scale_thread_copy,
1557 b_scale_thread_slice_copy_step,
1559 num_k_block_main_loop);
1564 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1565 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1568 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1572 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
1573 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1577 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
1578 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
1580 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I0);
1581 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I1);
1582 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I2);
1583 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I3);
1584 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I4);
1585 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I5);
1586 constexpr
auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I6);
1587 constexpr
auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I7);
1589 static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
1590 static_assert(M0 * M1 * M2 == MPerBlock);
1591 static_assert(N4 == 4);
1598 if constexpr(MulRoutedWeight)
1600 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
1601 topk_weight = p_ds_grid[
I0][m_pos];
1606 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
1609 if constexpr(IsInputGemm)
1613 float gate = c_thread_buf[cidx];
1614 float up = c_thread_buf_up[cidx];
1615 if constexpr(MulRoutedWeight)
1617 gate = gate * topk_weight;
1618 up = up * topk_weight;
1626 c_thread_buf(cidx) = gate * up;
1630 float gate = c_thread_buf[cidx];
1631 float up = c_thread_buf_up[cidx];
1632 if constexpr(MulRoutedWeight)
1634 gate = gate * topk_weight;
1635 up = up * topk_weight;
1643 c_thread_buf(cidx) = gate * up;
1648 if constexpr(MulRoutedWeight)
1650 c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
1658 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1661 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1662 static_cast<CShuffleDataType*
>(p_shared),
1663 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1666 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1686 const auto c_thread_mtx_on_block =
1687 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1689 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1690 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1692 const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
1698 const auto m_thread_data_on_block_idx =
1699 m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
1702 const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
1708 const auto n_thread_data_on_block_idx =
1709 n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
1713 auto c_thread_copy_vgpr_to_lds =
1716 decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1717 decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
1719 Sequence<CShuffleMXdlPerWavePerShuffle,
1720 CShuffleNXdlPerWavePerShuffle,
1733 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1736 m_thread_data_on_block_idx[
I1],
1737 n_thread_data_on_block_idx[
I1],
1738 m_thread_data_on_block_idx[
I2],
1739 n_thread_data_on_block_idx[
I2],
1740 n_thread_data_on_block_idx[
I3],
1741 n_thread_data_on_block_idx[
I4]),
1744 using EDataType = CDataType;
1749 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
1756 const DDataType* ptr_ = p_ds_grid[i];
1759 return make_dynamic_buffer<AddressSpaceEnum::Global>(
1760 ptr_, ds_grid_desc_m_n[i].GetElementSpaceSize());
1766 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1768 [&](
auto i) ->
const auto&
1769 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
1774 tie(c_shuffle_block_buf),
1776 [&](
auto i) ->
const auto&
1777 {
return ds_grid_buf[i]; },
1781 const auto idx_c_ds_block_begin =
1791 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
1792 c_grid_desc_mblock_mperblock_nblock_nperblock;
1794 using CDEBlockTransferCluster =
1795 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
1796 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
1797 constexpr
index_t scatter_weight_idx = IsInputGemm ? 1 : 1;
1802 decltype(c_ds_desc_refs),
1803 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
1804 CElementwiseOperation,
1808 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1810 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>,
1811 CDEBlockTransferCluster,
1817 CDEShuffleBlockTransferScalarPerVectors,
1829 idx_c_ds_block_begin,
1830 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1834 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1835 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1837 constexpr
auto sfc_c_vgpr =
1840 Sequence<CShuffleMXdlPerWavePerShuffle,
1841 CShuffleNXdlPerWavePerShuffle,
1849 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1852 constexpr
auto sfc_cde_block =
1856 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1858 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>>{};
1860 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
1861 constexpr
auto EMThreads =
1862 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
1863 constexpr
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
1864 constexpr
auto ENThreads =
1865 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
1870 auto dstidx = sfc_cde_block.GetIndex(access_id);
1872 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
1874 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
1875 index_t token_offset = fused_token & 0xffffff;
1876 if constexpr(IsInputGemm)
1878 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
1880 scatter_offsets(m0) = token_offset * problem.
N;
1886 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1887 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1889 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
1890 c_shuffle_block_buf);
1896 cde_block_copy_lds_and_global.Run(
1899 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1903 if constexpr(access_id < num_access - 1)
1905 constexpr
auto cde_lds_and_global_step =
1906 sfc_cde_block.GetForwardStep(access_id);
1910 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
1911 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
1915 cde_block_copy_lds_and_global.MoveDstSliceWindow(
1916 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
1918 cde_lds_and_global_step);
1924 template <
bool HasMainKBlockLoop,
1928 const index_t* p_sorted_expert_ids,
1929 const index_t* p_max_token_id,
1930 const ADataType* p_a_grid,
1931 const BDataType* p_b_grid,
1933 CDataType* p_c_grid,
1939 AElementwiseOperation a_element_op,
1940 BElementwiseOperation b_element_op,
1941 CElementwiseOperation c_element_op)
1951 const auto b_grid_desc_bpreshuffled =
1953 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
1970 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1973 const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
1974 const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
1975 if(expert_block_id * MPerBlock >= max_token_id)
1978 __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
1979 const auto block_mn = [&]() -> std::pair<int, int> {
1980 if constexpr(NSwizzle)
1982 const index_t ecnt_prefix = p_max_token_id[1 + expert_id];
1984 const index_t ecnt = p_max_token_id[2 + expert_id] - ecnt_prefix;
1985 const index_t expert_swizzle = ecnt > 0 ? ecnt : 1;
1986 const index_t bid_new = blockIdx.x - prefix_block;
1987 const index_t nid = __builtin_amdgcn_readfirstlane(
1988 bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
1990 __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
1995 return {blockIdx.x, blockIdx.y};
1998 const index_t block_n_id = block_mn.first;
1999 const index_t block_m_id = block_mn.second;
2002 __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
2005 constexpr
auto AMThreads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
2006 constexpr
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
2007 constexpr
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
2008 constexpr
auto AKThreads = AK0Threads * AK1Threads;
2009 constexpr
auto AMRepeats = MPerBlock / AMThreads;
2010 const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
2012 if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id ||
2018 const index_t fused_token = p_sorted_token_ids[token_pos + m0];
2019 index_t token_offset = fused_token & 0xffffff;
2020 if constexpr(!IsInputGemm)
2022 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2024 gather_offsets(m0) =
static_cast<IndexType
>(token_offset) * problem.
K;
2027 __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
2028 const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
2032 const index_t n_block_data_idx_on_grid =
2033 __builtin_amdgcn_readfirstlane(block_n_id * NXdlPerWave);
2035 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2036 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
2037 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2039 b_grid_desc_bpreshuffled.GetElementSpaceSize());
2041 const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2042 p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
2043 const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2044 p_b_scale_grid + expert_id * expert_scale_stride,
2045 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2056 AElementwiseOperation,
2060 ABlockTransferThreadClusterLengths_AK0_M_AK1,
2061 ABlockTransferThreadClusterArrangeOrder,
2064 decltype(a_grid_desc_ak0_m_ak1),
2065 decltype(a_block_desc_ak0_m_ak1),
2066 ABlockTransferSrcAccessOrder,
2068 ABlockTransferSrcVectorDim,
2070 ABlockTransferSrcScalarPerVector,
2071 ABlockTransferDstScalarPerVector_AK1,
2074 AThreadTransferSrcResetCoordinateAfterRun,
2078 BlockwiseGemmPipe::GlobalBufferNum>(a_grid_desc_ak0_m_ak1,
2081 a_block_desc_ak0_m_ak1,
2088 auto b_block_buf_ping = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2089 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2090 auto b_block_buf_pong = make_static_buffer<AddressSpaceEnum::Vgpr, BDataType>(
2091 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
2092 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
2097 decltype(b_grid_desc_bpreshuffled),
2098 decltype(b_block_desc_bk0_n_bk1),
2102 BBlockTransferSrcScalarPerVector,
2103 BThreadTransferSrcResetCoordinateAfterRun,
2104 true>(b_grid_desc_bpreshuffled,
2112 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2113 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2114 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2115 static_cast<ADataType*
>(p_shared1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
2116 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
2122 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
2124 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
2125 decltype(c_thread_buf) c_thread_buf_up;
2127 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
2128 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
2132 constexpr
index_t ScaleSliceSizeM = MXdlPerWave;
2141 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
2142 constexpr
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
2143 auto a_thread_offset =
2154 const index_t token_scale_pos = block_m_id * MPerBlock / ScaleBlockM;
2156 if(token_scale_pos >= max_token_id || token0 >= problem.
NumTokens)
2161 p_sorted_token_ids[token_scale_pos + m0 * MPerXdl * MWaves + a_thread_offset];
2162 index_t token_offset = fused_token & 0xffffff;
2163 if constexpr(!IsInputGemm)
2165 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2167 scale_gather_offsets(m0) =
static_cast<IndexType
>(token_offset) *
2171 auto a_scale_thread_copy =
2174 decltype(a_scale_grid_desc_am_ak),
2175 decltype(a_scale_thread_desc),
2185 auto b_scale_thread_copy =
2188 decltype(b_scale_grid_desc_bn_ak),
2189 decltype(b_scale_thread_desc),
2196 b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN, 0));
2199 constexpr
auto a_scale_thread_slice_copy_step =
2201 constexpr
auto b_scale_thread_slice_copy_step =
make_multi_index(0, ScaleSliceSizeK);
2204 if constexpr(IsInputGemm)
2206 const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2 /
BPackedSize;
2207 const auto b_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2209 b_grid_desc_bpreshuffled.GetElementSpaceSize());
2213 decltype(b_grid_desc_bpreshuffled),
2214 decltype(b_block_desc_bk0_n_bk1),
2218 BBlockTransferSrcScalarPerVector,
2219 BThreadTransferSrcResetCoordinateAfterRun,
2220 true>(b_grid_desc_bpreshuffled,
2226 p_b_scale_grid + expert_scale_stride / 2 /
BPackedSize;
2227 const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
2228 p_b_scale_grid_up + expert_id * expert_scale_stride /
BPackedSize,
2229 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
2230 auto b_scale_thread_copy_up =
2233 decltype(b_scale_grid_desc_bn_ak),
2234 decltype(b_scale_thread_desc),
2241 b_scale_grid_desc_bn_ak,
2244 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
2245 a_grid_desc_ak0_m_ak1,
2246 a_block_desc_ak0_m_ak1,
2250 a_block_slice_copy_step,
2251 b_grid_desc_bpreshuffled,
2252 b_block_desc_bk0_n_bk1,
2254 b_blockwise_copy_up,
2258 b_block_slice_copy_step,
2259 c_scale_thread_desc,
2262 a_scale_grid_desc_am_ak,
2263 a_scale_thread_desc,
2264 a_scale_thread_copy,
2266 a_scale_thread_slice_copy_step,
2267 b_scale_grid_desc_bn_ak,
2268 b_scale_thread_desc,
2269 b_scale_thread_copy,
2270 b_scale_thread_copy_up,
2272 b_scale_grid_buf_up,
2273 b_scale_thread_slice_copy_step,
2274 num_k_block_main_loop);
2278 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, NumKBlockPerScale, TailNum>(
2279 a_grid_desc_ak0_m_ak1,
2280 a_block_desc_ak0_m_ak1,
2284 a_block_slice_copy_step,
2285 b_grid_desc_bpreshuffled,
2286 b_block_desc_bk0_n_bk1,
2290 b_block_slice_copy_step,
2291 c_scale_thread_desc,
2293 a_scale_grid_desc_am_ak,
2294 a_scale_thread_desc,
2295 a_scale_thread_copy,
2297 a_scale_thread_slice_copy_step,
2298 b_scale_grid_desc_bn_ak,
2299 b_scale_thread_desc,
2300 b_scale_thread_copy,
2302 b_scale_thread_slice_copy_step,
2303 num_k_block_main_loop);
2309 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
2310 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
2313 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
2317 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4 =
2318 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
2322 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp =
2323 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4();
2325 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I0);
2326 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I1);
2327 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I2);
2328 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I3);
2329 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I4);
2330 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I5);
2331 constexpr
auto N3 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I6);
2332 constexpr
auto N4 = c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4_tmp.GetLength(
I7);
2334 static_assert(N0 * N1 * N2 * N3 * N4 == NPerBlock);
2335 static_assert(M0 * M1 * M2 == MPerBlock);
2336 static_assert(N4 == 4);
2343 if constexpr(MulRoutedWeight)
2345 const index_t m_pos = block_m_id * MPerBlock + m0 * M1 * M2 + m1 * M2 + m2;
2346 topk_weight = p_ds_grid[
I0][m_pos];
2351 blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
2354 if constexpr(IsInputGemm)
2358 float gate = c_thread_buf[cidx];
2359 float up = c_thread_buf_up[cidx];
2360 if constexpr(MulRoutedWeight)
2362 gate = gate * topk_weight;
2363 up = up * topk_weight;
2371 c_thread_buf(cidx) = gate * up;
2375 float gate = c_thread_buf[cidx];
2376 float up = c_thread_buf_up[cidx];
2377 if constexpr(MulRoutedWeight)
2379 gate = gate * topk_weight;
2380 up = up * topk_weight;
2388 c_thread_buf(cidx) = gate * up;
2393 if constexpr(MulRoutedWeight)
2395 c_thread_buf(cidx) = c_thread_buf[cidx] * topk_weight;
2404 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
2407 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
2408 static_cast<CShuffleDataType*
>(p_shared),
2409 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2412 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2432 const auto c_thread_mtx_on_block =
2433 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
2435 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
2436 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
2438 const auto m_thread_data_on_block_to_m0_m1_m2_adaptor =
2444 const auto m_thread_data_on_block_idx =
2445 m_thread_data_on_block_to_m0_m1_m2_adaptor.CalculateBottomIndex(
2448 const auto n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor =
2454 const auto n_thread_data_on_block_idx =
2455 n_thread_data_on_block_to_n0_n1_n2_n3_n4_adaptor.CalculateBottomIndex(
2459 auto c_thread_copy_vgpr_to_lds =
2462 decltype(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4),
2463 decltype(c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4),
2465 Sequence<CShuffleMXdlPerWavePerShuffle,
2466 CShuffleNXdlPerWavePerShuffle,
2479 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2482 m_thread_data_on_block_idx[
I1],
2483 n_thread_data_on_block_idx[
I1],
2484 m_thread_data_on_block_idx[
I2],
2485 n_thread_data_on_block_idx[
I2],
2486 n_thread_data_on_block_idx[
I3],
2487 n_thread_data_on_block_idx[
I4]),
2490 using EDataType = CDataType;
2495 const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
2501 return make_dynamic_buffer<AddressSpaceEnum::Global>(
2502 p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
2508 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2510 [&](
auto i) ->
const auto&
2511 {
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
2516 tie(c_shuffle_block_buf),
2518 [&](
auto i) ->
const auto&
2519 {
return ds_grid_buf[i]; },
2523 const auto idx_c_ds_block_begin =
2533 const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
2534 c_grid_desc_mblock_mperblock_nblock_nperblock;
2536 using CDEBlockTransferCluster =
2537 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
2538 const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
2539 constexpr
index_t scatter_weight_idx = IsInputGemm ? 1 : 1;
2544 decltype(c_ds_desc_refs),
2545 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
2546 CElementwiseOperation,
2550 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2552 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>,
2553 CDEBlockTransferCluster,
2559 CDEShuffleBlockTransferScalarPerVectors,
2571 idx_c_ds_block_begin,
2572 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2576 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
2577 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2579 constexpr
auto sfc_c_vgpr =
2582 Sequence<CShuffleMXdlPerWavePerShuffle,
2583 CShuffleNXdlPerWavePerShuffle,
2591 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2594 constexpr
auto sfc_cde_block =
2598 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2600 CShuffleNXdlPerWavePerShuffle *
NWave * NPerXdl>>{};
2602 static_assert(num_access == sfc_cde_block.GetNumOfAccess(),
"wrong!");
2603 constexpr
auto EMThreads =
2604 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
2605 constexpr
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
2606 constexpr
auto ENThreads =
2607 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
2613 auto dstidx = sfc_cde_block.GetIndex(access_id);
2615 block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
2617 const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
2618 index_t token_offset = fused_token & 0xffffff;
2619 if constexpr(IsInputGemm)
2621 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
2623 scatter_offsets(m0) = token_offset * problem.
N;
2629 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2630 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2632 c_block_desc_m0_n0_m1_n1_m2_n2_n3_n4,
2633 c_shuffle_block_buf);
2639 cde_block_copy_lds_and_global.Run(
2642 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2646 if constexpr(access_id < num_access - 1)
2648 constexpr
auto cde_lds_and_global_step =
2649 sfc_cde_block.GetForwardStep(access_id);
2653 cde_block_copy_lds_and_global.MoveSrcSliceWindow(
2654 c_ds_desc_refs, i +
I1, cde_lds_and_global_step);
2658 cde_block_copy_lds_and_global.MoveDstSliceWindow(
2659 tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
2661 cde_lds_and_global_step);
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:23
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
InMemoryDataOperationEnum
Definition: ck.hpp:278
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__global__ void kernel_moe_gemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:46
int64_t long_index_t
Definition: ck.hpp:301
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
Activation
Definition: gridwise_moe_gemm.hpp:31
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
constexpr bool is_same_v
Definition: type.hpp:283
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:300
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
constexpr auto BlockGemmBlockMoeScaleBPreshufflePipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_moe_blockscale_b_preshuffle_selector.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__global__ void kernel_moe_gemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_gemm.hpp:81
Definition: gridwise_moe_gemm_blockscale.hpp:673
const index_t * p_sorted_token_ids
Definition: gridwise_moe_gemm_blockscale.hpp:729
CDataType * p_c_grid
Definition: gridwise_moe_gemm_blockscale.hpp:735
const BScaleType * p_b_scale_grid
Definition: gridwise_moe_gemm_blockscale.hpp:738
const index_t * p_max_token_id
Definition: gridwise_moe_gemm_blockscale.hpp:731
DsGridPointer p_ds_grid
Definition: gridwise_moe_gemm_blockscale.hpp:734
const CElementwiseOperation c_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:742
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, const AScaleType *p_a_scale_grid_, const BScaleType *p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_gemm_blockscale.hpp:674
const ADataType * p_a_grid
Definition: gridwise_moe_gemm_blockscale.hpp:732
const AElementwiseOperation a_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:740
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_gemm_blockscale.hpp:730
const BDataType * p_b_grid
Definition: gridwise_moe_gemm_blockscale.hpp:733
const AScaleType * p_a_scale_grid
Definition: gridwise_moe_gemm_blockscale.hpp:737
const BElementwiseOperation b_element_op
Definition: gridwise_moe_gemm_blockscale.hpp:741
Definition: gridwise_moe_gemm_blockscale.hpp:593
index_t K
Definition: gridwise_moe_gemm_blockscale.hpp:652
__host__ __device__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_gemm_blockscale.hpp:594
index_t TopK
Definition: gridwise_moe_gemm_blockscale.hpp:649
index_t BK0Shuffled
Definition: gridwise_moe_gemm_blockscale.hpp:668
index_t NPadded
Definition: gridwise_moe_gemm_blockscale.hpp:659
index_t StrideB
Definition: gridwise_moe_gemm_blockscale.hpp:654
__host__ void Print() const
Definition: gridwise_moe_gemm_blockscale.hpp:627
index_t BK0
Definition: gridwise_moe_gemm_blockscale.hpp:663
index_t BN0Shuffled
Definition: gridwise_moe_gemm_blockscale.hpp:667
index_t KRead
Definition: gridwise_moe_gemm_blockscale.hpp:660
index_t N
Definition: gridwise_moe_gemm_blockscale.hpp:651
index_t StrideC
Definition: gridwise_moe_gemm_blockscale.hpp:656
index_t KBatch
Definition: gridwise_moe_gemm_blockscale.hpp:657
index_t MBlock
Definition: gridwise_moe_gemm_blockscale.hpp:664
index_t KPadded
Definition: gridwise_moe_gemm_blockscale.hpp:661
index_t NumTokens
Definition: gridwise_moe_gemm_blockscale.hpp:648
index_t StrideA
Definition: gridwise_moe_gemm_blockscale.hpp:653
index_t AK0
Definition: gridwise_moe_gemm_blockscale.hpp:662
index_t M
Definition: gridwise_moe_gemm_blockscale.hpp:650
index_t MPadded
Definition: gridwise_moe_gemm_blockscale.hpp:658
index_t NBlock
Definition: gridwise_moe_gemm_blockscale.hpp:665
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_gemm_blockscale.hpp:655
Definition: gridwise_moe_gemm_blockscale.hpp:746
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_gemm_blockscale.hpp:747
index_t a_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:778
index_t b_k_split_offset
Definition: gridwise_moe_gemm_blockscale.hpp:779
Definition: gridwise_moe_gemm_blockscale.hpp:171
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_gemm_blockscale.hpp:522
static constexpr index_t KPack
Definition: gridwise_moe_gemm_blockscale.hpp:196
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_gemm_blockscale.hpp:898
__host__ static __device__ auto MakeBGridDescriptor_Preshuffled(index_t N0, index_t K0)
Definition: gridwise_moe_gemm_blockscale.hpp:411
static constexpr auto AK1Number
Definition: gridwise_moe_gemm_blockscale.hpp:189
__host__ static __device__ auto CalculateBK0Shuffled(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:269
static constexpr auto BK1Number
Definition: gridwise_moe_gemm_blockscale.hpp:190
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_gemm_blockscale.hpp:230
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, void *p_shared1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm_blockscale.hpp:1927
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1144
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_gemm_blockscale.hpp:782
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:260
__host__ static __device__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:285
static constexpr auto I3
Definition: gridwise_moe_gemm_blockscale.hpp:178
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_gemm_blockscale.hpp:971
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:315
static constexpr index_t KLane
Definition: gridwise_moe_gemm_blockscale.hpp:209
remove_cvref_t< decltype(BlockGemmBlockMoeScaleBPreshufflePipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, ScaleBlockM, ScaleBlockN, ScaleBlockK, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_gemm_blockscale.hpp:947
float AScaleType
Definition: gridwise_moe_gemm_blockscale.hpp:172
static constexpr auto AK0Number
Definition: gridwise_moe_gemm_blockscale.hpp:187
static constexpr index_t KRepeat
Definition: gridwise_moe_gemm_blockscale.hpp:211
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_gemm_blockscale.hpp:329
static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm_blockscale.hpp:1159
static constexpr auto I2
Definition: gridwise_moe_gemm_blockscale.hpp:177
static constexpr index_t APackedSize
Definition: gridwise_moe_gemm_blockscale.hpp:232
__host__ static __device__ auto CalculateBN0Shuffled(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:265
__host__ static __device__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:274
static constexpr auto I4
Definition: gridwise_moe_gemm_blockscale.hpp:179
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_gemm_blockscale.hpp:567
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_gemm_blockscale.hpp:255
static constexpr auto I6
Definition: gridwise_moe_gemm_blockscale.hpp:181
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_gemm_blockscale.hpp:184
__host__ static constexpr __device__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_gemm_blockscale.hpp:1151
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:516
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:246
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_gemm_blockscale.hpp:228
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_gemm_blockscale.hpp:905
float BScaleType
Definition: gridwise_moe_gemm_blockscale.hpp:173
static constexpr auto I7
Definition: gridwise_moe_gemm_blockscale.hpp:182
__host__ static __device__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_gemm_blockscale.hpp:304
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_gemm_blockscale.hpp:579
static constexpr index_t NWave
Definition: gridwise_moe_gemm_blockscale.hpp:213
static constexpr index_t NLane
Definition: gridwise_moe_gemm_blockscale.hpp:212
static constexpr index_t NumDTensor
Definition: gridwise_moe_gemm_blockscale.hpp:193
__host__ static __device__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:297
__host__ static __device__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_gemm_blockscale.hpp:309
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_gemm_blockscale.hpp:949
__host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:291
static constexpr auto I0
Definition: gridwise_moe_gemm_blockscale.hpp:175
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_gemm_blockscale.hpp:507
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_gemm_blockscale.hpp:546
static constexpr auto BlockSizeNumber
Definition: gridwise_moe_gemm_blockscale.hpp:191
static constexpr index_t KGroup
Definition: gridwise_moe_gemm_blockscale.hpp:198
static constexpr index_t BPackedSize
Definition: gridwise_moe_gemm_blockscale.hpp:239
static constexpr auto I5
Definition: gridwise_moe_gemm_blockscale.hpp:180
static __device__ void Run(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, const AScaleType *p_a_scale_grid, const BScaleType *p_b_scale_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_gemm_blockscale.hpp:1180
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_gemm_blockscale.hpp:217
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_gemm_blockscale.hpp:419
static constexpr auto I1
Definition: gridwise_moe_gemm_blockscale.hpp:176
static constexpr auto BK0Number
Definition: gridwise_moe_gemm_blockscale.hpp:188
static constexpr index_t SortedTileSize
Definition: gridwise_moe_gemm_blockscale.hpp:215
__host__ static __device__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_gemm_blockscale.hpp:279
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N(0, 0, 0, 0, {}))> DsGridDesc_M_N
Definition: gridwise_moe_gemm_blockscale.hpp:590
Definition: xdlops_gemm.hpp:942
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1388
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1343
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1382
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1_gather.hpp:48
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: threadwise_tensor_slice_transfer.hpp:440
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
Definition: data_type.hpp:186
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:981
Definition: unary_element_wise_operation.hpp:308
Definition: unary_element_wise_operation.hpp:1023