/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp Source File
wmma_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/math.hpp"
9 
10 namespace ck {
11 
12 enum struct WmmaInstr
13 {
14  // gfx11
21  // gfx12
25 };
26 
27 /*
28  * WMMA Wave Tile Always MxNxK = 16x16x16
29  * WAVE32
30  -----------------------------------
31  |RC0| | | | | | | | | | | | | | | | SubGroup 0
32  |RC1| | | | | | | | | | | | | | | |
33  |RC2| | | | | | | | | | | | | | | |
34  |RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
35  |RC4|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1|
36  |RC5|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5|
37  |RC6| | | | | | | | | | | | | | | |
38  |RC7| | | | | | | | | | | | | | | |
39  -----------------------------------
40  | | | | | | | | | | | | | | | | | SubGroup 1
41  | | | | | | | | | | | | | | | | |
42  | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
43  | 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3|
44  | 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1|
45  | | | | | | | | | | | | | | | | |
46  | | | | | | | | | | | | | | | | |
47  | | | | | | | | | | | | | | | | |
48  -----------------------------------
49 
50 
51  * WAVE64
52  -----------------------------------
53  |RC0|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 0
54  |RC1|0|0|0|0|0|0|0|0|0|1|1|1|1|1|1|
55  |RC2|1|2|3|4|5|6|7|8|9|0|1|2|3|4|5|
56  |RC3|T|T|T|T|T|T|T|T|T|T|T|T|T|T|T|
57  -----------------------------------
58  | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 1
59  | 1 |1|1|1|2|2|2|2|2|2|2|2|2|2|3|3|
60  | 6 |7|8|9|0|1|2|3|4|5|6|7|8|9|0|1|
61  | | | | | | | | | | | | | | | | |
62  -----------------------------------
63  | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 2
64  | 3 |3|3|3|3|3|3|3|4|4|4|4|4|4|4|4|
65  | 2 |3|4|5|6|7|8|9|0|1|2|3|4|5|6|7|
66  | | | | | | | | | | | | | | | | |
67  -----------------------------------
68  | T |T|T|T|T|T|T|T|T|T|T|T|T|T|T|T| SubGroup 3
69  | 4 |4|5|5|5|5|5|5|5|5|5|5|6|6|6|6|
70  | 8 |9|0|1|2|3|4|5|6|7|8|9|0|1|2|3|
71  | | | | | | | | | | | | | | | | |
72  -----------------------------------
73 
74 * RC = Register for storing accumalted result
75 * T = Thread ID
76 */
77 
78 template <WmmaInstr Instr, index_t WaveSize, typename = void>
79 struct wmma_type
80 {
81 };
82 
83 // A-swizzled
84 template <index_t WaveSize>
86  WaveSize,
87  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
88 {
89  // Absolute fixing property
90  // * Data Pixel
91  static constexpr index_t m_per_wmma = 16;
92  static constexpr index_t n_per_wmma = 16;
93  static constexpr index_t k_per_wmma = 16;
94  static constexpr index_t src_a_data_size = 2;
95  static constexpr index_t src_b_data_size = 2;
96  static constexpr index_t acc_data_size = 4;
97  static constexpr index_t acc_pack_number = 1;
98  // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
99  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
100 
101  // Wave mode dependent propety
102  static constexpr index_t wave_size = Number<WaveSize>{};
103  // * Fixed on gfx11, Will be wave mode dependent for future architectures
104  static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
105  static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
106  // * num_acc_vgprs_per_wave alone M direction
107  // * num_subgroups alone M direction
108  static constexpr index_t num_acc_vgprs_per_wave =
109  m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
110  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
111 
112  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
113  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
114  {
115  if constexpr(wave_size == 32)
116  {
118  }
119  else if constexpr(wave_size == 64)
120  {
122  }
123  }
124 };
125 
126 template <index_t WaveSize>
128  WaveSize,
129  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
130 {
131  // Absolute fixing property
132  static constexpr index_t m_per_wmma = 16;
133  static constexpr index_t n_per_wmma = 16;
134  static constexpr index_t k_per_wmma = 16;
135  static constexpr index_t src_a_data_size = 2;
136  static constexpr index_t src_b_data_size = 2;
137  static constexpr index_t acc_data_size = 4;
138  static constexpr index_t acc_pack_number = 1;
139  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
140 
141  // Wave mode dependent propety
142  static constexpr index_t wave_size = Number<WaveSize>{};
143  static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
144  static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
145  static constexpr index_t num_acc_vgprs_per_wave =
146  m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
147  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
148 
149  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
150  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
151  {
152  if constexpr(wave_size == 32)
153  {
155  }
156  else if constexpr(wave_size == 64)
157  {
159  }
160  }
161 };
162 
163 template <index_t WaveSize>
165  WaveSize,
166  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
167 {
168  // Absolute fixing property
169  static constexpr index_t m_per_wmma = 16;
170  static constexpr index_t n_per_wmma = 16;
171  static constexpr index_t k_per_wmma = 16;
172  static constexpr index_t src_a_data_size = 2;
173  static constexpr index_t src_b_data_size = 2;
174  static constexpr index_t acc_data_size = 2;
175  static constexpr index_t acc_pack_number = 2;
176  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
177 
178  // Wave mode dependent propety
179  static constexpr index_t wave_size = Number<WaveSize>{};
180  static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
181  static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
182  static constexpr index_t num_acc_vgprs_per_wave =
183  m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
184  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
185 
186  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
187  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
188  {
189  if constexpr(wave_size == 32)
190  {
192  }
193  else if constexpr(wave_size == 64)
194  {
196  }
197  }
198 };
199 template <index_t WaveSize>
201  WaveSize,
202  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
203 {
204  // Absolute fixing property
205  static constexpr index_t m_per_wmma = 16;
206  static constexpr index_t n_per_wmma = 16;
207  static constexpr index_t k_per_wmma = 16;
208  static constexpr index_t src_a_data_size = 2;
209  static constexpr index_t src_b_data_size = 2;
210  static constexpr index_t acc_data_size = 2;
211  static constexpr index_t acc_pack_number = 2;
212  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
213 
214  // Wave mode dependent propety
215  static constexpr index_t wave_size = Number<WaveSize>{};
216  static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
217  static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
218  static constexpr index_t num_acc_vgprs_per_wave =
219  m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
220  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
221 
222  template <index_t MPerWmma,
223  index_t NPerWmma,
224  index_t Opsel,
225  class FloatA,
226  class FloatB,
227  class FloatC>
228  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
229  {
230  if constexpr(wave_size == 32)
231  {
233  }
234  else if constexpr(wave_size == 64)
235  {
237  }
238  }
239 };
240 
241 template <index_t WaveSize>
243  WaveSize,
244  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
245 {
246  // Absolute fixing property
247  static constexpr index_t m_per_wmma = 16;
248  static constexpr index_t n_per_wmma = 16;
249  static constexpr index_t k_per_wmma = 16;
250  static constexpr index_t src_a_data_size = 2;
251  static constexpr index_t src_b_data_size = 2;
252  static constexpr index_t acc_data_size = 4;
253  static constexpr index_t acc_pack_number = 1;
254  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
255 
256  // Wave mode dependent propety
257  static constexpr index_t wave_size = Number<WaveSize>{};
258  static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
259  static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
260  static constexpr index_t num_acc_vgprs_per_wave =
261  m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
262  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
263 
264  template <index_t MPerWmma,
265  index_t NPerWmma,
266  class FloatA,
267  class FloatB,
268  class FloatC,
269  bool neg_a = false,
270  bool neg_b = false,
271  bool clamp = false>
272  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
273  {
274  if constexpr(wave_size == 32)
275  {
277  a, b, reg_c);
278  }
279  else if constexpr(wave_size == 64)
280  {
282  a, b, reg_c);
283  }
284  }
285 };
286 
287 // gfx12
288 
289 // A-swizzled
290 template <index_t WaveSize>
292  WaveSize,
293  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
294 {
295  // Absolute fixing property
296  // * Data Pixel
297  static constexpr index_t m_per_wmma = 16;
298  static constexpr index_t n_per_wmma = 16;
299  static constexpr index_t k_per_wmma = 16;
300  // static constexpr index_t src_a_data_size = 2;
301  // static constexpr index_t src_b_data_size = 2;
302  // static constexpr index_t acc_data_size = 4;
303  // * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
304  static constexpr index_t acc_data_size = 4;
305  static constexpr index_t acc_pack_number = 1;
306  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
307 
308  // Wave mode dependent propety
309  static constexpr index_t wave_size = Number<WaveSize>{};
310  // * Fixed for gfx11, Will be wave mode dependent on gfx12
311  // static constexpr index_t num_src_a_vgprs_per_wave = k_per_wmma / 2 * src_a_data_size / 4;
312  // static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4;
313  // * num_acc_vgprs_per_wave alone M direction
314  // * num_subgroups alone M direction
315  static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
316  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
317 
318  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
319  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
320  {
321  static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
322  if constexpr(wave_size == 32)
323  {
325  }
326  }
327 };
328 
329 template <index_t WaveSize>
331  WaveSize,
332  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
333 {
334  // Absolute fixing property
335  static constexpr index_t m_per_wmma = 16;
336  static constexpr index_t n_per_wmma = 16;
337  static constexpr index_t k_per_wmma = 16;
338  // static constexpr index_t src_a_data_size = 2;
339  // static constexpr index_t src_b_data_size = 2;
340  static constexpr index_t acc_data_size = 4;
341  static constexpr index_t acc_pack_number = 1;
342  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
343 
344  // Wave mode dependent propety
345  static constexpr index_t wave_size = Number<WaveSize>{};
346  // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
347  // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
348  static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
349  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
350 
351  template <index_t MPerWmma, index_t NPerWmma, class FloatA, class FloatB, class FloatC>
352  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
353  {
354  static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
355  if constexpr(wave_size == 32)
356  {
358  }
359  }
360 };
361 
362 template <index_t WaveSize>
364  WaveSize,
365  typename std::enable_if_t<WaveSize == 32 || WaveSize == 64>>
366 {
367  // Absolute fixing property
368  static constexpr index_t m_per_wmma = 16;
369  static constexpr index_t n_per_wmma = 16;
370  static constexpr index_t k_per_wmma = 16;
371  // static constexpr index_t src_a_data_size = 2;
372  // static constexpr index_t src_b_data_size = 2;
373  static constexpr index_t acc_data_size = 4;
374  static constexpr index_t acc_pack_number = 1;
375  static constexpr index_t num_thread_per_subgroups = n_per_wmma;
376 
377  // Wave mode dependent propety
378  static constexpr index_t wave_size = Number<WaveSize>{};
379  // static constexpr index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
380  // static constexpr index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
381  static constexpr index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
382  static constexpr index_t num_subgroups = wave_size / num_thread_per_subgroups;
383 
384  template <index_t MPerWmma,
385  index_t NPerWmma,
386  class FloatA,
387  class FloatB,
388  class FloatC,
389  bool neg_a = false,
390  bool neg_b = false,
391  bool clamp = false>
392  __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
393  {
394  static_assert(wave_size == 32, "only support wave32 for gfx12 wmma");
395  if constexpr(wave_size == 32)
396  {
398  a, b, reg_c);
399  }
400  }
401 };
402 
403 template <typename src_type_a,
404  typename src_type_b,
405  typename dst_type,
406  index_t MPerWmma,
407  index_t NPerWmma>
409 {
410  template <typename src_type_a_,
411  typename src_type_b_,
412  typename dst_type_,
413  index_t MPerWmma_,
414  index_t NPerWmma_>
415  static constexpr auto GetWmma();
416 
417  template <>
418  constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
419  {
420 #ifdef __gfx12__
422 #else
424 #endif
425  }
426 
427  template <>
428  constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
429  {
430 #ifdef __gfx12__
432 #else
434 #endif
435  }
436 
437  template <>
438  constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>()
439  {
441  }
442 
443  template <>
444  constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
445  {
447  }
448 
449  template <>
450  constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
451  {
452 #ifdef __gfx12__
454 #else
456 #endif
457  }
458 
459 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
460  template <>
461  constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
462  {
464  }
465 #endif
466  // get_warp_size do not return the correct wavesize, hardcode to 32 as workaround
467  static constexpr auto selected_wmma =
469 
470  __host__ __device__ constexpr WmmaSelector()
471  {
472  static_assert(selected_wmma.m_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
473 
474  static_assert(selected_wmma.m_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
475 
476  static_assert(selected_wmma.k_per_wmma == 16, "WRONG! WMMA_M must equal to 16");
477 
478  static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave *
479  selected_wmma.acc_data_size * selected_wmma.acc_pack_number ==
480  selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4,
481  "WRONG! Invalid Number of Accumulator Register");
482  }
483 };
484 
485 template <typename src_type_a,
486  typename src_type_b,
487  typename dst_type,
488  index_t MPerWmma,
489  index_t NPerWmma,
490  index_t KPack,
491  bool TransposeC = false,
492  bool AssemblyBackend = false>
493 struct WmmaGemm
494 {
495  static constexpr auto I0 = Number<0>{};
496  static constexpr auto I1 = Number<1>{};
497  static constexpr auto I2 = Number<2>{};
498  static constexpr auto I3 = Number<3>{};
499  static constexpr auto I4 = Number<4>{};
500  static constexpr auto I5 = Number<5>{};
501 
504 
505  __host__ __device__ constexpr WmmaGemm()
506  {
507  static_assert(NPerWmma == 16 && MPerWmma == 16,
508  "Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma");
509 
510  static_assert(KPack % wmma_instr.k_per_wmma == 0, "KPack should be multiple of k_per_wmma");
511  }
512 
513  // WMMA output supporting C = A * B
514  // Vector Write
515  // MPerWMMA_NPerWMMA -> MSubGroup_..._NPerWMMA_MAccVgprPerWave
516  template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
517  __host__ __device__ static constexpr auto
519  const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
520  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
521  {
522  const auto MBlockxRepeat =
523  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
524  const auto NBlockxRepeat =
525  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
526  const auto MWave =
527  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
528  const auto NWave =
529  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
530 
532  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
533  make_tuple(
534  make_pass_through_transform(MBlockxRepeat),
537  Number<wmma_instr.num_acc_vgprs_per_wave>{})),
538  make_pass_through_transform(NBlockxRepeat),
542  Sequence<1>{},
543  Sequence<2>{},
544  Sequence<3>{},
545  Sequence<4>{},
546  Sequence<5>{}),
548  Sequence<1>{},
549  Sequence<2, 6>{},
550  Sequence<3>{},
551  Sequence<4>{},
552  Sequence<5>{}));
553  }
554 
555  // Transposed WMMA Output C' = B' * A'
556  template <typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
557  __host__ __device__ static constexpr auto
559  const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
560  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
561  {
562  const auto MBlockxRepeat =
563  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I0);
564  const auto NBlockxRepeat =
565  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I3);
566  const auto MWave =
567  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I1);
568  const auto NWave =
569  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(I4);
570 
572  c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
573  make_tuple(
574  make_pass_through_transform(MBlockxRepeat),
577  make_pass_through_transform(NBlockxRepeat),
580  Number<wmma_instr.num_acc_vgprs_per_wave>{}))),
582  Sequence<1>{},
583  Sequence<2>{},
584  Sequence<3>{},
585  Sequence<4>{},
586  Sequence<5>{}),
588  Sequence<1>{},
589  Sequence<2>{},
590  Sequence<3>{},
591  Sequence<4>{},
592  Sequence<5, 6>{}));
593  }
594 
595  __device__ static constexpr index_t GetRegSizePerWmma()
596  {
597  return wmma_instr.num_acc_vgprs_per_wave * wmma_instr.acc_pack_number;
598  }
599 
600  __device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; }
601 
602  template <class FloatA, class FloatB, class FloatC>
603  __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
604  {
605  static_assert(
616 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
619 #endif
620  ,
621  "base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
622  "(int8, int32) or (int4, int32)!");
623  static_for<0, KPack / wmma_instr.k_per_wmma, 1>{}([&](auto k) {
624  if constexpr(!TransposeC)
625  {
626  wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave[k], p_b_wave[k], p_c_thread);
627  }
628  else
629  {
630  wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave[k], p_a_wave[k], p_c_thread);
631  }
632  });
633  }
634 
635  __device__ static auto GetLaneId() { return get_thread_local_1d_id() % wmma_instr.wave_size; }
636 
637  __device__ static auto GetSubGroupId()
638  {
639  static_assert(wmma_instr.num_thread_per_subgroups * wmma_instr.num_subgroups ==
640  wmma_instr.wave_size,
641  "");
642  return (GetLaneId() / wmma_instr.num_thread_per_subgroups) % wmma_instr.num_subgroups;
643  }
644 
645  __device__ static auto GetLaneIdUnderSubGroup()
646  {
647  return GetLaneId() % wmma_instr.num_thread_per_subgroups;
648  }
649  __device__ static auto GetSwizzledLaneIdLow()
650  {
651  return ((GetLaneIdUnderSubGroup() & 1) << 3) | (GetLaneIdUnderSubGroup() >> 1);
652  }
653 
654  __host__ __device__ static auto CalculateAThreadOriginDataIndex()
655  {
656 #ifdef __gfx12__
657  return GetLaneIdUnderSubGroup();
658 #else
659  return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
660 #endif
661  }
662 
663  __host__ __device__ static auto CalculateBThreadOriginDataIndex()
664  {
665 #ifdef __gfx12__
666  return GetLaneIdUnderSubGroup();
667 #else
668  return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
669 #endif
670  }
671 
672  __device__ static CIndex GetBeginOfThreadBlk()
673  {
674  index_t n_offset = GetLaneIdUnderSubGroup();
675  index_t m_offset = GetSubGroupId() * wmma_instr.num_acc_vgprs_per_wave;
676 
677  return TransposeC ? CIndex{n_offset, m_offset} : CIndex{m_offset, n_offset};
678  }
679 
680  __device__ static CIndex3D GetBeginOfThreadBlk3D()
681  {
682  index_t n_offset = GetLaneIdUnderSubGroup();
683  index_t m_offset = GetSubGroupId();
684 
685  return TransposeC ? CIndex3D{n_offset, m_offset, I0} : CIndex3D{m_offset, n_offset, I0};
686  }
687 
688  static constexpr auto wmma =
690  static constexpr auto wmma_instr = wmma.selected_wmma;
691 
692  __host__ __device__ static constexpr auto
694  {
695  return make_tuple(I1,
696  I1,
698  Number<wmma_instr.acc_pack_number>{});
699  }
700 };
701 
702 } // namespace ck
__host__ constexpr __device__ T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition: math.hpp:148
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:13
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:16
WmmaInstr
Definition: wmma_gemm.hpp:13
@ wmma_f32_16x16x16_bf16_gfx12
@ wmma_i32_16x16x16_iu8_gfx12
@ wmma_f32_16x16x16_f16_gfx12
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
Definition: array.hpp:14
Definition: sequence.hpp:43
Definition: wmma_gemm.hpp:494
static constexpr auto I0
Definition: wmma_gemm.hpp:495
static __device__ auto GetLaneId()
Definition: wmma_gemm.hpp:635
__device__ void Run(const FloatA &p_a_wave, const FloatB &p_b_wave, FloatC &p_c_thread) const
Definition: wmma_gemm.hpp:603
static constexpr __device__ index_t GetWaveSize()
Definition: wmma_gemm.hpp:600
static constexpr auto wmma
Definition: wmma_gemm.hpp:688
__host__ static constexpr __device__ auto GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
Definition: wmma_gemm.hpp:693
__host__ static __device__ auto CalculateAThreadOriginDataIndex()
Definition: wmma_gemm.hpp:654
static __device__ auto GetSubGroupId()
Definition: wmma_gemm.hpp:637
static __device__ auto GetSwizzledLaneIdLow()
Definition: wmma_gemm.hpp:649
static constexpr auto I3
Definition: wmma_gemm.hpp:498
static constexpr auto I5
Definition: wmma_gemm.hpp:500
__host__ static __device__ auto CalculateBThreadOriginDataIndex()
Definition: wmma_gemm.hpp:663
__host__ static constexpr __device__ auto MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA &c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
Definition: wmma_gemm.hpp:558
static __device__ CIndex GetBeginOfThreadBlk()
Definition: wmma_gemm.hpp:672
static constexpr auto I4
Definition: wmma_gemm.hpp:499
static constexpr __device__ index_t GetRegSizePerWmma()
Definition: wmma_gemm.hpp:595
__host__ static constexpr __device__ auto MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA &c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
Definition: wmma_gemm.hpp:518
__host__ constexpr __device__ WmmaGemm()
Definition: wmma_gemm.hpp:505
static constexpr auto I2
Definition: wmma_gemm.hpp:497
static __device__ CIndex3D GetBeginOfThreadBlk3D()
Definition: wmma_gemm.hpp:680
static constexpr auto I1
Definition: wmma_gemm.hpp:496
static __device__ auto GetLaneIdUnderSubGroup()
Definition: wmma_gemm.hpp:645
static constexpr auto wmma_instr
Definition: wmma_gemm.hpp:690
Definition: wmma_gemm.hpp:409
static constexpr auto selected_wmma
Definition: wmma_gemm.hpp:467
__host__ constexpr __device__ WmmaSelector()
Definition: wmma_gemm.hpp:470
static constexpr auto GetWmma()
Definition: integral_constant.hpp:10
Definition: amd_wmma.hpp:96
Definition: amd_wmma.hpp:216
Definition: amd_wmma.hpp:72
Definition: amd_wmma.hpp:192
Definition: amd_wmma.hpp:50
Definition: amd_wmma.hpp:170
Definition: amd_wmma.hpp:271
Definition: amd_wmma.hpp:25
Definition: amd_wmma.hpp:149
Definition: amd_wmma.hpp:319
Definition: amd_wmma.hpp:121
Definition: amd_wmma.hpp:241
Definition: type.hpp:177
Definition: functional2.hpp:31
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:228
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:187
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:150
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:352
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:113
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:319
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:272
__device__ void run(const FloatA &a, const FloatB &b, FloatC &reg_c) const
Definition: wmma_gemm.hpp:392
Definition: wmma_gemm.hpp:80