20 typename ComputeDataType,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
43 typename ComputeDataType,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
126 using Base::xdlops_gemm;
129 using Base::CalculateCThreadOriginDataIndex;
130 using Base::CalculateCThreadOriginDataIndex8D;
131 using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
132 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
133 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
134 using Base::GetCThreadBuffer;
135 using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
136 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
137 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
138 using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
139 using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
141 using Base::a_block_desc_m0_m1_m2_k;
142 using Base::b_block_desc_n0_n1_n2_k;
144 using Base::AMmaKStride;
145 using Base::BMmaKStride;
154 return num_loop > PrefetchStages;
159 if(num_loop % HotloopUnroll == 1)
174 constexpr
auto num_ds_read_inst_a =
175 HotLoopInstList::A_LDS_Read_Width *
sizeof(ADataType) == 16
176 ? HotLoopInstList::A_LDS_Read_Inst_Num
177 : HotLoopInstList::A_LDS_Read_Inst_Num / 2;
178 constexpr
auto num_ds_read_inst_b =
179 HotLoopInstList::B_LDS_Read_Width *
sizeof(BDataType) == 16
180 ? HotLoopInstList::B_LDS_Read_Inst_Num
181 : HotLoopInstList::B_LDS_Read_Inst_Num / 2;
183 constexpr
auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
184 constexpr
auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
186 constexpr
auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
187 constexpr
auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
189 constexpr
auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
191 constexpr
auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
192 constexpr
auto ds_read_a_issue_cycle =
193 HotLoopInstList::A_LDS_Read_Width *
sizeof(ADataType) == 16 ? 8 : 4;
194 constexpr
auto ds_read_b_issue_cycle =
195 HotLoopInstList::B_LDS_Read_Width *
sizeof(BDataType) == 16 ? 8 : 4;
196 constexpr
auto ds_read_a_mfma_rate =
197 (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
198 constexpr
auto ds_read_b_mfma_rate =
199 (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
201 constexpr
auto num_dsread_stage1_a = num_ds_read_inst_a / KRepeat * (KRepeat - 1);
202 constexpr
auto num_dsread_stage1_b = num_ds_read_inst_b / KRepeat * (KRepeat - 1);
203 constexpr
auto num_dsread_stage3_a = num_ds_read_inst_a / KRepeat;
204 constexpr
auto num_dsread_stage3_b = num_ds_read_inst_b / KRepeat;
206 constexpr
auto num_dsread_stage1_a_mfma =
207 (num_dsread_stage1_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
208 constexpr
auto num_dsread_stage1_b_mfma =
209 (num_dsread_stage1_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
210 constexpr
auto num_dsread_stage3_a_mfma =
211 (num_dsread_stage3_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
212 constexpr
auto num_dsread_stage3_b_mfma =
213 (num_dsread_stage3_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
215 constexpr
auto num_mfma_stage2 = num_mfma_inst - num_ds_read_inst_a / ds_read_a_mfma_rate -
216 num_ds_read_inst_b / ds_read_b_mfma_rate;
217 constexpr
auto num_mfma_per_issue =
218 num_mfma_stage2 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
219 constexpr
auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
220 constexpr
auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
225 if constexpr((num_dsread_stage1_a - (i + 1) * ds_read_a_mfma_rate) >=
228 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0);
232 __builtin_amdgcn_sched_group_barrier(
234 num_dsread_stage1_a - (num_dsread_stage1_a_mfma - 1) * ds_read_a_mfma_rate,
237 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
241 if constexpr((num_dsread_stage1_b - (i + 1) * ds_read_b_mfma_rate) >=
244 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0);
248 __builtin_amdgcn_sched_group_barrier(
250 num_dsread_stage1_b - (num_dsread_stage1_b_mfma - 1) * ds_read_b_mfma_rate,
253 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
261 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
262 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
264 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
265 __builtin_amdgcn_sched_group_barrier(
266 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0);
272 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
273 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
275 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
276 __builtin_amdgcn_sched_group_barrier(
277 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0);
283 if constexpr((num_dsread_stage3_a - (i + 1) * ds_read_a_mfma_rate) >=
286 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0);
290 __builtin_amdgcn_sched_group_barrier(
292 num_dsread_stage3_a - (num_dsread_stage3_a_mfma - 1) * ds_read_a_mfma_rate,
295 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
299 if constexpr((num_dsread_stage3_b - (i + 1) * ds_read_b_mfma_rate) >=
302 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0);
306 __builtin_amdgcn_sched_group_barrier(
308 num_dsread_stage3_b - (num_dsread_stage3_b_mfma - 1) * ds_read_b_mfma_rate,
311 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
316 __builtin_amdgcn_sched_barrier(0);
319 template <
bool HasMainLoop,
323 typename ABlockTransfer,
324 typename AGridBuffer,
325 typename ABlockBuffer,
326 typename ABlockTransferStep,
329 typename BBlockTransfer,
330 typename BGridBuffer,
331 typename BBlockBuffer,
332 typename BBlockTransferStep,
333 typename CThreadBuffer>
334 __device__
void Run(
const AGridDesc& a_grid_desc,
335 const ABlockDesc& a_block_desc,
336 ABlockTransfer& a_blockwise_copy,
337 const AGridBuffer& a_grid_buf,
338 ABlockBuffer& a_block_buf,
339 const ABlockTransferStep& a_block_copy_step,
340 const BGridDesc& b_grid_desc,
341 const BBlockDesc& b_block_desc,
342 BBlockTransfer& b_blockwise_copy,
343 const BGridBuffer& b_grid_buf,
344 BBlockBuffer& b_block_buf,
345 const BBlockTransferStep& b_block_copy_step,
346 CThreadBuffer& c_thread_buf,
349 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
350 a_thread_desc_.GetElementSpaceSize());
351 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
352 b_thread_desc_.GetElementSpaceSize());
355 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
356 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
358 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
359 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
362 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
363 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
366 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
367 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
369 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
370 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
373 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
374 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
376 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
377 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
380 c_thread_buf.Clear();
385 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
393 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
402 if constexpr(HasMainLoop)
407 auto LoopFunc = [&](
auto vmem_buf) {
412 if constexpr(k0 == (KRepeat - 1))
416 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, vmem_buf);
417 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, vmem_buf);
419 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, vmem_buf);
420 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, vmem_buf);
422 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
423 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
430 a_thread_vec.template AsType<ComputeDataType>()(ik) =
431 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
435 b_thread_vec.template AsType<ComputeDataType>()(ik) =
436 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
440 using mfma_input_type =
442 xdlops_gemm.K1PerXdlops>::type;
445 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
448 a_thread_vec.template AsType<mfma_input_type>(),
449 b_thread_vec.template AsType<mfma_input_type>(),
454 a_block_desc_m0_m1_m2_k,
464 b_block_desc_n0_n1_n2_k,
480 }
while(i < (num_loop - PrefetchStages));
483 auto ReadWriteCompFunc = [&](
auto vmem_buf) {
488 if constexpr(k0 == (KRepeat - 1))
492 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, vmem_buf);
493 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, vmem_buf);
500 a_thread_vec.template AsType<ComputeDataType>()(ik) =
501 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
505 b_thread_vec.template AsType<ComputeDataType>()(ik) =
506 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
510 using mfma_input_type =
511 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
514 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
516 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
517 b_thread_vec.template AsType<mfma_input_type>(),
521 a_block_desc_m0_m1_m2_k,
531 b_block_desc_n0_n1_n2_k,
542 auto ReadCompFunc = [&]() {
546 static_for<0, KRepeat - 1, 1>{}([&](
auto k0) {
550 a_thread_vec.template AsType<ComputeDataType>()(ik) =
551 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
555 b_thread_vec.template AsType<ComputeDataType>()(ik) =
556 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
560 using mfma_input_type =
561 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
564 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
566 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
567 b_thread_vec.template AsType<mfma_input_type>(),
572 a_block_desc_m0_m1_m2_k,
582 b_block_desc_n0_n1_n2_k,
594 a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_buf
598 b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_buf
602 using mfma_input_type =
603 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
606 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
608 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
609 b_thread_vec.template AsType<mfma_input_type>(),
619 ReadWriteCompFunc(I0);
620 ReadWriteCompFunc(I1);
625 ReadWriteCompFunc(I0);
632 static constexpr
auto a_thread_desc_ =
636 static constexpr
auto b_thread_desc_ =
641 decltype(a_block_desc_m0_m1_m2_k),
642 decltype(a_thread_desc_),
651 decltype(b_block_desc_n0_n1_n2_k),
652 decltype(b_thread_desc_),
659 AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
660 BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
661 using Base::c_thread_desc_;
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
TailNumber
Definition: blkgemmpipe_scheduler.hpp:18
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
ck::BlockwiseGemmXdlops_pipeline_v5< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::HotLoopScheduler static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops_v5.hpp:169
ck::BlockwiseGemmXdlops_pipeline_v5< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v5.hpp:334
ck::BlockwiseGemmXdlops_pipeline_v5< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum static constexpr __host__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v5.hpp:157
ck::BlockwiseGemmXdlops_pipeline_v5< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop static constexpr __host__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v5.hpp:152
Definition: blockwise_gemm_pipeline_xdlops_v5.hpp:37
Definition: sequence.hpp:43
ck::ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataType, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 >
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
Definition: data_type.hpp:347