9 #if defined(__gfx942__) || defined(__gfx950__)
15 template <index_t VecSize>
16 __device__ __forceinline__
void
23 reg_bf16_big.template AsType<bhalf_t>()(k) =
26 reg_f32.template AsType<float>()[IK{}] -
33 template <index_t MPerWave, index_t NPerWave>
39 template <
class FloatC>
40 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
42 reg_c.template AsType<float32_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
43 reg_a, reg_b, reg_c.template AsType<float32_t>()[
Number<0>{}], 1, 0, 0);
44 reg_c.template AsType<float32_t>()(
Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
45 reg_a, reg_b, reg_c.template AsType<float32_t>()[
Number<1>{}], 1, 1, 0);
52 template <
class FloatC>
53 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
55 reg_c.template AsType<float32_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
56 reg_a, reg_b, reg_c.template AsType<float32_t>()[
Number<0>{}], 1, 0, 0);
60 template <index_t MPerWave, index_t NPerWave>
66 template <
class FloatC>
67 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
69 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x2f32(
70 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
74 template <index_t MPerWave, index_t NPerWave>
80 template <
class FloatC>
81 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
83 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f32(
84 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
88 template <index_t MPerWave, index_t NPerWave>
94 template <
class FloatC>
95 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
97 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x1f32(
98 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 2, 0, 0);
102 template <index_t MPerWave, index_t NPerWave>
108 template <
class FloatC>
109 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
111 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
112 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 4, 0, 0);
119 template <
class FloatC>
120 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
122 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
123 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 4, 0, 0);
124 reg_c.template AsType<float4_t>()(
Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
125 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<1>{}], 4, 1, 0);
130 template <index_t MPerWave, index_t NPerWave>
136 template <
class FloatC>
139 reg_c.template AsType<float32_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
140 reg_a, reg_b, reg_c.template AsType<float32_t>()[
Number<0>{}], 1, 0, 0);
141 reg_c.template AsType<float32_t>()(
Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
142 reg_a, reg_b, reg_c.template AsType<float32_t>()[
Number<1>{}], 1, 1, 0);
149 template <
class FloatC>
152 reg_c.template AsType<float32_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
153 reg_a, reg_b, reg_c.template AsType<float32_t>()[
Number<0>{}], 1, 0, 0);
157 template <index_t MPerWave, index_t NPerWave>
163 template <
class FloatC>
166 #if defined(__gfx950__)
167 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_f16(
168 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
177 template <index_t MPerWave, index_t NPerWave>
183 template <
class FloatC>
186 #if defined(__gfx950__)
187 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_f16(
188 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
197 template <index_t MPerWave, index_t NPerWave>
203 template <
class FloatC>
206 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8f16(
207 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
211 template <index_t MPerWave, index_t NPerWave>
217 template <
class FloatC>
220 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
221 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
225 template <index_t MPerWave, index_t NPerWave>
231 template <
class FloatC>
234 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f16(
235 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 2, 0, 0);
239 template <index_t MPerWave, index_t NPerWave>
245 template <
class FloatC>
248 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
249 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 4, 0, 0);
256 template <
class FloatC>
259 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
260 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 4, 0, 0);
261 reg_c.template AsType<float4_t>()(
Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
262 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<1>{}], 4, 1, 0);
267 template <index_t MPerWave, index_t NPerWave>
273 template <
class FloatC>
276 #if defined(__gfx950__)
277 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf16(
278 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
287 template <index_t MPerWave, index_t NPerWave>
293 template <
class FloatC>
296 #if defined(__gfx950__)
297 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf16(
298 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
307 template <index_t MPerWave, index_t NPerWave>
313 template <
class FloatC>
316 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
317 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
321 template <index_t MPerWave, index_t NPerWave>
327 template <
class FloatC>
330 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
331 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
335 template <index_t MPerWave, index_t NPerWave>
341 template <
class FloatC>
344 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
345 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
349 template <index_t MPerWave, index_t NPerWave>
355 template <
class FloatC>
358 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8bf16(
359 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
363 template <index_t MPerWave, index_t NPerWave>
369 template <
class FloatC>
372 reg_c.template AsType<int32x16_t>()(
Number<0>{}) =
373 __builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast<int32_t>(reg_a),
374 bit_cast<int32_t>(reg_b),
375 reg_c.template AsType<int32x16_t>()[
Number<0>{}],
382 template <index_t MPerWave, index_t NPerWave>
388 template <
class FloatC>
391 reg_c.template AsType<int32x4_t>()(
Number<0>{}) =
392 __builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int32_t>(reg_a),
393 bit_cast<int32_t>(reg_b),
394 reg_c.template AsType<int32x4_t>()[
Number<0>{}],
401 template <index_t MPerWave, index_t NPerWave>
407 template <
class FloatC>
410 #if defined(__gfx950__)
411 reg_c.template AsType<int32x16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_i32_32x32x32_i8(
412 reg_a, reg_b, reg_c.template AsType<int32x16_t>()[
Number<0>{}], 0, 0, 0);
421 template <index_t MPerWave, index_t NPerWave>
427 template <
class FloatC>
430 #if defined(__gfx950__)
431 reg_c.template AsType<int32x4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_i32_16x16x64_i8(
432 reg_a, reg_b, reg_c.template AsType<int32x4_t>()[
Number<0>{}], 0, 0, 0);
441 template <index_t MPerWave, index_t NPerWave>
447 template <
class FloatC>
450 reg_c.template AsType<int32x16_t>()(
Number<0>{}) =
451 __builtin_amdgcn_mfma_i32_32x32x16_i8(bit_cast<int64_t>(reg_a),
452 bit_cast<int64_t>(reg_b),
453 reg_c.template AsType<int32x16_t>()[
Number<0>{}],
460 template <index_t MPerWave, index_t NPerWave>
466 template <
class FloatC>
469 reg_c.template AsType<int32x4_t>()(
Number<0>{}) =
470 __builtin_amdgcn_mfma_i32_16x16x32_i8(bit_cast<int64_t>(reg_a),
471 bit_cast<int64_t>(reg_b),
472 reg_c.template AsType<int32x4_t>()[
Number<0>{}],
479 template <index_t MPerWave, index_t NPerWave>
485 template <
class FloatC>
486 __device__
static void Run(
const double& reg_a,
const double& reg_b, FloatC& reg_c)
488 #if defined(__gfx90a__) || defined(__gfx94__)
489 reg_c.template AsType<double4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64(
490 reg_a, reg_b, reg_c.template AsType<double4_t>()[
Number<0>{}], 0, 0, 0);
499 template <index_t MPerWave, index_t NPerWave>
511 template <
class FloatC>
512 __device__
static void Run(
const f8x32_t& reg_a,
const f8x32_t& reg_b, FloatC& reg_c)
514 #if defined(__gfx950__)
515 reg_c.template AsType<float16_t>()(
Number<0>{}) =
516 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
519 reg_c.template AsType<float16_t>()[
Number<0>{}],
533 template <
class FloatC>
536 #if defined(__gfx950__)
537 reg_c.template AsType<float16_t>()(
Number<0>{}) =
538 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
541 reg_c.template AsType<float16_t>()[
Number<0>{}],
555 template <
class FloatC>
556 __device__
static void Run(
const bf8x32_t& reg_a,
const f8x32_t& reg_b, FloatC& reg_c)
558 #if defined(__gfx950__)
559 reg_c.template AsType<float16_t>()(
Number<0>{}) =
560 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
563 reg_c.template AsType<float16_t>()[
Number<0>{}],
577 template <
class FloatC>
578 __device__
static void Run(
const f8x32_t& reg_a,
const bf8x32_t& reg_b, FloatC& reg_c)
580 #if defined(__gfx950__)
581 reg_c.template AsType<float16_t>()(
Number<0>{}) =
582 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
585 reg_c.template AsType<float16_t>()[
Number<0>{}],
599 template <
class FloatC>
602 #if defined(__gfx950__)
604 int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
605 int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
609 reg_c.template AsType<float16_t>()(
Number<0>{}) =
610 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
611 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
612 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
613 reg_c.template AsType<float16_t>()[
Number<0>{}],
627 template <
class FloatC>
630 #if defined(__gfx950__)
632 int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
633 int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
637 reg_c.template AsType<float16_t>()(
Number<0>{}) =
638 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
639 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
640 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
641 reg_c.template AsType<float16_t>()[
Number<0>{}],
655 template <
class FloatC>
658 #if defined(__gfx950__)
660 int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
661 int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
665 reg_c.template AsType<float16_t>()(
Number<0>{}) =
666 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
667 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
668 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
669 reg_c.template AsType<float16_t>()[
Number<0>{}],
684 template <index_t MPerWave, index_t NPerWave, index_t OpselA, index_t OpselB>
687 template <index_t OpselA, index_t OpselB>
690 template <
class FloatC>
691 __device__
static void Run(
const f8x32_t& reg_a,
693 const f8x32_t& reg_b,
697 #if defined(__gfx950__)
699 reg_c.template AsType<float16_t>()(
Number<0>{}) =
700 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
703 reg_c.template AsType<float16_t>()[
Number<0>{}],
727 template <
class FloatC>
734 #if defined(__gfx950__)
736 reg_c.template AsType<float16_t>()(
Number<0>{}) =
737 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
740 reg_c.template AsType<float16_t>()[
Number<0>{}],
764 template <
class FloatC>
767 const f8x32_t& reg_b,
771 #if defined(__gfx950__)
773 reg_c.template AsType<float16_t>()(
Number<0>{}) =
774 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
777 reg_c.template AsType<float16_t>()[
Number<0>{}],
801 template <
class FloatC>
808 #if defined(__gfx950__)
810 int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
811 int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
815 reg_c.template AsType<float16_t>()(
Number<0>{}) =
816 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
817 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
818 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
819 reg_c.template AsType<float16_t>()[
Number<0>{}],
835 template <
class FloatC>
842 #if defined(__gfx950__)
844 int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
845 int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
849 reg_c.template AsType<float16_t>()(
Number<0>{}) =
850 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
851 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
852 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
853 reg_c.template AsType<float16_t>()[
Number<0>{}],
869 template <
class FloatC>
876 #if defined(__gfx950__)
878 int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
879 int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
883 reg_c.template AsType<float16_t>()(
Number<0>{}) =
884 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
885 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
886 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
887 reg_c.template AsType<float16_t>()[
Number<0>{}],
904 template <index_t MPerWave, index_t NPerWave, index_t OpselA, index_t OpselB>
907 template <index_t OpselA, index_t OpselB>
910 template <
class FloatC>
911 __device__
static void Run(
const f8x32_t& reg_a,
913 const f8x32_t& reg_b,
917 #if defined(__gfx950__)
919 reg_c.template AsType<float4_t>()(
Number<0>{}) =
920 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
923 reg_c.template AsType<float4_t>()[
Number<0>{}],
939 template <
class FloatC>
946 #if defined(__gfx950__)
948 reg_c.template AsType<float4_t>()(
Number<0>{}) =
949 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
952 reg_c.template AsType<float4_t>()[
Number<0>{}],
968 template <
class FloatC>
969 __device__
static void Run(
const f8x32_t& reg_a,
975 #if defined(__gfx950__)
977 reg_c.template AsType<float4_t>()(
Number<0>{}) =
978 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
981 reg_c.template AsType<float4_t>()[
Number<0>{}],
997 template <
class FloatC>
1000 const f8x32_t& reg_b,
1004 #if defined(__gfx950__)
1006 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1007 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1010 reg_c.template AsType<float4_t>()[
Number<0>{}],
1026 template <
class FloatC>
1033 #if defined(__gfx950__)
1034 int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
1035 int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
1039 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1040 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1041 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
1042 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
1043 reg_c.template AsType<float4_t>()[
Number<0>{}],
1059 template <
class FloatC>
1066 #if defined(__gfx950__)
1069 static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[
Number<0>{}][0]),
1070 static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[
Number<0>{}][1]),
1071 static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[
Number<0>{}][2]),
1072 static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[
Number<1>{}][0]),
1073 static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[
Number<1>{}][1]),
1074 static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[
Number<1>{}][2]),
1078 static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[
Number<0>{}][0]),
1079 static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[
Number<0>{}][1]),
1080 static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[
Number<0>{}][2]),
1081 static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[
Number<1>{}][0]),
1082 static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[
Number<1>{}][1]),
1083 static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[
Number<1>{}][2]),
1087 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1088 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1091 reg_c.template AsType<float4_t>()[
Number<0>{}],
1107 template <
class FloatC>
1114 #if defined(__gfx950__)
1115 int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
1116 int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
1120 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1121 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1122 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
1123 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
1124 reg_c.template AsType<float4_t>()[
Number<0>{}],
1140 template <
class FloatC>
1147 #if defined(__gfx950__)
1150 static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[
Number<0>{}][0]),
1151 static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[
Number<0>{}][1]),
1152 static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[
Number<0>{}][2]),
1153 static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[
Number<1>{}][0]),
1154 static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[
Number<1>{}][1]),
1155 static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[
Number<1>{}][2]),
1159 static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[
Number<0>{}][0]),
1160 static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[
Number<0>{}][1]),
1161 static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[
Number<0>{}][2]),
1162 static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[
Number<1>{}][0]),
1163 static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[
Number<1>{}][1]),
1164 static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[
Number<1>{}][2]),
1168 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1169 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1172 reg_c.template AsType<float4_t>()[
Number<0>{}],
1188 template <
class FloatC>
1195 #if defined(__gfx950__)
1196 int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
1197 int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
1199 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1200 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1201 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
1202 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
1203 reg_c.template AsType<float4_t>()[
Number<0>{}],
1220 template <index_t MPerWave, index_t NPerWave>
1232 template <
class FloatC>
1233 __device__
static void Run(
const f8x32_t& reg_a,
const f8x32_t& reg_b, FloatC& reg_c)
1235 #if defined(__gfx950__)
1237 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1238 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1241 reg_c.template AsType<float4_t>()[
Number<0>{}],
1255 template <
class FloatC>
1258 #if defined(__gfx950__)
1260 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1261 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1264 reg_c.template AsType<float4_t>()[
Number<0>{}],
1278 template <
class FloatC>
1279 __device__
static void Run(
const bf8x32_t& reg_a,
const f8x32_t& reg_b, FloatC& reg_c)
1281 #if defined(__gfx950__)
1283 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1284 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1287 reg_c.template AsType<float4_t>()[
Number<0>{}],
1301 template <
class FloatC>
1302 __device__
static void Run(
const f8x32_t& reg_a,
const bf8x32_t& reg_b, FloatC& reg_c)
1304 #if defined(__gfx950__)
1306 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1307 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1310 reg_c.template AsType<float4_t>()[
Number<0>{}],
1324 template <
class FloatC>
1327 #if defined(__gfx950__)
1328 int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
1329 int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
1333 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1334 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1335 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
1336 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
1337 reg_c.template AsType<float4_t>()[
Number<0>{}],
1351 template <
class FloatC>
1354 #if defined(__gfx950__)
1355 int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
1356 int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
1360 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1361 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1362 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
1363 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
1364 reg_c.template AsType<float4_t>()[
Number<0>{}],
1378 template <
class FloatC>
1381 #if defined(__gfx950__)
1382 int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
1383 int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
1387 reg_c.template AsType<float4_t>()(
Number<0>{}) =
1388 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1389 arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
1390 arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
1391 reg_c.template AsType<float4_t>()[
Number<0>{}],
1406 template <index_t MPerWave, index_t NPerWave>
1412 template <
class FloatC>
1413 __device__
static void Run(
const f8x8_t& reg_a,
const f8x8_t& reg_b, FloatC& reg_c)
1415 #if defined(__gfx94__)
1416 reg_c.template AsType<float16_t>()(
Number<0>{}) =
1417 __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
1418 bit_cast<int64_t>(reg_a),
1419 bit_cast<int64_t>(reg_b),
1420 reg_c.template AsType<float16_t>()[
Number<0>{}],
1429 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[
Number<k>{}]);
1430 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[
Number<k>{}]);
1438 template <index_t MPerWave, index_t NPerWave>
1444 template <
class FloatC>
1445 __device__
static void Run(
const f8x8_t& reg_a,
const f8x8_t& reg_b, FloatC& reg_c)
1447 #if defined(__gfx94__)
1448 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
1449 bit_cast<int64_t>(reg_a),
1450 bit_cast<int64_t>(reg_b),
1451 reg_c.template AsType<float4_t>()[
Number<0>{}],
1460 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[
Number<k>{}]);
1461 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[
Number<k>{}]);
1469 template <index_t MPerWave, index_t NPerWave>
1475 template <
class FloatC>
1478 #if defined(__gfx94__)
1479 reg_c.template AsType<float16_t>()(
Number<0>{}) =
1480 __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
1481 bit_cast<int64_t>(reg_a),
1482 bit_cast<int64_t>(reg_b),
1483 reg_c.template AsType<float16_t>()[
Number<0>{}],
1492 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[
Number<k>{}]);
1493 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[
Number<k>{}]);
1501 template <index_t MPerWave, index_t NPerWave>
1507 template <
class FloatC>
1510 #if defined(__gfx94__)
1511 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
1512 bit_cast<int64_t>(reg_a),
1513 bit_cast<int64_t>(reg_b),
1514 reg_c.template AsType<float4_t>()[
Number<0>{}],
1523 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[
Number<k>{}]);
1524 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[
Number<k>{}]);
1532 template <index_t MPerWave, index_t NPerWave>
1538 template <
class FloatC>
1539 __device__
static void Run(
const f8x8_t& reg_a,
const bf8x8_t& reg_b, FloatC& reg_c)
1541 #if defined(__gfx94__)
1542 reg_c.template AsType<float16_t>()(
Number<0>{}) =
1543 __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
1544 bit_cast<int64_t>(reg_a),
1545 bit_cast<int64_t>(reg_b),
1546 reg_c.template AsType<float16_t>()[
Number<0>{}],
1555 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[
Number<k>{}]);
1556 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[
Number<k>{}]);
1564 template <index_t MPerWave, index_t NPerWave>
1570 template <
class FloatC>
1571 __device__
static void Run(
const f8x8_t& reg_a,
const bf8x8_t& reg_b, FloatC& reg_c)
1573 #if defined(__gfx94__)
1574 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
1575 bit_cast<int64_t>(reg_a),
1576 bit_cast<int64_t>(reg_b),
1577 reg_c.template AsType<float4_t>()[
Number<0>{}],
1586 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[
Number<k>{}]);
1587 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[
Number<k>{}]);
1595 template <index_t MPerWave, index_t NPerWave>
1601 template <
class FloatC>
1602 __device__
static void Run(
const bf8x8_t& reg_a,
const f8x8_t& reg_b, FloatC& reg_c)
1604 #if defined(__gfx94__)
1605 reg_c.template AsType<float16_t>()(
Number<0>{}) =
1606 __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
1607 bit_cast<int64_t>(reg_a),
1608 bit_cast<int64_t>(reg_b),
1609 reg_c.template AsType<float16_t>()[
Number<0>{}],
1618 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[
Number<k>{}]);
1619 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[
Number<k>{}]);
1627 template <index_t MPerWave, index_t NPerWave>
1633 template <
class FloatC>
1634 __device__
static void Run(
const bf8x8_t& reg_a,
const f8x8_t& reg_b, FloatC& reg_c)
1636 #if defined(__gfx94__)
1637 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
1638 bit_cast<int64_t>(reg_a),
1639 bit_cast<int64_t>(reg_b),
1640 reg_c.template AsType<float4_t>()[
Number<0>{}],
1649 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[
Number<k>{}]);
1650 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[
Number<k>{}]);
1659 template <index_t MPerWave, index_t NPerWave>
1665 template <
class FloatC>
1668 #if defined(__gfx942__)
1669 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8_xf32(
1670 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
1679 template <index_t MPerWave, index_t NPerWave>
1685 template <
class FloatC>
1688 #if defined(__gfx942__)
1689 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4_xf32(
1690 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
1710 template <index_t MPerWave, index_t NPerWave>
1716 template <
class FloatC>
1719 #if defined(__gfx950__)
1734 v_reg_a_bf16_small.template AsType<bhalf8_t>()[I0{}],
1735 v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
1738 v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
1739 v_reg_b_bf16_small.template AsType<bhalf8_t>()[I0{}],
1742 v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
1743 v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
1753 template <index_t MPerWave, index_t NPerWave>
1759 template <
class FloatC>
1762 #if defined(__gfx950__)
1777 v_reg_a_bf16_small.template AsType<bhalf8_t>()[I0{}],
1778 v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
1781 v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
1782 v_reg_b_bf16_small.template AsType<bhalf8_t>()[I0{}],
1785 v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
1786 v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
bf8_t bf8x32_t
Definition: vector_type.hpp:240
bf8_t bf8x8_t
Definition: vector_type.hpp:238
typename vector_type< bf6x16_pk_t, 2 >::type bf6x16x2_t
Definition: dtype_vector.hpp:2273
typename vector_type< f6x16_pk_t, 2 >::type f6x16x2_t
Definition: dtype_vector.hpp:2268
typename vector_type< f6x32_pk_t, 1 >::type f6x32_t
Definition: dtype_vector.hpp:2269
typename vector_type< bhalf_t, 4 >::type bhalf4_t
Definition: dtype_vector.hpp:2162
__device__ __forceinline__ void convert_float_to_bf16_pairs(const vector_type< float, VecSize > ®_f32, vector_type< bhalf_t, VecSize > ®_bf16_big, vector_type< bhalf_t, VecSize > ®_bf16_small)
Definition: amd_xdlops.hpp:17
typename vector_type< bhalf_t, 8 >::type bhalf8_t
Definition: dtype_vector.hpp:2163
typename vector_type< float, 2 >::type float2_t
Definition: dtype_vector.hpp:2146
typename vector_type< int8_t, 8 >::type int8x8_t
Definition: dtype_vector.hpp:2179
typename vector_type< half_t, 4 >::type half4_t
Definition: dtype_vector.hpp:2155
typename vector_type< bf6x32_pk_t, 1 >::type bf6x32_t
Definition: dtype_vector.hpp:2274
typename vector_type< int32_t, 8 >::type int32x8_t
Definition: dtype_vector.hpp:2171
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
typename vector_type< float, 8 >::type float8_t
Definition: dtype_vector.hpp:2148
typename vector_type< f4x2_pk_t, 16 >::type f4x32_t
Definition: dtype_vector.hpp:2263
typename vector_type< bhalf_t, 2 >::type bhalf2_t
Definition: dtype_vector.hpp:2161
typename vector_type< int8_t, 16 >::type int8x16_t
Definition: dtype_vector.hpp:2180
typename vector_type< int32_t, 4 >::type int32x4_t
Definition: dtype_vector.hpp:2169
typename vector_type< int8_t, 4 >::type int8x4_t
Definition: dtype_vector.hpp:2178
typename vector_type< int32_t, 6 >::type int32x6_t
Definition: dtype_vector.hpp:2170
__host__ constexpr __device__ bhalf_t type_convert< bhalf_t, float >(float x)
Definition: type_convert.hpp:133
typename vector_type< half_t, 8 >::type half8_t
Definition: dtype_vector.hpp:2156
__host__ constexpr __device__ float type_convert< float, bhalf_t >(bhalf_t x)
Definition: type_convert.hpp:120
signed int int32_t
Definition: stdint.h:123
Definition: integral_constant.hpp:20
static __device__ void Run(const bf6x32_t ®_a, const bf6x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1379
static __device__ void Run(const f6x32_t ®_a, const f6x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1352
static __device__ void Run(const f8x32_t ®_a, const bf8x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1302
static __device__ void Run(const bf8x32_t ®_a, const bf8x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1256
static __device__ void Run(const bf8x32_t ®_a, const f8x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1279
static __device__ void Run(const f4x32_t ®_a, const f4x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1325
static __device__ void Run(const f8x32_t ®_a, const f8x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1233
Definition: amd_xdlops.hpp:1221
static __device__ void Run(const bhalf4_t ®_a, const bhalf4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:328
Definition: amd_xdlops.hpp:322
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:218
Definition: amd_xdlops.hpp:212
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:95
Definition: amd_xdlops.hpp:89
static __device__ void Run(const bhalf8_t ®_a, const bhalf8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:294
Definition: amd_xdlops.hpp:288
static __device__ void Run(const bf8x8_t ®_a, const bf8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1508
Definition: amd_xdlops.hpp:1502
static __device__ void Run(const bf8x8_t ®_a, const f8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1634
Definition: amd_xdlops.hpp:1628
static __device__ void Run(const half8_t ®_a, const half8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:184
Definition: amd_xdlops.hpp:178
static __device__ void Run(const f8x8_t ®_a, const bf8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1571
Definition: amd_xdlops.hpp:1565
static __device__ void Run(const f8x8_t ®_a, const f8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1445
Definition: amd_xdlops.hpp:1439
static __device__ void Run(const float8_t ®_a, const float8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1717
Definition: amd_xdlops.hpp:1711
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:232
Definition: amd_xdlops.hpp:226
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:81
Definition: amd_xdlops.hpp:75
static __device__ void Run(const bhalf2_t ®_a, const bhalf2_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:356
Definition: amd_xdlops.hpp:350
static __device__ void Run(const float2_t ®_a, const float2_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1666
Definition: amd_xdlops.hpp:1660
static __device__ void Run(const bhalf8_t ®_a, const bhalf8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:274
Definition: amd_xdlops.hpp:268
static __device__ void Run(const bf8x8_t ®_a, const bf8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1476
Definition: amd_xdlops.hpp:1470
static __device__ void Run(const bf8x8_t ®_a, const f8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1602
Definition: amd_xdlops.hpp:1596
static __device__ void Run(const half8_t ®_a, const half8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:164
Definition: amd_xdlops.hpp:158
static __device__ void Run(const f8x8_t ®_a, const bf8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1539
Definition: amd_xdlops.hpp:1533
static __device__ void Run(const f8x8_t ®_a, const f8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1413
Definition: amd_xdlops.hpp:1407
static __device__ void Run(const float8_t ®_a, const float8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1760
Definition: amd_xdlops.hpp:1754
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:53
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:40
Definition: amd_xdlops.hpp:34
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:67
Definition: amd_xdlops.hpp:61
static __device__ void Run(const bhalf2_t ®_a, const bhalf2_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:342
Definition: amd_xdlops.hpp:336
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:150
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:137
Definition: amd_xdlops.hpp:131
static __device__ void Run(const float2_t ®_a, const float2_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1686
Definition: amd_xdlops.hpp:1680
static __device__ void Run(const bf8x32_t ®_a, const f8x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:556
static __device__ void Run(const f8x32_t ®_a, const bf8x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:578
static __device__ void Run(const bf6x32_t ®_a, const bf6x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:656
static __device__ void Run(const f6x32_t ®_a, const f6x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:628
static __device__ void Run(const f8x32_t ®_a, const f8x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:512
static __device__ void Run(const f4x32_t ®_a, const f4x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:600
static __device__ void Run(const bf8x32_t ®_a, const bf8x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:534
Definition: amd_xdlops.hpp:500
static __device__ void Run(const bhalf4_t ®_a, const bhalf4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:314
Definition: amd_xdlops.hpp:308
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:204
Definition: amd_xdlops.hpp:198
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:109
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:120
Definition: amd_xdlops.hpp:103
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:246
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:257
Definition: amd_xdlops.hpp:240
static __device__ void Run(const double ®_a, const double ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:486
Definition: amd_xdlops.hpp:480
static __device__ void Run(const int8x4_t ®_a, const int8x4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:389
Definition: amd_xdlops.hpp:383
static __device__ void Run(const int8x8_t ®_a, const int8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:467
Definition: amd_xdlops.hpp:461
static __device__ void Run(const int8x16_t ®_a, const int8x16_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:428
Definition: amd_xdlops.hpp:422
static __device__ void Run(const int8x8_t ®_a, const int8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:448
Definition: amd_xdlops.hpp:442
static __device__ void Run(const int8x16_t ®_a, const int8x16_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:408
Definition: amd_xdlops.hpp:402
static __device__ void Run(const int8x4_t ®_a, const int8x4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:370
Definition: amd_xdlops.hpp:364
static __device__ void Run(const f6x16x2_t ®_a, const int32_t scale_a, const f6x16x2_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1060
static __device__ void Run(const f4x32_t ®_a, const int32_t scale_a, const f4x32_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1189
static __device__ void Run(const f6x32_t ®_a, const int32_t scale_a, const f6x32_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1027
static __device__ void Run(const f8x32_t ®_a, const int32_t &scale_a, const f8x32_t ®_b, const int32_t &scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:911
static __device__ void Run(const bf8x32_t ®_a, const int32_t &scale_a, const f8x32_t ®_b, const int32_t &scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:998
static __device__ void Run(const f8x32_t ®_a, const int32_t &scale_a, const bf8x32_t ®_b, const int32_t &scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:969
static __device__ void Run(const bf6x16x2_t ®_a, const int32_t scale_a, const bf6x16x2_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1141
static __device__ void Run(const bf6x32_t ®_a, const int32_t scale_a, const bf6x32_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:1108
static __device__ void Run(const bf8x32_t ®_a, const int32_t &scale_a, const bf8x32_t ®_b, const int32_t &scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:940
Definition: amd_xdlops.hpp:905
static __device__ void Run(const f8x32_t ®_a, const int32_t &scale_a, const f8x32_t ®_b, const int32_t &scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:691
static __device__ void Run(const f6x32_t ®_a, const int32_t scale_a, const f6x32_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:802
static __device__ void Run(const bf8x32_t ®_a, const int32_t &scale_a, const bf8x32_t ®_b, const int32_t &scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:728
static __device__ void Run(const bf8x32_t ®_a, const int32_t &scale_a, const f8x32_t ®_b, const int32_t &scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:765
static __device__ void Run(const f4x32_t ®_a, const int32_t scale_a, const f4x32_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:870
static __device__ void Run(const bf6x32_t ®_a, const int32_t scale_a, const bf6x32_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:836
Definition: amd_xdlops.hpp:685
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:11