26 template <
typename GridwiseGemm,
27 bool HasMainKBlockLoop,
32 #if CK_USE_LAUNCH_BOUNDS
39 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
41 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg);
43 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
44 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
45 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
46 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
54 template <
typename GridwiseGemm,
55 bool HasMainKBlockLoop,
60 #if CK_USE_LAUNCH_BOUNDS
69 __shared__
char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
70 __shared__
char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
72 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg);
74 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
75 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
76 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
77 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
191 template <
typename ALayout,
196 typename AccDataType,
197 typename CShuffleDataType,
199 typename AElementwiseOperation,
200 typename BElementwiseOperation,
201 typename CElementwiseOperation,
213 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
214 typename ABlockTransferThreadClusterArrangeOrder,
215 typename ABlockTransferSrcAccessOrder,
216 index_t ABlockTransferSrcVectorDim,
217 index_t ABlockTransferSrcScalarPerVector,
218 index_t ABlockTransferDstScalarPerVector_AK1,
219 bool AThreadTransferSrcResetCoordinateAfterRun,
221 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
222 typename BBlockTransferThreadClusterArrangeOrder,
223 typename BBlockTransferSrcAccessOrder,
224 index_t BBlockTransferSrcVectorDim,
225 index_t BBlockTransferSrcScalarPerVector,
226 index_t BBlockTransferDstScalarPerVector_BK1,
227 bool BThreadTransferSrcResetCoordinateAfterRun,
229 index_t CShuffleMXdlPerWavePerShuffle,
230 index_t CShuffleNXdlPerWavePerShuffle,
231 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
232 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
235 typename ComputeTypeA = CDataType,
236 typename ComputeTypeB = ComputeTypeA,
237 bool PermuteA =
false,
238 bool PermuteB =
false,
239 bool DoElementwiseBeforeCShuffle =
false>
264 KPerBlock < 128 && MPerXdl == 16))
315 auto K_t = K_Batch * KPerBlock;
316 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
321 auto K_t = K_Batch * KPerBlock;
322 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
327 auto K_t = K_Batch * KPerBlock;
328 return (K + K_t - 1) / K_t * KPerBlock;
334 auto K_t = K_Batch * KReadVec;
335 return (K + K_t - 1) / K_t * KReadVec;
348 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl,
typename TileDesc_K0_MN_K1>
366 const auto a_grid_desc_mraw_kraw = [&]() {
367 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
371 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
379 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
380 GemmSpec == GemmSpecialization::MNKPadding)
383 const auto a_grid_desc_m_k =
397 return a_grid_desc_ak0_m_ak1;
399 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
400 GemmSpec == GemmSpecialization::MNPadding)
404 a_grid_desc_mraw_kraw,
410 return a_grid_desc_ak0_m_ak1;
412 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
413 GemmSpec == GemmSpecialization::NKPadding)
417 a_grid_desc_mraw_kraw,
429 return a_grid_desc_ak0_m_ak1;
435 a_grid_desc_mraw_kraw,
441 return a_grid_desc_ak0_m_ak1;
448 const auto b_grid_desc_nraw_kraw = [&]() {
462 GemmSpec != GemmSpecialization::Default),
463 "pk_i4_t does not support padding");
465 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
466 GemmSpec == GemmSpecialization::MNKPadding)
469 const auto b_grid_desc_n_k =
483 return b_grid_desc_bk0_n_bk1;
485 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
486 GemmSpec == GemmSpecialization::MNPadding)
490 b_grid_desc_nraw_kraw,
496 return b_grid_desc_bk0_n_bk1;
498 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
499 GemmSpec == GemmSpecialization::MKPadding)
503 b_grid_desc_nraw_kraw,
515 return b_grid_desc_bk0_n_bk1;
519 if constexpr(!PermuteB)
523 b_grid_desc_nraw_kraw,
529 return b_grid_desc_bk0_n_bk1;
535 constexpr
index_t BK01 = KPerBlock / BK1Value;
536 const index_t BK0_ = StrideB / BK1Value;
537 const index_t BK00 = BK0_ / BK01;
539 const auto b_grid_desc_bk00_n_bk01_bk1_permute =
543 b_grid_desc_bk00_n_bk01_bk1_permute,
550 return b_grid_desc_bk0_n_bk1_permute;
555 template <
typename ABlockDesc_AK0_M_AK1>
556 __host__ __device__
static constexpr
auto
559 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
561 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
564 template <
typename BBlockDesc_BK0_N_BK1>
565 __host__ __device__
static constexpr
auto
568 constexpr
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
570 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
573 __host__ __device__
static auto
576 const auto c_grid_desc_mraw_nraw = [&]() {
596 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
597 GemmSpec == GemmSpecialization::MNKPadding)
606 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
607 GemmSpec == GemmSpecialization::MKPadding)
611 c_grid_desc_mraw_nraw,
616 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
617 GemmSpec == GemmSpecialization::NKPadding)
621 c_grid_desc_mraw_nraw,
629 return c_grid_desc_mraw_nraw;
643 AElementwiseOperation a_element_op,
644 BElementwiseOperation b_element_op,
645 CElementwiseOperation c_element_op)
669 std::cout <<
"problem {" <<
"M:" <<
M <<
", " <<
"N:" <<
N <<
", " <<
"K:" <<
K <<
", "
672 <<
"KRead:" <<
KRead <<
", " <<
"KP:" <<
KPadded <<
", " <<
"AK0:" <<
AK0
673 <<
", " <<
"BK0:" <<
BK0 <<
", " <<
"MBlock: " <<
MBlock <<
", "
674 <<
"NBlock: " <<
NBlock <<
"}" << std::endl;
701 const BDataType* p_b_grid_,
702 CDataType* p_c_grid_,
710 bool is_reduce_ =
false,
711 AElementwiseOperation
a_element_op = AElementwiseOperation{},
712 BElementwiseOperation
b_element_op = BElementwiseOperation{},
713 CElementwiseOperation
c_element_op = CElementwiseOperation{})
752 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
756 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
761 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
765 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
767 if constexpr(!PermuteB)
773 const int k0_offset = karg.
KRead * karg.
N;
778 if(blockIdx.z <
static_cast<uint32_t
>(karg.
KBatch - 1))
818 constexpr
auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
833 a_lds_block_desc_permuted,
841 a_lds_block_desc_ak0_mldslayer_m_ak1,
849 return a_lds_block_desc_ak0_m_ak1;
856 constexpr
auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
857 constexpr
auto M1 = MPerBlock / M0;
859 constexpr
auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
860 constexpr
auto K0PerThreadWrite =
AK0Number / KThreadWrite;
861 constexpr
auto KThreadRead = 64 / MPerXdl;
862 constexpr
auto K0PerThreadRead =
AK0Number / KThreadRead;
864 constexpr
auto kfold = (
AK1Number * M0 *
sizeof(ADataType) > 128)
866 : 128 / (
AK1Number * M0 *
sizeof(ADataType));
867 constexpr
auto KThreadReadPerm =
868 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
869 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
873 constexpr
auto mpair = (
AK1Number * MPerXdl *
sizeof(ADataType) > 128)
875 : ((128 / (
AK1Number * MPerXdl *
sizeof(ADataType))) > M0
877 : 128 / (
AK1Number * MPerXdl *
sizeof(ADataType)));
883 Number<kfold * M0 / mpair>{},
902 a_lds_block_desc_permuted,
924 a_lds_block_desc_unmerged,
927 Number<KThreadWrite / kfold / KThreadReadPerm>{},
936 return a_lds_block_desc_ak0_m_ak1;
955 constexpr
index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
970 b_lds_block_desc_permuted,
978 b_lds_block_desc_bk0_nldslayer_n_bk1,
986 return b_lds_block_desc_bk0_n_bk1;
990 constexpr
auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I1);
991 constexpr
auto N1 = NPerBlock / N0;
993 constexpr
auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I0);
994 constexpr
auto K0PerThreadWrite =
BK0Number / KThreadWrite;
995 constexpr
auto KThreadRead = 64 / NPerXdl;
996 constexpr
auto K0PerThreadRead =
BK0Number / KThreadRead;
998 constexpr
auto kfold = (
BK1Number * N0 *
sizeof(BDataType) > 128)
1000 : 128 / (
BK1Number * N0 *
sizeof(BDataType));
1001 constexpr
auto KThreadReadPerm =
1002 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
1003 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
1007 constexpr
auto npair = (
BK1Number * NPerXdl *
sizeof(BDataType) > 128)
1009 : ((128 / (
BK1Number * NPerXdl *
sizeof(BDataType))) > N0
1011 : 128 / (
BK1Number * NPerXdl *
sizeof(BDataType)));
1017 Number<kfold * N0 / npair>{},
1036 b_lds_block_desc_permuted,
1058 b_lds_block_desc_unmerged,
1061 Number<KThreadWrite / kfold / KThreadReadPerm>{},
1070 return b_lds_block_desc_bk0_n_bk1;
1076 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1077 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1079 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1086 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1104 ABlockTransferSrcScalarPerVector,
1105 BBlockTransferSrcScalarPerVector,
1125 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1128 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1131 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1134 constexpr
auto c_block_size =
1135 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1138 b_block_space_size_aligned *
sizeof(BDataType) /
BPackedSize),
1139 c_block_size *
sizeof(CShuffleDataType));
1145 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1146 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1147 "Invalid tuning param!");
1155 if(!(karg.
M % MPerBlock == 0))
1159 std::cout <<
"Arg M value is not a multiple of MPerBlock! M: " << karg.
M <<
" "
1160 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1173 if(!(karg.
N % NPerBlock == 0))
1177 std::cout <<
"Arg N value is not a multiple of NPerBlock! N: " << karg.
N <<
" "
1178 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1191 auto K_t = karg.
KBatch * KPerBlock;
1192 if(!(karg.
K % K_t == 0))
1196 std::cout <<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1197 << karg.
K <<
" " << __FILE__ <<
":" << __LINE__
1198 <<
", in function: " << __func__ << std::endl;
1206 auto K_t = karg.
KBatch * KReadVec;
1208 if((KReadPadSplited * (karg.
KBatch - 1)) >= karg.
K)
1216 if(karg.
K % ABlockTransferSrcScalarPerVector != 0)
1220 std::cout <<
"Arg K (" << karg.
K
1221 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1222 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1223 << __LINE__ <<
", in function: " << __func__ << std::endl;
1230 if(karg.
M % ABlockTransferSrcScalarPerVector != 0)
1234 std::cout <<
"Arg M (" << karg.
M
1235 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1236 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1237 << __LINE__ <<
", in function: " << __func__ << std::endl;
1245 if(karg.
N % BBlockTransferSrcScalarPerVector != 0)
1249 std::cout <<
"Arg N (" << karg.
N
1250 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1251 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1252 << __LINE__ <<
", in function: " << __func__ << std::endl;
1259 if(karg.
K % BBlockTransferSrcScalarPerVector != 0)
1263 std::cout <<
"Arg K (" << karg.
K
1264 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1265 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1266 << __LINE__ <<
", in function: " << __func__ << std::endl;
1274 if(karg.
N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1278 std::cout <<
"Arg N (" << karg.
N
1279 <<
") value is not a multiple of "
1280 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1281 << CShuffleBlockTransferScalarPerVector_NPerBlock <<
" )! "
1282 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1290 if(karg.
M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1294 std::cout <<
"Arg M (" << karg.
M
1295 <<
") value is not a multiple of "
1296 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1297 << CShuffleBlockTransferScalarPerVector_NPerBlock <<
" )! "
1298 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1314 std::cout <<
" KBatch: " << karg.
KBatch <<
" > 1 is not support yet" << __FILE__
1315 <<
":" << __LINE__ <<
", in function: " << __func__ << std::endl;
1325 const auto num_k_loop = karg.
AK0 / (KPerBlock / AK1Value);
1329 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1341 const index_t num_loop = K / KPerBlock;
1343 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1348 const index_t num_loop = K / KPerBlock;
1350 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1353 template <
typename CGr
idDesc>
1355 const CGridDesc& c_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1364 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1372 template <
typename AGridDesc_AK0_M_K1,
1373 typename BGridDesc_BK0_N_K1,
1374 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1375 bool HasMainKBlockLoop,
1378 __device__
static void Run(
const ADataType* p_a_grid,
1379 const BDataType* p_b_grid,
1380 CDataType* p_c_grid,
1383 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1384 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1385 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1386 c_grid_desc_mblock_mperblock_nblock_nperblock)
1388 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1389 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1390 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1391 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1392 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1393 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1398 const auto block_work_idx =
1401 if(!block_2_ctile_map.ValidCTileIndex(
1403 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
1404 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
1409 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
1410 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
1413 const index_t m_block_data_idx_on_grid =
1414 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1416 const index_t n_block_data_idx_on_grid =
1417 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1429 auto a_blockwise_copy =
1431 AElementwiseOperation,
1435 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1436 ABlockTransferThreadClusterArrangeOrder,
1439 decltype(a_grid_desc_ak0_m_ak1),
1440 decltype(a_block_desc_ak0_m_ak1),
1441 ABlockTransferSrcAccessOrder,
1443 ABlockTransferSrcVectorDim,
1445 ABlockTransferSrcScalarPerVector,
1446 ABlockTransferDstScalarPerVector_AK1,
1449 AThreadTransferSrcResetCoordinateAfterRun,
1451 BlockwiseGemmPipe::GlobalBufferNum>(
1452 a_grid_desc_ak0_m_ak1,
1455 a_block_desc_ak0_m_ak1,
1460 auto b_blockwise_copy =
1462 BElementwiseOperation,
1466 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1467 BBlockTransferThreadClusterArrangeOrder,
1470 decltype(b_grid_desc_bk0_n_bk1),
1471 decltype(b_block_desc_bk0_n_bk1),
1472 BBlockTransferSrcAccessOrder,
1474 BBlockTransferSrcVectorDim,
1476 BBlockTransferSrcScalarPerVector,
1477 BBlockTransferDstScalarPerVector_BK1,
1480 BThreadTransferSrcResetCoordinateAfterRun,
1482 BlockwiseGemmPipe::GlobalBufferNum>(
1483 b_grid_desc_bk0_n_bk1,
1486 b_block_desc_bk0_n_bk1,
1492 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1495 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1496 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1498 auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1499 reinterpret_cast<BDataType*
>(
static_cast<char*
>(p_shared) + a_block_space_size_aligned *
1502 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1508 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1510 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1512 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1513 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1516 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1517 a_block_desc_ak0_m_ak1,
1521 a_block_slice_copy_step,
1522 b_grid_desc_bk0_n_bk1,
1523 b_block_desc_bk0_n_bk1,
1527 b_block_slice_copy_step,
1529 num_k_block_main_loop);
1533 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1534 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1537 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1538 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1541 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1542 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1546 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1547 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1549 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
1550 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
1551 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
1552 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
1553 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
1554 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
1555 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
1556 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
1558 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1561 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1562 static_cast<CShuffleDataType*
>(p_shared),
1563 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1566 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1586 const auto c_thread_mtx_on_block =
1587 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1589 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1590 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1592 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1598 const auto m_thread_data_on_block_idx =
1599 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1602 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1608 const auto n_thread_data_on_block_idx =
1609 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1613 const auto& vpgr_to_lds_element_op = [&] {
1614 if constexpr(DoElementwiseBeforeCShuffle)
1620 return pass_through;
1623 const auto& lds_to_global_element_op = [&] {
1624 if constexpr(!DoElementwiseBeforeCShuffle)
1630 return pass_through;
1638 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1639 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1641 CElementwiseOperation,
1643 Sequence<CShuffleMXdlPerWavePerShuffle,
1644 CShuffleNXdlPerWavePerShuffle,
1656 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1659 m_thread_data_on_block_idx[
I1],
1660 n_thread_data_on_block_idx[
I1],
1661 m_thread_data_on_block_idx[
I2],
1662 m_thread_data_on_block_idx[
I3],
1663 m_thread_data_on_block_idx[
I4],
1664 n_thread_data_on_block_idx[
I2]),
1665 vpgr_to_lds_element_op()};
1671 CElementwiseOperation,
1673 CGlobalMemoryDataOperation,
1675 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1677 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
1678 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1682 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1683 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1686 CShuffleBlockTransferScalarPerVector_NPerBlock,
1689 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1691 c_grid_desc_mblock_mperblock_nblock_nperblock,
1693 lds_to_global_element_op()};
1696 constexpr
auto sfc_c_vgpr =
1699 Sequence<CShuffleMXdlPerWavePerShuffle,
1700 CShuffleNXdlPerWavePerShuffle,
1709 constexpr
auto sfc_c_global =
1713 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1715 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1717 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1719 static_assert(num_access == sfc_c_global.GetNumOfAccess(),
"wrong!");
1726 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1727 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1729 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1730 c_shuffle_block_buf);
1736 c_shuffle_block_copy_lds_to_global.Run(
1737 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1738 c_shuffle_block_buf,
1739 c_grid_desc_mblock_mperblock_nblock_nperblock,
1742 if constexpr(access_id < num_access - 1)
1744 constexpr
auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1747 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1748 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1754 template <
bool HasMainKBlockLoop,
1757 __device__
static void Run(
const ADataType* p_a_grid,
1758 const BDataType* p_b_grid,
1759 CDataType* p_c_grid,
1769 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1773 Run<decltype(a_grid_desc_ak0_m_ak1),
1774 decltype(b_grid_desc_bk0_n_bk1),
1775 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1777 CGlobalMemoryDataOperation,
1783 a_grid_desc_ak0_m_ak1,
1784 b_grid_desc_bk0_n_bk1,
1785 c_grid_desc_mblock_mperblock_nblock_nperblock);
1788 template <
typename AGridDesc_AK0_M_K1,
1789 typename BGridDesc_BK0_N_K1,
1790 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1791 bool HasMainKBlockLoop,
1794 __device__
static void Run_2Lds(
const ADataType* p_a_grid,
1795 const BDataType* p_b_grid,
1796 CDataType* p_c_grid,
1800 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1801 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1802 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1803 c_grid_desc_mblock_mperblock_nblock_nperblock)
1805 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1806 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1807 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1808 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1809 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1810 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1815 const auto block_work_idx =
1818 if(!block_2_ctile_map.ValidCTileIndex(
1820 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
1821 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
1826 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
1827 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
1830 const index_t m_block_data_idx_on_grid =
1831 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1833 const index_t n_block_data_idx_on_grid =
1834 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1846 auto a_blockwise_copy =
1848 AElementwiseOperation,
1852 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1853 ABlockTransferThreadClusterArrangeOrder,
1856 decltype(a_grid_desc_ak0_m_ak1),
1857 decltype(a_block_desc_ak0_m_ak1),
1858 ABlockTransferSrcAccessOrder,
1860 ABlockTransferSrcVectorDim,
1862 ABlockTransferSrcScalarPerVector,
1863 ABlockTransferDstScalarPerVector_AK1,
1866 AThreadTransferSrcResetCoordinateAfterRun,
1868 BlockwiseGemmPipe::GlobalBufferNum>(
1869 a_grid_desc_ak0_m_ak1,
1872 a_block_desc_ak0_m_ak1,
1877 auto b_blockwise_copy =
1879 BElementwiseOperation,
1883 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1884 BBlockTransferThreadClusterArrangeOrder,
1887 decltype(b_grid_desc_bk0_n_bk1),
1888 decltype(b_block_desc_bk0_n_bk1),
1889 BBlockTransferSrcAccessOrder,
1891 BBlockTransferSrcVectorDim,
1893 BBlockTransferSrcScalarPerVector,
1894 BBlockTransferDstScalarPerVector_BK1,
1897 BThreadTransferSrcResetCoordinateAfterRun,
1899 BlockwiseGemmPipe::GlobalBufferNum>(
1900 b_grid_desc_bk0_n_bk1,
1903 b_block_desc_bk0_n_bk1,
1909 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1911 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1912 static_cast<ADataType*
>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1914 auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1915 bit_cast<BDataType*>(
static_cast<char*
>(p_shared_0) +
1916 a_block_space_size_aligned *
sizeof(ADataType)),
1917 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1919 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1920 static_cast<ADataType*
>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1922 auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1923 bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
1924 a_block_space_size_aligned *
sizeof(ADataType)),
1925 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1927 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
1928 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
1934 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1936 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1938 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1939 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1942 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1943 a_block_desc_ak0_m_ak1,
1947 a_block_slice_copy_step,
1948 b_grid_desc_bk0_n_bk1,
1949 b_block_desc_bk0_n_bk1,
1953 b_block_slice_copy_step,
1955 num_k_block_main_loop);
1959 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1960 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1963 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1964 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1967 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1968 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1972 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1973 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1975 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
1976 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
1977 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
1978 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
1979 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
1980 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
1981 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
1982 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
1984 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1987 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1988 static_cast<CShuffleDataType*
>(p_shared_0),
1989 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1992 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2012 const auto c_thread_mtx_on_block =
2013 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
2015 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
2016 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
2018 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2024 const auto m_thread_data_on_block_idx =
2025 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2028 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2034 const auto n_thread_data_on_block_idx =
2035 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2039 auto c_thread_copy_vgpr_to_lds =
2042 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2043 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2045 Sequence<CShuffleMXdlPerWavePerShuffle,
2046 CShuffleNXdlPerWavePerShuffle,
2059 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2062 m_thread_data_on_block_idx[
I1],
2063 n_thread_data_on_block_idx[
I1],
2064 m_thread_data_on_block_idx[
I2],
2065 m_thread_data_on_block_idx[
I3],
2066 m_thread_data_on_block_idx[
I4],
2067 n_thread_data_on_block_idx[
I2]),
2073 CElementwiseOperation,
2074 CGlobalMemoryDataOperation,
2076 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2078 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
2079 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
2083 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2084 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2087 CShuffleBlockTransferScalarPerVector_NPerBlock,
2090 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2092 c_grid_desc_mblock_mperblock_nblock_nperblock,
2097 constexpr
auto sfc_c_vgpr =
2100 Sequence<CShuffleMXdlPerWavePerShuffle,
2101 CShuffleNXdlPerWavePerShuffle,
2110 constexpr
auto sfc_c_global =
2114 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2116 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2118 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2120 static_assert(num_access == sfc_c_global.GetNumOfAccess(),
"wrong!");
2127 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2128 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2130 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2131 c_shuffle_block_buf);
2137 c_shuffle_block_copy_lds_to_global.Run(
2138 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2139 c_shuffle_block_buf,
2140 c_grid_desc_mblock_mperblock_nblock_nperblock,
2143 if constexpr(access_id < num_access - 1)
2145 constexpr
auto c_global_step = sfc_c_global.GetForwardStep(access_id);
2148 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
2149 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
2155 template <
bool HasMainKBlockLoop,
2158 __device__
static void Run_2Lds(
const ADataType* p_a_grid,
2159 const BDataType* p_b_grid,
2160 CDataType* p_c_grid,
2172 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2176 Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
2177 decltype(b_grid_desc_bk0_n_bk1),
2178 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2180 CGlobalMemoryDataOperation,
2187 a_grid_desc_ak0_m_ak1,
2188 b_grid_desc_bk0_n_bk1,
2189 c_grid_desc_mblock_mperblock_nblock_nperblock);
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
int32_t int32_t
Definition: integer.hpp:10
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:275
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr auto BlockGemmPipeline_Selector()
Definition: blockwise_gemm_pipeline_wmma_selector.hpp:31
_Float16 half_t
Definition: data_type.hpp:30
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
ushort bhalf_t
Definition: data_type.hpp:29
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:59
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:140
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
constexpr bool is_same_v
Definition: type.hpp:283
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:297
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: block_to_ctile_map.hpp:270
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:282
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:699
const BElementwiseOperation b_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:639
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t k_batch_, bool is_reduce_=false, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CElementwiseOperation c_element_op=CElementwiseOperation{})
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:700
const BDataType * p_b_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:742
CDataType * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:743
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:731
const AElementwiseOperation a_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:638
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:736
const ADataType * p_a_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:741
const CElementwiseOperation c_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:640
bool is_reduce
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:744
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:635
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:678
index_t NPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:685
index_t KBatch
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:683
index_t StrideA
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:680
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:636
CElementwiseOperation c_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:694
index_t BK0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:689
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:677
index_t NBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:691
index_t MPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:684
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:679
index_t StrideB
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:681
index_t KPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:687
index_t StrideC
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:682
index_t MBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:690
BElementwiseOperation b_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:693
index_t AK0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:688
index_t KRead
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:686
AElementwiseOperation a_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:692
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:667
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:748
index_t a_k_split_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:797
index_t b_k_split_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:798
__device__ SplitKBatchOffset(Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:750
index_t c_reduce_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:799
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:241
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:566
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:445
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:331
static constexpr auto is_scale_mfma
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:267
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:308
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:298
static constexpr auto BK1Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:255
static constexpr index_t APackedSize
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:279
static constexpr bool is_single_rate_mfma
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:258
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1346
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:277
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1354
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:313
static constexpr auto I2
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:244
static constexpr index_t KPack
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:268
static constexpr auto lcm_AK1_BK1
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:257
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:349
static constexpr auto I7
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:249
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1113
static constexpr auto I5
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:247
static constexpr auto AK1Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:254
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1143
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1794
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:293
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:338
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:2158
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:303
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:557
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:319
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1115
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:940
static constexpr index_t BPackedSize
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:286
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1378
static constexpr auto I6
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:248
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:802
static constexpr auto I1
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:243
static constexpr auto I0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:242
static constexpr auto I3
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:245
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1757
static constexpr auto I4
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:246
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:325
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:363
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:343
static constexpr auto BK0Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:253
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1074
static constexpr auto AK0Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:252
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1339
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:574
Definition: xdlops_gemm.hpp:942
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: data_type.hpp:197
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:308
#define CK_ENV(name)
Definition: env.hpp:129