16 typename ComputeTypeA,
17 typename ComputeTypeB,
19 typename AWmmaTileDesc,
20 typename BWmmaTileDesc,
21 index_t ABlockTransferSrcScalarPerVector,
22 index_t BBlockTransferSrcScalarPerVector,
32 bool TransposeC =
false,
33 bool BSkipLDS =
false>
41 typename ComputeTypeA,
42 typename ComputeTypeB,
44 typename AWmmaTileDesc,
45 typename BWmmaTileDesc,
46 index_t ABlockTransferSrcScalarPerVector,
47 index_t BBlockTransferSrcScalarPerVector,
67 ABlockTransferSrcScalarPerVector,
68 BBlockTransferSrcScalarPerVector,
88 ABlockTransferSrcScalarPerVector,
89 BBlockTransferSrcScalarPerVector,
113 ABlockTransferSrcScalarPerVector,
114 BBlockTransferSrcScalarPerVector,
136 using Base::wmma_gemm;
138 using Base::CalculateCThreadOriginDataIndex;
140 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
141 using Base::GetCThreadBuffer;
143 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
145 using Base::a_block_desc_k0_m0_m1_m2_k1;
146 using Base::b_block_desc_k0_n0_n1_n2_k1;
148 using typename Base::Empty;
162 template <
bool HasMainLoop,
166 typename ABlockTransfer,
167 typename AGridBuffer,
168 typename ABlockBuffer,
169 typename ABlockTransferStep,
172 typename BBlockTransfer,
173 typename BGridBuffer,
174 typename BBlockBuffer,
175 typename BBlockTransferStep,
176 typename CThreadBuffer,
177 typename BScaleStruct>
178 __device__
void Run(
const AGridDesc& a_grid_desc,
179 const ABlockDesc& a_block_desc,
180 ABlockTransfer& a_blockwise_copy,
181 const AGridBuffer& a_grid_buf,
182 ABlockBuffer& a_block_buf,
183 const ABlockTransferStep& a_block_copy_step,
184 const BGridDesc& b_grid_desc,
185 const BBlockDesc& b_block_desc,
186 BBlockTransfer& b_blockwise_copy,
187 const BGridBuffer& b_grid_buf,
188 BBlockBuffer& b_block_buf,
189 const BBlockTransferStep& b_block_copy_step,
190 CThreadBuffer& c_thread_buf,
192 BScaleStruct& b_scale_struct,
194 index_t num_loop_per_scale)
const
196 constexpr
index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
198 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
199 a_thread_desc_.GetElementSpaceSize());
200 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
201 b_thread_desc_.GetElementSpaceSize());
204 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
205 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
207 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
208 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
210 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
213 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
214 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
217 c_thread_buf.Clear();
219 auto blockwise_gemm_func = [&]() {
222 a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
228 if constexpr(m0 == I0)
233 b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
245 b_block_desc_k0_n0_n1_n2_k1,
248 b_scale_struct.b_scale_thread_bufs(
249 I0)[
Number<n0 * BScaleStruct::num_scale_k_block +
250 k0 / BScaleStruct::num_scale_krepeat>{}],
260 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
261 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
263 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
264 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
265 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
266 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
275 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
276 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
277 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
278 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
288 using wmma_input_type_a =
289 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
290 using wmma_input_type_b =
291 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
294 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
296 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
297 b_thread_vec.template AsType<wmma_input_type_b>(),
306 if constexpr(HasMainLoop)
311 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
312 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
314 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
315 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
318 blockwise_gemm_func();
321 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
326 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
327 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
329 constexpr
index_t num_ds_write_inst =
330 HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num;
332 constexpr
index_t num_buffer_load_inst = HotLoopInstList::A_Buffer_Load_Inst_Num +
333 HotLoopInstList::B_Buffer_Load_Inst_Num;
335 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
339 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
340 if constexpr(m0 == I0)
343 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
348 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
354 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
358 }
while(i < (num_loop - 1));
365 blockwise_gemm_func();
381 decltype(a_block_desc_k0_m0_m1_m2_k1),
382 decltype(a_thread_desc_),
383 Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
392 decltype(b_block_desc_k0_n0_n1_n2_k1),
393 decltype(b_thread_desc_),
394 Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
400 AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
401 BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
402 using Base::c_thread_desc_;
408 typename ComputeTypeA,
409 typename ComputeTypeB,
410 typename AccDataType,
411 typename AWmmaTileDesc,
412 typename BWmmaTileDesc,
413 index_t ABlockTransferSrcScalarPerVector,
414 index_t BBlockTransferSrcScalarPerVector,
434 ABlockTransferSrcScalarPerVector,
435 BBlockTransferSrcScalarPerVector,
455 ABlockTransferSrcScalarPerVector,
456 BBlockTransferSrcScalarPerVector,
480 ABlockTransferSrcScalarPerVector,
481 BBlockTransferSrcScalarPerVector,
502 using Base::wmma_gemm;
504 using Base::CalculateCThreadOriginDataIndex;
506 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
507 using Base::GetCThreadBuffer;
509 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
511 using Base::a_block_desc_k0_m0_m1_m2_k1;
512 using Base::b_block_desc_k0_n0_n1_n2_k1;
514 using typename Base::Empty;
531 template <
bool HasMainLoop,
535 typename ABlockTransfer,
536 typename AGridBuffer,
537 typename ABlockBuffer,
538 typename ABlockTransferStep,
541 typename BBlockTransfer,
542 typename BGridBuffer,
543 typename BBlockBuffer,
544 typename BBlockTransferStep,
545 typename CThreadBuffer,
546 typename BScaleStruct>
547 __device__
void Run(
const AGridDesc& a_grid_desc,
548 const ABlockDesc& a_block_desc,
549 ABlockTransfer& a_blockwise_copy,
550 const AGridBuffer& a_grid_buf,
551 ABlockBuffer& a_block_buf,
552 const ABlockTransferStep& a_block_copy_step,
553 const BGridDesc& b_grid_desc,
554 const BBlockDesc& b_block_desc,
555 BBlockTransfer& b_blockwise_copy,
556 const BGridBuffer& b_grid_buf,
557 BBlockBuffer& b_block_buf,
558 const BBlockTransferStep& b_block_copy_step,
559 CThreadBuffer& c_thread_buf,
561 BScaleStruct& b_scale_struct,
563 index_t num_loop_per_scale)
const
565 constexpr
index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
567 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
568 a_thread_desc_.GetElementSpaceSize());
569 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
570 b_thread_desc_.GetElementSpaceSize());
573 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
574 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
576 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
577 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
579 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
582 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
583 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
586 c_thread_buf.Clear();
588 auto blockwise_gemm_func = [&]() {
592 a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
593 make_tuple(I0, m0, k0_offset + k0_inner, I0, I0, I0, I0),
603 b_block_desc_k0_n0_n1_n2_k1,
604 make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
615 b_block_desc_k0_n0_n1_n2_k1,
616 make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
618 b_scale_struct.b_scale_thread_bufs(I0)[
Number<
619 n0 * BScaleStruct::num_scale_k_block +
620 (k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
628 __builtin_amdgcn_sched_barrier(0);
635 if constexpr(k0_offset != 0 || KRepeat == 1)
637 __builtin_amdgcn_s_barrier();
638 __builtin_amdgcn_sched_barrier(0);
644 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
645 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
647 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
648 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
649 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
650 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
659 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
660 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
661 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
662 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
672 using wmma_input_type_a =
673 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
674 using wmma_input_type_b =
675 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
678 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
686 if constexpr(k0_offset + k0_inner == KRepeat - 1 &&
687 m0 == MRepeat - 1 && n0 == NRepeat - 1)
689 __builtin_amdgcn_sched_barrier(0);
691 __builtin_amdgcn_sched_barrier(0);
694 a_thread_vec.template AsType<wmma_input_type_a>(),
695 b_thread_vec.template AsType<wmma_input_type_b>(),
697 if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
699 __builtin_amdgcn_sched_barrier(0);
700 __builtin_amdgcn_s_setprio(1);
701 __builtin_amdgcn_sched_barrier(0);
707 __builtin_amdgcn_sched_barrier(0);
708 __builtin_amdgcn_s_setprio(0);
709 __builtin_amdgcn_sched_barrier(0);
714 if constexpr(HasMainLoop)
719 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
720 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
722 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
723 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
726 blockwise_gemm_func();
728 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
733 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
734 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
737 }
while(i < (num_loop - 1));
744 blockwise_gemm_func();
749 static constexpr
auto a_thread_desc_ =
752 Number<KRepeatPerCluster>{},
759 Number<KPack / A_KRow * MRepeat>{},
765 static constexpr
auto b_thread_desc_ =
768 Number<KRepeatPerCluster>{},
775 Number<KPack / B_KRow * NRepeat>{},
784 decltype(a_block_desc_k0_m0_m1_m2_k1),
785 decltype(a_thread_desc_),
786 Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
795 decltype(b_block_desc_k0_n0_n1_n2_k1),
796 decltype(b_thread_desc_),
797 Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
803 AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
804 BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
805 using Base::c_thread_desc_;
811 typename ComputeTypeA,
812 typename ComputeTypeB,
813 typename AccDataType,
814 typename AWmmaTileDesc,
815 typename BWmmaTileDesc,
816 index_t ABlockTransferSrcScalarPerVector,
817 index_t BBlockTransferSrcScalarPerVector,
837 ABlockTransferSrcScalarPerVector,
838 BBlockTransferSrcScalarPerVector,
858 ABlockTransferSrcScalarPerVector,
859 BBlockTransferSrcScalarPerVector,
883 ABlockTransferSrcScalarPerVector,
884 BBlockTransferSrcScalarPerVector,
898 using Base::WaveSize;
908 using Base::wmma_gemm;
910 using Base::CalculateCThreadOriginDataIndex;
912 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
913 using Base::GetCThreadBuffer;
915 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
917 using Base::a_block_desc_k0_m0_m1_m2_k1;
918 using Base::b_block_desc_k0_n0_n1_n2_k1;
920 using typename Base::Empty;
935 constexpr
auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
936 constexpr
auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
937 constexpr
auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves;
938 constexpr
auto wmma_interleave = 2;
942 if constexpr(MPerBlock >= 128 && NPerBlock >= 128)
944 __builtin_amdgcn_sched_group_barrier(0x008, 2 * wmma_interleave, 0);
948 __builtin_amdgcn_sched_group_barrier(0x008, wmma_interleave, 0);
950 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
956 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
957 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
958 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
959 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
965 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
966 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
970 template <
bool HasMainLoop,
974 typename ABlockTransfer,
975 typename AGridBuffer,
976 typename ABlockBuffer,
977 typename ABlockTransferStep,
980 typename BBlockTransfer,
981 typename BGridBuffer,
982 typename BBlockBuffer,
983 typename BBlockTransferStep,
984 typename CThreadBuffer,
985 typename BScaleStruct>
986 __device__
void Run(
const AGridDesc& a_grid_desc,
987 const ABlockDesc& a_block_desc,
988 ABlockTransfer& a_blockwise_copy,
989 const AGridBuffer& a_grid_buf,
990 ABlockBuffer& a_block_buf,
991 const ABlockTransferStep& a_block_copy_step,
992 const BGridDesc& b_grid_desc,
994 BBlockTransfer& b_blockwise_copy,
995 const BGridBuffer& b_grid_buf,
997 const BBlockTransferStep& b_block_copy_step,
998 CThreadBuffer& c_thread_buf,
1004 __builtin_amdgcn_sched_barrier(0);
1005 constexpr
index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
1007 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
1008 a_thread_desc_.GetElementSpaceSize());
1009 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
1010 b_thread_desc_.GetElementSpaceSize());
1013 constexpr
auto b_block_origin_idx =
make_tuple(I0, I0, I0, I0, I0, I0, I0);
1016 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
1017 b_blockwise_copy.Run(b_grid_desc,
1019 b_block_desc_k0_n0_n1_n2_k1,
1023 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1024 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
1025 __builtin_amdgcn_sched_barrier(0);
1028 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
1031 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
1032 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1038 a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
1048 c_thread_buf.Clear();
1050 __builtin_amdgcn_sched_barrier(0);
1053 if constexpr(HasMainLoop)
1058 auto LoopFunc = [&](
auto wmma_reg_buf,
auto local_read_buf) {
1059 b_blockwise_copy.Run(b_grid_desc,
1061 b_block_desc_k0_n0_n1_n2_k1,
1063 b_thread_bufs(local_read_buf));
1065 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
1069 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, wmma_reg_buf);
1071 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
1072 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1077 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1078 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1080 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
1081 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1082 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1083 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
1092 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
1093 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1094 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1095 b_thread_bufs[wmma_reg_buf]
1096 [
Number<b_thread_desc_.CalculateOffset(
1105 using wmma_input_type_a =
1106 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1107 using wmma_input_type_b =
1108 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1111 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
1114 a_thread_vec.template AsType<wmma_input_type_a>(),
1115 b_thread_vec.template AsType<wmma_input_type_b>(),
1127 a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
1137 __builtin_amdgcn_sched_barrier(0);
1144 }
while(i < (num_loop - 2));
1150 b_blockwise_copy.Run(b_grid_desc,
1152 b_block_desc_k0_n0_n1_n2_k1,
1158 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
1163 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1164 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1166 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
1167 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1168 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1169 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
1178 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
1179 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1180 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1181 b_thread_bufs[I0][
Number<b_thread_desc_.CalculateOffset(
1191 using wmma_input_type_a =
1192 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1193 using wmma_input_type_b =
1194 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1197 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
1199 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
1200 b_thread_vec.template AsType<wmma_input_type_b>(),
1212 a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
1221 __builtin_amdgcn_sched_barrier(0);
1226 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1227 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1229 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
1230 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1231 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1232 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
1241 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
1242 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1243 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1244 b_thread_bufs[I1][
Number<b_thread_desc_.CalculateOffset(
1253 using wmma_input_type_a =
1254 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1255 using wmma_input_type_b =
1256 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1259 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
1261 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
1262 b_thread_vec.template AsType<wmma_input_type_b>(),
1277 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1278 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1280 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
1281 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1282 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1283 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
1292 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
1293 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
1294 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1295 b_thread_bufs[I0][
Number<b_thread_desc_.CalculateOffset(
1304 using wmma_input_type_a =
1305 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1306 using wmma_input_type_b =
1307 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1310 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
1312 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
1313 b_thread_vec.template AsType<wmma_input_type_b>(),
1323 static constexpr
auto b_thread_desc_ =
1332 using Base::a_thread_copy_;
1333 using Base::a_thread_desc_;
1334 using Base::c_thread_desc_;
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:211
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:301
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:36
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::BlockLoopTailNum static TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:525
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:547
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::BlockHasHotloop static bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:523
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::BlockLoopTailNum static TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:156
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:178
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::BlockHasHotloop static bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:154
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, true >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &, index_t num_loop, index_t) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:986
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, true >::HotLoopScheduler static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:933
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, true >::BlockHasHotloop static bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:926
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, true >::BlockLoopTailNum static TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:928
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:35
Definition: sequence.hpp:43
ck::ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeTypeA, decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_thread_desc_), Sequence< KPack/A_K1/A_KRow, 1, 1, 1, 1, 1, A_K1 >, Sequence< 0, 1, 2, 3, 4, 5, 6 >, 6, A_K1, A_K1 >
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:11