20 template <
typename GridwiseGemm,
21 bool HasMainKBlockLoop,
26 #if CK_USE_LAUNCH_BOUNDS
31 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__) || defined(__gfx12__))
32 #if defined(__gfx11__)
36 (std::is_same_v<c_data_type, ck::half_t> ||
37 std::is_same_v<c_data_type, ck::bhalf_t>)))
40 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
42 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg);
44 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
45 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
46 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
47 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
50 #if defined(__gfx11__)
161 template <
typename ALayout,
166 typename AccDataType,
167 typename CShuffleDataType,
169 typename AElementwiseOperation,
170 typename BElementwiseOperation,
171 typename CElementwiseOperation,
183 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
184 typename ABlockTransferThreadClusterArrangeOrder,
185 typename ABlockTransferSrcAccessOrder,
186 index_t ABlockTransferSrcVectorDim,
187 index_t ABlockTransferSrcScalarPerVector,
188 index_t ABlockTransferDstScalarPerVector_AK1,
189 bool AThreadTransferSrcResetCoordinateAfterRun,
191 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
192 typename BBlockTransferThreadClusterArrangeOrder,
193 typename BBlockTransferSrcAccessOrder,
194 index_t BBlockTransferSrcVectorDim,
195 index_t BBlockTransferSrcScalarPerVector,
196 index_t BBlockTransferDstScalarPerVector_BK1,
197 bool BThreadTransferSrcResetCoordinateAfterRun,
199 index_t CShuffleMRepeatPerShuffle,
200 index_t CShuffleNRepeatPerShuffle,
201 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
202 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
205 typename ComputeTypeA,
206 typename ComputeTypeB,
269 auto K_t = K_Batch * KPerBlock;
270 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
275 auto K_t = K_Batch * KPerBlock;
276 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
281 auto K_t = K_Batch * KPerBlock;
282 return (K + K_t - 1) / K_t * KPerBlock;
288 auto K_t = K_Batch * KReadVec;
289 return (K + K_t - 1) / K_t * KReadVec;
302 template <index_t MNRepeat, index_t MNWaves, index_t MNPerWmma,
typename BlockDesc>
306 constexpr
auto K0 = BlockDesc{}.GetLength(
I0);
307 constexpr
auto K1 = BlockDesc{}.GetLength(
I2);
309 constexpr
auto KRow =
I2;
311 constexpr
auto KRow =
I1;
326 const auto a_grid_desc_mraw_kraw = [&]() {
327 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
331 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
339 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
340 GemmSpec == GemmSpecialization::MNKPadding)
343 const auto a_grid_desc_m_k =
357 return a_grid_desc_ak0_m_ak1;
359 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
360 GemmSpec == GemmSpecialization::MNPadding)
364 a_grid_desc_mraw_kraw,
370 return a_grid_desc_ak0_m_ak1;
372 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
373 GemmSpec == GemmSpecialization::NKPadding)
377 a_grid_desc_mraw_kraw,
389 return a_grid_desc_ak0_m_ak1;
393 static_assert(!PermuteA,
"PermuteA is not supported");
397 a_grid_desc_mraw_kraw,
403 return a_grid_desc_ak0_m_ak1;
410 const auto b_grid_desc_nraw_kraw = [&]() {
424 GemmSpec != GemmSpecialization::Default),
425 "pk_i4_t does not support padding");
427 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
428 GemmSpec == GemmSpecialization::MNKPadding)
431 const auto b_grid_desc_n_k =
445 return b_grid_desc_bk0_n_bk1;
447 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
448 GemmSpec == GemmSpecialization::MNPadding)
452 b_grid_desc_nraw_kraw,
458 return b_grid_desc_bk0_n_bk1;
460 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
461 GemmSpec == GemmSpecialization::MKPadding)
465 b_grid_desc_nraw_kraw,
477 return b_grid_desc_bk0_n_bk1;
481 if constexpr(!PermuteB)
485 b_grid_desc_nraw_kraw,
491 return b_grid_desc_bk0_n_bk1;
497 constexpr
index_t BK01 = KPerBlock / BK1Value;
498 const index_t BK0_ = StrideB / BK1Value;
499 const index_t BK00 = BK0_ / BK01;
501 const auto b_grid_desc_bk00_n_bk01_bk1_permute =
505 b_grid_desc_bk00_n_bk01_bk1_permute,
512 return b_grid_desc_bk0_n_bk1_permute;
517 template <
typename ABlockDesc_AK0_M_AK1>
520 constexpr
index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
522 return MakeWmmaTileDescriptor<MRepeat, MWaves, MPerWmma>(ABlockDesc_AK0_M_AK1{});
525 template <
typename BBlockDesc_BK0_N_BK1>
528 constexpr
index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
530 return MakeWmmaTileDescriptor<NRepeat, NWaves, NPerWmma>(BBlockDesc_BK0_N_BK1{});
533 __host__ __device__
static auto
536 const auto c_grid_desc_mraw_nraw = [&]() {
558 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
559 GemmSpec == GemmSpecialization::MNKPadding)
568 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
569 GemmSpec == GemmSpecialization::MKPadding)
573 c_grid_desc_mraw_nraw,
578 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
579 GemmSpec == GemmSpecialization::NKPadding)
583 c_grid_desc_mraw_nraw,
591 return c_grid_desc_mraw_nraw;
625 std::cout <<
"problem {"
634 <<
"KRead:" <<
KRead <<
", "
636 <<
"AK0:" <<
AK0 <<
", "
637 <<
"BK0:" <<
BK0 <<
", "
638 <<
"MBlock: " <<
MBlock <<
", "
639 <<
"NBlock: " <<
NBlock <<
"}" << std::endl;
663 const BDataType* p_b_grid_,
664 CDataType* p_c_grid_,
672 bool is_reduce_ =
false)
673 :
Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_},
702 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
706 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
711 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
715 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
717 if constexpr(!PermuteB)
723 const int k0_offset = karg.
KRead * karg.
N;
728 if(blockIdx.z <
static_cast<uint32_t
>(karg.
KBatch - 1))
768 constexpr
auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
783 a_lds_block_desc_permuted,
791 a_lds_block_desc_ak0_mldslayer_m_ak1,
799 return a_lds_block_desc_ak0_m_ak1;
806 constexpr
auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
807 constexpr
auto M1 = MPerBlock / M0;
809 constexpr
auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
810 constexpr
auto K0PerThreadWrite =
AK0Number / KThreadWrite;
811 constexpr
auto KThreadRead = 64 / MPerWmma;
812 constexpr
auto K0PerThreadRead =
AK0Number / KThreadRead;
814 constexpr
auto kfold = (
AK1Number * M0 *
sizeof(ADataType) > 128)
816 : 128 / (
AK1Number * M0 *
sizeof(ADataType));
817 constexpr
auto KThreadReadPerm =
818 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
819 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
823 constexpr
auto mpair = (
AK1Number * MPerWmma *
sizeof(ADataType) > 128)
825 : ((128 / (
AK1Number * MPerWmma *
sizeof(ADataType))) > M0
827 : 128 / (
AK1Number * MPerWmma *
sizeof(ADataType)));
833 Number<kfold * M0 / mpair>{},
852 a_lds_block_desc_permuted,
874 a_lds_block_desc_unmerged,
877 Number<KThreadWrite / kfold / KThreadReadPerm>{},
886 return a_lds_block_desc_ak0_m_ak1;
905 constexpr
index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
920 b_lds_block_desc_permuted,
928 b_lds_block_desc_bk0_nldslayer_n_bk1,
936 return b_lds_block_desc_bk0_n_bk1;
940 constexpr
auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I1);
941 constexpr
auto N1 = NPerBlock / N0;
943 constexpr
auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I0);
944 constexpr
auto K0PerThreadWrite =
BK0Number / KThreadWrite;
945 constexpr
auto KThreadRead = 64 / NPerWmma;
946 constexpr
auto K0PerThreadRead =
BK0Number / KThreadRead;
948 constexpr
auto kfold = (
BK1Number * N0 *
sizeof(BDataType) > 128)
950 : 128 / (
BK1Number * N0 *
sizeof(BDataType));
951 constexpr
auto KThreadReadPerm =
952 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
953 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
957 constexpr
auto npair = (
BK1Number * NPerWmma *
sizeof(BDataType) > 128)
959 : ((128 / (
BK1Number * NPerWmma *
sizeof(BDataType))) > N0
961 : 128 / (
BK1Number * NPerWmma *
sizeof(BDataType)));
967 Number<kfold * N0 / npair>{},
986 b_lds_block_desc_permuted,
1008 b_lds_block_desc_unmerged,
1011 Number<KThreadWrite / kfold / KThreadReadPerm>{},
1020 return b_lds_block_desc_bk0_n_bk1;
1024 __host__ __device__
static constexpr
auto
1028 constexpr
index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
1029 constexpr
index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
1031 constexpr
auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
1038 return c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat;
1053 ABlockTransferSrcScalarPerVector,
1054 BBlockTransferSrcScalarPerVector,
1074 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1077 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
1080 constexpr
auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
1083 constexpr
auto c_block_size =
1084 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
1085 .GetElementSpaceSize();
1088 b_block_space_size_aligned *
sizeof(BDataType) /
BPackedSize),
1089 c_block_size *
sizeof(CShuffleDataType));
1095 static_assert((MPerBlock % (MPerWmma * MRepeat) == 0) &&
1096 (NPerBlock % (NPerWmma * NRepeat)) == 0,
1097 "Invalid tuning param!");
1105 if(!(karg.
M % MPerBlock == 0))
1109 std::cout <<
"Arg M value is not a multiple of MPerBlock! M: " << karg.
M <<
" "
1110 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1123 if(!(karg.
N % NPerBlock == 0))
1127 std::cout <<
"Arg N value is not a multiple of NPerBlock! N: " << karg.
N <<
" "
1128 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1141 auto K_t = karg.
KBatch * KPerBlock;
1142 if(!(karg.
K % K_t == 0))
1146 std::cout <<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1147 << karg.
K <<
" " << __FILE__ <<
":" << __LINE__
1148 <<
", in function: " << __func__ << std::endl;
1156 auto K_t = karg.
KBatch * KReadVec;
1158 if((KReadPadSplited * (karg.
KBatch - 1)) >= karg.
K)
1166 if(karg.
K % ABlockTransferSrcScalarPerVector != 0)
1170 std::cout <<
"Arg K (" << karg.
K
1171 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1172 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1173 << __LINE__ <<
", in function: " << __func__ << std::endl;
1180 if(karg.
M % ABlockTransferSrcScalarPerVector != 0)
1184 std::cout <<
"Arg M (" << karg.
M
1185 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1186 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1187 << __LINE__ <<
", in function: " << __func__ << std::endl;
1195 if(karg.
N % BBlockTransferSrcScalarPerVector != 0)
1199 std::cout <<
"Arg N (" << karg.
N
1200 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1201 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1202 << __LINE__ <<
", in function: " << __func__ << std::endl;
1209 if(karg.
K % BBlockTransferSrcScalarPerVector != 0)
1213 std::cout <<
"Arg K (" << karg.
K
1214 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1215 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1216 << __LINE__ <<
", in function: " << __func__ << std::endl;
1224 if(karg.
N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1228 std::cout <<
"Arg N (" << karg.
N
1229 <<
") value is not a multiple of "
1230 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1231 << CShuffleBlockTransferScalarPerVector_NPerBlock <<
" )! "
1232 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1240 if(karg.
M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1244 std::cout <<
"Arg M (" << karg.
M
1245 <<
") value is not a multiple of "
1246 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1247 << CShuffleBlockTransferScalarPerVector_NPerBlock <<
" )! "
1248 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1264 std::cout <<
" KBatch: " << karg.
KBatch <<
" > 1 is not supported yet"
1265 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1276 const auto num_k_loop = karg.
AK0 / (KPerBlock / AK1Value);
1280 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1292 const index_t num_loop = K / KPerBlock;
1294 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1299 const index_t num_loop = K / KPerBlock;
1301 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1304 template <
typename CGr
idDesc>
1306 const CGridDesc& c_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1315 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1323 template <
typename AGridDesc_AK0_M_K1,
1324 typename BGridDesc_BK0_N_K1,
1325 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1326 bool HasMainKBlockLoop,
1329 __device__
static void Run(
const ADataType* p_a_grid,
1330 const BDataType* p_b_grid,
1331 CDataType* p_c_grid,
1334 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1335 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1336 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1337 c_grid_desc_mblock_mperblock_nblock_nperblock)
1339 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1340 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1341 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1342 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1343 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1344 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1346 const AElementwiseOperation a_element_op{};
1347 const BElementwiseOperation b_element_op{};
1348 const CElementwiseOperation c_element_op{};
1353 const auto block_work_idx =
1356 if(!block_2_ctile_map.ValidCTileIndex(
1358 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
1359 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
1364 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
1365 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
1368 const index_t m_block_data_idx_on_grid =
1369 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1371 const index_t n_block_data_idx_on_grid =
1372 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1384 auto a_blockwise_copy =
1386 AElementwiseOperation,
1390 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1391 ABlockTransferThreadClusterArrangeOrder,
1394 decltype(a_grid_desc_ak0_m_ak1),
1395 decltype(a_block_desc_ak0_m_ak1),
1396 ABlockTransferSrcAccessOrder,
1398 ABlockTransferSrcVectorDim,
1400 ABlockTransferSrcScalarPerVector,
1401 ABlockTransferDstScalarPerVector_AK1,
1404 AThreadTransferSrcResetCoordinateAfterRun,
1406 BlockwiseGemmPipe::GlobalBufferNum>(
1407 a_grid_desc_ak0_m_ak1,
1410 a_block_desc_ak0_m_ak1,
1415 auto b_blockwise_copy =
1417 BElementwiseOperation,
1421 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1422 BBlockTransferThreadClusterArrangeOrder,
1425 decltype(b_grid_desc_bk0_n_bk1),
1426 decltype(b_block_desc_bk0_n_bk1),
1427 BBlockTransferSrcAccessOrder,
1429 BBlockTransferSrcVectorDim,
1431 BBlockTransferSrcScalarPerVector,
1432 BBlockTransferDstScalarPerVector_BK1,
1435 BThreadTransferSrcResetCoordinateAfterRun,
1437 BlockwiseGemmPipe::GlobalBufferNum>(
1438 b_grid_desc_bk0_n_bk1,
1441 b_block_desc_bk0_n_bk1,
1447 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1450 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1451 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1453 auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1454 reinterpret_cast<BDataType*
>(
static_cast<char*
>(p_shared) + a_block_space_size_aligned *
1457 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1463 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1465 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1467 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1468 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1471 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1472 a_block_desc_ak0_m_ak1,
1476 a_block_slice_copy_step,
1477 b_grid_desc_bk0_n_bk1,
1478 b_block_desc_bk0_n_bk1,
1482 b_block_slice_copy_step,
1484 num_k_block_main_loop);
1489 constexpr
auto c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
1490 blockwise_gemm_pipeline
1491 .GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
1495 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp =
1496 blockwise_gemm_pipeline
1497 .GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs();
1499 constexpr
auto MWave =
1500 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
1502 constexpr
auto MSubGroup =
1503 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
1505 constexpr
auto NWave =
1506 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
1508 constexpr
auto NThreadPerSubGroup =
1509 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
1511 constexpr
auto MAccVgprs =
1512 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
1516 constexpr
auto c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat =
1519 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1520 static_cast<CShuffleDataType*
>(p_shared),
1521 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
1522 .GetElementSpaceSize());
1525 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs =
1527 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
1539 NThreadPerSubGroup))),
1548 const auto c_thread_mtx_on_block =
1549 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0);
1551 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1552 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1554 const auto m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor =
1556 MRepeat, MWave, MSubGroup, MAccVgprs))),
1560 const auto m_thread_data_on_block_idx =
1561 m_thread_data_on_block_to_mrepeat_mwave_msubgroup_maccvgprs_adaptor
1564 const auto n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor =
1566 NRepeat, NWave, NThreadPerSubGroup))),
1570 const auto n_thread_data_on_block_idx =
1571 n_thread_data_on_block_to_nrepeat_nwave_nthreadpersubgroup_adaptor
1578 decltype(c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
1579 decltype(c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs),
1581 Sequence<CShuffleMRepeatPerShuffle,
1584 CShuffleNRepeatPerShuffle,
1594 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
1596 m_thread_data_on_block_idx[
I1],
1597 m_thread_data_on_block_idx[
I2],
1599 n_thread_data_on_block_idx[
I1],
1600 n_thread_data_on_block_idx[
I2],
1601 m_thread_data_on_block_idx[
I3]),
1607 CElementwiseOperation,
1608 CGlobalMemoryDataOperation,
1610 CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1612 CShuffleNRepeatPerShuffle * NWave * NPerWmma>,
1613 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1617 decltype(c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat),
1618 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1621 CShuffleBlockTransferScalarPerVector_NPerBlock,
1624 {c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
1626 c_grid_desc_mblock_mperblock_nblock_nperblock,
1632 constexpr
auto sfc_c_vgpr =
1635 Sequence<CShuffleMRepeatPerShuffle,
1638 CShuffleNRepeatPerShuffle,
1644 constexpr
auto sfc_c_global =
1648 CShuffleMRepeatPerShuffle * MWave * MPerWmma,
1650 CShuffleNRepeatPerShuffle * NWave * NPerWmma>>{};
1652 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1654 static_assert(num_access == sfc_c_global.GetNumOfAccess(),
"wrong!");
1661 c_thread_copy_vgpr_to_lds.Run(
1662 c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
1663 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1665 c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs,
1666 c_shuffle_block_buf);
1672 c_shuffle_block_copy_lds_to_global.Run(
1673 c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat,
1674 c_shuffle_block_buf,
1675 c_grid_desc_mblock_mperblock_nblock_nperblock,
1678 if constexpr(access_id < num_access - 1)
1680 constexpr
auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1683 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1684 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1690 template <
bool HasMainKBlockLoop,
1693 __device__
static void Run(
const ADataType* p_a_grid,
1694 const BDataType* p_b_grid,
1695 CDataType* p_c_grid,
1705 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1709 Run<decltype(a_grid_desc_ak0_m_ak1),
1710 decltype(b_grid_desc_bk0_n_bk1),
1711 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1713 CGlobalMemoryDataOperation,
1719 a_grid_desc_ak0_m_ak1,
1720 b_grid_desc_bk0_n_bk1,
1721 c_grid_desc_mblock_mperblock_nblock_nperblock);
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
int32_t int32_t
Definition: integer.hpp:10
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:278
typename remove_pointer< T >::type remove_pointer_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr auto BlockGemmPipeline_Selector()
Definition: blockwise_gemm_pipeline_wmma_selector.hpp:31
_Float16 half_t
Definition: data_type.hpp:30
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
ushort bhalf_t
Definition: data_type.hpp:29
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
__global__ void kernel_gemm_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:29
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:139
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
constexpr bool is_same_v
Definition: type.hpp:283
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:300
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: block_to_ctile_map.hpp:270
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:282
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:661
CDataType * p_c_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:693
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t k_batch_, bool is_reduce_=false)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:662
bool is_reduce
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:694
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:681
const ADataType * p_a_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:691
const BDataType * p_b_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:692
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:686
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:597
index_t M
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:642
index_t KPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:652
index_t NPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:650
index_t NBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:656
index_t K
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:644
__host__ void Print() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:623
index_t N
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:643
index_t AK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:653
index_t BK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:654
index_t KBatch
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:648
index_t MPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:649
index_t MBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:655
index_t StrideA
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:645
index_t StrideB
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:646
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:598
index_t StrideC
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:647
index_t KRead
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:651
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:698
index_t c_reduce_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:749
index_t b_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:748
index_t a_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:747
__device__ SplitKBatchOffset(Argument &karg)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:700
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:210
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:252
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:323
static constexpr auto BK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:224
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:534
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, decltype(MakeAWmmaTileDescriptor(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBWmmaTileDescriptor(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >())> BlockwiseGemmPipe
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1062
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:240
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:273
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1305
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:752
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1297
__host__ static constexpr __device__ auto MakeBWmmaTileDescriptor(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:526
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:279
static constexpr auto I6
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:217
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1064
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:285
static constexpr auto I5
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:216
__host__ static constexpr __device__ auto MakeAWmmaTileDescriptor(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:518
static constexpr auto I7
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:218
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:233
static constexpr auto I4
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:215
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1093
__host__ static constexpr __device__ auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1026
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:247
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:231
static constexpr auto I1
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:212
static constexpr auto I0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:211
static constexpr auto I3
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:214
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:257
static constexpr auto AK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:221
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1290
static constexpr auto AK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:223
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:292
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:267
static constexpr auto BK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:222
static constexpr index_t KPack
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:226
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1329
static constexpr auto I2
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:213
__host__ static constexpr __device__ auto MakeWmmaTileDescriptor(const BlockDesc &)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:303
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:297
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:262
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:890
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:407
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:1693
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: wmma_gemm.hpp:553
Definition: integral_constant.hpp:20
Definition: data_type.hpp:186
Definition: functional2.hpp:33
Definition: device_base.hpp:51
Definition: unary_element_wise_operation.hpp:308
#define CK_ENV(name)
Definition: env.hpp:128