25 template <
typename GridwiseGemm,
26 bool HasMainKBlockLoop,
31 #if CK_USE_LAUNCH_BOUNDS
37 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
38 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
40 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg);
42 GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
43 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
44 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
45 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
53 template <
typename GridwiseGemm,
54 bool HasMainKBlockLoop,
59 #if CK_USE_LAUNCH_BOUNDS
65 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
68 __shared__
char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()];
69 __shared__
char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()];
71 auto splitk_batch_offset =
typename GridwiseGemm::SplitKBatchOffset(karg);
73 GridwiseGemm::template Run_2Lds<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
74 karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
75 karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
76 karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
85 template <
typename ALayout,
91 typename CShuffleDataType,
93 typename AElementwiseOperation,
94 typename BElementwiseOperation,
95 typename CElementwiseOperation,
107 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
108 typename ABlockTransferThreadClusterArrangeOrder,
109 typename ABlockTransferSrcAccessOrder,
110 index_t ABlockTransferSrcVectorDim,
111 index_t ABlockTransferSrcScalarPerVector,
112 index_t ABlockTransferDstScalarPerVector_AK1,
113 bool AThreadTransferSrcResetCoordinateAfterRun,
115 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
116 typename BBlockTransferThreadClusterArrangeOrder,
117 typename BBlockTransferSrcAccessOrder,
118 index_t BBlockTransferSrcVectorDim,
119 index_t BBlockTransferSrcScalarPerVector,
120 index_t BBlockTransferDstScalarPerVector_BK1,
121 bool BThreadTransferSrcResetCoordinateAfterRun,
123 index_t CShuffleMXdlPerWavePerShuffle,
124 index_t CShuffleNXdlPerWavePerShuffle,
125 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
126 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
129 typename ComputeTypeA = CDataType,
130 typename ComputeTypeB = ComputeTypeA,
131 bool PermuteA =
false,
132 bool PermuteB =
false>
133 struct GridwiseGemm_xdl_cshuffle_v3
135 static constexpr
auto I0 = Number<0>{};
136 static constexpr
auto I1 = Number<1>{};
137 static constexpr
auto I2 = Number<2>{};
138 static constexpr
auto I3 = Number<3>{};
139 static constexpr
auto I4 = Number<4>{};
140 static constexpr
auto I5 = Number<5>{};
141 static constexpr
auto I6 = Number<6>{};
142 static constexpr
auto I7 = Number<7>{};
147 static constexpr
auto AK1Number = Number<AK1Value>{};
148 static constexpr
auto BK1Number = Number<BK1Value>{};
192 auto K_t = K_Batch * KPerBlock;
193 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
198 auto K_t = K_Batch * KPerBlock;
199 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
204 auto K_t = K_Batch * KPerBlock;
205 return (K + K_t - 1) / K_t * KPerBlock;
211 auto K_t = K_Batch * KReadVec;
212 return (K + K_t - 1) / K_t * KReadVec;
225 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl,
typename TileDesc_K0_MN_K1>
243 const auto a_grid_desc_mraw_kraw = [&]() {
244 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
248 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
256 if constexpr(GemmSpec == GemmSpecialization::MKPadding ||
257 GemmSpec == GemmSpecialization::MNKPadding)
260 const auto a_grid_desc_m_k =
274 return a_grid_desc_ak0_m_ak1;
276 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
277 GemmSpec == GemmSpecialization::MNPadding)
281 a_grid_desc_mraw_kraw,
287 return a_grid_desc_ak0_m_ak1;
289 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
290 GemmSpec == GemmSpecialization::NKPadding)
294 a_grid_desc_mraw_kraw,
306 return a_grid_desc_ak0_m_ak1;
312 a_grid_desc_mraw_kraw,
318 return a_grid_desc_ak0_m_ak1;
325 const auto b_grid_desc_nraw_kraw = [&]() {
339 GemmSpec != GemmSpecialization::Default),
340 "pk_i4_t does not support padding");
342 if constexpr(GemmSpec == GemmSpecialization::NKPadding ||
343 GemmSpec == GemmSpecialization::MNKPadding)
346 const auto b_grid_desc_n_k =
360 return b_grid_desc_bk0_n_bk1;
362 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
363 GemmSpec == GemmSpecialization::MNPadding)
367 b_grid_desc_nraw_kraw,
373 return b_grid_desc_bk0_n_bk1;
375 else if constexpr(GemmSpec == GemmSpecialization::KPadding ||
376 GemmSpec == GemmSpecialization::MKPadding)
380 b_grid_desc_nraw_kraw,
392 return b_grid_desc_bk0_n_bk1;
396 if constexpr(!PermuteB)
400 b_grid_desc_nraw_kraw,
406 return b_grid_desc_bk0_n_bk1;
412 constexpr
index_t BK01 = KPerBlock / BK1Value;
413 const index_t BK0_ = StrideB / BK1Value;
414 const index_t BK00 = BK0_ / BK01;
416 const auto b_grid_desc_bk00_n_bk01_bk1_permute =
420 b_grid_desc_bk00_n_bk01_bk1_permute,
427 return b_grid_desc_bk0_n_bk1_permute;
432 template <
typename ABlockDesc_AK0_M_AK1>
433 __host__ __device__
static constexpr
auto
436 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
438 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
441 template <
typename BBlockDesc_BK0_N_BK1>
442 __host__ __device__
static constexpr
auto
445 constexpr
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
447 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
450 __host__ __device__
static auto
453 const auto c_grid_desc_mraw_nraw = [&]() {
473 if constexpr(GemmSpec == GemmSpecialization::MNPadding ||
474 GemmSpec == GemmSpecialization::MNKPadding)
483 else if constexpr(GemmSpec == GemmSpecialization::MPadding ||
484 GemmSpec == GemmSpecialization::MKPadding)
488 c_grid_desc_mraw_nraw,
493 else if constexpr(GemmSpec == GemmSpecialization::NPadding ||
494 GemmSpec == GemmSpecialization::NKPadding)
498 c_grid_desc_mraw_nraw,
506 return c_grid_desc_mraw_nraw;
540 std::cout <<
"problem {"
549 <<
"KRead:" <<
KRead <<
", "
551 <<
"AK0:" <<
AK0 <<
", "
552 <<
"BK0:" <<
BK0 <<
", "
553 <<
"MBlock: " <<
MBlock <<
", "
554 <<
"NBlock: " <<
NBlock <<
"}" << std::endl;
578 const BDataType* p_b_grid_,
579 CDataType* p_c_grid_,
587 bool is_reduce_ =
false)
588 :
Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_},
617 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
621 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
626 if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
630 else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
632 if constexpr(!PermuteB)
638 const int k0_offset = karg.KRead * karg.N;
643 if(blockIdx.z <
static_cast<uint32_t
>(karg.KBatch - 1))
649 karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
652 if(karg.IsReduceAdd())
670 if constexpr(ABlockLdsExtraM)
681 constexpr
auto MLdsLayer = LdsSize < 1 ? 1 : LdsSize;
696 a_lds_block_desc_permuted,
704 a_lds_block_desc_ak0_mldslayer_m_ak1,
712 return a_lds_block_desc_ak0_m_ak1;
719 constexpr
auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
720 constexpr
auto M1 = MPerBlock / M0;
722 constexpr
auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
723 constexpr
auto K0PerThreadWrite =
AK0Number / KThreadWrite;
724 constexpr
auto KThreadRead = 64 / MPerXdl;
725 constexpr
auto K0PerThreadRead =
AK0Number / KThreadRead;
727 constexpr
auto kfold = (
AK1Number * M0 *
sizeof(ADataType) > 128)
729 : 128 / (
AK1Number * M0 *
sizeof(ADataType));
730 constexpr
auto KThreadReadPerm =
731 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
732 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
736 constexpr
auto mpair = (
AK1Number * MPerXdl *
sizeof(ADataType) > 128)
738 : ((128 / (
AK1Number * MPerXdl *
sizeof(ADataType))) > M0
740 : 128 / (
AK1Number * MPerXdl *
sizeof(ADataType)));
746 Number<kfold * M0 / mpair>{},
765 a_lds_block_desc_permuted,
787 a_lds_block_desc_unmerged,
790 Number<KThreadWrite / kfold / KThreadReadPerm>{},
799 return a_lds_block_desc_ak0_m_ak1;
806 if constexpr(BBlockLdsExtraN)
816 constexpr
index_t NLdsLayer = LdsSize < 1 ? 1 : LdsSize;
831 b_lds_block_desc_permuted,
839 b_lds_block_desc_bk0_nldslayer_n_bk1,
847 return b_lds_block_desc_bk0_n_bk1;
851 constexpr
auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I1);
852 constexpr
auto N1 = NPerBlock / N0;
854 constexpr
auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I0);
855 constexpr
auto K0PerThreadWrite =
BK0Number / KThreadWrite;
856 constexpr
auto KThreadRead = 64 / NPerXdl;
857 constexpr
auto K0PerThreadRead =
BK0Number / KThreadRead;
859 constexpr
auto kfold = (
BK1Number * N0 *
sizeof(BDataType) > 128)
861 : 128 / (
BK1Number * N0 *
sizeof(BDataType));
862 constexpr
auto KThreadReadPerm =
863 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
864 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
868 constexpr
auto npair = (
BK1Number * NPerXdl *
sizeof(BDataType) > 128)
870 : ((128 / (
BK1Number * NPerXdl *
sizeof(BDataType))) > N0
872 : 128 / (
BK1Number * NPerXdl *
sizeof(BDataType)));
878 Number<kfold * N0 / npair>{},
897 b_lds_block_desc_permuted,
919 b_lds_block_desc_unmerged,
922 Number<KThreadWrite / kfold / KThreadReadPerm>{},
931 return b_lds_block_desc_bk0_n_bk1;
937 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
938 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
940 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
947 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
965 ABlockTransferSrcScalarPerVector,
966 BBlockTransferSrcScalarPerVector,
986 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
989 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
992 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
995 constexpr
auto c_block_size =
996 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
999 b_block_space_size_aligned *
sizeof(BDataType) /
BPackedSize),
1000 c_block_size *
sizeof(CShuffleDataType));
1006 static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
1007 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
1008 "Invalid tuning param!");
1016 if(!(karg.M % MPerBlock == 0))
1020 std::cout <<
"Arg M value is not a multiple of MPerBlock! M: " << karg.M <<
" "
1021 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1034 if(!(karg.N % NPerBlock == 0))
1038 std::cout <<
"Arg N value is not a multiple of NPerBlock! N: " << karg.N <<
" "
1039 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1052 auto K_t = karg.KBatch * KPerBlock;
1053 if(!(karg.K % K_t == 0))
1057 std::cout <<
"Arg K value is not a multiple of K_Batch * K0PerBlock * K1! K: "
1058 << karg.K <<
" " << __FILE__ <<
":" << __LINE__
1059 <<
", in function: " << __func__ << std::endl;
1067 auto K_t = karg.KBatch * KReadVec;
1069 if((KReadPadSplited * (karg.KBatch - 1)) >= karg.K)
1077 if(karg.K % ABlockTransferSrcScalarPerVector != 0)
1081 std::cout <<
"Arg K (" << karg.K
1082 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1083 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1084 << __LINE__ <<
", in function: " << __func__ << std::endl;
1091 if(karg.M % ABlockTransferSrcScalarPerVector != 0)
1095 std::cout <<
"Arg M (" << karg.M
1096 <<
") value is not a multiple of ABlockTransferSrcScalarPerVector ("
1097 << ABlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1098 << __LINE__ <<
", in function: " << __func__ << std::endl;
1106 if(karg.N % BBlockTransferSrcScalarPerVector != 0)
1110 std::cout <<
"Arg N (" << karg.N
1111 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1112 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1113 << __LINE__ <<
", in function: " << __func__ << std::endl;
1120 if(karg.K % BBlockTransferSrcScalarPerVector != 0)
1124 std::cout <<
"Arg K (" << karg.K
1125 <<
") value is not a multiple of BBlockTransferSrcScalarPerVector ("
1126 << BBlockTransferSrcScalarPerVector <<
" )! " << __FILE__ <<
":"
1127 << __LINE__ <<
", in function: " << __func__ << std::endl;
1135 if(karg.N % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1139 std::cout <<
"Arg N (" << karg.N
1140 <<
") value is not a multiple of "
1141 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1142 << CShuffleBlockTransferScalarPerVector_NPerBlock <<
" )! "
1143 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1151 if(karg.M % CShuffleBlockTransferScalarPerVector_NPerBlock != 0)
1155 std::cout <<
"Arg M (" << karg.M
1156 <<
") value is not a multiple of "
1157 "CShuffleBlockTransferScalarPerVector_NPerBlock ("
1158 << CShuffleBlockTransferScalarPerVector_NPerBlock <<
" )! "
1159 << __FILE__ <<
":" << __LINE__ <<
", in function: " << __func__
1171 if(!karg.IsReduceAdd())
1175 std::cout <<
" KBatch: " << karg.KBatch <<
" > 1 is not support yet" << __FILE__
1176 <<
":" << __LINE__ <<
", in function: " << __func__ << std::endl;
1186 const auto num_k_loop = karg.AK0 / (KPerBlock / AK1Value);
1190 if(num_k_loop <= BlockwiseGemmPipe::PrefetchStages)
1202 const index_t num_loop = K / KPerBlock;
1204 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
1209 const index_t num_loop = K / KPerBlock;
1211 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
1214 template <
typename CGr
idDesc>
1216 const CGridDesc& c_grid_desc_m_n,
index_t MBlock,
index_t NBlock)
1225 return c_grid_desc_mblock_mperblock_nblock_nperblock;
1233 template <
typename AGridDesc_AK0_M_K1,
1234 typename BGridDesc_BK0_N_K1,
1235 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1236 bool HasMainKBlockLoop,
1239 __device__
static void Run(
const ADataType* p_a_grid,
1240 const BDataType* p_b_grid,
1241 CDataType* p_c_grid,
1244 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1245 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1246 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1247 c_grid_desc_mblock_mperblock_nblock_nperblock)
1249 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1250 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1251 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1252 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1253 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1254 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1256 const AElementwiseOperation a_element_op{};
1257 const BElementwiseOperation b_element_op{};
1258 const CElementwiseOperation c_element_op{};
1261 const auto block_2_ctile_map =
Block2CTileMap{problem.M, problem.N, 4};
1263 const auto block_work_idx =
1266 if(!block_2_ctile_map.ValidCTileIndex(
1268 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
1269 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
1274 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
1275 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
1278 const index_t m_block_data_idx_on_grid =
1279 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1281 const index_t n_block_data_idx_on_grid =
1282 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1294 auto a_blockwise_copy =
1296 AElementwiseOperation,
1300 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1301 ABlockTransferThreadClusterArrangeOrder,
1304 decltype(a_grid_desc_ak0_m_ak1),
1305 decltype(a_block_desc_ak0_m_ak1),
1306 ABlockTransferSrcAccessOrder,
1308 ABlockTransferSrcVectorDim,
1310 ABlockTransferSrcScalarPerVector,
1311 ABlockTransferDstScalarPerVector_AK1,
1314 AThreadTransferSrcResetCoordinateAfterRun,
1316 BlockwiseGemmPipe::GlobalBufferNum>(
1317 a_grid_desc_ak0_m_ak1,
1320 a_block_desc_ak0_m_ak1,
1325 auto b_blockwise_copy =
1327 BElementwiseOperation,
1331 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1332 BBlockTransferThreadClusterArrangeOrder,
1335 decltype(b_grid_desc_bk0_n_bk1),
1336 decltype(b_block_desc_bk0_n_bk1),
1337 BBlockTransferSrcAccessOrder,
1339 BBlockTransferSrcVectorDim,
1341 BBlockTransferSrcScalarPerVector,
1342 BBlockTransferDstScalarPerVector_BK1,
1345 BThreadTransferSrcResetCoordinateAfterRun,
1347 BlockwiseGemmPipe::GlobalBufferNum>(
1348 b_grid_desc_bk0_n_bk1,
1351 b_block_desc_bk0_n_bk1,
1357 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1360 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1361 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1363 auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1364 reinterpret_cast<BDataType*
>(
static_cast<char*
>(p_shared) + a_block_space_size_aligned *
1367 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1373 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1375 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1377 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1378 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1381 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1382 a_block_desc_ak0_m_ak1,
1386 a_block_slice_copy_step,
1387 b_grid_desc_bk0_n_bk1,
1388 b_block_desc_bk0_n_bk1,
1392 b_block_slice_copy_step,
1394 num_k_block_main_loop);
1398 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1399 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1402 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1403 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1406 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1407 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1411 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1412 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1414 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
1415 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
1416 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
1417 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
1418 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
1419 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
1420 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
1421 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
1423 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1426 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1427 static_cast<CShuffleDataType*
>(p_shared),
1428 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1431 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1451 const auto c_thread_mtx_on_block =
1452 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1454 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1455 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1457 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1463 const auto m_thread_data_on_block_idx =
1464 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1467 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1473 const auto n_thread_data_on_block_idx =
1474 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1478 auto c_thread_copy_vgpr_to_lds =
1481 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1482 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1484 Sequence<CShuffleMXdlPerWavePerShuffle,
1485 CShuffleNXdlPerWavePerShuffle,
1498 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1501 m_thread_data_on_block_idx[
I1],
1502 n_thread_data_on_block_idx[
I1],
1503 m_thread_data_on_block_idx[
I2],
1504 m_thread_data_on_block_idx[
I3],
1505 m_thread_data_on_block_idx[
I4],
1506 n_thread_data_on_block_idx[
I2]),
1512 CElementwiseOperation,
1513 CGlobalMemoryDataOperation,
1515 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1517 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
1518 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1522 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1523 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1526 CShuffleBlockTransferScalarPerVector_NPerBlock,
1529 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1531 c_grid_desc_mblock_mperblock_nblock_nperblock,
1536 constexpr
auto sfc_c_vgpr =
1539 Sequence<CShuffleMXdlPerWavePerShuffle,
1540 CShuffleNXdlPerWavePerShuffle,
1549 constexpr
auto sfc_c_global =
1553 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1555 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1557 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1559 static_assert(num_access == sfc_c_global.GetNumOfAccess(),
"wrong!");
1566 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1567 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1569 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1570 c_shuffle_block_buf);
1576 c_shuffle_block_copy_lds_to_global.Run(
1577 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1578 c_shuffle_block_buf,
1579 c_grid_desc_mblock_mperblock_nblock_nperblock,
1582 if constexpr(access_id < num_access - 1)
1584 constexpr
auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1587 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1588 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1594 template <
bool HasMainKBlockLoop,
1597 __device__
static void Run(
const ADataType* p_a_grid,
1598 const BDataType* p_b_grid,
1599 CDataType* p_c_grid,
1604 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
1606 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
1608 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
1609 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
1611 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
1613 Run<decltype(a_grid_desc_ak0_m_ak1),
1614 decltype(b_grid_desc_bk0_n_bk1),
1615 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1617 CGlobalMemoryDataOperation,
1623 a_grid_desc_ak0_m_ak1,
1624 b_grid_desc_bk0_n_bk1,
1625 c_grid_desc_mblock_mperblock_nblock_nperblock);
1628 template <
typename AGridDesc_AK0_M_K1,
1629 typename BGridDesc_BK0_N_K1,
1630 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
1631 bool HasMainKBlockLoop,
1634 __device__
static void Run_2Lds(
const ADataType* p_a_grid,
1635 const BDataType* p_b_grid,
1636 CDataType* p_c_grid,
1640 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1641 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1642 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1643 c_grid_desc_mblock_mperblock_nblock_nperblock)
1645 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1646 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1647 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1648 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1649 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1650 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1652 const AElementwiseOperation a_element_op{};
1653 const BElementwiseOperation b_element_op{};
1654 const CElementwiseOperation c_element_op{};
1657 const auto block_2_ctile_map =
Block2CTileMap{problem.M, problem.N, 4};
1659 const auto block_work_idx =
1662 if(!block_2_ctile_map.ValidCTileIndex(
1664 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
1665 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
1670 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
1671 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
1674 const index_t m_block_data_idx_on_grid =
1675 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1677 const index_t n_block_data_idx_on_grid =
1678 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1690 auto a_blockwise_copy =
1692 AElementwiseOperation,
1696 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1697 ABlockTransferThreadClusterArrangeOrder,
1700 decltype(a_grid_desc_ak0_m_ak1),
1701 decltype(a_block_desc_ak0_m_ak1),
1702 ABlockTransferSrcAccessOrder,
1704 ABlockTransferSrcVectorDim,
1706 ABlockTransferSrcScalarPerVector,
1707 ABlockTransferDstScalarPerVector_AK1,
1710 AThreadTransferSrcResetCoordinateAfterRun,
1712 BlockwiseGemmPipe::GlobalBufferNum>(
1713 a_grid_desc_ak0_m_ak1,
1716 a_block_desc_ak0_m_ak1,
1721 auto b_blockwise_copy =
1723 BElementwiseOperation,
1727 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1728 BBlockTransferThreadClusterArrangeOrder,
1731 decltype(b_grid_desc_bk0_n_bk1),
1732 decltype(b_block_desc_bk0_n_bk1),
1733 BBlockTransferSrcAccessOrder,
1735 BBlockTransferSrcVectorDim,
1737 BBlockTransferSrcScalarPerVector,
1738 BBlockTransferDstScalarPerVector_BK1,
1741 BThreadTransferSrcResetCoordinateAfterRun,
1743 BlockwiseGemmPipe::GlobalBufferNum>(
1744 b_grid_desc_bk0_n_bk1,
1747 b_block_desc_bk0_n_bk1,
1753 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1755 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1756 static_cast<ADataType*
>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1758 auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1759 bit_cast<BDataType*>(
static_cast<char*
>(p_shared_0) +
1760 a_block_space_size_aligned *
sizeof(ADataType)),
1761 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1763 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1764 static_cast<ADataType*
>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1766 auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1767 bit_cast<BDataType*>(bit_cast<char*>(p_shared_1) +
1768 a_block_space_size_aligned *
sizeof(ADataType)),
1769 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1771 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
1772 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
1778 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1780 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1782 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1783 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1786 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1787 a_block_desc_ak0_m_ak1,
1791 a_block_slice_copy_step,
1792 b_grid_desc_bk0_n_bk1,
1793 b_block_desc_bk0_n_bk1,
1797 b_block_slice_copy_step,
1799 num_k_block_main_loop);
1803 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1804 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1807 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1808 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1811 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1812 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1816 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1817 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1819 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
1820 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
1821 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
1822 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
1823 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
1824 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
1825 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
1826 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
1828 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1831 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1832 static_cast<CShuffleDataType*
>(p_shared_0),
1833 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1836 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1856 const auto c_thread_mtx_on_block =
1857 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1859 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1860 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1862 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1868 const auto m_thread_data_on_block_idx =
1869 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1872 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1878 const auto n_thread_data_on_block_idx =
1879 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1883 auto c_thread_copy_vgpr_to_lds =
1886 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1887 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1889 Sequence<CShuffleMXdlPerWavePerShuffle,
1890 CShuffleNXdlPerWavePerShuffle,
1903 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1906 m_thread_data_on_block_idx[
I1],
1907 n_thread_data_on_block_idx[
I1],
1908 m_thread_data_on_block_idx[
I2],
1909 m_thread_data_on_block_idx[
I3],
1910 m_thread_data_on_block_idx[
I4],
1911 n_thread_data_on_block_idx[
I2]),
1917 CElementwiseOperation,
1918 CGlobalMemoryDataOperation,
1920 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1922 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
1923 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1927 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1928 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1931 CShuffleBlockTransferScalarPerVector_NPerBlock,
1934 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1936 c_grid_desc_mblock_mperblock_nblock_nperblock,
1941 constexpr
auto sfc_c_vgpr =
1944 Sequence<CShuffleMXdlPerWavePerShuffle,
1945 CShuffleNXdlPerWavePerShuffle,
1954 constexpr
auto sfc_c_global =
1958 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1960 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1962 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1964 static_assert(num_access == sfc_c_global.GetNumOfAccess(),
"wrong!");
1971 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1972 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1974 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1975 c_shuffle_block_buf);
1981 c_shuffle_block_copy_lds_to_global.Run(
1982 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1983 c_shuffle_block_buf,
1984 c_grid_desc_mblock_mperblock_nblock_nperblock,
1987 if constexpr(access_id < num_access - 1)
1989 constexpr
auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1992 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1993 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
1999 template <
bool HasMainKBlockLoop,
2002 __device__
static void Run_2Lds(
const ADataType* p_a_grid,
2003 const BDataType* p_b_grid,
2004 CDataType* p_c_grid,
2010 problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
2012 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
2014 problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
2016 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
2018 c_grid_desc_m_n, problem.MBlock, problem.NBlock);
2020 Run_2Lds<decltype(a_grid_desc_ak0_m_ak1),
2021 decltype(b_grid_desc_bk0_n_bk1),
2022 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
2024 CGlobalMemoryDataOperation,
2031 a_grid_desc_ak0_m_ak1,
2032 b_grid_desc_bk0_n_bk1,
2033 c_grid_desc_mblock_mperblock_nblock_nperblock);
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
#define CK_ENV(name)
Definition: env.hpp:128
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blockwise_gemm_pipeline_xdlops_ab_scale_selector.hpp:13
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
TailNumber
Definition: blkgemmpipe_scheduler.hpp:18
_Float16 half_t
Definition: data_type.hpp:25
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
ushort bhalf_t
Definition: data_type.hpp:24
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:58
constexpr auto BlockGemmPipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp:44
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:139
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
constexpr bool is_same_v
Definition: type.hpp:283
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:37
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: block_to_ctile_map.hpp:270
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: block_to_ctile_map.hpp:296
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:281
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:241
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t k_batch_, bool is_reduce_=false)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:577
const BDataType * p_b_grid
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:260
bool is_reduce
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:609
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:596
const ADataType * p_a_grid
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:259
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:601
CDataType * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:261
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:177
index_t StrideB
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:226
index_t M
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:222
index_t KRead
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:231
index_t BK0
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:234
index_t NPadded
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:230
index_t MBlock
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:235
index_t KPadded
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:232
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:538
index_t NBlock
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:236
index_t MPadded
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:229
index_t AK0
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:233
index_t N
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:223
index_t KBatch
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:228
index_t StrideA
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:225
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:513
index_t K
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:224
index_t StrideC
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:227
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:613
index_t a_k_split_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:662
__device__ SplitKBatchOffset(Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:615
index_t c_reduce_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:664
index_t b_k_split_offset
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:663
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:196
static constexpr auto I5
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:72
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const index_t k_id=0)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:1002
static constexpr auto AK1Number
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:79
static constexpr auto BK1Number
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:80
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1207
static constexpr index_t APackedSize
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:156
static constexpr index_t BPackedSize
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:163
static constexpr auto I3
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:70
static constexpr auto I1
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:68
static constexpr auto I4
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:71
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:185
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:322
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:202
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:190
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:976
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:803
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:2002
static constexpr auto AK0Number
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:77
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1004
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:667
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1215
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1200
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const index_t k_id=0)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:641
static constexpr auto I7
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:74
static constexpr index_t KPack
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:82
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1597
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:180
static constexpr auto I0
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:67
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:434
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:170
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:935
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:443
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:86
static constexpr auto BK0Number
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:78
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1634
static constexpr auto I2
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:69
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:240
static constexpr auto I6
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:73
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:451
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:208
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:226
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1239
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:215
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:574
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:220
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:175
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1130
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:10
Definition: data_type.hpp:320
Definition: functional2.hpp:31
Definition: device_base.hpp:50
Definition: unary_element_wise_operation.hpp:241