16 static constexpr
bool is_scale_mfma_data_type()
18 using U = element_type_t<T>;
19 return is_same_v<U, f8_ocp_t> || is_same_v<U, bf8_ocp_t> || is_same_v<U, f6_t> ||
20 is_same_v<U, bf6_t> || is_same_v<U, f4_t>;
23 #ifndef CK_CODE_GEN_RTC
28 static constexpr
bool is_scale_mfma_scale_type()
30 return is_same_v<T, e8m0_bexp_t>;
37 template <
typename ADataType,
typename BDataType,
typename AScaleDataType,
typename BScaleDataType>
38 static constexpr
bool scale_mfma_hw_support()
40 return is_scale_mfma_data_type<ADataType>() && is_scale_mfma_data_type<BDataType>() &&
41 is_scale_mfma_scale_type<AScaleDataType>() && is_scale_mfma_scale_type<BScaleDataType>();
103 template <MfmaInstr instr>
110 static constexpr
index_t num_groups_per_blk = 4;
111 static constexpr
index_t num_regs_per_blk = 16;
112 static constexpr
index_t num_threads_per_blk = 32;
119 static constexpr
bool is_k_reduction =
false;
121 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
122 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
132 static constexpr
index_t num_groups_per_blk = 4;
133 static constexpr
index_t num_regs_per_blk = 16;
134 static constexpr
index_t num_threads_per_blk = 32;
141 static constexpr
bool is_k_reduction =
true;
143 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
144 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
154 static constexpr
index_t num_groups_per_blk = 1;
155 static constexpr
index_t num_regs_per_blk = 4;
156 static constexpr
index_t num_threads_per_blk = 16;
163 static constexpr
bool is_k_reduction =
true;
165 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
166 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
176 static constexpr
index_t num_groups_per_blk = 1;
177 static constexpr
index_t num_regs_per_blk = 4;
178 static constexpr
index_t num_threads_per_blk = 16;
185 static constexpr
bool is_k_reduction =
false;
187 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
188 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
199 static constexpr
index_t num_groups_per_blk = 1;
200 static constexpr
index_t num_regs_per_blk = 4;
201 static constexpr
index_t num_threads_per_blk = 64;
208 static constexpr
bool is_k_reduction =
false;
210 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
211 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
221 static constexpr
index_t num_groups_per_blk = 4;
222 static constexpr
index_t num_regs_per_blk = 16;
223 static constexpr
index_t num_threads_per_blk = 32;
230 static constexpr
bool is_k_reduction =
false;
232 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
233 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
243 static constexpr
index_t num_groups_per_blk = 4;
244 static constexpr
index_t num_regs_per_blk = 16;
245 static constexpr
index_t num_threads_per_blk = 32;
252 static constexpr
bool is_k_reduction =
true;
254 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
255 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
265 static constexpr
index_t num_groups_per_blk = 4;
266 static constexpr
index_t num_regs_per_blk = 16;
267 static constexpr
index_t num_threads_per_blk = 32;
274 static constexpr
bool is_k_reduction =
true;
276 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
277 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
287 static constexpr
index_t num_groups_per_blk = 1;
288 static constexpr
index_t num_regs_per_blk = 4;
289 static constexpr
index_t num_threads_per_blk = 16;
296 static constexpr
bool is_k_reduction =
true;
298 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
299 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
309 static constexpr
index_t num_groups_per_blk = 1;
310 static constexpr
index_t num_regs_per_blk = 4;
311 static constexpr
index_t num_threads_per_blk = 16;
318 static constexpr
bool is_k_reduction =
true;
320 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
321 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
331 static constexpr
index_t num_groups_per_blk = 1;
332 static constexpr
index_t num_regs_per_blk = 4;
333 static constexpr
index_t num_threads_per_blk = 16;
340 static constexpr
bool is_k_reduction =
false;
342 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
343 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
353 static constexpr
index_t num_groups_per_blk = 1;
354 static constexpr
index_t num_regs_per_blk = 4;
355 static constexpr
index_t num_threads_per_blk = 64;
362 static constexpr
bool is_k_reduction =
false;
364 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
365 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
375 static constexpr
index_t num_groups_per_blk = 4;
376 static constexpr
index_t num_regs_per_blk = 16;
377 static constexpr
index_t num_threads_per_blk = 32;
384 static constexpr
bool is_k_reduction =
true;
386 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
387 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
397 static constexpr
index_t num_groups_per_blk = 4;
398 static constexpr
index_t num_regs_per_blk = 16;
399 static constexpr
index_t num_threads_per_blk = 32;
406 static constexpr
bool is_k_reduction =
true;
408 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
409 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
419 static constexpr
index_t num_groups_per_blk = 1;
420 static constexpr
index_t num_regs_per_blk = 4;
421 static constexpr
index_t num_threads_per_blk = 16;
428 static constexpr
bool is_k_reduction =
true;
430 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
431 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
441 static constexpr
index_t num_groups_per_blk = 1;
442 static constexpr
index_t num_regs_per_blk = 4;
443 static constexpr
index_t num_threads_per_blk = 16;
450 static constexpr
bool is_k_reduction =
true;
452 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
453 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
463 static constexpr
index_t num_groups_per_blk = 4;
464 static constexpr
index_t num_regs_per_blk = 16;
465 static constexpr
index_t num_threads_per_blk = 32;
472 static constexpr
bool is_k_reduction =
true;
474 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
475 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
485 static constexpr
index_t num_groups_per_blk = 1;
486 static constexpr
index_t num_regs_per_blk = 4;
487 static constexpr
index_t num_threads_per_blk = 16;
494 static constexpr
bool is_k_reduction =
true;
496 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
497 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
507 static constexpr
index_t num_groups_per_blk = 4;
508 static constexpr
index_t num_regs_per_blk = 16;
509 static constexpr
index_t num_threads_per_blk = 32;
516 static constexpr
bool is_k_reduction =
true;
518 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
519 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
529 static constexpr
index_t num_groups_per_blk = 1;
530 static constexpr
index_t num_regs_per_blk = 4;
531 static constexpr
index_t num_threads_per_blk = 16;
538 static constexpr
bool is_k_reduction =
true;
540 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
541 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
551 static constexpr
index_t num_groups_per_blk = 4;
552 static constexpr
index_t num_regs_per_blk = 16;
553 static constexpr
index_t num_threads_per_blk = 32;
560 static constexpr
bool is_k_reduction =
true;
562 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
563 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
573 static constexpr
index_t num_groups_per_blk = 1;
574 static constexpr
index_t num_regs_per_blk = 4;
575 static constexpr
index_t num_threads_per_blk = 16;
582 static constexpr
bool is_k_reduction =
true;
584 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
585 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
595 static constexpr
index_t num_groups_per_blk = 4;
596 static constexpr
index_t num_regs_per_blk = 16;
597 static constexpr
index_t num_threads_per_blk = 32;
604 static constexpr
bool is_k_reduction =
true;
606 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
607 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
617 static constexpr
index_t num_groups_per_blk = 1;
618 static constexpr
index_t num_regs_per_blk = 4;
619 static constexpr
index_t num_threads_per_blk = 16;
626 static constexpr
bool is_k_reduction =
true;
628 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
629 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
639 static constexpr
index_t num_groups_per_blk = 4;
640 static constexpr
index_t num_regs_per_blk = 4;
641 static constexpr
index_t num_threads_per_blk = 16;
648 static constexpr
bool is_k_reduction =
true;
650 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
651 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
661 static constexpr
index_t num_groups_per_blk = 4;
662 static constexpr
index_t num_regs_per_blk = 16;
663 static constexpr
index_t num_threads_per_blk = 32;
670 static constexpr
bool is_k_reduction =
true;
672 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
673 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
683 static constexpr
index_t num_groups_per_blk = 1;
684 static constexpr
index_t num_regs_per_blk = 4;
685 static constexpr
index_t num_threads_per_blk = 16;
692 static constexpr
bool is_k_reduction =
true;
694 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
695 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
705 static constexpr
index_t num_groups_per_blk = 4;
706 static constexpr
index_t num_regs_per_blk = 16;
707 static constexpr
index_t num_threads_per_blk = 32;
714 static constexpr
bool is_k_reduction =
true;
716 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
717 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
727 static constexpr
index_t num_groups_per_blk = 1;
728 static constexpr
index_t num_regs_per_blk = 4;
729 static constexpr
index_t num_threads_per_blk = 16;
736 static constexpr
bool is_k_reduction =
true;
738 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
739 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
749 static constexpr
index_t num_groups_per_blk = 4;
750 static constexpr
index_t num_regs_per_blk = 16;
751 static constexpr
index_t num_threads_per_blk = 32;
758 static constexpr
bool is_k_reduction =
true;
760 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
761 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
771 static constexpr
index_t num_groups_per_blk = 1;
772 static constexpr
index_t num_regs_per_blk = 4;
773 static constexpr
index_t num_threads_per_blk = 16;
780 static constexpr
bool is_k_reduction =
true;
782 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
783 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
793 static constexpr
index_t num_groups_per_blk = 4;
794 static constexpr
index_t num_regs_per_blk = 16;
795 static constexpr
index_t num_threads_per_blk = 32;
802 static constexpr
bool is_k_reduction =
true;
804 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
805 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
815 static constexpr
index_t num_groups_per_blk = 1;
816 static constexpr
index_t num_regs_per_blk = 4;
817 static constexpr
index_t num_threads_per_blk = 16;
824 static constexpr
bool is_k_reduction =
true;
826 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
827 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
838 static constexpr
index_t num_groups_per_blk = 4;
839 static constexpr
index_t num_regs_per_blk = 16;
840 static constexpr
index_t num_threads_per_blk = 32;
847 static constexpr
bool is_k_reduction =
true;
850 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
851 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
862 static constexpr
index_t num_groups_per_blk = 1;
863 static constexpr
index_t num_regs_per_blk = 4;
864 static constexpr
index_t num_threads_per_blk = 16;
871 static constexpr
bool is_k_reduction =
true;
874 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
875 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
886 static constexpr
index_t num_groups_per_blk = 4;
887 static constexpr
index_t num_regs_per_blk = 16;
888 static constexpr
index_t num_threads_per_blk = 32;
895 static constexpr
bool is_k_reduction =
true;
907 __device__
void run(
const FloatA&
a,
908 const ScaleA& scale_a,
910 const ScaleB& scale_b,
914 a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
923 static constexpr
index_t num_groups_per_blk = 1;
924 static constexpr
index_t num_regs_per_blk = 4;
925 static constexpr
index_t num_threads_per_blk = 16;
932 static constexpr
bool is_k_reduction =
true;
944 __device__
void run(
const FloatA&
a,
945 const ScaleA& scale_a,
947 const ScaleB& scale_b,
952 a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
981 static constexpr
index_t num_threads_per_blk = n_per_blk;
982 static constexpr
index_t num_regs_per_blk = m_per_blk * n_per_blk / wave_size;
983 static constexpr
index_t num_input_blks = m_per_blk / num_regs_per_blk;
985 static constexpr
index_t num_groups_per_blk = 1;
988 static constexpr
bool is_k_reduction =
true;
991 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
992 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1004 static constexpr
index_t num_threads_per_blk = n_per_blk;
1005 static constexpr
index_t num_regs_per_blk = m_per_blk * n_per_blk / wave_size;
1006 static constexpr
index_t num_input_blks = m_per_blk / num_regs_per_blk;
1011 static constexpr
bool is_k_reduction =
true;
1013 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
1014 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1027 static constexpr
index_t num_threads_per_blk = 32;
1034 static constexpr
bool is_k_reduction =
true;
1036 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
1037 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1049 static constexpr
index_t num_threads_per_blk = 16;
1056 static constexpr
bool is_k_reduction =
true;
1058 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
1059 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1084 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1085 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1094 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1095 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1112 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1122 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1123 __device__
void run(
const FloatA&,
const FloatB&, FloatC&)
const
1148 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1149 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1158 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1159 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1176 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1186 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1187 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1196 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1197 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1206 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1207 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1216 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1217 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
1227 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
1228 __device__
void run(
const FloatA&,
const FloatB&, FloatC&)
const
1248 template <
typename base_type,
1251 typename additional_type = base_type,
1252 bool is_single_rate_mfma =
false,
1253 bool is_scale_mfma =
false>
1256 template <
typename base_type_,
1259 typename additional_type_ = base_type_,
1260 bool is_single_rate_mfma_ =
false,
1261 bool is_scale_mfma_ =
false>
1265 constexpr
auto GetMfma<double, 16, 16>()
1267 #if defined(__gfx12__)
1269 #elif defined(__gfx11__)
1277 constexpr
auto GetMfma<float, 64, 64>()
1283 constexpr
auto GetMfma<float, 32, 64>()
1289 constexpr
auto GetMfma<float, 16, 64>()
1295 constexpr
auto GetMfma<float, 8, 64>()
1301 constexpr
auto GetMfma<float, 4, 64>()
1307 constexpr
auto GetMfma<float, 32, 32>()
1313 constexpr
auto GetMfma<float, 16, 16>()
1315 #if defined(__gfx12__)
1317 #elif defined(__gfx11__)
1325 constexpr
auto GetMfma<tf32_t, 32, 32, tf32_t>()
1327 #if defined(__gfx12__)
1329 #elif defined(__gfx11__)
1331 #elif defined(__gfx950__)
1333 #elif defined(__gfx942__)
1341 constexpr
auto GetMfma<tf32_t, 16, 16, tf32_t>()
1343 #if defined(__gfx12__)
1345 #elif defined(__gfx11__)
1347 #elif defined(__gfx950__)
1349 #elif defined(__gfx942__)
1357 constexpr
auto GetMfma<half_t, 64, 64>()
1363 constexpr
auto GetMfma<half_t, 32, 64>()
1369 constexpr
auto GetMfma<half_t, 32, 32, half_t, false>()
1371 #if defined(__gfx950__)
1378 constexpr
auto GetMfma<half_t, 32, 32, half_t, true>()
1384 constexpr
auto GetMfma<half_t, 16, 16, half_t, false>()
1386 #if defined(__gfx12__)
1388 #elif defined(__gfx11__)
1390 #elif defined(__gfx950__)
1398 constexpr
auto GetMfma<half_t, 16, 16, half_t, true>()
1400 #if defined(__gfx12__)
1402 #elif defined(__gfx11__)
1410 constexpr
auto GetMfma<half_t, 16, 64>()
1416 constexpr
auto GetMfma<half_t, 8, 64>()
1422 constexpr
auto GetMfma<half_t, 4, 64>()
1428 constexpr
auto GetMfma<bhalf_t, 32, 32, bhalf_t, false>()
1430 #if defined(__gfx950__)
1432 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1440 constexpr
auto GetMfma<bhalf_t, 32, 32, bhalf_t, true>()
1442 #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1450 constexpr
auto GetMfma<bhalf_t, 16, 16, bhalf_t, false>()
1452 #if defined(__gfx12__)
1454 #elif defined(__gfx11__)
1456 #elif defined(__gfx950__)
1458 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1466 constexpr
auto GetMfma<bhalf_t, 16, 16, bhalf_t, true>()
1468 #if defined(__gfx12__)
1470 #elif defined(__gfx11__)
1472 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1480 constexpr
auto GetMfma<int8_t, 32, 32, int8_t, false>()
1482 #if defined(__gfx950__)
1484 #elif defined(__gfx942__)
1492 constexpr
auto GetMfma<int8_t, 32, 32, int8_t, true>()
1494 #if defined(__gfx942__) || defined(__gfx950__)
1502 constexpr
auto GetMfma<int8_t, 16, 16, int8_t, false>()
1504 #if defined(__gfx12__)
1506 #elif defined(__gfx11__)
1508 #elif defined(__gfx950__)
1510 #elif defined(__gfx942__)
1518 constexpr
auto GetMfma<int8_t, 16, 16, int8_t, true>()
1520 #if defined(__gfx12__)
1522 #elif defined(__gfx11__)
1524 #elif defined(__gfx942__) || defined(__gfx950__)
1532 constexpr
auto GetMfma<f8_t, 32, 32, f8_t, true, false>()
1538 constexpr
auto GetMfma<f8_t, 32, 32, f8_t, false, false>()
1540 #if defined(__gfx950__)
1548 constexpr
auto GetMfma<f8_t, 32, 32, f8_t, is_single_rate_mfma, true>()
1554 constexpr
auto GetMfma<bf8_t, 32, 32, f8_t, is_single_rate_mfma, true>()
1559 constexpr
auto GetMfma<f4_t, 32, 32, f4_t, is_single_rate_mfma, true>()
1564 constexpr
auto GetMfma<f4_t, 16, 16, f4_t, is_single_rate_mfma, true>()
1566 #if defined(__gfx12__)
1568 #elif defined(__gfx11__)
1576 constexpr
auto GetMfma<f8_t, 16, 16, f8_t, true, false>()
1578 #if defined(__gfx12__)
1580 #elif defined(__gfx11__)
1588 constexpr
auto GetMfma<f8_t, 16, 16, f8_t, false, false>()
1590 #if defined(__gfx12__)
1592 #elif defined(__gfx11__)
1594 #elif defined(__gfx950__)
1602 constexpr
auto GetMfma<f8_t, 16, 16, f8_t, is_single_rate_mfma, true>()
1604 #if defined(__gfx12__)
1606 #elif defined(__gfx11__)
1614 constexpr
auto GetMfma<bf8_t, 16, 16, bf8_t, is_single_rate_mfma, true>()
1616 #if defined(__gfx12__)
1618 #elif defined(__gfx11__)
1626 constexpr
auto GetMfma<f8_t, 16, 16, bf8_t, is_single_rate_mfma, true>()
1628 #if defined(__gfx12__)
1630 #elif defined(__gfx11__)
1638 constexpr
auto GetMfma<bf8_t, 16, 16, f8_t, is_single_rate_mfma, true>()
1640 #if defined(__gfx12__)
1642 #elif defined(__gfx11__)
1650 constexpr
auto GetMfma<f6_t, 32, 32, f6_t, is_single_rate_mfma, true>()
1655 constexpr
auto GetMfma<f6_t, 16, 16, f6_t, is_single_rate_mfma, true>()
1657 #if defined(__gfx12__)
1659 #elif defined(__gfx11__)
1666 constexpr
auto GetMfma<bf6_t, 32, 32, bf6_t, is_single_rate_mfma, true>()
1671 constexpr
auto GetMfma<bf6_t, 16, 16, bf6_t, is_single_rate_mfma, true>()
1673 #if defined(__gfx12__)
1675 #elif defined(__gfx11__)
1683 constexpr
auto GetMfma<bf8_t, 32, 32, bf8_t, true, false>()
1689 constexpr
auto GetMfma<bf8_t, 32, 32, bf8_t, false, false>()
1691 #if defined(__gfx950__)
1699 constexpr
auto GetMfma<bf8_t, 16, 16, bf8_t, true, false>()
1701 #if defined(__gfx12__)
1703 #elif defined(__gfx11__)
1711 constexpr
auto GetMfma<bf8_t, 16, 16, bf8_t, false, false>()
1713 #if defined(__gfx12__)
1715 #elif defined(__gfx11__)
1717 #elif defined(__gfx950__)
1725 constexpr
auto GetMfma<f8_t, 32, 32, bf8_t, true, false>()
1731 constexpr
auto GetMfma<f8_t, 32, 32, bf8_t, false, false>()
1733 #if defined(__gfx950__)
1741 constexpr
auto GetMfma<f8_t, 16, 16, bf8_t, true, false>()
1743 #if defined(__gfx12__)
1745 #elif defined(__gfx11__)
1753 constexpr
auto GetMfma<f8_t, 16, 16, bf8_t, false, false>()
1755 #if defined(__gfx12__)
1757 #elif defined(__gfx11__)
1759 #elif defined(__gfx950__)
1767 constexpr
auto GetMfma<bf8_t, 32, 32, f8_t, true, false>()
1773 constexpr
auto GetMfma<bf8_t, 32, 32, f8_t, false, false>()
1775 #if defined(__gfx950__)
1783 constexpr
auto GetMfma<bf8_t, 16, 16, f8_t, true, false>()
1785 #if defined(__gfx12__)
1787 #elif defined(__gfx11__)
1795 constexpr
auto GetMfma<bf8_t, 16, 16, f8_t, false, false>()
1797 #if defined(__gfx12__)
1799 #elif defined(__gfx11__)
1801 #elif defined(__gfx950__)
1812 is_single_rate_mfma,
1813 is_scale_mfma>()>{};
1819 "wrong! num_regs_per_blk");
1822 "n_per_blk != num_threads_per_blk");
1823 #if defined(__gfx11__)
1824 if constexpr(MPerXdlops == 16 && NPerXdlops == 16)
1828 "m_per_blk != num_input_blks * num_regs_per_blk");
1833 "m_per_blk != num_input_blks * num_regs_per_blk");
1838 "incorrect num_output_blks");
1842 "num_regs_per_blk incorrect");
1846 "is_k_reduction wrong!");
1851 static_assert(NPerXdlops >= MPerXdlops,
"only support ABroadcast");
1864 template <
typename base_type,
1868 typename additional_type = base_type,
1869 bool TransposeC =
false,
1870 bool is_scale_mfma =
false>
1887 return MPerXdlops * NPerXdlops /
1893 static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 ||
1895 "Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
1897 static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 ||
1899 "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
1900 #if defined(__HIP_DEVICE_COMPILE__)
1901 static_assert(KPack %
mfma_instr.k_per_blk == 0,
"KPack should be a multiple of k_per_blk");
1907 template <
typename CDesc_M0_N0_M1_N1_M2_N2>
1908 __host__ __device__
static constexpr
auto
1911 const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I0);
1912 const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I1);
1913 const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I2);
1914 const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I3);
1918 c_desc_m0_n0_m1_n1_m2_n2,
1943 template <
typename CDesc_M0_N0_M1_N1_M2_N2>
1945 const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
1947 const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I0);
1948 const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I1);
1949 const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I2);
1950 const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I3);
1951 const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I4);
1952 const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I5);
1956 c_desc_m0_n0_m1_n1_m2_n2,
1987 template <
typename CDesc_M0_N0_M1_N1_M2_N2>
1988 __host__ __device__
static constexpr
auto
1991 const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I0);
1992 const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I1);
1993 const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I2);
1994 const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I3);
1998 c_desc_m0_n0_m1_n1_m2_n2,
2021 template <
typename CDesc_G_M0_N0_M1_N1_M2_N2>
2023 const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2)
2025 const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I0);
2026 const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I1);
2027 const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I2);
2028 const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I3);
2029 const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I4);
2033 c_desc_g_m0_n0_m1_n1_m2_n2,
2065 template <
class FloatA,
class FloatB,
class FloatC>
2066 __device__
void Run(
const FloatA& p_a_wave,
const FloatB& p_b_wave, FloatC& p_c_thread)
const
2075 "base_type must be double, float, tf32_t, half, bfloat16, int8_t, f8_t or bf8_t!");
2078 if constexpr(!TransposeC)
2080 mfma_instr.template run<MPerXdlops, NPerXdlops>(
2081 p_a_wave[k], p_b_wave[k], p_c_thread);
2085 mfma_instr.template run<MPerXdlops, NPerXdlops>(
2086 p_b_wave[k], p_a_wave[k], p_c_thread);
2098 __device__
void Run(
const FloatA& p_a_wave,
2099 const ScaleA& a_scale_thread,
2100 const FloatB& p_b_wave,
2101 const ScaleB& b_scale_thread,
2102 FloatC& p_c_thread)
const
2105 if constexpr(!TransposeC)
2107 mfma_instr.template run<MPerXdlops, NPerXdlops, OpselA, OpselB>(
2108 p_a_wave[k], a_scale_thread[k], p_b_wave[k], b_scale_thread[k], p_c_thread);
2112 mfma_instr.template run<MPerXdlops, NPerXdlops, OpselB, OpselA>(
2113 p_b_wave[k], b_scale_thread[k], p_a_wave[k], a_scale_thread[k], p_c_thread);
2131 const auto blk_idx =
2132 threadidx_to_blk_idx_adaptor.CalculateBottomIndex(
make_multi_index(laneId));
2134 const auto blk_id = blk_idx[
I1];
2135 const auto blk_td = blk_idx[
I2];
2140 template <
bool SwizzleA>
2144 if constexpr(SwizzleA)
2146 laneId = ((laneId & 1) << 3) | (laneId >> 1);
2154 const auto blk_idx =
2155 threadidx_to_blk_idx_adaptor.CalculateBottomIndex(
make_multi_index(laneId));
2157 const auto blk_id = blk_idx[
I1];
2158 const auto blk_td = blk_idx[
I2];
2166 #if defined(__gfx11__)
2167 const auto blk_idx = GetGfx11InputBlkIdx<!TransposeC>();
2172 const auto blk_id = blk_idx[
I0];
2173 const auto blk_td = blk_idx[
I1];
2188 #if defined(__gfx11__)
2189 const auto blk_idx = GetGfx11InputBlkIdx<TransposeC>();
2194 const auto blk_id = blk_idx[
I0];
2195 const auto blk_td = blk_idx[
I1];
2211 const auto blk_id = blk_idx[
I0];
2212 const auto blk_td = blk_idx[
I1];
2217 return TransposeC ?
CIndex{n_offset, m_offset} :
CIndex{m_offset, n_offset};
2224 const auto blk_id = blk_idx[
I0];
2225 const auto blk_td = blk_idx[
I1];
2239 #if defined(__gfx950__)
__host__ constexpr __device__ T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition: math.hpp:148
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
MfmaInstr
Definition: xdlops_gemm.hpp:45
@ wmma_f32_16x16x16_bf16_gfx12
@ mfma_f32_32x32x64f8f6f4
@ wmma_unsupport_16x16_gfx11
@ wmma_i32_16x16x16_iu8_gfx12
@ mfma_scale_f32_32x32x64f8f6f4
@ wmma_f32_16x16x16_bf8f8_gfx12
@ wmma_f32_16x16x16_f16_gfx12
@ wmma_f32_16x16x16_bf8bf8_gfx12
@ wmma_unsupport_16x16_gfx12
@ mfma_f32_16x16x16bf16_1k
@ wmma_f32_16x16x16_f8f8_gfx12
@ mfma_scale_f32_16x16x128f8f6f4
@ mfma_f32_16x16x32bf8bf8
@ mfma_f32_16x16x128f8f6f4
@ mfma_f32_32x32x16bf8bf8
@ mfma_f32_32x32x8bf16_1k
@ wmma_f32_16x16x16_f8bf8_gfx12
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
typename packed_type_info< T >::element_type element_type_t
Definition: data_type.hpp:408
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:301
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
@ wmma_f32_16x16x16_bf16_gfx12
@ wmma_i32_16x16x16_iu8_gfx12
@ wmma_f32_16x16x16_bf8f8_gfx12
@ wmma_f32_16x16x16_f16_gfx12
@ wmma_f32_16x16x16_bf8bf8_gfx12
@ wmma_f32_16x16x16_f8f8_gfx12
@ wmma_f32_16x16x16_f8bf8_gfx12
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1517
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1255
__host__ constexpr __device__ MfmaSelector()
Definition: xdlops_gemm.hpp:1815
static constexpr bool IsABroadcast()
Definition: xdlops_gemm.hpp:1849
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1861
static constexpr auto GetMfma()
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1808
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1855
Definition: sequence.hpp:43
Definition: xdlops_gemm.hpp:1872
static constexpr auto mfma_instr
Definition: xdlops_gemm.hpp:2252
__host__ constexpr __device__ XdlopsGemm()
Definition: xdlops_gemm.hpp:1891
__host__ static __device__ auto CalculateBThreadOriginDataIndex()
Definition: xdlops_gemm.hpp:2185
static __device__ auto GetBlkIdx()
Definition: xdlops_gemm.hpp:2120
__device__ static constexpr __host__ index_t GetRegSizePerXdlops()
Definition: xdlops_gemm.hpp:2058
static constexpr auto I2
Definition: xdlops_gemm.hpp:1875
static constexpr __device__ index_t GetNumBlks()
Definition: xdlops_gemm.hpp:1883
static __device__ auto GetLaneId()
Definition: xdlops_gemm.hpp:2118
static constexpr auto K0PerXdlops
Definition: xdlops_gemm.hpp:2256
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1944
static constexpr __device__ index_t GetNumXdlops()
Definition: xdlops_gemm.hpp:1885
__host__ static __device__ auto CalculateAThreadOriginDataIndex()
Definition: xdlops_gemm.hpp:2163
static constexpr bool is_single_rate_mfma
Definition: xdlops_gemm.hpp:2233
static __device__ CIndex4D GetBeginOfThreadBlk4D(index_t, index_t)
Definition: xdlops_gemm.hpp:2220
static constexpr __device__ index_t GetWaveSize()
Definition: xdlops_gemm.hpp:2063
static __device__ auto GetGfx11InputBlkIdx()
Definition: xdlops_gemm.hpp:2141
static constexpr auto I5
Definition: xdlops_gemm.hpp:1878
static constexpr auto I3
Definition: xdlops_gemm.hpp:1876
static constexpr auto I0
Definition: xdlops_gemm.hpp:1873
__device__ void Run(const FloatA &p_a_wave, const ScaleA &a_scale_thread, const FloatB &p_b_wave, const ScaleB &b_scale_thread, FloatC &p_c_thread) const
Definition: xdlops_gemm.hpp:2098
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1909
__host__ static constexpr __device__ auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_G_M0_N0_M1_N1_M2_N2 &c_desc_g_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:2022
static constexpr auto I1
Definition: xdlops_gemm.hpp:1874
static constexpr auto K1PerXdlops
Definition: xdlops_gemm.hpp:2255
static constexpr auto KPerXdlops
Definition: xdlops_gemm.hpp:2254
static constexpr auto I4
Definition: xdlops_gemm.hpp:1877
__device__ void Run(const FloatA &p_a_wave, const FloatB &p_b_wave, FloatC &p_c_thread) const
Definition: xdlops_gemm.hpp:2066
static constexpr auto mfma
Definition: xdlops_gemm.hpp:2245
static __device__ CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
Definition: xdlops_gemm.hpp:2207
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1989
__host__ static constexpr __device__ auto GetCM0M1M2NThreadBlkLengths()
Definition: xdlops_gemm.hpp:2258
Definition: integral_constant.hpp:20
Definition: amd_xdlops.hpp:1221
Definition: amd_xdlops.hpp:322
Definition: amd_xdlops.hpp:212
Definition: amd_xdlops.hpp:89
Definition: amd_xdlops.hpp:288
Definition: amd_xdlops.hpp:1502
Definition: amd_xdlops.hpp:1628
Definition: amd_xdlops.hpp:178
Definition: amd_xdlops.hpp:1565
Definition: amd_xdlops.hpp:1439
Definition: amd_xdlops.hpp:1711
Definition: amd_xdlops.hpp:226
Definition: amd_xdlops.hpp:75
Definition: amd_xdlops.hpp:350
Definition: amd_xdlops.hpp:1660
Definition: amd_xdlops.hpp:268
Definition: amd_xdlops.hpp:1470
Definition: amd_xdlops.hpp:1596
Definition: amd_xdlops.hpp:158
Definition: amd_xdlops.hpp:1533
Definition: amd_xdlops.hpp:1407
Definition: amd_xdlops.hpp:1754
Definition: amd_xdlops.hpp:34
Definition: amd_xdlops.hpp:61
Definition: amd_xdlops.hpp:336
Definition: amd_xdlops.hpp:131
Definition: amd_xdlops.hpp:1680
Definition: amd_xdlops.hpp:500
Definition: amd_xdlops.hpp:308
Definition: amd_xdlops.hpp:198
Definition: amd_xdlops.hpp:103
Definition: amd_xdlops.hpp:240
Definition: amd_xdlops.hpp:480
Definition: amd_xdlops.hpp:383
Definition: amd_xdlops.hpp:461
Definition: amd_xdlops.hpp:422
Definition: amd_xdlops.hpp:442
Definition: amd_xdlops.hpp:402
Definition: amd_xdlops.hpp:364
Definition: amd_xdlops.hpp:905
Definition: amd_xdlops.hpp:685
Definition: amd_wmma.hpp:297
Definition: amd_wmma.hpp:50
Definition: amd_wmma.hpp:418
Definition: amd_wmma.hpp:394
Definition: amd_wmma.hpp:271
Definition: amd_wmma.hpp:25
Definition: amd_wmma.hpp:370
Definition: amd_wmma.hpp:346
Definition: amd_wmma.hpp:319
Definition: amd_wmma.hpp:121
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:875
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:453
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:321
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:188
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:431
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:739
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:827
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:299
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:783
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:695
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1059
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:343
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:166
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:497
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:992
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:387
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:717
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:805
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:277
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:761
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:673
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1037
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:122
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:144
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:475
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:233
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1014
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:851
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:409
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:255
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:211
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:365
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:651
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:541
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:585
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:629
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:563
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:607
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:519
__device__ void run(const FloatA &a, const ScaleA &scale_a, const FloatB &b, const ScaleB &scale_b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:944
__device__ void run(const FloatA &a, const ScaleA &scale_a, const FloatB &b, const ScaleB &scale_b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:907
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1095
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1159
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1217
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1207
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1085
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1149
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1197
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1187
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1112
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:1176
__device__ void run(const FloatA &, const FloatB &, FloatC &) const
Definition: xdlops_gemm.hpp:1123
__device__ void run(const FloatA &, const FloatB &, FloatC &) const
Definition: xdlops_gemm.hpp:1228
Definition: xdlops_gemm.hpp:1067
static constexpr index_t n_per_blk
Definition: xdlops_gemm.hpp:1076
static constexpr index_t group_size
Definition: xdlops_gemm.hpp:1068
static constexpr index_t m_per_blk
Definition: xdlops_gemm.hpp:1075
static constexpr bool is_k_reduction
Definition: xdlops_gemm.hpp:1078
static constexpr index_t num_threads_per_blk
Definition: xdlops_gemm.hpp:1071
static constexpr index_t num_output_blks
Definition: xdlops_gemm.hpp:1074
static constexpr index_t wave_size
Definition: xdlops_gemm.hpp:1072
static constexpr index_t num_input_blks
Definition: xdlops_gemm.hpp:1073
static constexpr index_t num_groups_per_blk
Definition: xdlops_gemm.hpp:1069
static constexpr index_t num_regs_per_blk
Definition: xdlops_gemm.hpp:1070
static constexpr index_t k_per_blk
Definition: xdlops_gemm.hpp:1077
Definition: xdlops_gemm.hpp:1131
static constexpr index_t n_per_blk
Definition: xdlops_gemm.hpp:1140
static constexpr index_t group_size
Definition: xdlops_gemm.hpp:1132
static constexpr index_t num_output_blks
Definition: xdlops_gemm.hpp:1138
static constexpr index_t m_per_blk
Definition: xdlops_gemm.hpp:1139
static constexpr index_t num_threads_per_blk
Definition: xdlops_gemm.hpp:1135
static constexpr bool is_k_reduction
Definition: xdlops_gemm.hpp:1142
static constexpr index_t num_regs_per_blk
Definition: xdlops_gemm.hpp:1134
static constexpr index_t num_groups_per_blk
Definition: xdlops_gemm.hpp:1133
static constexpr index_t num_input_blks
Definition: xdlops_gemm.hpp:1137
static constexpr index_t wave_size
Definition: xdlops_gemm.hpp:1136
static constexpr index_t k_per_blk
Definition: xdlops_gemm.hpp:1141
Definition: xdlops_gemm.hpp:104
Definition: functional2.hpp:33