/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp Source File
wmma_gemm.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
7 #include "ck/utility/math.hpp"
9 
10 namespace ck {
11 
12 enum struct WmmaInstr
13 {
14  // gfx11
21  // gfx12
29 };
30 
31 /*
32  * WMMA Wave Tile Always MxNxK = 16x16x16
33  * WAVE32
34  -----------------------------------
35  |RC0| | | | | | | | | | | | | | | | SubGroup 0
36  |RC1| | | | | | | | | | | | | | | |
37  |RC2| | | | | | | | | | | | | | | |
38  |RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
39  |RC4|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1|
40  |RC5|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5|
41  |RC6| | | | | | | | | | | | | | | |
42  |RC7| | | | | | | | | | | | | | | |
43  -----------------------------------
44  | | | | | | | | | | | | | | | | | SubGroup 1
45  | | | | | | | | | | | | | | | | |
46  | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
47  | 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3|
48  | 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1|
49  | | | | | | | | | | | | | | | | |
50  | | | | | | | | | | | | | | | | |
51  | | | | | | | | | | | | | | | | |
52  -----------------------------------
53 
54 
55  * WAVE64
56  -----------------------------------
57  |RC0|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 0
58  |RC1|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1|
59  |RC2|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5|
60  |RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
61  -----------------------------------
62  | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 1
63  | 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3|
64  | 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1|
65  | | | | | | | | | | | | | | | | |
66  -----------------------------------
67  | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 2
68  | 3 |3|3|3|3|3|3|3|4|4|4|4|4|4|4|4|
69  | 2 |3|4|5|6|7|8|9|0|1|2|3|4|5|6|7|
70  | | | | | | | | | | | | | | | | |
71  -----------------------------------
72  | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 3
73  | 4 |4|5|5|5|5|5|5|5|5|5|5|6|6|6|6|
74  | 8 |9|0|1|2|3|4|5|6|7|8|9|0|1|2|3|
75  | | | | | | | | | | | | | | | | |
76  -----------------------------------
77 
78 * RC = Register for storing accumalted result
79 * T = Thread ID
80 */
81 
82 template <WmmaInstr Instr, index_t WaveSize, typename = void>
83 struct wmma_type
84 {
85 };
86 
87 // A-swizzled
88 template <index_t WaveSize>
90  WaveSize,
91  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
92 {
93  // Absolute fixing property
94  // * Data Pixel
95  static constexpr index_t m_per_wmma = 16;
96  static constexpr index_t n_per_wmma = 16;
97  static constexpr index_t k_per_wmma = 16;
98  static constexpr index_t k_per_blk = 8;
99  static constexpr index_t src_a_data_size = 2;
100  static constexpr index_t src_b_data_size = 2;
101  static constexpr index_t acc_data_size = 4;
102  static constexpr index_t acc_pack_number = 1;
103  // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
104  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
105 
106  // Wave mode dependent propety
107  static constexpr index_t wave_size = Number<WaveSize>{};
108  // * Fixed on gfx11, Will be wave mode dependent for future architectures
109  static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
110  static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
111  // * num_acc_vgprs_per_wave alone M direction
112  // * num_subgroups alone M direction
113  static constexpr index_t num_acc_vgprs_per_wave =
114  m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
115  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
116 
117  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
118  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
119  {
120  if constexpr(wave_size == 32)
121  {
123  }
124  else if constexpr(wave_size == 64)
125  {
127  }
128  }
129 };
130 
131 template <index_t WaveSize>
133  WaveSize,
134  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
135 {
136  // Absolute fixing property
137  static constexpr index_t m_per_wmma = 16;
138  static constexpr index_t n_per_wmma = 16;
139  static constexpr index_t k_per_wmma = 16;
140  static constexpr index_t k_per_blk = 8;
141  static constexpr index_t src_a_data_size = 2;
142  static constexpr index_t src_b_data_size = 2;
143  static constexpr index_t acc_data_size = 4;
144  static constexpr index_t acc_pack_number = 1;
145  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
146 
147  // Wave mode dependent propety
148  static constexpr index_t wave_size = Number<WaveSize>{};
149  static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
150  static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
151  static constexpr index_t num_acc_vgprs_per_wave =
152  m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
153  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
154 
155  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
156  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
157  {
158  if constexpr(wave_size == 32)
159  {
161  }
162  else if constexpr(wave_size == 64)
163  {
165  }
166  }
167 };
168 
169 template <index_t WaveSize>
171  WaveSize,
172  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
173 {
174  // Absolute fixing property
175  static constexpr index_t m_per_wmma = 16;
176  static constexpr index_t n_per_wmma = 16;
177  static constexpr index_t k_per_wmma = 16;
178  static constexpr index_t k_per_blk = 8;
179  static constexpr index_t src_a_data_size = 2;
180  static constexpr index_t src_b_data_size = 2;
181  static constexpr index_t acc_data_size = 2;
182  static constexpr index_t acc_pack_number = 2;
183  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
184 
185  // Wave mode dependent propety
186  static constexpr index_t wave_size = Number<WaveSize>{};
187  static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
188  static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
189  static constexpr index_t num_acc_vgprs_per_wave =
190  m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
191  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
192 
193  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
194  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
195  {
196  if constexpr(wave_size == 32)
197  {
199  }
200  else if constexpr(wave_size == 64)
201  {
203  }
204  }
205 };
206 template <index_t WaveSize>
208  WaveSize,
209  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
210 {
211  // Absolute fixing property
212  static constexpr index_t m_per_wmma = 16;
213  static constexpr index_t n_per_wmma = 16;
214  static constexpr index_t k_per_wmma = 16;
215  static constexpr index_t k_per_blk = 8;
216  static constexpr index_t src_a_data_size = 2;
217  static constexpr index_t src_b_data_size = 2;
218  static constexpr index_t acc_data_size = 2;
219  static constexpr index_t acc_pack_number = 2;
220  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
221 
222  // Wave mode dependent propety
223  static constexpr index_t wave_size = Number<WaveSize>{};
224  static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
225  static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
226  static constexpr index_t num_acc_vgprs_per_wave =
227  m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
228  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
229 
230  template <index_t MPerWmma,
231  index_t NPerWmma,
232  index_t Opsel,
233  class FloatA,
234  class FloatB,
235  class FloatC>
236  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
237  {
238  if constexpr(wave_size == 32)
239  {
241  }
242  else if constexpr(wave_size == 64)
243  {
245  }
246  }
247 };
248 
249 template <index_t WaveSize>
251  WaveSize,
252  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
253 {
254  // Absolute fixing property
255  static constexpr index_t m_per_wmma = 16;
256  static constexpr index_t n_per_wmma = 16;
257  static constexpr index_t k_per_wmma = 16;
258  static constexpr index_t k_per_blk = 8;
259  static constexpr index_t src_a_data_size = 2;
260  static constexpr index_t src_b_data_size = 2;
261  static constexpr index_t acc_data_size = 4;
262  static constexpr index_t acc_pack_number = 1;
263  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
264 
265  // Wave mode dependent propety
266  static constexpr index_t wave_size = Number<WaveSize>{};
267  static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
268  static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
269  static constexpr index_t num_acc_vgprs_per_wave =
270  m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
271  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
272 
273  template <index_t MPerWmma,
274  index_t NPerWmma,
275  class FloatA,
276  class FloatB,
277  class FloatC,
278  bool neg_a = true,
279  bool neg_b = true,
280  bool clamp = false>
281  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
282  {
283  if constexpr(wave_size == 32)
284  {
286  a, b, reg_c);
287  }
288  else if constexpr(wave_size == 64)
289  {
291  a, b, reg_c);
292  }
293  }
294 };
295 
296 // gfx12
297 
298 // A-swizzled
299 template <index_t WaveSize>
301  WaveSize,
302  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
303 {
304  // Absolute fixing property
305  // * Data Pixel
306  static constexpr index_t m_per_wmma = 16;
307  static constexpr index_t n_per_wmma = 16;
308  static constexpr index_t k_per_wmma = 16;
309  static constexpr index_t k_per_blk = 8;
310  // static constexpr index_t src_a_data_size = 2;
311  // static constexpr index_t src_b_data_size = 2;
312  // static constexpr index_t acc_data_size = 4;
313  // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
314  static constexpr index_t acc_data_size = 4;
315  static constexpr index_t acc_pack_number = 1;
316  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
317 
318  // Wave mode dependent propety
319  static constexpr index_t wave_size = Number<WaveSize>{};
320  // * Fixed for gfx11, Will be wave mode dependent on gfx12
321  // static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4;
322  // static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4;
323  // * num_acc_vgprs_per_wave alone M direction
324  // * num_subgroups alone M direction
325  static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
326  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
327 
328  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
329  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
330  {
331  static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
332  if constexpr(wave_size == 32)
333  {
335  }
336  }
337 };
338 
339 template <index_t WaveSize>
341  WaveSize,
342  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
343 {
344  // Absolute fixing property
345  static constexpr index_t m_per_wmma = 16;
346  static constexpr index_t n_per_wmma = 16;
347  static constexpr index_t k_per_wmma = 16;
348  static constexpr index_t k_per_blk = 8;
349  // static constexpr index_t src_a_data_size = 2;
350  // static constexpr index_t src_b_data_size = 2;
351  static constexpr index_t acc_data_size = 4;
352  static constexpr index_t acc_pack_number = 1;
353  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
354 
355  // Wave mode dependent propety
356  static constexpr index_t wave_size = Number<WaveSize>{};
357  // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
358  // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
359  static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
360  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
361 
362  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
363  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
364  {
365  static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
366  if constexpr(wave_size == 32)
367  {
369  }
370  }
371 };
372 
373 template <index_t WaveSize>
375  WaveSize,
376  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
377 {
378  // Absolute fixing property
379  static constexpr index_t m_per_wmma = 16;
380  static constexpr index_t n_per_wmma = 16;
381  static constexpr index_t k_per_wmma = 16;
382  static constexpr index_t k_per_blk = 8;
383  // static constexpr index_t src_a_data_size = 2;
384  // static constexpr index_t src_b_data_size = 2;
385  static constexpr index_t acc_data_size = 4;
386  static constexpr index_t acc_pack_number = 1;
387  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
388 
389  // Wave mode dependent propety
390  static constexpr index_t wave_size = Number<WaveSize>{};
391  // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
392  // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
393  static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
394  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
395 
396  template <index_t MPerWmma,
397  index_t NPerWmma,
398  class FloatA,
399  class FloatB,
400  class FloatC,
401  bool neg_a = true,
402  bool neg_b = true,
403  bool clamp = false>
404  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
405  {
406  static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
407  if constexpr(wave_size == 32)
408  {
410  a, b, reg_c);
411  }
412  }
413 };
414 
415 template <index_t WaveSize>
417  WaveSize,
418  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
419 {
420  // Absolute fixing property
421  static constexpr index_t m_per_wmma = 16;
422  static constexpr index_t n_per_wmma = 16;
423  static constexpr index_t k_per_wmma = 16;
424  static constexpr index_t k_per_blk = 8;
425  static constexpr index_t acc_data_size = 4;
426  static constexpr index_t acc_pack_number = 1;
427  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
428 
429  // Wave mode dependent propety
430  static constexpr index_t wave_size = Number<WaveSize>{};
431  static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
432  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
433 
434  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
435  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
436  {
437  static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
438  if constexpr(wave_size == 32)
439  {
440 #ifdef __gfx12__
442 #else
443  ignore = a;
444  ignore = b;
445  ignore = reg_c;
446 #endif
447  }
448  }
449 };
450 
451 template <index_t WaveSize>
453  WaveSize,
454  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
455 {
456  // Absolute fixing property
457  static constexpr index_t m_per_wmma = 16;
458  static constexpr index_t n_per_wmma = 16;
459  static constexpr index_t k_per_wmma = 16;
460  static constexpr index_t k_per_blk = 8;
461  static constexpr index_t acc_data_size = 4;
462  static constexpr index_t acc_pack_number = 1;
463  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
464 
465  // Wave mode dependent propety
466  static constexpr index_t wave_size = Number<WaveSize>{};
467  static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
468  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
469 
470  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
471  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
472  {
473  static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
474  if constexpr(wave_size == 32)
475  {
476 #ifdef __gfx12__
478 #else
479  ignore = a;
480  ignore = b;
481  ignore = reg_c;
482 #endif
483  }
484  }
485 };
486 
487 template <index_t WaveSize>
489  WaveSize,
490  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
491 {
492  // Absolute fixing property
493  static constexpr index_t m_per_wmma = 16;
494  static constexpr index_t n_per_wmma = 16;
495  static constexpr index_t k_per_wmma = 16;
496  static constexpr index_t k_per_blk = 8;
497  static constexpr index_t acc_data_size = 4;
498  static constexpr index_t acc_pack_number = 1;
499  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
500 
501  // Wave mode dependent propety
502  static constexpr index_t wave_size = Number<WaveSize>{};
503  static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
504  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
505 
506  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
507  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
508  {
509  static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
510  if constexpr(wave_size == 32)
511  {
512 #ifdef __gfx12__
514 #else
515  ignore = a;
516  ignore = b;
517  ignore = reg_c;
518 #endif
519  }
520  }
521 };
522 
523 template <index_t WaveSize>
525  WaveSize,
526  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
527 {
528  // Absolute fixing property
529  static constexpr index_t m_per_wmma = 16;
530  static constexpr index_t n_per_wmma = 16;
531  static constexpr index_t k_per_wmma = 16;
532  static constexpr index_t k_per_blk = 8;
533  static constexpr index_t acc_data_size = 4;
534  static constexpr index_t acc_pack_number = 1;
535  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
536 
537  // Wave mode dependent propety
538  static constexpr index_t wave_size = Number<WaveSize>{};
539  static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
540  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
541 
542  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
543  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
544  {
545  static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
546  if constexpr(wave_size == 32)
547  {
548 #ifdef __gfx12__
550 #else
551  ignore = a;
552  ignore = b;
553  ignore = reg_c;
554 #endif
555  }
556  }
557 };
558 
559 template <typename src_type_a,
560  typename src_type_b,
561  typename dst_type,
562  index_t MPerWmma,
563  index_t NPerWmma>
565 {
566  template <typename src_type_a_,
567  typename src_type_b_,
568  typename dst_type_,
569  index_t MPerWmma_,
570  index_t NPerWmma_>
571  static constexpr auto GetWmma();
572 
573  template <>
574  constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
575  {
576 #ifdef __gfx12__
578 #else
580 #endif
581  }
582 
583  template <>
584  constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
585  {
586 #ifdef __gfx12__
588 #else
590 #endif
591  }
592 
593  template <>
594  constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>()
595  {
597  }
598 
599  template <>
600  constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
601  {
603  }
604 
605  template <>
606  constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
607  {
608 #ifdef __gfx12__
610 #else
612 #endif
613  }
614 
615 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
616  template <>
617  constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
618  {
620  }
621 #endif
622 
623  template <>
624  constexpr auto GetWmma<f8_t, f8_t, float, 16, 16>()
625  {
627  }
628 
629  template <>
630  constexpr auto GetWmma<f8_t, bf8_t, float, 16, 16>()
631  {
633  }
634 
635  template <>
636  constexpr auto GetWmma<bf8_t, f8_t, float, 16, 16>()
637  {
639  }
640 
641  template <>
642  constexpr auto GetWmma<bf8_t, bf8_t, float, 16, 16>()
643  {
645  }
646 
647  // get_warp_size do not return the correct wavesize, hardcode to 32 as workaround
648  static constexpr auto selected_wmma =
650 
651  __host__ __device__ constexpr WmmaSelector()
652  {
653  static_assert(selected_wmma.m_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
654 
655  static_assert(selected_wmma.m_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
656 
657  static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
658 
659  static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave *
660  selected_wmma.acc_data_size * selected_wmma.acc_pack_number ==
661  selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4,
662  "WRONG! Invalid Number of Accumulator Register");
663  }
664 };
665 
666 template <typename src_type_a,
667  typename src_type_b,
668  typename dst_type,
669  index_t MPerWmma,
670  index_t NPerWmma,
671  index_t KPack,
672  bool TransposeC = false,
673  bool AssemblyBackend = false>
674 struct WmmaGemm
675 {
676  static constexpr auto I0 = Number<0>{};
677  static constexpr auto I1 = Number<1>{};
678  static constexpr auto I2 = Number<2>{};
679  static constexpr auto I3 = Number<3>{};
680  static constexpr auto I4 = Number<4>{};
681  static constexpr auto I5 = Number<5>{};
682 
685 
686  __host__ __device__ constexpr WmmaGemm()
687  {
688  static_assert(NPerWmma == 16 && MPerWmma == 16,
689  "Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma");
690 
691  static_assert(KPack % wmma_instr.k_per_wmma == 0, "KPack should be multiple of k_per_wmma");
692  }
693 
694  // WMMA output supporting C = A * B
695  // Vector Write
696  // MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave
697  template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
698  __host__ __device__ static constexpr auto
700  const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
701  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
702  {
703  const auto MBlockxRepeat =
704  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
705  const auto NBlockxRepeat =
706  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
707  const auto MWave =
708  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
709  const auto NWave =
710  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
711 
713  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
714  make_tuple(
715  make_pass_through_transform(MBlockxRepeat),
718  Number<wmma_instr.num_acc_vgprs_per_wave>{})),
719  make_pass_through_transform(NBlockxRepeat),
723  Sequence<1>{},
724  Sequence<2>{},
725  Sequence<3>{},
726  Sequence<4>{},
727  Sequence<5>{}),
729  Sequence<1>{},
730  Sequence<2, 6>{},
731  Sequence<3>{},
732  Sequence<4>{},
733  Sequence<5>{}));
734  }
735 
736  // Transposed WMMA Output C' = B' * A'
737  template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
738  __host__ __device__ static constexpr auto
740  const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
741  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
742  {
743  const auto MBlockxRepeat =
744  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
745  const auto NBlockxRepeat =
746  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
747  const auto MWave =
748  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
749  const auto NWave =
750  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
751 
753  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
754  make_tuple(
755  make_pass_through_transform(MBlockxRepeat),
758  make_pass_through_transform(NBlockxRepeat),
761  Number<wmma_instr.num_acc_vgprs_per_wave>{}))),
763  Sequence<1>{},
764  Sequence<2>{},
765  Sequence<3>{},
766  Sequence<4>{},
767  Sequence<5>{}),
769  Sequence<1>{},
770  Sequence<2>{},
771  Sequence<3>{},
772  Sequence<4>{},
773  Sequence<5, 6>{}));
774  }
775 
776  __device__ static constexpr index_t GetRegSizePerWmma()
777  {
778  return wmma_instr.num_acc_vgprs_per_wave * wmma_instr.acc_pack_number;
779  }
780 
781  __device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; }
782 
783  __device__ static constexpr index_t GetKPerWaveBlk() { return wmma_instr.k_per_blk; }
784 
785  template <class FloatA, class FloatB, class FloatC>
786  __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
787  {
788  static_assert(
802 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
805 #endif
806  false,
807  "base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
808  "((f8 or bf8, f8 or bf8), float), (int8, int32) or (int4, int32)!");
809  static_for<0, KPack / wmma_instr.k_per_wmma, 1>{}([&](auto k) {
810  // Integer wmma operators need extra input flags to indicate if the input is signed or
811  // unsigned. At the moment CK supports only signed integer inputs, so these flags are
812  // hardcoded.
813  if constexpr(!TransposeC)
814  {
815  wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave[k], p_b_wave[k], p_c_thread);
816  }
817  else
818  {
819  wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave[k], p_a_wave[k], p_c_thread);
820  }
821  });
822  }
823 
824  __device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; }
825 
826  __device__ static auto GetSubGroupId()
827  {
828  static_assert(wmma_instr.num_thread_per_subgroups * wmma_instr.num_subgroups ==
829  wmma_instr.wave_size,
830  "");
831  return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups;
832  }
833 
834  __device__ static auto GetLaneIdUnderSubGroup()
835  {
836  return GetLaneId() % wmma_instr.num_thread_per_subgroups;
837  }
838  __device__ static auto GetSwizzledLaneIdLow()
839  {
840  return ((GetLaneIdUnderSubGroup() & 1) << 3) | (GetLaneIdUnderSubGroup() >> 1);
841  }
842 
843  __host__ __device__ static auto CalculateAThreadOriginDataIndex()
844  {
845 #ifdef __gfx12__
846  return GetLaneIdUnderSubGroup();
847 #else
848  return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
849 #endif
850  }
851 
852  __host__ __device__ static auto CalculateBThreadOriginDataIndex()
853  {
854 #ifdef __gfx12__
855  return GetLaneIdUnderSubGroup();
856 #else
857  return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
858 #endif
859  }
860 
861  __device__ static CIndex GetBeginOfThreadBlk()
862  {
863  index_t n_offset = GetLaneIdUnderSubGroup();
864  index_t m_offset = GetSubGroupId() * wmma_instr.num_acc_vgprs_per_wave;
865 
866  return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
867  }
868 
869  __device__ static CIndex3D GetBeginOfThreadBlk3D()
870  {
871  index_t n_offset = GetLaneIdUnderSubGroup();
872  index_t m_offset = GetSubGroupId();
873 
874  return TransposeC ? CIndex3D{n_offset, m_offset, I0} : CIndex3D{m_offset, n_offset, I0};
875  }
876 
877  static constexpr auto wmma =
879  static constexpr auto wmma_instr = wmma.selected_wmma;
880 
881  __host__ __device__ static constexpr auto
883  {
884  return make_tuple(I1,
885  I1,
887  Number<wmma_instr.acc_pack_number>{});
888  }
889 };
890 
891 } // namespace ck
__host__ constexpr __device__ T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition: math.hpp:148
Definition: ck.hpp:270
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:301
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
WmmaInstr
Definition: wmma_gemm.hpp:13
@ wmma_f32_16x16x16_bf16_gfx12
@ wmma_i32_16x16x16_iu8_gfx12
@ wmma_f32_16x16x16_bf8f8_gfx12
@ wmma_f32_16x16x16_f16_gfx12
@ wmma_f32_16x16x16_bf8bf8_gfx12
@ wmma_f32_16x16x16_f8f8_gfx12
@ wmma_f32_16x16x16_f8bf8_gfx12
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1517
Definition: array.hpp:14
Definition: sequence.hpp:43
Definition: wmma_gemm.hpp:675
static constexpr auto I0
Definition: wmma_gemm.hpp:676
static __device__ auto GetLaneId()
Definition: wmma_gemm.hpp:824
__device__ void Run(const FloatA &p_a_wave, const FloatB &p_b_wave, FloatC &p_c_thread) const
Definition: wmma_gemm.hpp:786
static constexpr __device__ index_t GetWaveSize()
Definition: wmma_gemm.hpp:781
static constexpr auto wmma
Definition: wmma_gemm.hpp:877
__host__ static constexpr __device__ auto GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
Definition: wmma_gemm.hpp:882
__host__ static __device__ auto CalculateAThreadOriginDataIndex()
Definition: wmma_gemm.hpp:843
static __device__ auto GetSubGroupId()
Definition: wmma_gemm.hpp:826
static __device__ auto GetSwizzledLaneIdLow()
Definition: wmma_gemm.hpp:838
static constexpr __device__ index_t GetKPerWaveBlk()
Definition: wmma_gemm.hpp:783
static constexpr auto I3
Definition: wmma_gemm.hpp:679
static constexpr auto I5
Definition: wmma_gemm.hpp:681
__host__ static __device__ auto CalculateBThreadOriginDataIndex()
Definition: wmma_gemm.hpp:852
__host__ static constexpr __device__ auto MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA &c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
Definition: wmma_gemm.hpp:739
static __device__ CIndex GetBeginOfThreadBlk()
Definition: wmma_gemm.hpp:861
static constexpr auto I4
Definition: wmma_gemm.hpp:680
static constexpr __device__ index_t GetRegSizePerWmma()
Definition: wmma_gemm.hpp:776
__host__ static constexpr __device__ auto MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA &c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
Definition: wmma_gemm.hpp:699
__host__ constexpr __device__ WmmaGemm()
Definition: wmma_gemm.hpp:686
static constexpr auto I2
Definition: wmma_gemm.hpp:678
static __device__ CIndex3D GetBeginOfThreadBlk3D()
Definition: wmma_gemm.hpp:869
static constexpr auto I1
Definition: wmma_gemm.hpp:677
static __device__ auto GetLaneIdUnderSubGroup()
Definition: wmma_gemm.hpp:834
static constexpr auto wmma_instr
Definition: wmma_gemm.hpp:879
Definition: wmma_gemm.hpp:565
static constexpr auto selected_wmma
Definition: wmma_gemm.hpp:648
__host__ constexpr __device__ WmmaSelector()
Definition: wmma_gemm.hpp:651
static constexpr auto GetWmma()
Definition: integral_constant.hpp:20
Definition: amd_wmma.hpp:96
Definition: amd_wmma.hpp:216
Definition: amd_wmma.hpp:72
Definition: amd_wmma.hpp:192
Definition: amd_wmma.hpp:50
Definition: amd_wmma.hpp:170
Definition: amd_wmma.hpp:271
Definition: amd_wmma.hpp:25
Definition: amd_wmma.hpp:149
Definition: amd_wmma.hpp:319
Definition: amd_wmma.hpp:121
Definition: amd_wmma.hpp:241
Definition: type.hpp:177
Definition: functional2.hpp:33
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:236
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:194
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:156
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:363
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:543
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:507
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:118
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:329
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:471
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:435
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:281
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:404
Definition: wmma_gemm.hpp:84