20 typename ComputeDataType,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
36 struct BlockwiseGemmXdlops_pipeline_v4
43 typename ComputeDataType,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
124 using Base::xdlops_gemm;
127 using Base::CalculateCThreadOriginDataIndex;
128 using Base::CalculateCThreadOriginDataIndex8D;
129 using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
130 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
131 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
132 using Base::GetCThreadBuffer;
133 using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
134 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
135 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
136 using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
137 using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
139 using Base::a_block_desc_m0_m1_m2_k;
140 using Base::b_block_desc_n0_n1_n2_k;
142 using Base::AMmaKStride;
143 using Base::BMmaKStride;
154 return num_loop > PrefetchStages;
159 if(num_loop % HotloopUnroll == 1)
173 constexpr
auto num_ds_read_inst_a =
177 constexpr
auto num_ds_read_inst_b =
183 constexpr
auto num_dswrite_per_issue_a =
185 constexpr
auto num_dsread_per_issue_a = num_ds_read_inst_a / num_issue_a;
188 constexpr
auto num_dswrite_per_issue_b =
190 constexpr
auto num_dsread_per_issue_b = num_ds_read_inst_b / num_issue_b;
192 constexpr
auto num_mfma_per_issue =
199 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
200 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
205 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
206 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
209 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
210 __builtin_amdgcn_sched_group_barrier(0x008,
211 num_mfma_per_issue - num_dsread_per_issue_a -
212 num_dswrite_per_issue_a,
220 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
221 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
226 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
227 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
230 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
231 __builtin_amdgcn_sched_group_barrier(0x008,
232 num_mfma_per_issue - num_dsread_per_issue_a -
233 num_dswrite_per_issue_b,
236 __builtin_amdgcn_sched_barrier(0);
239 template <
bool HasMainLoop,
243 typename ABlockTransfer,
244 typename AGridBuffer,
245 typename ABlockBuffer,
246 typename ABlockTransferStep,
249 typename BBlockTransfer,
250 typename BGridBuffer,
251 typename BBlockBuffer,
252 typename BBlockTransferStep,
253 typename CThreadBuffer>
254 __device__
void Run(
const AGridDesc& a_grid_desc,
255 const ABlockDesc& a_block_desc,
256 ABlockTransfer& a_blockwise_copy,
257 const AGridBuffer& a_grid_buf,
258 ABlockBuffer& a_block_buf,
259 const ABlockTransferStep& a_block_copy_step,
260 const BGridDesc& b_grid_desc,
261 const BBlockDesc& b_block_desc,
262 BBlockTransfer& b_blockwise_copy,
263 const BGridBuffer& b_grid_buf,
264 BBlockBuffer& b_block_buf,
265 const BBlockTransferStep& b_block_copy_step,
266 CThreadBuffer& c_thread_buf,
269 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
271 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
278 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
279 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
281 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
282 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
285 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(
I0));
286 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(
I0));
310 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
311 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
313 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
314 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
317 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(
I1));
318 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(
I1));
321 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
322 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
324 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
325 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
328 c_thread_buf.Clear();
331 if constexpr(HasMainLoop)
337 auto LoopFunc = [&](
auto lds_read_buf,
338 auto lds_read_reg_buf,
347 a_block_buf.At(lds_read_buf),
350 a_thread_bufs(lds_read_reg_buf));
355 b_block_buf.At(lds_read_buf),
358 b_thread_bufs(lds_read_reg_buf));
362 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
363 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
365 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
366 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
368 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
369 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
378 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
379 a_thread_bufs[mfma_reg_buf]
382 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
383 b_thread_bufs[mfma_reg_buf]
388 using mfma_input_type =
396 a_thread_vec.template AsType<mfma_input_type>(),
397 b_thread_vec.template AsType<mfma_input_type>(),
410 }
while(i < (num_loop - PrefetchStages));
413 auto ReadWriteCompFunc = [&](
auto lds_read_buf,
414 auto lds_read_reg_buf,
423 a_block_buf.At(lds_read_buf),
426 a_thread_bufs(lds_read_reg_buf));
431 b_block_buf.At(lds_read_buf),
434 b_thread_bufs(lds_read_reg_buf));
438 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
439 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
448 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
451 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
456 using mfma_input_type =
462 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
463 b_thread_vec.template AsType<mfma_input_type>(),
472 auto ReadCompFunc = [&](
auto lds_read_buf,
auto lds_read_reg_buf,
auto mfma_reg_buf) {
479 a_block_buf.At(lds_read_buf),
482 a_thread_bufs(lds_read_reg_buf));
487 b_block_buf.At(lds_read_buf),
490 b_thread_bufs(lds_read_reg_buf));
501 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
504 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
509 using mfma_input_type =
515 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
516 b_thread_vec.template AsType<mfma_input_type>(),
525 auto CompFunc = [&](
auto mfma_reg_buf) {
533 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
536 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
541 using mfma_input_type =
547 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
548 b_thread_vec.template AsType<mfma_input_type>(),
569 using Base::a_thread_copy_;
570 using Base::a_thread_desc_;
571 using Base::b_thread_copy_;
572 using Base::b_thread_desc_;
573 using Base::c_thread_desc_;
587 typename ComputeDataType,
588 typename AccDataType,
591 typename AMmaTileDesc,
592 typename BMmaTileDesc,
593 index_t ABlockTransferSrcScalarPerVector,
594 index_t BBlockTransferSrcScalarPerVector,
610 typename ComputeDataType,
611 typename AccDataType,
614 typename AMmaTileDesc,
615 typename BMmaTileDesc,
616 index_t ABlockTransferSrcScalarPerVector,
617 index_t BBlockTransferSrcScalarPerVector,
638 ABlockTransferSrcScalarPerVector,
639 BBlockTransferSrcScalarPerVector,
657 ABlockTransferSrcScalarPerVector,
658 BBlockTransferSrcScalarPerVector,
678 ABlockTransferSrcScalarPerVector,
679 BBlockTransferSrcScalarPerVector,
691 using Base::xdlops_gemm;
694 using Base::CalculateCThreadOriginDataIndex;
695 using Base::CalculateCThreadOriginDataIndex8D;
696 using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
697 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
698 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
699 using Base::GetCThreadBuffer;
700 using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
701 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
702 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
703 using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
704 using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
706 using Base::a_block_desc_m0_m1_m2_k;
707 using Base::b_block_desc_n0_n1_n2_k;
709 using Base::AMmaKStride;
710 using Base::BMmaKStride;
721 return num_loop > PrefetchStages;
726 if(num_loop % HotloopUnroll == 1)
740 constexpr
auto num_ds_read_inst_a =
741 HotLoopInstList::A_LDS_Read_Width *
sizeof(ADataType) == 16
742 ? HotLoopInstList::A_LDS_Read_Inst_Num
743 : HotLoopInstList::A_LDS_Read_Inst_Num / 2;
744 constexpr
auto num_ds_read_inst_b =
745 HotLoopInstList::B_LDS_Read_Width *
sizeof(BDataType) == 16
746 ? HotLoopInstList::B_LDS_Read_Inst_Num
747 : HotLoopInstList::B_LDS_Read_Inst_Num / 2;
749 constexpr
auto num_issue_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
750 constexpr
auto num_dswrite_per_issue_a = 0;
751 constexpr
auto num_dsread_per_issue_a = num_ds_read_inst_a / num_issue_a;
753 constexpr
auto num_issue_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
754 constexpr
auto num_dswrite_per_issue_b = 0;
755 constexpr
auto num_dsread_per_issue_b = num_ds_read_inst_b / num_issue_b;
757 constexpr
auto num_mfma_per_issue =
758 HotLoopInstList::C_MFMA_Inst_Num / (num_issue_a + num_issue_b);
764 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
765 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
770 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
771 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
774 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
775 __builtin_amdgcn_sched_group_barrier(0x008,
776 num_mfma_per_issue - num_dsread_per_issue_a -
777 num_dswrite_per_issue_a,
785 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
786 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
791 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
792 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
795 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
796 __builtin_amdgcn_sched_group_barrier(0x008,
797 num_mfma_per_issue - num_dsread_per_issue_a -
798 num_dswrite_per_issue_b,
801 __builtin_amdgcn_sched_barrier(0);
804 template <
bool HasMainLoop,
808 typename ABlockTransfer,
809 typename AGridBuffer,
810 typename ABlockBuffer,
811 typename ABlockTransferStep,
814 typename BBlockTransfer,
815 typename BGridBuffer,
816 typename BBlockBuffer,
817 typename BBlockTransferStep,
818 typename CThreadBuffer>
819 __device__
void Run(
const AGridDesc& a_grid_desc,
820 const ABlockDesc& a_block_desc,
821 ABlockTransfer& a_blockwise_copy,
822 const AGridBuffer& a_grid_buf,
823 ABlockBuffer& a_block_buf,
824 const ABlockTransferStep& a_block_copy_step,
825 const BGridDesc& b_grid_desc,
826 const BBlockDesc& b_block_desc,
827 BBlockTransfer& b_blockwise_copy,
828 const BGridBuffer& b_grid_buf,
829 BBlockBuffer& b_block_buf,
830 const BBlockTransferStep& b_block_copy_step,
831 CThreadBuffer& c_thread_buf,
834 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
835 a_thread_desc_.GetElementSpaceSize());
836 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
837 b_thread_desc_.GetElementSpaceSize());
843 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf.At(I0));
844 b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf.At(I0));
846 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
847 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
854 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
862 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
872 a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_buf.At(I1));
873 b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_buf.At(I1));
875 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
876 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
879 c_thread_buf.Clear();
882 if constexpr(HasMainLoop)
888 auto LoopFunc = [&](
auto lds_read_buf,
889 auto lds_read_reg_buf,
896 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
898 a_block_buf.At(lds_read_buf),
901 a_thread_bufs(lds_read_reg_buf));
904 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
906 b_block_buf.At(lds_read_buf),
909 b_thread_bufs(lds_read_reg_buf));
913 a_blockwise_copy.Run(
914 a_grid_desc, a_grid_buf, a_block_desc, a_block_buf.At(lds_write_buf));
915 b_blockwise_copy.Run(
916 b_grid_desc, b_grid_buf, b_block_desc, b_block_buf.At(lds_write_buf));
918 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
919 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
928 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
929 a_thread_bufs[mfma_reg_buf]
930 [
Number<a_thread_desc_.CalculateOffset(
932 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
933 b_thread_bufs[mfma_reg_buf]
934 [
Number<b_thread_desc_.CalculateOffset(
938 using mfma_input_type =
940 xdlops_gemm.K1PerXdlops>::type;
943 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
946 a_thread_vec.template AsType<mfma_input_type>(),
947 b_thread_vec.template AsType<mfma_input_type>(),
956 LoopFunc(I1, I1, I0, I0);
957 LoopFunc(I0, I0, I1, I1);
960 }
while(i < (num_loop - PrefetchStages));
963 auto ReadWriteCompFunc = [&](
auto lds_read_buf,
964 auto lds_read_reg_buf,
971 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
973 a_block_buf.At(lds_read_buf),
976 a_thread_bufs(lds_read_reg_buf));
979 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
981 b_block_buf.At(lds_read_buf),
984 b_thread_bufs(lds_read_reg_buf));
988 a_blockwise_copy.Run(
989 a_grid_desc, a_grid_buf, a_block_desc, a_block_buf.At(lds_write_buf));
990 b_blockwise_copy.Run(
991 b_grid_desc, b_grid_buf, b_block_desc, b_block_buf.At(lds_write_buf));
1000 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1001 a_thread_bufs[mfma_reg_buf][
Number<a_thread_desc_.CalculateOffset(
1003 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1004 b_thread_bufs[mfma_reg_buf][
Number<b_thread_desc_.CalculateOffset(
1008 using mfma_input_type =
1012 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
1014 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
1015 b_thread_vec.template AsType<mfma_input_type>(),
1024 auto ReadCompFunc = [&](
auto lds_read_buf,
auto lds_read_reg_buf,
auto mfma_reg_buf) {
1029 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
1031 a_block_buf.At(lds_read_buf),
1034 a_thread_bufs(lds_read_reg_buf));
1037 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
1039 b_block_buf.At(lds_read_buf),
1042 b_thread_bufs(lds_read_reg_buf));
1053 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1054 a_thread_bufs[mfma_reg_buf][
Number<a_thread_desc_.CalculateOffset(
1056 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1057 b_thread_bufs[mfma_reg_buf][
Number<b_thread_desc_.CalculateOffset(
1061 using mfma_input_type =
1065 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
1067 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
1068 b_thread_vec.template AsType<mfma_input_type>(),
1077 auto CompFunc = [&](
auto mfma_reg_buf) {
1085 a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1086 a_thread_bufs[mfma_reg_buf][
Number<a_thread_desc_.CalculateOffset(
1088 b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
1089 b_thread_bufs[mfma_reg_buf][
Number<b_thread_desc_.CalculateOffset(
1093 using mfma_input_type =
1097 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
1099 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
1100 b_thread_vec.template AsType<mfma_input_type>(),
1109 ReadWriteCompFunc(I1, I1, I0, I0);
1110 ReadCompFunc(I0, I0, I1);
1115 ReadCompFunc(I1, I1, I0);
1121 using Base::a_thread_copy_;
1122 using Base::a_thread_desc_;
1123 using Base::b_thread_copy_;
1124 using Base::b_thread_desc_;
1125 using Base::c_thread_desc_;
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
__device__ void block_sync_lds_direct_load()
Definition: synchronization.hpp:43
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:299
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
conditional_t< std::is_same< ComputeDataType, ck::tf32_t >::value, float, ComputeDataType > ComputeDataTypeBuf
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:58
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
static constexpr index_t B_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:46
static constexpr index_t A_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:49
static constexpr index_t A_LDS_Read_Width
Definition: blkgemmpipe_scheduler.hpp:82
static constexpr index_t B_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:51
static constexpr index_t A_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:44
static constexpr index_t C_MFMA_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:54
static constexpr index_t A_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:39
static constexpr index_t B_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:41
static constexpr index_t B_LDS_Read_Width
Definition: blkgemmpipe_scheduler.hpp:83
ck::BlockwiseGemmXdlops_pipeline_v4< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:152
ck::BlockwiseGemmXdlops_pipeline_v4< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::HotLoopScheduler static constexpr __device__ void HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:169
ck::BlockwiseGemmXdlops_pipeline_v4< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:157
ck::BlockwiseGemmXdlops_pipeline_v4< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:254
Definition: blockwise_gemm_pipeline_xdlops.hpp:103
static constexpr auto I1
Definition: blockwise_gemm_pipeline_xdlops.hpp:105
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:961
static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops.hpp:373
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:967
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops.hpp:991
static constexpr auto I0
Definition: blockwise_gemm_pipeline_xdlops.hpp:104
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_pipeline_xdlops.hpp:453
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops.hpp:990
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_pipeline_xdlops.hpp:454
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:955
static constexpr auto xdlops_gemm
Definition: blockwise_gemm_pipeline_xdlops.hpp:120
ck::BlockwiseGemmXdlopsDirectLoad_pipeline_v4< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:724
ck::BlockwiseGemmXdlopsDirectLoad_pipeline_v4< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::HotLoopScheduler static constexpr __device__ void HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:736
ck::BlockwiseGemmXdlopsDirectLoad_pipeline_v4< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:719
ck::BlockwiseGemmXdlopsDirectLoad_pipeline_v4< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:819
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:604
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1293
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10