39 template <
typename GridwiseGemm,
 
   40           bool HasMainKBlockLoop,
 
   45 #if CK_USE_LAUNCH_BOUNDS 
   51 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) 
   52     __shared__ 
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
 
   54     auto splitk_batch_offset = 
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
 
   56     GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
 
   57         karg.p_sorted_token_ids,
 
   58         karg.p_sorted_expert_ids,
 
   60         karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
 
   61         karg.p_a_scale_grid + splitk_batch_offset.a_k_split_offset,
 
   62         karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
 
   63         karg.p_b_scale_grid + splitk_batch_offset.b_k_split_offset,
 
   77 template <
typename GridwiseGemm,
 
   78           bool HasMainKBlockLoop,
 
   83 #if CK_USE_LAUNCH_BOUNDS 
   89 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) 
   90     __shared__ 
char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
 
   91     __shared__ 
char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
 
   93     auto splitk_batch_offset = 
typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
 
   95     GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
 
   96         karg.p_sorted_token_ids,
 
   97         karg.p_sorted_expert_ids,
 
   99         karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
 
  100         karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset,
 
  101         karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
 
  102         karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset,
 
  116 template <
typename ALayout,
 
  121           typename AScaleDataType,
 
  123           typename BScaleDataType,
 
  124           typename AccDataType,
 
  125           typename CShuffleDataType,
 
  128           typename AElementwiseOperation,
 
  129           typename BElementwiseOperation,
 
  130           typename CElementwiseOperation,
 
  143           typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
  144           typename ABlockTransferThreadClusterArrangeOrder,
 
  145           typename ABlockTransferSrcAccessOrder,
 
  146           index_t ABlockTransferSrcVectorDim,
 
  147           index_t ABlockTransferSrcScalarPerVector,
 
  148           index_t ABlockTransferDstScalarPerVector_AK1,
 
  149           bool AThreadTransferSrcResetCoordinateAfterRun,
 
  151           typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
  152           typename BBlockTransferThreadClusterArrangeOrder,
 
  153           typename BBlockTransferSrcAccessOrder,
 
  154           index_t BBlockTransferSrcVectorDim,
 
  155           index_t BBlockTransferSrcScalarPerVector,
 
  156           index_t BBlockTransferDstScalarPerVector_BK1,
 
  157           bool BThreadTransferSrcResetCoordinateAfterRun,
 
  159           index_t CShuffleMXdlPerWavePerShuffle,
 
  160           index_t CShuffleNXdlPerWavePerShuffle,
 
  161           typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
  162           typename CDEShuffleBlockTransferScalarPerVectors,
 
  165           index_t ActivationOperation                 = 0,
 
  166           bool NSwizzle                               = 
false,
 
  167           bool IsInputGemm                            = 
true,
 
  168           bool MulRoutedWeight                        = 
true,
 
  170           typename ComputeTypeA                       = ADataType,
 
  171           typename ComputeTypeB                       = BDataType>
 
  189         CDEShuffleBlockTransferScalarPerVectors{}[
I0];
 
  233                 return static_cast<const DDataType*
>(
nullptr);
 
  246         const index_t gridx  = NSwizzle ? nblock * mblock : nblock;
 
  247         const index_t gridy  = NSwizzle ? 1 : mblock;
 
  269         auto K_t = K_Batch * KPerBlock;
 
  270         return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
 
  275         auto K_t = K_Batch * KPerBlock;
 
  276         return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
 
  281         auto K_t = K_Batch * KPerBlock;
 
  282         return (K + K_t - 1) / K_t * KPerBlock;
 
  288         auto K_t                = K_Batch * KReadVec;
 
  289         return (K + K_t - 1) / K_t * KReadVec;
 
  302     template <
index_t MNXdlPerWave,
 
  306               typename TileDesc_K0_MN_K1>
 
  332         IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
 
  334         const auto a_grid_desc_mraw_kraw = [&]() {
 
  335             if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
 
  339             else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
 
  347         if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
 
  348                      GemmSpec == GemmSpecialization::MNKPadding)
 
  351             const auto a_grid_desc_m_k =
 
  365             return a_grid_desc_ak0_m_ak1;
 
  367         else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
 
  368                           GemmSpec == GemmSpecialization::MNPadding)
 
  372                 a_grid_desc_mraw_kraw,
 
  379                 a_grid_desc_ak0_m_ak1,
 
  387                 a_grid_desc_permuted,
 
  396         else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
 
  397                           GemmSpec == GemmSpecialization::NKPadding)
 
  401                 a_grid_desc_mraw_kraw,
 
  413             return a_grid_desc_ak0_m_ak1;
 
  419                 a_grid_desc_mraw_kraw,
 
  426                 a_grid_desc_ak0_m_ak1,
 
  434                 a_grid_desc_permuted,
 
  449         const auto b_grid_desc_nraw_kraw = [&]() {
 
  463                         GemmSpec != GemmSpecialization::Default),
 
  464                       "pk_i4_t does not support padding");
 
  466                         (GemmSpec != GemmSpecialization::Default &&
 
  467                          GemmSpec != GemmSpecialization::MPadding)),
 
  468                       "f4x2_pk_t does not support K padding");
 
  470         if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
 
  471                      GemmSpec == GemmSpecialization::MNKPadding)
 
  474             const auto b_grid_desc_n_k =
 
  488             return b_grid_desc_bk0_n_bk1;
 
  490         else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
 
  491                           GemmSpec == GemmSpecialization::MNPadding)
 
  495                 b_grid_desc_nraw_kraw,
 
  501             return b_grid_desc_bk0_n_bk1;
 
  503         else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
 
  504                           GemmSpec == GemmSpecialization::MKPadding)
 
  508                 b_grid_desc_nraw_kraw,
 
  520             return b_grid_desc_bk0_n_bk1;
 
  526                 b_grid_desc_nraw_kraw,
 
  533                 b_grid_desc_bk0_n_bk1,
 
  541                 b_grid_desc_permuted,
 
  553     template <
typename ABlockDesc_AK0_M_AK1>
 
  554     __host__ __device__ 
static constexpr 
auto 
  557         constexpr 
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
 
  559         return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MXdlPack, MPerXdl>(
 
  560             ABlockDesc_AK0_M_AK1{});
 
  563     template <
typename BBlockDesc_BK0_N_BK1>
 
  564     __host__ __device__ 
static constexpr 
auto 
  567         constexpr 
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
 
  569         return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NXdlPack, NPerXdl>(
 
  570             BBlockDesc_BK0_N_BK1{});
 
  573     template <
typename ELayout>
 
  575         IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
 
  577         const auto c_grid_desc_mraw_nraw = [&]() {
 
  596     template <
typename DLayout>
 
  597     __host__ __device__ 
static auto 
  600         const auto c_grid_desc_mraw_nraw = [&]() {
 
  625                 return MakeDGridDescriptor_M_N<DLayout>(M, MPad, N, NPad, StrideDs[i]);
 
  630     template <
typename DsGr
idDesc>
 
  632         const DsGridDesc& ds_grid_desc_m_n, 
index_t MBlock, 
index_t NBlock)
 
  637                     ds_grid_desc_m_n[i], MBlock, NBlock);
 
  653                          std::array<index_t, NumDTensor> StrideDs_,
 
  681             std::cout << 
"problem {" 
  683                       << 
"TopK:" << 
TopK << 
", " 
  694                       << 
"KRead:" << 
KRead << 
", " 
  696                       << 
"AK0:" << 
AK0 << 
", " 
  697                       << 
"BK0:" << 
BK0 << 
", " 
  698                       << 
"MBlock: " << 
MBlock << 
", " 
  699                       << 
"NBlock: " << 
NBlock << 
"}" << std::endl;
 
  728                           const index_t* p_sorted_expert_ids_,
 
  729                           const index_t* p_max_token_id_,
 
  730                           const ADataType* p_a_grid_,
 
  731                           const AScaleDataType* p_a_scale_grid_,
 
  732                           const BDataType* p_b_grid_,
 
  733                           const BScaleDataType* p_b_scale_grid_,
 
  734                           std::array<const void*, NumDTensor> p_ds_grid_,
 
  735                           CDataType* p_c_grid_,
 
  745                           std::array<index_t, NumDTensor> StrideDs_,
 
  748                           AElementwiseOperation a_element_op_,
 
  749                           BElementwiseOperation b_element_op_,
 
  750                           CElementwiseOperation c_element_op_)
 
  782                 p_ds_grid(i) = 
