53 template <MfmaInstr instr>
60 static constexpr
index_t num_groups_per_blk = 4;
61 static constexpr
index_t num_regs_per_blk = 16;
62 static constexpr
index_t num_threads_per_blk = 32;
64 static constexpr
index_t num_input_blks = 2;
65 static constexpr
index_t num_output_blks = 2;
69 static constexpr
bool is_k_reduction =
false;
71 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
72 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
82 static constexpr
index_t num_groups_per_blk = 4;
83 static constexpr
index_t num_regs_per_blk = 16;
84 static constexpr
index_t num_threads_per_blk = 32;
86 static constexpr
index_t num_input_blks = 2;
87 static constexpr
index_t num_output_blks = 1;
91 static constexpr
bool is_k_reduction =
true;
93 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
94 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
104 static constexpr
index_t num_groups_per_blk = 1;
105 static constexpr
index_t num_regs_per_blk = 4;
106 static constexpr
index_t num_threads_per_blk = 16;
113 static constexpr
bool is_k_reduction =
true;
115 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
116 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
126 static constexpr
index_t num_groups_per_blk = 1;
127 static constexpr
index_t num_regs_per_blk = 4;
128 static constexpr
index_t num_threads_per_blk = 16;
135 static constexpr
bool is_k_reduction =
false;
137 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
138 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
149 static constexpr
index_t num_groups_per_blk = 1;
150 static constexpr
index_t num_regs_per_blk = 4;
151 static constexpr
index_t num_threads_per_blk = 64;
158 static constexpr
bool is_k_reduction =
false;
160 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
161 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
171 static constexpr
index_t num_groups_per_blk = 4;
172 static constexpr
index_t num_regs_per_blk = 16;
173 static constexpr
index_t num_threads_per_blk = 32;
180 static constexpr
bool is_k_reduction =
false;
182 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
183 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
193 static constexpr
index_t num_groups_per_blk = 4;
194 static constexpr
index_t num_regs_per_blk = 16;
195 static constexpr
index_t num_threads_per_blk = 32;
202 static constexpr
bool is_k_reduction =
true;
204 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
205 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
215 static constexpr
index_t num_groups_per_blk = 4;
216 static constexpr
index_t num_regs_per_blk = 16;
217 static constexpr
index_t num_threads_per_blk = 32;
224 static constexpr
bool is_k_reduction =
true;
226 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
227 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
237 static constexpr
index_t num_groups_per_blk = 1;
238 static constexpr
index_t num_regs_per_blk = 4;
239 static constexpr
index_t num_threads_per_blk = 16;
246 static constexpr
bool is_k_reduction =
true;
248 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
249 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
259 static constexpr
index_t num_groups_per_blk = 1;
260 static constexpr
index_t num_regs_per_blk = 4;
261 static constexpr
index_t num_threads_per_blk = 16;
268 static constexpr
bool is_k_reduction =
true;
270 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
271 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
281 static constexpr
index_t num_groups_per_blk = 1;
282 static constexpr
index_t num_regs_per_blk = 4;
283 static constexpr
index_t num_threads_per_blk = 16;
290 static constexpr
bool is_k_reduction =
false;
292 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
293 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
303 static constexpr
index_t num_groups_per_blk = 1;
304 static constexpr
index_t num_regs_per_blk = 4;
305 static constexpr
index_t num_threads_per_blk = 64;
312 static constexpr
bool is_k_reduction =
false;
314 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
315 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
325 static constexpr
index_t num_groups_per_blk = 4;
326 static constexpr
index_t num_regs_per_blk = 16;
327 static constexpr
index_t num_threads_per_blk = 32;
334 static constexpr
bool is_k_reduction =
true;
336 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
337 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
347 static constexpr
index_t num_groups_per_blk = 4;
348 static constexpr
index_t num_regs_per_blk = 16;
349 static constexpr
index_t num_threads_per_blk = 32;
356 static constexpr
bool is_k_reduction =
true;
358 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
359 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
369 static constexpr
index_t num_groups_per_blk = 1;
370 static constexpr
index_t num_regs_per_blk = 4;
371 static constexpr
index_t num_threads_per_blk = 16;
378 static constexpr
bool is_k_reduction =
true;
380 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
381 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
391 static constexpr
index_t num_groups_per_blk = 1;
392 static constexpr
index_t num_regs_per_blk = 4;
393 static constexpr
index_t num_threads_per_blk = 16;
400 static constexpr
bool is_k_reduction =
true;
402 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
403 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
413 static constexpr
index_t num_groups_per_blk = 4;
414 static constexpr
index_t num_regs_per_blk = 16;
415 static constexpr
index_t num_threads_per_blk = 32;
422 static constexpr
bool is_k_reduction =
true;
424 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
425 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
435 static constexpr
index_t num_groups_per_blk = 1;
436 static constexpr
index_t num_regs_per_blk = 4;
437 static constexpr
index_t num_threads_per_blk = 16;
444 static constexpr
bool is_k_reduction =
true;
446 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
447 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
457 static constexpr
index_t num_groups_per_blk = 4;
458 static constexpr
index_t num_regs_per_blk = 16;
459 static constexpr
index_t num_threads_per_blk = 32;
466 static constexpr
bool is_k_reduction =
true;
468 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
469 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
479 static constexpr
index_t num_groups_per_blk = 1;
480 static constexpr
index_t num_regs_per_blk = 4;
481 static constexpr
index_t num_threads_per_blk = 16;
488 static constexpr
bool is_k_reduction =
true;
490 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
491 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
501 static constexpr
index_t num_groups_per_blk = 4;
502 static constexpr
index_t num_regs_per_blk = 16;
503 static constexpr
index_t num_threads_per_blk = 32;
510 static constexpr
bool is_k_reduction =
true;
512 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
513 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
523 static constexpr
index_t num_groups_per_blk = 1;
524 static constexpr
index_t num_regs_per_blk = 4;
525 static constexpr
index_t num_threads_per_blk = 16;
532 static constexpr
bool is_k_reduction =
true;
534 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
535 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
545 static constexpr
index_t num_groups_per_blk = 4;
546 static constexpr
index_t num_regs_per_blk = 16;
547 static constexpr
index_t num_threads_per_blk = 32;
554 static constexpr
bool is_k_reduction =
true;
556 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
557 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
567 static constexpr
index_t num_groups_per_blk = 1;
568 static constexpr
index_t num_regs_per_blk = 4;
569 static constexpr
index_t num_threads_per_blk = 16;
576 static constexpr
bool is_k_reduction =
true;
578 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
579 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
589 static constexpr
index_t num_groups_per_blk = 4;
590 static constexpr
index_t num_regs_per_blk = 4;
591 static constexpr
index_t num_threads_per_blk = 16;
598 static constexpr
bool is_k_reduction =
true;
600 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
601 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
611 static constexpr
index_t num_groups_per_blk = 4;
612 static constexpr
index_t num_regs_per_blk = 16;
613 static constexpr
index_t num_threads_per_blk = 32;
620 static constexpr
bool is_k_reduction =
true;
622 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
623 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
633 static constexpr
index_t num_groups_per_blk = 1;
634 static constexpr
index_t num_regs_per_blk = 4;
635 static constexpr
index_t num_threads_per_blk = 16;
642 static constexpr
bool is_k_reduction =
true;
644 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
645 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
655 static constexpr
index_t num_groups_per_blk = 4;
656 static constexpr
index_t num_regs_per_blk = 16;
657 static constexpr
index_t num_threads_per_blk = 32;
664 static constexpr
bool is_k_reduction =
true;
666 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
667 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
677 static constexpr
index_t num_groups_per_blk = 1;
678 static constexpr
index_t num_regs_per_blk = 4;
679 static constexpr
index_t num_threads_per_blk = 16;
686 static constexpr
bool is_k_reduction =
true;
688 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
689 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
699 static constexpr
index_t num_groups_per_blk = 4;
700 static constexpr
index_t num_regs_per_blk = 16;
701 static constexpr
index_t num_threads_per_blk = 32;
708 static constexpr
bool is_k_reduction =
true;
710 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
711 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
721 static constexpr
index_t num_groups_per_blk = 1;
722 static constexpr
index_t num_regs_per_blk = 4;
723 static constexpr
index_t num_threads_per_blk = 16;
730 static constexpr
bool is_k_reduction =
true;
732 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
733 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
743 static constexpr
index_t num_groups_per_blk = 4;
744 static constexpr
index_t num_regs_per_blk = 16;
745 static constexpr
index_t num_threads_per_blk = 32;
752 static constexpr
bool is_k_reduction =
true;
754 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
755 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
765 static constexpr
index_t num_groups_per_blk = 1;
766 static constexpr
index_t num_regs_per_blk = 4;
767 static constexpr
index_t num_threads_per_blk = 16;
774 static constexpr
bool is_k_reduction =
true;
776 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
777 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
789 static constexpr
index_t num_groups_per_blk = 4;
790 static constexpr
index_t num_regs_per_blk = 16;
791 static constexpr
index_t num_threads_per_blk = 32;
798 static constexpr
bool is_k_reduction =
true;
801 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
802 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
813 static constexpr
index_t num_groups_per_blk = 1;
814 static constexpr
index_t num_regs_per_blk = 4;
815 static constexpr
index_t num_threads_per_blk = 16;
822 static constexpr
bool is_k_reduction =
true;
825 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
826 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
837 static constexpr
index_t num_groups_per_blk = 4;
838 static constexpr
index_t num_regs_per_blk = 16;
839 static constexpr
index_t num_threads_per_blk = 32;
846 static constexpr
bool is_k_reduction =
true;
849 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
850 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
861 static constexpr
index_t num_groups_per_blk = 1;
862 static constexpr
index_t num_regs_per_blk = 4;
863 static constexpr
index_t num_threads_per_blk = 16;
870 static constexpr
bool is_k_reduction =
true;
873 template <index_t MPerXdlops, index_t NPerXdlops,
class FloatA,
class FloatB,
class FloatC>
874 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
880 template <
typename base_type,
883 typename additional_type = base_type,
884 bool is_single_rate_mfma =
false>
887 template <
typename base_type_,
890 typename additional_type_ = base_type_,
891 bool is_single_rate_mfma_ =
false>
895 constexpr
auto GetMfma<double, 16, 16>()
901 constexpr
auto GetMfma<float, 64, 64>()
907 constexpr
auto GetMfma<float, 32, 64>()
913 constexpr
auto GetMfma<float, 16, 64>()
919 constexpr
auto GetMfma<float, 8, 64>()
925 constexpr
auto GetMfma<float, 4, 64>()
931 constexpr
auto GetMfma<float, 32, 32>()
937 constexpr
auto GetMfma<float, 16, 16>()
943 constexpr
auto GetMfma<half_t, 64, 64>()
949 constexpr
auto GetMfma<half_t, 32, 64>()
955 constexpr
auto GetMfma<half_t, 32, 32, half_t, false>()
957 #if defined(__gfx950__)
964 constexpr
auto GetMfma<half_t, 32, 32, half_t, true>()
970 constexpr
auto GetMfma<half_t, 16, 16, half_t, false>()
972 #if defined(__gfx950__)
980 constexpr
auto GetMfma<half_t, 16, 16, half_t, true>()
986 constexpr
auto GetMfma<half_t, 16, 64>()
992 constexpr
auto GetMfma<half_t, 8, 64>()
998 constexpr
auto GetMfma<half_t, 4, 64>()
1004 constexpr
auto GetMfma<bhalf_t, 32, 32, bhalf_t, false>()
1006 #if defined(__gfx950__)
1008 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1016 constexpr
auto GetMfma<bhalf_t, 32, 32, bhalf_t, true>()
1018 #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1026 constexpr
auto GetMfma<bhalf_t, 16, 16, bhalf_t, false>()
1028 #if defined(__gfx950__)
1030 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1038 constexpr
auto GetMfma<bhalf_t, 16, 16, bhalf_t, true>()
1040 #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1047 #if defined(__gfx950__)
1049 constexpr
auto GetMfma<int8_t, 32, 32>()
1054 constexpr
auto GetMfma<int8_t, 16, 16>()
1058 #elif defined(__gfx942__)
1060 constexpr
auto GetMfma<int8_t, 32, 32>()
1065 constexpr
auto GetMfma<int8_t, 16, 16>()
1071 constexpr
auto GetMfma<int8_t, 32, 32>()
1076 constexpr
auto GetMfma<int8_t, 16, 16>()
1083 constexpr
auto GetMfma<f8_t, 32, 32>()
1089 constexpr
auto GetMfma<f8_t, 16, 16>()
1095 constexpr
auto GetMfma<bf8_t, 32, 32>()
1101 constexpr
auto GetMfma<bf8_t, 16, 16>()
1107 constexpr
auto GetMfma<f8_t, 32, 32, bf8_t>()
1113 constexpr
auto GetMfma<f8_t, 16, 16, bf8_t>()
1119 constexpr
auto GetMfma<bf8_t, 32, 32, f8_t>()
1125 constexpr
auto GetMfma<bf8_t, 16, 16, f8_t>()
1131 GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type, is_single_rate_mfma>()>{};
1137 "wrong! num_regs_per_blk");
1140 "n_per_blk != num_threads_per_blk");
1144 "m_per_blk != num_input_blks * num_regs_per_blk");
1148 "incorrect num_output_blks");
1152 "num_regs_per_blk incorrect");
1156 "is_k_reduction wrong!");
1161 static_assert(NPerXdlops >= MPerXdlops,
"only support ABroadcast");
1174 template <
typename base_type,
1178 typename additional_type = base_type,
1179 bool TransposeC =
false>
1196 return MPerXdlops * NPerXdlops /
1202 static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 ||
1204 "Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
1206 static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 ||
1208 "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
1210 static_assert(KPack %
mfma_instr.k_per_blk == 0,
"KPack cannot be divided by k_per_blk");
1215 template <
typename CDesc_M0_N0_M1_N1_M2_N2>
1216 __host__ __device__
static constexpr
auto
1219 const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I0);
1220 const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I1);
1221 const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I2);
1222 const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I3);
1225 c_desc_m0_n0_m1_n1_m2_n2,
1250 template <
typename CDesc_M0_N0_M1_N1_M2_N2>
1251 __host__ __device__
static constexpr
auto
1254 const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I0);
1255 const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I1);
1256 const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I2);
1257 const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(
I3);
1260 c_desc_m0_n0_m1_n1_m2_n2,
1283 template <
typename CDesc_G_M0_N0_M1_N1_M2_N2>
1285 const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2)
1287 const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I0);
1288 const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I1);
1289 const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I2);
1290 const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I3);
1291 const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(
I4);
1294 c_desc_g_m0_n0_m1_n1_m2_n2,
1322 return MPerXdlops * NPerXdlops /
mfma_instr.wave_size;
1327 template <
class FloatA,
class FloatB,
class FloatC>
1328 __device__
void Run(
const FloatA& p_a_wave,
const FloatB& p_b_wave, FloatC& p_c_thread)
const
1337 "base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!");
1340 if constexpr(!TransposeC)
1342 mfma_instr.template run<MPerXdlops, NPerXdlops>(
1343 p_a_wave[k], p_b_wave[k], p_c_thread);
1347 mfma_instr.template run<MPerXdlops, NPerXdlops>(
1348 p_b_wave[k], p_a_wave[k], p_c_thread);
1365 const auto blk_idx =
1366 threadidx_to_blk_idx_adaptor.CalculateBottomIndex(
make_multi_index(laneId));
1368 const auto blk_id = blk_idx[
I1];
1369 const auto blk_td = blk_idx[
I2];
1379 const auto blk_id = blk_idx[
I0];
1380 const auto blk_td = blk_idx[
I1];
1397 const auto blk_id = blk_idx[
I0];
1398 const auto blk_td = blk_idx[
I1];
1414 const auto blk_id = blk_idx[
I0];
1415 const auto blk_td = blk_idx[
I1];
1420 return TransposeC ?
CIndex{n_offset, m_offset} :
CIndex{m_offset, n_offset};
1427 const auto blk_id = blk_idx[
I0];
1428 const auto blk_td = blk_idx[
I1];
1434 static constexpr
auto
1436 MPerXdlops, NPerXdlops, additional_type,
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
MfmaInstr
Definition: xdlops_gemm.hpp:13
@ mfma_f32_32x32x64f8f6f4
@ mfma_scale_f32_32x32x64f8f6f4
@ mfma_f32_16x16x16bf16_1k
@ mfma_scale_f32_16x16x128f8f6f4
@ mfma_f32_16x16x32bf8bf8
@ mfma_f32_16x16x128f8f6f4
@ mfma_f32_32x32x16bf8bf8
@ mfma_f32_32x32x8bf16_1k
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:16
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
Definition: xdlops_gemm.hpp:886
static constexpr bool IsABroadcast()
Definition: xdlops_gemm.hpp:1159
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1165
static constexpr auto GetMfma()
__host__ constexpr __device__ MfmaSelector()
Definition: xdlops_gemm.hpp:1133
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1130
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1171
Definition: sequence.hpp:43
Definition: xdlops_gemm.hpp:1181
static __device__ auto GetLaneId()
Definition: xdlops_gemm.hpp:1353
__host__ static __device__ auto CalculateBThreadOriginDataIndex()
Definition: xdlops_gemm.hpp:1392
static constexpr __device__ index_t GetNumBlks()
Definition: xdlops_gemm.hpp:1192
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1252
static constexpr auto KPerXdlops
Definition: xdlops_gemm.hpp:1443
__host__ constexpr __device__ XdlopsGemm()
Definition: xdlops_gemm.hpp:1200
static constexpr auto mfma_instr
Definition: xdlops_gemm.hpp:1441
static __device__ auto GetBlkIdx()
Definition: xdlops_gemm.hpp:1355
static constexpr __device__ index_t GetRegSizePerXdlops()
Definition: xdlops_gemm.hpp:1320
static constexpr auto I5
Definition: xdlops_gemm.hpp:1187
static constexpr auto K0PerXdlops
Definition: xdlops_gemm.hpp:1445
static constexpr auto I4
Definition: xdlops_gemm.hpp:1186
static constexpr __device__ index_t GetWaveSize()
Definition: xdlops_gemm.hpp:1325
static constexpr auto K1PerXdlops
Definition: xdlops_gemm.hpp:1444
static __device__ CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
Definition: xdlops_gemm.hpp:1410
__host__ static __device__ auto CalculateAThreadOriginDataIndex()
Definition: xdlops_gemm.hpp:1374
static constexpr auto mfma
Definition: xdlops_gemm.hpp:1435
static constexpr auto I3
Definition: xdlops_gemm.hpp:1185
static constexpr auto I2
Definition: xdlops_gemm.hpp:1184
__host__ static constexpr __device__ auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_G_M0_N0_M1_N1_M2_N2 &c_desc_g_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1284
__device__ void Run(const FloatA &p_a_wave, const FloatB &p_b_wave, FloatC &p_c_thread) const
Definition: xdlops_gemm.hpp:1328
__host__ static constexpr __device__ auto GetCM0M1M2NThreadBlkLengths()
Definition: xdlops_gemm.hpp:1447
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1217
static constexpr auto I1
Definition: xdlops_gemm.hpp:1183
static constexpr auto I0
Definition: xdlops_gemm.hpp:1182
static __device__ CIndex4D GetBeginOfThreadBlk4D(index_t, index_t)
Definition: xdlops_gemm.hpp:1423
static constexpr __device__ index_t GetNumXdlops()
Definition: xdlops_gemm.hpp:1194
Definition: integral_constant.hpp:10
Definition: amd_xdlops.hpp:587
Definition: amd_xdlops.hpp:302
Definition: amd_xdlops.hpp:192
Definition: amd_xdlops.hpp:69
Definition: amd_xdlops.hpp:268
Definition: amd_xdlops.hpp:718
Definition: amd_xdlops.hpp:844
Definition: amd_xdlops.hpp:158
Definition: amd_xdlops.hpp:781
Definition: amd_xdlops.hpp:655
Definition: amd_xdlops.hpp:206
Definition: amd_xdlops.hpp:55
Definition: amd_xdlops.hpp:330
Definition: amd_xdlops.hpp:248
Definition: amd_xdlops.hpp:686
Definition: amd_xdlops.hpp:812
Definition: amd_xdlops.hpp:138
Definition: amd_xdlops.hpp:749
Definition: amd_xdlops.hpp:623
Definition: amd_xdlops.hpp:14
Definition: amd_xdlops.hpp:41
Definition: amd_xdlops.hpp:316
Definition: amd_xdlops.hpp:111
Definition: amd_xdlops.hpp:480
Definition: amd_xdlops.hpp:288
Definition: amd_xdlops.hpp:178
Definition: amd_xdlops.hpp:83
Definition: amd_xdlops.hpp:220
Definition: amd_xdlops.hpp:460
Definition: amd_xdlops.hpp:363
Definition: amd_xdlops.hpp:441
Definition: amd_xdlops.hpp:402
Definition: amd_xdlops.hpp:422
Definition: amd_xdlops.hpp:382
Definition: amd_xdlops.hpp:344
Definition: amd_xdlops.hpp:551
Definition: amd_xdlops.hpp:515
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:826
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:403
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:271
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:138
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:381
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:689
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:777
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:249
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:733
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:645
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:293
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:116
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:447
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:337
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:667
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:755
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:227
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:711
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:623
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:72
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:94
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:425
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:183
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:802
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:359
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:205
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:161
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:315
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:601
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:491
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:535
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:579
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:513
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:557
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:469
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:874
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: xdlops_gemm.hpp:850
Definition: xdlops_gemm.hpp:54
Definition: functional2.hpp:31