/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp Source File
xdlops_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/math.hpp"
9 
10 namespace ck {
14 template <typename T>
15 static constexpr bool is_scale_mfma_data_type()
16 {
17  using U = element_type_t<T>;
18  return is_same_v<U, f8_ocp_t> || is_same_v<U, bf8_ocp_t> || is_same_v<U, f6_t> ||
19  is_same_v<U, bf6_t> || is_same_v<U, f4_t>;
20 }
21 
25 template <typename T>
26 static constexpr bool is_scale_mfma_scale_type()
27 {
28  return is_same_v<T, e8m0_bexp_t>;
29 }
30 
34 template <typename ADataType, typename BDataType, typename AScaleDataType, typename BScaleDataType>
35 static constexpr bool scale_mfma_hw_support()
36 {
37  return is_scale_mfma_data_type<ADataType>() && is_scale_mfma_data_type<BDataType>() &&
38  is_scale_mfma_scale_type<AScaleDataType>() && is_scale_mfma_scale_type<BScaleDataType>();
39 }
40 
41 enum struct MfmaInstr
42 {
80 };
81 
82 template <MfmaInstr instr>
83 struct mfma_type;
84 
85 template <>
87 {
88  static constexpr index_t group_size = 4;
89  static constexpr index_t num_groups_per_blk = 4;
90  static constexpr index_t num_regs_per_blk = 16;
91  static constexpr index_t num_threads_per_blk = 32;
92  static constexpr index_t wave_size = 64;
93  static constexpr index_t num_input_blks = 2;
94  static constexpr index_t num_output_blks = 2;
95  static constexpr index_t m_per_blk = 32;
96  static constexpr index_t n_per_blk = 32;
97  static constexpr index_t k_per_blk = 1;
98  static constexpr bool is_k_reduction = false;
99 
100  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
101  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
102  {
104  }
105 };
106 
107 template <>
109 {
110  static constexpr index_t group_size = 4;
111  static constexpr index_t num_groups_per_blk = 4;
112  static constexpr index_t num_regs_per_blk = 16;
113  static constexpr index_t num_threads_per_blk = 32;
114  static constexpr index_t wave_size = 64;
115  static constexpr index_t num_input_blks = 2;
116  static constexpr index_t num_output_blks = 1;
117  static constexpr index_t m_per_blk = 32;
118  static constexpr index_t n_per_blk = 32;
119  static constexpr index_t k_per_blk = 1;
120  static constexpr bool is_k_reduction = true;
121 
122  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
123  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
124  {
126  }
127 };
128 
129 template <>
131 {
132  static constexpr index_t group_size = 4;
133  static constexpr index_t num_groups_per_blk = 1;
134  static constexpr index_t num_regs_per_blk = 4;
135  static constexpr index_t num_threads_per_blk = 16;
136  static constexpr index_t wave_size = 64;
137  static constexpr index_t num_input_blks = 4;
138  static constexpr index_t num_output_blks = 1;
139  static constexpr index_t m_per_blk = 16;
140  static constexpr index_t n_per_blk = 16;
141  static constexpr index_t k_per_blk = 1;
142  static constexpr bool is_k_reduction = true;
143 
144  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
145  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
146  {
148  }
149 };
150 
151 template <>
153 {
154  static constexpr index_t group_size = 4;
155  static constexpr index_t num_groups_per_blk = 1;
156  static constexpr index_t num_regs_per_blk = 4;
157  static constexpr index_t num_threads_per_blk = 16;
158  static constexpr index_t wave_size = 64;
159  static constexpr index_t num_input_blks = 4;
160  static constexpr index_t num_output_blks = 4;
161  static constexpr index_t m_per_blk = 16;
162  static constexpr index_t n_per_blk = 16;
163  static constexpr index_t k_per_blk = 1;
164  static constexpr bool is_k_reduction = false;
165 
166  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
167  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
168  {
170  }
171 };
172 
173 // treat 4x4x1 as a single-blk 4x64 mfma
174 template <>
176 {
177  static constexpr index_t group_size = 4;
178  static constexpr index_t num_groups_per_blk = 1;
179  static constexpr index_t num_regs_per_blk = 4;
180  static constexpr index_t num_threads_per_blk = 64;
181  static constexpr index_t wave_size = 64;
182  static constexpr index_t num_input_blks = 1;
183  static constexpr index_t num_output_blks = 1;
184  static constexpr index_t m_per_blk = 4;
185  static constexpr index_t n_per_blk = 64;
186  static constexpr index_t k_per_blk = 1;
187  static constexpr bool is_k_reduction = false;
188 
189  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
190  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
191  {
193  }
194 };
195 
196 template <>
198 {
199  static constexpr index_t group_size = 4;
200  static constexpr index_t num_groups_per_blk = 4;
201  static constexpr index_t num_regs_per_blk = 16;
202  static constexpr index_t num_threads_per_blk = 32;
203  static constexpr index_t wave_size = 64;
204  static constexpr index_t num_input_blks = 2;
205  static constexpr index_t num_output_blks = 2;
206  static constexpr index_t m_per_blk = 32;
207  static constexpr index_t n_per_blk = 32;
208  static constexpr index_t k_per_blk = 4;
209  static constexpr bool is_k_reduction = false;
210 
211  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
212  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
213  {
215  }
216 };
217 
218 template <>
220 {
221  static constexpr index_t group_size = 4;
222  static constexpr index_t num_groups_per_blk = 4;
223  static constexpr index_t num_regs_per_blk = 16;
224  static constexpr index_t num_threads_per_blk = 32;
225  static constexpr index_t wave_size = 64;
226  static constexpr index_t num_input_blks = 2;
227  static constexpr index_t num_output_blks = 1;
228  static constexpr index_t m_per_blk = 32;
229  static constexpr index_t n_per_blk = 32;
230  static constexpr index_t k_per_blk = 4;
231  static constexpr bool is_k_reduction = true;
232 
233  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
234  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
235  {
237  }
238 };
239 
240 template <>
242 {
243  static constexpr index_t group_size = 4;
244  static constexpr index_t num_groups_per_blk = 4;
245  static constexpr index_t num_regs_per_blk = 16;
246  static constexpr index_t num_threads_per_blk = 32;
247  static constexpr index_t wave_size = 64;
248  static constexpr index_t num_input_blks = 2;
249  static constexpr index_t num_output_blks = 1;
250  static constexpr index_t m_per_blk = 32;
251  static constexpr index_t n_per_blk = 32;
252  static constexpr index_t k_per_blk = 8;
253  static constexpr bool is_k_reduction = true;
254 
255  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
256  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
257  {
259  }
260 };
261 
262 template <>
264 {
265  static constexpr index_t group_size = 4;
266  static constexpr index_t num_groups_per_blk = 1;
267  static constexpr index_t num_regs_per_blk = 4;
268  static constexpr index_t num_threads_per_blk = 16;
269  static constexpr index_t wave_size = 64;
270  static constexpr index_t num_input_blks = 4;
271  static constexpr index_t num_output_blks = 1;
272  static constexpr index_t m_per_blk = 16;
273  static constexpr index_t n_per_blk = 16;
274  static constexpr index_t k_per_blk = 8;
275  static constexpr bool is_k_reduction = true;
276 
277  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
278  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
279  {
281  }
282 };
283 
284 template <>
286 {
287  static constexpr index_t group_size = 4;
288  static constexpr index_t num_groups_per_blk = 1;
289  static constexpr index_t num_regs_per_blk = 4;
290  static constexpr index_t num_threads_per_blk = 16;
291  static constexpr index_t wave_size = 64;
292  static constexpr index_t num_input_blks = 4;
293  static constexpr index_t num_output_blks = 1;
294  static constexpr index_t m_per_blk = 16;
295  static constexpr index_t n_per_blk = 16;
296  static constexpr index_t k_per_blk = 4;
297  static constexpr bool is_k_reduction = true;
298 
299  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
300  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
301  {
303  }
304 };
305 
306 template <>
308 {
309  static constexpr index_t group_size = 4;
310  static constexpr index_t num_groups_per_blk = 1;
311  static constexpr index_t num_regs_per_blk = 4;
312  static constexpr index_t num_threads_per_blk = 16;
313  static constexpr index_t wave_size = 64;
314  static constexpr index_t num_input_blks = 4;
315  static constexpr index_t num_output_blks = 4;
316  static constexpr index_t m_per_blk = 16;
317  static constexpr index_t n_per_blk = 16;
318  static constexpr index_t k_per_blk = 4;
319  static constexpr bool is_k_reduction = false;
320 
321  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
322  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
323  {
325  }
326 };
327 
328 template <>
330 {
331  static constexpr index_t group_size = 4;
332  static constexpr index_t num_groups_per_blk = 1;
333  static constexpr index_t num_regs_per_blk = 4;
334  static constexpr index_t num_threads_per_blk = 64;
335  static constexpr index_t wave_size = 64;
336  static constexpr index_t num_input_blks = 1;
337  static constexpr index_t num_output_blks = 1;
338  static constexpr index_t m_per_blk = 4;
339  static constexpr index_t n_per_blk = 64;
340  static constexpr index_t k_per_blk = 4;
341  static constexpr bool is_k_reduction = false;
342 
343  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
344  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
345  {
347  }
348 };
349 
350 template <>
352 {
353  static constexpr index_t group_size = 4;
354  static constexpr index_t num_groups_per_blk = 4;
355  static constexpr index_t num_regs_per_blk = 16;
356  static constexpr index_t num_threads_per_blk = 32;
357  static constexpr index_t wave_size = 64;
358  static constexpr index_t num_input_blks = 2;
359  static constexpr index_t num_output_blks = 1;
360  static constexpr index_t m_per_blk = 32;
361  static constexpr index_t n_per_blk = 32;
362  static constexpr index_t k_per_blk = 8;
363  static constexpr bool is_k_reduction = true;
364 
365  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
366  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
367  {
369  }
370 };
371 
372 template <>
374 {
375  static constexpr index_t group_size = 4;
376  static constexpr index_t num_groups_per_blk = 4;
377  static constexpr index_t num_regs_per_blk = 16;
378  static constexpr index_t num_threads_per_blk = 32;
379  static constexpr index_t wave_size = 64;
380  static constexpr index_t num_input_blks = 2;
381  static constexpr index_t num_output_blks = 1;
382  static constexpr index_t m_per_blk = 32;
383  static constexpr index_t n_per_blk = 32;
384  static constexpr index_t k_per_blk = 4;
385  static constexpr bool is_k_reduction = true;
386 
387  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
388  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
389  {
391  }
392 };
393 
394 template <>
396 {
397  static constexpr index_t group_size = 4;
398  static constexpr index_t num_groups_per_blk = 1;
399  static constexpr index_t num_regs_per_blk = 4;
400  static constexpr index_t num_threads_per_blk = 16;
401  static constexpr index_t wave_size = 64;
402  static constexpr index_t num_input_blks = 4;
403  static constexpr index_t num_output_blks = 1;
404  static constexpr index_t m_per_blk = 16;
405  static constexpr index_t n_per_blk = 16;
406  static constexpr index_t k_per_blk = 8;
407  static constexpr bool is_k_reduction = true;
408 
409  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
410  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
411  {
413  }
414 };
415 
416 template <>
418 {
419  static constexpr index_t group_size = 4;
420  static constexpr index_t num_groups_per_blk = 1;
421  static constexpr index_t num_regs_per_blk = 4;
422  static constexpr index_t num_threads_per_blk = 16;
423  static constexpr index_t wave_size = 64;
424  static constexpr index_t num_input_blks = 4;
425  static constexpr index_t num_output_blks = 1;
426  static constexpr index_t m_per_blk = 16;
427  static constexpr index_t n_per_blk = 16;
428  static constexpr index_t k_per_blk = 4;
429  static constexpr bool is_k_reduction = true;
430 
431  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
432  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
433  {
435  }
436 };
437 
438 template <>
440 {
441  static constexpr index_t group_size = 4;
442  static constexpr index_t num_groups_per_blk = 4;
443  static constexpr index_t num_regs_per_blk = 16;
444  static constexpr index_t num_threads_per_blk = 32;
445  static constexpr index_t wave_size = 64;
446  static constexpr index_t num_input_blks = 2;
447  static constexpr index_t num_output_blks = 1;
448  static constexpr index_t m_per_blk = 32;
449  static constexpr index_t n_per_blk = 32;
450  static constexpr index_t k_per_blk = 2;
451  static constexpr bool is_k_reduction = true;
452 
453  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
454  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
455  {
457  }
458 };
459 
460 template <>
462 {
463  static constexpr index_t group_size = 4;
464  static constexpr index_t num_groups_per_blk = 1;
465  static constexpr index_t num_regs_per_blk = 4;
466  static constexpr index_t num_threads_per_blk = 16;
467  static constexpr index_t wave_size = 64;
468  static constexpr index_t num_input_blks = 4;
469  static constexpr index_t num_output_blks = 1;
470  static constexpr index_t m_per_blk = 16;
471  static constexpr index_t n_per_blk = 16;
472  static constexpr index_t k_per_blk = 2;
473  static constexpr bool is_k_reduction = true;
474 
475  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
476  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
477  {
479  }
480 };
481 
482 template <>
484 {
485  static constexpr index_t group_size = 4;
486  static constexpr index_t num_groups_per_blk = 4;
487  static constexpr index_t num_regs_per_blk = 16;
488  static constexpr index_t num_threads_per_blk = 32;
489  static constexpr index_t wave_size = 64;
490  static constexpr index_t num_input_blks = 2;
491  static constexpr index_t num_output_blks = 1;
492  static constexpr index_t m_per_blk = 32;
493  static constexpr index_t n_per_blk = 32;
494  static constexpr index_t k_per_blk = 4;
495  static constexpr bool is_k_reduction = true;
496 
497  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
498  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
499  {
501  }
502 };
503 
504 template <>
506 {
507  static constexpr index_t group_size = 4;
508  static constexpr index_t num_groups_per_blk = 1;
509  static constexpr index_t num_regs_per_blk = 4;
510  static constexpr index_t num_threads_per_blk = 16;
511  static constexpr index_t wave_size = 64;
512  static constexpr index_t num_input_blks = 4;
513  static constexpr index_t num_output_blks = 1;
514  static constexpr index_t m_per_blk = 16;
515  static constexpr index_t n_per_blk = 16;
516  static constexpr index_t k_per_blk = 4;
517  static constexpr bool is_k_reduction = true;
518 
519  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
520  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
521  {
523  }
524 };
525 
526 template <>
528 {
529  static constexpr index_t group_size = 4;
530  static constexpr index_t num_groups_per_blk = 4;
531  static constexpr index_t num_regs_per_blk = 16;
532  static constexpr index_t num_threads_per_blk = 32;
533  static constexpr index_t wave_size = 64;
534  static constexpr index_t num_input_blks = 2;
535  static constexpr index_t num_output_blks = 1;
536  static constexpr index_t m_per_blk = 32;
537  static constexpr index_t n_per_blk = 32;
538  static constexpr index_t k_per_blk = 8;
539  static constexpr bool is_k_reduction = true;
540 
541  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
542  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
543  {
545  }
546 };
547 
548 template <>
550 {
551  static constexpr index_t group_size = 4;
552  static constexpr index_t num_groups_per_blk = 1;
553  static constexpr index_t num_regs_per_blk = 4;
554  static constexpr index_t num_threads_per_blk = 16;
555  static constexpr index_t wave_size = 64;
556  static constexpr index_t num_input_blks = 4;
557  static constexpr index_t num_output_blks = 1;
558  static constexpr index_t m_per_blk = 16;
559  static constexpr index_t n_per_blk = 16;
560  static constexpr index_t k_per_blk = 8;
561  static constexpr bool is_k_reduction = true;
562 
563  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
564  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
565  {
567  }
568 };
569 
570 template <>
572 {
573  static constexpr index_t group_size = 4;
574  static constexpr index_t num_groups_per_blk = 4;
575  static constexpr index_t num_regs_per_blk = 16;
576  static constexpr index_t num_threads_per_blk = 32;
577  static constexpr index_t wave_size = 64;
578  static constexpr index_t num_input_blks = 2;
579  static constexpr index_t num_output_blks = 1;
580  static constexpr index_t m_per_blk = 32;
581  static constexpr index_t n_per_blk = 32;
582  static constexpr index_t k_per_blk = 16;
583  static constexpr bool is_k_reduction = true;
584 
585  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
586  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
587  {
589  }
590 };
591 
592 template <>
594 {
595  static constexpr index_t group_size = 4;
596  static constexpr index_t num_groups_per_blk = 1;
597  static constexpr index_t num_regs_per_blk = 4;
598  static constexpr index_t num_threads_per_blk = 16;
599  static constexpr index_t wave_size = 64;
600  static constexpr index_t num_input_blks = 4;
601  static constexpr index_t num_output_blks = 1;
602  static constexpr index_t m_per_blk = 16;
603  static constexpr index_t n_per_blk = 16;
604  static constexpr index_t k_per_blk = 16;
605  static constexpr bool is_k_reduction = true;
606 
607  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
608  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
609  {
611  }
612 };
613 
614 template <>
616 {
617  static constexpr index_t group_size = 1;
618  static constexpr index_t num_groups_per_blk = 4;
619  static constexpr index_t num_regs_per_blk = 4; // group_size * num_groups_per_blk;
620  static constexpr index_t num_threads_per_blk = 16;
621  static constexpr index_t wave_size = 64;
622  static constexpr index_t num_input_blks = 4; // wave_size / num_threads_per_blk;
623  static constexpr index_t num_output_blks = 1;
624  static constexpr index_t m_per_blk = 16;
625  static constexpr index_t n_per_blk = 16;
626  static constexpr index_t k_per_blk = 1;
627  static constexpr bool is_k_reduction = true;
628 
629  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
630  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
631  {
633  }
634 };
635 
636 template <>
638 {
639  static constexpr index_t group_size = 4;
640  static constexpr index_t num_groups_per_blk = 4;
641  static constexpr index_t num_regs_per_blk = 16;
642  static constexpr index_t num_threads_per_blk = 32;
643  static constexpr index_t wave_size = 64;
644  static constexpr index_t num_input_blks = 2;
645  static constexpr index_t num_output_blks = 1;
646  static constexpr index_t m_per_blk = 32;
647  static constexpr index_t n_per_blk = 32;
648  static constexpr index_t k_per_blk = 8;
649  static constexpr bool is_k_reduction = true;
650 
651  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
652  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
653  {
655  }
656 };
657 
658 template <>
660 {
661  static constexpr index_t group_size = 4;
662  static constexpr index_t num_groups_per_blk = 1;
663  static constexpr index_t num_regs_per_blk = 4;
664  static constexpr index_t num_threads_per_blk = 16;
665  static constexpr index_t wave_size = 64;
666  static constexpr index_t num_input_blks = 4;
667  static constexpr index_t num_output_blks = 1;
668  static constexpr index_t m_per_blk = 16;
669  static constexpr index_t n_per_blk = 16;
670  static constexpr index_t k_per_blk = 8;
671  static constexpr bool is_k_reduction = true;
672 
673  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
674  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
675  {
677  }
678 };
679 
680 template <>
682 {
683  static constexpr index_t group_size = 4;
684  static constexpr index_t num_groups_per_blk = 4;
685  static constexpr index_t num_regs_per_blk = 16;
686  static constexpr index_t num_threads_per_blk = 32;
687  static constexpr index_t wave_size = 64;
688  static constexpr index_t num_input_blks = 2;
689  static constexpr index_t num_output_blks = 1;
690  static constexpr index_t m_per_blk = 32;
691  static constexpr index_t n_per_blk = 32;
692  static constexpr index_t k_per_blk = 8;
693  static constexpr bool is_k_reduction = true;
694 
695  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
696  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
697  {
699  }
700 };
701 
702 template <>
704 {
705  static constexpr index_t group_size = 4;
706  static constexpr index_t num_groups_per_blk = 1;
707  static constexpr index_t num_regs_per_blk = 4;
708  static constexpr index_t num_threads_per_blk = 16;
709  static constexpr index_t wave_size = 64;
710  static constexpr index_t num_input_blks = 4;
711  static constexpr index_t num_output_blks = 1;
712  static constexpr index_t m_per_blk = 16;
713  static constexpr index_t n_per_blk = 16;
714  static constexpr index_t k_per_blk = 8;
715  static constexpr bool is_k_reduction = true;
716 
717  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
718  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
719  {
721  }
722 };
723 
724 template <>
726 {
727  static constexpr index_t group_size = 4;
728  static constexpr index_t num_groups_per_blk = 4;
729  static constexpr index_t num_regs_per_blk = 16;
730  static constexpr index_t num_threads_per_blk = 32;
731  static constexpr index_t wave_size = 64;
732  static constexpr index_t num_input_blks = 2;
733  static constexpr index_t num_output_blks = 1;
734  static constexpr index_t m_per_blk = 32;
735  static constexpr index_t n_per_blk = 32;
736  static constexpr index_t k_per_blk = 8;
737  static constexpr bool is_k_reduction = true;
738 
739  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
740  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
741  {
743  }
744 };
745 
746 template <>
748 {
749  static constexpr index_t group_size = 4;
750  static constexpr index_t num_groups_per_blk = 1;
751  static constexpr index_t num_regs_per_blk = 4;
752  static constexpr index_t num_threads_per_blk = 16;
753  static constexpr index_t wave_size = 64;
754  static constexpr index_t num_input_blks = 4;
755  static constexpr index_t num_output_blks = 1;
756  static constexpr index_t m_per_blk = 16;
757  static constexpr index_t n_per_blk = 16;
758  static constexpr index_t k_per_blk = 8;
759  static constexpr bool is_k_reduction = true;
760 
761  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
762  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
763  {
765  }
766 };
767 
768 template <>
770 {
771  static constexpr index_t group_size = 4;
772  static constexpr index_t num_groups_per_blk = 4;
773  static constexpr index_t num_regs_per_blk = 16;
774  static constexpr index_t num_threads_per_blk = 32;
775  static constexpr index_t wave_size = 64;
776  static constexpr index_t num_input_blks = 2;
777  static constexpr index_t num_output_blks = 1;
778  static constexpr index_t m_per_blk = 32;
779  static constexpr index_t n_per_blk = 32;
780  static constexpr index_t k_per_blk = 8;
781  static constexpr bool is_k_reduction = true;
782 
783  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
784  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
785  {
787  }
788 };
789 
790 template <>
792 {
793  static constexpr index_t group_size = 4;
794  static constexpr index_t num_groups_per_blk = 1;
795  static constexpr index_t num_regs_per_blk = 4;
796  static constexpr index_t num_threads_per_blk = 16;
797  static constexpr index_t wave_size = 64;
798  static constexpr index_t num_input_blks = 4;
799  static constexpr index_t num_output_blks = 1;
800  static constexpr index_t m_per_blk = 16;
801  static constexpr index_t n_per_blk = 16;
802  static constexpr index_t k_per_blk = 8;
803  static constexpr bool is_k_reduction = true;
804 
805  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
806  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
807  {
809  }
810 };
811 
812 template <>
814 {
815  // clang-format off
816  static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
817  static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
818  static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size
819  static constexpr index_t num_threads_per_blk = 32; // n_per_blk
820  static constexpr index_t wave_size = 64; // fixed
821  static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk
822  static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
823  static constexpr index_t m_per_blk = 32; // from the instruction
824  static constexpr index_t n_per_blk = 32; // from the instruction
825  static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
826  static constexpr bool is_k_reduction = true; // ???
827  // clang-format on
828 
829  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
830  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
831  {
833  }
834 };
835 
836 template <>
838 {
839  // clang-format off
840  static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
841  static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk
842  static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size
843  static constexpr index_t num_threads_per_blk = 16; // == n_per_blk
844  static constexpr index_t wave_size = 64; // fixed
845  static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk
846  static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
847  static constexpr index_t m_per_blk = 16; // from the instruction
848  static constexpr index_t n_per_blk = 16; // from the instruction
849  static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
850  static constexpr bool is_k_reduction = true; // ???
851  // clang-format on
852 
853  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
854  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
855  {
857  }
858 };
859 
860 template <>
862 {
863  // clang-format off
864  static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
865  static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
866  static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size
867  static constexpr index_t num_threads_per_blk = 32; // n_per_blk
868  static constexpr index_t wave_size = 64; // fixed
869  static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk
870  static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
871  static constexpr index_t m_per_blk = 32; // from the instruction
872  static constexpr index_t n_per_blk = 32; // from the instruction
873  static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
874  static constexpr bool is_k_reduction = true; // ???
875  // clang-format on
876 
877  template <index_t MPerXdlops,
878  index_t NPerXdlops,
879  index_t OpselA,
880  index_t OpselB,
881  class FloatA,
882  class ScaleA,
883  class FloatB,
884  class ScaleB,
885  class FloatC>
886  __device__ void run(const FloatA& a,
887  const ScaleA& scale_a,
888  const FloatB& b,
889  const ScaleB& scale_b,
890  FloatC& reg_c) const
891  {
893  a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
894  }
895 };
896 
897 template <>
899 {
900  // clang-format off
901  static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
902  static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk
903  static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size
904  static constexpr index_t num_threads_per_blk = 16; // == n_per_blk
905  static constexpr index_t wave_size = 64; // fixed
906  static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk
907  static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
908  static constexpr index_t m_per_blk = 16; // from the instruction
909  static constexpr index_t n_per_blk = 16; // from the instruction
910  static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? KPerXdlops / num_input_blks
911  static constexpr bool is_k_reduction = true; // ???
912  // clang-format on
913 
914  template <index_t MPerXdlops,
915  index_t NPerXdlops,
916  index_t OpselA,
917  index_t OpselB,
918  class FloatA,
919  class ScaleA,
920  class FloatB,
921  class ScaleB,
922  class FloatC>
923  __device__ void run(const FloatA& a,
924  const ScaleA& scale_a,
925  const FloatB& b,
926  const ScaleB& scale_b,
927  FloatC& reg_c) const
928  {
929 
931  a, bit_cast<uint32_t>(scale_a), b, bit_cast<uint32_t>(scale_b), reg_c);
932  }
933 };
934 
935 template <typename base_type,
936  index_t MPerXdlops,
937  index_t NPerXdlops,
938  typename additional_type = base_type,
939  bool is_single_rate_mfma = false,
940  bool is_scale_mfma = false>
942 {
943  template <typename base_type_,
944  index_t MPerXdlops_,
945  index_t NPerXdlops_,
946  typename additional_type_ = base_type_,
947  bool is_single_rate_mfma_ = false,
948  bool is_scale_mfma_ = false>
949  static constexpr auto GetMfma();
950 
951  template <>
952  constexpr auto GetMfma<double, 16, 16>()
953  {
955  }
956 
957  template <>
958  constexpr auto GetMfma<float, 64, 64>()
959  {
961  }
962 
963  template <>
964  constexpr auto GetMfma<float, 32, 64>()
965  {
967  }
968 
969  template <>
970  constexpr auto GetMfma<float, 16, 64>()
971  {
973  }
974 
975  template <>
976  constexpr auto GetMfma<float, 8, 64>()
977  {
979  }
980 
981  template <>
982  constexpr auto GetMfma<float, 4, 64>()
983  {
985  }
986 
987  template <>
988  constexpr auto GetMfma<float, 32, 32>()
989  {
991  }
992 
993  template <>
994  constexpr auto GetMfma<float, 16, 16>()
995  {
997  }
998 
999  template <>
1000  constexpr auto GetMfma<half_t, 64, 64>()
1001  {
1003  }
1004 
1005  template <>
1006  constexpr auto GetMfma<half_t, 32, 64>()
1007  {
1009  }
1010 
1011  template <>
1012  constexpr auto GetMfma<half_t, 32, 32, half_t, false>()
1013  {
1014 #if defined(__gfx950__)
1016 #else
1018 #endif
1019  }
1020  template <>
1021  constexpr auto GetMfma<half_t, 32, 32, half_t, true>()
1022  {
1024  }
1025 
1026  template <>
1027  constexpr auto GetMfma<half_t, 16, 16, half_t, false>()
1028  {
1029 #if defined(__gfx950__)
1031 #else
1033 #endif
1034  }
1035 
1036  template <>
1037  constexpr auto GetMfma<half_t, 16, 16, half_t, true>()
1038  {
1040  }
1041 
1042  template <>
1043  constexpr auto GetMfma<half_t, 16, 64>()
1044  {
1046  }
1047 
1048  template <>
1049  constexpr auto GetMfma<half_t, 8, 64>()
1050  {
1052  }
1053 
1054  template <>
1055  constexpr auto GetMfma<half_t, 4, 64>()
1056  {
1058  }
1059 
1060  template <>
1061  constexpr auto GetMfma<bhalf_t, 32, 32, bhalf_t, false>()
1062  {
1063 #if defined(__gfx950__)
1065 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1067 #else
1069 #endif
1070  }
1071 
1072  template <>
1073  constexpr auto GetMfma<bhalf_t, 32, 32, bhalf_t, true>()
1074  {
1075 #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1077 #else
1079 #endif
1080  }
1081 
1082  template <>
1083  constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, false>()
1084  {
1085 #if defined(__gfx950__)
1087 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1089 #else
1091 #endif
1092  }
1093 
1094  template <>
1095  constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, true>()
1096  {
1097 #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1099 #else
1101 #endif
1102  }
1103 
1104  template <>
1105  constexpr auto GetMfma<int8_t, 32, 32, int8_t, false>()
1106  {
1107 #if defined(__gfx950__)
1109 #elif defined(__gfx942__)
1111 #else
1113 #endif
1114  }
1115 
1116  template <>
1117  constexpr auto GetMfma<int8_t, 32, 32, int8_t, true>()
1118  {
1119 #if defined(__gfx942__) || defined(__gfx950__)
1121 #else
1123 #endif
1124  }
1125 
1126  template <>
1127  constexpr auto GetMfma<int8_t, 16, 16, int8_t, false>()
1128  {
1129 #if defined(__gfx950__)
1131 #elif defined(__gfx942__)
1133 #else
1135 #endif
1136  }
1137 
1138  template <>
1139  constexpr auto GetMfma<int8_t, 16, 16, int8_t, true>()
1140  {
1141 #if defined(__gfx942__) || defined(__gfx950__)
1143 #else
1145 #endif
1146  }
1147 
1148  template <>
1149  constexpr auto GetMfma<f8_t, 32, 32, f8_t, true, false>()
1150  {
1152  }
1153 
1154  template <>
1155  constexpr auto GetMfma<f8_t, 32, 32, f8_t, false, false>()
1156  {
1157 #if defined(__gfx950__)
1159 #else
1161 #endif
1162  }
1163 
1164  template <>
1165  constexpr auto GetMfma<f8_t, 32, 32, f8_t, false, true>()
1166  {
1168  }
1169 
1170  template <>
1171  constexpr auto GetMfma<bf8_t, 32, 32, f8_t, false, true>()
1172  {
1174  }
1175  template <>
1176  constexpr auto GetMfma<f4_t, 32, 32, f4_t, false, true>()
1177  {
1179  }
1180  template <>
1181  constexpr auto GetMfma<f4_t, 16, 16, f4_t, false, true>()
1182  {
1184  }
1185 
1186  template <>
1187  constexpr auto GetMfma<f8_t, 16, 16, f8_t, true, false>()
1188  {
1190  }
1191 
1192  template <>
1193  constexpr auto GetMfma<f8_t, 16, 16, f8_t, false, false>()
1194  {
1195 #if defined(__gfx950__)
1197 #else
1199 #endif
1200  }
1201 
1202  template <>
1203  constexpr auto GetMfma<f8_t, 16, 16, f8_t, false, true>()
1204  {
1206  }
1207 
1208  template <>
1209  constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, false, true>()
1210  {
1212  }
1213 
1214  template <>
1215  constexpr auto GetMfma<f8_t, 16, 16, bf8_t, false, true>()
1216  {
1218  }
1219 
1220  template <>
1221  constexpr auto GetMfma<bf8_t, 16, 16, f8_t, false, true>()
1222  {
1224  }
1225 
1226  template <>
1227  constexpr auto GetMfma<f6_t, 32, 32, f6_t, false, true>()
1228  {
1230  }
1231  template <>
1232  constexpr auto GetMfma<f6_t, 16, 16, f6_t, false, true>()
1233  {
1235  }
1236  template <>
1237  constexpr auto GetMfma<bf6_t, 32, 32, bf6_t, false, true>()
1238  {
1240  }
1241  template <>
1242  constexpr auto GetMfma<bf6_t, 16, 16, bf6_t, false, true>()
1243  {
1245  }
1246 
1247  template <>
1248  constexpr auto GetMfma<bf8_t, 32, 32, bf8_t, true, false>()
1249  {
1251  }
1252 
1253  template <>
1254  constexpr auto GetMfma<bf8_t, 32, 32, bf8_t, false, false>()
1255  {
1256 #if defined(__gfx950__)
1258 #else
1260 #endif
1261  }
1262 
1263  template <>
1264  constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, true, false>()
1265  {
1267  }
1268 
1269  template <>
1270  constexpr auto GetMfma<bf8_t, 16, 16, bf8_t, false, false>()
1271  {
1272 #if defined(__gfx950__)
1274 #else
1276 #endif
1277  }
1278 
1279  template <>
1280  constexpr auto GetMfma<f8_t, 32, 32, bf8_t, true, false>()
1281  {
1283  }
1284 
1285  template <>
1286  constexpr auto GetMfma<f8_t, 32, 32, bf8_t, false, false>()
1287  {
1288 #if defined(__gfx950__)
1290 #else
1292 #endif
1293  }
1294 
1295  template <>
1296  constexpr auto GetMfma<f8_t, 16, 16, bf8_t, true, false>()
1297  {
1299  }
1300 
1301  template <>
1302  constexpr auto GetMfma<f8_t, 16, 16, bf8_t, false, false>()
1303  {
1304 #if defined(__gfx950__)
1306 #else
1308 #endif
1309  }
1310 
1311  template <>
1312  constexpr auto GetMfma<bf8_t, 32, 32, f8_t, true, false>()
1313  {
1315  }
1316 
1317  template <>
1318  constexpr auto GetMfma<bf8_t, 32, 32, f8_t, false, false>()
1319  {
1320 #if defined(__gfx950__)
1322 #else
1324 #endif
1325  }
1326 
1327  template <>
1328  constexpr auto GetMfma<bf8_t, 16, 16, f8_t, true, false>()
1329  {
1331  }
1332 
1333  template <>
1334  constexpr auto GetMfma<bf8_t, 16, 16, f8_t, false, false>()
1335  {
1336 #if defined(__gfx950__)
1338 #else
1340 #endif
1341  }
1342 
1344  MPerXdlops,
1345  NPerXdlops,
1347  is_single_rate_mfma,
1348  is_scale_mfma>()>{};
1349 
1350  __host__ __device__ constexpr MfmaSelector()
1351  {
1352  static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk ==
1353  selected_mfma.num_regs_per_blk,
1354  "wrong! num_regs_per_blk");
1355 
1356  static_assert(selected_mfma.num_threads_per_blk == selected_mfma.n_per_blk,
1357  "n_per_blk != num_threads_per_blk");
1358 
1359  static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks ==
1360  selected_mfma.m_per_blk,
1361  "m_per_blk != num_input_blks * num_regs_per_blk");
1362 
1363  static_assert(selected_mfma.num_output_blks == selected_mfma.num_input_blks ||
1364  selected_mfma.num_output_blks == 1,
1365  "incorrect num_output_blks");
1366 
1367  static_assert(selected_mfma.num_regs_per_blk * selected_mfma.wave_size ==
1368  selected_mfma.m_per_blk * selected_mfma.n_per_blk,
1369  "num_regs_per_blk incorrect");
1370 
1371  static_assert(selected_mfma.is_k_reduction ||
1372  (selected_mfma.num_input_blks == selected_mfma.num_output_blks),
1373  "is_k_reduction wrong!");
1374  }
1375 
1376  static constexpr bool IsABroadcast()
1377  {
1378  static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast");
1379  return true;
1380  }
1381 
1382  static constexpr index_t GetKPerXdlops()
1383  {
1384  return (selected_mfma.is_k_reduction ? selected_mfma.num_input_blks : 1) *
1385  selected_mfma.k_per_blk;
1386  }
1387 
1388  static constexpr index_t GetK1PerXdlops() { return selected_mfma.k_per_blk; }
1389 };
1390 
1391 template <typename base_type,
1392  index_t MPerXdlops,
1393  index_t NPerXdlops,
1394  index_t KPack,
1395  typename additional_type = base_type,
1396  bool TransposeC = false,
1397  bool is_scale_mfma = false>
1399 {
1400  static constexpr auto I0 = Number<0>{};
1401  static constexpr auto I1 = Number<1>{};
1402  static constexpr auto I2 = Number<2>{};
1403  static constexpr auto I3 = Number<3>{};
1404  static constexpr auto I4 = Number<4>{};
1405  static constexpr auto I5 = Number<5>{};
1406 
1409 
1410  __device__ static constexpr index_t GetNumBlks() { return mfma_instr.num_output_blks; }
1411 
1412  __device__ static constexpr index_t GetNumXdlops()
1413  {
1414  return MPerXdlops * NPerXdlops /
1415  (mfma_instr.m_per_blk * mfma_instr.n_per_blk * mfma_instr.num_output_blks);
1416  }
1417 
1418  __host__ __device__ constexpr XdlopsGemm()
1419  {
1420  static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 ||
1421  NPerXdlops == 64,
1422  "Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
1423 
1424  static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 ||
1425  MPerXdlops == 64,
1426  "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
1427 
1428  static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack should be a multiple of k_per_blk");
1429  }
1430 
1431  // XDL output supporting C = A * B
1432  // M2_N2 -> M2_M3_M4_N2
1433  template <typename CDesc_M0_N0_M1_N1_M2_N2>
1434  __host__ __device__ static constexpr auto
1435  MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
1436  {
1437  const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
1438  const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
1439  const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
1440  const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
1441 
1443  c_desc_m0_n0_m1_n1_m2_n2,
1449  Number<mfma_instr.num_input_blks>{},
1450  Number<mfma_instr.group_size>{})),
1453  Sequence<1>{},
1454  Sequence<2>{},
1455  Sequence<3>{},
1456  Sequence<4>{},
1457  Sequence<5>{}),
1459  Sequence<1>{},
1460  Sequence<2>{},
1461  Sequence<3>{},
1463  Sequence<7>{}));
1464  }
1465 
1466  // XDL output supporting C = A * B
1467  // M3_N3 -> M3_M4_M5_N3
1468  template <typename CDesc_M0_N0_M1_N1_M2_N2>
1469  __host__ __device__ static constexpr auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(
1470  const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
1471  {
1472  const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
1473  const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
1474  const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
1475  const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
1476  const auto M2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I4);
1477  const auto N2 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I5);
1478 
1480  c_desc_m0_n0_m1_n1_m2_n2,
1488  Number<mfma_instr.num_input_blks>{},
1489  Number<mfma_instr.group_size>{})),
1492  Sequence<1>{},
1493  Sequence<2>{},
1494  Sequence<3>{},
1495  Sequence<4>{},
1496  Sequence<5>{},
1497  Sequence<6>{},
1498  Sequence<7>{}),
1500  Sequence<1>{},
1501  Sequence<2>{},
1502  Sequence<3>{},
1503  Sequence<4>{},
1504  Sequence<5>{},
1506  Sequence<9>{}));
1507  }
1508 
1509  // transposed XDL output supporting C' = B' * A'
1510  // M2_N2 -> M2_N2_N3_N4
1511  template <typename CDesc_M0_N0_M1_N1_M2_N2>
1512  __host__ __device__ static constexpr auto
1513  MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
1514  {
1515  const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
1516  const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
1517  const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
1518  const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
1519 
1521  c_desc_m0_n0_m1_n1_m2_n2,
1528  Number<mfma_instr.num_input_blks>{},
1529  Number<mfma_instr.group_size>{}))),
1531  Sequence<1>{},
1532  Sequence<2>{},
1533  Sequence<3>{},
1534  Sequence<4>{},
1535  Sequence<5>{}),
1537  Sequence<1>{},
1538  Sequence<2>{},
1539  Sequence<3>{},
1540  Sequence<4>{},
1541  Sequence<5, 6, 7>{}));
1542  }
1543 
1544  template <typename CDesc_G_M0_N0_M1_N1_M2_N2>
1545  __host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
1546  const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2)
1547  {
1548  const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0);
1549  const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1);
1550  const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2);
1551  const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3);
1552  const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4);
1553 
1555  c_desc_g_m0_n0_m1_n1_m2_n2,
1561  make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk,
1562  mfma_instr.num_input_blks,
1563  mfma_instr.group_size)),
1564  make_pass_through_transform(mfma_instr.num_threads_per_blk)),
1566  Sequence<1>{},
1567  Sequence<2>{},
1568  Sequence<3>{},
1569  Sequence<4>{},
1570  Sequence<5>{},
1571  Sequence<6>{}),
1573  Sequence<1>{},
1574  Sequence<2>{},
1575  Sequence<3>{},
1576  Sequence<4>{},
1578  Sequence<8>{}));
1579  }
1580 
1581  __device__ static constexpr index_t GetRegSizePerXdlops()
1582  {
1583  return MPerXdlops * NPerXdlops / mfma_instr.wave_size;
1584  }
1585 
1586  __device__ static constexpr index_t GetWaveSize() { return mfma_instr.wave_size; }
1587 
1588  template <class FloatA, class FloatB, class FloatC>
1589  __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
1590  {
1591  static_assert(
1598  "base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!");
1599 
1600  static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
1601  if constexpr(!TransposeC)
1602  {
1603  mfma_instr.template run<MPerXdlops, NPerXdlops>(
1604  p_a_wave[k], p_b_wave[k], p_c_thread);
1605  }
1606  else
1607  {
1608  mfma_instr.template run<MPerXdlops, NPerXdlops>(
1609  p_b_wave[k], p_a_wave[k], p_c_thread);
1610  }
1611  });
1612  }
1613 
1614  template <index_t OpselA,
1615  index_t OpselB,
1616  class FloatA,
1617  class ScaleA,
1618  class FloatB,
1619  class ScaleB,
1620  class FloatC>
1621  __device__ void Run(const FloatA& p_a_wave,
1622  const ScaleA& a_scale_thread,
1623  const FloatB& p_b_wave,
1624  const ScaleB& b_scale_thread,
1625  FloatC& p_c_thread) const
1626  {
1627  static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
1628  if constexpr(!TransposeC)
1629  {
1630  mfma_instr.template run<MPerXdlops, NPerXdlops, OpselA, OpselB>(
1631  p_a_wave[k], a_scale_thread[k], p_b_wave[k], b_scale_thread[k], p_c_thread);
1632  }
1633  else
1634  {
1635  mfma_instr.template run<MPerXdlops, NPerXdlops, OpselB, OpselA>(
1636  p_b_wave[k], b_scale_thread[k], p_a_wave[k], a_scale_thread[k], p_c_thread);
1637  }
1638  });
1639  }
1640 
1641  __device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_instr.wave_size; }
1642 
1643  __device__ static auto GetBlkIdx()
1644  {
1645  const auto laneId = GetLaneId();
1646 
1647  constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
1649  make_tuple(1, mfma_instr.num_input_blks, mfma_instr.num_threads_per_blk))),
1651  make_tuple(Sequence<0>{}));
1652 
1653  const auto blk_idx =
1654  threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId));
1655 
1656  const auto blk_id = blk_idx[I1];
1657  const auto blk_td = blk_idx[I2];
1658 
1659  return make_tuple(blk_id, blk_td);
1660  }
1661 
1662  __host__ __device__ static auto CalculateAThreadOriginDataIndex()
1663  {
1664  const auto laneId = GetLaneId();
1665  const auto blk_idx = GetBlkIdx();
1666 
1667  const auto blk_id = blk_idx[I0];
1668  const auto blk_td = blk_idx[I1];
1669 
1670  if constexpr(mfma_instr.is_k_reduction)
1671  {
1672  return make_tuple(blk_id, blk_td);
1673  }
1674  else
1675  {
1676  return make_tuple(0, laneId);
1677  }
1678  }
1679 
1680  __host__ __device__ static auto CalculateBThreadOriginDataIndex()
1681  {
1682  const auto laneId = GetLaneId();
1683  const auto blk_idx = GetBlkIdx();
1684 
1685  const auto blk_id = blk_idx[I0];
1686  const auto blk_td = blk_idx[I1];
1687 
1688  if constexpr(mfma_instr.is_k_reduction)
1689  {
1690  return make_tuple(blk_id, blk_td);
1691  }
1692  else
1693  {
1694  return make_tuple(0, laneId);
1695  }
1696  }
1697 
1698  __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
1699  {
1700  const auto blk_idx = GetBlkIdx();
1701 
1702  const auto blk_id = blk_idx[I0];
1703  const auto blk_td = blk_idx[I1];
1704 
1705  index_t n_offset = blk_i * mfma_instr.n_per_blk + blk_td;
1706  index_t m_offset = xdlops_i * mfma_instr.m_per_blk + blk_id * mfma_instr.group_size;
1707 
1708  return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
1709  }
1710 
1711  __device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */)
1712  {
1713  const auto blk_idx = GetBlkIdx();
1714 
1715  const auto blk_id = blk_idx[I0];
1716  const auto blk_td = blk_idx[I1];
1717 
1718  return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td};
1719  }
1720 
1721  // Falls back to single rate instruction on gfx950 if KPack is single rate; no change on gfx942-
1722  // when base_type is either f8_t or bf8_t, additional_type will always be either f8_t or bf8_t,
1723  // except Use single rate mfma instruction for this special case A (f8_t) * B (pk_i4_t)
1724  static constexpr bool is_single_rate_mfma =
1726  KPack <= 4) ||
1727  (is_same<base_type, int8_t>::value && KPack <= 8) ||
1730  ? true
1731  : false;
1732  static constexpr auto mfma = MfmaSelector<base_type,
1733  MPerXdlops,
1734  NPerXdlops,
1735  additional_type,
1737  is_scale_mfma>{};
1738 
1739  static constexpr auto mfma_instr = mfma.selected_mfma;
1740 
1741  static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
1742  static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
1743  static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
1744 
1745  __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths()
1746  {
1747  return make_tuple(
1749  }
1750 };
1751 
1752 } // namespace ck
Definition: ck.hpp:269
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
MfmaInstr
Definition: xdlops_gemm.hpp:42
@ mfma_scale_f32_32x32x64f8f6f4
@ mfma_f32_16x16x16bf16_1k
@ mfma_scale_f32_16x16x128f8f6f4
@ mfma_f32_16x16x128f8f6f4
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
typename packed_type_info< T >::element_type element_type_t
Definition: data_type.hpp:405
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:300
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
Definition: array.hpp:14
Definition: xdlops_gemm.hpp:942
__host__ constexpr __device__ MfmaSelector()
Definition: xdlops_gemm.hpp:1350
static constexpr bool IsABroadcast()
Definition: xdlops_gemm.hpp:1376
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1388
static constexpr auto GetMfma()
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1343
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1382
Definition: sequence.hpp:43
Definition: xdlops_gemm.hpp:1399
static constexpr auto mfma_instr
Definition: xdlops_gemm.hpp:1739
__host__ constexpr __device__ XdlopsGemm()
Definition: xdlops_gemm.hpp:1418
__host__ static __device__ auto CalculateBThreadOriginDataIndex()
Definition: xdlops_gemm.hpp:1680
static __device__ auto GetBlkIdx()
Definition: xdlops_gemm.hpp:1643
static constexpr auto I2
Definition: xdlops_gemm.hpp:1402
static constexpr __device__ index_t GetNumBlks()
Definition: xdlops_gemm.hpp:1410
static __device__ auto GetLaneId()
Definition: xdlops_gemm.hpp:1641
static constexpr auto K0PerXdlops
Definition: xdlops_gemm.hpp:1743
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1469
static constexpr __device__ index_t GetNumXdlops()
Definition: xdlops_gemm.hpp:1412
__host__ static __device__ auto CalculateAThreadOriginDataIndex()
Definition: xdlops_gemm.hpp:1662
static constexpr bool is_single_rate_mfma
Definition: xdlops_gemm.hpp:1724
static __device__ CIndex4D GetBeginOfThreadBlk4D(index_t, index_t)
Definition: xdlops_gemm.hpp:1711
static constexpr __device__ index_t GetWaveSize()
Definition: xdlops_gemm.hpp:1586
static constexpr __device__ index_t GetRegSizePerXdlops()
Definition: xdlops_gemm.hpp:1581
static constexpr auto I5
Definition: xdlops_gemm.hpp:1405
static constexpr auto I3
Definition: xdlops_gemm.hpp:1403
static constexpr auto I0
Definition: xdlops_gemm.hpp:1400
__device__ void Run(const FloatA &p_a_wave, const ScaleA &a_scale_thread, const FloatB &p_b_wave, const ScaleB &b_scale_thread, FloatC &p_c_thread) const
Definition: xdlops_gemm.hpp:1621
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1435
__host__ static constexpr __device__ auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_G_M0_N0_M1_N1_M2_N2 &c_desc_g_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1545
static constexpr auto I1
Definition: xdlops_gemm.hpp:1401
static constexpr auto K1PerXdlops
Definition: xdlops_gemm.hpp:1742
static constexpr auto KPerXdlops
Definition: xdlops_gemm.hpp:1741
static constexpr auto I4
Definition: xdlops_gemm.hpp:1404
__device__ void Run(const FloatA &p_a_wave, const FloatB &p_b_wave, FloatC &p_c_thread) const
Definition: xdlops_gemm.hpp:1589
static constexpr auto mfma
Definition: xdlops_gemm.hpp:1732
static __device__ CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
Definition: xdlops_gemm.hpp:1698
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1513
__host__ static constexpr __device__ auto GetCM0M1M2NThreadBlkLengths()
Definition: xdlops_gemm.hpp:1745
Definition: integral_constant.hpp:20
Definition: amd_xdlops.hpp:1202
Definition: amd_xdlops.hpp:303
Definition: amd_xdlops.hpp:193
Definition: amd_xdlops.hpp:70
Definition: amd_xdlops.hpp:269
Definition: amd_xdlops.hpp:1483
Definition: amd_xdlops.hpp:1609
Definition: amd_xdlops.hpp:159
Definition: amd_xdlops.hpp:1546
Definition: amd_xdlops.hpp:1420
Definition: amd_xdlops.hpp:207
Definition: amd_xdlops.hpp:56
Definition: amd_xdlops.hpp:331
Definition: amd_xdlops.hpp:249
Definition: amd_xdlops.hpp:1451
Definition: amd_xdlops.hpp:1577
Definition: amd_xdlops.hpp:139
Definition: amd_xdlops.hpp:1514
Definition: amd_xdlops.hpp:1388
Definition: amd_xdlops.hpp:15
Definition: amd_xdlops.hpp:42
Definition: amd_xdlops.hpp:317
Definition: amd_xdlops.hpp:112
Definition: amd_xdlops.hpp:481
Definition: amd_xdlops.hpp:289
Definition: amd_xdlops.hpp:179
Definition: amd_xdlops.hpp:84
Definition: amd_xdlops.hpp:221
Definition: amd_xdlops.hpp:461
Definition: amd_xdlops.hpp:364
Definition: amd_xdlops.hpp:442
Definition: amd_xdlops.hpp:403
Definition: amd_xdlops.hpp:423
Definition: amd_xdlops.hpp:383
Definition: amd_xdlops.hpp:345
Definition: amd_xdlops.hpp:886
Definition: amd_xdlops.hpp:666
Definition: type.hpp:177
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:854
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:432
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:300
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:167
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:410
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:718
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:806
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:278
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:762
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:674
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:322
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:145
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:476
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:366
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:696
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:784
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:256
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:740
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:652
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:101
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:123
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:454
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:212
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:830
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:388
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:234
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:190
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:344
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:630
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:520
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:564
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:608
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:542
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:586
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:498
__device__ void run(const FloatA &a, const ScaleA &scale_a, const FloatB &b, const ScaleB &scale_b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:923
__device__ void run(const FloatA &a, const ScaleA &scale_a, const FloatB &b, const ScaleB &scale_b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:886
Definition: xdlops_gemm.hpp:83
Definition: functional2.hpp:33