/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_xdlops.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_xdlops.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/utility/amd_xdlops.hpp Source File
amd_xdlops.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
6 
7 namespace ck {
8 // Define the common macro for MI300 models
9 #if defined(__gfx942__) || defined(__gfx950__)
10 #define __gfx94__
11 #endif
12 
13 // Helper function to convert float vector to bf16 vectors (big and small parts)
14 // This is used by both tf32 and xf32 implementations
15 template <index_t VecSize>
16 __device__ __forceinline__ void
18  vector_type<bhalf_t, VecSize>& reg_bf16_big,
19  vector_type<bhalf_t, VecSize>& reg_bf16_small)
20 {
21  static_for<0, VecSize, 1>{}([&](auto k) {
22  using IK = Number<k>;
23  reg_bf16_big.template AsType<bhalf_t>()(k) =
24  type_convert<bhalf_t, float>(reg_f32.template AsType<float>()[IK{}]);
25  reg_bf16_small.template AsType<bhalf_t>()(k) = type_convert<bhalf_t, float>(
26  reg_f32.template AsType<float>()[IK{}] -
27  type_convert<float, bhalf_t>(reg_bf16_big.template AsType<bhalf_t>()[IK{}]));
28  });
29 }
30 /* */
31 
32 // fp32
33 template <index_t MPerWave, index_t NPerWave>
35 
36 template <>
38 {
39  template <class FloatC>
40  __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
41  {
42  reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
43  reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
44  reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
45  reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
46  }
47 };
48 
49 template <>
51 {
52  template <class FloatC>
53  __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
54  {
55  reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
56  reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
57  }
58 };
59 
60 template <index_t MPerWave, index_t NPerWave>
62 
63 template <>
65 {
66  template <class FloatC>
67  __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
68  {
69  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x2f32(
70  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
71  }
72 };
73 
74 template <index_t MPerWave, index_t NPerWave>
76 
77 template <>
79 {
80  template <class FloatC>
81  __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
82  {
83  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f32(
84  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
85  }
86 };
87 
88 template <index_t MPerWave, index_t NPerWave>
90 
91 template <>
93 {
94  template <class FloatC>
95  __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
96  {
97  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x1f32(
98  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
99  }
100 };
101 
102 template <index_t MPerWave, index_t NPerWave>
104 
105 template <>
107 {
108  template <class FloatC>
109  __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
110  {
111  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
112  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
113  }
114 };
115 
116 template <>
118 {
119  template <class FloatC>
120  __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
121  {
122  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
123  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
124  reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
125  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
126  }
127 };
128 
129 // fp16
130 template <index_t MPerWave, index_t NPerWave>
132 
133 template <>
135 {
136  template <class FloatC>
137  __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
138  {
139  reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
140  reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
141  reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
142  reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
143  }
144 };
145 
146 template <>
148 {
149  template <class FloatC>
150  __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
151  {
152  reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
153  reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
154  }
155 };
156 
157 template <index_t MPerWave, index_t NPerWave>
159 
160 template <>
162 {
163  template <class FloatC>
164  __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c)
165  {
166 #if defined(__gfx950__)
167  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_f16(
168  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
169 #else
170  ignore = reg_a;
171  ignore = reg_b;
172  ignore = reg_c;
173 #endif // defined(__gfx950__)
174  }
175 };
176 
177 template <index_t MPerWave, index_t NPerWave>
179 
180 template <>
182 {
183  template <class FloatC>
184  __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c)
185  {
186 #if defined(__gfx950__)
187  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_f16(
188  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
189 #else
190  ignore = reg_a;
191  ignore = reg_b;
192  ignore = reg_c;
193 #endif // defined(__gfx950__)
194  }
195 };
196 
197 template <index_t MPerWave, index_t NPerWave>
199 
200 template <>
202 {
203  template <class FloatC>
204  __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
205  {
206  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8f16(
207  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
208  }
209 };
210 
211 template <index_t MPerWave, index_t NPerWave>
213 
214 template <>
216 {
217  template <class FloatC>
218  __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
219  {
220  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
221  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
222  }
223 };
224 
225 template <index_t MPerWave, index_t NPerWave>
227 
228 template <>
230 {
231  template <class FloatC>
232  __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
233  {
234  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f16(
235  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
236  }
237 };
238 
239 template <index_t MPerWave, index_t NPerWave>
241 
242 template <>
244 {
245  template <class FloatC>
246  __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
247  {
248  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
249  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
250  }
251 };
252 
253 template <>
255 {
256  template <class FloatC>
257  __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
258  {
259  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
260  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
261  reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
262  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
263  }
264 };
265 
266 // bfp16
267 template <index_t MPerWave, index_t NPerWave>
269 
270 template <>
272 {
273  template <class FloatC>
274  __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c)
275  {
276 #if defined(__gfx950__)
277  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf16(
278  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
279 #else
280  ignore = reg_a;
281  ignore = reg_b;
282  ignore = reg_c;
283 #endif // defined(__gfx950__)
284  }
285 };
286 
287 template <index_t MPerWave, index_t NPerWave>
289 
290 template <>
292 {
293  template <class FloatC>
294  __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c)
295  {
296 #if defined(__gfx950__)
297  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf16(
298  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
299 #else
300  ignore = reg_a;
301  ignore = reg_b;
302  ignore = reg_c;
303 #endif // defined(__gfx950__)
304  }
305 };
306 
307 template <index_t MPerWave, index_t NPerWave>
309 
310 template <>
312 {
313  template <class FloatC>
314  __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
315  {
316  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
317  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
318  }
319 };
320 
321 template <index_t MPerWave, index_t NPerWave>
323 
324 template <>
326 {
327  template <class FloatC>
328  __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
329  {
330  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
331  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
332  }
333 };
334 
335 template <index_t MPerWave, index_t NPerWave>
337 
338 template <>
340 {
341  template <class FloatC>
342  __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
343  {
344  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
345  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
346  }
347 };
348 
349 template <index_t MPerWave, index_t NPerWave>
351 
352 template <>
354 {
355  template <class FloatC>
356  __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
357  {
358  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8bf16(
359  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
360  }
361 };
362 
363 template <index_t MPerWave, index_t NPerWave>
365 
366 template <>
368 {
369  template <class FloatC>
370  __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
371  {
372  reg_c.template AsType<int32x16_t>()(Number<0>{}) =
373  __builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast<int32_t>(reg_a),
374  bit_cast<int32_t>(reg_b),
375  reg_c.template AsType<int32x16_t>()[Number<0>{}],
376  0,
377  0,
378  0);
379  }
380 };
381 
382 template <index_t MPerWave, index_t NPerWave>
384 
385 template <>
387 {
388  template <class FloatC>
389  __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
390  {
391  reg_c.template AsType<int32x4_t>()(Number<0>{}) =
392  __builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int32_t>(reg_a),
393  bit_cast<int32_t>(reg_b),
394  reg_c.template AsType<int32x4_t>()[Number<0>{}],
395  0,
396  0,
397  0);
398  }
399 };
400 
401 template <index_t MPerWave, index_t NPerWave>
403 
404 template <>
406 {
407  template <class FloatC>
408  __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
409  {
410 #if defined(__gfx950__)
411  reg_c.template AsType<int32x16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_i32_32x32x32_i8(
412  reg_a, reg_b, reg_c.template AsType<int32x16_t>()[Number<0>{}], 0, 0, 0);
413 #else
414  ignore = reg_a;
415  ignore = reg_b;
416  ignore = reg_c;
417 #endif // defined(__gfx950__)
418  }
419 };
420 
421 template <index_t MPerWave, index_t NPerWave>
423 
424 template <>
426 {
427  template <class FloatC>
428  __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
429  {
430 #if defined(__gfx950__)
431  reg_c.template AsType<int32x4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_i32_16x16x64_i8(
432  reg_a, reg_b, reg_c.template AsType<int32x4_t>()[Number<0>{}], 0, 0, 0);
433 #else
434  ignore = reg_a;
435  ignore = reg_b;
436  ignore = reg_c;
437 #endif // defined(__gfx950__)
438  }
439 };
440 
441 template <index_t MPerWave, index_t NPerWave>
443 
444 template <>
446 {
447  template <class FloatC>
448  __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
449  {
450  reg_c.template AsType<int32x16_t>()(Number<0>{}) =
451  __builtin_amdgcn_mfma_i32_32x32x16_i8(bit_cast<int64_t>(reg_a),
452  bit_cast<int64_t>(reg_b),
453  reg_c.template AsType<int32x16_t>()[Number<0>{}],
454  0,
455  0,
456  0);
457  }
458 };
459 
460 template <index_t MPerWave, index_t NPerWave>
462 
463 template <>
465 {
466  template <class FloatC>
467  __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
468  {
469  reg_c.template AsType<int32x4_t>()(Number<0>{}) =
470  __builtin_amdgcn_mfma_i32_16x16x32_i8(bit_cast<int64_t>(reg_a),
471  bit_cast<int64_t>(reg_b),
472  reg_c.template AsType<int32x4_t>()[Number<0>{}],
473  0,
474  0,
475  0);
476  }
477 };
478 
479 template <index_t MPerWave, index_t NPerWave>
481 
482 template <>
484 {
485  template <class FloatC>
486  __device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c)
487  {
488 #if defined(__gfx90a__) || defined(__gfx94__)
489  reg_c.template AsType<double4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64(
490  reg_a, reg_b, reg_c.template AsType<double4_t>()[Number<0>{}], 0, 0, 0);
491 #else
492  ignore = reg_a;
493  ignore = reg_b;
494  ignore = reg_c;
495 #endif
496  }
497 };
498 
499 template <index_t MPerWave, index_t NPerWave>
501 
508 template <>
510 {
511  template <class FloatC>
512  __device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
513  {
514 #if defined(__gfx950__)
515  reg_c.template AsType<float16_t>()(Number<0>{}) =
516  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
517  reg_a,
518  reg_b,
519  reg_c.template AsType<float16_t>()[Number<0>{}],
520  0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
521  0, // blgp
522  0,
523  0,
524  0,
525  0);
526 #else
527  ignore = reg_a;
528  ignore = reg_b;
529  ignore = reg_c;
530 #endif
531  }
532 
533  template <class FloatC>
534  __device__ static void Run(const bf8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
535  {
536 #if defined(__gfx950__)
537  reg_c.template AsType<float16_t>()(Number<0>{}) =
538  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
539  reg_a,
540  reg_b,
541  reg_c.template AsType<float16_t>()[Number<0>{}],
542  1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
543  1, // blgp
544  0,
545  0,
546  0,
547  0);
548 #else
549  ignore = reg_a;
550  ignore = reg_b;
551  ignore = reg_c;
552 #endif
553  }
554 
555  template <class FloatC>
556  __device__ static void Run(const bf8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
557  {
558 #if defined(__gfx950__)
559  reg_c.template AsType<float16_t>()(Number<0>{}) =
560  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
561  reg_a,
562  reg_b,
563  reg_c.template AsType<float16_t>()[Number<0>{}],
564  1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
565  0, // blgp
566  0,
567  0,
568  0,
569  0);
570 #else
571  ignore = reg_a;
572  ignore = reg_b;
573  ignore = reg_c;
574 #endif
575  }
576 
577  template <class FloatC>
578  __device__ static void Run(const f8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
579  {
580 #if defined(__gfx950__)
581  reg_c.template AsType<float16_t>()(Number<0>{}) =
582  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
583  reg_a,
584  reg_b,
585  reg_c.template AsType<float16_t>()[Number<0>{}],
586  0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
587  1, // blgp
588  0,
589  0,
590  0,
591  0);
592 #else
593  ignore = reg_a;
594  ignore = reg_b;
595  ignore = reg_c;
596 #endif
597  }
598 
599  template <class FloatC>
600  __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
601  {
602 #if defined(__gfx950__)
603 
604  int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
605  int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
606 
607  using arg_type = int32x8_t;
608 
609  reg_c.template AsType<float16_t>()(Number<0>{}) =
610  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
611  arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
612  arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
613  reg_c.template AsType<float16_t>()[Number<0>{}],
614  4, // cbsz
615  4, // blgp
616  0, // OPSEL
617  0,
618  0, // OPSEL
619  0);
620 #else
621  ignore = reg_a;
622  ignore = reg_b;
623  ignore = reg_c;
624 #endif
625  }
626 
627  template <class FloatC>
628  __device__ static void Run(const f6x32_t& reg_a, const f6x32_t& reg_b, FloatC& reg_c)
629  {
630 #if defined(__gfx950__)
631 
632  int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
633  int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
634 
635  using arg_type = int32x8_t;
636 
637  reg_c.template AsType<float16_t>()(Number<0>{}) =
638  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
639  arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
640  arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
641  reg_c.template AsType<float16_t>()[Number<0>{}],
642  2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
643  2, // blgp
644  0, // OPSEL
645  0,
646  0, // OPSEL
647  0);
648 #else
649  ignore = reg_a;
650  ignore = reg_b;
651  ignore = reg_c;
652 #endif
653  }
654 
655  template <class FloatC>
656  __device__ static void Run(const bf6x32_t& reg_a, const bf6x32_t& reg_b, FloatC& reg_c)
657  {
658 #if defined(__gfx950__)
659 
660  int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
661  int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
662 
663  using arg_type = int32x8_t;
664 
665  reg_c.template AsType<float16_t>()(Number<0>{}) =
666  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
667  arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
668  arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
669  reg_c.template AsType<float16_t>()[Number<0>{}],
670  3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
671  3, // blgp
672  0, // OPSEL
673  0,
674  0, // OPSEL
675  0);
676 #else
677  ignore = reg_a;
678  ignore = reg_b;
679  ignore = reg_c;
680 #endif
681  }
682 };
683 
684 template <index_t MPerWave, index_t NPerWave, index_t OpselA, index_t OpselB>
686 
687 template <index_t OpselA, index_t OpselB>
688 struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32, OpselA, OpselB>
689 {
690  template <class FloatC>
691  __device__ static void Run(const f8x32_t& reg_a,
692  const int32_t& scale_a,
693  const f8x32_t& reg_b,
694  const int32_t& scale_b,
695  FloatC& reg_c)
696  {
697 #if defined(__gfx950__)
698  // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
699  reg_c.template AsType<float16_t>()(Number<0>{}) =
700  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
701  reg_a,
702  reg_b,
703  reg_c.template AsType<float16_t>()[Number<0>{}],
704  0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
705  0, // blgp
706  OpselA, // OPSEL
707  scale_a,
708  OpselB, // OPSEL
709  scale_b);
710  // XXX: Note on the scale_a and scale_b parameters:
711  // If compiler detects that one or both scales are constant values, it will treat that
712  // constant as F32 constant. I.e., if scale_a at some point was declared as
713  // `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is
714  // assigned value `bit_cast<int32_t>(static_cast<float>(a_scale))`.
715 
716  // XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even
717  // when OPSEL is set otherwise.
718 #else
719  ignore = reg_a;
720  ignore = scale_a;
721  ignore = reg_b;
722  ignore = scale_b;
723  ignore = reg_c;
724 #endif
725  }
726 
727  template <class FloatC>
728  __device__ static void Run(const bf8x32_t& reg_a,
729  const int32_t& scale_a,
730  const bf8x32_t& reg_b,
731  const int32_t& scale_b,
732  FloatC& reg_c)
733  {
734 #if defined(__gfx950__)
735  // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
736  reg_c.template AsType<float16_t>()(Number<0>{}) =
737  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
738  reg_a,
739  reg_b,
740  reg_c.template AsType<float16_t>()[Number<0>{}],
741  1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
742  1, // blgp
743  OpselA, // OPSEL
744  scale_a,
745  OpselB, // OPSEL
746  scale_b);
747  // XXX: Note on the scale_a and scale_b parameters:
748  // If compiler detects that one or both scales are constant values, it will treat that
749  // constant as F32 constant. I.e., if scale_a at some point was declared as
750  // `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is
751  // assigned value `bit_cast<int32_t>(static_cast<float>(a_scale))`.
752 
753  // XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even
754  // when OPSEL is set otherwise.
755 #else
756  ignore = reg_a;
757  ignore = scale_a;
758  ignore = reg_b;
759  ignore = scale_b;
760  ignore = reg_c;
761 #endif
762  }
763 
764  template <class FloatC>
765  __device__ static void Run(const bf8x32_t& reg_a,
766  const int32_t& scale_a,
767  const f8x32_t& reg_b,
768  const int32_t& scale_b,
769  FloatC& reg_c)
770  {
771 #if defined(__gfx950__)
772  // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
773  reg_c.template AsType<float16_t>()(Number<0>{}) =
774  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
775  reg_a,
776  reg_b,
777  reg_c.template AsType<float16_t>()[Number<0>{}],
778  1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
779  0, // blgp
780  OpselA, // OPSEL
781  scale_a,
782  OpselB, // OPSEL
783  scale_b);
784  // XXX: Note on the scale_a and scale_b parameters:
785  // If compiler detects that one or both scales are constant values, it will treat that
786  // constant as F32 constant. I.e., if scale_a at some point was declared as
787  // `e8m0_bexp_t a_scale{1.0f}`, the instruction would only work if scale_a parameter is
788  // assigned value `bit_cast<int32_t>(static_cast<float>(a_scale))`.
789 
790  // XXX: Note on the OPSEL parameters: Instruction always takes byte0 as a scale value even
791  // when OPSEL is set otherwise.
792 #else
793  ignore = reg_a;
794  ignore = scale_a;
795  ignore = reg_b;
796  ignore = scale_b;
797  ignore = reg_c;
798 #endif
799  }
800 
801  template <class FloatC>
802  __device__ static void Run(const f6x32_t& reg_a,
803  const int32_t scale_a,
804  const f6x32_t& reg_b,
805  const int32_t scale_b,
806  FloatC& reg_c)
807  {
808 #if defined(__gfx950__)
809 
810  int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
811  int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
812 
813  using arg_type = int32x8_t;
814 
815  reg_c.template AsType<float16_t>()(Number<0>{}) =
816  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
817  arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
818  arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
819  reg_c.template AsType<float16_t>()[Number<0>{}],
820  2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
821  2, // blgp
822  OpselA, // OPSEL
823  scale_a,
824  OpselB, // OPSEL
825  scale_b);
826 #else
827  ignore = reg_a;
828  ignore = scale_a;
829  ignore = reg_b;
830  ignore = scale_b;
831  ignore = reg_c;
832 #endif
833  }
834 
835  template <class FloatC>
836  __device__ static void Run(const bf6x32_t& reg_a,
837  const int32_t scale_a,
838  const bf6x32_t& reg_b,
839  const int32_t scale_b,
840  FloatC& reg_c)
841  {
842 #if defined(__gfx950__)
843 
844  int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
845  int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
846 
847  using arg_type = int32x8_t;
848 
849  reg_c.template AsType<float16_t>()(Number<0>{}) =
850  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
851  arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
852  arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
853  reg_c.template AsType<float16_t>()[Number<0>{}],
854  3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
855  3, // blgp
856  OpselA, // OPSEL
857  scale_a,
858  OpselB, // OPSEL
859  scale_b);
860 #else
861  ignore = reg_a;
862  ignore = scale_a;
863  ignore = reg_b;
864  ignore = scale_b;
865  ignore = reg_c;
866 #endif
867  }
868 
869  template <class FloatC>
870  __device__ static void Run(const f4x32_t& reg_a,
871  const int32_t scale_a,
872  const f4x32_t& reg_b,
873  const int32_t scale_b,
874  FloatC& reg_c)
875  {
876 #if defined(__gfx950__)
877 
878  int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
879  int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
880 
881  using arg_type = int32x8_t;
882 
883  reg_c.template AsType<float16_t>()(Number<0>{}) =
884  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
885  arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
886  arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
887  reg_c.template AsType<float16_t>()[Number<0>{}],
888  4, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
889  4, // blgp
890  OpselA, // OPSEL
891  scale_a,
892  OpselB, // OPSEL
893  scale_b);
894 #else
895  ignore = reg_a;
896  ignore = scale_a;
897  ignore = reg_b;
898  ignore = scale_b;
899  ignore = reg_c;
900 #endif
901  }
902 };
903 
904 template <index_t MPerWave, index_t NPerWave, index_t OpselA, index_t OpselB>
906 
907 template <index_t OpselA, index_t OpselB>
908 struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16, OpselA, OpselB>
909 {
910  template <class FloatC>
911  __device__ static void Run(const f8x32_t& reg_a,
912  const int32_t& scale_a,
913  const f8x32_t& reg_b,
914  const int32_t& scale_b,
915  FloatC& reg_c)
916  {
917 #if defined(__gfx950__)
918  // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
919  reg_c.template AsType<float4_t>()(Number<0>{}) =
920  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
921  reg_a,
922  reg_b,
923  reg_c.template AsType<float4_t>()[Number<0>{}],
924  0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
925  0, // blgp
926  OpselA, // OPSEL
927  scale_a,
928  OpselB, // OPSEL
929  scale_b);
930 #else
931  ignore = reg_a;
932  ignore = scale_a;
933  ignore = reg_b;
934  ignore = scale_b;
935  ignore = reg_c;
936 #endif
937  }
938 
939  template <class FloatC>
940  __device__ static void Run(const bf8x32_t& reg_a,
941  const int32_t& scale_a,
942  const bf8x32_t& reg_b,
943  const int32_t& scale_b,
944  FloatC& reg_c)
945  {
946 #if defined(__gfx950__)
947  // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
948  reg_c.template AsType<float4_t>()(Number<0>{}) =
949  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
950  reg_a,
951  reg_b,
952  reg_c.template AsType<float4_t>()[Number<0>{}],
953  1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
954  1, // blgp
955  OpselA, // OPSEL
956  scale_a,
957  OpselB, // OPSEL
958  scale_b);
959 #else
960  ignore = reg_a;
961  ignore = scale_a;
962  ignore = reg_b;
963  ignore = scale_b;
964  ignore = reg_c;
965 #endif
966  }
967 
968  template <class FloatC>
969  __device__ static void Run(const f8x32_t& reg_a,
970  const int32_t& scale_a,
971  const bf8x32_t& reg_b,
972  const int32_t& scale_b,
973  FloatC& reg_c)
974  {
975 #if defined(__gfx950__)
976  // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
977  reg_c.template AsType<float4_t>()(Number<0>{}) =
978  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
979  reg_a,
980  reg_b,
981  reg_c.template AsType<float4_t>()[Number<0>{}],
982  0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
983  1, // blgp
984  OpselA, // OPSEL
985  scale_a,
986  OpselB, // OPSEL
987  scale_b);
988 #else
989  ignore = reg_a;
990  ignore = scale_a;
991  ignore = reg_b;
992  ignore = scale_b;
993  ignore = reg_c;
994 #endif
995  }
996 
997  template <class FloatC>
998  __device__ static void Run(const bf8x32_t& reg_a,
999  const int32_t& scale_a,
1000  const f8x32_t& reg_b,
1001  const int32_t& scale_b,
1002  FloatC& reg_c)
1003  {
1004 #if defined(__gfx950__)
1005  // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
1006  reg_c.template AsType<float4_t>()(Number<0>{}) =
1007  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1008  reg_a,
1009  reg_b,
1010  reg_c.template AsType<float4_t>()[Number<0>{}],
1011  1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1012  0, // blgp
1013  OpselA, // OPSEL
1014  scale_a,
1015  OpselB, // OPSEL
1016  scale_b);
1017 #else
1018  ignore = reg_a;
1019  ignore = scale_a;
1020  ignore = reg_b;
1021  ignore = scale_b;
1022  ignore = reg_c;
1023 #endif
1024  }
1025 
1026  template <class FloatC>
1027  __device__ static void Run(const f6x32_t& reg_a,
1028  const int32_t scale_a,
1029  const f6x32_t& reg_b,
1030  const int32_t scale_b,
1031  FloatC& reg_c)
1032  {
1033 #if defined(__gfx950__)
1034  int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
1035  int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
1036 
1037  using arg_type = int32x8_t;
1038 
1039  reg_c.template AsType<float4_t>()(Number<0>{}) =
1040  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1041  arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
1042  arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
1043  reg_c.template AsType<float4_t>()[Number<0>{}],
1044  2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1045  2, // blgp
1046  OpselA, // OPSEL
1047  scale_a,
1048  OpselB, // OPSEL
1049  scale_b);
1050 #else
1051  ignore = reg_a;
1052  ignore = scale_a;
1053  ignore = reg_b;
1054  ignore = scale_b;
1055  ignore = reg_c;
1056 #endif
1057  }
1058 
1059  template <class FloatC>
1060  __device__ static void Run(const f6x16x2_t& reg_a,
1061  const int32_t scale_a,
1062  const f6x16x2_t& reg_b,
1063  const int32_t scale_b,
1064  FloatC& reg_c)
1065  {
1066 #if defined(__gfx950__)
1067  using arg_type = int32x8_t;
1068  arg_type arg_a{
1069  static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<0>{}][0]),
1070  static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<0>{}][1]),
1071  static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<0>{}][2]),
1072  static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<1>{}][0]),
1073  static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<1>{}][1]),
1074  static_cast<int32_t>(reg_a.template AsType<f6x16x2_t::data_t>()[Number<1>{}][2]),
1075  0,
1076  0};
1077  arg_type arg_b{
1078  static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<0>{}][0]),
1079  static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<0>{}][1]),
1080  static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<0>{}][2]),
1081  static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<1>{}][0]),
1082  static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<1>{}][1]),
1083  static_cast<int32_t>(reg_b.template AsType<f6x16x2_t::data_t>()[Number<1>{}][2]),
1084  0,
1085  0};
1086 
1087  reg_c.template AsType<float4_t>()(Number<0>{}) =
1088  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1089  arg_a,
1090  arg_b,
1091  reg_c.template AsType<float4_t>()[Number<0>{}],
1092  2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1093  2, // blgp
1094  OpselA, // OPSEL
1095  scale_a,
1096  OpselB, // OPSEL
1097  scale_b);
1098 #else
1099  ignore = reg_a;
1100  ignore = scale_a;
1101  ignore = reg_b;
1102  ignore = scale_b;
1103  ignore = reg_c;
1104 #endif
1105  }
1106 
1107  template <class FloatC>
1108  __device__ static void Run(const bf6x32_t& reg_a,
1109  const int32_t scale_a,
1110  const bf6x32_t& reg_b,
1111  const int32_t scale_b,
1112  FloatC& reg_c)
1113  {
1114 #if defined(__gfx950__)
1115  int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
1116  int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
1117 
1118  using arg_type = int32x8_t;
1119 
1120  reg_c.template AsType<float4_t>()(Number<0>{}) =
1121  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1122  arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
1123  arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
1124  reg_c.template AsType<float4_t>()[Number<0>{}],
1125  3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1126  3, // blgp
1127  OpselA, // OPSEL
1128  scale_a,
1129  OpselB, // OPSEL
1130  scale_b);
1131 #else
1132  ignore = reg_a;
1133  ignore = scale_a;
1134  ignore = reg_b;
1135  ignore = scale_b;
1136  ignore = reg_c;
1137 #endif
1138  }
1139 
1140  template <class FloatC>
1141  __device__ static void Run(const bf6x16x2_t& reg_a,
1142  const int32_t scale_a,
1143  const bf6x16x2_t& reg_b,
1144  const int32_t scale_b,
1145  FloatC& reg_c)
1146  {
1147 #if defined(__gfx950__)
1148  using arg_type = int32x8_t;
1149  arg_type arg_a{
1150  static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][0]),
1151  static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][1]),
1152  static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][2]),
1153  static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][0]),
1154  static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][1]),
1155  static_cast<int32_t>(reg_a.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][2]),
1156  0,
1157  0};
1158  arg_type arg_b{
1159  static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][0]),
1160  static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][1]),
1161  static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<0>{}][2]),
1162  static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][0]),
1163  static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][1]),
1164  static_cast<int32_t>(reg_b.template AsType<bf6x16x2_t::data_t>()[Number<1>{}][2]),
1165  0,
1166  0};
1167 
1168  reg_c.template AsType<float4_t>()(Number<0>{}) =
1169  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1170  arg_a,
1171  arg_b,
1172  reg_c.template AsType<float4_t>()[Number<0>{}],
1173  3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1174  3, // blgp
1175  OpselA, // OPSEL
1176  scale_a,
1177  OpselB, // OPSEL
1178  scale_b);
1179 #else
1180  ignore = reg_a;
1181  ignore = scale_a;
1182  ignore = reg_b;
1183  ignore = scale_b;
1184  ignore = reg_c;
1185 #endif
1186  }
1187 
1188  template <class FloatC>
1189  __device__ static void Run(const f4x32_t& reg_a,
1190  const int32_t scale_a,
1191  const f4x32_t& reg_b,
1192  const int32_t scale_b,
1193  FloatC& reg_c)
1194  {
1195 #if defined(__gfx950__)
1196  int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
1197  int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
1198  using arg_type = int32x8_t;
1199  reg_c.template AsType<float4_t>()(Number<0>{}) =
1200  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1201  arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
1202  arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
1203  reg_c.template AsType<float4_t>()[Number<0>{}],
1204  4, // cbsz
1205  4, // blgp
1206  OpselA, // OPSEL
1207  scale_a,
1208  OpselB, // OPSEL
1209  scale_b);
1210 #else
1211  ignore = reg_a;
1212  ignore = scale_a;
1213  ignore = reg_b;
1214  ignore = scale_b;
1215  ignore = reg_c;
1216 #endif
1217  }
1218 };
1219 
1220 template <index_t MPerWave, index_t NPerWave>
1222 
1229 template <>
1231 {
1232  template <class FloatC>
1233  __device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
1234  {
1235 #if defined(__gfx950__)
1236  // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
1237  reg_c.template AsType<float4_t>()(Number<0>{}) =
1238  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1239  reg_a,
1240  reg_b,
1241  reg_c.template AsType<float4_t>()[Number<0>{}],
1242  0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1243  0, // blgp
1244  0,
1245  0,
1246  0,
1247  0);
1248 #else
1249  ignore = reg_a;
1250  ignore = reg_b;
1251  ignore = reg_c;
1252 #endif
1253  }
1254 
1255  template <class FloatC>
1256  __device__ static void Run(const bf8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
1257  {
1258 #if defined(__gfx950__)
1259  // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
1260  reg_c.template AsType<float4_t>()(Number<0>{}) =
1261  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1262  reg_a,
1263  reg_b,
1264  reg_c.template AsType<float4_t>()[Number<0>{}],
1265  1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1266  1, // blgp
1267  0,
1268  0,
1269  0,
1270  0);
1271 #else
1272  ignore = reg_a;
1273  ignore = reg_b;
1274  ignore = reg_c;
1275 #endif
1276  }
1277 
1278  template <class FloatC>
1279  __device__ static void Run(const bf8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
1280  {
1281 #if defined(__gfx950__)
1282  // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
1283  reg_c.template AsType<float4_t>()(Number<0>{}) =
1284  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1285  reg_a,
1286  reg_b,
1287  reg_c.template AsType<float4_t>()[Number<0>{}],
1288  1, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1289  0, // blgp
1290  0,
1291  0,
1292  0,
1293  0);
1294 #else
1295  ignore = reg_a;
1296  ignore = reg_b;
1297  ignore = reg_c;
1298 #endif
1299  }
1300 
1301  template <class FloatC>
1302  __device__ static void Run(const f8x32_t& reg_a, const bf8x32_t& reg_b, FloatC& reg_c)
1303  {
1304 #if defined(__gfx950__)
1305  // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
1306  reg_c.template AsType<float4_t>()(Number<0>{}) =
1307  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1308  reg_a,
1309  reg_b,
1310  reg_c.template AsType<float4_t>()[Number<0>{}],
1311  0, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1312  1, // blgp
1313  0,
1314  0,
1315  0,
1316  0);
1317 #else
1318  ignore = reg_a;
1319  ignore = reg_b;
1320  ignore = reg_c;
1321 #endif
1322  }
1323 
1324  template <class FloatC>
1325  __device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
1326  {
1327 #if defined(__gfx950__)
1328  int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
1329  int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
1330 
1331  using arg_type = int32x8_t;
1332 
1333  reg_c.template AsType<float4_t>()(Number<0>{}) =
1334  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1335  arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
1336  arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
1337  reg_c.template AsType<float4_t>()[Number<0>{}],
1338  4, // cbsz
1339  4, // blgp
1340  0, // OPSEL
1341  0,
1342  0, // OPSEL
1343  0);
1344 #else
1345  ignore = reg_a;
1346  ignore = reg_b;
1347  ignore = reg_c;
1348 #endif
1349  }
1350 
1351  template <class FloatC>
1352  __device__ static void Run(const f6x32_t& reg_a, const f6x32_t& reg_b, FloatC& reg_c)
1353  {
1354 #if defined(__gfx950__)
1355  int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
1356  int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
1357 
1358  using arg_type = int32x8_t;
1359 
1360  reg_c.template AsType<float4_t>()(Number<0>{}) =
1361  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1362  arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
1363  arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
1364  reg_c.template AsType<float4_t>()[Number<0>{}],
1365  2, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1366  2, // blgp
1367  0, // OPSEL
1368  0,
1369  0, // OPSEL
1370  0);
1371 #else
1372  ignore = reg_a;
1373  ignore = reg_b;
1374  ignore = reg_c;
1375 #endif
1376  }
1377 
1378  template <class FloatC>
1379  __device__ static void Run(const bf6x32_t& reg_a, const bf6x32_t& reg_b, FloatC& reg_c)
1380  {
1381 #if defined(__gfx950__)
1382  int32x6_t arg_a = bit_cast<int32x6_t>(reg_a);
1383  int32x6_t arg_b = bit_cast<int32x6_t>(reg_b);
1384 
1385  using arg_type = int32x8_t;
1386 
1387  reg_c.template AsType<float4_t>()(Number<0>{}) =
1388  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1389  arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], arg_a[4], arg_a[5], 0, 0},
1390  arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], arg_b[4], arg_b[5], 0, 0},
1391  reg_c.template AsType<float4_t>()[Number<0>{}],
1392  3, // cbsz {0 FP8 E4M3; 1 FP8 E5M2; 2 FP6 E2M3; 3 FP6 E3M2; 4 FP4 E2M1}
1393  3, // blgp
1394  0, // OPSEL
1395  0,
1396  0, // OPSEL
1397  0);
1398 #else
1399  ignore = reg_a;
1400  ignore = reg_b;
1401  ignore = reg_c;
1402 #endif
1403  }
1404 };
1405 
1406 template <index_t MPerWave, index_t NPerWave>
1408 
1409 template <>
1411 {
1412  template <class FloatC>
1413  __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
1414  {
1415 #if defined(__gfx94__)
1416  reg_c.template AsType<float16_t>()(Number<0>{}) =
1417  __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
1418  bit_cast<int64_t>(reg_a),
1419  bit_cast<int64_t>(reg_b),
1420  reg_c.template AsType<float16_t>()[Number<0>{}],
1421  0,
1422  0,
1423  0);
1424 #else
1425  vector_type<f8_t, 8> reg_a_v(reg_a);
1426  vector_type<f8_t, 8> reg_b_v(reg_b);
1427 
1428  static_for<0, 8, 1>{}([&](auto k) {
1429  float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
1430  float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
1431 
1432  intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
1433  });
1434 #endif
1435  }
1436 };
1437 
1438 template <index_t MPerWave, index_t NPerWave>
1440 
1441 template <>
1443 {
1444  template <class FloatC>
1445  __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
1446  {
1447 #if defined(__gfx94__)
1448  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
1449  bit_cast<int64_t>(reg_a),
1450  bit_cast<int64_t>(reg_b),
1451  reg_c.template AsType<float4_t>()[Number<0>{}],
1452  0,
1453  0,
1454  0);
1455 #else
1456  vector_type<f8_t, 8> reg_a_v(reg_a);
1457  vector_type<f8_t, 8> reg_b_v(reg_b);
1458 
1459  static_for<0, 8, 1>{}([&](auto k) {
1460  float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
1461  float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
1462 
1463  intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
1464  });
1465 #endif
1466  }
1467 };
1468 
1469 template <index_t MPerWave, index_t NPerWave>
1471 
1472 template <>
1474 {
1475  template <class FloatC>
1476  __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
1477  {
1478 #if defined(__gfx94__)
1479  reg_c.template AsType<float16_t>()(Number<0>{}) =
1480  __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
1481  bit_cast<int64_t>(reg_a),
1482  bit_cast<int64_t>(reg_b),
1483  reg_c.template AsType<float16_t>()[Number<0>{}],
1484  0,
1485  0,
1486  0);
1487 #else
1488  vector_type<bf8_t, 8> reg_a_v(reg_a);
1489  vector_type<bf8_t, 8> reg_b_v(reg_b);
1490 
1491  static_for<0, 8, 1>{}([&](auto k) {
1492  float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
1493  float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
1494 
1495  intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
1496  });
1497 #endif
1498  }
1499 };
1500 
1501 template <index_t MPerWave, index_t NPerWave>
1503 
1504 template <>
1506 {
1507  template <class FloatC>
1508  __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
1509  {
1510 #if defined(__gfx94__)
1511  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
1512  bit_cast<int64_t>(reg_a),
1513  bit_cast<int64_t>(reg_b),
1514  reg_c.template AsType<float4_t>()[Number<0>{}],
1515  0,
1516  0,
1517  0);
1518 #else
1519  vector_type<bf8_t, 8> reg_a_v(reg_a);
1520  vector_type<bf8_t, 8> reg_b_v(reg_b);
1521 
1522  static_for<0, 8, 1>{}([&](auto k) {
1523  float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
1524  float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
1525 
1526  intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
1527  });
1528 #endif
1529  }
1530 };
1531 
1532 template <index_t MPerWave, index_t NPerWave>
1534 
1535 template <>
1537 {
1538  template <class FloatC>
1539  __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
1540  {
1541 #if defined(__gfx94__)
1542  reg_c.template AsType<float16_t>()(Number<0>{}) =
1543  __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
1544  bit_cast<int64_t>(reg_a),
1545  bit_cast<int64_t>(reg_b),
1546  reg_c.template AsType<float16_t>()[Number<0>{}],
1547  0,
1548  0,
1549  0);
1550 #else
1551  vector_type<f8_t, 8> reg_a_v(reg_a);
1552  vector_type<bf8_t, 8> reg_b_v(reg_b);
1553 
1554  static_for<0, 8, 1>{}([&](auto k) {
1555  float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
1556  float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
1557 
1558  intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
1559  });
1560 #endif
1561  }
1562 };
1563 
1564 template <index_t MPerWave, index_t NPerWave>
1566 
1567 template <>
1569 {
1570  template <class FloatC>
1571  __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
1572  {
1573 #if defined(__gfx94__)
1574  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
1575  bit_cast<int64_t>(reg_a),
1576  bit_cast<int64_t>(reg_b),
1577  reg_c.template AsType<float4_t>()[Number<0>{}],
1578  0,
1579  0,
1580  0);
1581 #else
1582  vector_type<f8_t, 8> reg_a_v(reg_a);
1583  vector_type<bf8_t, 8> reg_b_v(reg_b);
1584 
1585  static_for<0, 8, 1>{}([&](auto k) {
1586  float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
1587  float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
1588 
1589  intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
1590  });
1591 #endif
1592  }
1593 };
1594 
1595 template <index_t MPerWave, index_t NPerWave>
1597 
1598 template <>
1600 {
1601  template <class FloatC>
1602  __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
1603  {
1604 #if defined(__gfx94__)
1605  reg_c.template AsType<float16_t>()(Number<0>{}) =
1606  __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
1607  bit_cast<int64_t>(reg_a),
1608  bit_cast<int64_t>(reg_b),
1609  reg_c.template AsType<float16_t>()[Number<0>{}],
1610  0,
1611  0,
1612  0);
1613 #else
1614  vector_type<bf8_t, 8> reg_a_v(reg_a);
1615  vector_type<f8_t, 8> reg_b_v(reg_b);
1616 
1617  static_for<0, 8, 1>{}([&](auto k) {
1618  float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
1619  float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
1620 
1621  intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
1622  });
1623 #endif
1624  }
1625 };
1626 
1627 template <index_t MPerWave, index_t NPerWave>
1629 
1630 template <>
1632 {
1633  template <class FloatC>
1634  __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
1635  {
1636 #if defined(__gfx94__)
1637  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
1638  bit_cast<int64_t>(reg_a),
1639  bit_cast<int64_t>(reg_b),
1640  reg_c.template AsType<float4_t>()[Number<0>{}],
1641  0,
1642  0,
1643  0);
1644 #else
1645  vector_type<bf8_t, 8> reg_a_v(reg_a);
1646  vector_type<f8_t, 8> reg_b_v(reg_b);
1647 
1648  static_for<0, 8, 1>{}([&](auto k) {
1649  float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
1650  float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
1651 
1652  intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
1653  });
1654 #endif
1655  }
1656 };
1657 
1658 /******************* tf32 on gfx942 *************************************/
1659 template <index_t MPerWave, index_t NPerWave>
1661 
1662 template <>
1664 {
1665  template <class FloatC>
1666  __device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c)
1667  {
1668 #if defined(__gfx942__)
1669  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8_xf32(
1670  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
1671 #else
1672  ignore = reg_a;
1673  ignore = reg_b;
1674  ignore = reg_c;
1675 #endif
1676  }
1677 };
1678 
1679 template <index_t MPerWave, index_t NPerWave>
1681 
1682 template <>
1684 {
1685  template <class FloatC>
1686  __device__ static void Run(const float2_t& reg_a, const float2_t& reg_b, FloatC& reg_c)
1687  {
1688 #if defined(__gfx942__)
1689  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4_xf32(
1690  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
1691 #else
1692  ignore = reg_a;
1693  ignore = reg_b;
1694  ignore = reg_c;
1695 #endif
1696  }
1697 };
1698 
1699 /******************* tf32/xf32 on gfx950 ********************************/
1700 /* bf16x3 simulate tf32/xf32: input/output/accumulator are all float; */
1701 /* step: */
1702 /* 1. separate one input to 2 bf16 registers: */
1703 /* in_bf16_big = f32_to_bf16(in_f32) */
1704 /* in_bf16_small = in_f32 - in_bf16_big */
1705 /* 2. run 3 xdlops gemm: the accumulator of each gemm is the same. */
1706 /* out_f32 = A_bf16_big * B_bf16_big */
1707 /* out_f32 += A_bf16_small * B_bf16_big */
1708 /* out_f32 += A_bf16_big * B_bf16_small */
1709 /************************************************************************/
1710 template <index_t MPerWave, index_t NPerWave>
1712 
1713 template <>
1715 {
1716  template <class FloatC>
1717  __device__ static void Run(const float8_t& reg_a, const float8_t& reg_b, FloatC& reg_c)
1718  {
1719 #if defined(__gfx950__)
1720  using I0 = Number<0>;
1721  vector_type<float, 8> reg_a_v(reg_a);
1722  vector_type<float, 8> reg_b_v(reg_b);
1723 
1724  vector_type<bhalf_t, 8> v_reg_a_bf16_big;
1725  vector_type<bhalf_t, 8> v_reg_a_bf16_small;
1726  vector_type<bhalf_t, 8> v_reg_b_bf16_big;
1727  vector_type<bhalf_t, 8> v_reg_b_bf16_small;
1728 
1729  convert_float_to_bf16_pairs(reg_a_v, v_reg_a_bf16_big, v_reg_a_bf16_small);
1730  convert_float_to_bf16_pairs(reg_b_v, v_reg_b_bf16_big, v_reg_b_bf16_small);
1731 
1732  // Run 3 times: big*big, small*big, big*small
1734  v_reg_a_bf16_small.template AsType<bhalf8_t>()[I0{}],
1735  v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
1736  reg_c);
1738  v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
1739  v_reg_b_bf16_small.template AsType<bhalf8_t>()[I0{}],
1740  reg_c);
1742  v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
1743  v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
1744  reg_c);
1745 #else
1746  ignore = reg_a;
1747  ignore = reg_b;
1748  ignore = reg_c;
1749 #endif // defined(__gfx950__)
1750  }
1751 };
1752 
1753 template <index_t MPerWave, index_t NPerWave>
1755 
1756 template <>
1758 {
1759  template <class FloatC>
1760  __device__ static void Run(const float8_t& reg_a, const float8_t& reg_b, FloatC& reg_c)
1761  {
1762 #if defined(__gfx950__)
1763  using I0 = Number<0>;
1764  vector_type<float, 8> reg_a_v(reg_a);
1765  vector_type<float, 8> reg_b_v(reg_b);
1766 
1767  vector_type<bhalf_t, 8> v_reg_a_bf16_big;
1768  vector_type<bhalf_t, 8> v_reg_a_bf16_small;
1769  vector_type<bhalf_t, 8> v_reg_b_bf16_big;
1770  vector_type<bhalf_t, 8> v_reg_b_bf16_small;
1771 
1772  convert_float_to_bf16_pairs(reg_a_v, v_reg_a_bf16_big, v_reg_a_bf16_small);
1773  convert_float_to_bf16_pairs(reg_b_v, v_reg_b_bf16_big, v_reg_b_bf16_small);
1774 
1775  // Run 3 times: big*big, small*big, big*small
1777  v_reg_a_bf16_small.template AsType<bhalf8_t>()[I0{}],
1778  v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
1779  reg_c);
1781  v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
1782  v_reg_b_bf16_small.template AsType<bhalf8_t>()[I0{}],
1783  reg_c);
1785  v_reg_a_bf16_big.template AsType<bhalf8_t>()[I0{}],
1786  v_reg_b_bf16_big.template AsType<bhalf8_t>()[I0{}],
1787  reg_c);
1788 #else
1789  ignore = reg_a;
1790  ignore = reg_b;
1791  ignore = reg_c;
1792 #endif // defined(__gfx950__)
1793  }
1794 };
1795 
1796 /******************* tf32/xf32 on gfx950 end ************************************/
1797 } // namespace ck
bf8_t bf8x32_t
Definition: vector_type.hpp:240
bf8_t bf8x8_t
Definition: vector_type.hpp:238
Definition: ck.hpp:270
typename vector_type< bf6x16_pk_t, 2 >::type bf6x16x2_t
Definition: dtype_vector.hpp:2273
typename vector_type< f6x16_pk_t, 2 >::type f6x16x2_t
Definition: dtype_vector.hpp:2268
typename vector_type< f6x32_pk_t, 1 >::type f6x32_t
Definition: dtype_vector.hpp:2269
typename vector_type< bhalf_t, 4 >::type bhalf4_t
Definition: dtype_vector.hpp:2162
__device__ __forceinline__ void convert_float_to_bf16_pairs(const vector_type< float, VecSize > &reg_f32, vector_type< bhalf_t, VecSize > &reg_bf16_big, vector_type< bhalf_t, VecSize > &reg_bf16_small)
Definition: amd_xdlops.hpp:17
typename vector_type< bhalf_t, 8 >::type bhalf8_t
Definition: dtype_vector.hpp:2163
typename vector_type< float, 2 >::type float2_t
Definition: dtype_vector.hpp:2146
typename vector_type< int8_t, 8 >::type int8x8_t
Definition: dtype_vector.hpp:2179
typename vector_type< half_t, 4 >::type half4_t
Definition: dtype_vector.hpp:2155
typename vector_type< bf6x32_pk_t, 1 >::type bf6x32_t
Definition: dtype_vector.hpp:2274
typename vector_type< int32_t, 8 >::type int32x8_t
Definition: dtype_vector.hpp:2171
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
typename vector_type< float, 8 >::type float8_t
Definition: dtype_vector.hpp:2148
typename vector_type< f4x2_pk_t, 16 >::type f4x32_t
Definition: dtype_vector.hpp:2263
typename vector_type< bhalf_t, 2 >::type bhalf2_t
Definition: dtype_vector.hpp:2161
typename vector_type< int8_t, 16 >::type int8x16_t
Definition: dtype_vector.hpp:2180
typename vector_type< int32_t, 4 >::type int32x4_t
Definition: dtype_vector.hpp:2169
typename vector_type< int8_t, 4 >::type int8x4_t
Definition: dtype_vector.hpp:2178
typename vector_type< int32_t, 6 >::type int32x6_t
Definition: dtype_vector.hpp:2170
__host__ constexpr __device__ bhalf_t type_convert< bhalf_t, float >(float x)
Definition: type_convert.hpp:133
typename vector_type< half_t, 8 >::type half8_t
Definition: dtype_vector.hpp:2156
__host__ constexpr __device__ float type_convert< float, bhalf_t >(bhalf_t x)
Definition: type_convert.hpp:120
signed int int32_t
Definition: stdint.h:123
Definition: integral_constant.hpp:20
static __device__ void Run(const bf6x32_t &reg_a, const bf6x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1379
static __device__ void Run(const f6x32_t &reg_a, const f6x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1352
static __device__ void Run(const f8x32_t &reg_a, const bf8x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1302
static __device__ void Run(const bf8x32_t &reg_a, const bf8x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1256
static __device__ void Run(const bf8x32_t &reg_a, const f8x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1279
static __device__ void Run(const f4x32_t &reg_a, const f4x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1325
static __device__ void Run(const f8x32_t &reg_a, const f8x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1233
Definition: amd_xdlops.hpp:1221
static __device__ void Run(const bhalf4_t &reg_a, const bhalf4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:328
Definition: amd_xdlops.hpp:322
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:218
Definition: amd_xdlops.hpp:212
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:95
Definition: amd_xdlops.hpp:89
static __device__ void Run(const bhalf8_t &reg_a, const bhalf8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:294
Definition: amd_xdlops.hpp:288
static __device__ void Run(const bf8x8_t &reg_a, const bf8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1508
Definition: amd_xdlops.hpp:1502
static __device__ void Run(const bf8x8_t &reg_a, const f8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1634
Definition: amd_xdlops.hpp:1628
static __device__ void Run(const half8_t &reg_a, const half8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:184
Definition: amd_xdlops.hpp:178
static __device__ void Run(const f8x8_t &reg_a, const bf8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1571
Definition: amd_xdlops.hpp:1565
static __device__ void Run(const f8x8_t &reg_a, const f8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1445
Definition: amd_xdlops.hpp:1439
static __device__ void Run(const float8_t &reg_a, const float8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1717
Definition: amd_xdlops.hpp:1711
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:232
Definition: amd_xdlops.hpp:226
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:81
Definition: amd_xdlops.hpp:75
static __device__ void Run(const bhalf2_t &reg_a, const bhalf2_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:356
Definition: amd_xdlops.hpp:350
static __device__ void Run(const float2_t &reg_a, const float2_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1666
Definition: amd_xdlops.hpp:1660
static __device__ void Run(const bhalf8_t &reg_a, const bhalf8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:274
Definition: amd_xdlops.hpp:268
static __device__ void Run(const bf8x8_t &reg_a, const bf8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1476
Definition: amd_xdlops.hpp:1470
static __device__ void Run(const bf8x8_t &reg_a, const f8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1602
Definition: amd_xdlops.hpp:1596
static __device__ void Run(const half8_t &reg_a, const half8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:164
Definition: amd_xdlops.hpp:158
static __device__ void Run(const f8x8_t &reg_a, const bf8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1539
Definition: amd_xdlops.hpp:1533
static __device__ void Run(const f8x8_t &reg_a, const f8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1413
Definition: amd_xdlops.hpp:1407
static __device__ void Run(const float8_t &reg_a, const float8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1760
Definition: amd_xdlops.hpp:1754
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:53
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:40
Definition: amd_xdlops.hpp:34
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:67
Definition: amd_xdlops.hpp:61
static __device__ void Run(const bhalf2_t &reg_a, const bhalf2_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:342
Definition: amd_xdlops.hpp:336
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:150
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:137
Definition: amd_xdlops.hpp:131
static __device__ void Run(const float2_t &reg_a, const float2_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1686
Definition: amd_xdlops.hpp:1680
static __device__ void Run(const bf8x32_t &reg_a, const f8x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:556
static __device__ void Run(const f8x32_t &reg_a, const bf8x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:578
static __device__ void Run(const bf6x32_t &reg_a, const bf6x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:656
static __device__ void Run(const f6x32_t &reg_a, const f6x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:628
static __device__ void Run(const f8x32_t &reg_a, const f8x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:512
static __device__ void Run(const f4x32_t &reg_a, const f4x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:600
static __device__ void Run(const bf8x32_t &reg_a, const bf8x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:534
Definition: amd_xdlops.hpp:500
static __device__ void Run(const bhalf4_t &reg_a, const bhalf4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:314
Definition: amd_xdlops.hpp:308
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:204
Definition: amd_xdlops.hpp:198
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:109
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:120
Definition: amd_xdlops.hpp:103
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:246
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:257
Definition: amd_xdlops.hpp:240
static __device__ void Run(const double &reg_a, const double &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:486
Definition: amd_xdlops.hpp:480
static __device__ void Run(const int8x4_t &reg_a, const int8x4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:389
Definition: amd_xdlops.hpp:383
static __device__ void Run(const int8x8_t &reg_a, const int8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:467
Definition: amd_xdlops.hpp:461
static __device__ void Run(const int8x16_t &reg_a, const int8x16_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:428
Definition: amd_xdlops.hpp:422
static __device__ void Run(const int8x8_t &reg_a, const int8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:448
Definition: amd_xdlops.hpp:442
static __device__ void Run(const int8x16_t &reg_a, const int8x16_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:408
Definition: amd_xdlops.hpp:402
static __device__ void Run(const int8x4_t &reg_a, const int8x4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:370
Definition: amd_xdlops.hpp:364
static __device__ void Run(const f6x16x2_t &reg_a, const int32_t scale_a, const f6x16x2_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1060
static __device__ void Run(const f4x32_t &reg_a, const int32_t scale_a, const f4x32_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1189
static __device__ void Run(const f6x32_t &reg_a, const int32_t scale_a, const f6x32_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1027
static __device__ void Run(const f8x32_t &reg_a, const int32_t &scale_a, const f8x32_t &reg_b, const int32_t &scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:911
static __device__ void Run(const bf8x32_t &reg_a, const int32_t &scale_a, const f8x32_t &reg_b, const int32_t &scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:998
static __device__ void Run(const f8x32_t &reg_a, const int32_t &scale_a, const bf8x32_t &reg_b, const int32_t &scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:969
static __device__ void Run(const bf6x16x2_t &reg_a, const int32_t scale_a, const bf6x16x2_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1141
static __device__ void Run(const bf6x32_t &reg_a, const int32_t scale_a, const bf6x32_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:1108
static __device__ void Run(const bf8x32_t &reg_a, const int32_t &scale_a, const bf8x32_t &reg_b, const int32_t &scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:940
Definition: amd_xdlops.hpp:905
static __device__ void Run(const f8x32_t &reg_a, const int32_t &scale_a, const f8x32_t &reg_b, const int32_t &scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:691
static __device__ void Run(const f6x32_t &reg_a, const int32_t scale_a, const f6x32_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:802
static __device__ void Run(const bf8x32_t &reg_a, const int32_t &scale_a, const bf8x32_t &reg_b, const int32_t &scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:728
static __device__ void Run(const bf8x32_t &reg_a, const int32_t &scale_a, const f8x32_t &reg_b, const int32_t &scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:765
static __device__ void Run(const f4x32_t &reg_a, const int32_t scale_a, const f4x32_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:870
static __device__ void Run(const bf6x32_t &reg_a, const int32_t scale_a, const bf6x32_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:836
Definition: amd_xdlops.hpp:685
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:11