26 template <
typename GridwiseGemm,
27 bool HasMainKBlockLoop,
32 #if CK_USE_LAUNCH_BOUNDS
38 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
39 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
41 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg);
43 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
44 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
45 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
46 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
54 template <
typename GridwiseGemm,
55 bool HasMainKBlockLoop,
60 #if CK_USE_LAUNCH_BOUNDS
66 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
69 __shared__
char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
70 __shared__
char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
72 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg);
74 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
75 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
76 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
77 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
191 template <
typename ALayout,
196 typename AccDataType,
197 typename CShuffleDataType,
199 typename AElementwiseOperation,
200 typename BElementwiseOperation,
201 typename CElementwiseOperation,
213 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
214 typename ABlockTransferThreadClusterArrangeOrder,
215 typename ABlockTransferSrcAccessOrder,
216 index_t ABlockTransferSrcVectorDim,
217 index_t ABlockTransferSrcScalarPerVector,
218 index_t ABlockTransferDstScalarPerVector_AK1,
219 bool AThreadTransferSrcResetCoordinateAfterRun,
221 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
222 typename BBlockTransferThreadClusterArrangeOrder,
223 typename BBlockTransferSrcAccessOrder,
224 index_t BBlockTransferSrcVectorDim,
225 index_t BBlockTransferSrcScalarPerVector,
226 index_t BBlockTransferDstScalarPerVector_BK1,
227 bool BThreadTransferSrcResetCoordinateAfterRun,
229 index_t CShuffleMXdlPerWavePerShuffle,
230 index_t CShuffleNXdlPerWavePerShuffle,
231 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
232 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
235 typename ComputeTypeA = CDataType,
236 typename ComputeTypeB = ComputeTypeA,
237 bool PermuteA =
false,
238 bool PermuteB =
false,
239 bool DoElementwiseBeforeCShuffle =
false>
264 KPerBlock < 128 && MPerXdl == 16))
315 auto K_t = K_Batch * KPerBlock;
316 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
321 auto K_t = K_Batch * KPerBlock;
322 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
327 auto K_t = K_Batch * KPerBlock;
328 return (K + K_t - 1) / K_t * KPerBlock;
334 auto K_t = K_Batch * KReadVec;
335 return (K + K_t - 1) / K_t * KReadVec;
348 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl,
typename TileDesc_K0_MN_K1>
366 const auto a_grid_desc_mraw_kraw = [&]() {
367 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
371 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
379 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
380 GemmSpec == GemmSpecialization::MNKPadding)
383 const auto a_grid_desc_m_k =
397 return a_grid_desc_ak0_m_ak1;
399 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
400 GemmSpec == GemmSpecialization::MNPadding)
404 a_grid_desc_mraw_kraw,
410 return a_grid_desc_ak0_m_ak1;
412 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
413 GemmSpec == GemmSpecialization::NKPadding)
417 a_grid_desc_mraw_kraw,
429 return a_grid_desc_ak0_m_ak1;
435 a_grid_desc_mraw_kraw,
441 return a_grid_desc_ak0_m_ak1;
448 const auto b_grid_desc_nraw_kraw = [&]() {
462 GemmSpec != GemmSpecialization::Default),
463 "pk_i4_t does not support padding");
465 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
466 GemmSpec == GemmSpecialization::MNKPadding)
469 const auto b_grid_desc_n_k =
483 return b_grid_desc_bk0_n_bk1;
485 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
486 GemmSpec == GemmSpecialization::MNPadding)
490 b_grid_desc_nraw_kraw,
496 return b_grid_desc_bk0_n_bk1;
498 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
499 GemmSpec == GemmSpecialization::MKPadding)
503 b_grid_desc_nraw_kraw,
515 return b_grid_desc_bk0_n_bk1;
519 if constexpr(!PermuteB)
523 b_grid_desc_nraw_kraw,
529 return b_grid_desc_bk0_n_bk1;
535 constexpr
index_t BK01 = KPerBlock / BK1Value;
536 const index_t BK0_ = StrideB / BK1Value;
537 const index_t BK00 = BK0_ / BK01;
539 const auto b_grid_desc_bk00_n_bk01_bk1_permute =
543 b_grid_desc_bk00_n_bk01_bk1_permute,
550 return b_grid_desc_bk0_n_bk1_permute;
555 template <
typename ABlockDesc_AK0_M_AK1>
556 __host__ __device__
static constexpr
auto
559 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
561 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
564 template <
typename BBlockDesc_BK0_N_BK1>
565 __host__ __device__
static constexpr
auto
568 constexpr
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
570 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
573 __host__ __device__
static auto
576 const auto c_grid_desc_mraw_nraw = [&]() {
596 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
597 GemmSpec == GemmSpecialization::MNKPadding)
606 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
607 GemmSpec == GemmSpecialization::MKPadding)
611 c_grid_desc_mraw_nraw,
616 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
617 GemmSpec == GemmSpecialization::NKPadding)
621 c_grid_desc_mraw_nraw,
629 return c_grid_desc_mraw_nraw;
643 AElementwiseOperation a_element_op,
644 BElementwiseOperation b_element_op,
645 CElementwiseOperation c_element_op)
669 std::cout <<
"problem {"
678 <<
"KRead:" <<
KRead <<
", "
680 <<
"AK0:" <<
AK0 <<
", "
681 <<
"BK0:" <<
BK0 <<
", "
682 <<
"MBlock: " <<
MBlock <<
", "
683 <<
"NBlock: " <<
NBlock <<
"}" << std::endl;
710 const BDataType* p_b_grid_,
711 CDataType* p_c_grid_,
719 bool is_reduce_ =
false,
720 AElementwiseOperation
a_element_op = AElementwiseOperation{},
721 BElementwiseOperation
b_element_op = BElementwiseOperation{},
722 CElementwiseOperation
c_element_op = CElementwiseOperation{})
761 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
765 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
770 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
774 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
776 if constexpr(!PermuteB)
782 const int k0_offset = karg.
KRead * karg.
N;
787 if(blockIdx.z <
static_cast<uint32_t
>(karg.
KBatch - 1))
827 constexpr
auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
842 a_lds_block_desc_permuted,
850 a_lds_block_desc_ak0_mldslayer_m_ak1,
858 return a_lds_block_desc_ak0_m_ak1;
865 constexpr
auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
866 constexpr
auto M1 = MPerBlock / M0;
868 constexpr
auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
869 constexpr
auto K0PerThreadWrite =
AK0Number / KThreadWrite;
870 constexpr
auto KThreadRead = 64 / MPerXdl;
871 constexpr
auto K0PerThreadRead =
AK0Number / KThreadRead;
873 constexpr
auto kfold = (
AK1Number * M0 *
sizeof(ADataType) > 128)
875 : 128 / (
AK1Number * M0 *
sizeof(ADataType));
876 constexpr
auto KThreadReadPerm =
877 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
878 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
882 constexpr
auto mpair = (
AK1Number * MPerXdl *
sizeof(ADataType) > 128)
884 : ((128 / (
AK1Number * MPerXdl *
sizeof(ADataType))) > M0
886 : 128 / (
AK1Number * MPerXdl *
sizeof(ADataType)));
892 Number<kfold * M0 / mpair>{},
911 a_lds_block_desc_permuted,
933 a_lds_block_desc_unmerged,
936 Number<KThreadWrite / kfold / KThreadReadPerm>{},
945 return a_lds_block_desc_ak0_m_ak1;
964 constexpr
index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
979 b_lds_block_desc_permuted,
987 b_lds_block_desc_bk0_nldslayer_n_bk1,
995 return b_lds_block_desc_bk0_n_bk1;
999 constexpr
auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I1);
1000 constexpr
auto N1 = NPerBlock / N0;
1002 constexpr
auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I0);
1003 constexpr
auto K0PerThreadWrite =
BK0Number / KThreadWrite;
1004 constexpr
auto KThreadRead = 64 / NPerXdl;
1005 constexpr
auto K0PerThreadRead =
BK0Number / KThreadRead;
1007 constexpr
auto kfold = (
BK1Number * N0 *
sizeof(BDataType) > 128)
1009 : 128 / (
BK1Number * N0 *
sizeof(BDataType));
1010 constexpr
auto KThreadReadPerm =
1011 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
1012 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
1016 constexpr
auto npair = (
BK1Number * NPerXdl *
sizeof(BDataType) > 128)
1018 : ((128 / (
BK1Number * NPerXdl *
sizeof(BDataType))) > N0
1020 : 128 / (
BK1Number * NPerXdl *
sizeof(BDataType)));
1026 Number<kfold * N0 / npair>{},
1045 b_lds_block_desc_permuted,
1067 b_lds_block_desc_unmerged,
1070 Number<KThreadWrite / kfold / KThreadReadPerm>{},
1079 return b_lds_block_desc_bk0_n_bk1;
1085 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1086 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1088 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1095 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
1113 ABlockTransferSrcScalarPerVector,
1114 BBlockTransferSrcScalarPerVector,
1134 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1137 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1140 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1143 constexpr
auto c_block_size =
1144 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
1147 b_block_space_size_aligned *
sizeof(BDataType) /
BPackedSize),
1148 c_block_size *
sizeof(CShuffleDataType));
1154 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1155 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1156 "Invalid tuning param!");
1164 if(!(karg.
M % MPerBlock == 0))
1168 std::cout <<
"Arg M value is not a multiple of MPerBlock! M: " << karg.
M <<
" "
1169 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1182 if(!(karg.
N % NPerBlock == 0))
1186 std::cout <<
"Arg N value is not a multiple of NPerBlock! N: " << karg.
N <<
" "
1187 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1200 auto K_t = karg.
KBatch * KPerBlock;
1201 if(!(karg.
K % K_t == 0))
1205 std::cout <<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1206 << karg.
K <<
" " << __FILE__ <<
":" << __LINE__
1207 <<
", in function: " << __func__ << std::endl;
1215 auto K_t = karg.
KBatch * KReadVec;
1217 if((KReadPadSplited * (karg.
KBatch - 1)) >= karg.
K)
1225 if(karg.
K % ABlockTransferSrcScalarPerVector != 0)
1229 std::cout <<
"Arg K (" << karg.
K
1230 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1231 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1232 << __LINE__ <<
", in function: " << __func__ << std::endl;
1239 if(karg.
M % ABlockTransferSrcScalarPerVector != 0)
1243 std::cout <<
"Arg M (" << karg.
M
1244 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1245 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1246 << __LINE__ <<
", in function: " << __func__ << std::endl;
1254 if(karg.
N % BBlockTransferSrcScalarPerVector != 0)
1258 std::cout <<
"Arg N (" << karg.
N
1259 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1260 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1261 << __LINE__ <<
", in function: " << __func__ << std::endl;
1268 if(karg.
K % BBlockTransferSrcScalarPerVector != 0)
1272 std::cout <<
"Arg K (" << karg.
K
1273 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1274 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1275 << __LINE__ <<
", in function: " << __func__ << std::endl;
1283 if(karg.
N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1287 std::cout <<
"Arg N (" << karg.
N
1288 <<
") value is not a multiple of "
1289 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1290 << CShuffleBlockTransferScalarPerVector_NPerBlock <<
" )! "
1291 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1299 if(karg.
M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1303 std::cout <<
"Arg M (" << karg.
M
1304 <<
") value is not a multiple of "
1305 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1306 << CShuffleBlockTransferScalarPerVector_NPerBlock <<
" )! "
1307 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1323 std::cout <<
" KBatch: " << karg.
KBatch <<
" > 1 is not support yet" << __FILE__
1324 <<
":" << __LINE__ <<
", in function: " << __func__ << std::endl;
1334 const auto num_k_loop = karg.
AK0 / (KPerBlock / AK1Value);
1338 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1350 const index_t num_loop = K / KPerBlock;
1352 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1357 const index_t num_loop = K / KPerBlock;
1359 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1362 template <
typename CGr
idDesc>
1364 const CGridDesc& c_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1373 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1381 template <
typename AGridDesc_AK0_M_K1,
1382 typename BGridDesc_BK0_N_K1,
1383 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1384 bool HasMainKBlockLoop,
1387 __device__
static void Run(
const ADataType* p_a_grid,
1388 const BDataType* p_b_grid,
1389 CDataType* p_c_grid,
1392 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1393 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1394 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1395 c_grid_desc_mblock_mperblock_nblock_nperblock)
1397 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1398 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1399 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1400 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1401 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1402 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1407 const auto block_work_idx =
1410 if(!block_2_ctile_map.ValidCTileIndex(
1412 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
1413 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
1418 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
1419 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
1422 const index_t m_block_data_idx_on_grid =
1423 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1425 const index_t n_block_data_idx_on_grid =
1426 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1438 auto a_blockwise_copy =
1440 AElementwiseOperation,
1444 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1445 ABlockTransferThreadClusterArrangeOrder,
1448 decltype(a_grid_desc_ak0_m_ak1),
1449 decltype(a_block_desc_ak0_m_ak1),
1450 ABlockTransferSrcAccessOrder,
1452 ABlockTransferSrcVectorDim,
1454 ABlockTransferSrcScalarPerVector,
1455 ABlockTransferDstScalarPerVector_AK1,
1458 AThreadTransferSrcResetCoordinateAfterRun,
1460 BlockwiseGemmPipe::GlobalBufferNum>(
1461 a_grid_desc_ak0_m_ak1,
1464 a_block_desc_ak0_m_ak1,
1469 auto b_blockwise_copy =
1471 BElementwiseOperation,
1475 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1476 BBlockTransferThreadClusterArrangeOrder,
1479 decltype(b_grid_desc_bk0_n_bk1),
1480 decltype(b_block_desc_bk0_n_bk1),
1481 BBlockTransferSrcAccessOrder,
1483 BBlockTransferSrcVectorDim,
1485 BBlockTransferSrcScalarPerVector,
1486 BBlockTransferDstScalarPerVector_BK1,
1489 BThreadTransferSrcResetCoordinateAfterRun,
1491 BlockwiseGemmPipe::GlobalBufferNum>(
1492 b_grid_desc_bk0_n_bk1,
1495 b_block_desc_bk0_n_bk1,
1501 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1504 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1505 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1507 auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1508 reinterpret_cast<BDataType*
>(
static_cast<char*
>(p_shared) + a_block_space_size_aligned *
1511 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1517 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1519 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1521 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1522 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1525 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1526 a_block_desc_ak0_m_ak1,
1530 a_block_slice_copy_step,
1531 b_grid_desc_bk0_n_bk1,
1532 b_block_desc_bk0_n_bk1,
1536 b_block_slice_copy_step,
1538 num_k_block_main_loop);
1542 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1543 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1546 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1547 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1550 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1551 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1555 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1556 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1558 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
1559 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
1560 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
1561 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
1562 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
1563 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
1564 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
1565 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
1567 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1570 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1571 static_cast<CShuffleDataType*
>(p_shared),
1572 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1575 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1595 const auto c_thread_mtx_on_block =
1596 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1598 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1599 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1601 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1607 const auto m_thread_data_on_block_idx =
1608 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1611 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1617 const auto n_thread_data_on_block_idx =
1618 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1622 const auto& vpgr_to_lds_element_op = [&] {
1623 if constexpr(DoElementwiseBeforeCShuffle)
1629 return pass_through;
1632 const auto& lds_to_global_element_op = [&] {
1633 if constexpr(!DoElementwiseBeforeCShuffle)
1639 return pass_through;
1647 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1648 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1650 CElementwiseOperation,
1652 Sequence<CShuffleMXdlPerWavePerShuffle,
1653 CShuffleNXdlPerWavePerShuffle,
1665 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1668 m_thread_data_on_block_idx[
I1],
1669 n_thread_data_on_block_idx[
I1],
1670 m_thread_data_on_block_idx[
I2],
1671 m_thread_data_on_block_idx[
I3],
1672 m_thread_data_on_block_idx[
I4],
1673 n_thread_data_on_block_idx[
I2]),
1674 vpgr_to_lds_element_op()};
1680 CElementwiseOperation,
1682 CGlobalMemoryDataOperation,
1684 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1686 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
1687 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1691 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1692 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1695 CShuffleBlockTransferScalarPerVector_NPerBlock,
1698 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1700 c_grid_desc_mblock_mperblock_nblock_nperblock,
1702 lds_to_global_element_op()};
1705 constexpr
auto sfc_c_vgpr =
1708 Sequence<CShuffleMXdlPerWavePerShuffle,
1709 CShuffleNXdlPerWavePerShuffle,
1718 constexpr
auto sfc_c_global =
1722 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1724 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1726 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1728 static_assert(num_access == sfc_c_global.GetNumOfAccess(),
"wrong!");
1735 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1736 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1738 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1739 c_shuffle_block_buf);
1745 c_shuffle_block_copy_lds_to_global.Run(
1746 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1747 c_shuffle_block_buf,
1748 c_grid_desc_mblock_mperblock_nblock_nperblock,
1751 if constexpr(access_id < num_access - 1)
1753 constexpr
auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1756 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1757 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1763 template <
bool HasMainKBlockLoop,
1766 __device__
static void Run(
const ADataType* p_a_grid,
1767 const BDataType* p_b_grid,
1768 CDataType* p_c_grid,
1778 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1782 Run<decltype(a_grid_desc_ak0_m_ak1),
1783 decltype(b_grid_desc_bk0_n_bk1),
1784 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1786 CGlobalMemoryDataOperation,
1792 a_grid_desc_ak0_m_ak1,
1793 b_grid_desc_bk0_n_bk1,
1794 c_grid_desc_mblock_mperblock_nblock_nperblock);
1797 template <
typename AGridDesc_AK0_M_K1,
1798 typename BGridDesc_BK0_N_K1,
1799 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1800 bool HasMainKBlockLoop,
1803 __device__
static void Run_2Lds(
const ADataType* p_a_grid,
1804 const BDataType* p_b_grid,
1805 CDataType* p_c_grid,
1809 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1810 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1811 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1812 c_grid_desc_mblock_mperblock_nblock_nperblock)
1814 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1815 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1816 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1817 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1818 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1819 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1824 const auto block_work_idx =
1827 if(!block_2_ctile_map.ValidCTileIndex(
1829 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
1830 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
1835 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
1836 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
1839 const index_t m_block_data_idx_on_grid =
1840 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1842 const index_t n_block_data_idx_on_grid =
1843 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1855 auto a_blockwise_copy =
1857 AElementwiseOperation,
1861 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1862 ABlockTransferThreadClusterArrangeOrder,
1865 decltype(a_grid_desc_ak0_m_ak1),
1866 decltype(a_block_desc_ak0_m_ak1),
1867 ABlockTransferSrcAccessOrder,
1869 ABlockTransferSrcVectorDim,
1871 ABlockTransferSrcScalarPerVector,
1872 ABlockTransferDstScalarPerVector_AK1,
1875 AThreadTransferSrcResetCoordinateAfterRun,
1877 BlockwiseGemmPipe::GlobalBufferNum>(
1878 a_grid_desc_ak0_m_ak1,
1881 a_block_desc_ak0_m_ak1,
1886 auto b_blockwise_copy =
1888 BElementwiseOperation,
1892 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1893 BBlockTransferThreadClusterArrangeOrder,
1896 decltype(b_grid_desc_bk0_n_bk1),
1897 decltype(b_block_desc_bk0_n_bk1),
1898 BBlockTransferSrcAccessOrder,
1900 BBlockTransferSrcVectorDim,
1902 BBlockTransferSrcScalarPerVector,
1903 BBlockTransferDstScalarPerVector_BK1,
1906 BThreadTransferSrcResetCoordinateAfterRun,
1908 BlockwiseGemmPipe::GlobalBufferNum>(
1909 b_grid_desc_bk0_n_bk1,
1912 b_block_desc_bk0_n_bk1,
1918 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1920 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1921 static_cast<ADataType*
>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1923 auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1924 bit_cast<BDataType*>(
static_cast<char*
>(p_shared_0) +
1925 a_block_space_size_aligned *
sizeof(ADataType)),
1926 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1928 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1929 static_cast<ADataType*
>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1931 auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1932 bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
1933 a_block_space_size_aligned *
sizeof(ADataType)),
1934 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1936 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
1937 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
1943 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1945 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1947 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1948 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1951 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1952 a_block_desc_ak0_m_ak1,
1956 a_block_slice_copy_step,
1957 b_grid_desc_bk0_n_bk1,
1958 b_block_desc_bk0_n_bk1,
1962 b_block_slice_copy_step,
1964 num_k_block_main_loop);
1968 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1969 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1972 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1973 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1976 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1977 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1981 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1982 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1984 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
1985 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
1986 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
1987 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
1988 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
1989 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
1990 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
1991 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
1993 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1996 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1997 static_cast<CShuffleDataType*
>(p_shared_0),
1998 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
2001 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2021 const auto c_thread_mtx_on_block =
2022 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
2024 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
2025 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
2027 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
2033 const auto m_thread_data_on_block_idx =
2034 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
2037 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
2043 const auto n_thread_data_on_block_idx =
2044 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
2048 auto c_thread_copy_vgpr_to_lds =
2051 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2052 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
2054 Sequence<CShuffleMXdlPerWavePerShuffle,
2055 CShuffleNXdlPerWavePerShuffle,
2068 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2071 m_thread_data_on_block_idx[
I1],
2072 n_thread_data_on_block_idx[
I1],
2073 m_thread_data_on_block_idx[
I2],
2074 m_thread_data_on_block_idx[
I3],
2075 m_thread_data_on_block_idx[
I4],
2076 n_thread_data_on_block_idx[
I2]),
2082 CElementwiseOperation,
2083 CGlobalMemoryDataOperation,
2085 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2087 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
2088 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
2092 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
2093 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2096 CShuffleBlockTransferScalarPerVector_NPerBlock,
2099 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2101 c_grid_desc_mblock_mperblock_nblock_nperblock,
2106 constexpr
auto sfc_c_vgpr =
2109 Sequence<CShuffleMXdlPerWavePerShuffle,
2110 CShuffleNXdlPerWavePerShuffle,
2119 constexpr
auto sfc_c_global =
2123 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
2125 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
2127 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
2129 static_assert(num_access == sfc_c_global.GetNumOfAccess(),
"wrong!");
2136 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2137 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
2139 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
2140 c_shuffle_block_buf);
2146 c_shuffle_block_copy_lds_to_global.Run(
2147 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
2148 c_shuffle_block_buf,
2149 c_grid_desc_mblock_mperblock_nblock_nperblock,
2152 if constexpr(access_id < num_access - 1)
2154 constexpr
auto c_global_step = sfc_c_global.GetForwardStep(access_id);
2157 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
2158 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
2164 template <
bool HasMainKBlockLoop,
2167 __device__
static void Run_2Lds(
const ADataType* p_a_grid,
2168 const BDataType* p_b_grid,
2169 CDataType* p_c_grid,
2181 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2185 Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
2186 decltype(b_grid_desc_bk0_n_bk1),
2187 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2189 CGlobalMemoryDataOperation,
2196 a_grid_desc_ak0_m_ak1,
2197 b_grid_desc_bk0_n_bk1,
2198 c_grid_desc_mblock_mperblock_nblock_nperblock);
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
int32_t int32_t
Definition: integer.hpp:10
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:278
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr auto BlockGemmPipeline_Selector()
Definition: blockwise_gemm_pipeline_wmma_selector.hpp:31
_Float16 half_t
Definition: data_type.hpp:30
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
ushort bhalf_t
Definition: data_type.hpp:29
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:59
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:139
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
constexpr bool is_same_v
Definition: type.hpp:283
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:38
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:300
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: block_to_ctile_map.hpp:270
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:282
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:708
const BElementwiseOperation b_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:649
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t k_batch_, bool is_reduce_=false, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CElementwiseOperation c_element_op=CElementwiseOperation{})
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:709
const BDataType * p_b_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:751
CDataType * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:752
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:740
const AElementwiseOperation a_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:648
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:745
const ADataType * p_a_grid
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:750
const CElementwiseOperation c_element_op
Definition: gridwise_gemm_xdl_cshuffle_v3_b_scale.hpp:650
bool is_reduce
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:753
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:635
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:687
index_t NPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:694
index_t KBatch
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:692
index_t StrideA
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:689
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:636
CElementwiseOperation c_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:703
index_t BK0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:698
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:686
index_t NBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:700
index_t MPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:693
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:688
index_t StrideB
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:690
index_t KPadded
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:696
index_t StrideC
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:691
index_t MBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:699
BElementwiseOperation b_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:702
index_t AK0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:697
index_t KRead
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:695
AElementwiseOperation a_element_op_
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:701
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:667
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:757
index_t a_k_split_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:806
index_t b_k_split_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:807
__device__ SplitKBatchOffset(Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:759
index_t c_reduce_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:808
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:241
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:566
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:445
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:331
static constexpr auto is_scale_mfma
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:267
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:308
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:298
static constexpr auto BK1Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:255
static constexpr index_t APackedSize
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:279
static constexpr bool is_single_rate_mfma
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:258
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1355
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:277
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1363
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:313
static constexpr auto I2
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:244
static constexpr index_t KPack
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:268
static constexpr auto lcm_AK1_BK1
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:257
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:349
static constexpr auto I7
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:249
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1122
static constexpr auto I5
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:247
static constexpr auto AK1Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:254
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1152
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1803
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:293
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:338
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:2167
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:303
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:557
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:319
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1124
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:949
static constexpr index_t BPackedSize
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:286
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1387
static constexpr auto I6
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:248
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:811
static constexpr auto I1
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:243
static constexpr auto I0
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:242
static constexpr auto I3
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:245
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1766
static constexpr auto I4
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:246
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:325
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:363
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:343
static constexpr auto BK0Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:253
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1083
static constexpr auto AK0Number
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:252
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1348
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:574
Definition: xdlops_gemm.hpp:942
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:20
Definition: data_type.hpp:186
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:308
#define CK_ENV(name)
Definition: env.hpp:128