15 static constexpr
bool is_scale_mfma_data_type()
17 using U = element_type_t<T>;
18 return is_same_v<U, f8_ocp_t> || is_same_v<U, bf8_ocp_t> || is_same_v<U, f6_t> ||
19 is_same_v<U, bf6_t> || is_same_v<U, f4_t>;
26 static constexpr
bool is_scale_mfma_scale_type()
28 return is_same_v<T, e8m0_bexp_t>;
34 template <
typename ADataType,
typename BDataType,
typename AScaleDataType,
typename BScaleDataType>
35 static constexpr
bool scale_mfma_hw_support()
37 return is_scale_mfma_data_type<ADataType>() && is_scale_mfma_data_type<BDataType>() &&
38 is_scale_mfma_scale_type<AScaleDataType>() && is_scale_mfma_scale_type<BScaleDataType>();
82 template <MfmaInstr instr>
89 static constexpr
index_t num_groups_per_blk = 4;
90 static constexpr
index_t num_regs_per_blk = 16;
91 static constexpr
index_t num_threads_per_blk = 32;
93 static constexpr
index_t num_input_blks = 2;
94 static constexpr
index_t num_output_blks = 2;
98 static constexpr
bool is_k_reduction =
false;
100 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
101 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
111 static constexpr
index_t num_groups_per_blk = 4;
112 static constexpr
index_t num_regs_per_blk = 16;
113 static constexpr
index_t num_threads_per_blk = 32;
120 static constexpr
bool is_k_reduction =
true;
122 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
123 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
133 static constexpr
index_t num_groups_per_blk = 1;
134 static constexpr
index_t num_regs_per_blk = 4;
135 static constexpr
index_t num_threads_per_blk = 16;
142 static constexpr
bool is_k_reduction =
true;
144 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
145 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
155 static constexpr
index_t num_groups_per_blk = 1;
156 static constexpr
index_t num_regs_per_blk = 4;
157 static constexpr
index_t num_threads_per_blk = 16;
164 static constexpr
bool is_k_reduction =
false;
166 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
167 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
178 static constexpr
index_t num_groups_per_blk = 1;
179 static constexpr
index_t num_regs_per_blk = 4;
180 static constexpr
index_t num_threads_per_blk = 64;
187 static constexpr
bool is_k_reduction =
false;
189 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
190 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
200 static constexpr
index_t num_groups_per_blk = 4;
201 static constexpr
index_t num_regs_per_blk = 16;
202 static constexpr
index_t num_threads_per_blk = 32;
209 static constexpr
bool is_k_reduction =
false;
211 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
212 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
222 static constexpr
index_t num_groups_per_blk = 4;
223 static constexpr
index_t num_regs_per_blk = 16;
224 static constexpr
index_t num_threads_per_blk = 32;
231 static constexpr
bool is_k_reduction =
true;
233 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
234 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
244 static constexpr
index_t num_groups_per_blk = 4;
245 static constexpr
index_t num_regs_per_blk = 16;
246 static constexpr
index_t num_threads_per_blk = 32;
253 static constexpr
bool is_k_reduction =
true;
255 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
256 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
266 static constexpr
index_t num_groups_per_blk = 1;
267 static constexpr
index_t num_regs_per_blk = 4;
268 static constexpr
index_t num_threads_per_blk = 16;
275 static constexpr
bool is_k_reduction =
true;
277 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
278 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
288 static constexpr
index_t num_groups_per_blk = 1;
289 static constexpr
index_t num_regs_per_blk = 4;
290 static constexpr
index_t num_threads_per_blk = 16;
297 static constexpr
bool is_k_reduction =
true;
299 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
300 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
310 static constexpr
index_t num_groups_per_blk = 1;
311 static constexpr
index_t num_regs_per_blk = 4;
312 static constexpr
index_t num_threads_per_blk = 16;
319 static constexpr
bool is_k_reduction =
false;
321 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
322 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
332 static constexpr
index_t num_groups_per_blk = 1;
333 static constexpr
index_t num_regs_per_blk = 4;
334 static constexpr
index_t num_threads_per_blk = 64;
341 static constexpr
bool is_k_reduction =
false;
343 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
344 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
354 static constexpr
index_t num_groups_per_blk = 4;
355 static constexpr
index_t num_regs_per_blk = 16;
356 static constexpr
index_t num_threads_per_blk = 32;
363 static constexpr
bool is_k_reduction =
true;
365 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
366 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
376 static constexpr
index_t num_groups_per_blk = 4;
377 static constexpr
index_t num_regs_per_blk = 16;
378 static constexpr
index_t num_threads_per_blk = 32;
385 static constexpr
bool is_k_reduction =
true;
387 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
388 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
398 static constexpr
index_t num_groups_per_blk = 1;
399 static constexpr
index_t num_regs_per_blk = 4;
400 static constexpr
index_t num_threads_per_blk = 16;
407 static constexpr
bool is_k_reduction =
true;
409 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
410 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
420 static constexpr
index_t num_groups_per_blk = 1;
421 static constexpr
index_t num_regs_per_blk = 4;
422 static constexpr
index_t num_threads_per_blk = 16;
429 static constexpr
bool is_k_reduction =
true;
431 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
432 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
442 static constexpr
index_t num_groups_per_blk = 4;
443 static constexpr
index_t num_regs_per_blk = 16;
444 static constexpr
index_t num_threads_per_blk = 32;
451 static constexpr
bool is_k_reduction =
true;
453 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
454 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
464 static constexpr
index_t num_groups_per_blk = 1;
465 static constexpr
index_t num_regs_per_blk = 4;
466 static constexpr
index_t num_threads_per_blk = 16;
473 static constexpr
bool is_k_reduction =
true;
475 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
476 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
486 static constexpr
index_t num_groups_per_blk = 4;
487 static constexpr
index_t num_regs_per_blk = 16;
488 static constexpr
index_t num_threads_per_blk = 32;
495 static constexpr
bool is_k_reduction =
true;
497 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
498 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
508 static constexpr
index_t num_groups_per_blk = 1;
509 static constexpr
index_t num_regs_per_blk = 4;
510 static constexpr
index_t num_threads_per_blk = 16;
517 static constexpr
bool is_k_reduction =
true;
519 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
520 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
530 static constexpr
index_t num_groups_per_blk = 4;
531 static constexpr
index_t num_regs_per_blk = 16;
532 static constexpr
index_t num_threads_per_blk = 32;
539 static constexpr
bool is_k_reduction =
true;
541 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
542 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
552 static constexpr
index_t num_groups_per_blk = 1;
553 static constexpr
index_t num_regs_per_blk = 4;
554 static constexpr
index_t num_threads_per_blk = 16;
561 static constexpr
bool is_k_reduction =
true;
563 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
564 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
574 static constexpr
index_t num_groups_per_blk = 4;
575 static constexpr
index_t num_regs_per_blk = 16;
576 static constexpr
index_t num_threads_per_blk = 32;
583 static constexpr
bool is_k_reduction =
true;
585 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
586 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
596 static constexpr
index_t num_groups_per_blk = 1;
597 static constexpr
index_t num_regs_per_blk = 4;
598 static constexpr
index_t num_threads_per_blk = 16;
605 static constexpr
bool is_k_reduction =
true;
607 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
608 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
618 static constexpr
index_t num_groups_per_blk = 4;
619 static constexpr
index_t num_regs_per_blk = 4;
620 static constexpr
index_t num_threads_per_blk = 16;
627 static constexpr
bool is_k_reduction =
true;
629 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
630 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
640 static constexpr
index_t num_groups_per_blk = 4;
641 static constexpr
index_t num_regs_per_blk = 16;
642 static constexpr
index_t num_threads_per_blk = 32;
649 static constexpr
bool is_k_reduction =
true;
651 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
652 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
662 static constexpr
index_t num_groups_per_blk = 1;
663 static constexpr
index_t num_regs_per_blk = 4;
664 static constexpr
index_t num_threads_per_blk = 16;
671 static constexpr
bool is_k_reduction =
true;
673 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
674 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
684 static constexpr
index_t num_groups_per_blk = 4;
685 static constexpr
index_t num_regs_per_blk = 16;
686 static constexpr
index_t num_threads_per_blk = 32;
693 static constexpr
bool is_k_reduction =
true;
695 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
696 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
706 static constexpr
index_t num_groups_per_blk = 1;
707 static constexpr
index_t num_regs_per_blk = 4;
708 static constexpr
index_t num_threads_per_blk = 16;
715 static constexpr
bool is_k_reduction =
true;
717 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
718 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
728 static constexpr
index_t num_groups_per_blk = 4;
729 static constexpr
index_t num_regs_per_blk = 16;
730 static constexpr
index_t num_threads_per_blk = 32;
737 static constexpr
bool is_k_reduction =
true;
739 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
740 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
750 static constexpr
index_t num_groups_per_blk = 1;
751 static constexpr
index_t num_regs_per_blk = 4;
752 static constexpr
index_t num_threads_per_blk = 16;
759 static constexpr
bool is_k_reduction =
true;
761 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
762 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
772 static constexpr
index_t num_groups_per_blk = 4;
773 static constexpr
index_t num_regs_per_blk = 16;
774 static constexpr
index_t num_threads_per_blk = 32;
781 static constexpr
bool is_k_reduction =
true;
783 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
784 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
794 static constexpr
index_t num_groups_per_blk = 1;
795 static constexpr
index_t num_regs_per_blk = 4;
796 static constexpr
index_t num_threads_per_blk = 16;
803 static constexpr
bool is_k_reduction =
true;
805 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
806 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
817 static constexpr
index_t num_groups_per_blk = 4;
818 static constexpr
index_t num_regs_per_blk = 16;
819 static constexpr
index_t num_threads_per_blk = 32;
826 static constexpr
bool is_k_reduction =
true;
829 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
830 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
841 static constexpr
index_t num_groups_per_blk = 1;
842 static constexpr
index_t num_regs_per_blk = 4;
843 static constexpr
index_t num_threads_per_blk = 16;
850 static constexpr
bool is_k_reduction =
true;
853 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
854 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
865 static constexpr
index_t num_groups_per_blk = 4;
866 static constexpr
index_t num_regs_per_blk = 16;
867 static constexpr
index_t num_threads_per_blk = 32;
874 static constexpr
bool is_k_reduction =
true;
886 __device__
void run(
const FloatA& a,
887 const ScaleA& scale_a,
889 const ScaleB& scale_b,
893 a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
902 static constexpr
index_t num_groups_per_blk = 1;
903 static constexpr
index_t num_regs_per_blk = 4;
904 static constexpr
index_t num_threads_per_blk = 16;
911 static constexpr
bool is_k_reduction =
true;
923 __device__
void run(
const FloatA& a,
924 const ScaleA& scale_a,
926 const ScaleB& scale_b,
931 a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
935 template <
typename base_type,
938 typename additional_type = base_type,
939 bool is_single_rate_mfma =
false,
940 bool is_scale_mfma =
false>
943 template <
typename base_type_,
946 typename additional_type_ = base_type_,
947 bool is_single_rate_mfma_ =
false,
948 bool is_scale_mfma_ =
false>
952 constexpr
auto GetMfma<double, 16, 16>()
958 constexpr
auto GetMfma<float, 64, 64>()
964 constexpr
auto GetMfma<float, 32, 64>()
970 constexpr
auto GetMfma<float, 16, 64>()
976 constexpr
auto GetMfma<float, 8, 64>()
982 constexpr
auto GetMfma<float, 4, 64>()
988 constexpr
auto GetMfma<float, 32, 32>()
994 constexpr
auto GetMfma<float, 16, 16>()
1000 constexpr
auto GetMfma<half_t, 64, 64>()
1006 constexpr
auto GetMfma<half_t, 32, 64>()
1012 constexpr
auto GetMfma<half_t, 32, 32, half_t, false>()
1014 #if defined(__gfx950__)
1021 constexpr
auto GetMfma<half_t, 32, 32, half_t, true>()
1027 constexpr
auto GetMfma<half_t, 16, 16, half_t, false>()
1029 #if defined(__gfx950__)
1037 constexpr
auto GetMfma<half_t, 16, 16, half_t, true>()
1043 constexpr
auto GetMfma<half_t, 16, 64>()
1049 constexpr
auto GetMfma<half_t, 8, 64>()
1055 constexpr
auto GetMfma<half_t, 4, 64>()
1061 constexpr
auto GetMfma<bhalf_t, 32, 32, bhalf_t, false>()
1063 #if defined(__gfx950__)
1065 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1073 constexpr
auto GetMfma<bhalf_t, 32, 32, bhalf_t, true>()
1075 #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1083 constexpr
auto GetMfma<bhalf_t, 16, 16, bhalf_t, false>()
1085 #if defined(__gfx950__)
1087 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1095 constexpr
auto GetMfma<bhalf_t, 16, 16, bhalf_t, true>()
1097 #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1105 constexpr
auto GetMfma<int8_t, 32, 32, int8_t, false>()
1107 #if defined(__gfx950__)
1109 #elif defined(__gfx942__)
1117 constexpr
auto GetMfma<int8_t, 32, 32, int8_t, true>()
1119 #if defined(__gfx942__) || defined(__gfx950__)
1127 constexpr
auto GetMfma<int8_t, 16, 16, int8_t, false>()
1129 #if defined(__gfx950__)
1131 #elif defined(__gfx942__)
1139 constexpr
auto GetMfma<int8_t, 16, 16, int8_t, true>()
1141 #if defined(__gfx942__) || defined(__gfx950__)
1149 constexpr
auto GetMfma<f8_t, 32, 32, f8_t, true, false>()
1155 constexpr
auto GetMfma<f8_t, 32, 32, f8_t, false, false>()
1157 #if defined(__gfx950__)
1165 constexpr
auto GetMfma<f8_t, 32, 32, f8_t, false, true>()
1171 constexpr
auto GetMfma<bf8_t, 32, 32, f8_t, false, true>()
1176 constexpr
auto GetMfma<f4_t, 32, 32, f4_t, false, true>()
1181 constexpr
auto GetMfma<f4_t, 16, 16, f4_t, false, true>()
1187 constexpr
auto GetMfma<f8_t, 16, 16, f8_t, true, false>()
1193 constexpr
auto GetMfma<f8_t, 16, 16, f8_t, false, false>()
1195 #if defined(__gfx950__)
1203 constexpr
auto GetMfma<f8_t, 16, 16, f8_t, false, true>()
1209 constexpr
auto GetMfma<bf8_t, 16, 16, bf8_t, false, true>()
1215 constexpr
auto GetMfma<f8_t, 16, 16, bf8_t, false, true>()
1221 constexpr
auto GetMfma<bf8_t, 16, 16, f8_t, false, true>()
1227 constexpr
auto GetMfma<f6_t, 32, 32, f6_t, false, true>()
1232 constexpr
auto GetMfma<f6_t, 16, 16, f6_t, false, true>()
1237 constexpr
auto GetMfma<bf6_t, 32, 32, bf6_t, false, true>()
1242 constexpr
auto GetMfma<bf6_t, 16, 16, bf6_t, false, true>()
1248 constexpr
auto GetMfma<bf8_t, 32, 32, bf8_t, true, false>()
1254 constexpr
auto GetMfma<bf8_t, 32, 32, bf8_t, false, false>()
1256 #if defined(__gfx950__)
1264 constexpr
auto GetMfma<bf8_t, 16, 16, bf8_t, true, false>()
1270 constexpr
auto GetMfma<bf8_t, 16, 16, bf8_t, false, false>()
1272 #if defined(__gfx950__)
1280 constexpr
auto GetMfma<f8_t, 32, 32, bf8_t, true, false>()
1286 constexpr
auto GetMfma<f8_t, 32, 32, bf8_t, false, false>()
1288 #if defined(__gfx950__)
1296 constexpr
auto GetMfma<f8_t, 16, 16, bf8_t, true, false>()
1302 constexpr
auto GetMfma<f8_t, 16, 16, bf8_t, false, false>()
1304 #if defined(__gfx950__)
1312 constexpr
auto GetMfma<bf8_t, 32, 32, f8_t, true, false>()
1318 constexpr
auto GetMfma<bf8_t, 32, 32, f8_t, false, false>()
1320 #if defined(__gfx950__)
1328 constexpr
auto GetMfma<bf8_t, 16, 16, f8_t, true, false>()
1334 constexpr
auto GetMfma<bf8_t, 16, 16, f8_t, false, false>()
1336 #if defined(__gfx950__)
1347 is_single_rate_mfma,
1348 is_scale_mfma>()>{};
1354 "wrong! num_regs_per_blk");
1357 "n_per_blk != num_threads_per_blk");
1361 "m_per_blk != num_input_blks * num_regs_per_blk");
1365 "incorrect num_output_blks");
1369 "num_regs_per_blk incorrect");
1373 "is_k_reduction wrong!");
1378 static_assert(NPerXdlops >= MPerXdlops,
"only support ABroadcast");
1391 template <
typename base_type,
1395 typename additional_type = base_type,
1396 bool TransposeC =
false,
1397 bool is_scale_mfma =
false>
1414 return MPerXdlops * NPerXdlops /
1420 static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 ||
1422 "Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
1424 static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 ||
1426 "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
1428 static_assert(KPack %
mfma_instr.k_per_blk == 0,
"KPack should be a multiple of k_per_blk");
1433 template <
typename CDesc_M0_N0_M1_N1_M2_N2>
1434 __host__ __device__
static constexpr
auto
1437 const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I0);
1438 const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I1);
1439 const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I2);
1440 const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I3);
1443 c_desc_m0_n0_m1_n1_m2_n2,
1468 template <
typename CDesc_M0_N0_M1_N1_M2_N2>
1470 const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
1472 const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I0);
1473 const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I1);
1474 const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I2);
1475 const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I3);
1476 const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I4);
1477 const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I5);
1480 c_desc_m0_n0_m1_n1_m2_n2,
1511 template <
typename CDesc_M0_N0_M1_N1_M2_N2>
1512 __host__ __device__
static constexpr
auto
1515 const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I0);
1516 const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I1);
1517 const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I2);
1518 const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I3);
1521 c_desc_m0_n0_m1_n1_m2_n2,
1544 template <
typename CDesc_G_M0_N0_M1_N1_M2_N2>
1546 const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2)
1548 const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I0);
1549 const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I1);
1550 const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I2);
1551 const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I3);
1552 const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I4);
1555 c_desc_g_m0_n0_m1_n1_m2_n2,
1583 return MPerXdlops * NPerXdlops /
mfma_instr.wave_size;
1588 template <
class FloatA,
class FloatB,
class FloatC>
1589 __device__
void Run(
const FloatA& p_a_wave,
const FloatB& p_b_wave, FloatC& p_c_thread)
const
1598 "base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!");
1601 if constexpr(!TransposeC)
1603 mfma_instr.template run<MPerXdlops, NPerXdlops>(
1604 p_a_wave[k], p_b_wave[k], p_c_thread);
1608 mfma_instr.template run<MPerXdlops, NPerXdlops>(
1609 p_b_wave[k], p_a_wave[k], p_c_thread);
1621 __device__
void Run(
const FloatA& p_a_wave,
1622 const ScaleA& a_scale_thread,
1623 const FloatB& p_b_wave,
1624 const ScaleB& b_scale_thread,
1625 FloatC& p_c_thread)
const
1628 if constexpr(!TransposeC)
1630 mfma_instr.template run<MPerXdlops, NPerXdlops, OpselA, OpselB>(
1631 p_a_wave[k], a_scale_thread[k], p_b_wave[k], b_scale_thread[k], p_c_thread);
1635 mfma_instr.template run<MPerXdlops, NPerXdlops, OpselB, OpselA>(
1636 p_b_wave[k], b_scale_thread[k], p_a_wave[k], a_scale_thread[k], p_c_thread);
1653 const auto blk_idx =
1654 threadidx_to_blk_idx_adaptor.CalculateBottomIndex(
make_multi_index(laneId));
1656 const auto blk_id = blk_idx[
I1];
1657 const auto blk_td = blk_idx[
I2];
1667 const auto blk_id = blk_idx[
I0];
1668 const auto blk_td = blk_idx[
I1];
1685 const auto blk_id = blk_idx[
I0];
1686 const auto blk_td = blk_idx[
I1];
1702 const auto blk_id = blk_idx[
I0];
1703 const auto blk_td = blk_idx[
I1];
1708 return TransposeC ?
CIndex{n_offset, m_offset} :
CIndex{m_offset, n_offset};
1715 const auto blk_id = blk_idx[
I0];
1716 const auto blk_td = blk_idx[
I1];
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
MfmaInstr
Definition: xdlops_gemm.hpp:42
@ mfma_f32_32x32x64f8f6f4
@ mfma_scale_f32_32x32x64f8f6f4
@ mfma_f32_16x16x16bf16_1k
@ mfma_scale_f32_16x16x128f8f6f4
@ mfma_f32_16x16x32bf8bf8
@ mfma_f32_16x16x128f8f6f4
@ mfma_f32_32x32x16bf8bf8
@ mfma_f32_32x32x8bf16_1k
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
typename packed_type_info< T >::element_type element_type_t
Definition: data_type.hpp:405
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:300
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
Definition: xdlops_gemm.hpp:942
__host__ constexpr __device__ MfmaSelector()
Definition: xdlops_gemm.hpp:1350
static constexpr bool IsABroadcast()
Definition: xdlops_gemm.hpp:1376
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1388
static constexpr auto GetMfma()
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1343
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1382
Definition: sequence.hpp:43
Definition: xdlops_gemm.hpp:1399
static constexpr auto mfma_instr
Definition: xdlops_gemm.hpp:1739
__host__ constexpr __device__ XdlopsGemm()
Definition: xdlops_gemm.hpp:1418
__host__ static __device__ auto CalculateBThreadOriginDataIndex()
Definition: xdlops_gemm.hpp:1680
static __device__ auto GetBlkIdx()
Definition: xdlops_gemm.hpp:1643
static constexpr auto I2
Definition: xdlops_gemm.hpp:1402
static constexpr __device__ index_t GetNumBlks()
Definition: xdlops_gemm.hpp:1410
static __device__ auto GetLaneId()
Definition: xdlops_gemm.hpp:1641
static constexpr auto K0PerXdlops
Definition: xdlops_gemm.hpp:1743
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1469
static constexpr __device__ index_t GetNumXdlops()
Definition: xdlops_gemm.hpp:1412
__host__ static __device__ auto CalculateAThreadOriginDataIndex()
Definition: xdlops_gemm.hpp:1662
static constexpr bool is_single_rate_mfma
Definition: xdlops_gemm.hpp:1724
static __device__ CIndex4D GetBeginOfThreadBlk4D(index_t, index_t)
Definition: xdlops_gemm.hpp:1711
static constexpr __device__ index_t GetWaveSize()
Definition: xdlops_gemm.hpp:1586
static constexpr __device__ index_t GetRegSizePerXdlops()
Definition: xdlops_gemm.hpp:1581
static constexpr auto I5
Definition: xdlops_gemm.hpp:1405
static constexpr auto I3
Definition: xdlops_gemm.hpp:1403
static constexpr auto I0
Definition: xdlops_gemm.hpp:1400
__device__ void Run(const FloatA &p_a_wave, const ScaleA &a_scale_thread, const FloatB &p_b_wave, const ScaleB &b_scale_thread, FloatC &p_c_thread) const
Definition: xdlops_gemm.hpp:1621
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1435
__host__ static constexpr __device__ auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_G_M0_N0_M1_N1_M2_N2 &c_desc_g_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1545
static constexpr auto I1
Definition: xdlops_gemm.hpp:1401
static constexpr auto K1PerXdlops
Definition: xdlops_gemm.hpp:1742
static constexpr auto KPerXdlops
Definition: xdlops_gemm.hpp:1741
static constexpr auto I4
Definition: xdlops_gemm.hpp:1404
__device__ void Run(const FloatA &p_a_wave, const FloatB &p_b_wave, FloatC &p_c_thread) const
Definition: xdlops_gemm.hpp:1589
static constexpr auto mfma
Definition: xdlops_gemm.hpp:1732
static __device__ CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
Definition: xdlops_gemm.hpp:1698
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1513
__host__ static constexpr __device__ auto GetCM0M1M2NThreadBlkLengths()
Definition: xdlops_gemm.hpp:1745
Definition: integral_constant.hpp:20
Definition: amd_xdlops.hpp:1202
Definition: amd_xdlops.hpp:303
Definition: amd_xdlops.hpp:193
Definition: amd_xdlops.hpp:70
Definition: amd_xdlops.hpp:269
Definition: amd_xdlops.hpp:1483
Definition: amd_xdlops.hpp:1609
Definition: amd_xdlops.hpp:159
Definition: amd_xdlops.hpp:1546
Definition: amd_xdlops.hpp:1420
Definition: amd_xdlops.hpp:207
Definition: amd_xdlops.hpp:56
Definition: amd_xdlops.hpp:331
Definition: amd_xdlops.hpp:249
Definition: amd_xdlops.hpp:1451
Definition: amd_xdlops.hpp:1577
Definition: amd_xdlops.hpp:139
Definition: amd_xdlops.hpp:1514
Definition: amd_xdlops.hpp:1388
Definition: amd_xdlops.hpp:15
Definition: amd_xdlops.hpp:42
Definition: amd_xdlops.hpp:317
Definition: amd_xdlops.hpp:112
Definition: amd_xdlops.hpp:481
Definition: amd_xdlops.hpp:289
Definition: amd_xdlops.hpp:179
Definition: amd_xdlops.hpp:84
Definition: amd_xdlops.hpp:221
Definition: amd_xdlops.hpp:461
Definition: amd_xdlops.hpp:364
Definition: amd_xdlops.hpp:442
Definition: amd_xdlops.hpp:403
Definition: amd_xdlops.hpp:423
Definition: amd_xdlops.hpp:383
Definition: amd_xdlops.hpp:345
Definition: amd_xdlops.hpp:886
Definition: amd_xdlops.hpp:666
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:854
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:432
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:300
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:167
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:410
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:718
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:806
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:278
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:762
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:674
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:322
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:145
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:476
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:366
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:696
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:784
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:256
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:740
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:652
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:101
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:123
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:454
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:212
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:830
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:388
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:234
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:190
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:344
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:630
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:520
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:564
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:608
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:542
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:586
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:498
__device__ void run(const FloatA &a, const ScaleA &scale_a, const FloatB &b, const ScaleB &scale_b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:923
__device__ void run(const FloatA &a, const ScaleA &scale_a, const FloatB &b, const ScaleB &scale_b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:886
Definition: xdlops_gemm.hpp:83
Definition: functional2.hpp:33