8 #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
13 template <index_t MPerWave, index_t NPerWave>
19 template <
class FloatC>
20 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
22 reg_c.template AsType<float32_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
23 reg_a, reg_b, reg_c.template AsType<float32_t>()[
Number<0>{}], 1, 0, 0);
24 reg_c.template AsType<float32_t>()(
Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
25 reg_a, reg_b, reg_c.template AsType<float32_t>()[
Number<1>{}], 1, 1, 0);
32 template <
class FloatC>
33 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
35 reg_c.template AsType<float32_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
36 reg_a, reg_b, reg_c.template AsType<float32_t>()[
Number<0>{}], 1, 0, 0);
40 template <index_t MPerWave, index_t NPerWave>
46 template <
class FloatC>
47 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
49 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x2f32(
50 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
54 template <index_t MPerWave, index_t NPerWave>
60 template <
class FloatC>
61 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
63 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f32(
64 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
68 template <index_t MPerWave, index_t NPerWave>
74 template <
class FloatC>
75 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
77 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x1f32(
78 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 2, 0, 0);
82 template <index_t MPerWave, index_t NPerWave>
88 template <
class FloatC>
89 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
91 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
92 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 4, 0, 0);
99 template <
class FloatC>
100 __device__
static void Run(
const float& reg_a,
const float& reg_b, FloatC& reg_c)
102 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
103 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 4, 0, 0);
104 reg_c.template AsType<float4_t>()(
Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
105 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<1>{}], 4, 1, 0);
110 template <index_t MPerWave, index_t NPerWave>
116 template <
class FloatC>
119 reg_c.template AsType<float32_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
120 reg_a, reg_b, reg_c.template AsType<float32_t>()[
Number<0>{}], 1, 0, 0);
121 reg_c.template AsType<float32_t>()(
Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
122 reg_a, reg_b, reg_c.template AsType<float32_t>()[
Number<1>{}], 1, 1, 0);
129 template <
class FloatC>
132 reg_c.template AsType<float32_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
133 reg_a, reg_b, reg_c.template AsType<float32_t>()[
Number<0>{}], 1, 0, 0);
137 template <index_t MPerWave, index_t NPerWave>
143 template <
class FloatC>
146 #if defined(__gfx950__)
147 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_f16(
148 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
157 template <index_t MPerWave, index_t NPerWave>
163 template <
class FloatC>
166 #if defined(__gfx950__)
167 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_f16(
168 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
177 template <index_t MPerWave, index_t NPerWave>
183 template <
class FloatC>
186 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8f16(
187 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
191 template <index_t MPerWave, index_t NPerWave>
197 template <
class FloatC>
200 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
201 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
205 template <index_t MPerWave, index_t NPerWave>
211 template <
class FloatC>
214 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f16(
215 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 2, 0, 0);
219 template <index_t MPerWave, index_t NPerWave>
225 template <
class FloatC>
228 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
229 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 4, 0, 0);
236 template <
class FloatC>
239 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
240 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 4, 0, 0);
241 reg_c.template AsType<float4_t>()(
Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
242 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<1>{}], 4, 1, 0);
247 template <index_t MPerWave, index_t NPerWave>
253 template <
class FloatC>
256 #if defined(__gfx950__)
257 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf16(
258 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
267 template <index_t MPerWave, index_t NPerWave>
273 template <
class FloatC>
276 #if defined(__gfx950__)
277 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf16(
278 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
287 template <index_t MPerWave, index_t NPerWave>
293 template <
class FloatC>
296 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
297 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
301 template <index_t MPerWave, index_t NPerWave>
307 template <
class FloatC>
310 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
311 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
315 template <index_t MPerWave, index_t NPerWave>
321 template <
class FloatC>
324 reg_c.template AsType<float16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
325 reg_a, reg_b, reg_c.template AsType<float16_t>()[
Number<0>{}], 0, 0, 0);
329 template <index_t MPerWave, index_t NPerWave>
335 template <
class FloatC>
338 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8bf16(
339 reg_a, reg_b, reg_c.template AsType<float4_t>()[
Number<0>{}], 0, 0, 0);
343 template <index_t MPerWave, index_t NPerWave>
349 template <
class FloatC>
352 reg_c.template AsType<int32x16_t>()(
Number<0>{}) =
353 __builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast<int32_t>(reg_a),
354 bit_cast<int32_t>(reg_b),
355 reg_c.template AsType<int32x16_t>()[
Number<0>{}],
362 template <index_t MPerWave, index_t NPerWave>
368 template <
class FloatC>
371 reg_c.template AsType<int32x4_t>()(
Number<0>{}) =
372 __builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int32_t>(reg_a),
373 bit_cast<int32_t>(reg_b),
374 reg_c.template AsType<int32x4_t>()[
Number<0>{}],
381 template <index_t MPerWave, index_t NPerWave>
387 template <
class FloatC>
390 #if defined(__gfx950__)
391 reg_c.template AsType<int32x16_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_i32_32x32x32_i8(
392 reg_a, reg_b, reg_c.template AsType<int32x16_t>()[
Number<0>{}], 0, 0, 0);
401 template <index_t MPerWave, index_t NPerWave>
407 template <
class FloatC>
410 #if defined(__gfx950__)
411 reg_c.template AsType<int32x4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_i32_16x16x64_i8(
412 reg_a, reg_b, reg_c.template AsType<int32x4_t>()[
Number<0>{}], 0, 0, 0);
421 template <index_t MPerWave, index_t NPerWave>
427 template <
class FloatC>
430 reg_c.template AsType<int32x16_t>()(
Number<0>{}) =
431 __builtin_amdgcn_mfma_i32_32x32x16_i8(bit_cast<int64_t>(reg_a),
432 bit_cast<int64_t>(reg_b),
433 reg_c.template AsType<int32x16_t>()[
Number<0>{}],
440 template <index_t MPerWave, index_t NPerWave>
446 template <
class FloatC>
449 reg_c.template AsType<int32x4_t>()(
Number<0>{}) =
450 __builtin_amdgcn_mfma_i32_16x16x32_i8(bit_cast<int64_t>(reg_a),
451 bit_cast<int64_t>(reg_b),
452 reg_c.template AsType<int32x4_t>()[
Number<0>{}],
459 template <index_t MPerWave, index_t NPerWave>
465 template <
class FloatC>
466 __device__
static void Run(
const double& reg_a,
const double& reg_b, FloatC& reg_c)
468 #if defined(__gfx90a__) || defined(__gfx94__)
469 reg_c.template AsType<double4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64(
470 reg_a, reg_b, reg_c.template AsType<double4_t>()[
Number<0>{}], 0, 0, 0);
479 template <index_t MPerWave, index_t NPerWave>
491 template <
class FloatC>
492 __device__
static void Run(
const f8x32_t& reg_a,
const f8x32_t& reg_b, FloatC& reg_c)
494 #if defined(__gfx950__)
495 reg_c.template AsType<float16_t>()(
Number<0>{}) =
496 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
499 reg_c.template AsType<float16_t>()[
Number<0>{}],
514 template <index_t MPerWave, index_t NPerWave>
520 template <
class FloatC>
521 __device__
static void Run(
const f8x32_t& reg_a,
522 const int32_t scale_a,
523 const f8x32_t& reg_b,
524 const int32_t scale_b,
527 #if defined(__gfx950__)
529 reg_c.template AsType<float16_t>()(
Number<0>{}) =
530 __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
533 reg_c.template AsType<float16_t>()[
Number<0>{}],
550 template <index_t MPerWave, index_t NPerWave>
556 template <
class FloatC>
557 __device__
static void Run(
const f8x32_t& reg_a,
558 const int32_t scale_a,
559 const f8x32_t& reg_b,
560 const int32_t scale_b,
563 #if defined(__gfx950__)
565 reg_c.template AsType<float4_t>()(
Number<0>{}) =
566 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
569 reg_c.template AsType<float4_t>()[
Number<0>{}],
586 template <index_t MPerWave, index_t NPerWave>
598 template <
class FloatC>
599 __device__
static void Run(
const f8x32_t& reg_a,
const f8x32_t& reg_b, FloatC& reg_c)
601 #if defined(__gfx950__)
603 reg_c.template AsType<float4_t>()(
Number<0>{}) =
604 __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
607 reg_c.template AsType<float4_t>()[
Number<0>{}],
622 template <index_t MPerWave, index_t NPerWave>
628 template <
class FloatC>
629 __device__
static void Run(
const f8x8_t& reg_a,
const f8x8_t& reg_b, FloatC& reg_c)
631 #if defined(__gfx94__)
632 reg_c.template AsType<float16_t>()(
Number<0>{}) =
633 __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
634 bit_cast<long>(reg_a),
635 bit_cast<long>(reg_b),
636 reg_c.template AsType<float16_t>()[
Number<0>{}],
645 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[
Number<k>{}]);
646 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[
Number<k>{}]);
654 template <index_t MPerWave, index_t NPerWave>
660 template <
class FloatC>
661 __device__
static void Run(
const f8x8_t& reg_a,
const f8x8_t& reg_b, FloatC& reg_c)
663 #if defined(__gfx94__)
664 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
665 bit_cast<long>(reg_a),
666 bit_cast<long>(reg_b),
667 reg_c.template AsType<float4_t>()[
Number<0>{}],
676 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[
Number<k>{}]);
677 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[
Number<k>{}]);
685 template <index_t MPerWave, index_t NPerWave>
691 template <
class FloatC>
694 #if defined(__gfx94__)
695 reg_c.template AsType<float16_t>()(
Number<0>{}) =
696 __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
697 bit_cast<long>(reg_a),
698 bit_cast<long>(reg_b),
699 reg_c.template AsType<float16_t>()[
Number<0>{}],
708 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[
Number<k>{}]);
709 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[
Number<k>{}]);
717 template <index_t MPerWave, index_t NPerWave>
723 template <
class FloatC>
726 #if defined(__gfx94__)
727 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
728 bit_cast<long>(reg_a),
729 bit_cast<long>(reg_b),
730 reg_c.template AsType<float4_t>()[
Number<0>{}],
739 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[
Number<k>{}]);
740 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[
Number<k>{}]);
748 template <index_t MPerWave, index_t NPerWave>
754 template <
class FloatC>
755 __device__
static void Run(
const f8x8_t& reg_a,
const bf8x8_t& reg_b, FloatC& reg_c)
757 #if defined(__gfx94__)
758 reg_c.template AsType<float16_t>()(
Number<0>{}) =
759 __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
760 bit_cast<long>(reg_a),
761 bit_cast<long>(reg_b),
762 reg_c.template AsType<float16_t>()[
Number<0>{}],
771 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[
Number<k>{}]);
772 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[
Number<k>{}]);
780 template <index_t MPerWave, index_t NPerWave>
786 template <
class FloatC>
787 __device__
static void Run(
const f8x8_t& reg_a,
const bf8x8_t& reg_b, FloatC& reg_c)
789 #if defined(__gfx94__)
790 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
791 bit_cast<long>(reg_a),
792 bit_cast<long>(reg_b),
793 reg_c.template AsType<float4_t>()[
Number<0>{}],
802 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[
Number<k>{}]);
803 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[
Number<k>{}]);
811 template <index_t MPerWave, index_t NPerWave>
817 template <
class FloatC>
818 __device__
static void Run(
const bf8x8_t& reg_a,
const f8x8_t& reg_b, FloatC& reg_c)
820 #if defined(__gfx94__)
821 reg_c.template AsType<float16_t>()(
Number<0>{}) =
822 __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
823 bit_cast<long>(reg_a),
824 bit_cast<long>(reg_b),
825 reg_c.template AsType<float16_t>()[
Number<0>{}],
834 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[
Number<k>{}]);
835 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[
Number<k>{}]);
843 template <index_t MPerWave, index_t NPerWave>
849 template <
class FloatC>
850 __device__
static void Run(
const bf8x8_t& reg_a,
const f8x8_t& reg_b, FloatC& reg_c)
852 #if defined(__gfx94__)
853 reg_c.template AsType<float4_t>()(
Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
854 bit_cast<long>(reg_a),
855 bit_cast<long>(reg_b),
856 reg_c.template AsType<float4_t>()[
Number<0>{}],
865 float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[
Number<k>{}]);
866 float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[
Number<k>{}]);
bf8_t __attribute((ext_vector_type(8))) bf8x8_t
Definition: vector_type.hpp:197
typename vector_type< bhalf_t, 4 >::type bhalf4_t
Definition: data_type.hpp:2498
typename vector_type< bhalf_t, 8 >::type bhalf8_t
Definition: data_type.hpp:2499
typename vector_type< int8_t, 8 >::type int8x8_t
Definition: data_type.hpp:2515
typename vector_type< half_t, 4 >::type half4_t
Definition: data_type.hpp:2490
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
typename vector_type< bhalf_t, 2 >::type bhalf2_t
Definition: data_type.hpp:2497
typename vector_type< int8_t, 16 >::type int8x16_t
Definition: data_type.hpp:2516
typename vector_type< int8_t, 4 >::type int8x4_t
Definition: data_type.hpp:2514
typename vector_type< half_t, 8 >::type half8_t
Definition: data_type.hpp:2491
Definition: integral_constant.hpp:10
static __device__ void Run(const f8x32_t ®_a, const f8x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:599
Definition: amd_xdlops.hpp:587
static __device__ void Run(const bhalf4_t ®_a, const bhalf4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:308
Definition: amd_xdlops.hpp:302
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:198
Definition: amd_xdlops.hpp:192
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:75
Definition: amd_xdlops.hpp:69
static __device__ void Run(const bhalf8_t ®_a, const bhalf8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:274
Definition: amd_xdlops.hpp:268
static __device__ void Run(const bf8x8_t ®_a, const bf8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:724
Definition: amd_xdlops.hpp:718
static __device__ void Run(const bf8x8_t ®_a, const f8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:850
Definition: amd_xdlops.hpp:844
static __device__ void Run(const half8_t ®_a, const half8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:164
Definition: amd_xdlops.hpp:158
static __device__ void Run(const f8x8_t ®_a, const bf8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:787
Definition: amd_xdlops.hpp:781
static __device__ void Run(const f8x8_t ®_a, const f8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:661
Definition: amd_xdlops.hpp:655
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:212
Definition: amd_xdlops.hpp:206
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:61
Definition: amd_xdlops.hpp:55
static __device__ void Run(const bhalf2_t ®_a, const bhalf2_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:336
Definition: amd_xdlops.hpp:330
static __device__ void Run(const bhalf8_t ®_a, const bhalf8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:254
Definition: amd_xdlops.hpp:248
static __device__ void Run(const bf8x8_t ®_a, const bf8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:692
Definition: amd_xdlops.hpp:686
static __device__ void Run(const bf8x8_t ®_a, const f8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:818
Definition: amd_xdlops.hpp:812
static __device__ void Run(const half8_t ®_a, const half8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:144
Definition: amd_xdlops.hpp:138
static __device__ void Run(const f8x8_t ®_a, const bf8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:755
Definition: amd_xdlops.hpp:749
static __device__ void Run(const f8x8_t ®_a, const f8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:629
Definition: amd_xdlops.hpp:623
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:33
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:20
Definition: amd_xdlops.hpp:14
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:47
Definition: amd_xdlops.hpp:41
static __device__ void Run(const bhalf2_t ®_a, const bhalf2_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:322
Definition: amd_xdlops.hpp:316
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:130
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:117
Definition: amd_xdlops.hpp:111
static __device__ void Run(const f8x32_t ®_a, const f8x32_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:492
Definition: amd_xdlops.hpp:480
static __device__ void Run(const bhalf4_t ®_a, const bhalf4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:294
Definition: amd_xdlops.hpp:288
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:184
Definition: amd_xdlops.hpp:178
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:89
static __device__ void Run(const float ®_a, const float ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:100
Definition: amd_xdlops.hpp:83
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:226
static __device__ void Run(const half4_t ®_a, const half4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:237
Definition: amd_xdlops.hpp:220
static __device__ void Run(const double ®_a, const double ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:466
Definition: amd_xdlops.hpp:460
static __device__ void Run(const int8x4_t ®_a, const int8x4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:369
Definition: amd_xdlops.hpp:363
static __device__ void Run(const int8x8_t ®_a, const int8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:447
Definition: amd_xdlops.hpp:441
static __device__ void Run(const int8x16_t ®_a, const int8x16_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:408
Definition: amd_xdlops.hpp:402
static __device__ void Run(const int8x8_t ®_a, const int8x8_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:428
Definition: amd_xdlops.hpp:422
static __device__ void Run(const int8x16_t ®_a, const int8x16_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:388
Definition: amd_xdlops.hpp:382
static __device__ void Run(const int8x4_t ®_a, const int8x4_t ®_b, FloatC ®_c)
Definition: amd_xdlops.hpp:350
Definition: amd_xdlops.hpp:344
static __device__ void Run(const f8x32_t ®_a, const int32_t scale_a, const f8x32_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:557
Definition: amd_xdlops.hpp:551
static __device__ void Run(const f8x32_t ®_a, const int32_t scale_a, const f8x32_t ®_b, const int32_t scale_b, FloatC ®_c)
Definition: amd_xdlops.hpp:521
Definition: amd_xdlops.hpp:515
Definition: functional2.hpp:31
Definition: data_type.hpp:347