/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp Source File
xdlops_gemm.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
7 #include "ck/utility/math.hpp"
10 
11 namespace ck {
15 template <typename T>
16 static constexpr bool is_scale_mfma_data_type()
17 {
18  using U = element_type_t<T>;
19  return is_same_v<U, f8_ocp_t> || is_same_v<U, bf8_ocp_t> || is_same_v<U, f6_t> ||
20  is_same_v<U, bf6_t> || is_same_v<U, f4_t>;
21 }
22 
23 #ifndef CK_CODE_GEN_RTC
27 template <typename T>
28 static constexpr bool is_scale_mfma_scale_type()
29 {
30  return is_same_v<T, e8m0_bexp_t>;
31 }
32 #endif
33 
37 template <typename ADataType, typename BDataType, typename AScaleDataType, typename BScaleDataType>
38 static constexpr bool scale_mfma_hw_support()
39 {
40  return is_scale_mfma_data_type<ADataType>() && is_scale_mfma_data_type<BDataType>() &&
41  is_scale_mfma_scale_type<AScaleDataType>() && is_scale_mfma_scale_type<BScaleDataType>();
42 }
43 
44 enum struct MfmaInstr
45 {
83  mfma_f32_16x16x8xf32, // tf32 on gfx942
84  mfma_f32_32x32x4xf32, // tf32 on gfx942
85  mfma_f32_16x16x32xf32, // bf16x3 simulate tf32 on gfx950
86  mfma_f32_32x32x16xf32, // bf16x3 simulate tf32 on gfx950
87  // gfx11
92  // gfx12
101 };
102 
103 template <MfmaInstr instr>
104 struct mfma_type;
105 
106 template <>
108 {
109  static constexpr index_t group_size = 4;
110  static constexpr index_t num_groups_per_blk = 4;
111  static constexpr index_t num_regs_per_blk = 16;
112  static constexpr index_t num_threads_per_blk = 32;
113  static constexpr index_t wave_size = 64;
114  static constexpr index_t num_input_blks = 2;
115  static constexpr index_t num_output_blks = 2;
116  static constexpr index_t m_per_blk = 32;
117  static constexpr index_t n_per_blk = 32;
118  static constexpr index_t k_per_blk = 1;
119  static constexpr bool is_k_reduction = false;
120 
121  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
122  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
123  {
125  }
126 };
127 
128 template <>
130 {
131  static constexpr index_t group_size = 4;
132  static constexpr index_t num_groups_per_blk = 4;
133  static constexpr index_t num_regs_per_blk = 16;
134  static constexpr index_t num_threads_per_blk = 32;
135  static constexpr index_t wave_size = 64;
136  static constexpr index_t num_input_blks = 2;
137  static constexpr index_t num_output_blks = 1;
138  static constexpr index_t m_per_blk = 32;
139  static constexpr index_t n_per_blk = 32;
140  static constexpr index_t k_per_blk = 1;
141  static constexpr bool is_k_reduction = true;
142 
143  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
144  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
145  {
147  }
148 };
149 
150 template <>
152 {
153  static constexpr index_t group_size = 4;
154  static constexpr index_t num_groups_per_blk = 1;
155  static constexpr index_t num_regs_per_blk = 4;
156  static constexpr index_t num_threads_per_blk = 16;
157  static constexpr index_t wave_size = 64;
158  static constexpr index_t num_input_blks = 4;
159  static constexpr index_t num_output_blks = 1;
160  static constexpr index_t m_per_blk = 16;
161  static constexpr index_t n_per_blk = 16;
162  static constexpr index_t k_per_blk = 1;
163  static constexpr bool is_k_reduction = true;
164 
165  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
166  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
167  {
169  }
170 };
171 
172 template <>
174 {
175  static constexpr index_t group_size = 4;
176  static constexpr index_t num_groups_per_blk = 1;
177  static constexpr index_t num_regs_per_blk = 4;
178  static constexpr index_t num_threads_per_blk = 16;
179  static constexpr index_t wave_size = 64;
180  static constexpr index_t num_input_blks = 4;
181  static constexpr index_t num_output_blks = 4;
182  static constexpr index_t m_per_blk = 16;
183  static constexpr index_t n_per_blk = 16;
184  static constexpr index_t k_per_blk = 1;
185  static constexpr bool is_k_reduction = false;
186 
187  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
188  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
189  {
191  }
192 };
193 
194 // treat 4x4x1 as a single-blk 4x64 mfma
195 template <>
197 {
198  static constexpr index_t group_size = 4;
199  static constexpr index_t num_groups_per_blk = 1;
200  static constexpr index_t num_regs_per_blk = 4;
201  static constexpr index_t num_threads_per_blk = 64;
202  static constexpr index_t wave_size = 64;
203  static constexpr index_t num_input_blks = 1;
204  static constexpr index_t num_output_blks = 1;
205  static constexpr index_t m_per_blk = 4;
206  static constexpr index_t n_per_blk = 64;
207  static constexpr index_t k_per_blk = 1;
208  static constexpr bool is_k_reduction = false;
209 
210  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
211  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
212  {
214  }
215 };
216 
217 template <>
219 {
220  static constexpr index_t group_size = 4;
221  static constexpr index_t num_groups_per_blk = 4;
222  static constexpr index_t num_regs_per_blk = 16;
223  static constexpr index_t num_threads_per_blk = 32;
224  static constexpr index_t wave_size = 64;
225  static constexpr index_t num_input_blks = 2;
226  static constexpr index_t num_output_blks = 2;
227  static constexpr index_t m_per_blk = 32;
228  static constexpr index_t n_per_blk = 32;
229  static constexpr index_t k_per_blk = 4;
230  static constexpr bool is_k_reduction = false;
231 
232  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
233  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
234  {
236  }
237 };
238 
239 template <>
241 {
242  static constexpr index_t group_size = 4;
243  static constexpr index_t num_groups_per_blk = 4;
244  static constexpr index_t num_regs_per_blk = 16;
245  static constexpr index_t num_threads_per_blk = 32;
246  static constexpr index_t wave_size = 64;
247  static constexpr index_t num_input_blks = 2;
248  static constexpr index_t num_output_blks = 1;
249  static constexpr index_t m_per_blk = 32;
250  static constexpr index_t n_per_blk = 32;
251  static constexpr index_t k_per_blk = 4;
252  static constexpr bool is_k_reduction = true;
253 
254  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
255  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
256  {
258  }
259 };
260 
261 template <>
263 {
264  static constexpr index_t group_size = 4;
265  static constexpr index_t num_groups_per_blk = 4;
266  static constexpr index_t num_regs_per_blk = 16;
267  static constexpr index_t num_threads_per_blk = 32;
268  static constexpr index_t wave_size = 64;
269  static constexpr index_t num_input_blks = 2;
270  static constexpr index_t num_output_blks = 1;
271  static constexpr index_t m_per_blk = 32;
272  static constexpr index_t n_per_blk = 32;
273  static constexpr index_t k_per_blk = 8;
274  static constexpr bool is_k_reduction = true;
275 
276  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
277  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
278  {
280  }
281 };
282 
283 template <>
285 {
286  static constexpr index_t group_size = 4;
287  static constexpr index_t num_groups_per_blk = 1;
288  static constexpr index_t num_regs_per_blk = 4;
289  static constexpr index_t num_threads_per_blk = 16;
290  static constexpr index_t wave_size = 64;
291  static constexpr index_t num_input_blks = 4;
292  static constexpr index_t num_output_blks = 1;
293  static constexpr index_t m_per_blk = 16;
294  static constexpr index_t n_per_blk = 16;
295  static constexpr index_t k_per_blk = 8;
296  static constexpr bool is_k_reduction = true;
297 
298  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
299  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
300  {
302  }
303 };
304 
305 template <>
307 {
308  static constexpr index_t group_size = 4;
309  static constexpr index_t num_groups_per_blk = 1;
310  static constexpr index_t num_regs_per_blk = 4;
311  static constexpr index_t num_threads_per_blk = 16;
312  static constexpr index_t wave_size = 64;
313  static constexpr index_t num_input_blks = 4;
314  static constexpr index_t num_output_blks = 1;
315  static constexpr index_t m_per_blk = 16;
316  static constexpr index_t n_per_blk = 16;
317  static constexpr index_t k_per_blk = 4;
318  static constexpr bool is_k_reduction = true;
319 
320  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
321  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
322  {
324  }
325 };
326 
327 template <>
329 {
330  static constexpr index_t group_size = 4;
331  static constexpr index_t num_groups_per_blk = 1;
332  static constexpr index_t num_regs_per_blk = 4;
333  static constexpr index_t num_threads_per_blk = 16;
334  static constexpr index_t wave_size = 64;
335  static constexpr index_t num_input_blks = 4;
336  static constexpr index_t num_output_blks = 4;
337  static constexpr index_t m_per_blk = 16;
338  static constexpr index_t n_per_blk = 16;
339  static constexpr index_t k_per_blk = 4;
340  static constexpr bool is_k_reduction = false;
341 
342  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
343  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
344  {
346  }
347 };
348 
349 template <>
351 {
352  static constexpr index_t group_size = 4;
353  static constexpr index_t num_groups_per_blk = 1;
354  static constexpr index_t num_regs_per_blk = 4;
355  static constexpr index_t num_threads_per_blk = 64;
356  static constexpr index_t wave_size = 64;
357  static constexpr index_t num_input_blks = 1;
358  static constexpr index_t num_output_blks = 1;
359  static constexpr index_t m_per_blk = 4;
360  static constexpr index_t n_per_blk = 64;
361  static constexpr index_t k_per_blk = 4;
362  static constexpr bool is_k_reduction = false;
363 
364  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
365  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
366  {
368  }
369 };
370 
371 template <>
373 {
374  static constexpr index_t group_size = 4;
375  static constexpr index_t num_groups_per_blk = 4;
376  static constexpr index_t num_regs_per_blk = 16;
377  static constexpr index_t num_threads_per_blk = 32;
378  static constexpr index_t wave_size = 64;
379  static constexpr index_t num_input_blks = 2;
380  static constexpr index_t num_output_blks = 1;
381  static constexpr index_t m_per_blk = 32;
382  static constexpr index_t n_per_blk = 32;
383  static constexpr index_t k_per_blk = 8;
384  static constexpr bool is_k_reduction = true;
385 
386  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
387  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
388  {
390  }
391 };
392 
393 template <>
395 {
396  static constexpr index_t group_size = 4;
397  static constexpr index_t num_groups_per_blk = 4;
398  static constexpr index_t num_regs_per_blk = 16;
399  static constexpr index_t num_threads_per_blk = 32;
400  static constexpr index_t wave_size = 64;
401  static constexpr index_t num_input_blks = 2;
402  static constexpr index_t num_output_blks = 1;
403  static constexpr index_t m_per_blk = 32;
404  static constexpr index_t n_per_blk = 32;
405  static constexpr index_t k_per_blk = 4;
406  static constexpr bool is_k_reduction = true;
407 
408  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
409  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
410  {
412  }
413 };
414 
415 template <>
417 {
418  static constexpr index_t group_size = 4;
419  static constexpr index_t num_groups_per_blk = 1;
420  static constexpr index_t num_regs_per_blk = 4;
421  static constexpr index_t num_threads_per_blk = 16;
422  static constexpr index_t wave_size = 64;
423  static constexpr index_t num_input_blks = 4;
424  static constexpr index_t num_output_blks = 1;
425  static constexpr index_t m_per_blk = 16;
426  static constexpr index_t n_per_blk = 16;
427  static constexpr index_t k_per_blk = 8;
428  static constexpr bool is_k_reduction = true;
429 
430  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
431  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
432  {
434  }
435 };
436 
437 template <>
439 {
440  static constexpr index_t group_size = 4;
441  static constexpr index_t num_groups_per_blk = 1;
442  static constexpr index_t num_regs_per_blk = 4;
443  static constexpr index_t num_threads_per_blk = 16;
444  static constexpr index_t wave_size = 64;
445  static constexpr index_t num_input_blks = 4;
446  static constexpr index_t num_output_blks = 1;
447  static constexpr index_t m_per_blk = 16;
448  static constexpr index_t n_per_blk = 16;
449  static constexpr index_t k_per_blk = 4;
450  static constexpr bool is_k_reduction = true;
451 
452  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
453  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
454  {
456  }
457 };
458 
459 template <>
461 {
462  static constexpr index_t group_size = 4;
463  static constexpr index_t num_groups_per_blk = 4;
464  static constexpr index_t num_regs_per_blk = 16;
465  static constexpr index_t num_threads_per_blk = 32;
466  static constexpr index_t wave_size = 64;
467  static constexpr index_t num_input_blks = 2;
468  static constexpr index_t num_output_blks = 1;
469  static constexpr index_t m_per_blk = 32;
470  static constexpr index_t n_per_blk = 32;
471  static constexpr index_t k_per_blk = 2;
472  static constexpr bool is_k_reduction = true;
473 
474  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
475  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
476  {
478  }
479 };
480 
481 template <>
483 {
484  static constexpr index_t group_size = 4;
485  static constexpr index_t num_groups_per_blk = 1;
486  static constexpr index_t num_regs_per_blk = 4;
487  static constexpr index_t num_threads_per_blk = 16;
488  static constexpr index_t wave_size = 64;
489  static constexpr index_t num_input_blks = 4;
490  static constexpr index_t num_output_blks = 1;
491  static constexpr index_t m_per_blk = 16;
492  static constexpr index_t n_per_blk = 16;
493  static constexpr index_t k_per_blk = 2;
494  static constexpr bool is_k_reduction = true;
495 
496  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
497  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
498  {
500  }
501 };
502 
503 template <>
505 {
506  static constexpr index_t group_size = 4;
507  static constexpr index_t num_groups_per_blk = 4;
508  static constexpr index_t num_regs_per_blk = 16;
509  static constexpr index_t num_threads_per_blk = 32;
510  static constexpr index_t wave_size = 64;
511  static constexpr index_t num_input_blks = 2;
512  static constexpr index_t num_output_blks = 1;
513  static constexpr index_t m_per_blk = 32;
514  static constexpr index_t n_per_blk = 32;
515  static constexpr index_t k_per_blk = 4;
516  static constexpr bool is_k_reduction = true;
517 
518  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
519  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
520  {
522  }
523 };
524 
525 template <>
527 {
528  static constexpr index_t group_size = 4;
529  static constexpr index_t num_groups_per_blk = 1;
530  static constexpr index_t num_regs_per_blk = 4;
531  static constexpr index_t num_threads_per_blk = 16;
532  static constexpr index_t wave_size = 64;
533  static constexpr index_t num_input_blks = 4;
534  static constexpr index_t num_output_blks = 1;
535  static constexpr index_t m_per_blk = 16;
536  static constexpr index_t n_per_blk = 16;
537  static constexpr index_t k_per_blk = 4;
538  static constexpr bool is_k_reduction = true;
539 
540  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
541  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
542  {
544  }
545 };
546 
547 template <>
549 {
550  static constexpr index_t group_size = 4;
551  static constexpr index_t num_groups_per_blk = 4;
552  static constexpr index_t num_regs_per_blk = 16;
553  static constexpr index_t num_threads_per_blk = 32;
554  static constexpr index_t wave_size = 64;
555  static constexpr index_t num_input_blks = 2;
556  static constexpr index_t num_output_blks = 1;
557  static constexpr index_t m_per_blk = 32;
558  static constexpr index_t n_per_blk = 32;
559  static constexpr index_t k_per_blk = 8;
560  static constexpr bool is_k_reduction = true;
561 
562  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
563  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
564  {
566  }
567 };
568 
569 template <>
571 {
572  static constexpr index_t group_size = 4;
573  static constexpr index_t num_groups_per_blk = 1;
574  static constexpr index_t num_regs_per_blk = 4;
575  static constexpr index_t num_threads_per_blk = 16;
576  static constexpr index_t wave_size = 64;
577  static constexpr index_t num_input_blks = 4;
578  static constexpr index_t num_output_blks = 1;
579  static constexpr index_t m_per_blk = 16;
580  static constexpr index_t n_per_blk = 16;
581  static constexpr index_t k_per_blk = 8;
582  static constexpr bool is_k_reduction = true;
583 
584  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
585  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
586  {
588  }
589 };
590 
591 template <>
593 {
594  static constexpr index_t group_size = 4;
595  static constexpr index_t num_groups_per_blk = 4;
596  static constexpr index_t num_regs_per_blk = 16;
597  static constexpr index_t num_threads_per_blk = 32;
598  static constexpr index_t wave_size = 64;
599  static constexpr index_t num_input_blks = 2;
600  static constexpr index_t num_output_blks = 1;
601  static constexpr index_t m_per_blk = 32;
602  static constexpr index_t n_per_blk = 32;
603  static constexpr index_t k_per_blk = 16;
604  static constexpr bool is_k_reduction = true;
605 
606  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
607  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
608  {
610  }
611 };
612 
613 template <>
615 {
616  static constexpr index_t group_size = 4;
617  static constexpr index_t num_groups_per_blk = 1;
618  static constexpr index_t num_regs_per_blk = 4;
619  static constexpr index_t num_threads_per_blk = 16;
620  static constexpr index_t wave_size = 64;
621  static constexpr index_t num_input_blks = 4;
622  static constexpr index_t num_output_blks = 1;
623  static constexpr index_t m_per_blk = 16;
624  static constexpr index_t n_per_blk = 16;
625  static constexpr index_t k_per_blk = 16;
626  static constexpr bool is_k_reduction = true;
627 
628  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
629  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
630  {
632  }
633 };
634 
635 template <>
637 {
638  static constexpr index_t group_size = 1;
639  static constexpr index_t num_groups_per_blk = 4;
640  static constexpr index_t num_regs_per_blk = 4; // group_size * num_groups_per_blk;
641  static constexpr index_t num_threads_per_blk = 16;
642  static constexpr index_t wave_size = 64;
643  static constexpr index_t num_input_blks = 4; // wave_size / num_threads_per_blk;
644  static constexpr index_t num_output_blks = 1;
645  static constexpr index_t m_per_blk = 16;
646  static constexpr index_t n_per_blk = 16;
647  static constexpr index_t k_per_blk = 1;
648  static constexpr bool is_k_reduction = true;
649 
650  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
651  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
652  {
654  }
655 };
656 
657 template <>
659 {
660  static constexpr index_t group_size = 4;
661  static constexpr index_t num_groups_per_blk = 4;
662  static constexpr index_t num_regs_per_blk = 16;
663  static constexpr index_t num_threads_per_blk = 32;
664  static constexpr index_t wave_size = 64;
665  static constexpr index_t num_input_blks = 2;
666  static constexpr index_t num_output_blks = 1;
667  static constexpr index_t m_per_blk = 32;
668  static constexpr index_t n_per_blk = 32;
669  static constexpr index_t k_per_blk = 8;
670  static constexpr bool is_k_reduction = true;
671 
672  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
673  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
674  {
676  }
677 };
678 
679 template <>
681 {
682  static constexpr index_t group_size = 4;
683  static constexpr index_t num_groups_per_blk = 1;
684  static constexpr index_t num_regs_per_blk = 4;
685  static constexpr index_t num_threads_per_blk = 16;
686  static constexpr index_t wave_size = 64;
687  static constexpr index_t num_input_blks = 4;
688  static constexpr index_t num_output_blks = 1;
689  static constexpr index_t m_per_blk = 16;
690  static constexpr index_t n_per_blk = 16;
691  static constexpr index_t k_per_blk = 8;
692  static constexpr bool is_k_reduction = true;
693 
694  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
695  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
696  {
698  }
699 };
700 
701 template <>
703 {
704  static constexpr index_t group_size = 4;
705  static constexpr index_t num_groups_per_blk = 4;
706  static constexpr index_t num_regs_per_blk = 16;
707  static constexpr index_t num_threads_per_blk = 32;
708  static constexpr index_t wave_size = 64;
709  static constexpr index_t num_input_blks = 2;
710  static constexpr index_t num_output_blks = 1;
711  static constexpr index_t m_per_blk = 32;
712  static constexpr index_t n_per_blk = 32;
713  static constexpr index_t k_per_blk = 8;
714  static constexpr bool is_k_reduction = true;
715 
716  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
717  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
718  {
720  }
721 };
722 
723 template <>
725 {
726  static constexpr index_t group_size = 4;
727  static constexpr index_t num_groups_per_blk = 1;
728  static constexpr index_t num_regs_per_blk = 4;
729  static constexpr index_t num_threads_per_blk = 16;
730  static constexpr index_t wave_size = 64;
731  static constexpr index_t num_input_blks = 4;
732  static constexpr index_t num_output_blks = 1;
733  static constexpr index_t m_per_blk = 16;
734  static constexpr index_t n_per_blk = 16;
735  static constexpr index_t k_per_blk = 8;
736  static constexpr bool is_k_reduction = true;
737 
738  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
739  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
740  {
742  }
743 };
744 
745 template <>
747 {
748  static constexpr index_t group_size = 4;
749  static constexpr index_t num_groups_per_blk = 4;
750  static constexpr index_t num_regs_per_blk = 16;
751  static constexpr index_t num_threads_per_blk = 32;
752  static constexpr index_t wave_size = 64;
753  static constexpr index_t num_input_blks = 2;
754  static constexpr index_t num_output_blks = 1;
755  static constexpr index_t m_per_blk = 32;
756  static constexpr index_t n_per_blk = 32;
757  static constexpr index_t k_per_blk = 8;
758  static constexpr bool is_k_reduction = true;
759 
760  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
761  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
762  {
764  }
765 };
766 
767 template <>
769 {
770  static constexpr index_t group_size = 4;
771  static constexpr index_t num_groups_per_blk = 1;
772  static constexpr index_t num_regs_per_blk = 4;
773  static constexpr index_t num_threads_per_blk = 16;
774  static constexpr index_t wave_size = 64;
775  static constexpr index_t num_input_blks = 4;
776  static constexpr index_t num_output_blks = 1;
777  static constexpr index_t m_per_blk = 16;
778  static constexpr index_t n_per_blk = 16;
779  static constexpr index_t k_per_blk = 8;
780  static constexpr bool is_k_reduction = true;
781 
782  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
783  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
784  {
786  }
787 };
788 
789 template <>
791 {
792  static constexpr index_t group_size = 4;
793  static constexpr index_t num_groups_per_blk = 4;
794  static constexpr index_t num_regs_per_blk = 16;
795  static constexpr index_t num_threads_per_blk = 32;
796  static constexpr index_t wave_size = 64;
797  static constexpr index_t num_input_blks = 2;
798  static constexpr index_t num_output_blks = 1;
799  static constexpr index_t m_per_blk = 32;
800  static constexpr index_t n_per_blk = 32;
801  static constexpr index_t k_per_blk = 8;
802  static constexpr bool is_k_reduction = true;
803 
804  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
805  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
806  {
808  }
809 };
810 
811 template <>
813 {
814  static constexpr index_t group_size = 4;
815  static constexpr index_t num_groups_per_blk = 1;
816  static constexpr index_t num_regs_per_blk = 4;
817  static constexpr index_t num_threads_per_blk = 16;
818  static constexpr index_t wave_size = 64;
819  static constexpr index_t num_input_blks = 4;
820  static constexpr index_t num_output_blks = 1;
821  static constexpr index_t m_per_blk = 16;
822  static constexpr index_t n_per_blk = 16;
823  static constexpr index_t k_per_blk = 8;
824  static constexpr bool is_k_reduction = true;
825 
826  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
827  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
828  {
830  }
831 };
832 
833 template <>
835 {
836  // clang-format off
837  static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
838  static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
839  static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size
840  static constexpr index_t num_threads_per_blk = 32; // n_per_blk
841  static constexpr index_t wave_size = 64; // fixed
842  static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk
843  static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
844  static constexpr index_t m_per_blk = 32; // from the instruction
845  static constexpr index_t n_per_blk = 32; // from the instruction
846  static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
847  static constexpr bool is_k_reduction = true; // ???
848  // clang-format on
849 
850  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
851  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
852  {
854  }
855 };
856 
857 template <>
859 {
860  // clang-format off
861  static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
862  static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk
863  static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size
864  static constexpr index_t num_threads_per_blk = 16; // == n_per_blk
865  static constexpr index_t wave_size = 64; // fixed
866  static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk
867  static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
868  static constexpr index_t m_per_blk = 16; // from the instruction
869  static constexpr index_t n_per_blk = 16; // from the instruction
870  static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
871  static constexpr bool is_k_reduction = true; // ???
872  // clang-format on
873 
874  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
875  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
876  {
878  }
879 };
880 
881 template <>
883 {
884  // clang-format off
885  static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
886  static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
887  static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size
888  static constexpr index_t num_threads_per_blk = 32; // n_per_blk
889  static constexpr index_t wave_size = 64; // fixed
890  static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk
891  static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
892  static constexpr index_t m_per_blk = 32; // from the instruction
893  static constexpr index_t n_per_blk = 32; // from the instruction
894  static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
895  static constexpr bool is_k_reduction = true; // ???
896  // clang-format on
897 
898  template <index_t MPerXdlops,
899  index_t NPerXdlops,
900  index_t OpselA,
901  index_t OpselB,
902  class FloatA,
903  class ScaleA,
904  class FloatB,
905  class ScaleB,
906  class FloatC>
907  __device__ void run(const FloatA& a,
908  const ScaleA& scale_a,
909  const FloatB& b,
910  const ScaleB& scale_b,
911  FloatC& reg_c) const
912  {
914  a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
915  }
916 };
917 
918 template <>
920 {
921  // clang-format off
922  static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
923  static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk
924  static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size
925  static constexpr index_t num_threads_per_blk = 16; // == n_per_blk
926  static constexpr index_t wave_size = 64; // fixed
927  static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk
928  static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
929  static constexpr index_t m_per_blk = 16; // from the instruction
930  static constexpr index_t n_per_blk = 16; // from the instruction
931  static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
932  static constexpr bool is_k_reduction = true; // ???
933  // clang-format on
934 
935  template <index_t MPerXdlops,
936  index_t NPerXdlops,
937  index_t OpselA,
938  index_t OpselB,
939  class FloatA,
940  class ScaleA,
941  class FloatB,
942  class ScaleB,
943  class FloatC>
944  __device__ void run(const FloatA& a,
945  const ScaleA& scale_a,
946  const FloatB& b,
947  const ScaleB& scale_b,
948  FloatC& reg_c) const
949  {
950 
952  a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
953  }
954 };
955 
975 template <>
977 {
978  static constexpr index_t wave_size = 64; // fixed
979  static constexpr index_t m_per_blk = 16; // from the instruction
980  static constexpr index_t n_per_blk = 16; // from the instruction
981  static constexpr index_t num_threads_per_blk = n_per_blk; // 16
982  static constexpr index_t num_regs_per_blk = m_per_blk * n_per_blk / wave_size; // 4
983  static constexpr index_t num_input_blks = m_per_blk / num_regs_per_blk; // 4
984  static constexpr index_t group_size = 4;
985  static constexpr index_t num_groups_per_blk = 1;
986  static constexpr index_t num_output_blks = 1;
987  static constexpr index_t k_per_blk = 2; // k_per_blk(K1PerXdlops) should be 2.
988  static constexpr bool is_k_reduction = true;
989 
990  // AB register size : 2, register size: 4
991  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
992  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
993  {
995  }
996 };
997 
998 template <>
1000 {
1001  static constexpr index_t wave_size = 64; // fixed
1002  static constexpr index_t m_per_blk = 32; // from the instruction
1003  static constexpr index_t n_per_blk = 32; // from the instruction
1004  static constexpr index_t num_threads_per_blk = n_per_blk; // 32
1005  static constexpr index_t num_regs_per_blk = m_per_blk * n_per_blk / wave_size; // 16
1006  static constexpr index_t num_input_blks = m_per_blk / num_regs_per_blk; // 2
1007  static constexpr index_t group_size = 4; // corresponding to CD rows mapping
1008  static constexpr index_t num_groups_per_blk = 4;
1009  static constexpr index_t num_output_blks = 1;
1010  static constexpr index_t k_per_blk = 2;
1011  static constexpr bool is_k_reduction = true;
1012  // AB register size: 2, CD register size: 16
1013  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
1014  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1015  {
1017  }
1018 };
1019 
1020 template <>
1022 {
1023  // gfx950 specific: use bf16x3 simulate tf32
1024  static constexpr index_t group_size = 4;
1025  static constexpr index_t num_groups_per_blk = 4;
1026  static constexpr index_t num_regs_per_blk = 16;
1027  static constexpr index_t num_threads_per_blk = 32;
1028  static constexpr index_t wave_size = 64;
1029  static constexpr index_t num_input_blks = 2;
1030  static constexpr index_t num_output_blks = 1;
1031  static constexpr index_t m_per_blk = 32;
1032  static constexpr index_t n_per_blk = 32;
1033  static constexpr index_t k_per_blk = 8;
1034  static constexpr bool is_k_reduction = true;
1035 
1036  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
1037  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1038  {
1040  }
1041 };
1042 template <>
1044 {
1045  // gfx950 specific: use bf16x3 simulate tf32
1046  static constexpr index_t group_size = 4;
1047  static constexpr index_t num_groups_per_blk = 1;
1048  static constexpr index_t num_regs_per_blk = 4;
1049  static constexpr index_t num_threads_per_blk = 16;
1050  static constexpr index_t wave_size = 64;
1051  static constexpr index_t num_input_blks = 4;
1052  static constexpr index_t num_output_blks = 1;
1053  static constexpr index_t m_per_blk = 16;
1054  static constexpr index_t n_per_blk = 16;
1055  static constexpr index_t k_per_blk = 8;
1056  static constexpr bool is_k_reduction = true;
1057 
1058  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
1059  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1060  {
1062  }
1063 };
1064 
1065 // gfx11
1067 {
1068  static constexpr index_t group_size = 8;
1069  static constexpr index_t num_groups_per_blk = 1;
1070  static constexpr index_t num_regs_per_blk = 8;
1071  static constexpr index_t num_threads_per_blk = 16;
1072  static constexpr index_t wave_size = 32;
1073  static constexpr index_t num_input_blks = 1;
1074  static constexpr index_t num_output_blks = 1;
1075  static constexpr index_t m_per_blk = 16;
1076  static constexpr index_t n_per_blk = 16;
1077  static constexpr index_t k_per_blk = 16;
1078  static constexpr bool is_k_reduction = true;
1079 };
1080 
1081 template <>
1083 {
1084  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1085  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1086  {
1088  }
1089 };
1090 
1091 template <>
1093 {
1094  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1095  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1096  {
1098  }
1099 };
1100 
1101 template <>
1103 {
1104  template <index_t MPerWmma,
1105  index_t NPerWmma,
1106  class FloatA,
1107  class FloatB,
1108  class FloatC,
1109  bool neg_a = true,
1110  bool neg_b = true,
1111  bool clamp = false>
1112  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1113  {
1115  }
1116 };
1117 
1118 template <>
1120 {
1121  static constexpr index_t k_per_blk = 2;
1122  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1123  __device__ void run(const FloatA&, const FloatB&, FloatC&) const
1124  {
1125  // empty for all unsupported types.
1126  }
1127 };
1128 
1129 // gfx12
1131 {
1132  static constexpr index_t group_size = 8;
1133  static constexpr index_t num_groups_per_blk = 1;
1134  static constexpr index_t num_regs_per_blk = 8;
1135  static constexpr index_t num_threads_per_blk = 16;
1136  static constexpr index_t wave_size = 32;
1137  static constexpr index_t num_input_blks = 2;
1138  static constexpr index_t num_output_blks = 1;
1139  static constexpr index_t m_per_blk = 16;
1140  static constexpr index_t n_per_blk = 16;
1141  static constexpr index_t k_per_blk = 8;
1142  static constexpr bool is_k_reduction = true;
1143 };
1144 
1145 template <>
1147 {
1148  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1149  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1150  {
1152  }
1153 };
1154 
1155 template <>
1157 {
1158  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1159  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1160  {
1162  }
1163 };
1164 
1165 template <>
1167 {
1168  template <index_t MPerWmma,
1169  index_t NPerWmma,
1170  class FloatA,
1171  class FloatB,
1172  class FloatC,
1173  bool neg_a = true,
1174  bool neg_b = true,
1175  bool clamp = false>
1176  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1177  {
1179  a, b, reg_c);
1180  }
1181 };
1182 
1183 template <>
1185 {
1186  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1187  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1188  {
1190  }
1191 };
1192 
1193 template <>
1195 {
1196  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1197  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1198  {
1200  }
1201 };
1202 
1203 template <>
1205 {
1206  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1207  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1208  {
1210  }
1211 };
1212 
1213 template <>
1215 {
1216  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1217  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
1218  {
1220  }
1221 };
1222 
1223 template <>
1225 {
1226  static constexpr index_t k_per_blk = 2;
1227  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
1228  __device__ void run(const FloatA&, const FloatB&, FloatC&) const
1229  {
1230  // empty for all unsupported types.
1231  }
1232 };
1233 
1248 template <typename base_type,
1249  index_t MPerXdlops,
1250  index_t NPerXdlops,
1251  typename additional_type = base_type,
1252  bool is_single_rate_mfma = false,
1253  bool is_scale_mfma = false>
1255 {
1256  template <typename base_type_,
1257  index_t MPerXdlops_,
1258  index_t NPerXdlops_,
1259  typename additional_type_ = base_type_,
1260  bool is_single_rate_mfma_ = false,
1261  bool is_scale_mfma_ = false>
1262  static constexpr auto GetMfma();
1263 
1264  template <>
1265  constexpr auto GetMfma<double, 16, 16>()
1266  {
1267 #if defined(__gfx12__)
1269 #elif defined(__gfx11__)
1271 #else
1273 #endif
1274  }
1275 
1276  template <>
1277  constexpr auto GetMfma<float, 64, 64>()
1278  {
1280  }
1281 
1282  template <>
1283  constexpr auto GetMfma<float, 32, 64>()
1284  {
1286  }
1287 
1288  template <>
1289  constexpr auto GetMfma<float, 16, 64>()
1290  {
1292  }
1293 
1294  template <>
1295  constexpr auto GetMfma<float, 8, 64>()
1296  {
1298  }
1299 
1300  template <>
1301  constexpr auto GetMfma<float, 4, 64>()
1302  {
1304  }
1305 
1306  template <>
1307  constexpr auto GetMfma<float, 32, 32>()
1308  {
1310  }
1311 
1312  template <>
1313  constexpr auto GetMfma<float, 16, 16>()
1314  {
1315 #if defined(__gfx12__)
1317 #elif defined(__gfx11__)
1319 #else
1321 #endif
1322  }
1323 
1324  template <>
1325  constexpr auto GetMfma<tf32_t, 32, 32, tf32_t>()
1326  {
1327 #if defined(__gfx12__)
1329 #elif defined(__gfx11__)
1331 #elif defined(__gfx950__)
1333 #elif defined(__gfx942__)
1335 #else
1337 #endif
1338  }
1339 
1340  template <>
1341  constexpr auto GetMfma<tf32_t, 16, 16, tf32_t>()
1342  {
1343 #if defined(__gfx12__)
1345 #elif defined(__gfx11__)
1347 #elif defined(__gfx950__)
1349 #elif defined(__gfx942__)
1351 #else
1353 #endif
1354  }
1355 
1356  template <>
1357  constexpr auto GetMfma<half_t, 64, 64>()
1358  {
1360  }
1361 
1362  template <>
1363  constexpr auto GetMfma<half_t, 32, 64>()
1364  {
1366  }
1367 
1368  template <>
1369  constexpr auto GetMfma<half_t, 32, 32, half_t, false>()
1370  {
1371 #if defined(__gfx950__)
1373 #else
1375 #endif
1376  }
1377  template <>
1378  constexpr auto GetMfma<half_t, 32, 32, half_t, true>()
1379  {
1381  }
1382 
1383  template <>
1384  constexpr auto GetMfma<half_t, 16, 16, half_t, false>()
1385  {
1386 #if defined(__gfx12__)
1388 #elif defined(__gfx11__)
1390 #elif defined(__gfx950__)
1392 #else
1394 #endif
1395  }
1396 
1397  template <>
1398  constexpr auto GetMfma<half_t, 16, 16, half_t, true>()
1399  {
1400 #if defined(__gfx12__)
1402 #elif defined(__gfx11__)
1404 #else
1406 #endif
1407  }
1408 
1409  template <>
1410  constexpr auto GetMfma<half_t, 16, 64>()
1411  {
1413  }
1414 
1415  template <>
1416  constexpr auto GetMfma<half_t, 8, 64>()
1417  {
1419  }
1420 
1421  template <>
1422  constexpr auto GetMfma<half_t, 4, 64>()
1423  {
1425  }
1426 
1427  template <>
1428  constexpr auto GetMfma<bhalf_t, 32, 32, bhalf_t, false>()
1429  {
1430 #if defined(__gfx950__)
1432 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1434 #else
1436 #endif
1437  }
1438 
1439  template <>
1440  constexpr auto GetMfma<bhalf_t, 32, 32, bhalf_t, true>()
1441  {
1442 #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1444 #else
1446 #endif
1447  }
1448 
1449  template <>
1450  constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, false>()
1451  {
1452 #if defined(__gfx12__)
1454 #elif defined(__gfx11__)
1456 #elif defined(__gfx950__)
1458 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1460 #else
1462 #endif
1463  }
1464 
1465  template <>
1466  constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, true>()
1467  {
1468 #if defined(__gfx12__)
1470 #elif defined(__gfx11__)
1472 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1474 #else
1476 #endif
1477  }
1478 
1479  template <>
1480  constexpr auto GetMfma<int8_t, 32, 32, int8_t, false>()
1481  {
1482 #if defined(__gfx950__)
1484 #elif defined(__gfx942__)
1486 #else
1488 #endif
1489  }
1490 
1491  template <>
1492  constexpr auto GetMfma<int8_t, 32, 32, int8_t, true>()
1493  {
1494 #if defined(__gfx942__) || defined(__gfx950__)
1496 #else
1498 #endif
1499  }
1500 
1501  template <>
1502  constexpr auto GetMfma<int8_t, 16, 16, int8_t, false>()
1503  {
1504 #if defined(__gfx12__)
1506 #elif defined(__gfx11__)
1508 #elif defined(__gfx950__)
1510 #elif defined(__gfx942__)
1512 #else
1514 #endif
1515  }
1516 
1517  template <>
1518  constexpr auto GetMfma<int8_t, 16, 16, int8_t, true>()
1519  {
1520 #if defined(__gfx12__)
1522 #elif defined(__gfx11__)
1524 #elif defined(__gfx942__) || defined(__gfx950__)
1526 #else
1528 #endif
1529  }
1530 
1531  template <>
1532  constexpr auto GetMfma<f8_t, 32, 32, f8_t, true, false>()
1533  {
1535  }
1536 
1537  template <>
1538  constexpr auto GetMfma<f8_t, 32, 32, f8_t, false, false>()
1539  {
1540 #if defined(__gfx950__)
1542 #else
1544 #endif
1545  }
1546 
1547  template <>
1548  constexpr auto GetMfma<f8_t, 32, 32, f8_t, is_single_rate_mfma, true>()
1549  {
1551  }
1552 
1553  template <>
1554  constexpr auto GetMfma<bf8_t, 32, 32, f8_t, is_single_rate_mfma, true>()
1555  {
1557  }
1558  template <>
1559  constexpr auto GetMfma<f4_t, 32, 32, f4_t, is_single_rate_mfma, true>()
1560  {
1562  }
1563  template <>
1564  constexpr auto GetMfma<f4_t, 16, 16, f4_t, is_single_rate_mfma, true>()
1565  {
1566 #if defined(__gfx12__)
1568 #elif defined(__gfx11__)
1570 #else
1572 #endif
1573  }
1574 
1575  template <>
1576  constexpr auto GetMfma<f8_t, 16, 16, f8_t, true, false>()
1577  {
1578 #if defined(__gfx12__)
1580 #elif defined(__gfx11__)
1582 #else
1584 #endif
1585  }
1586 
1587  template <>
1588  constexpr auto GetMfma<f8_t, 16, 16, f8_t, false, false>()
1589  {
1590 #if defined(__gfx12__)
1592 #elif defined(__gfx11__)
1594 #elif defined(__gfx950__)
1596 #else
1598 #endif
1599  }
1600 
1601  template <>
1602  constexpr auto GetMfma<f8_t, 16, 16, f8_t, is_single_rate_mfma, true>()
1603  {
1604 #if defined(__gfx12__)
1606 #elif defined(__gfx11__)
1608 #else
1610 #endif
1611  }
1612 
1613  template <>
1614  constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, is_single_rate_mfma, true>()
1615  {
1616 #if defined(__gfx12__)
1618 #elif defined(__gfx11__)
1620 #else
1622 #endif
1623  }
1624 
1625  template <>
1626  constexpr auto GetMfma<f8_t, 16, 16, bf8_t, is_single_rate_mfma, true>()
1627  {
1628 #if defined(__gfx12__)
1630 #elif defined(__gfx11__)
1632 #else
1634 #endif
1635  }
1636 
1637  template <>
1638  constexpr auto GetMfma<bf8_t, 16, 16, f8_t, is_single_rate_mfma, true>()
1639  {
1640 #if defined(__gfx12__)
1642 #elif defined(__gfx11__)
1644 #else
1646 #endif
1647  }
1648 
1649  template <>
1650  constexpr auto GetMfma<f6_t, 32, 32, f6_t, is_single_rate_mfma, true>()
1651  {
1653  }
1654  template <>
1655  constexpr auto GetMfma<f6_t, 16, 16, f6_t, is_single_rate_mfma, true>()
1656  {
1657 #if defined(__gfx12__)
1659 #elif defined(__gfx11__)
1661 #else
1663 #endif
1664  }
1665  template <>
1666  constexpr auto GetMfma<bf6_t, 32, 32, bf6_t, is_single_rate_mfma, true>()
1667  {
1669  }
1670  template <>
1671  constexpr auto GetMfma<bf6_t, 16, 16, bf6_t, is_single_rate_mfma, true>()
1672  {
1673 #if defined(__gfx12__)
1675 #elif defined(__gfx11__)
1677 #else
1679 #endif
1680  }
1681 
1682  template <>
1683  constexpr auto GetMfma<bf8_t, 32, 32, bf8_t, true, false>()
1684  {
1686  }
1687 
1688  template <>
1689  constexpr auto GetMfma<bf8_t, 32, 32, bf8_t, false, false>()
1690  {
1691 #if defined(__gfx950__)
1693 #else
1695 #endif
1696  }
1697 
1698  template <>
1699  constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, true, false>()
1700  {
1701 #if defined(__gfx12__)
1703 #elif defined(__gfx11__)
1705 #else
1707 #endif
1708  }
1709 
1710  template <>
1711  constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, false, false>()
1712  {
1713 #if defined(__gfx12__)
1715 #elif defined(__gfx11__)
1717 #elif defined(__gfx950__)
1719 #else
1721 #endif
1722  }
1723 
1724  template <>
1725  constexpr auto GetMfma<f8_t, 32, 32, bf8_t, true, false>()
1726  {
1728  }
1729 
1730  template <>
1731  constexpr auto GetMfma<f8_t, 32, 32, bf8_t, false, false>()
1732  {
1733 #if defined(__gfx950__)
1735 #else
1737 #endif
1738  }
1739 
1740  template <>
1741  constexpr auto GetMfma<f8_t, 16, 16, bf8_t, true, false>()
1742  {
1743 #if defined(__gfx12__)
1745 #elif defined(__gfx11__)
1747 #else
1749 #endif
1750  }
1751 
1752  template <>
1753  constexpr auto GetMfma<f8_t, 16, 16, bf8_t, false, false>()
1754  {
1755 #if defined(__gfx12__)
1757 #elif defined(__gfx11__)
1759 #elif defined(__gfx950__)
1761 #else
1763 #endif
1764  }
1765 
1766  template <>
1767  constexpr auto GetMfma<bf8_t, 32, 32, f8_t, true, false>()
1768  {
1770  }
1771 
1772  template <>
1773  constexpr auto GetMfma<bf8_t, 32, 32, f8_t, false, false>()
1774  {
1775 #if defined(__gfx950__)
1777 #else
1779 #endif
1780  }
1781 
1782  template <>
1783  constexpr auto GetMfma<bf8_t, 16, 16, f8_t, true, false>()
1784  {
1785 #if defined(__gfx12__)
1787 #elif defined(__gfx11__)
1789 #else
1791 #endif
1792  }
1793 
1794  template <>
1795  constexpr auto GetMfma<bf8_t, 16, 16, f8_t, false, false>()
1796  {
1797 #if defined(__gfx12__)
1799 #elif defined(__gfx11__)
1801 #elif defined(__gfx950__)
1803 #else
1805 #endif
1806  }
1807 
1809  MPerXdlops,
1810  NPerXdlops,
1812  is_single_rate_mfma,
1813  is_scale_mfma>()>{};
1814 
1815  __host__ __device__ constexpr MfmaSelector()
1816  {
1817  static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk ==
1818  selected_mfma.num_regs_per_blk,
1819  "wrong! num_regs_per_blk");
1820 
1821  static_assert(selected_mfma.num_threads_per_blk == selected_mfma.n_per_blk,
1822  "n_per_blk != num_threads_per_blk");
1823 #if defined(__gfx11__)
1824  if constexpr(MPerXdlops == 16 && NPerXdlops == 16)
1825  {
1826  static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks * 2 ==
1827  selected_mfma.m_per_blk,
1828  "m_per_blk != num_input_blks * num_regs_per_blk");
1829  }
1830 #else
1831  static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks ==
1832  selected_mfma.m_per_blk,
1833  "m_per_blk != num_input_blks * num_regs_per_blk");
1834 #endif
1835 
1836  static_assert(selected_mfma.num_output_blks == selected_mfma.num_input_blks ||
1837  selected_mfma.num_output_blks == 1,
1838  "incorrect num_output_blks");
1839 
1840  static_assert(selected_mfma.num_regs_per_blk * selected_mfma.wave_size ==
1841  selected_mfma.m_per_blk * selected_mfma.n_per_blk,
1842  "num_regs_per_blk incorrect");
1843 
1844  static_assert(selected_mfma.is_k_reduction ||
1845  (selected_mfma.num_input_blks == selected_mfma.num_output_blks),
1846  "is_k_reduction wrong!");
1847  }
1848 
1849  static constexpr bool IsABroadcast()
1850  {
1851  static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast");
1852  return true;
1853  }
1854 
1855  static constexpr index_t GetKPerXdlops()
1856  {
1857  return (selected_mfma.is_k_reduction ? selected_mfma.num_input_blks : 1) *
1858  selected_mfma.k_per_blk;
1859  }
1860 
1861  static constexpr index_t GetK1PerXdlops() { return selected_mfma.k_per_blk; }
1862 };
1863 
1864 template <typename base_type,
1865  index_t MPerXdlops,
1866  index_t NPerXdlops,
1867  index_t KPack,
1868  typename additional_type = base_type,
1869  bool TransposeC = false,
1870  bool is_scale_mfma = false>
1872 {
1873  static constexpr auto I0 = Number<0>{};
1874  static constexpr auto I1 = Number<1>{};
1875  static constexpr auto I2 = Number<2>{};
1876  static constexpr auto I3 = Number<3>{};
1877  static constexpr auto I4 = Number<4>{};
1878  static constexpr auto I5 = Number<5>{};
1879 
1882 
1883  __device__ static constexpr index_t GetNumBlks() { return mfma_instr.num_output_blks; }
1884 
1885  __device__ static constexpr index_t GetNumXdlops()
1886  {
1887  return MPerXdlops * NPerXdlops /
1888  (mfma_instr.m_per_blk * mfma_instr.n_per_blk * mfma_instr.num_output_blks);
1889  }
1890 
1891  __host__ __device__ constexpr XdlopsGemm()
1892  {
1893  static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 ||
1894  NPerXdlops == 64,
1895  "Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
1896 
1897  static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 ||
1898  MPerXdlops == 64,
1899  "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
1900 #if defined(__HIP_DEVICE_COMPILE__)
1901  static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk");
1902 #endif
1903  }
1904 
1905  // XDL output supporting C = A * B
1906  // M2_N2 -> M2_M3_M4_N2
1907  template <typename CDesc_M0_N0_M1_N1_M2_N2>
1908  __host__ __device__ static constexpr auto
1909  MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
1910  {
1911  const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
1912  const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
1913  const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
1914  const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
1915  constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
1916 
1918  c_desc_m0_n0_m1_n1_m2_n2,
1924  Number<num_blks>{},
1925  Number<mfma_instr.group_size>{})),
1928  Sequence<1>{},
1929  Sequence<2>{},
1930  Sequence<3>{},
1931  Sequence<4>{},
1932  Sequence<5>{}),
1934  Sequence<1>{},
1935  Sequence<2>{},
1936  Sequence<3>{},
1938  Sequence<7>{}));
1939  }
1940 
1941  // XDL output supporting C = A * B
1942  // M3_N3 -> M3_M4_M5_N3
1943  template <typename CDesc_M0_N0_M1_N1_M2_N2>
1944  __host__ __device__ static constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
1945  const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
1946  {
1947  const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
1948  const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
1949  const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
1950  const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
1951  const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I4);
1952  const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
1953  constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
1954 
1956  c_desc_m0_n0_m1_n1_m2_n2,
1964  Number<num_blks>{},
1965  Number<mfma_instr.group_size>{})),
1968  Sequence<1>{},
1969  Sequence<2>{},
1970  Sequence<3>{},
1971  Sequence<4>{},
1972  Sequence<5>{},
1973  Sequence<6>{},
1974  Sequence<7>{}),
1976  Sequence<1>{},
1977  Sequence<2>{},
1978  Sequence<3>{},
1979  Sequence<4>{},
1980  Sequence<5>{},
1982  Sequence<9>{}));
1983  }
1984 
1985  // transposed XDL output supporting C' = B' * A'
1986  // M2_N2 -> M2_N2_N3_N4
1987  template <typename CDesc_M0_N0_M1_N1_M2_N2>
1988  __host__ __device__ static constexpr auto
1989  MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
1990  {
1991  const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
1992  const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
1993  const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
1994  const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
1995  constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
1996 
1998  c_desc_m0_n0_m1_n1_m2_n2,
2005  Number<num_blks>{},
2006  Number<mfma_instr.group_size>{}))),
2008  Sequence<1>{},
2009  Sequence<2>{},
2010  Sequence<3>{},
2011  Sequence<4>{},
2012  Sequence<5>{}),
2014  Sequence<1>{},
2015  Sequence<2>{},
2016  Sequence<3>{},
2017  Sequence<4>{},
2018  Sequence<5, 6, 7>{}));
2019  }
2020 
2021  template <typename CDesc_G_M0_N0_M1_N1_M2_N2>
2022  __host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
2023  const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2)
2024  {
2025  const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0);
2026  const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1);
2027  const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2);
2028  const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3);
2029  const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4);
2030  constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
2031 
2033  c_desc_g_m0_n0_m1_n1_m2_n2,
2040  mfma_instr.num_groups_per_blk, num_blks, mfma_instr.group_size)),
2041  make_pass_through_transform(mfma_instr.num_threads_per_blk)),
2043  Sequence<1>{},
2044  Sequence<2>{},
2045  Sequence<3>{},
2046  Sequence<4>{},
2047  Sequence<5>{},
2048  Sequence<6>{}),
2050  Sequence<1>{},
2051  Sequence<2>{},
2052  Sequence<3>{},
2053  Sequence<4>{},
2055  Sequence<8>{}));
2056  }
2057 
2058  __device__ __host__ static constexpr index_t GetRegSizePerXdlops()
2059  {
2060  return mfma_instr.num_regs_per_blk;
2061  }
2062 
2063  __device__ static constexpr index_t GetWaveSize() { return mfma_instr.wave_size; }
2064 
2065  template <class FloatA, class FloatB, class FloatC>
2066  __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
2067  {
2068  static_assert(
2075  "base_type must be double, float, tf32_t, half, bfloat16, int8_t, f8_t or bf8_t!");
2076 
2077  static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
2078  if constexpr(!TransposeC)
2079  {
2080  mfma_instr.template run<MPerXdlops, NPerXdlops>(
2081  p_a_wave[k], p_b_wave[k], p_c_thread);
2082  }
2083  else
2084  {
2085  mfma_instr.template run<MPerXdlops, NPerXdlops>(
2086  p_b_wave[k], p_a_wave[k], p_c_thread);
2087  }
2088  });
2089  }
2090 
2091  template <index_t OpselA,
2092  index_t OpselB,
2093  class FloatA,
2094  class ScaleA,
2095  class FloatB,
2096  class ScaleB,
2097  class FloatC>
2098  __device__ void Run(const FloatA& p_a_wave,
2099  const ScaleA& a_scale_thread,
2100  const FloatB& p_b_wave,
2101  const ScaleB& b_scale_thread,
2102  FloatC& p_c_thread) const
2103  {
2104  static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
2105  if constexpr(!TransposeC)
2106  {
2107  mfma_instr.template run<MPerXdlops, NPerXdlops, OpselA, OpselB>(
2108  p_a_wave[k], a_scale_thread[k], p_b_wave[k], b_scale_thread[k], p_c_thread);
2109  }
2110  else
2111  {
2112  mfma_instr.template run<MPerXdlops, NPerXdlops, OpselB, OpselA>(
2113  p_b_wave[k], b_scale_thread[k], p_a_wave[k], a_scale_thread[k], p_c_thread);
2114  }
2115  });
2116  }
2117 
2118  __device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_instr.wave_size; }
2119 
2120  __device__ static auto GetBlkIdx()
2121  {
2122  const auto laneId = GetLaneId();
2123  constexpr auto num_blks = mfma_instr.m_per_blk / mfma_instr.num_regs_per_blk;
2124 
2125  constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
2126  make_tuple(
2127  make_merge_transform(make_tuple(1, num_blks, mfma_instr.num_threads_per_blk))),
2129  make_tuple(Sequence<0>{}));
2130 
2131  const auto blk_idx =
2132  threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId));
2133 
2134  const auto blk_id = blk_idx[I1];
2135  const auto blk_td = blk_idx[I2];
2136 
2137  return make_tuple(blk_id, blk_td);
2138  }
2139 
2140  template <bool SwizzleA>
2141  __device__ static auto GetGfx11InputBlkIdx()
2142  {
2143  auto laneId = GetLaneId() % mfma_instr.num_threads_per_blk;
2144  if constexpr(SwizzleA)
2145  {
2146  laneId = ((laneId & 1) << 3) | (laneId >> 1);
2147  }
2148  constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
2150  make_tuple(1, mfma_instr.num_input_blks, mfma_instr.num_threads_per_blk))),
2152  make_tuple(Sequence<0>{}));
2153 
2154  const auto blk_idx =
2155  threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId));
2156 
2157  const auto blk_id = blk_idx[I1];
2158  const auto blk_td = blk_idx[I2];
2159 
2160  return make_tuple(blk_id, blk_td);
2161  }
2162 
2163  __host__ __device__ static auto CalculateAThreadOriginDataIndex()
2164  {
2165  const auto laneId = GetLaneId();
2166 #if defined(__gfx11__)
2167  const auto blk_idx = GetGfx11InputBlkIdx<!TransposeC>();
2168 #else
2169  const auto blk_idx = GetBlkIdx();
2170 #endif
2171 
2172  const auto blk_id = blk_idx[I0];
2173  const auto blk_td = blk_idx[I1];
2174 
2175  if constexpr(mfma_instr.is_k_reduction)
2176  {
2177  return make_tuple(blk_id, blk_td);
2178  }
2179  else
2180  {
2181  return make_tuple(0, laneId);
2182  }
2183  }
2184 
2185  __host__ __device__ static auto CalculateBThreadOriginDataIndex()
2186  {
2187  const auto laneId = GetLaneId();
2188 #if defined(__gfx11__)
2189  const auto blk_idx = GetGfx11InputBlkIdx<TransposeC>();
2190 #else
2191  const auto blk_idx = GetBlkIdx();
2192 #endif
2193 
2194  const auto blk_id = blk_idx[I0];
2195  const auto blk_td = blk_idx[I1];
2196 
2197  if constexpr(mfma_instr.is_k_reduction)
2198  {
2199  return make_tuple(blk_id, blk_td);
2200  }
2201  else
2202  {
2203  return make_tuple(0, laneId);
2204  }
2205  }
2206 
2207  __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
2208  {
2209  const auto blk_idx = GetBlkIdx();
2210 
2211  const auto blk_id = blk_idx[I0];
2212  const auto blk_td = blk_idx[I1];
2213 
2214  index_t n_offset = blk_i * mfma_instr.n_per_blk + blk_td;
2215  index_t m_offset = xdlops_i * mfma_instr.m_per_blk + blk_id * mfma_instr.group_size;
2216 
2217  return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
2218  }
2219 
2220  __device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */)
2221  {
2222  const auto blk_idx = GetBlkIdx();
2223 
2224  const auto blk_id = blk_idx[I0];
2225  const auto blk_td = blk_idx[I1];
2226 
2227  return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td};
2228  }
2229 
2230  // Falls back to single rate instruction on gfx950 if KPack is single rate; no change on gfx942-
2231  // when base_type is either f8_t or bf8_t, additional_type will always be either f8_t or bf8_t,
2232  // except Use single rate mfma instruction for this special case A (f8_t) * B (pk_i4_t)
2233  static constexpr bool is_single_rate_mfma =
2235  KPack <= 4) ||
2236  (is_same<base_type, int8_t>::value && KPack <= 8) ||
2239 #if defined(__gfx950__)
2240  // tf32 on gfx950 is implemented as bf16x3, so it should be treated as bf16.
2241  || (is_same<base_type, tf32_t>::value && KPack <= 4)
2242 #endif
2243  ? true
2244  : false;
2245  static constexpr auto mfma = MfmaSelector<base_type,
2246  MPerXdlops,
2247  NPerXdlops,
2248  additional_type,
2250  is_scale_mfma>{};
2251 
2252  static constexpr auto mfma_instr = mfma.selected_mfma;
2253 
2254  static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
2255  static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
2256  static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
2257 
2258  __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths()
2259  {
2260  return make_tuple(
2262  }
2263 };
2264 
2265 } // namespace ck
__host__ constexpr __device__ T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition: math.hpp:148
Definition: ck.hpp:270
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
MfmaInstr
Definition: xdlops_gemm.hpp:45
@ wmma_f32_16x16x16_bf16_gfx12
@ wmma_unsupport_16x16_gfx11
@ wmma_i32_16x16x16_iu8_gfx12
@ mfma_scale_f32_32x32x64f8f6f4
@ wmma_f32_16x16x16_bf8f8_gfx12
@ wmma_f32_16x16x16_f16_gfx12
@ wmma_f32_16x16x16_bf8bf8_gfx12
@ wmma_unsupport_16x16_gfx12
@ mfma_f32_16x16x16bf16_1k
@ wmma_f32_16x16x16_f8f8_gfx12
@ mfma_scale_f32_16x16x128f8f6f4
@ mfma_f32_16x16x128f8f6f4
@ wmma_f32_16x16x16_f8bf8_gfx12
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
typename packed_type_info< T >::element_type element_type_t
Definition: data_type.hpp:408
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:301
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
@ wmma_f32_16x16x16_bf16_gfx12
@ wmma_i32_16x16x16_iu8_gfx12
@ wmma_f32_16x16x16_bf8f8_gfx12
@ wmma_f32_16x16x16_f16_gfx12
@ wmma_f32_16x16x16_bf8bf8_gfx12
@ wmma_f32_16x16x16_f8f8_gfx12
@ wmma_f32_16x16x16_f8bf8_gfx12
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1517
Definition: array.hpp:14
Selects the appropriate MFMA instruction type and configuration for given data types and tile sizes o...
Definition: xdlops_gemm.hpp:1255
__host__ constexpr __device__ MfmaSelector()
Definition: xdlops_gemm.hpp:1815
static constexpr bool IsABroadcast()
Definition: xdlops_gemm.hpp:1849
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1861
static constexpr auto GetMfma()
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1808
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1855
Definition: sequence.hpp:43
Definition: xdlops_gemm.hpp:1872
static constexpr auto mfma_instr
Definition: xdlops_gemm.hpp:2252
__host__ constexpr __device__ XdlopsGemm()
Definition: xdlops_gemm.hpp:1891
__host__ static __device__ auto CalculateBThreadOriginDataIndex()
Definition: xdlops_gemm.hpp:2185
static __device__ auto GetBlkIdx()
Definition: xdlops_gemm.hpp:2120
__device__ static constexpr __host__ index_t GetRegSizePerXdlops()
Definition: xdlops_gemm.hpp:2058
static constexpr auto I2
Definition: xdlops_gemm.hpp:1875
static constexpr __device__ index_t GetNumBlks()
Definition: xdlops_gemm.hpp:1883
static __device__ auto GetLaneId()
Definition: xdlops_gemm.hpp:2118
static constexpr auto K0PerXdlops
Definition: xdlops_gemm.hpp:2256
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1944
static constexpr __device__ index_t GetNumXdlops()
Definition: xdlops_gemm.hpp:1885
__host__ static __device__ auto CalculateAThreadOriginDataIndex()
Definition: xdlops_gemm.hpp:2163
static constexpr bool is_single_rate_mfma
Definition: xdlops_gemm.hpp:2233
static __device__ CIndex4D GetBeginOfThreadBlk4D(index_t, index_t)
Definition: xdlops_gemm.hpp:2220
static constexpr __device__ index_t GetWaveSize()
Definition: xdlops_gemm.hpp:2063
static __device__ auto GetGfx11InputBlkIdx()
Definition: xdlops_gemm.hpp:2141
static constexpr auto I5
Definition: xdlops_gemm.hpp:1878
static constexpr auto I3
Definition: xdlops_gemm.hpp:1876
static constexpr auto I0
Definition: xdlops_gemm.hpp:1873
__device__ void Run(const FloatA &p_a_wave, const ScaleA &a_scale_thread, const FloatB &p_b_wave, const ScaleB &b_scale_thread, FloatC &p_c_thread) const
Definition: xdlops_gemm.hpp:2098
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1909
__host__ static constexpr __device__ auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_G_M0_N0_M1_N1_M2_N2 &c_desc_g_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:2022
static constexpr auto I1
Definition: xdlops_gemm.hpp:1874
static constexpr auto K1PerXdlops
Definition: xdlops_gemm.hpp:2255
static constexpr auto KPerXdlops
Definition: xdlops_gemm.hpp:2254
static constexpr auto I4
Definition: xdlops_gemm.hpp:1877
__device__ void Run(const FloatA &p_a_wave, const FloatB &p_b_wave, FloatC &p_c_thread) const
Definition: xdlops_gemm.hpp:2066
static constexpr auto mfma
Definition: xdlops_gemm.hpp:2245
static __device__ CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
Definition: xdlops_gemm.hpp:2207
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1989
__host__ static constexpr __device__ auto GetCM0M1M2NThreadBlkLengths()
Definition: xdlops_gemm.hpp:2258
Definition: integral_constant.hpp:20
Definition: amd_xdlops.hpp:1221
Definition: amd_xdlops.hpp:322
Definition: amd_xdlops.hpp:212
Definition: amd_xdlops.hpp:89
Definition: amd_xdlops.hpp:288
Definition: amd_xdlops.hpp:1502
Definition: amd_xdlops.hpp:1628
Definition: amd_xdlops.hpp:178
Definition: amd_xdlops.hpp:1565
Definition: amd_xdlops.hpp:1439
Definition: amd_xdlops.hpp:1711
Definition: amd_xdlops.hpp:226
Definition: amd_xdlops.hpp:75
Definition: amd_xdlops.hpp:350
Definition: amd_xdlops.hpp:1660
Definition: amd_xdlops.hpp:268
Definition: amd_xdlops.hpp:1470
Definition: amd_xdlops.hpp:1596
Definition: amd_xdlops.hpp:158
Definition: amd_xdlops.hpp:1533
Definition: amd_xdlops.hpp:1407
Definition: amd_xdlops.hpp:1754
Definition: amd_xdlops.hpp:34
Definition: amd_xdlops.hpp:61
Definition: amd_xdlops.hpp:336
Definition: amd_xdlops.hpp:131
Definition: amd_xdlops.hpp:1680
Definition: amd_xdlops.hpp:500
Definition: amd_xdlops.hpp:308
Definition: amd_xdlops.hpp:198
Definition: amd_xdlops.hpp:103
Definition: amd_xdlops.hpp:240
Definition: amd_xdlops.hpp:480
Definition: amd_xdlops.hpp:383
Definition: amd_xdlops.hpp:461
Definition: amd_xdlops.hpp:422
Definition: amd_xdlops.hpp:442
Definition: amd_xdlops.hpp:402
Definition: amd_xdlops.hpp:364
Definition: amd_xdlops.hpp:905
Definition: amd_xdlops.hpp:685
Definition: amd_wmma.hpp:50
Definition: amd_wmma.hpp:271
Definition: amd_wmma.hpp:25
Definition: amd_wmma.hpp:319
Definition: amd_wmma.hpp:121
Definition: type.hpp:177
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:875
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:453
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:321
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:188
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:431
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:739
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:827
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:299
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:783
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:695
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1059
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:343
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:166
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:497
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:992
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:387
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:717
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:805
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:277
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:761
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:673
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1037
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:122
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:144
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:475
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:233
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1014
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:851
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:409
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:255
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:211
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:365
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:651
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:541
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:585
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:629
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:563
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:607
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:519
__device__ void run(const FloatA &a, const ScaleA &scale_a, const FloatB &b, const ScaleB &scale_b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:944
__device__ void run(const FloatA &a, const ScaleA &scale_a, const FloatB &b, const ScaleB &scale_b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:907
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1095
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1159
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1217
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1207
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1085
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1149
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1197
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1187
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1112
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:1176
__device__ void run(const FloatA &, const FloatB &, FloatC &) const
Definition: xdlops_gemm.hpp:1123
__device__ void run(const FloatA &, const FloatB &, FloatC &) const
Definition: xdlops_gemm.hpp:1228
Definition: xdlops_gemm.hpp:1067
static constexpr index_t n_per_blk
Definition: xdlops_gemm.hpp:1076
static constexpr index_t group_size
Definition: xdlops_gemm.hpp:1068
static constexpr index_t m_per_blk
Definition: xdlops_gemm.hpp:1075
static constexpr bool is_k_reduction
Definition: xdlops_gemm.hpp:1078
static constexpr index_t num_threads_per_blk
Definition: xdlops_gemm.hpp:1071
static constexpr index_t num_output_blks
Definition: xdlops_gemm.hpp:1074
static constexpr index_t wave_size
Definition: xdlops_gemm.hpp:1072
static constexpr index_t num_input_blks
Definition: xdlops_gemm.hpp:1073
static constexpr index_t num_groups_per_blk
Definition: xdlops_gemm.hpp:1069
static constexpr index_t num_regs_per_blk
Definition: xdlops_gemm.hpp:1070
static constexpr index_t k_per_blk
Definition: xdlops_gemm.hpp:1077
Definition: xdlops_gemm.hpp:1131
static constexpr index_t n_per_blk
Definition: xdlops_gemm.hpp:1140
static constexpr index_t group_size
Definition: xdlops_gemm.hpp:1132
static constexpr index_t num_output_blks
Definition: xdlops_gemm.hpp:1138
static constexpr index_t m_per_blk
Definition: xdlops_gemm.hpp:1139
static constexpr index_t num_threads_per_blk
Definition: xdlops_gemm.hpp:1135
static constexpr bool is_k_reduction
Definition: xdlops_gemm.hpp:1142
static constexpr index_t num_regs_per_blk
Definition: xdlops_gemm.hpp:1134
static constexpr index_t num_groups_per_blk
Definition: xdlops_gemm.hpp:1133
static constexpr index_t num_input_blks
Definition: xdlops_gemm.hpp:1137
static constexpr index_t wave_size
Definition: xdlops_gemm.hpp:1136
static constexpr index_t k_per_blk
Definition: xdlops_gemm.hpp:1141
Definition: xdlops_gemm.hpp:104
Definition: functional2.hpp:33