static_cast<const DDataType_*
>(p_ds_grid_[i]);
 
  805             if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
 
  809             else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
 
  814             if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
 
  818             else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
 
  825             if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
 
  829             else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
 
  836             if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
 
  841             else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
 
  846             if(k_id < karg.
KBatch - 1)
 
  876             constexpr 
auto a_lds_block_desc =
 
  888             return a_lds_block_desc_permuted;
 
  895             constexpr 
auto WaveSize = 64;
 
  896             constexpr 
auto M0       = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
 
  897             constexpr 
auto M1       = MPerBlock / M0;
 
  899             constexpr 
auto KThreadWrite     = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
 
  900             constexpr 
auto K0PerThreadWrite = 
AK0Number / KThreadWrite;
 
  901             constexpr 
auto KThreadRead      = WaveSize / MPerXdl;
 
  902             constexpr 
auto K0PerThreadRead  = 
AK0Number / KThreadRead;
 
  904             constexpr 
auto kfold = (
AK1Number * M0 * 
sizeof(ADataType) > 128)
 
  906                                        : 128 / (
AK1Number * M0 * 
sizeof(ADataType));
 
  907             constexpr 
auto KThreadReadPerm =
 
  908                 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
 
  909                     ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
 
  913             constexpr 
auto mpair = (
AK1Number * MPerXdl * 
sizeof(ADataType) > 128)
 
  915                                        : ((128 / (
AK1Number * MPerXdl * 
sizeof(ADataType))) > M0
 
  917                                               : 128 / (
AK1Number * MPerXdl * 
sizeof(ADataType)));
 
  923                            Number<kfold * M0 / mpair>{},
 
  942                 a_lds_block_desc_permuted,
 
  964                 a_lds_block_desc_unmerged,
 
  967                                           Number<KThreadWrite / kfold / KThreadReadPerm>{},
 
  976             return a_lds_block_desc_ak0_m_ak1;
 
  993             constexpr 
auto b_lds_block_desc =
 
 1005             return b_lds_block_desc_permuted;
 
 1009             constexpr 
auto WaveSize = 64;
 
 1010             constexpr 
auto N0       = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I1);
 
 1011             constexpr 
auto N1       = NPerBlock / N0;
 
 1013             constexpr 
auto KThreadWrite     = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I0);
 
 1014             constexpr 
auto K0PerThreadWrite = 
BK0Number / KThreadWrite;
 
 1015             constexpr 
auto KThreadRead      = WaveSize / NPerXdl;
 
 1016             constexpr 
auto K0PerThreadRead  = 
BK0Number / KThreadRead;
 
 1018             constexpr 
auto kfold = (
BK1Number * N0 * 
sizeof(BDataType) > 128)
 
 1020                                        : 128 / (
BK1Number * N0 * 
sizeof(BDataType));
 
 1021             constexpr 
auto KThreadReadPerm =
 
 1022                 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
 
 1023                     ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
 
 1027             constexpr 
auto npair = (
BK1Number * NPerXdl * 
sizeof(BDataType) > 128)
 
 1029                                        : ((128 / (
BK1Number * NPerXdl * 
sizeof(BDataType))) > N0
 
 1031                                               : 128 / (
BK1Number * NPerXdl * 
sizeof(BDataType)));
 
 1037                            Number<kfold * N0 / npair>{},
 
 1056                 b_lds_block_desc_permuted,
 
 1078                 b_lds_block_desc_unmerged,
 
 1081                                           Number<KThreadWrite / kfold / KThreadReadPerm>{},
 
 1090             return b_lds_block_desc_bk0_n_bk1;
 
 1096         constexpr 
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
 
 1097         constexpr 
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
 
 1099         constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
 1106         return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
 
 1127                                 ABlockTransferSrcScalarPerVector,
 
 1128                                 BBlockTransferSrcScalarPerVector,
 
 1149             a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
 
 1152             b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
 
 1155         constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
 1158         constexpr 
auto c_block_size =
 
 1159             c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
 
 1161         if constexpr(IsInputGemm)
 
 1163             return math::max(a_block_space_size_aligned * 
sizeof(ADataType) +
 
 1164                                  b_block_space_size_aligned * 
sizeof(BDataType) * 2,
 
 1165                              c_block_size * 
sizeof(CShuffleDataType));
 
 1169             return math::max((a_block_space_size_aligned * 
sizeof(ADataType) +
 
 1170                               b_block_space_size_aligned * 
sizeof(BDataType)),
 
 1171                              c_block_size * 
sizeof(CShuffleDataType));
 
 1178         static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
 
 1179                           (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
 
 1180                       "Invalid tuning param!");
 
 1182         static_assert(KPerBlock % (ScaleBlockSize / 
BPackedSize) == 0,
 
 1183                       "KPerBlock should be multiple of ScaleBlockSize");
 
 1191             if(!(karg.
M % MPerBlock == 0))
 
 1195                     std::cout << 
"Arg M value is not a multiple of MPerBlock! M: " << karg.
M << 
" " 
 1196                               << __FILE__ << 
":" << __LINE__ << 
", in function: " << __func__
 
 1209             if(!(karg.
N % NPerBlock == 0))
 
 1213                     std::cout << 
"Arg N value is not a multiple of NPerBlock! N: " << karg.
N << 
" " 
 1214                               << __FILE__ << 
":" << __LINE__ << 
", in function: " << __func__
 
 1226             auto K_t = karg.
KBatch * KPerBlock;
 
 1227             if(!(karg.
K % K_t == 0))
 
 1231                     std::cout << 
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: " 
 1232                               << karg.
K << 
" " << __FILE__ << 
":" << __LINE__
 
 1233                               << 
", in function: " << __func__ << std::endl;
 
 1241             auto K_t                = karg.
KBatch * KReadVec;
 
 1243             if((KReadPadSplited * (karg.
KBatch - 1)) >= karg.
K)
 
 1251             if(karg.
K % ABlockTransferSrcScalarPerVector != 0)
 
 1255                     std::cout << 
"Arg K (" << karg.
K 
 1256                               << 
") value is not a multiple of ABlockTransferSrcScalarPerVector (" 
 1257                               << ABlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1258                               << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1265             if(karg.
M % ABlockTransferSrcScalarPerVector != 0)
 
 1269                     std::cout << 
"Arg M (" << karg.
M 
 1270                               << 
") value is not a multiple of ABlockTransferSrcScalarPerVector (" 
 1271                               << ABlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1272                               << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1280             if(karg.
N % BBlockTransferSrcScalarPerVector != 0)
 
 1284                     std::cout << 
"Arg N (" << karg.
N 
 1285                               << 
") value is not a multiple of BBlockTransferSrcScalarPerVector (" 
 1286                               << BBlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1287                               << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1294             if(karg.
K % BBlockTransferSrcScalarPerVector != 0)
 
 1298                     std::cout << 
"Arg K (" << karg.
K 
 1299                               << 
") value is not a multiple of BBlockTransferSrcScalarPerVector (" 
 1300                               << BBlockTransferSrcScalarPerVector << 
" )! " << __FILE__ << 
":" 
 1301                               << __LINE__ << 
", in function: " << __func__ << std::endl;
 
 1313                     std::cout << 
"Arg N (" << karg.
N 
 1314                               << 
") value is not a multiple of " 
 1315                                  "CShuffleBlockTransferScalarPerVector_NPerBlock (" 
 1317                               << __FILE__ << 
":" << __LINE__ << 
", in function: " << __func__
 
 1329                     std::cout << 
"Arg M (" << karg.
M 
 1330                               << 
") value is not a multiple of " 
 1331                                  "CShuffleBlockTransferScalarPerVector_NPerBlock (" 
 1333                               << __FILE__ << 
":" << __LINE__ << 
", in function: " << __func__
 
 1343         const auto num_k_loop = karg.
AK0 / (KPerBlock / AK1Value);
 
 1345         if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
 
 1356         const index_t num_loop = K / KPerBlock;
 
 1358         return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
 
 1363         const index_t num_loop = K / KPerBlock;
 
 1365         return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
 
 1368     template <
typename CGr
idDesc>
 
 1370         const CGridDesc& c_grid_desc_m_n, 
