20 typename AScaleDataType,
22 typename BScaleDataType,
25 typename AMmaTileDesc,
26 typename BMmaTileDesc,
27 index_t ABlockTransferSrcScalarPerVector,
28 index_t BBlockTransferSrcScalarPerVector,
41 template <
index_t ThreadBlockSize,
44 typename AScaleDataType,
46 typename BScaleDataType,
49 typename AMmaTileDesc,
50 typename BMmaTileDesc,
51 index_t ABlockTransferSrcScalarPerVector,
52 index_t BBlockTransferSrcScalarPerVector,
72 ABlockTransferSrcScalarPerVector,
73 BBlockTransferSrcScalarPerVector,
89 ABlockTransferSrcScalarPerVector,
90 BBlockTransferSrcScalarPerVector,
109 ABlockTransferSrcScalarPerVector,
110 BBlockTransferSrcScalarPerVector,
124 using Base::WaveSize;
125 using Base::xdlops_gemm;
128 using Base::CalculateCThreadOriginDataIndex;
129 using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
130 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
131 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
132 using Base::GetCThreadBuffer;
133 using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
134 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
135 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
136 using Base::GetWaveIdx;
137 using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
138 using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
140 using Base::a_block_desc_m0_m1_m2_m3_k;
141 using Base::b_block_desc_n0_n1_n2_n3_k;
143 using Base::AMmaKStride;
144 using Base::APackedSize;
145 using Base::BMmaKStride;
146 using Base::BPackedSize;
147 using Base::KThreadChunk;
149 using Base::KXdlPack;
150 using Base::MXdlPack;
151 using Base::NXdlPack;
162 static constexpr
auto ScalesPerKBlockSize =
163 KPerBlock / ScaleBlockSize;
166 static constexpr
auto ScalesPerXdlopsRun =
167 (APackedSize * KPack * xdlops_gemm.K0PerXdlops) / ScaleBlockSize;
170 static constexpr
auto ScalesPerXdlopsRunPerThread =
171 ScalesPerXdlopsRun / xdlops_gemm.mfma_instr.num_input_blks;
174 static constexpr
auto scale_pack_size_a =
sizeof(AScaleDataType) /
sizeof(
mx_scale_t);
175 static constexpr
auto scale_pack_size_b =
sizeof(BScaleDataType) /
sizeof(
mx_scale_t);
176 static_assert(KXdlPack * MXdlPack % scale_pack_size_a == 0,
177 "A scale pack data type too large!");
178 static_assert(KXdlPack * NXdlPack % scale_pack_size_b == 0,
179 "B scale pack data type too large!");
180 static constexpr
auto a_scale_thread_vec_size = KXdlPack * MXdlPack / scale_pack_size_a;
181 static constexpr
auto b_scale_thread_vec_size = KXdlPack * NXdlPack / scale_pack_size_b;
185 return num_loop > PrefetchStages;
197 constexpr
auto num_ds_read_inst_a =
198 HotLoopInstList::A_LDS_Read_Width *
sizeof(ADataType) == 16
199 ? HotLoopInstList::A_LDS_Read_Inst_Num
200 : HotLoopInstList::A_LDS_Read_Inst_Num / 2;
201 constexpr
auto num_ds_read_inst_b =
202 HotLoopInstList::B_LDS_Read_Width *
sizeof(BDataType) == 16
203 ? HotLoopInstList::B_LDS_Read_Inst_Num
204 : HotLoopInstList::B_LDS_Read_Inst_Num / 2;
206 constexpr
auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
207 constexpr
auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
209 constexpr
auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
210 constexpr
auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
212 constexpr
auto num_buffer_load_a_scale = MRepeat / MXdlPack * KRepeat / KXdlPack;
213 constexpr
auto num_buffer_load_b_scale = NRepeat / NXdlPack * KRepeat / KXdlPack;
215 constexpr
auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * APackedSize;
217 constexpr
auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
218 constexpr
auto ds_read_a_issue_cycle =
219 HotLoopInstList::A_LDS_Read_Width *
sizeof(ADataType) == 16 ? 8 : 4;
220 constexpr
auto ds_read_b_issue_cycle =
221 HotLoopInstList::B_LDS_Read_Width *
sizeof(BDataType) == 16 ? 8 : 4;
223 constexpr
auto ds_read_a_mfma_rate =
224 (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
225 constexpr
auto ds_read_b_mfma_rate =
226 (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
228 constexpr
auto num_dsread_a_mfma =
229 (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
230 constexpr
auto num_dsread_b_mfma =
231 (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
234 constexpr
auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
235 constexpr
auto num_buffer_load_total = num_buffer_load_inst_a + num_buffer_load_inst_b +
236 num_buffer_load_a_scale + num_buffer_load_b_scale;
238 constexpr
auto mfma_perstage_more =
240 constexpr
auto mfma_perstage_less =
243 constexpr
auto mfma_stages_more =
244 num_mfma_stage1 - mfma_perstage_less * num_buffer_load_total;
246 constexpr
auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
247 constexpr
auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
250 if constexpr(i < mfma_stages_more)
253 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
254 if constexpr(imfma < num_dswrite_per_issue_a)
256 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
259 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
264 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
265 if constexpr(imfma < num_dswrite_per_issue_a)
267 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
270 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
275 if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more)
278 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
279 if constexpr(imfma < num_dswrite_per_issue_a)
281 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
284 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
289 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
290 if constexpr(imfma < num_dswrite_per_issue_b)
292 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
295 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
300 if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b) < mfma_stages_more)
303 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
305 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
310 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
312 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
317 if constexpr((i + num_buffer_load_inst_a + num_buffer_load_inst_b +
318 num_buffer_load_a_scale) < mfma_stages_more)
321 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
323 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
328 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
330 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
336 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
337 if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
340 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0);
344 __builtin_amdgcn_sched_group_barrier(0x100,
345 num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
352 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
353 if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
356 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0);
360 __builtin_amdgcn_sched_group_barrier(0x100,
361 num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
368 template <
bool HasMainLoop,
372 typename ABlockTransfer,
373 typename AGridBuffer,
374 typename ABlockBuffer,
375 typename ABlockTransferStep,
378 typename BBlockTransfer,
379 typename BGridBuffer,
380 typename BBlockBuffer,
381 typename BBlockTransferStep,
382 typename CThreadBuffer,
383 typename AScaleGridBuffer,
384 typename AScaleGridDesc,
385 typename AScaleThreadTransfer,
386 typename BScaleGridBuffer,
387 typename BScaleGridDesc,
388 typename BScaleThreadTransfer>
391 const AGridDesc& a_grid_desc,
392 const ABlockDesc& a_block_desc,
393 ABlockTransfer& a_blockwise_copy,
394 const AGridBuffer& a_grid_buf,
395 ABlockBuffer& a_block_buf,
396 const ABlockTransferStep& a_block_copy_step,
398 const BGridDesc& b_grid_desc,
399 const BBlockDesc& b_block_desc,
400 BBlockTransfer& b_blockwise_copy,
401 const BGridBuffer& b_grid_buf,
402 BBlockBuffer& b_block_buf,
403 const BBlockTransferStep& b_block_copy_step,
405 CThreadBuffer& c_thread_buf,
407 const AScaleGridDesc& a_scale_grid_desc,
408 AScaleThreadTransfer& a_scale_thread_copy,
409 const AScaleGridBuffer& a_scale_grid_buf,
410 const BScaleGridDesc& b_scale_grid_desc,
411 BScaleThreadTransfer& b_scale_thread_copy,
412 const BScaleGridBuffer& b_scale_grid_buf,
415 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
416 a_thread_desc_.GetElementSpaceSize());
417 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
418 b_thread_desc_.GetElementSpaceSize());
420 auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleDataType>(
421 a_scale_thread_desc.GetElementSpaceSize());
423 auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleDataType>(
424 b_scale_thread_desc.GetElementSpaceSize());
430 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
431 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
433 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
434 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
437 static_for<0, MRepeat / MXdlPack, 1>{}([&](
auto m0) {
438 static_for<0, KRepeat / KXdlPack, 1>{}([&](
auto k0) {
439 a_scale_thread_copy.Run(a_scale_grid_desc,
443 a_scale_thread_bufs(I0));
445 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
448 a_scale_thread_copy.MoveSrcSliceWindow(
453 a_scale_thread_copy.MoveSrcSliceWindow(
458 static_for<0, NRepeat / NXdlPack, 1>{}([&](
auto n0) {
459 static_for<0, KRepeat / KXdlPack, 1>{}([&](
auto k0) {
460 b_scale_thread_copy.Run(b_scale_grid_desc,
464 b_scale_thread_bufs(I0));
466 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
469 b_scale_thread_copy.MoveSrcSliceWindow(
475 b_scale_thread_copy.MoveSrcSliceWindow(
480 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
481 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
484 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
485 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
487 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
488 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
493 constexpr
auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
494 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
496 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
498 constexpr
auto a_k_step_chunk =
499 k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
500 a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
518 static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
520 constexpr
auto b_k_step_chunk =
521 k_step + chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
522 b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
541 c_thread_buf.Clear();
542 __builtin_amdgcn_sched_barrier(0);
545 if constexpr(HasMainLoop)
551 auto LoopFunc = [&](
auto scale_comp_buf,
auto scale_mem_buf) {
554 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
555 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
557 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
558 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
561 static_for<0, MRepeat / MXdlPack, 1>{}([&](
auto m0) {
562 static_for<0, KRepeat / KXdlPack, 1>{}([&](
auto k0) {
563 a_scale_thread_copy.Run(a_scale_grid_desc,
567 a_scale_thread_bufs(scale_mem_buf));
569 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
572 a_scale_thread_copy.MoveSrcSliceWindow(
577 a_scale_thread_copy.MoveSrcSliceWindow(
582 static_for<0, NRepeat / NXdlPack, 1>{}([&](
auto n0) {
583 static_for<0, KRepeat / KXdlPack, 1>{}([&](
auto k0) {
584 b_scale_thread_copy.Run(b_scale_grid_desc,
588 b_scale_thread_bufs(scale_mem_buf));
590 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
593 b_scale_thread_copy.MoveSrcSliceWindow(
599 b_scale_thread_copy.MoveSrcSliceWindow(
603 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
604 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
606 static_for<0, MRepeat / MXdlPack, 1>{}([&](
auto m0) {
607 static_for<0, NRepeat / NXdlPack, 1>{}([&](
auto n0) {
608 static_for<0, KRepeat / KXdlPack, 1>{}([&](
auto k0) {
609 constexpr
index_t a_scale_offset =
610 a_scale_thread_desc.CalculateOffset(
make_tuple(m0, k0, I0));
611 constexpr
index_t b_scale_offset =
612 b_scale_thread_desc.CalculateOffset(
make_tuple(n0, k0, I0));
614 static_assert(0 < ScalesPerXdlopsRunPerThread,
615 "Must have at least one scale per Xdlops "
625 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
631 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
639 constexpr
auto kxdl = ikxdl + k0 * KXdlPack;
645 a_thread_vec.template AsType<ComputeTypeA>()(
647 [
Number<a_thread_desc_.CalculateOffset(
649 b_thread_vec.template AsType<ComputeTypeB>()(
651 [
Number<b_thread_desc_.CalculateOffset(
655 using mfma_input_type_a =
657 xdlops_gemm.K1PerXdlops /
660 using mfma_input_type_b =
662 xdlops_gemm.K1PerXdlops /
665 using mfma_scale_input_type_a =
667 a_scale_thread_vec_size>::type;
668 using mfma_scale_input_type_b =
670 b_scale_thread_vec_size>::type;
673 c_thread_desc_.CalculateOffset(
677 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
678 ikxdl * NXdlPack + inxdl>(
679 a_thread_vec.template AsType<mfma_input_type_a>(),
681 .template AsType<mfma_scale_input_type_a>(),
682 b_thread_vec.template AsType<mfma_input_type_b>(),
684 .template AsType<mfma_scale_input_type_b>(),
685 c_thread_buf.GetVectorTypeReference(
707 constexpr
auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
708 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
711 xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk),
712 1>{}([&](
auto chunk) {
713 constexpr
auto a_k_step_chunk =
715 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
716 a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
735 xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk),
736 1>{}([&](
auto chunk) {
737 constexpr
auto b_k_step_chunk =
739 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
740 b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
759 __builtin_amdgcn_sched_barrier(0);
766 }
while(i < (num_loop - 2));
773 static_for<0, MRepeat / MXdlPack, 1>{}([&](
auto m0) {
774 static_for<0, KRepeat / KXdlPack, 1>{}([&](
auto k0) {
775 a_scale_thread_copy.Run(a_scale_grid_desc,
779 a_scale_thread_bufs(I1));
781 a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
784 a_scale_thread_copy.MoveSrcSliceWindow(
789 static_for<0, NRepeat / NXdlPack, 1>{}([&](
auto n0) {
790 static_for<0, KRepeat / KXdlPack, 1>{}([&](
auto k0) {
791 b_scale_thread_copy.Run(b_scale_grid_desc,
795 b_scale_thread_bufs(I1));
797 b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
800 b_scale_thread_copy.MoveSrcSliceWindow(
805 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
806 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
808 static_for<0, MRepeat / MXdlPack, 1>{}([&](
auto m0) {
809 static_for<0, NRepeat / NXdlPack, 1>{}([&](
auto n0) {
810 static_for<0, KRepeat / KXdlPack, 1>{}([&](
auto k0) {
811 constexpr
index_t a_scale_offset =
812 a_scale_thread_desc.CalculateOffset(
make_tuple(m0, k0, I0));
813 constexpr
index_t b_scale_offset =
814 b_scale_thread_desc.CalculateOffset(
make_tuple(n0, k0, I0));
816 static_assert(0 < ScalesPerXdlopsRunPerThread,
817 "Must have at least one scale per Xdlops "
825 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
830 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
837 constexpr
auto kxdl = ikxdl + k0 * KXdlPack;
843 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
844 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
846 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
847 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
851 using mfma_input_type_a =
853 xdlops_gemm.K1PerXdlops /
856 using mfma_input_type_b =
858 xdlops_gemm.K1PerXdlops /
861 using mfma_scale_input_type_a =
863 a_scale_thread_vec_size>::type;
864 using mfma_scale_input_type_b =
866 b_scale_thread_vec_size>::type;
868 constexpr
index_t c_offset = c_thread_desc_.CalculateOffset(
872 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
873 ikxdl * NXdlPack + inxdl>(
874 a_thread_vec.template AsType<mfma_input_type_a>(),
876 .template AsType<mfma_scale_input_type_a>(),
877 b_thread_vec.template AsType<mfma_input_type_b>(),
879 .template AsType<mfma_scale_input_type_b>(),
891 constexpr
auto k_step = k * xdlops_gemm.KPerXdlops / APackedSize *
892 (APackedSize * KPack / xdlops_gemm.K1PerXdlops);
894 static_for<0, xdlops_gemm.K1PerXdlops / (APackedSize * KThreadChunk), 1>{}(
896 constexpr
auto a_k_step_chunk =
898 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
899 a_thread_copy_.Run(a_block_desc_m0_m1_m2_m3_k,
917 static_for<0, xdlops_gemm.K1PerXdlops / (BPackedSize * KThreadChunk), 1>{}(
919 constexpr
auto b_k_step_chunk =
921 chunk * KThreadChunk * xdlops_gemm.mfma_instr.num_input_blks;
922 b_thread_copy_.Run(b_block_desc_n0_n1_n2_n3_k,
940 static_for<0, MRepeat / MXdlPack, 1>{}([&](
auto m0) {
941 static_for<0, NRepeat / NXdlPack, 1>{}([&](
auto n0) {
942 static_for<0, KRepeat / KXdlPack, 1>{}([&](
auto k0) {
943 constexpr
index_t a_scale_offset =
944 a_scale_thread_desc.CalculateOffset(
make_tuple(m0, k0, I0));
945 constexpr
index_t b_scale_offset =
946 b_scale_thread_desc.CalculateOffset(
make_tuple(n0, k0, I0));
948 static_assert(0 < ScalesPerXdlopsRunPerThread,
949 "Must have at least one scale per Xdlops "
957 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
962 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
969 constexpr
auto kxdl = ikxdl + k0 * KXdlPack;
975 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
976 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
978 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
979 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
983 using mfma_input_type_a =
985 xdlops_gemm.K1PerXdlops /
988 using mfma_input_type_b =
990 xdlops_gemm.K1PerXdlops /
993 using mfma_scale_input_type_a =
995 a_scale_thread_vec_size>::type;
996 using mfma_scale_input_type_b =
998 b_scale_thread_vec_size>::type;
1000 constexpr
index_t c_offset = c_thread_desc_.CalculateOffset(
1004 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
1005 ikxdl * NXdlPack + inxdl>(
1006 a_thread_vec.template AsType<mfma_input_type_a>(),
1008 .template AsType<mfma_scale_input_type_a>(),
1009 b_thread_vec.template AsType<mfma_input_type_b>(),
1011 .template AsType<mfma_scale_input_type_b>(),
1022 static_for<0, MRepeat / MXdlPack, 1>{}([&](
auto m0) {
1023 static_for<0, NRepeat / NXdlPack, 1>{}([&](
auto n0) {
1024 static_for<0, KRepeat / KXdlPack, 1>{}([&](
auto k0) {
1025 constexpr
index_t a_scale_offset =
1026 a_scale_thread_desc.CalculateOffset(
make_tuple(m0, k0, I0));
1027 constexpr
index_t b_scale_offset =
1028 b_scale_thread_desc.CalculateOffset(
make_tuple(n0, k0, I0));
1030 static_assert(0 < ScalesPerXdlopsRunPerThread,
1031 "Must have at least one scale per Xdlops "
1039 a_scale_thread_vec.template AsType<AScaleDataType>()(s) =
1044 b_scale_thread_vec.template AsType<BScaleDataType>()(s) =
1051 constexpr
auto kxdl = ikxdl + k0 * KXdlPack;
1057 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1058 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
1060 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1061 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
1065 using mfma_input_type_a =
1067 xdlops_gemm.K1PerXdlops /
1070 using mfma_input_type_b =
1072 xdlops_gemm.K1PerXdlops /
1075 using mfma_scale_input_type_a =
1077 a_scale_thread_vec_size>::type;
1078 using mfma_scale_input_type_b =
1080 b_scale_thread_vec_size>::type;
1082 constexpr
index_t c_offset = c_thread_desc_.CalculateOffset(
1086 xdlops_gemm.template Run<ikxdl * MXdlPack + imxdl,
1087 ikxdl * NXdlPack + inxdl>(
1088 a_thread_vec.template AsType<mfma_input_type_a>(),
1090 .template AsType<mfma_scale_input_type_a>(),
1091 b_thread_vec.template AsType<mfma_input_type_b>(),
1093 .template AsType<mfma_scale_input_type_b>(),
1108 Number<KRepeat / KXdlPack>{},
1109 Number<ScalesPerXdlopsRunPerThread * a_scale_thread_vec_size>{}));
1115 Number<KRepeat / KXdlPack>{},
1116 Number<ScalesPerXdlopsRunPerThread * b_scale_thread_vec_size>{}));
1119 using Base::a_thread_copy_;
1120 using Base::a_thread_desc_;
1121 using Base::b_thread_copy_;
1122 using Base::b_thread_desc_;
1123 using Base::c_thread_desc_;
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto integer_divide_floor(X x, Y y)
Definition: math.hpp:66
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:300
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_mx_pipeline_xdlops_base.hpp:33
ck::BlockwiseGemmXdlops_mx_pipeline_base< ThreadBlockSize, ADataType, BDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::AccType float AccType
Definition: blockwise_gemm_mx_pipeline_xdlops_base.hpp:36
ck::BlockwiseGemmXdlops_mx_pipeline_base< ThreadBlockSize, ADataType, BDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::ComputeTypeA ADataType ComputeTypeA
Definition: blockwise_gemm_mx_pipeline_xdlops_base.hpp:34
ck::BlockwiseGemmXdlops_mx_pipeline_base< ThreadBlockSize, ADataType, BDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Tuple5 decltype(CalculateAThreadOriginDataIndex()) Tuple5
Definition: blockwise_gemm_mx_pipeline_xdlops_base.hpp:185
ck::BlockwiseGemmXdlops_mx_pipeline_base< ThreadBlockSize, ADataType, BDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::ComputeTypeB BDataType ComputeTypeB
Definition: blockwise_gemm_mx_pipeline_xdlops_base.hpp:35
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
ck::BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3< BlockGemmPipelineScheduler::Intrawave, ThreadBlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, const AScaleGridDesc &a_scale_grid_desc, AScaleThreadTransfer &a_scale_thread_copy, const AScaleGridBuffer &a_scale_grid_buf, const BScaleGridDesc &b_scale_grid_desc, BScaleThreadTransfer &b_scale_thread_copy, const BScaleGridBuffer &b_scale_grid_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp:389
ck::BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3< BlockGemmPipelineScheduler::Intrawave, ThreadBlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::ComputeTypeA typename Base::ComputeTypeA ComputeTypeA
Definition: blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp:155
ck::BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3< BlockGemmPipelineScheduler::Intrawave, ThreadBlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Tuple5 typename Base::Tuple5 Tuple5
Definition: blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp:154
ck::BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3< BlockGemmPipelineScheduler::Intrawave, ThreadBlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop static constexpr __host__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp:183
ck::BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3< BlockGemmPipelineScheduler::Intrawave, ThreadBlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum static constexpr __host__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp:188
ck::BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3< BlockGemmPipelineScheduler::Intrawave, ThreadBlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::AccType typename Base::AccType AccType
Definition: blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp:153
ck::BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3< BlockGemmPipelineScheduler::Intrawave, ThreadBlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::ComputeTypeB typename Base::ComputeTypeB ComputeTypeB
Definition: blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp:156
ck::BlockwiseGemmXdlops_pipeline_mx_moe_nbs_v3< BlockGemmPipelineScheduler::Intrawave, ThreadBlockSize, ScaleBlockSize, ADataType, AScaleDataType, BDataType, BScaleDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::HotLoopScheduler static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp:193
Definition: blockwise_gemm_pipeline_xdlops_mx_moe_nbs_v3.hpp:38
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:25
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10