/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/utility/amd_xdlops.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/utility/amd_xdlops.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/utility/amd_xdlops.hpp Source File
amd_xdlops.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 namespace ck {
7 // Define the common macro for MI300 models
8 #if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) || defined(__gfx950__)
9 #define __gfx94__
10 #endif
11 
12 // fp32
13 template <index_t MPerWave, index_t NPerWave>
15 
16 template <>
18 {
19  template <class FloatC>
20  __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
21  {
22  reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
23  reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
24  reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
25  reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
26  }
27 };
28 
29 template <>
31 {
32  template <class FloatC>
33  __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
34  {
35  reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
36  reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
37  }
38 };
39 
40 template <index_t MPerWave, index_t NPerWave>
42 
43 template <>
45 {
46  template <class FloatC>
47  __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
48  {
49  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x2f32(
50  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
51  }
52 };
53 
54 template <index_t MPerWave, index_t NPerWave>
56 
57 template <>
59 {
60  template <class FloatC>
61  __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
62  {
63  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f32(
64  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
65  }
66 };
67 
68 template <index_t MPerWave, index_t NPerWave>
70 
71 template <>
73 {
74  template <class FloatC>
75  __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
76  {
77  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x1f32(
78  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
79  }
80 };
81 
82 template <index_t MPerWave, index_t NPerWave>
84 
85 template <>
87 {
88  template <class FloatC>
89  __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
90  {
91  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
92  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
93  }
94 };
95 
96 template <>
98 {
99  template <class FloatC>
100  __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
101  {
102  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
103  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
104  reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
105  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
106  }
107 };
108 
109 // fp16
110 template <index_t MPerWave, index_t NPerWave>
112 
113 template <>
115 {
116  template <class FloatC>
117  __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
118  {
119  reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
120  reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
121  reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
122  reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
123  }
124 };
125 
126 template <>
128 {
129  template <class FloatC>
130  __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
131  {
132  reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
133  reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
134  }
135 };
136 
137 template <index_t MPerWave, index_t NPerWave>
139 
140 template <>
142 {
143  template <class FloatC>
144  __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c)
145  {
146 #if defined(__gfx950__)
147  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_f16(
148  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
149 #else
150  ignore = reg_a;
151  ignore = reg_b;
152  ignore = reg_c;
153 #endif // defined(__gfx950__)
154  }
155 };
156 
157 template <index_t MPerWave, index_t NPerWave>
159 
160 template <>
162 {
163  template <class FloatC>
164  __device__ static void Run(const half8_t& reg_a, const half8_t& reg_b, FloatC& reg_c)
165  {
166 #if defined(__gfx950__)
167  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_f16(
168  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
169 #else
170  ignore = reg_a;
171  ignore = reg_b;
172  ignore = reg_c;
173 #endif // defined(__gfx950__)
174  }
175 };
176 
177 template <index_t MPerWave, index_t NPerWave>
179 
180 template <>
182 {
183  template <class FloatC>
184  __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
185  {
186  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8f16(
187  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
188  }
189 };
190 
191 template <index_t MPerWave, index_t NPerWave>
193 
194 template <>
196 {
197  template <class FloatC>
198  __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
199  {
200  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
201  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
202  }
203 };
204 
205 template <index_t MPerWave, index_t NPerWave>
207 
208 template <>
210 {
211  template <class FloatC>
212  __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
213  {
214  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f16(
215  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
216  }
217 };
218 
219 template <index_t MPerWave, index_t NPerWave>
221 
222 template <>
224 {
225  template <class FloatC>
226  __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
227  {
228  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
229  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
230  }
231 };
232 
233 template <>
235 {
236  template <class FloatC>
237  __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
238  {
239  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
240  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
241  reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
242  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
243  }
244 };
245 
246 // bfp16
247 template <index_t MPerWave, index_t NPerWave>
249 
250 template <>
252 {
253  template <class FloatC>
254  __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c)
255  {
256 #if defined(__gfx950__)
257  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x16_bf16(
258  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
259 #else
260  ignore = reg_a;
261  ignore = reg_b;
262  ignore = reg_c;
263 #endif // defined(__gfx950__)
264  }
265 };
266 
267 template <index_t MPerWave, index_t NPerWave>
269 
270 template <>
272 {
273  template <class FloatC>
274  __device__ static void Run(const bhalf8_t& reg_a, const bhalf8_t& reg_b, FloatC& reg_c)
275  {
276 #if defined(__gfx950__)
277  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf16(
278  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
279 #else
280  ignore = reg_a;
281  ignore = reg_b;
282  ignore = reg_c;
283 #endif // defined(__gfx950__)
284  }
285 };
286 
287 template <index_t MPerWave, index_t NPerWave>
289 
290 template <>
292 {
293  template <class FloatC>
294  __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
295  {
296  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
297  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
298  }
299 };
300 
301 template <index_t MPerWave, index_t NPerWave>
303 
304 template <>
306 {
307  template <class FloatC>
308  __device__ static void Run(const bhalf4_t& reg_a, const bhalf4_t& reg_b, FloatC& reg_c)
309  {
310  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
311  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
312  }
313 };
314 
315 template <index_t MPerWave, index_t NPerWave>
317 
318 template <>
320 {
321  template <class FloatC>
322  __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
323  {
324  reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
325  reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
326  }
327 };
328 
329 template <index_t MPerWave, index_t NPerWave>
331 
332 template <>
334 {
335  template <class FloatC>
336  __device__ static void Run(const bhalf2_t& reg_a, const bhalf2_t& reg_b, FloatC& reg_c)
337  {
338  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x8bf16(
339  reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
340  }
341 };
342 
343 template <index_t MPerWave, index_t NPerWave>
345 
346 template <>
348 {
349  template <class FloatC>
350  __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
351  {
352  reg_c.template AsType<int32x16_t>()(Number<0>{}) =
353  __builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast<int32_t>(reg_a),
354  bit_cast<int32_t>(reg_b),
355  reg_c.template AsType<int32x16_t>()[Number<0>{}],
356  0,
357  0,
358  0);
359  }
360 };
361 
362 template <index_t MPerWave, index_t NPerWave>
364 
365 template <>
367 {
368  template <class FloatC>
369  __device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
370  {
371  reg_c.template AsType<int32x4_t>()(Number<0>{}) =
372  __builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int32_t>(reg_a),
373  bit_cast<int32_t>(reg_b),
374  reg_c.template AsType<int32x4_t>()[Number<0>{}],
375  0,
376  0,
377  0);
378  }
379 };
380 
381 template <index_t MPerWave, index_t NPerWave>
383 
384 template <>
386 {
387  template <class FloatC>
388  __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
389  {
390 #if defined(__gfx950__)
391  reg_c.template AsType<int32x16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_i32_32x32x32_i8(
392  reg_a, reg_b, reg_c.template AsType<int32x16_t>()[Number<0>{}], 0, 0, 0);
393 #else
394  ignore = reg_a;
395  ignore = reg_b;
396  ignore = reg_c;
397 #endif // defined(__gfx950__)
398  }
399 };
400 
401 template <index_t MPerWave, index_t NPerWave>
403 
404 template <>
406 {
407  template <class FloatC>
408  __device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
409  {
410 #if defined(__gfx950__)
411  reg_c.template AsType<int32x4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_i32_16x16x64_i8(
412  reg_a, reg_b, reg_c.template AsType<int32x4_t>()[Number<0>{}], 0, 0, 0);
413 #else
414  ignore = reg_a;
415  ignore = reg_b;
416  ignore = reg_c;
417 #endif // defined(__gfx950__)
418  }
419 };
420 
421 template <index_t MPerWave, index_t NPerWave>
423 
424 template <>
426 {
427  template <class FloatC>
428  __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
429  {
430  reg_c.template AsType<int32x16_t>()(Number<0>{}) =
431  __builtin_amdgcn_mfma_i32_32x32x16_i8(bit_cast<int64_t>(reg_a),
432  bit_cast<int64_t>(reg_b),
433  reg_c.template AsType<int32x16_t>()[Number<0>{}],
434  0,
435  0,
436  0);
437  }
438 };
439 
440 template <index_t MPerWave, index_t NPerWave>
442 
443 template <>
445 {
446  template <class FloatC>
447  __device__ static void Run(const int8x8_t& reg_a, const int8x8_t& reg_b, FloatC& reg_c)
448  {
449  reg_c.template AsType<int32x4_t>()(Number<0>{}) =
450  __builtin_amdgcn_mfma_i32_16x16x32_i8(bit_cast<int64_t>(reg_a),
451  bit_cast<int64_t>(reg_b),
452  reg_c.template AsType<int32x4_t>()[Number<0>{}],
453  0,
454  0,
455  0);
456  }
457 };
458 
459 template <index_t MPerWave, index_t NPerWave>
461 
462 template <>
464 {
465  template <class FloatC>
466  __device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c)
467  {
468 #if defined(__gfx90a__) || defined(__gfx94__)
469  reg_c.template AsType<double4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64(
470  reg_a, reg_b, reg_c.template AsType<double4_t>()[Number<0>{}], 0, 0, 0);
471 #else
472  ignore = reg_a;
473  ignore = reg_b;
474  ignore = reg_c;
475 #endif
476  }
477 };
478 
479 template <index_t MPerWave, index_t NPerWave>
481 
488 template <>
490 {
491  template <class FloatC>
492  __device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
493  {
494 #if defined(__gfx950__)
495  reg_c.template AsType<float16_t>()(Number<0>{}) =
496  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
497  reg_a,
498  reg_b,
499  reg_c.template AsType<float16_t>()[Number<0>{}],
500  0, // cbsz
501  0, // blgp
502  0,
503  0,
504  0,
505  0);
506 #else
507  ignore = reg_a;
508  ignore = reg_b;
509  ignore = reg_c;
510 #endif
511  }
512 };
513 
514 template <index_t MPerWave, index_t NPerWave>
516 
517 template <>
519 {
520  template <class FloatC>
521  __device__ static void Run(const f8x32_t& reg_a,
522  const int32_t scale_a,
523  const f8x32_t& reg_b,
524  const int32_t scale_b,
525  FloatC& reg_c)
526  {
527 #if defined(__gfx950__)
528  // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
529  reg_c.template AsType<float16_t>()(Number<0>{}) =
530  __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
531  reg_a,
532  reg_b,
533  reg_c.template AsType<float16_t>()[Number<0>{}],
534  0, // cbsz
535  0, // blgp
536  0, // { OPSEL_HI[0], OPSEL[0] }?
537  scale_a,
538  0, // { OPSEL_HI[1], OPSEL[1] }?
539  scale_b);
540 #else
541  ignore = reg_a;
542  ignore = scale_a;
543  ignore = reg_b;
544  ignore = scale_b;
545  ignore = reg_c;
546 #endif
547  }
548 };
549 
550 template <index_t MPerWave, index_t NPerWave>
552 
553 template <>
555 {
556  template <class FloatC>
557  __device__ static void Run(const f8x32_t& reg_a,
558  const int32_t scale_a,
559  const f8x32_t& reg_b,
560  const int32_t scale_b,
561  FloatC& reg_c)
562  {
563 #if defined(__gfx950__)
564  // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
565  reg_c.template AsType<float4_t>()(Number<0>{}) =
566  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
567  reg_a,
568  reg_b,
569  reg_c.template AsType<float4_t>()[Number<0>{}],
570  0, // cbsz
571  0, // blgp
572  0, // { OPSEL_HI[0], OPSEL[0] }?
573  scale_a,
574  0, // { OPSEL_HI[1], OPSEL[1] }?
575  scale_b);
576 #else
577  ignore = reg_a;
578  ignore = scale_a;
579  ignore = reg_b;
580  ignore = scale_b;
581  ignore = reg_c;
582 #endif
583  }
584 };
585 
586 template <index_t MPerWave, index_t NPerWave>
588 
595 template <>
597 {
598  template <class FloatC>
599  __device__ static void Run(const f8x32_t& reg_a, const f8x32_t& reg_b, FloatC& reg_c)
600  {
601 #if defined(__gfx950__)
602  // https://github.com/ROCm/llvm-project/blob/656552edc693e2bb4abc9258399c39d190fce2b3/llvm/test/Verifier/AMDGPU/mfma-scale.ll#L10
603  reg_c.template AsType<float4_t>()(Number<0>{}) =
604  __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
605  reg_a,
606  reg_b,
607  reg_c.template AsType<float4_t>()[Number<0>{}],
608  0, // cbsz
609  0, // blgp
610  0,
611  0,
612  0,
613  0);
614 #else
615  ignore = reg_a;
616  ignore = reg_b;
617  ignore = reg_c;
618 #endif
619  }
620 };
621 
622 template <index_t MPerWave, index_t NPerWave>
624 
625 template <>
627 {
628  template <class FloatC>
629  __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
630  {
631 #if defined(__gfx94__)
632  reg_c.template AsType<float16_t>()(Number<0>{}) =
633  __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
634  bit_cast<long>(reg_a),
635  bit_cast<long>(reg_b),
636  reg_c.template AsType<float16_t>()[Number<0>{}],
637  0,
638  0,
639  0);
640 #else
641  vector_type<f8_t, 8> reg_a_v(reg_a);
642  vector_type<f8_t, 8> reg_b_v(reg_b);
643 
644  static_for<0, 8, 1>{}([&](auto k) {
645  float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
646  float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
647 
648  intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
649  });
650 #endif
651  }
652 };
653 
654 template <index_t MPerWave, index_t NPerWave>
656 
657 template <>
659 {
660  template <class FloatC>
661  __device__ static void Run(const f8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
662  {
663 #if defined(__gfx94__)
664  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
665  bit_cast<long>(reg_a),
666  bit_cast<long>(reg_b),
667  reg_c.template AsType<float4_t>()[Number<0>{}],
668  0,
669  0,
670  0);
671 #else
672  vector_type<f8_t, 8> reg_a_v(reg_a);
673  vector_type<f8_t, 8> reg_b_v(reg_b);
674 
675  static_for<0, 8, 1>{}([&](auto k) {
676  float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
677  float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
678 
679  intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
680  });
681 #endif
682  }
683 };
684 
685 template <index_t MPerWave, index_t NPerWave>
687 
688 template <>
690 {
691  template <class FloatC>
692  __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
693  {
694 #if defined(__gfx94__)
695  reg_c.template AsType<float16_t>()(Number<0>{}) =
696  __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
697  bit_cast<long>(reg_a),
698  bit_cast<long>(reg_b),
699  reg_c.template AsType<float16_t>()[Number<0>{}],
700  0,
701  0,
702  0);
703 #else
704  vector_type<bf8_t, 8> reg_a_v(reg_a);
705  vector_type<bf8_t, 8> reg_b_v(reg_b);
706 
707  static_for<0, 8, 1>{}([&](auto k) {
708  float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
709  float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
710 
711  intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
712  });
713 #endif
714  }
715 };
716 
717 template <index_t MPerWave, index_t NPerWave>
719 
720 template <>
722 {
723  template <class FloatC>
724  __device__ static void Run(const bf8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
725  {
726 #if defined(__gfx94__)
727  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
728  bit_cast<long>(reg_a),
729  bit_cast<long>(reg_b),
730  reg_c.template AsType<float4_t>()[Number<0>{}],
731  0,
732  0,
733  0);
734 #else
735  vector_type<bf8_t, 8> reg_a_v(reg_a);
736  vector_type<bf8_t, 8> reg_b_v(reg_b);
737 
738  static_for<0, 8, 1>{}([&](auto k) {
739  float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
740  float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
741 
742  intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
743  });
744 #endif
745  }
746 };
747 
748 template <index_t MPerWave, index_t NPerWave>
750 
751 template <>
753 {
754  template <class FloatC>
755  __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
756  {
757 #if defined(__gfx94__)
758  reg_c.template AsType<float16_t>()(Number<0>{}) =
759  __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
760  bit_cast<long>(reg_a),
761  bit_cast<long>(reg_b),
762  reg_c.template AsType<float16_t>()[Number<0>{}],
763  0,
764  0,
765  0);
766 #else
767  vector_type<f8_t, 8> reg_a_v(reg_a);
768  vector_type<bf8_t, 8> reg_b_v(reg_b);
769 
770  static_for<0, 8, 1>{}([&](auto k) {
771  float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
772  float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
773 
774  intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
775  });
776 #endif
777  }
778 };
779 
780 template <index_t MPerWave, index_t NPerWave>
782 
783 template <>
785 {
786  template <class FloatC>
787  __device__ static void Run(const f8x8_t& reg_a, const bf8x8_t& reg_b, FloatC& reg_c)
788  {
789 #if defined(__gfx94__)
790  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
791  bit_cast<long>(reg_a),
792  bit_cast<long>(reg_b),
793  reg_c.template AsType<float4_t>()[Number<0>{}],
794  0,
795  0,
796  0);
797 #else
798  vector_type<f8_t, 8> reg_a_v(reg_a);
799  vector_type<bf8_t, 8> reg_b_v(reg_b);
800 
801  static_for<0, 8, 1>{}([&](auto k) {
802  float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<f8_t>()[Number<k>{}]);
803  float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<bf8_t>()[Number<k>{}]);
804 
805  intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
806  });
807 #endif
808  }
809 };
810 
811 template <index_t MPerWave, index_t NPerWave>
813 
814 template <>
816 {
817  template <class FloatC>
818  __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
819  {
820 #if defined(__gfx94__)
821  reg_c.template AsType<float16_t>()(Number<0>{}) =
822  __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
823  bit_cast<long>(reg_a),
824  bit_cast<long>(reg_b),
825  reg_c.template AsType<float16_t>()[Number<0>{}],
826  0,
827  0,
828  0);
829 #else
830  vector_type<bf8_t, 8> reg_a_v(reg_a);
831  vector_type<f8_t, 8> reg_b_v(reg_b);
832 
833  static_for<0, 8, 1>{}([&](auto k) {
834  float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
835  float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
836 
837  intrin_mfma_f32_32x32x2f32<32, 32>::Run(reg_a_f32, reg_b_f32, reg_c);
838  });
839 #endif
840  }
841 };
842 
843 template <index_t MPerWave, index_t NPerWave>
845 
846 template <>
848 {
849  template <class FloatC>
850  __device__ static void Run(const bf8x8_t& reg_a, const f8x8_t& reg_b, FloatC& reg_c)
851  {
852 #if defined(__gfx94__)
853  reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
854  bit_cast<long>(reg_a),
855  bit_cast<long>(reg_b),
856  reg_c.template AsType<float4_t>()[Number<0>{}],
857  0,
858  0,
859  0);
860 #else
861  vector_type<bf8_t, 8> reg_a_v(reg_a);
862  vector_type<f8_t, 8> reg_b_v(reg_b);
863 
864  static_for<0, 8, 1>{}([&](auto k) {
865  float reg_a_f32 = type_convert<float>(reg_a_v.template AsType<bf8_t>()[Number<k>{}]);
866  float reg_b_f32 = type_convert<float>(reg_b_v.template AsType<f8_t>()[Number<k>{}]);
867 
868  intrin_mfma_f32_16x16x4f32<16, 16>::Run(reg_a_f32, reg_b_f32, reg_c);
869  });
870 #endif
871  }
872 };
873 
874 } // namespace ck
bf8_t __attribute((ext_vector_type(8))) bf8x8_t
Definition: vector_type.hpp:197
Definition: ck.hpp:264
typename vector_type< bhalf_t, 4 >::type bhalf4_t
Definition: data_type.hpp:2498
typename vector_type< bhalf_t, 8 >::type bhalf8_t
Definition: data_type.hpp:2499
typename vector_type< int8_t, 8 >::type int8x8_t
Definition: data_type.hpp:2515
typename vector_type< half_t, 4 >::type half4_t
Definition: data_type.hpp:2490
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
typename vector_type< bhalf_t, 2 >::type bhalf2_t
Definition: data_type.hpp:2497
typename vector_type< int8_t, 16 >::type int8x16_t
Definition: data_type.hpp:2516
typename vector_type< int8_t, 4 >::type int8x4_t
Definition: data_type.hpp:2514
typename vector_type< half_t, 8 >::type half8_t
Definition: data_type.hpp:2491
Definition: integral_constant.hpp:10
static __device__ void Run(const f8x32_t &reg_a, const f8x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:599
Definition: amd_xdlops.hpp:587
static __device__ void Run(const bhalf4_t &reg_a, const bhalf4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:308
Definition: amd_xdlops.hpp:302
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:198
Definition: amd_xdlops.hpp:192
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:75
Definition: amd_xdlops.hpp:69
static __device__ void Run(const bhalf8_t &reg_a, const bhalf8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:274
Definition: amd_xdlops.hpp:268
static __device__ void Run(const bf8x8_t &reg_a, const bf8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:724
Definition: amd_xdlops.hpp:718
static __device__ void Run(const bf8x8_t &reg_a, const f8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:850
Definition: amd_xdlops.hpp:844
static __device__ void Run(const half8_t &reg_a, const half8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:164
Definition: amd_xdlops.hpp:158
static __device__ void Run(const f8x8_t &reg_a, const bf8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:787
Definition: amd_xdlops.hpp:781
static __device__ void Run(const f8x8_t &reg_a, const f8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:661
Definition: amd_xdlops.hpp:655
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:212
Definition: amd_xdlops.hpp:206
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:61
Definition: amd_xdlops.hpp:55
static __device__ void Run(const bhalf2_t &reg_a, const bhalf2_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:336
Definition: amd_xdlops.hpp:330
static __device__ void Run(const bhalf8_t &reg_a, const bhalf8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:254
Definition: amd_xdlops.hpp:248
static __device__ void Run(const bf8x8_t &reg_a, const bf8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:692
Definition: amd_xdlops.hpp:686
static __device__ void Run(const bf8x8_t &reg_a, const f8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:818
Definition: amd_xdlops.hpp:812
static __device__ void Run(const half8_t &reg_a, const half8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:144
Definition: amd_xdlops.hpp:138
static __device__ void Run(const f8x8_t &reg_a, const bf8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:755
Definition: amd_xdlops.hpp:749
static __device__ void Run(const f8x8_t &reg_a, const f8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:629
Definition: amd_xdlops.hpp:623
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:33
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:20
Definition: amd_xdlops.hpp:14
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:47
Definition: amd_xdlops.hpp:41
static __device__ void Run(const bhalf2_t &reg_a, const bhalf2_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:322
Definition: amd_xdlops.hpp:316
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:130
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:117
Definition: amd_xdlops.hpp:111
static __device__ void Run(const f8x32_t &reg_a, const f8x32_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:492
Definition: amd_xdlops.hpp:480
static __device__ void Run(const bhalf4_t &reg_a, const bhalf4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:294
Definition: amd_xdlops.hpp:288
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:184
Definition: amd_xdlops.hpp:178
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:89
static __device__ void Run(const float &reg_a, const float &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:100
Definition: amd_xdlops.hpp:83
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:226
static __device__ void Run(const half4_t &reg_a, const half4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:237
Definition: amd_xdlops.hpp:220
static __device__ void Run(const double &reg_a, const double &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:466
Definition: amd_xdlops.hpp:460
static __device__ void Run(const int8x4_t &reg_a, const int8x4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:369
Definition: amd_xdlops.hpp:363
static __device__ void Run(const int8x8_t &reg_a, const int8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:447
Definition: amd_xdlops.hpp:441
static __device__ void Run(const int8x16_t &reg_a, const int8x16_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:408
Definition: amd_xdlops.hpp:402
static __device__ void Run(const int8x8_t &reg_a, const int8x8_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:428
Definition: amd_xdlops.hpp:422
static __device__ void Run(const int8x16_t &reg_a, const int8x16_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:388
Definition: amd_xdlops.hpp:382
static __device__ void Run(const int8x4_t &reg_a, const int8x4_t &reg_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:350
Definition: amd_xdlops.hpp:344
static __device__ void Run(const f8x32_t &reg_a, const int32_t scale_a, const f8x32_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:557
Definition: amd_xdlops.hpp:551
static __device__ void Run(const f8x32_t &reg_a, const int32_t scale_a, const f8x32_t &reg_b, const int32_t scale_b, FloatC &reg_c)
Definition: amd_xdlops.hpp:521
Definition: amd_xdlops.hpp:515
Definition: functional2.hpp:31
Definition: data_type.hpp:347