20 typename ComputeDataType,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
43 typename ComputeDataType,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
123 using Base::xdlops_gemm;
125 using Base::CalculateCThreadOriginDataIndex;
126 using Base::CalculateCThreadOriginDataIndex8D;
127 using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
128 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
129 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
130 using Base::GetCThreadBuffer;
131 using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
132 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
133 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
134 using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
135 using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
137 using Base::a_block_desc_m0_m1_m2_k;
138 using Base::b_block_desc_n0_n1_n2_k;
140 using Base::AMmaKStride;
141 using Base::BMmaKStride;
149 return num_loop > PrefetchStages;
158 template <
bool HasMainLoop,
162 typename ABlockTransfer,
163 typename AGridBuffer,
164 typename ABlockBuffer,
165 typename ABlockTransferStep,
168 typename BBlockTransfer,
169 typename BGridBuffer,
170 typename BBlockBuffer,
171 typename BBlockTransferStep,
172 typename CThreadBuffer>
173 __device__
void Run(
const AGridDesc& a_grid_desc,
174 const ABlockDesc& a_block_desc,
175 ABlockTransfer& a_blockwise_copy,
176 const AGridBuffer& a_grid_buf,
177 ABlockBuffer& a_block_buf,
178 const ABlockTransferStep& a_block_copy_step,
179 const BGridDesc& b_grid_desc,
180 const BBlockDesc& b_block_desc,
181 BBlockTransfer& b_blockwise_copy,
182 const BGridBuffer& b_grid_buf,
183 BBlockBuffer& b_block_buf,
184 const BBlockTransferStep& b_block_copy_step,
185 CThreadBuffer& c_thread_buf,
188 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
189 a_thread_desc_.GetElementSpaceSize());
190 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
191 b_thread_desc_.GetElementSpaceSize());
194 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
195 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
197 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
198 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
201 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
202 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
205 c_thread_buf.Clear();
208 if constexpr(HasMainLoop)
214 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
215 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
217 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
218 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
223 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
230 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
247 a_thread_vec.template AsType<ComputeDataType>()(ik) =
248 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
250 b_thread_vec.template AsType<ComputeDataType>()(ik) =
251 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
255 using mfma_input_type =
257 xdlops_gemm.K1PerXdlops>::type;
260 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
263 a_thread_vec.template AsType<mfma_input_type>(),
264 b_thread_vec.template AsType<mfma_input_type>(),
271 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
272 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
275 }
while(i < (num_loop - 1));
284 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
291 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
308 a_thread_vec.template AsType<ComputeDataType>()(ik) =
309 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
311 b_thread_vec.template AsType<ComputeDataType>()(ik) =
312 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
316 using mfma_input_type =
317 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
320 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
322 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
323 b_thread_vec.template AsType<mfma_input_type>(),
332 using Base::a_thread_copy_;
333 using Base::a_thread_desc_;
334 using Base::b_thread_copy_;
335 using Base::b_thread_desc_;
336 using Base::c_thread_desc_;
342 typename ComputeDataType,
343 typename AccDataType,
346 typename AMmaTileDesc,
347 typename BMmaTileDesc,
348 index_t ABlockTransferSrcScalarPerVector,
349 index_t BBlockTransferSrcScalarPerVector,
370 ABlockTransferSrcScalarPerVector,
371 BBlockTransferSrcScalarPerVector,
389 ABlockTransferSrcScalarPerVector,
390 BBlockTransferSrcScalarPerVector,
410 ABlockTransferSrcScalarPerVector,
411 BBlockTransferSrcScalarPerVector,
424 using Base::KPerThread;
425 using Base::xdlops_gemm;
427 using Base::CalculateCThreadOriginDataIndex;
428 using Base::CalculateCThreadOriginDataIndex8D;
429 using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
430 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
431 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
432 using Base::GetCThreadBuffer;
433 using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
434 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
435 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
436 using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
437 using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
439 using Base::a_block_desc_m0_m1_m2_k;
440 using Base::b_block_desc_n0_n1_n2_k;
444 static constexpr
index_t KRepeat = KPerThread / KPerInnerLoop;
450 return num_loop > PrefetchStages;
459 template <
bool HasMainLoop,
463 typename ABlockTransfer,
464 typename AGridBuffer,
465 typename ABlockBuffer,
466 typename ABlockTransferStep,
469 typename BBlockTransfer,
470 typename BGridBuffer,
471 typename BBlockBuffer,
472 typename BBlockTransferStep,
473 typename CThreadBuffer>
474 __device__
void Run(
const AGridDesc& a_grid_desc,
475 const ABlockDesc& a_block_desc,
476 ABlockTransfer& a_blockwise_copy,
477 const AGridBuffer& a_grid_buf,
478 ABlockBuffer& a_block_buf,
479 const ABlockTransferStep& a_block_copy_step,
480 const BGridDesc& b_grid_desc,
481 const BBlockDesc& b_block_desc,
482 BBlockTransfer& b_blockwise_copy,
483 const BGridBuffer& b_grid_buf,
484 BBlockBuffer& b_block_buf,
485 const BBlockTransferStep& b_block_copy_step,
486 CThreadBuffer& c_thread_buf,
489 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
490 a_thread_desc_.GetElementSpaceSize());
491 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
492 b_thread_desc_.GetElementSpaceSize());
495 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
496 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
498 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
499 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
502 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
503 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
506 c_thread_buf.Clear();
509 if constexpr(HasMainLoop)
515 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
516 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
518 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
519 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
524 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
531 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
539 __builtin_amdgcn_sched_barrier(0);
546 if constexpr(k0.value != 0 || KRepeat == 1)
548 __builtin_amdgcn_s_barrier();
549 __builtin_amdgcn_sched_barrier(0);
558 a_thread_vec.template AsType<ComputeDataType>()(ik) =
559 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
561 b_thread_vec.template AsType<ComputeDataType>()(ik) =
562 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
566 using mfma_input_type =
568 xdlops_gemm.K1PerXdlops>::type;
571 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
578 if constexpr(k0.value == KRepeat - 1 &&
579 k_.value == KPerInnerLoop - KPack &&
580 m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
582 __builtin_amdgcn_sched_barrier(0);
584 __builtin_amdgcn_sched_barrier(0);
587 a_thread_vec.template AsType<mfma_input_type>(),
588 b_thread_vec.template AsType<mfma_input_type>(),
590 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
592 __builtin_amdgcn_sched_barrier(0);
593 __builtin_amdgcn_s_setprio(1);
594 __builtin_amdgcn_sched_barrier(0);
599 __builtin_amdgcn_sched_barrier(0);
600 __builtin_amdgcn_s_setprio(0);
601 __builtin_amdgcn_sched_barrier(0);
605 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
606 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
609 }
while(i < (num_loop - 1));
618 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
625 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
634 __builtin_amdgcn_sched_barrier(0);
635 if constexpr(k0.value != 0 || KRepeat == 1)
637 __builtin_amdgcn_s_barrier();
638 __builtin_amdgcn_sched_barrier(0);
647 a_thread_vec.template AsType<ComputeDataType>()(ik) =
648 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
650 b_thread_vec.template AsType<ComputeDataType>()(ik) =
651 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
655 using mfma_input_type =
657 xdlops_gemm.K1PerXdlops>::type;
660 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
662 if constexpr(k0.value == KRepeat - 1 &&
663 k_.value == KPerInnerLoop - KPack &&
664 m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
666 __builtin_amdgcn_sched_barrier(0);
668 __builtin_amdgcn_sched_barrier(0);
671 a_thread_vec.template AsType<mfma_input_type>(),
672 b_thread_vec.template AsType<mfma_input_type>(),
674 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
676 __builtin_amdgcn_sched_barrier(0);
677 __builtin_amdgcn_s_setprio(1);
678 __builtin_amdgcn_sched_barrier(0);
683 __builtin_amdgcn_sched_barrier(0);
684 __builtin_amdgcn_s_setprio(0);
685 __builtin_amdgcn_sched_barrier(0);
695 Number<KRepeat * MRepeat * KPerInnerLoop>{},
696 Number<MRepeat * KPerInnerLoop>{},
702 Number<KRepeat * NRepeat * KPerInnerLoop>{},
703 Number<NRepeat * KPerInnerLoop>{},
708 decltype(a_block_desc_m0_m1_m2_k),
709 decltype(a_thread_desc_),
718 decltype(b_block_desc_n0_n1_n2_k),
719 decltype(b_thread_desc_),
726 AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
727 BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
728 using Base::c_thread_desc_;
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:211
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
TailNumber
Definition: blkgemmpipe_scheduler.hpp:18
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
ck::BlockwiseGemmXdlops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:453
ck::BlockwiseGemmXdlops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:448
ck::BlockwiseGemmXdlops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:474
ck::BlockwiseGemmXdlops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:173
ck::BlockwiseGemmXdlops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:152
ck::BlockwiseGemmXdlops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:147
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:37
Definition: sequence.hpp:43
ck::ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataType, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 >
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
Definition: data_type.hpp:347