20 typename ComputeDataType,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
43 typename ComputeDataType,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
123 using Base::xdlops_gemm;
125 using Base::CalculateCThreadOriginDataIndex;
126 using Base::CalculateCThreadOriginDataIndex8D;
127 using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
128 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
129 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
130 using Base::GetCThreadBuffer;
131 using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
132 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
133 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
134 using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
135 using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
137 using Base::a_block_desc_m0_m1_m2_k;
138 using Base::b_block_desc_n0_n1_n2_k;
140 using Base::AMmaKStride;
141 using Base::BMmaKStride;
144 (4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
147 (MPerBlock *
sizeof(ADataType) + NPerBlock *
sizeof(BDataType)) * KPerBlock);
149 FullMemBandPrefetchStages >= 2
150 ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8
154 static constexpr
index_t GlobalBufferNum = PrefetchStages;
158 return num_loop > PrefetchStages;
163 if(num_loop % PrefetchStages == 1)
167 else if(num_loop % PrefetchStages == 2)
171 else if(num_loop % PrefetchStages == 3)
175 else if(num_loop % PrefetchStages == 4)
179 else if(num_loop % PrefetchStages == 5)
183 else if(num_loop % PrefetchStages == 6)
187 else if(num_loop % PrefetchStages == 7)
197 template <
bool HasMainLoop,
201 typename ABlockTransfer,
202 typename AGridBuffer,
203 typename ABlockBuffer,
204 typename ABlockTransferStep,
207 typename BBlockTransfer,
208 typename BGridBuffer,
209 typename BBlockBuffer,
210 typename BBlockTransferStep,
211 typename CThreadBuffer>
212 __device__
void Run(
const AGridDesc& a_grid_desc,
213 const ABlockDesc& a_block_desc,
214 ABlockTransfer& a_blockwise_copy,
215 const AGridBuffer& a_grid_buf,
216 ABlockBuffer& a_block_buf,
217 const ABlockTransferStep& a_block_copy_step,
218 const BGridDesc& b_grid_desc,
219 const BBlockDesc& b_block_desc,
220 BBlockTransfer& b_blockwise_copy,
221 const BGridBuffer& b_grid_buf,
222 BBlockBuffer& b_block_buf,
223 const BBlockTransferStep& b_block_copy_step,
224 CThreadBuffer& c_thread_buf,
227 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
228 a_thread_desc_.GetElementSpaceSize());
229 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
230 b_thread_desc_.GetElementSpaceSize());
233 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
234 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
236 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
237 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
240 c_thread_buf.Clear();
243 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
244 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
248 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
249 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
251 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
252 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
256 if constexpr(HasMainLoop)
266 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
274 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
290 a_thread_vec.template AsType<ComputeDataType>()(ik) =
291 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
293 b_thread_vec.template AsType<ComputeDataType>()(ik) =
294 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
298 using mfma_input_type =
300 xdlops_gemm.K1PerXdlops>::type;
303 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
306 a_thread_vec.template AsType<mfma_input_type>(),
307 b_thread_vec.template AsType<mfma_input_type>(),
314 a_blockwise_copy.RunWrite(
315 a_block_desc, a_block_buf,
Number<(iprefetch + 1) % PrefetchStages>{});
316 b_blockwise_copy.RunWrite(
317 b_block_desc, b_block_buf,
Number<(iprefetch + 1) % PrefetchStages>{});
319 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
320 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
322 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
323 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
327 }
while(i < (num_loop - PrefetchStages));
332 auto LoopTailFunc = [&](
auto tail_num) {
337 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
345 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
361 a_thread_vec.template AsType<ComputeDataType>()(ik) =
362 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
364 b_thread_vec.template AsType<ComputeDataType>()(ik) =
365 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
369 using mfma_input_type =
371 xdlops_gemm.K1PerXdlops>::type;
374 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
377 a_thread_vec.template AsType<mfma_input_type>(),
378 b_thread_vec.template AsType<mfma_input_type>(),
385 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch);
386 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch);
392 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
400 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
416 a_thread_vec.template AsType<ComputeDataType>()(ik) =
417 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
419 b_thread_vec.template AsType<ComputeDataType>()(ik) =
420 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
424 using mfma_input_type =
425 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
428 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
430 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
431 b_thread_vec.template AsType<mfma_input_type>(),
443 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
451 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
467 a_thread_vec.template AsType<ComputeDataType>()(ik) =
468 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
470 b_thread_vec.template AsType<ComputeDataType>()(ik) =
471 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
475 using mfma_input_type =
476 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
479 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
481 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
482 b_thread_vec.template AsType<mfma_input_type>(),
519 using Base::a_thread_copy_;
520 using Base::a_thread_desc_;
521 using Base::b_thread_copy_;
522 using Base::b_thread_desc_;
523 using Base::c_thread_desc_;
529 typename ComputeDataType,
530 typename AccDataType,
533 typename AMmaTileDesc,
534 typename BMmaTileDesc,
535 index_t ABlockTransferSrcScalarPerVector,
536 index_t BBlockTransferSrcScalarPerVector,
557 ABlockTransferSrcScalarPerVector,
558 BBlockTransferSrcScalarPerVector,
576 ABlockTransferSrcScalarPerVector,
577 BBlockTransferSrcScalarPerVector,
597 ABlockTransferSrcScalarPerVector,
598 BBlockTransferSrcScalarPerVector,
611 using Base::KPerThread;
612 using Base::xdlops_gemm;
614 using Base::CalculateCThreadOriginDataIndex;
615 using Base::CalculateCThreadOriginDataIndex8D;
616 using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
617 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
618 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
619 using Base::GetCThreadBuffer;
620 using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
621 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
622 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
623 using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
624 using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
626 using Base::a_block_desc_m0_m1_m2_k;
627 using Base::b_block_desc_n0_n1_n2_k;
631 static constexpr
index_t KRepeat = KPerThread / KPerInnerLoop;
634 (4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
637 (MPerBlock *
sizeof(ADataType) + NPerBlock *
sizeof(BDataType)) * KPerBlock);
639 FullMemBandPrefetchStages >= 2
640 ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8
644 static constexpr
index_t GlobalBufferNum = PrefetchStages;
648 return num_loop > PrefetchStages;
653 if(num_loop % PrefetchStages == 1)
657 else if(num_loop % PrefetchStages == 2)
661 else if(num_loop % PrefetchStages == 3)
665 else if(num_loop % PrefetchStages == 4)
669 else if(num_loop % PrefetchStages == 5)
673 else if(num_loop % PrefetchStages == 6)
677 else if(num_loop % PrefetchStages == 7)
687 template <
bool HasMainLoop,
691 typename ABlockTransfer,
692 typename AGridBuffer,
693 typename ABlockBuffer,
694 typename ABlockTransferStep,
697 typename BBlockTransfer,
698 typename BGridBuffer,
699 typename BBlockBuffer,
700 typename BBlockTransferStep,
701 typename CThreadBuffer>
702 __device__
void Run(
const AGridDesc& a_grid_desc,
703 const ABlockDesc& a_block_desc,
704 ABlockTransfer& a_blockwise_copy,
705 const AGridBuffer& a_grid_buf,
706 ABlockBuffer& a_block_buf,
707 const ABlockTransferStep& a_block_copy_step,
708 const BGridDesc& b_grid_desc,
709 const BBlockDesc& b_block_desc,
710 BBlockTransfer& b_blockwise_copy,
711 const BGridBuffer& b_grid_buf,
712 BBlockBuffer& b_block_buf,
713 const BBlockTransferStep& b_block_copy_step,
714 CThreadBuffer& c_thread_buf,
717 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
718 a_thread_desc_.GetElementSpaceSize());
719 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
720 b_thread_desc_.GetElementSpaceSize());
723 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
724 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
726 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
727 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
730 c_thread_buf.Clear();
733 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
734 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
738 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
739 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
741 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
742 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
746 if constexpr(HasMainLoop)
756 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
764 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
771 __builtin_amdgcn_sched_barrier(0);
779 if constexpr(k0.value != 0 || KRepeat == 1)
781 __builtin_amdgcn_s_barrier();
782 __builtin_amdgcn_sched_barrier(0);
791 a_thread_vec.template AsType<ComputeDataType>()(ik) =
792 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
794 b_thread_vec.template AsType<ComputeDataType>()(ik) =
795 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
799 using mfma_input_type =
801 xdlops_gemm.K1PerXdlops>::type;
804 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
812 if constexpr(k0.value == KRepeat - 1 &&
813 k_.value == KPerInnerLoop - KPack &&
814 m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
816 __builtin_amdgcn_sched_barrier(0);
818 __builtin_amdgcn_sched_barrier(0);
821 a_thread_vec.template AsType<mfma_input_type>(),
822 b_thread_vec.template AsType<mfma_input_type>(),
824 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
826 __builtin_amdgcn_sched_barrier(0);
827 __builtin_amdgcn_s_setprio(1);
828 __builtin_amdgcn_sched_barrier(0);
833 __builtin_amdgcn_sched_barrier(0);
834 __builtin_amdgcn_s_setprio(0);
835 __builtin_amdgcn_sched_barrier(0);
839 a_blockwise_copy.RunWrite(
840 a_block_desc, a_block_buf,
Number<(iprefetch + 1) % PrefetchStages>{});
841 b_blockwise_copy.RunWrite(
842 b_block_desc, b_block_buf,
Number<(iprefetch + 1) % PrefetchStages>{});
844 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
845 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
847 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
848 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
851 }
while(i < (num_loop - PrefetchStages));
856 auto LoopTailFunc = [&](
auto tail_num) {
861 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
869 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
877 __builtin_amdgcn_sched_barrier(0);
878 if constexpr(k0.value != 0 || KRepeat == 1)
880 __builtin_amdgcn_s_barrier();
881 __builtin_amdgcn_sched_barrier(0);
890 a_thread_vec.template AsType<ComputeDataType>()(ik) =
891 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
893 b_thread_vec.template AsType<ComputeDataType>()(ik) =
894 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
898 using mfma_input_type =
900 xdlops_gemm.K1PerXdlops>::type;
903 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
905 if constexpr(k0.value == KRepeat - 1 &&
906 k_.value == KPerInnerLoop - KPack &&
907 m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
909 __builtin_amdgcn_sched_barrier(0);
911 __builtin_amdgcn_sched_barrier(0);
914 a_thread_vec.template AsType<mfma_input_type>(),
915 b_thread_vec.template AsType<mfma_input_type>(),
917 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
919 __builtin_amdgcn_sched_barrier(0);
920 __builtin_amdgcn_s_setprio(1);
921 __builtin_amdgcn_sched_barrier(0);
926 __builtin_amdgcn_sched_barrier(0);
927 __builtin_amdgcn_s_setprio(0);
928 __builtin_amdgcn_sched_barrier(0);
931 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch);
932 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch);
937 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
945 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
953 __builtin_amdgcn_sched_barrier(0);
954 if constexpr(k0.value != 0 || KRepeat == 1)
956 __builtin_amdgcn_s_barrier();
957 __builtin_amdgcn_sched_barrier(0);
966 a_thread_vec.template AsType<ComputeDataType>()(ik) =
967 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
969 b_thread_vec.template AsType<ComputeDataType>()(ik) =
970 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
974 using mfma_input_type =
976 xdlops_gemm.K1PerXdlops>::type;
979 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
981 if constexpr(k0.value == KRepeat - 1 &&
982 k_.value == KPerInnerLoop - KPack &&
983 m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
985 __builtin_amdgcn_sched_barrier(0);
987 __builtin_amdgcn_sched_barrier(0);
990 a_thread_vec.template AsType<mfma_input_type>(),
991 b_thread_vec.template AsType<mfma_input_type>(),
993 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
995 __builtin_amdgcn_sched_barrier(0);
996 __builtin_amdgcn_s_setprio(1);
997 __builtin_amdgcn_sched_barrier(0);
1002 __builtin_amdgcn_sched_barrier(0);
1003 __builtin_amdgcn_s_setprio(0);
1004 __builtin_amdgcn_sched_barrier(0);
1013 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
1021 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
1029 __builtin_amdgcn_sched_barrier(0);
1030 if constexpr(k0.value != 0 || KRepeat == 1)
1032 __builtin_amdgcn_s_barrier();
1033 __builtin_amdgcn_sched_barrier(0);
1042 a_thread_vec.template AsType<ComputeDataType>()(ik) =
1043 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
1045 b_thread_vec.template AsType<ComputeDataType>()(ik) =
1046 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
1050 using mfma_input_type =
1052 xdlops_gemm.K1PerXdlops>::type;
1055 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
1057 if constexpr(k0.value == KRepeat - 1 &&
1058 k_.value == KPerInnerLoop - KPack &&
1059 m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
1061 __builtin_amdgcn_sched_barrier(0);
1063 __builtin_amdgcn_sched_barrier(0);
1066 a_thread_vec.template AsType<mfma_input_type>(),
1067 b_thread_vec.template AsType<mfma_input_type>(),
1069 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
1071 __builtin_amdgcn_sched_barrier(0);
1072 __builtin_amdgcn_s_setprio(1);
1073 __builtin_amdgcn_sched_barrier(0);
1078 __builtin_amdgcn_sched_barrier(0);
1079 __builtin_amdgcn_s_setprio(0);
1080 __builtin_amdgcn_sched_barrier(0);
1118 Number<KRepeat * MRepeat * KPerInnerLoop>{},
1119 Number<MRepeat * KPerInnerLoop>{},
1125 Number<KRepeat * NRepeat * KPerInnerLoop>{},
1126 Number<NRepeat * KPerInnerLoop>{},
1131 decltype(a_block_desc_m0_m1_m2_k),
1132 decltype(a_thread_desc_),
1141 decltype(b_block_desc_n0_n1_n2_k),
1142 decltype(b_thread_desc_),
1151 using Base::c_thread_desc_;
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:211
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
TailNumber
Definition: blkgemmpipe_scheduler.hpp:18
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
ck::BlockwiseGemmXdlops_pipeline_v2< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:702
ck::BlockwiseGemmXdlops_pipeline_v2< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:651
ck::BlockwiseGemmXdlops_pipeline_v2< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:646
ck::BlockwiseGemmXdlops_pipeline_v2< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:212
ck::BlockwiseGemmXdlops_pipeline_v2< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:161
ck::BlockwiseGemmXdlops_pipeline_v2< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:156
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:37
Definition: sequence.hpp:43
ck::ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataType, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 >
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
Definition: data_type.hpp:347