index_t MBlock, 
index_t NBlock)
 
 1379         return c_grid_desc_mblock_mperblock_nblock_nperblock;
 
 1391                   "A scale pack data type too large!");
 
 1393                   "B scale pack data type too large!");
 
 1395     static_assert(is_same_v<AElementwiseOperation, tensor_operation::element_wise::PassThrough> &&
 
 1396                       is_same_v<BElementwiseOperation, tensor_operation::element_wise::PassThrough>,
 
 1397                   "A/B ElementwiseOperation should be PassThrough as load_to_lds is used!");
 
 1400     template <
bool HasMainKBlockLoop,
 
 1403     __device__ 
static void Run(
const index_t* p_sorted_token_ids,
 
 1404                                const index_t* p_sorted_expert_ids,
 
 1405                                const index_t* p_max_token_id,
 
 1406                                const ADataType* p_a_grid,
 
 1407                                const AScaleDataType* p_a_scale_grid,
 
 1408                                const BDataType* p_b_grid,
 
 1409                                const BScaleDataType* p_b_scale_grid,
 
 1411                                CDataType* p_c_grid,
 
 1414                                AElementwiseOperation a_element_op,
 
 1415                                BElementwiseOperation b_element_op,
 
 1416                                CElementwiseOperation c_element_op)
 
 1429         const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
 
 1448         const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
 
 1452         const index_t max_token_id    = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
 
 1453         const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
 
 1454         if(expert_block_id * MPerBlock >= max_token_id)
 
 1457             __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
 
 1459         const auto block_mn = [&]() -> std::pair<int, int> {
 
 1460             if constexpr(NSwizzle)
 
 1462                 const index_t ecnt_prefix  = p_max_token_id[1 + expert_id];
 
 1464                 const index_t ecnt         = p_max_token_id[2 + expert_id] - ecnt_prefix;
 
 1465                 const index_t expert_swizzle =
 
 1466                     ecnt > 0 ? ecnt : 1; 
 
 1467                 const index_t bid_new = blockIdx.x - prefix_block;
 
 1468                 const index_t nid     = __builtin_amdgcn_readfirstlane(
 
 1469                     bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
 
 1471                     __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
 
 1476                 return {blockIdx.x, blockIdx.y};
 
 1480         const index_t block_n_id = block_mn.first;
 
 1481         const index_t block_m_id = block_mn.second;
 
 1483             __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
 
 1486         constexpr 
auto AMThreads  = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
 
 1487         constexpr 
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
 
 1488         constexpr 
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
 
 1489         constexpr 
auto AKThreads  = AK0Threads * AK1Threads;
 
 1490         constexpr 
auto AMRepeats  = MPerBlock / AMThreads;
 
 1491         const index_t token_pos   = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
 
 1493         if(token_pos >= max_token_id || token0 >= problem.
NumTokens)
 
 1495         StaticallyIndexedArray<IndexType, AMRepeats> gather_offsets;
 
 1496         static_for<0, AMRepeats, 1>{}([&](
auto m0) {
 
 1497             const index_t fused_token = p_sorted_token_ids[token_pos + m0];
 
 1498             index_t token_offset      = fused_token & 0xffffff;
 
 1499             if constexpr(!IsInputGemm)
 
 1501                 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
 
 1503             gather_offsets(m0) = 
static_cast<IndexType
>(token_offset);
 
 1507             __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
 
 1508         const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
 
 1509             problem.
N * (IsInputGemm ? 2 : 1) *
 
 1513         const index_t n_block_data_idx_on_grid =
 
 1514             __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
 
 1517         const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1518             p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
 
 1519         const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1520             p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
 
 1523         const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1524             p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
 
 1525         const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1526             p_b_scale_grid + (expert_id * expert_scale_stride) / 
sizeof(BScaleDataType),
 
 1527             b_scale_grid_desc_bn_ak.GetElementSpaceSize());
 
 1539         auto a_blockwise_copy = ThreadGroupTensorSliceTransfer_Gather_DirectLoad<
 
 1541             Sequence<AK0Number, MPerBlock, AK1Number>,
 
 1542             ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
 1543             ABlockTransferThreadClusterArrangeOrder,
 
 1546             decltype(a_grid_desc_ak0_m_ak1),
 
 1547             decltype(a_block_desc_ak0_m_ak1),
 
 1548             ABlockTransferSrcAccessOrder,
 
 1549             ABlockTransferSrcVectorDim,
 
 1551             ABlockTransferSrcScalarPerVector,
 
 1553             1>(a_grid_desc_ak0_m_ak1,
 
 1555                a_block_desc_ak0_m_ak1,
 
 1560         auto b_blockwise_copy =
 
 1562                                                       Sequence<BK0Number, NPerBlock, BK1Number>,
 
 1563                                                       BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
 1564                                                       BBlockTransferThreadClusterArrangeOrder,
 
 1567                                                       decltype(b_grid_desc_bk0_n_bk1),
 
 1568                                                       decltype(b_block_desc_bk0_n_bk1),
 
 1569                                                       BBlockTransferSrcAccessOrder,
 
 1570                                                       BBlockTransferSrcVectorDim,
 
 1572                                                       BBlockTransferSrcScalarPerVector>(
 
 1573                 b_grid_desc_bk0_n_bk1,
 
 1575                 b_block_desc_bk0_n_bk1,
 
 1580             a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
 
 1583         auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1584             static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
 
 1586         auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1587             reinterpret_cast<BDataType*
>(
static_cast<char*
>(p_shared) +
 
 1588                                          a_block_space_size_aligned * 
sizeof(ADataType)),
 
 1589             b_block_desc_bk0_n_bk1.GetElementSpaceSize());
 
 1595         static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
 
 1597         auto c_thread_buf            = blockwise_gemm_pipeline.GetCThreadBuffer();
 
 1598         decltype(c_thread_buf) c_thread_buf_up;
 
 1602                                   c_thread_buf.num_of_v_,
 
 1603                                   c_thread_buf.s_per_v,
 
 1607         const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
 
 1608             (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
 
 1612         const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
 
 1613         const auto waveId_m = wave_idx[
I0];
 
 1614         const auto waveId_n = wave_idx[
I1];
 
 1616         auto thread_offset_shuffled =
 
 1619         auto a_thread_offset_m = waveId_m;
 
 1621         auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
 
 1624             decltype(a_scale_grid_desc_am_ak),
 
 1625             decltype(BlockwiseGemmPipe::a_scale_thread_desc),
 
 1631             true>(a_scale_grid_desc_am_ak,
 
 1637         auto b_thread_offset_n = waveId_n;
 
 1639         auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2<
 
 1642             decltype(b_scale_grid_desc_bn_ak),
 
 1643             decltype(BlockwiseGemmPipe::b_scale_thread_desc),
 
 1649             true>(b_scale_grid_desc_bn_ak,
 
 1654         if constexpr(IsInputGemm)
 
 1657                 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
 
 1658             auto b_block_buf_up = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1659                 reinterpret_cast<BDataType*
>(
static_cast<char*
>(p_shared) +
 
 1660                                              a_block_space_size_aligned * 
sizeof(ADataType) +
 
 1661                                              b_block_space_size_aligned * 
sizeof(BDataType)),
 
 1662                 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
 
 1664             const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
 
 1665             const auto b_grid_buf_up     = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1666                 p_b_grid_up + expert_id * expert_stride,
 
 1667                 b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
 
 1669             auto b_blockwise_copy_up = ThreadGroupTensorSliceTransfer_DirectLoad<
 
 1671                 Sequence<BK0Number, NPerBlock, BK1Number>,
 
 1672                 BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
 1673                 BBlockTransferThreadClusterArrangeOrder,
 
 1676                 decltype(b_grid_desc_bk0_n_bk1),
 
 1677                 decltype(b_block_desc_bk0_n_bk1),
 
 1678                 BBlockTransferSrcAccessOrder,
 
 1679                 BBlockTransferSrcVectorDim,
 
 1681                 BBlockTransferSrcScalarPerVector>(b_grid_desc_bk0_n_bk1,
 
 1683                                                   b_block_desc_bk0_n_bk1,
 
 1686             const BScaleDataType* p_b_scale_grid_up =
 
 1687                 p_b_scale_grid + expert_scale_stride / 2 / 
sizeof(BScaleDataType);
 
 1688             const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1689                 p_b_scale_grid_up + expert_id * expert_scale_stride / 
sizeof(BScaleDataType),
 
 1690                 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
 
 1692             auto b_scale_thread_copy_up = ThreadwiseTensorSliceTransfer_v2<
 
 1695                 decltype(b_scale_grid_desc_bn_ak),
 
 1696                 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
 
 1703                 b_scale_grid_desc_bn_ak,
 
 1708             blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
 
 1710                 a_grid_desc_ak0_m_ak1,
 
 1711                 a_block_desc_ak0_m_ak1,
 
 1715                 a_block_slice_copy_step,
 
 1717                 b_grid_desc_bk0_n_bk1,
 
 1718                 b_block_desc_bk0_n_bk1,
 
 1720                 b_blockwise_copy_up,
 
 1725                 b_block_slice_copy_step,
 
 1730                 a_scale_grid_desc_am_ak,
 
 1731                 a_scale_thread_copy,
 
 1734                 b_scale_grid_desc_bn_ak,
 
 1735                 b_scale_thread_copy,
 
 1736                 b_scale_thread_copy_up,
 
 1738                 b_scale_grid_buf_up,
 
 1739                 num_k_block_main_loop);
 
 1743             blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
 
 1744                 a_grid_desc_ak0_m_ak1, 
 
 1745                 a_block_desc_ak0_m_ak1,
 
 1749                 a_block_slice_copy_step,
 
 1750                 b_grid_desc_bk0_n_bk1, 
 
 1751                 b_block_desc_bk0_n_bk1,
 
 1755                 b_block_slice_copy_step,
 
 1757                 a_scale_grid_desc_am_ak, 
 
 1758                 a_scale_thread_copy,
 
 1760                 b_scale_grid_desc_bn_ak, 
 
 1761                 b_scale_thread_copy,
 
 1763                 num_k_block_main_loop);
 
 1768             static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
 
 1769                               NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
 
 1771             static_assert(CShuffleMXdlPerWavePerShuffle % 
MXdlPack == 0 &&
 
 1772                               CShuffleNXdlPerWavePerShuffle % 
NXdlPack == 0,
 
 1775             constexpr 
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
 
 1776             constexpr 
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
 
 1779             constexpr 
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
 
 1780                 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
 
 1784             constexpr 
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
 
 1785                 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
 
 1787             constexpr 
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
 
 1788             constexpr 
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
 
 1789             constexpr 
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
 
 1790             constexpr 
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
 
 1791             constexpr 
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
 
 1792             constexpr 
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
 
 1793             constexpr 
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
 
 1794             constexpr 
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
 
 1795             constexpr 
auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I8);
 
 1796             constexpr 
auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I9);
 
 1799             static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
 
 1800             static_assert(M5 == 4);
 
 1804             vector_type<float, 4> topk_weights; 
 
 1805             static_for<0, NXdlPerWave / 
NXdlPack, 1>{}([&](
auto n0) {
 
 1806                 static_for<0, NXdlPack, 1>{}([&](
auto inxdl) {                
 
 1807                     static_for<0, MXdlPerWave / 
MXdlPack, 1>{}([&](
auto m0) { 
 
 1808                         static_for<0, MXdlPack, 1>{}([&](
auto imxdl) {        
 
 1809                             static_for<0, M3, 1>{}([&](
auto m3) { 
 
 1810                                 const index_t m_pos = block_m_id * MPerBlock +
 
 1811                                                       m0 * M2 * M1 * M3 * M4 * M5 +
 
 1812                                                       m1 * M2 * M3 * M4 * M5 +
 
 1813                                                       imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
 
 1814                                 if constexpr(MulRoutedWeight)
 
 1817                                         *c_style_pointer_cast<const vector_type<float, M5>*>(
 
 1818                                             p_ds_grid[
I2] + m_pos);
 
 1820                                 static_for<0, M5, 1>{}([&](
auto m5) { 
 
 1822                                         blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
 
 1823                                             make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
 
 1824                                     constexpr 
auto cidx = Number<c_offset>{};
 
 1826                                     if constexpr(IsInputGemm) 
 
 1828                                         if constexpr(ActivationOperation ==
 
 1831                                             float gate = c_thread_buf[cidx];
 
 1832                                             float up   = c_thread_buf_up[cidx];
 
 1833                                             if constexpr(MulRoutedWeight)
 
 1835                                                 gate = gate * topk_weights.AsType<
float>()[m5];
 
 1836                                                 up   = up * topk_weights.AsType<
float>()[m5];
 
 1838                                             tensor_operation::element_wise::Silu{}(gate, gate);
 
 1839                                             c_thread_buf_fp32(cidx) = gate * up;
 
 1843                                             float gate = c_thread_buf[cidx];
 
 1844                                             float up   = c_thread_buf_up[cidx];
 
 1845                                             if constexpr(MulRoutedWeight)
 
 1847                                                 gate = gate * topk_weights.AsType<
float>()[m5];
 
 1848                                                 up   = up * topk_weights.AsType<
float>()[m5];
 
 1850                                             tensor_operation::element_wise::Gelu{}(gate, gate);
 
 1851                                             c_thread_buf_fp32(cidx) = gate * up;
 
 1866                                         c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
 
 1867                                         if constexpr(MulRoutedWeight)
 
 1869                                             c_thread_buf_fp32(cidx) =
 
 1870                                                 topk_weights.AsType<
float>()[m5] *
 
 1871                                                 c_thread_buf_fp32[cidx];
 
 1881             constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
 1884             auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 1885                 static_cast<CShuffleDataType*
>(p_shared),
 
 1886                 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
 1889                 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
 1893                         Number<CShuffleMXdlPerWavePerShuffle / MXdlPack>{}, 
 
 1902                         Number<CShuffleNXdlPerWavePerShuffle / NXdlPack>{}, 
 
 1907                 make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
 
 1909                            Sequence<0, 2, 4, 6, 7, 8>{},
 
 1911                            Sequence<1, 3, 5, 9>{}));
 
 1915             const auto c_thread_mtx_on_block =
 
 1916                 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0, 
I0, 
I0, 
I0);
 
 1918             const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
 
 1919             const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
 
 1921             const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
 
 1927             const auto m_thread_data_on_block_idx =
 
 1928                 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
 
 1931             const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
 
 1937             const auto n_thread_data_on_block_idx =
 
 1938                 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
 
 1942             auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3<
 
 1945                 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
 1946                 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
 1948                 Sequence<CShuffleMXdlPerWavePerShuffle / 
MXdlPack,
 
 1949                          CShuffleNXdlPerWavePerShuffle / 
NXdlPack,
 
 1958                 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
 
 1963                 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 1966                                        m_thread_data_on_block_idx[
I1],
 
 1967                                        n_thread_data_on_block_idx[
I1],
 
 1968                                        m_thread_data_on_block_idx[
I2],
 
 1969                                        n_thread_data_on_block_idx[
I2],
 
 1970                                        m_thread_data_on_block_idx[
I3],
 
 1971                                        m_thread_data_on_block_idx[
I4],
 
 1972                                        m_thread_data_on_block_idx[
I5],
 
 1973                                        n_thread_data_on_block_idx[
I3]),
 
 1976             using EDataType = CDataType;
 
 1981             const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
 
 1987                     return make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 1988                         p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
 
 1990                 Number<NumDTensor>{});
 
 1994                 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
 
 1996                              { 
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
 
 1997                              Number<NumDTensor>{}));
 
 2001                 tie(c_shuffle_block_buf),
 
 2003                              { 
return ds_grid_buf[i]; },
 
 2004                              Number<NumDTensor>{}));
 
 2007             const auto idx_c_ds_block_begin =
 
 2015                                      Number<NumDTensor>{}));
 
 2017             const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
 
 2018                 c_grid_desc_mblock_mperblock_nblock_nperblock;
 
 2020             using CDEBlockTransferCluster =
 
 2021                 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
 
 2022             const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
 
 2023             constexpr 
index_t scatter_weight_idx  = 3; 
 
 2024             auto cde_block_copy_lds_and_global    = ThreadGroupTensorSliceTransfer_v7r3_scatter<
 
 2028                    decltype(c_ds_desc_refs),
 
 2029                    decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
 
 2030                    CElementwiseOperation,
 
 2031                    Sequence<static_cast<index_t>(EGlobalMemoryDataOperation)>, 
 
 2035                             CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
 2037                             CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, 
 
 2038                    CDEBlockTransferCluster,
 
 2039                    Sequence<0, 1, 2, 3>, 
 
 2040                    Sequence<0, 1, 2, 3>, 
 
 2041                    Sequence<0, 1, 2, 3>, 
 
 2044                    CDEShuffleBlockTransferScalarPerVectors,
 
 2056                      idx_c_ds_block_begin,
 
 2057                      tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
 2061             auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2062                 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
 2064             constexpr 
auto sfc_c_vgpr =
 
 2065                 SpaceFillingCurve<Sequence<MXdlPerWave / 
MXdlPack,
 
 2075                                   Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
 
 2076                                   Sequence<CShuffleMXdlPerWavePerShuffle / 
MXdlPack,
 
 2077                                            CShuffleNXdlPerWavePerShuffle / 
NXdlPack,
 
 2087             constexpr 
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
 
 2090             constexpr 
auto sfc_cde_block =
 
 2091                 SpaceFillingCurve<Sequence<1, MPerBlock, 1, NPerBlock>,
 
 2092                                   Sequence<0, 2, 1, 3>,
 
 2094                                            CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
 2096                                            CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
 
 2098             static_assert(num_access == sfc_cde_block.GetNumOfAccess(), 
"wrong!");
 
 2099             constexpr 
auto EMThreads =
 
 2100                 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
 
 2101             constexpr 
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
 
 2102             constexpr 
auto ENThreads =
 
 2103                 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
 
 2104             static_for<0, num_access, 1>{}([&](
auto access_id) {
 
 2106                 StaticallyIndexedArray<IndexType, EMRepeats> scatter_offsets;
 
 2108                 auto dstidx = sfc_cde_block.GetIndex(access_id);
 
 2110                     block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
 
 2111                 static_for<0, EMRepeats, 1>{}([&](
auto m0) {
 
 2112                     const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
 
 2113                     IndexType token_offset    = fused_token & 0xffffff;
 
 2114                     if constexpr(IsInputGemm)
 
 2116                         token_offset = token_offset * problem.
TopK + (fused_token >> 24);
 
 2118                     scatter_offsets(m0) = 
static_cast<IndexType
>(token_offset) * problem.
N;
 
 2124                 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 2125                                               sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
 
 2127                                               c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 2128                                               c_shuffle_block_buf);
 
 2134                 cde_block_copy_lds_and_global.Run(
 
 2137                     tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
 2141                 if constexpr(access_id < num_access - 1)
 
 2143                     constexpr 
auto cde_lds_and_global_step =
 
 2144                         sfc_cde_block.GetForwardStep(access_id);
 
 2147                     static_for<0, NumDTensor, 1>{}([&](
auto i) {
 
 2148                         cde_block_copy_lds_and_global.MoveSrcSliceWindow(
 
 2149                             c_ds_desc_refs, i + 
I1, cde_lds_and_global_step);
 
 2153                     cde_block_copy_lds_and_global.MoveDstSliceWindow(
 
 2154                         tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
 2156                         cde_lds_and_global_step);
 
 2163     template <
bool HasMainKBlockLoop,
 
 2167                                     const index_t* p_sorted_expert_ids,
 
 2168                                     const index_t* p_max_token_id,
 
 2169                                     const ADataType* p_a_grid,
 
 2170                                     const AScaleDataType* p_a_scale_grid,
 
 2171                                     const BDataType* p_b_grid,
 
 2172                                     const BScaleDataType* p_b_scale_grid,
 
 2174                                     CDataType* p_c_grid,
 
 2178                                     AElementwiseOperation a_element_op,
 
 2179                                     BElementwiseOperation b_element_op,
 
 2180                                     CElementwiseOperation c_element_op)
 
 2193         const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
 
 2212         const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
 
 2216         const index_t max_token_id    = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
 
 2217         const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.
NBlock : blockIdx.y;
 
 2218         if(expert_block_id * MPerBlock >= max_token_id)
 
 2221             __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
 
 2222         const auto block_mn = [&]() -> std::pair<int, int> {
 
 2223             if constexpr(NSwizzle)
 
 2225                 const index_t ecnt_prefix  = p_max_token_id[1 + expert_id];
 
 2227                 const index_t ecnt         = p_max_token_id[2 + expert_id] - ecnt_prefix;
 
 2228                 const index_t expert_swizzle =
 
 2229                     ecnt > 0 ? ecnt : 1; 
 
 2230                 const index_t bid_new = blockIdx.x - prefix_block;
 
 2231                 const index_t nid     = __builtin_amdgcn_readfirstlane(
 
 2232                     bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
 
 2234                     __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
 
 2239                 return {blockIdx.x, blockIdx.y};
 
 2243         const index_t block_n_id = block_mn.first;
 
 2244         const index_t block_m_id = block_mn.second;
 
 2246             __builtin_amdgcn_readfirstlane(p_sorted_token_ids[block_m_id * MPerBlock] & 0xffffff);
 
 2249         constexpr 
auto AMThreads  = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
 
 2250         constexpr 
auto AK0Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
 
 2251         constexpr 
auto AK1Threads = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I2);
 
 2252         constexpr 
auto AKThreads  = AK0Threads * AK1Threads;
 
 2253         constexpr 
auto AMRepeats  = MPerBlock / AMThreads;
 
 2254         const index_t token_pos   = block_m_id * MPerBlock + threadIdx.x / AKThreads;
 
 2256         if(token_pos >= max_token_id || token0 >= problem.
NumTokens)
 
 2260             const index_t fused_token = p_sorted_token_ids[token_pos + m0 * AMThreads];
 
 2261             index_t token_offset      = fused_token & 0xffffff;
 
 2262             if constexpr(!IsInputGemm)
 
 2264                 token_offset = token_offset * problem.
TopK + (fused_token >> 24);
 
 2266             gather_offsets(m0) = 
static_cast<IndexType
>(token_offset) * problem.
K;
 
 2270             __builtin_amdgcn_readfirstlane(problem.
N * problem.
K * (IsInputGemm ? 2 : 1));
 
 2271         const index_t expert_scale_stride = __builtin_amdgcn_readfirstlane(
 
 2272             problem.
N * (IsInputGemm ? 2 : 1) *
 
 2276         const index_t n_block_data_idx_on_grid =
 
 2277             __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
 
 2280         const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2281             p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
 
 2282         const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2283             p_b_grid + expert_id * expert_stride, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
 
 2286         const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2287             p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
 
 2288         const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2289             p_b_scale_grid + (expert_id * expert_scale_stride) / 
sizeof(BScaleDataType),
 
 2290             b_scale_grid_desc_bn_ak.GetElementSpaceSize());
 
 2305             ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
 2306             ABlockTransferThreadClusterArrangeOrder,
 
 2309             decltype(a_grid_desc_ak0_m_ak1),
 
 2310             decltype(a_block_desc_ak0_m_ak1),
 
 2311             ABlockTransferSrcAccessOrder,
 
 2312             ABlockTransferSrcVectorDim,
 
 2314             ABlockTransferSrcScalarPerVector,
 
 2316             1>(a_grid_desc_ak0_m_ak1,
 
 2318                a_block_desc_ak0_m_ak1,
 
 2323         auto b_blockwise_copy =
 
 2326                                                       BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
 2327                                                       BBlockTransferThreadClusterArrangeOrder,
 
 2330                                                       decltype(b_grid_desc_bk0_n_bk1),
 
 2331                                                       decltype(b_block_desc_bk0_n_bk1),
 
 2332                                                       BBlockTransferSrcAccessOrder,
 
 2333                                                       BBlockTransferSrcVectorDim,
 
 2335                                                       BBlockTransferSrcScalarPerVector>(
 
 2336                 b_grid_desc_bk0_n_bk1,
 
 2338                 b_block_desc_bk0_n_bk1,
 
 2343             a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
 
 2345         auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 2346             static_cast<ADataType*
>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
 
 2348         auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 2349             bit_cast<BDataType*>(
static_cast<char*
>(p_shared_0) +
 
 2350                                  a_block_space_size_aligned * 
sizeof(ADataType)),
 
 2351             b_block_desc_bk0_n_bk1.GetElementSpaceSize());
 
 2353         auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 2354             static_cast<ADataType*
>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
 
 2356         auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 2357             bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
 
 2358                                  a_block_space_size_aligned * 
sizeof(ADataType)),
 
 2359             b_block_desc_bk0_n_bk1.GetElementSpaceSize());
 
 2361         auto a_block_bufs = 
make_tuple(a_block_buf_ping, a_block_buf_pong);
 
 2362         auto b_block_bufs = 
make_tuple(b_block_buf_ping, b_block_buf_pong);
 
 2368         static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
 
 2370         auto c_thread_buf            = blockwise_gemm_pipeline.GetCThreadBuffer();
 
 2371         decltype(c_thread_buf) c_thread_buf_up;
 
 2375                                   c_thread_buf.num_of_v_,
 
 2376                                   c_thread_buf.s_per_v,
 
 2380         const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
 
 2381             (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
 
 2385         const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx();
 
 2386         const auto waveId_m = wave_idx[
I0];
 
 2387         const auto waveId_n = wave_idx[
I1];
 
 2389         auto thread_offset_shuffled =
 
 2392         auto a_thread_offset_m = waveId_m;
 
 2395         const index_t token_scale_pos = block_m_id * MPerBlock;
 
 2396         if(token_scale_pos >= max_token_id || token0 >= problem.
NumTokens)
 
 2402             decltype(a_scale_grid_desc_am_ak),
 
 2403             decltype(BlockwiseGemmPipe::a_scale_thread_desc),
 
 2409             true>(a_scale_grid_desc_am_ak,
 
 2415         auto b_thread_offset_n = waveId_n;
 
 2420             decltype(b_scale_grid_desc_bn_ak),
 
 2421             decltype(BlockwiseGemmPipe::b_scale_thread_desc),
 
 2427             true>(b_scale_grid_desc_bn_ak,
 
 2432         if constexpr(IsInputGemm)
 
 2434             const BDataType* p_b_grid_up = p_b_grid + expert_stride / 2;
 
 2435             const auto b_grid_buf_up     = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2436                 p_b_grid_up + expert_id * expert_stride,
 
 2437                 b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
 
 2441                 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
 
 2442             auto b_block_buf_up_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 2443                 bit_cast<BDataType*>(
static_cast<char*
>(p_shared_0) +
 
 2444                                      a_block_space_size_aligned * 
sizeof(ADataType) +
 
 2445                                      b_block_space_size_aligned * 
sizeof(BDataType)),
 
 2446                 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
 
 2447             auto b_block_buf_up_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 2448                 bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
 
 2449                                      a_block_space_size_aligned * 
sizeof(ADataType) +
 
 2450                                      b_block_space_size_aligned * 
sizeof(BDataType)),
 
 2451                 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
 
 2453             auto b_block_bufs_up = 
make_tuple(b_block_buf_up_ping, b_block_buf_up_pong);
 
 2458                 BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
 2459                 BBlockTransferThreadClusterArrangeOrder,
 
 2462                 decltype(b_grid_desc_bk0_n_bk1),
 
 2463                 decltype(b_block_desc_bk0_n_bk1),
 
 2464                 BBlockTransferSrcAccessOrder,
 
 2465                 BBlockTransferSrcVectorDim,
 
 2467                 BBlockTransferSrcScalarPerVector>(b_grid_desc_bk0_n_bk1,
 
 2469                                                   b_block_desc_bk0_n_bk1,
 
 2472             const BScaleDataType* p_b_scale_grid_up =
 
 2473                 p_b_scale_grid + expert_scale_stride / 2 / 
sizeof(BScaleDataType);
 
 2474             const auto b_scale_grid_buf_up = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2475                 p_b_scale_grid_up + expert_id * expert_scale_stride / 
sizeof(BScaleDataType),
 
 2476                 b_scale_grid_desc_bn_ak.GetElementSpaceSize());
 
 2481                 decltype(b_scale_grid_desc_bn_ak),
 
 2482                 decltype(BlockwiseGemmPipe::b_scale_thread_desc),
 
 2489                 b_scale_grid_desc_bn_ak,
 
 2494             blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
 
 2496                 a_grid_desc_ak0_m_ak1,
 
 2497                 a_block_desc_ak0_m_ak1,
 
 2501                 a_block_slice_copy_step,
 
 2503                 b_grid_desc_bk0_n_bk1,
 
 2504                 b_block_desc_bk0_n_bk1,
 
 2506                 b_blockwise_copy_up,
 
 2511                 b_block_slice_copy_step,
 
 2516                 a_scale_grid_desc_am_ak,
 
 2517                 a_scale_thread_copy,
 
 2520                 b_scale_grid_desc_bn_ak,
 
 2521                 b_scale_thread_copy,
 
 2522                 b_scale_thread_copy_up,
 
 2524                 b_scale_grid_buf_up,
 
 2525                 num_k_block_main_loop);
 
 2529             blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(
 
 2530                 a_grid_desc_ak0_m_ak1, 
 
 2531                 a_block_desc_ak0_m_ak1,
 
 2535                 a_block_slice_copy_step,
 
 2536                 b_grid_desc_bk0_n_bk1, 
 
 2537                 b_block_desc_bk0_n_bk1,
 
 2541                 b_block_slice_copy_step,
 
 2543                 a_scale_grid_desc_am_ak, 
 
 2544                 a_scale_thread_copy,
 
 2546                 b_scale_grid_desc_bn_ak, 
 
 2547                 b_scale_thread_copy,
 
 2549                 num_k_block_main_loop);
 
 2554             static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
 
 2555                               NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
 
 2557             static_assert(CShuffleMXdlPerWavePerShuffle % 
MXdlPack == 0 &&
 
 2558                               CShuffleNXdlPerWavePerShuffle % 
NXdlPack == 0,
 
 2561             constexpr 
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
 
 2562             constexpr 
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
 
 2565             constexpr 
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
 
 2566                 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
 
 2570             constexpr 
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
 
 2571                 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3();
 
 2573             constexpr 
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
 
 2574             constexpr 
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
 
 2575             constexpr 
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
 
 2576             constexpr 
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
 
 2577             constexpr 
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
 
 2578             constexpr 
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
 
 2579             constexpr 
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
 
 2580             constexpr 
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
 
 2581             constexpr 
auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I8);
 
 2582             constexpr 
auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I9);
 
 2586             static_assert(M0 * M1 * M2 * M3 * M4 * M5 == MPerBlock);
 
 2587             static_assert(M5 == 4);
 
 2597                                 const index_t m_pos = block_m_id * MPerBlock +
 
 2598                                                       m0 * M2 * M1 * M3 * M4 * M5 +
 
 2599                                                       m1 * M2 * M3 * M4 * M5 +
 
 2600                                                       imxdl * M3 * M4 * M5 + m3 * M4 * M5 + m4 * M5;
 
 2601                                 if constexpr(MulRoutedWeight)
 
 2604                                         *c_style_pointer_cast<const vector_type<float, M5>*>(
 
 2605                                             p_ds_grid[
I2] + m_pos);
 
 2609                                         blockwise_gemm_pipeline.GetCThreadDesc().CalculateOffset(
 
 2610                                             make_tuple(m0, n0, imxdl, inxdl, m3 * M5 + m5));
 
 2613                                     if constexpr(IsInputGemm) 
 
 2615                                         if constexpr(ActivationOperation ==
 
 2618                                             float gate = c_thread_buf[cidx];
 
 2619                                             float up   = c_thread_buf_up[cidx];
 
 2620                                             if constexpr(MulRoutedWeight)
 
 2622                                                 gate = gate * topk_weights.AsType<
float>()[m5];
 
 2623                                                 up   = up * topk_weights.AsType<
float>()[m5];
 
 2626                                             c_thread_buf_fp32(cidx) = gate * up;
 
 2630                                             float gate = c_thread_buf[cidx];
 
 2631                                             float up   = c_thread_buf_up[cidx];
 
 2632                                             if constexpr(MulRoutedWeight)
 
 2634                                                 gate = gate * topk_weights.AsType<
float>()[m5];
 
 2635                                                 up   = up * topk_weights.AsType<
float>()[m5];
 
 2638                                             c_thread_buf_fp32(cidx) = gate * up;
 
 2643                                         c_thread_buf_fp32(cidx) = c_thread_buf[cidx];
 
 2644                                         if constexpr(MulRoutedWeight)
 
 2646                                             c_thread_buf_fp32(cidx) =
 
 2647                                                 topk_weights.AsType<
float>()[m5] *
 
 2648                                                 c_thread_buf_fp32[cidx];
 
 2658             constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
 2661             auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
 2662                 static_cast<CShuffleDataType*
>(p_shared_0),
 
 2663                 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
 2666                 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
 2692             const auto c_thread_mtx_on_block =
 
 2693                 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0, 
I0, 
I0, 
I0);
 
 2695             const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
 
 2696             const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
 
 2698             const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
 
 2704             const auto m_thread_data_on_block_idx =
 
 2705                 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
 
 2708             const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
 
 2714             const auto n_thread_data_on_block_idx =
 
 2715                 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
 
 2722                 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
 2723                 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
 2726                          CShuffleNXdlPerWavePerShuffle / 
NXdlPack,
 
 2735                 Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
 
 2740                 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 2743                                        m_thread_data_on_block_idx[
I1],
 
 2744                                        n_thread_data_on_block_idx[
I1],
 
 2745                                        m_thread_data_on_block_idx[
I2],
 
 2746                                        n_thread_data_on_block_idx[
I2],
 
 2747                                        m_thread_data_on_block_idx[
I3],
 
 2748                                        m_thread_data_on_block_idx[
I4],
 
 2749                                        m_thread_data_on_block_idx[
I5],
 
 2750                                        n_thread_data_on_block_idx[
I3]),
 
 2753             using EDataType = CDataType;
 
 2758             const auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
 
 2764                     return make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2765                         p_ds_grid[i], ds_grid_desc_m_n[i].GetElementSpaceSize());
 
 2771                 tie(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
 
 2773                     [&](
auto i) -> 
const auto& 
 
 2774                     { 
return ds_grid_desc_mblock_mperblock_nblock_nperblock[i]; },
 
 2779                 tie(c_shuffle_block_buf),
 
 2781                     [&](
auto i) -> 
const auto& 
 
 2782                     { 
return ds_grid_buf[i]; },
 
 2786             const auto idx_c_ds_block_begin =
 
 2796             const auto e_grid_desc_mblock_mperblock_nblock_nperblock =
 
 2797                 c_grid_desc_mblock_mperblock_nblock_nperblock;
 
 2799             using CDEBlockTransferCluster =
 
 2800                 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
 
 2801             const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
 
 2802             constexpr 
index_t scatter_weight_idx  = 3; 
 
 2807                 decltype(c_ds_desc_refs),
 
 2808                 decltype(
tie(e_grid_desc_mblock_mperblock_nblock_nperblock)),
 
 2809                 CElementwiseOperation,
 
 2814                          CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
 2816                          CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, 
 
 2817                 CDEBlockTransferCluster,
 
 2823                 CDEShuffleBlockTransferScalarPerVectors,
 
 2835                   idx_c_ds_block_begin,
 
 2836                   tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
 2840             auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
 2841                 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
 2843             constexpr 
auto sfc_c_vgpr =
 
 2854                                   Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
 
 2856                                            CShuffleNXdlPerWavePerShuffle / 
NXdlPack,
 
 2866             constexpr 
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
 
 2869             constexpr 
auto sfc_cde_block =
 
 2873                                            CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
 2875                                            CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
 
 2877             static_assert(num_access == sfc_cde_block.GetNumOfAccess(), 
"wrong!");
 
 2878             constexpr 
auto EMThreads =
 
 2879                 CDEBlockTransferCluster{}.At(
I0) * CDEBlockTransferCluster{}.At(
I1);
 
 2880             constexpr 
auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
 
 2881             constexpr 
auto ENThreads =
 
 2882                 CDEBlockTransferCluster{}.At(
I2) * CDEBlockTransferCluster{}.At(
I3);
 
 2887                 auto dstidx = sfc_cde_block.GetIndex(access_id);
 
 2889                     block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(
I1);
 
 2891                     const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
 
 2892                     IndexType token_offset    = fused_token & 0xffffff;
 
 2893                     if constexpr(IsInputGemm)
 
 2895                         token_offset = token_offset * problem.
TopK + (fused_token >> 24);
 
 2897                     scatter_offsets(m0) = 
static_cast<IndexType
>(token_offset) * problem.
N;
 
 2903                 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 2904                                               sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
 
 2906                                               c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 2907                                               c_shuffle_block_buf);
 
 2913                 cde_block_copy_lds_and_global.Run(
 
 2916                     tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
 2920                 if constexpr(access_id < num_access - 1)
 
 2922                     constexpr 
auto cde_lds_and_global_step =
 
 2923                         sfc_cde_block.GetForwardStep(access_id);
 
 2927                         cde_block_copy_lds_and_global.MoveSrcSliceWindow(
 
 2928                             c_ds_desc_refs, i + 
I1, cde_lds_and_global_step);
 
 2932                     cde_block_copy_lds_and_global.MoveDstSliceWindow(
 
 2933                         tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
 2935                         cde_lds_and_global_step);
 
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
 
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
 
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
 
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
GemmSpecialization
Definition: gemm_specialization.hpp:11
 
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
__device__ index_t get_warp_local_1d_id()
Definition: get_id.hpp:23
 
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
 
__global__ void kernel_moe_mxgemm(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm_bns.hpp:48
 
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
 
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
 
constexpr auto BlockGemmMXPipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_mx_moe_selector.hpp:36
 
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
 
InMemoryDataOperationEnum
Definition: ck.hpp:278
 
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
 
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
 
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
 
__global__ void kernel_moe_mxgemm_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_moe_mx_gemm.hpp:87
 
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
 
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
 
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
 
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
 
Activation
Definition: gridwise_moe_gemm.hpp:31
 
@ silu_and_mul
Definition: gridwise_moe_gemm.hpp:33
 
@ gelu_and_mul
Definition: gridwise_moe_gemm.hpp:32
 
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:139
 
__host__ constexpr __device__ auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:320
 
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
 
__host__ constexpr __device__ auto concat_tuple_of_reference(const Tuple< X &... > &tx, const Tuple< Y &... > &ty)
Definition: tuple_helper.hpp:42
 
constexpr bool is_same_v
Definition: type.hpp:283
 
typename sequence_merge< Sx, Sy >::type sequence_merge_t
Definition: sequence.hpp:925
 
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
 
int32_t index_t
Definition: ck.hpp:300
 
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
 
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
 
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
 
Definition: gridwise_moe_mx_gemm.hpp:726
 
const index_t * p_max_token_id
Definition: gridwise_moe_mx_gemm.hpp:788
 
CDataType * p_c_grid
Definition: gridwise_moe_mx_gemm.hpp:794
 
const AElementwiseOperation a_element_op
Definition: gridwise_moe_mx_gemm.hpp:796
 
const index_t * p_sorted_expert_ids
Definition: gridwise_moe_mx_gemm.hpp:787
 
const CElementwiseOperation c_element_op
Definition: gridwise_moe_mx_gemm.hpp:798
 
const BDataType * p_b_grid
Definition: gridwise_moe_mx_gemm.hpp:791
 
const BScaleDataType * p_b_scale_grid
Definition: gridwise_moe_mx_gemm.hpp:792
 
DsGridPointer p_ds_grid
Definition: gridwise_moe_mx_gemm.hpp:793
 
__host__ Argument(const index_t *p_sorted_token_ids_, const index_t *p_sorted_expert_ids_, const index_t *p_max_token_id_, const ADataType *p_a_grid_, const AScaleDataType *p_a_scale_grid_, const BDataType *p_b_grid_, const BScaleDataType *p_b_scale_grid_, std::array< const void *, NumDTensor > p_ds_grid_, CDataType *p_c_grid_, index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_)
Definition: gridwise_moe_mx_gemm.hpp:727
 
const index_t * p_sorted_token_ids
Definition: gridwise_moe_mx_gemm.hpp:786
 
const BElementwiseOperation b_element_op
Definition: gridwise_moe_mx_gemm.hpp:797
 
const AScaleDataType * p_a_scale_grid
Definition: gridwise_moe_mx_gemm.hpp:790
 
const ADataType * p_a_grid
Definition: gridwise_moe_mx_gemm.hpp:789
 
Definition: gridwise_moe_mx_gemm.hpp:643
 
index_t MBlock
Definition: gridwise_moe_mx_gemm.hpp:720
 
index_t NPadded
Definition: gridwise_moe_mx_gemm.hpp:715
 
index_t K
Definition: gridwise_moe_mx_gemm.hpp:706
 
index_t N
Definition: gridwise_moe_mx_gemm.hpp:705
 
index_t NumTokens
Definition: gridwise_moe_mx_gemm.hpp:702
 
index_t M
Definition: gridwise_moe_mx_gemm.hpp:704
 
index_t StrideA
Definition: gridwise_moe_mx_gemm.hpp:707
 
index_t StrideScaleB
Definition: gridwise_moe_mx_gemm.hpp:710
 
index_t KRead
Definition: gridwise_moe_mx_gemm.hpp:716
 
index_t NBlock
Definition: gridwise_moe_mx_gemm.hpp:721
 
index_t StrideC
Definition: gridwise_moe_mx_gemm.hpp:712
 
index_t StrideB
Definition: gridwise_moe_mx_gemm.hpp:709
 
__host__ void Print() const
Definition: gridwise_moe_mx_gemm.hpp:679
 
index_t BK0
Definition: gridwise_moe_mx_gemm.hpp:719
 
index_t StrideScaleA
Definition: gridwise_moe_mx_gemm.hpp:708
 
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_moe_mx_gemm.hpp:711
 
index_t MPadded
Definition: gridwise_moe_mx_gemm.hpp:714
 
index_t KBatch
Definition: gridwise_moe_mx_gemm.hpp:713
 
index_t KPadded
Definition: gridwise_moe_mx_gemm.hpp:717
 
index_t TopK
Definition: gridwise_moe_mx_gemm.hpp:703
 
index_t AK0
Definition: gridwise_moe_mx_gemm.hpp:718
 
__host__ Problem(index_t NumTokens_, index_t TopK_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideScaleA_, index_t StrideB_, index_t StrideScaleB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_moe_mx_gemm.hpp:644
 
Definition: gridwise_moe_mx_gemm.hpp:802
 
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_moe_mx_gemm.hpp:803
 
index_t b_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:857
 
index_t b_scale_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:859
 
index_t a_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:856
 
index_t a_scale_k_split_offset
Definition: gridwise_moe_mx_gemm.hpp:858
 
Definition: gridwise_moe_mx_gemm.hpp:173
 
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_moe_mx_gemm.hpp:242
 
static constexpr index_t APackedSize
Definition: gridwise_moe_mx_gemm.hpp:212
 
static constexpr auto MakeDsGridPointer()
Definition: gridwise_moe_mx_gemm.hpp:227
 
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_moe_mx_gemm.hpp:1176
 
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_moe_mx_gemm.hpp:619
 
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm.hpp:631
 
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:273
 
static constexpr auto AK0Number
Definition: gridwise_moe_mx_gemm.hpp:191
 
static constexpr auto I7
Definition: gridwise_moe_mx_gemm.hpp:184
 
static constexpr index_t BPackedSize
Definition: gridwise_moe_mx_gemm.hpp:213
 
remove_cvref_t< decltype(BlockGemmMXPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_M3_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_N3_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack, IsInputGemm >())> BlockwiseGemmPipe
Definition: gridwise_moe_mx_gemm.hpp:1137
 
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_moe_mx_gemm.hpp:252
 
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:279
 
static constexpr auto KXdlPack
Definition: gridwise_moe_mx_gemm.hpp:204
 
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_moe_mx_gemm.hpp:980
 
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(IndexType M, IndexType MPad, IndexType K, IndexType KPad, IndexType StrideA, IndexType AK0)
Definition: gridwise_moe_mx_gemm.hpp:331
 
static constexpr auto I6
Definition: gridwise_moe_mx_gemm.hpp:183
 
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_moe_mx_gemm.hpp:1361
 
static constexpr auto AK1Number
Definition: gridwise_moe_mx_gemm.hpp:193
 
static constexpr auto I9
Definition: gridwise_moe_mx_gemm.hpp:186
 
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:285
 
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_moe_mx_gemm.hpp:238
 
static constexpr auto I8
Definition: gridwise_moe_mx_gemm.hpp:185
 
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_moe_mx_gemm.hpp:446
 
static constexpr auto I0
Definition: gridwise_moe_mx_gemm.hpp:177
 
BDataType LDSTypeB
Definition: gridwise_moe_mx_gemm.hpp:175
 
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_M3_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_moe_mx_gemm.hpp:555
 
static constexpr auto NXdlPack
Definition: gridwise_moe_mx_gemm.hpp:203
 
__host__ static __device__ auto MakeCGridDescriptor_M_N(IndexType M, IndexType MPad, IndexType N, IndexType NPad, IndexType StrideC)
Definition: gridwise_moe_mx_gemm.hpp:574
 
static constexpr auto I3
Definition: gridwise_moe_mx_gemm.hpp:180
 
static constexpr index_t scale_pack_size_b
Definition: gridwise_moe_mx_gemm.hpp:1389
 
static constexpr index_t scale_pack_size_a
Definition: gridwise_moe_mx_gemm.hpp:1388
 
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_moe_mx_gemm.hpp:292
 
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_moe_mx_gemm.hpp:262
 
static constexpr index_t NumDTensor
Definition: gridwise_moe_mx_gemm.hpp:200
 
static constexpr index_t SortedTileSize
Definition: gridwise_moe_mx_gemm.hpp:225
 
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_moe_mx_gemm.hpp:297
 
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_moe_mx_gemm.hpp:862
 
static constexpr auto CShuffleBlockTransferScalarPerVector_NPerBlock
Definition: gridwise_moe_mx_gemm.hpp:188
 
static __device__ void Run_2Lds(const index_t *p_sorted_token_ids, const index_t *p_sorted_expert_ids, const index_t *p_max_token_id, const ADataType *p_a_grid, const AScaleDataType *p_a_scale_grid, const BDataType *p_b_grid, const BScaleDataType *p_b_scale_grid, DsGridPointer &p_ds_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_moe_mx_gemm.hpp:2166
 
ADataType LDSTypeA
Definition: gridwise_moe_mx_gemm.hpp:174
 
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_N3_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_moe_mx_gemm.hpp:565
 
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_moe_mx_gemm.hpp:240
 
static constexpr auto lcm_AK1_BK1
Definition: gridwise_moe_mx_gemm.hpp:196
 
static constexpr auto I1
Definition: gridwise_moe_mx_gemm.hpp:178
 
static constexpr auto BK1Number
Definition: gridwise_moe_mx_gemm.hpp:194
 
static constexpr auto is_scale_mfma
Definition: gridwise_moe_mx_gemm.hpp:198
 
__host__ static __device__ auto MakeDGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_moe_mx_gemm.hpp:598
 
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_moe_mx_gemm.hpp:257
 
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_moe_mx_gemm.hpp:1139
 
static constexpr auto I5
Definition: gridwise_moe_mx_gemm.hpp:182
 
static constexpr index_t KPack
Definition: gridwise_moe_mx_gemm.hpp:221
 
static constexpr auto I4
Definition: gridwise_moe_mx_gemm.hpp:181
 
static constexpr auto I2
Definition: gridwise_moe_mx_gemm.hpp:179
 
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_moe_mx_gemm.hpp:307
 
static constexpr auto MXdlPack
Definition: gridwise_moe_mx_gemm.hpp:202
 
static constexpr bool is_single_rate_mfma
Definition: gridwise_moe_mx_gemm.hpp:197
 
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_moe_mx_gemm.hpp:1094
 
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_moe_mx_gemm.hpp:1369
 
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_moe_mx_gemm.hpp:267
 
static constexpr auto BK0Number
Definition: gridwise_moe_mx_gemm.hpp:192
 
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_moe_mx_gemm.hpp:1354
 
Definition: xdlops_gemm.hpp:942
 
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1343
 
Definition: sequence.hpp:43
 
Definition: tensor_space_filling_curve.hpp:20
 
Definition: static_buffer.hpp:75
 
Definition: thread_group_tensor_slice_transfer_direct_load.hpp:55
 
Definition: thread_group_tensor_slice_transfer_gather_direct_load.hpp:57
 
Definition: thread_group_tensor_slice_transfer_v7r3_scatter.hpp:51
 
Definition: threadwise_tensor_slice_transfer.hpp:39
 
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
 
Definition: tuple.hpp:117
 
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:25
 
Definition: data_type.hpp:41
 
Definition: integral_constant.hpp:20
 
Definition: data_type.hpp:186
 
Definition: functional2.hpp:33
 
Definition: device_base.hpp:51
 
Definition: unary_element_wise_operation.hpp:981
 
Definition: unary_element_wise_operation.hpp:308
 
Definition: unary_element_wise_operation.hpp:1023
 
Definition: dtype_vector.hpp:10
 
#define CK_ENV(name)
Definition: env.hpp:128