/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp Source File
xdlops_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/math.hpp"
9 
10 namespace ck {
11 
12 enum struct MfmaInstr
13 {
51 };
52 
53 template <MfmaInstr instr>
54 struct mfma_type;
55 
56 template <>
58 {
59  static constexpr index_t group_size = 4;
60  static constexpr index_t num_groups_per_blk = 4;
61  static constexpr index_t num_regs_per_blk = 16;
62  static constexpr index_t num_threads_per_blk = 32;
63  static constexpr index_t wave_size = 64;
64  static constexpr index_t num_input_blks = 2;
65  static constexpr index_t num_output_blks = 2;
66  static constexpr index_t m_per_blk = 32;
67  static constexpr index_t n_per_blk = 32;
68  static constexpr index_t k_per_blk = 1;
69  static constexpr bool is_k_reduction = false;
70 
71  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
72  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
73  {
75  }
76 };
77 
78 template <>
80 {
81  static constexpr index_t group_size = 4;
82  static constexpr index_t num_groups_per_blk = 4;
83  static constexpr index_t num_regs_per_blk = 16;
84  static constexpr index_t num_threads_per_blk = 32;
85  static constexpr index_t wave_size = 64;
86  static constexpr index_t num_input_blks = 2;
87  static constexpr index_t num_output_blks = 1;
88  static constexpr index_t m_per_blk = 32;
89  static constexpr index_t n_per_blk = 32;
90  static constexpr index_t k_per_blk = 1;
91  static constexpr bool is_k_reduction = true;
92 
93  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
94  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
95  {
97  }
98 };
99 
100 template <>
102 {
103  static constexpr index_t group_size = 4;
104  static constexpr index_t num_groups_per_blk = 1;
105  static constexpr index_t num_regs_per_blk = 4;
106  static constexpr index_t num_threads_per_blk = 16;
107  static constexpr index_t wave_size = 64;
108  static constexpr index_t num_input_blks = 4;
109  static constexpr index_t num_output_blks = 1;
110  static constexpr index_t m_per_blk = 16;
111  static constexpr index_t n_per_blk = 16;
112  static constexpr index_t k_per_blk = 1;
113  static constexpr bool is_k_reduction = true;
114 
115  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
116  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
117  {
119  }
120 };
121 
122 template <>
124 {
125  static constexpr index_t group_size = 4;
126  static constexpr index_t num_groups_per_blk = 1;
127  static constexpr index_t num_regs_per_blk = 4;
128  static constexpr index_t num_threads_per_blk = 16;
129  static constexpr index_t wave_size = 64;
130  static constexpr index_t num_input_blks = 4;
131  static constexpr index_t num_output_blks = 4;
132  static constexpr index_t m_per_blk = 16;
133  static constexpr index_t n_per_blk = 16;
134  static constexpr index_t k_per_blk = 1;
135  static constexpr bool is_k_reduction = false;
136 
137  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
138  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
139  {
141  }
142 };
143 
144 // treat 4x4x1 as a single-blk 4x64 mfma
145 template <>
147 {
148  static constexpr index_t group_size = 4;
149  static constexpr index_t num_groups_per_blk = 1;
150  static constexpr index_t num_regs_per_blk = 4;
151  static constexpr index_t num_threads_per_blk = 64;
152  static constexpr index_t wave_size = 64;
153  static constexpr index_t num_input_blks = 1;
154  static constexpr index_t num_output_blks = 1;
155  static constexpr index_t m_per_blk = 4;
156  static constexpr index_t n_per_blk = 64;
157  static constexpr index_t k_per_blk = 1;
158  static constexpr bool is_k_reduction = false;
159 
160  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
161  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
162  {
164  }
165 };
166 
167 template <>
169 {
170  static constexpr index_t group_size = 4;
171  static constexpr index_t num_groups_per_blk = 4;
172  static constexpr index_t num_regs_per_blk = 16;
173  static constexpr index_t num_threads_per_blk = 32;
174  static constexpr index_t wave_size = 64;
175  static constexpr index_t num_input_blks = 2;
176  static constexpr index_t num_output_blks = 2;
177  static constexpr index_t m_per_blk = 32;
178  static constexpr index_t n_per_blk = 32;
179  static constexpr index_t k_per_blk = 4;
180  static constexpr bool is_k_reduction = false;
181 
182  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
183  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
184  {
186  }
187 };
188 
189 template <>
191 {
192  static constexpr index_t group_size = 4;
193  static constexpr index_t num_groups_per_blk = 4;
194  static constexpr index_t num_regs_per_blk = 16;
195  static constexpr index_t num_threads_per_blk = 32;
196  static constexpr index_t wave_size = 64;
197  static constexpr index_t num_input_blks = 2;
198  static constexpr index_t num_output_blks = 1;
199  static constexpr index_t m_per_blk = 32;
200  static constexpr index_t n_per_blk = 32;
201  static constexpr index_t k_per_blk = 4;
202  static constexpr bool is_k_reduction = true;
203 
204  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
205  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
206  {
208  }
209 };
210 
211 template <>
213 {
214  static constexpr index_t group_size = 4;
215  static constexpr index_t num_groups_per_blk = 4;
216  static constexpr index_t num_regs_per_blk = 16;
217  static constexpr index_t num_threads_per_blk = 32;
218  static constexpr index_t wave_size = 64;
219  static constexpr index_t num_input_blks = 2;
220  static constexpr index_t num_output_blks = 1;
221  static constexpr index_t m_per_blk = 32;
222  static constexpr index_t n_per_blk = 32;
223  static constexpr index_t k_per_blk = 8;
224  static constexpr bool is_k_reduction = true;
225 
226  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
227  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
228  {
230  }
231 };
232 
233 template <>
235 {
236  static constexpr index_t group_size = 4;
237  static constexpr index_t num_groups_per_blk = 1;
238  static constexpr index_t num_regs_per_blk = 4;
239  static constexpr index_t num_threads_per_blk = 16;
240  static constexpr index_t wave_size = 64;
241  static constexpr index_t num_input_blks = 4;
242  static constexpr index_t num_output_blks = 1;
243  static constexpr index_t m_per_blk = 16;
244  static constexpr index_t n_per_blk = 16;
245  static constexpr index_t k_per_blk = 8;
246  static constexpr bool is_k_reduction = true;
247 
248  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
249  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
250  {
252  }
253 };
254 
255 template <>
257 {
258  static constexpr index_t group_size = 4;
259  static constexpr index_t num_groups_per_blk = 1;
260  static constexpr index_t num_regs_per_blk = 4;
261  static constexpr index_t num_threads_per_blk = 16;
262  static constexpr index_t wave_size = 64;
263  static constexpr index_t num_input_blks = 4;
264  static constexpr index_t num_output_blks = 1;
265  static constexpr index_t m_per_blk = 16;
266  static constexpr index_t n_per_blk = 16;
267  static constexpr index_t k_per_blk = 4;
268  static constexpr bool is_k_reduction = true;
269 
270  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
271  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
272  {
274  }
275 };
276 
277 template <>
279 {
280  static constexpr index_t group_size = 4;
281  static constexpr index_t num_groups_per_blk = 1;
282  static constexpr index_t num_regs_per_blk = 4;
283  static constexpr index_t num_threads_per_blk = 16;
284  static constexpr index_t wave_size = 64;
285  static constexpr index_t num_input_blks = 4;
286  static constexpr index_t num_output_blks = 4;
287  static constexpr index_t m_per_blk = 16;
288  static constexpr index_t n_per_blk = 16;
289  static constexpr index_t k_per_blk = 4;
290  static constexpr bool is_k_reduction = false;
291 
292  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
293  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
294  {
296  }
297 };
298 
299 template <>
301 {
302  static constexpr index_t group_size = 4;
303  static constexpr index_t num_groups_per_blk = 1;
304  static constexpr index_t num_regs_per_blk = 4;
305  static constexpr index_t num_threads_per_blk = 64;
306  static constexpr index_t wave_size = 64;
307  static constexpr index_t num_input_blks = 1;
308  static constexpr index_t num_output_blks = 1;
309  static constexpr index_t m_per_blk = 4;
310  static constexpr index_t n_per_blk = 64;
311  static constexpr index_t k_per_blk = 4;
312  static constexpr bool is_k_reduction = false;
313 
314  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
315  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
316  {
318  }
319 };
320 
321 template <>
323 {
324  static constexpr index_t group_size = 4;
325  static constexpr index_t num_groups_per_blk = 4;
326  static constexpr index_t num_regs_per_blk = 16;
327  static constexpr index_t num_threads_per_blk = 32;
328  static constexpr index_t wave_size = 64;
329  static constexpr index_t num_input_blks = 2;
330  static constexpr index_t num_output_blks = 1;
331  static constexpr index_t m_per_blk = 32;
332  static constexpr index_t n_per_blk = 32;
333  static constexpr index_t k_per_blk = 8;
334  static constexpr bool is_k_reduction = true;
335 
336  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
337  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
338  {
340  }
341 };
342 
343 template <>
345 {
346  static constexpr index_t group_size = 4;
347  static constexpr index_t num_groups_per_blk = 4;
348  static constexpr index_t num_regs_per_blk = 16;
349  static constexpr index_t num_threads_per_blk = 32;
350  static constexpr index_t wave_size = 64;
351  static constexpr index_t num_input_blks = 2;
352  static constexpr index_t num_output_blks = 1;
353  static constexpr index_t m_per_blk = 32;
354  static constexpr index_t n_per_blk = 32;
355  static constexpr index_t k_per_blk = 4;
356  static constexpr bool is_k_reduction = true;
357 
358  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
359  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
360  {
362  }
363 };
364 
365 template <>
367 {
368  static constexpr index_t group_size = 4;
369  static constexpr index_t num_groups_per_blk = 1;
370  static constexpr index_t num_regs_per_blk = 4;
371  static constexpr index_t num_threads_per_blk = 16;
372  static constexpr index_t wave_size = 64;
373  static constexpr index_t num_input_blks = 4;
374  static constexpr index_t num_output_blks = 1;
375  static constexpr index_t m_per_blk = 16;
376  static constexpr index_t n_per_blk = 16;
377  static constexpr index_t k_per_blk = 8;
378  static constexpr bool is_k_reduction = true;
379 
380  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
381  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
382  {
384  }
385 };
386 
387 template <>
389 {
390  static constexpr index_t group_size = 4;
391  static constexpr index_t num_groups_per_blk = 1;
392  static constexpr index_t num_regs_per_blk = 4;
393  static constexpr index_t num_threads_per_blk = 16;
394  static constexpr index_t wave_size = 64;
395  static constexpr index_t num_input_blks = 4;
396  static constexpr index_t num_output_blks = 1;
397  static constexpr index_t m_per_blk = 16;
398  static constexpr index_t n_per_blk = 16;
399  static constexpr index_t k_per_blk = 4;
400  static constexpr bool is_k_reduction = true;
401 
402  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
403  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
404  {
406  }
407 };
408 
409 template <>
411 {
412  static constexpr index_t group_size = 4;
413  static constexpr index_t num_groups_per_blk = 4;
414  static constexpr index_t num_regs_per_blk = 16;
415  static constexpr index_t num_threads_per_blk = 32;
416  static constexpr index_t wave_size = 64;
417  static constexpr index_t num_input_blks = 2;
418  static constexpr index_t num_output_blks = 1;
419  static constexpr index_t m_per_blk = 32;
420  static constexpr index_t n_per_blk = 32;
421  static constexpr index_t k_per_blk = 2;
422  static constexpr bool is_k_reduction = true;
423 
424  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
425  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
426  {
428  }
429 };
430 
431 template <>
433 {
434  static constexpr index_t group_size = 4;
435  static constexpr index_t num_groups_per_blk = 1;
436  static constexpr index_t num_regs_per_blk = 4;
437  static constexpr index_t num_threads_per_blk = 16;
438  static constexpr index_t wave_size = 64;
439  static constexpr index_t num_input_blks = 4;
440  static constexpr index_t num_output_blks = 1;
441  static constexpr index_t m_per_blk = 16;
442  static constexpr index_t n_per_blk = 16;
443  static constexpr index_t k_per_blk = 2;
444  static constexpr bool is_k_reduction = true;
445 
446  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
447  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
448  {
450  }
451 };
452 
453 template <>
455 {
456  static constexpr index_t group_size = 4;
457  static constexpr index_t num_groups_per_blk = 4;
458  static constexpr index_t num_regs_per_blk = 16;
459  static constexpr index_t num_threads_per_blk = 32;
460  static constexpr index_t wave_size = 64;
461  static constexpr index_t num_input_blks = 2;
462  static constexpr index_t num_output_blks = 1;
463  static constexpr index_t m_per_blk = 32;
464  static constexpr index_t n_per_blk = 32;
465  static constexpr index_t k_per_blk = 4;
466  static constexpr bool is_k_reduction = true;
467 
468  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
469  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
470  {
472  }
473 };
474 
475 template <>
477 {
478  static constexpr index_t group_size = 4;
479  static constexpr index_t num_groups_per_blk = 1;
480  static constexpr index_t num_regs_per_blk = 4;
481  static constexpr index_t num_threads_per_blk = 16;
482  static constexpr index_t wave_size = 64;
483  static constexpr index_t num_input_blks = 4;
484  static constexpr index_t num_output_blks = 1;
485  static constexpr index_t m_per_blk = 16;
486  static constexpr index_t n_per_blk = 16;
487  static constexpr index_t k_per_blk = 4;
488  static constexpr bool is_k_reduction = true;
489 
490  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
491  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
492  {
494  }
495 };
496 
497 template <>
499 {
500  static constexpr index_t group_size = 4;
501  static constexpr index_t num_groups_per_blk = 4;
502  static constexpr index_t num_regs_per_blk = 16;
503  static constexpr index_t num_threads_per_blk = 32;
504  static constexpr index_t wave_size = 64;
505  static constexpr index_t num_input_blks = 2;
506  static constexpr index_t num_output_blks = 1;
507  static constexpr index_t m_per_blk = 32;
508  static constexpr index_t n_per_blk = 32;
509  static constexpr index_t k_per_blk = 8;
510  static constexpr bool is_k_reduction = true;
511 
512  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
513  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
514  {
516  }
517 };
518 
519 template <>
521 {
522  static constexpr index_t group_size = 4;
523  static constexpr index_t num_groups_per_blk = 1;
524  static constexpr index_t num_regs_per_blk = 4;
525  static constexpr index_t num_threads_per_blk = 16;
526  static constexpr index_t wave_size = 64;
527  static constexpr index_t num_input_blks = 4;
528  static constexpr index_t num_output_blks = 1;
529  static constexpr index_t m_per_blk = 16;
530  static constexpr index_t n_per_blk = 16;
531  static constexpr index_t k_per_blk = 8;
532  static constexpr bool is_k_reduction = true;
533 
534  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
535  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
536  {
538  }
539 };
540 
541 template <>
543 {
544  static constexpr index_t group_size = 4;
545  static constexpr index_t num_groups_per_blk = 4;
546  static constexpr index_t num_regs_per_blk = 16;
547  static constexpr index_t num_threads_per_blk = 32;
548  static constexpr index_t wave_size = 64;
549  static constexpr index_t num_input_blks = 2;
550  static constexpr index_t num_output_blks = 1;
551  static constexpr index_t m_per_blk = 32;
552  static constexpr index_t n_per_blk = 32;
553  static constexpr index_t k_per_blk = 16;
554  static constexpr bool is_k_reduction = true;
555 
556  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
557  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
558  {
560  }
561 };
562 
563 template <>
565 {
566  static constexpr index_t group_size = 4;
567  static constexpr index_t num_groups_per_blk = 1;
568  static constexpr index_t num_regs_per_blk = 4;
569  static constexpr index_t num_threads_per_blk = 16;
570  static constexpr index_t wave_size = 64;
571  static constexpr index_t num_input_blks = 4;
572  static constexpr index_t num_output_blks = 1;
573  static constexpr index_t m_per_blk = 16;
574  static constexpr index_t n_per_blk = 16;
575  static constexpr index_t k_per_blk = 16;
576  static constexpr bool is_k_reduction = true;
577 
578  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
579  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
580  {
582  }
583 };
584 
585 template <>
587 {
588  static constexpr index_t group_size = 1;
589  static constexpr index_t num_groups_per_blk = 4;
590  static constexpr index_t num_regs_per_blk = 4; // group_size * num_groups_per_blk;
591  static constexpr index_t num_threads_per_blk = 16;
592  static constexpr index_t wave_size = 64;
593  static constexpr index_t num_input_blks = 4; // wave_size / num_threads_per_blk;
594  static constexpr index_t num_output_blks = 1;
595  static constexpr index_t m_per_blk = 16;
596  static constexpr index_t n_per_blk = 16;
597  static constexpr index_t k_per_blk = 1;
598  static constexpr bool is_k_reduction = true;
599 
600  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
601  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
602  {
604  }
605 };
606 
607 template <>
609 {
610  static constexpr index_t group_size = 4;
611  static constexpr index_t num_groups_per_blk = 4;
612  static constexpr index_t num_regs_per_blk = 16;
613  static constexpr index_t num_threads_per_blk = 32;
614  static constexpr index_t wave_size = 64;
615  static constexpr index_t num_input_blks = 2;
616  static constexpr index_t num_output_blks = 1;
617  static constexpr index_t m_per_blk = 32;
618  static constexpr index_t n_per_blk = 32;
619  static constexpr index_t k_per_blk = 8;
620  static constexpr bool is_k_reduction = true;
621 
622  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
623  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
624  {
626  }
627 };
628 
629 template <>
631 {
632  static constexpr index_t group_size = 4;
633  static constexpr index_t num_groups_per_blk = 1;
634  static constexpr index_t num_regs_per_blk = 4;
635  static constexpr index_t num_threads_per_blk = 16;
636  static constexpr index_t wave_size = 64;
637  static constexpr index_t num_input_blks = 4;
638  static constexpr index_t num_output_blks = 1;
639  static constexpr index_t m_per_blk = 16;
640  static constexpr index_t n_per_blk = 16;
641  static constexpr index_t k_per_blk = 8;
642  static constexpr bool is_k_reduction = true;
643 
644  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
645  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
646  {
648  }
649 };
650 
651 template <>
653 {
654  static constexpr index_t group_size = 4;
655  static constexpr index_t num_groups_per_blk = 4;
656  static constexpr index_t num_regs_per_blk = 16;
657  static constexpr index_t num_threads_per_blk = 32;
658  static constexpr index_t wave_size = 64;
659  static constexpr index_t num_input_blks = 2;
660  static constexpr index_t num_output_blks = 1;
661  static constexpr index_t m_per_blk = 32;
662  static constexpr index_t n_per_blk = 32;
663  static constexpr index_t k_per_blk = 8;
664  static constexpr bool is_k_reduction = true;
665 
666  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
667  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
668  {
670  }
671 };
672 
673 template <>
675 {
676  static constexpr index_t group_size = 4;
677  static constexpr index_t num_groups_per_blk = 1;
678  static constexpr index_t num_regs_per_blk = 4;
679  static constexpr index_t num_threads_per_blk = 16;
680  static constexpr index_t wave_size = 64;
681  static constexpr index_t num_input_blks = 4;
682  static constexpr index_t num_output_blks = 1;
683  static constexpr index_t m_per_blk = 16;
684  static constexpr index_t n_per_blk = 16;
685  static constexpr index_t k_per_blk = 8;
686  static constexpr bool is_k_reduction = true;
687 
688  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
689  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
690  {
692  }
693 };
694 
695 template <>
697 {
698  static constexpr index_t group_size = 4;
699  static constexpr index_t num_groups_per_blk = 4;
700  static constexpr index_t num_regs_per_blk = 16;
701  static constexpr index_t num_threads_per_blk = 32;
702  static constexpr index_t wave_size = 64;
703  static constexpr index_t num_input_blks = 2;
704  static constexpr index_t num_output_blks = 1;
705  static constexpr index_t m_per_blk = 32;
706  static constexpr index_t n_per_blk = 32;
707  static constexpr index_t k_per_blk = 8;
708  static constexpr bool is_k_reduction = true;
709 
710  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
711  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
712  {
714  }
715 };
716 
717 template <>
719 {
720  static constexpr index_t group_size = 4;
721  static constexpr index_t num_groups_per_blk = 1;
722  static constexpr index_t num_regs_per_blk = 4;
723  static constexpr index_t num_threads_per_blk = 16;
724  static constexpr index_t wave_size = 64;
725  static constexpr index_t num_input_blks = 4;
726  static constexpr index_t num_output_blks = 1;
727  static constexpr index_t m_per_blk = 16;
728  static constexpr index_t n_per_blk = 16;
729  static constexpr index_t k_per_blk = 8;
730  static constexpr bool is_k_reduction = true;
731 
732  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
733  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
734  {
736  }
737 };
738 
739 template <>
741 {
742  static constexpr index_t group_size = 4;
743  static constexpr index_t num_groups_per_blk = 4;
744  static constexpr index_t num_regs_per_blk = 16;
745  static constexpr index_t num_threads_per_blk = 32;
746  static constexpr index_t wave_size = 64;
747  static constexpr index_t num_input_blks = 2;
748  static constexpr index_t num_output_blks = 1;
749  static constexpr index_t m_per_blk = 32;
750  static constexpr index_t n_per_blk = 32;
751  static constexpr index_t k_per_blk = 8;
752  static constexpr bool is_k_reduction = true;
753 
754  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
755  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
756  {
758  }
759 };
760 
761 template <>
763 {
764  static constexpr index_t group_size = 4;
765  static constexpr index_t num_groups_per_blk = 1;
766  static constexpr index_t num_regs_per_blk = 4;
767  static constexpr index_t num_threads_per_blk = 16;
768  static constexpr index_t wave_size = 64;
769  static constexpr index_t num_input_blks = 4;
770  static constexpr index_t num_output_blks = 1;
771  static constexpr index_t m_per_blk = 16;
772  static constexpr index_t n_per_blk = 16;
773  static constexpr index_t k_per_blk = 8;
774  static constexpr bool is_k_reduction = true;
775 
776  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
777  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
778  {
780  }
781 };
782 
783 // TODO: fix mfma...f8f6f4 instructions
784 template <>
786 {
787  // clang-format off
788  static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
789  static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
790  static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size
791  static constexpr index_t num_threads_per_blk = 32; // n_per_blk
792  static constexpr index_t wave_size = 64; // fixed
793  static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk
794  static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
795  static constexpr index_t m_per_blk = 32; // from the instruction
796  static constexpr index_t n_per_blk = 32; // from the instruction
797  static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks
798  static constexpr bool is_k_reduction = true; // ???
799  // clang-format on
800 
801  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
802  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
803  {
805  }
806 };
807 
808 template <>
810 {
811  // clang-format off
812  static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
813  static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk
814  static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size
815  static constexpr index_t num_threads_per_blk = 16; // == n_per_blk
816  static constexpr index_t wave_size = 64; // fixed
817  static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk
818  static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
819  static constexpr index_t m_per_blk = 16; // from the instruction
820  static constexpr index_t n_per_blk = 16; // from the instruction
821  static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks
822  static constexpr bool is_k_reduction = true; // ???
823  // clang-format on
824 
825  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
826  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
827  {
829  }
830 };
831 
832 template <>
834 {
835  // clang-format off
836  static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
837  static constexpr index_t num_groups_per_blk = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
838  static constexpr index_t num_regs_per_blk = 16; // m_per_blk * n_per_blk / wave_size
839  static constexpr index_t num_threads_per_blk = 32; // n_per_blk
840  static constexpr index_t wave_size = 64; // fixed
841  static constexpr index_t num_input_blks = 2; // m_per_blk / num_regs_per_blk
842  static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
843  static constexpr index_t m_per_blk = 32; // from the instruction
844  static constexpr index_t n_per_blk = 32; // from the instruction
845  static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 64 / num_input_blks
846  static constexpr bool is_k_reduction = true; // ???
847  // clang-format on
848 
849  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
850  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
851  {
853  }
854 };
855 
856 template <>
858 {
859  // clang-format off
860  static constexpr index_t group_size = 4; // ??? group_size * num_groups_per_blk == num_regs_per_blk
861  static constexpr index_t num_groups_per_blk = 1; // ??? group_size * num_groups_per_blk == num_regs_per_blk
862  static constexpr index_t num_regs_per_blk = 4; // m_per_blk * n_per_blk / wave_size
863  static constexpr index_t num_threads_per_blk = 16; // == n_per_blk
864  static constexpr index_t wave_size = 64; // fixed
865  static constexpr index_t num_input_blks = 4; // m_per_blk / num_regs_per_blk
866  static constexpr index_t num_output_blks = 1; // (is_k_reduction == true) ???
867  static constexpr index_t m_per_blk = 16; // from the instruction
868  static constexpr index_t n_per_blk = 16; // from the instruction
869  static constexpr index_t k_per_blk = 32; // (is_k_reduction == true) ? 128 / num_input_blks
870  static constexpr bool is_k_reduction = true; // ???
871  // clang-format on
872 
873  template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
874  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
875  {
877  }
878 };
879 
880 template <typename base_type,
881  index_t MPerXdlops,
882  index_t NPerXdlops,
883  typename additional_type = base_type,
884  bool is_single_rate_mfma = false>
886 {
887  template <typename base_type_,
888  index_t MPerXdlops_,
889  index_t NPerXdlops_,
890  typename additional_type_ = base_type_,
891  bool is_single_rate_mfma_ = false>
892  static constexpr auto GetMfma();
893 
894  template <>
895  constexpr auto GetMfma<double, 16, 16>()
896  {
898  }
899 
900  template <>
901  constexpr auto GetMfma<float, 64, 64>()
902  {
904  }
905 
906  template <>
907  constexpr auto GetMfma<float, 32, 64>()
908  {
910  }
911 
912  template <>
913  constexpr auto GetMfma<float, 16, 64>()
914  {
916  }
917 
918  template <>
919  constexpr auto GetMfma<float, 8, 64>()
920  {
922  }
923 
924  template <>
925  constexpr auto GetMfma<float, 4, 64>()
926  {
928  }
929 
930  template <>
931  constexpr auto GetMfma<float, 32, 32>()
932  {
934  }
935 
936  template <>
937  constexpr auto GetMfma<float, 16, 16>()
938  {
940  }
941 
942  template <>
943  constexpr auto GetMfma<half_t, 64, 64>()
944  {
946  }
947 
948  template <>
949  constexpr auto GetMfma<half_t, 32, 64>()
950  {
952  }
953 
954  template <>
955  constexpr auto GetMfma<half_t, 32, 32, half_t, false>()
956  {
957 #if defined(__gfx950__)
959 #else
961 #endif
962  }
963  template <>
964  constexpr auto GetMfma<half_t, 32, 32, half_t, true>()
965  {
967  }
968 
969  template <>
970  constexpr auto GetMfma<half_t, 16, 16, half_t, false>()
971  {
972 #if defined(__gfx950__)
974 #else
976 #endif
977  }
978 
979  template <>
980  constexpr auto GetMfma<half_t, 16, 16, half_t, true>()
981  {
983  }
984 
985  template <>
986  constexpr auto GetMfma<half_t, 16, 64>()
987  {
989  }
990 
991  template <>
992  constexpr auto GetMfma<half_t, 8, 64>()
993  {
995  }
996 
997  template <>
998  constexpr auto GetMfma<half_t, 4, 64>()
999  {
1001  }
1002 
1003  template <>
1004  constexpr auto GetMfma<bhalf_t, 32, 32, bhalf_t, false>()
1005  {
1006 #if defined(__gfx950__)
1008 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1010 #else
1012 #endif
1013  }
1014 
1015  template <>
1016  constexpr auto GetMfma<bhalf_t, 32, 32, bhalf_t, true>()
1017  {
1018 #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1020 #else
1022 #endif
1023  }
1024 
1025  template <>
1026  constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, false>()
1027  {
1028 #if defined(__gfx950__)
1030 #elif defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1032 #else
1034 #endif
1035  }
1036 
1037  template <>
1038  constexpr auto GetMfma<bhalf_t, 16, 16, bhalf_t, true>()
1039  {
1040 #if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
1042 #else
1044 #endif
1045  }
1046 
1047 #if defined(__gfx950__)
1048  template <>
1049  constexpr auto GetMfma<int8_t, 32, 32>()
1050  {
1052  }
1053  template <>
1054  constexpr auto GetMfma<int8_t, 16, 16>()
1055  {
1057  }
1058 #elif defined(__gfx942__)
1059  template <>
1060  constexpr auto GetMfma<int8_t, 32, 32>()
1061  {
1063  }
1064  template <>
1065  constexpr auto GetMfma<int8_t, 16, 16>()
1066  {
1068  }
1069 #else
1070  template <>
1071  constexpr auto GetMfma<int8_t, 32, 32>()
1072  {
1074  }
1075  template <>
1076  constexpr auto GetMfma<int8_t, 16, 16>()
1077  {
1079  }
1080 #endif
1081 
1082  template <>
1083  constexpr auto GetMfma<f8_t, 32, 32>()
1084  {
1086  }
1087 
1088  template <>
1089  constexpr auto GetMfma<f8_t, 16, 16>()
1090  {
1092  }
1093 
1094  template <>
1095  constexpr auto GetMfma<bf8_t, 32, 32>()
1096  {
1098  }
1099 
1100  template <>
1101  constexpr auto GetMfma<bf8_t, 16, 16>()
1102  {
1104  }
1105 
1106  template <>
1107  constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
1108  {
1110  }
1111 
1112  template <>
1113  constexpr auto GetMfma<f8_t, 16, 16, bf8_t>()
1114  {
1116  }
1117 
1118  template <>
1119  constexpr auto GetMfma<bf8_t, 32, 32, f8_t>()
1120  {
1122  }
1123 
1124  template <>
1125  constexpr auto GetMfma<bf8_t, 16, 16, f8_t>()
1126  {
1128  }
1129 
1130  static constexpr auto selected_mfma = mfma_type<
1131  GetMfma<base_type, MPerXdlops, NPerXdlops, additional_type, is_single_rate_mfma>()>{};
1132 
1133  __host__ __device__ constexpr MfmaSelector()
1134  {
1135  static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk ==
1136  selected_mfma.num_regs_per_blk,
1137  "wrong! num_regs_per_blk");
1138 
1139  static_assert(selected_mfma.num_threads_per_blk == selected_mfma.n_per_blk,
1140  "n_per_blk != num_threads_per_blk");
1141 
1142  static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks ==
1143  selected_mfma.m_per_blk,
1144  "m_per_blk != num_input_blks * num_regs_per_blk");
1145 
1146  static_assert(selected_mfma.num_output_blks == selected_mfma.num_input_blks ||
1147  selected_mfma.num_output_blks == 1,
1148  "incorrect num_output_blks");
1149 
1150  static_assert(selected_mfma.num_regs_per_blk * selected_mfma.wave_size ==
1151  selected_mfma.m_per_blk * selected_mfma.n_per_blk,
1152  "num_regs_per_blk incorrect");
1153 
1154  static_assert(selected_mfma.is_k_reduction ||
1155  (selected_mfma.num_input_blks == selected_mfma.num_output_blks),
1156  "is_k_reduction wrong!");
1157  }
1158 
1159  static constexpr bool IsABroadcast()
1160  {
1161  static_assert(NPerXdlops >= MPerXdlops, "only support ABroadcast");
1162  return true;
1163  }
1164 
1165  static constexpr index_t GetKPerXdlops()
1166  {
1167  return (selected_mfma.is_k_reduction ? selected_mfma.num_input_blks : 1) *
1168  selected_mfma.k_per_blk;
1169  }
1170 
1171  static constexpr index_t GetK1PerXdlops() { return selected_mfma.k_per_blk; }
1172 };
1173 
1174 template <typename base_type,
1175  index_t MPerXdlops,
1176  index_t NPerXdlops,
1177  index_t KPack,
1178  typename additional_type = base_type,
1179  bool TransposeC = false>
1181 {
1182  static constexpr auto I0 = Number<0>{};
1183  static constexpr auto I1 = Number<1>{};
1184  static constexpr auto I2 = Number<2>{};
1185  static constexpr auto I3 = Number<3>{};
1186  static constexpr auto I4 = Number<4>{};
1187  static constexpr auto I5 = Number<5>{};
1188 
1191 
1192  __device__ static constexpr index_t GetNumBlks() { return mfma_instr.num_output_blks; }
1193 
1194  __device__ static constexpr index_t GetNumXdlops()
1195  {
1196  return MPerXdlops * NPerXdlops /
1197  (mfma_instr.m_per_blk * mfma_instr.n_per_blk * mfma_instr.num_output_blks);
1198  }
1199 
1200  __host__ __device__ constexpr XdlopsGemm()
1201  {
1202  static_assert(NPerXdlops == 4 || NPerXdlops == 8 || NPerXdlops == 16 || NPerXdlops == 32 ||
1203  NPerXdlops == 64,
1204  "Only support GemmNPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
1205 
1206  static_assert(MPerXdlops == 4 || MPerXdlops == 8 || MPerXdlops == 16 || MPerXdlops == 32 ||
1207  MPerXdlops == 64,
1208  "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops");
1209 
1210  static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk");
1211  }
1212 
1213  // XDL output supporting C = A * B
1214  // M2_N2 -> M2_M3_M4_N2
1215  template <typename CDesc_M0_N0_M1_N1_M2_N2>
1216  __host__ __device__ static constexpr auto
1217  MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
1218  {
1219  const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
1220  const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
1221  const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
1222  const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
1223 
1225  c_desc_m0_n0_m1_n1_m2_n2,
1231  Number<mfma_instr.num_input_blks>{},
1232  Number<mfma_instr.group_size>{})),
1235  Sequence<1>{},
1236  Sequence<2>{},
1237  Sequence<3>{},
1238  Sequence<4>{},
1239  Sequence<5>{}),
1241  Sequence<1>{},
1242  Sequence<2>{},
1243  Sequence<3>{},
1245  Sequence<7>{}));
1246  }
1247 
1248  // transposed XDL output supporting C' = B' * A'
1249  // M2_N2 -> M2_N2_N3_N4
1250  template <typename CDesc_M0_N0_M1_N1_M2_N2>
1251  __host__ __device__ static constexpr auto
1252  MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
1253  {
1254  const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
1255  const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
1256  const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
1257  const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
1258 
1260  c_desc_m0_n0_m1_n1_m2_n2,
1267  Number<mfma_instr.num_input_blks>{},
1268  Number<mfma_instr.group_size>{}))),
1270  Sequence<1>{},
1271  Sequence<2>{},
1272  Sequence<3>{},
1273  Sequence<4>{},
1274  Sequence<5>{}),
1276  Sequence<1>{},
1277  Sequence<2>{},
1278  Sequence<3>{},
1279  Sequence<4>{},
1280  Sequence<5, 6, 7>{}));
1281  }
1282 
1283  template <typename CDesc_G_M0_N0_M1_N1_M2_N2>
1284  __host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
1285  const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2)
1286  {
1287  const auto G = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I0);
1288  const auto M0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I1);
1289  const auto N0 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I2);
1290  const auto M1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I3);
1291  const auto N1 = c_desc_g_m0_n0_m1_n1_m2_n2.GetLength(I4);
1292 
1294  c_desc_g_m0_n0_m1_n1_m2_n2,
1300  make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk,
1301  mfma_instr.num_input_blks,
1302  mfma_instr.group_size)),
1303  make_pass_through_transform(mfma_instr.num_threads_per_blk)),
1305  Sequence<1>{},
1306  Sequence<2>{},
1307  Sequence<3>{},
1308  Sequence<4>{},
1309  Sequence<5>{},
1310  Sequence<6>{}),
1312  Sequence<1>{},
1313  Sequence<2>{},
1314  Sequence<3>{},
1315  Sequence<4>{},
1317  Sequence<8>{}));
1318  }
1319 
1320  __device__ static constexpr index_t GetRegSizePerXdlops()
1321  {
1322  return MPerXdlops * NPerXdlops / mfma_instr.wave_size;
1323  }
1324 
1325  __device__ static constexpr index_t GetWaveSize() { return mfma_instr.wave_size; }
1326 
1327  template <class FloatA, class FloatB, class FloatC>
1328  __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
1329  {
1330  static_assert(
1337  "base base_type must be double, float, half, bfloat16, int8_t, f8_t or bf8_t!");
1338 
1339  static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
1340  if constexpr(!TransposeC)
1341  {
1342  mfma_instr.template run<MPerXdlops, NPerXdlops>(
1343  p_a_wave[k], p_b_wave[k], p_c_thread);
1344  }
1345  else
1346  {
1347  mfma_instr.template run<MPerXdlops, NPerXdlops>(
1348  p_b_wave[k], p_a_wave[k], p_c_thread);
1349  }
1350  });
1351  }
1352 
1353  __device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_instr.wave_size; }
1354 
1355  __device__ static auto GetBlkIdx()
1356  {
1357  const auto laneId = GetLaneId();
1358 
1359  constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor(
1361  make_tuple(1, mfma_instr.num_input_blks, mfma_instr.num_threads_per_blk))),
1363  make_tuple(Sequence<0>{}));
1364 
1365  const auto blk_idx =
1366  threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId));
1367 
1368  const auto blk_id = blk_idx[I1];
1369  const auto blk_td = blk_idx[I2];
1370 
1371  return make_tuple(blk_id, blk_td);
1372  }
1373 
1374  __host__ __device__ static auto CalculateAThreadOriginDataIndex()
1375  {
1376  const auto laneId = GetLaneId();
1377  const auto blk_idx = GetBlkIdx();
1378 
1379  const auto blk_id = blk_idx[I0];
1380  const auto blk_td = blk_idx[I1];
1381 
1382  if constexpr(mfma_instr.is_k_reduction)
1383  {
1384  return make_tuple(blk_id, blk_td);
1385  }
1386  else
1387  {
1388  return make_tuple(0, laneId);
1389  }
1390  }
1391 
1392  __host__ __device__ static auto CalculateBThreadOriginDataIndex()
1393  {
1394  const auto laneId = GetLaneId();
1395  const auto blk_idx = GetBlkIdx();
1396 
1397  const auto blk_id = blk_idx[I0];
1398  const auto blk_td = blk_idx[I1];
1399 
1400  if constexpr(mfma_instr.is_k_reduction)
1401  {
1402  return make_tuple(blk_id, blk_td);
1403  }
1404  else
1405  {
1406  return make_tuple(0, laneId);
1407  }
1408  }
1409 
1410  __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
1411  {
1412  const auto blk_idx = GetBlkIdx();
1413 
1414  const auto blk_id = blk_idx[I0];
1415  const auto blk_td = blk_idx[I1];
1416 
1417  index_t n_offset = blk_i * mfma_instr.n_per_blk + blk_td;
1418  index_t m_offset = xdlops_i * mfma_instr.m_per_blk + blk_id * mfma_instr.group_size;
1419 
1420  return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
1421  }
1422 
1423  __device__ static CIndex4D GetBeginOfThreadBlk4D(index_t /* xdlops_i */, index_t /* blk_i */)
1424  {
1425  const auto blk_idx = GetBlkIdx();
1426 
1427  const auto blk_id = blk_idx[I0];
1428  const auto blk_td = blk_idx[I1];
1429 
1430  return TransposeC ? CIndex4D{blk_td, I0, blk_id, I0} : CIndex4D{I0, blk_id, I0, blk_td};
1431  }
1432 
1433  // Falls back to single rate instruction on gfx950 if KPack <= 4; no change on gfx942-
1434  static constexpr auto
1435  mfma = MfmaSelector < base_type,
1436  MPerXdlops, NPerXdlops, additional_type,
1438  ? true
1439  : false > {};
1440 
1441  static constexpr auto mfma_instr = mfma.selected_mfma;
1442 
1443  static constexpr auto KPerXdlops = mfma.GetKPerXdlops();
1444  static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops();
1445  static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops;
1446 
1447  __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths()
1448  {
1449  return make_tuple(
1451  }
1452 };
1453 
1454 } // namespace ck
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
MfmaInstr
Definition: xdlops_gemm.hpp:13
@ mfma_scale_f32_32x32x64f8f6f4
@ mfma_f32_16x16x16bf16_1k
@ mfma_scale_f32_16x16x128f8f6f4
@ mfma_f32_16x16x128f8f6f4
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:16
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
Definition: array.hpp:14
Definition: xdlops_gemm.hpp:886
static constexpr bool IsABroadcast()
Definition: xdlops_gemm.hpp:1159
static constexpr index_t GetKPerXdlops()
Definition: xdlops_gemm.hpp:1165
static constexpr auto GetMfma()
__host__ constexpr __device__ MfmaSelector()
Definition: xdlops_gemm.hpp:1133
static constexpr auto selected_mfma
Definition: xdlops_gemm.hpp:1130
static constexpr index_t GetK1PerXdlops()
Definition: xdlops_gemm.hpp:1171
Definition: sequence.hpp:43
Definition: xdlops_gemm.hpp:1181
static __device__ auto GetLaneId()
Definition: xdlops_gemm.hpp:1353
__host__ static __device__ auto CalculateBThreadOriginDataIndex()
Definition: xdlops_gemm.hpp:1392
static constexpr __device__ index_t GetNumBlks()
Definition: xdlops_gemm.hpp:1192
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1252
static constexpr auto KPerXdlops
Definition: xdlops_gemm.hpp:1443
__host__ constexpr __device__ XdlopsGemm()
Definition: xdlops_gemm.hpp:1200
static constexpr auto mfma_instr
Definition: xdlops_gemm.hpp:1441
static __device__ auto GetBlkIdx()
Definition: xdlops_gemm.hpp:1355
static constexpr __device__ index_t GetRegSizePerXdlops()
Definition: xdlops_gemm.hpp:1320
static constexpr auto I5
Definition: xdlops_gemm.hpp:1187
static constexpr auto K0PerXdlops
Definition: xdlops_gemm.hpp:1445
static constexpr auto I4
Definition: xdlops_gemm.hpp:1186
static constexpr __device__ index_t GetWaveSize()
Definition: xdlops_gemm.hpp:1325
static constexpr auto K1PerXdlops
Definition: xdlops_gemm.hpp:1444
static __device__ CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i)
Definition: xdlops_gemm.hpp:1410
__host__ static __device__ auto CalculateAThreadOriginDataIndex()
Definition: xdlops_gemm.hpp:1374
static constexpr auto mfma
Definition: xdlops_gemm.hpp:1435
static constexpr auto I3
Definition: xdlops_gemm.hpp:1185
static constexpr auto I2
Definition: xdlops_gemm.hpp:1184
__host__ static constexpr __device__ auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_G_M0_N0_M1_N1_M2_N2 &c_desc_g_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1284
__device__ void Run(const FloatA &p_a_wave, const FloatB &p_b_wave, FloatC &p_c_thread) const
Definition: xdlops_gemm.hpp:1328
__host__ static constexpr __device__ auto GetCM0M1M2NThreadBlkLengths()
Definition: xdlops_gemm.hpp:1447
__host__ static constexpr __device__ auto MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2 &c_desc_m0_n0_m1_n1_m2_n2)
Definition: xdlops_gemm.hpp:1217
static constexpr auto I1
Definition: xdlops_gemm.hpp:1183
static constexpr auto I0
Definition: xdlops_gemm.hpp:1182
static __device__ CIndex4D GetBeginOfThreadBlk4D(index_t, index_t)
Definition: xdlops_gemm.hpp:1423
static constexpr __device__ index_t GetNumXdlops()
Definition: xdlops_gemm.hpp:1194
Definition: integral_constant.hpp:10
Definition: amd_xdlops.hpp:587
Definition: amd_xdlops.hpp:302
Definition: amd_xdlops.hpp:192
Definition: amd_xdlops.hpp:69
Definition: amd_xdlops.hpp:268
Definition: amd_xdlops.hpp:718
Definition: amd_xdlops.hpp:844
Definition: amd_xdlops.hpp:158
Definition: amd_xdlops.hpp:781
Definition: amd_xdlops.hpp:655
Definition: amd_xdlops.hpp:206
Definition: amd_xdlops.hpp:55
Definition: amd_xdlops.hpp:330
Definition: amd_xdlops.hpp:248
Definition: amd_xdlops.hpp:686
Definition: amd_xdlops.hpp:812
Definition: amd_xdlops.hpp:138
Definition: amd_xdlops.hpp:749
Definition: amd_xdlops.hpp:623
Definition: amd_xdlops.hpp:14
Definition: amd_xdlops.hpp:41
Definition: amd_xdlops.hpp:316
Definition: amd_xdlops.hpp:111
Definition: amd_xdlops.hpp:480
Definition: amd_xdlops.hpp:288
Definition: amd_xdlops.hpp:178
Definition: amd_xdlops.hpp:83
Definition: amd_xdlops.hpp:220
Definition: amd_xdlops.hpp:460
Definition: amd_xdlops.hpp:363
Definition: amd_xdlops.hpp:441
Definition: amd_xdlops.hpp:402
Definition: amd_xdlops.hpp:422
Definition: amd_xdlops.hpp:382
Definition: amd_xdlops.hpp:344
Definition: amd_xdlops.hpp:551
Definition: amd_xdlops.hpp:515
Definition: type.hpp:177
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:826
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:403
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:271
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:138
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:381
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:689
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:777
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:249
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:733
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:645
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:293
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:116
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:447
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:337
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:667
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:755
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:227
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:711
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:623
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:72
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:94
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:425
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:183
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:802
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:359
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:205
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:161
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:315
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:601
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:491
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:535
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:579
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:513
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:557
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:469
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:874
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: xdlops_gemm.hpp:850
Definition: xdlops_gemm.hpp:54
Definition: functional2.hpp:31