/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp Source File
warp_gemm_attribute_mfma_impl.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
10 // TODO: refactor warp-gemm
11 // currently there is a discrepency for vav/vva if we need transpose C/D
12 // e.g. if we want A:agpr, B:vgpr, we have to use vva in WGAttrEnum
13 // because we swap the A/B pointer in _impl code (but not known this info here)
14 enum class WGAttrCtlEnum
15 {
16  Default_ = 0,
17  Raw_vvv = 1, // c-vgpr, a-vgpr, b-vgpr
18  Raw_vaa = 2, // c-vgpr, a-agpr, b-agpr
19  Raw_vav = 3, // c-vgpr, a-agpr, b-vgpr
20  Raw_vva = 4, // c-vgpr, a-vgpr, b-agpr
21  Raw_avv = 5, // c-agpr, a-vgpr, b-vgpr
22  // raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr
23 };
24 
25 #define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_) \
26  if constexpr(post_nop_) \
27  { \
28  asm volatile(mfma_ " %0, %1, %2, %3 ; yyy\n" \
29  "s_nop 3" \
30  : dmod_(c_vec) \
31  : amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
32  :); \
33  } \
34  else \
35  { \
36  asm volatile(mfma_ " %0, %1, %2, %3\n" \
37  : dmod_(c_vec) \
38  : amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
39  :); \
40  }
41 
42 #define DISPATCH_MFMA_CTRL_(mfma_, ctrl_) \
43  if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vvv) \
44  { \
45  DISPATCH_MFMA_(mfma_, "+v", "v", "v", "v") \
46  } \
47  else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vaa) \
48  { \
49  DISPATCH_MFMA_(mfma_, "+v", "a", "a", "v") \
50  } \
51  else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vav) \
52  { \
53  DISPATCH_MFMA_(mfma_, "+v", "a", "v", "v") \
54  } \
55  else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vva) \
56  { \
57  DISPATCH_MFMA_(mfma_, "+v", "v", "a", "v") \
58  } \
59  else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_avv) \
60  { \
61  DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a") \
62  }
63 
64 // F32
65 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
67 {
68  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
69 
70  using ADataType = float;
71  using BDataType = float;
72  using CDataType = float;
73 
77 
78  static constexpr index_t kM = 16;
79  static constexpr index_t kN = 16;
80  static constexpr index_t kK = 4;
81 
82  static constexpr index_t kAMBlock = 1;
83  static constexpr index_t kBNBlock = 1;
84 
85  static constexpr index_t kAMLane = 16;
86  static constexpr index_t kBNLane = 16;
87  static constexpr index_t kABKLane = 4;
88  static constexpr index_t kABKPerLane = 1;
89 
90  static constexpr index_t kCMLane = 4;
91  static constexpr index_t kCNLane = 16;
92  static constexpr index_t kCM0PerLane = 1;
93  static constexpr index_t kCM1PerLane = 4;
94 
95  // c_vec += a_vec * b_vec
96  template <bool post_nop_ = false>
98  const AVecType& a_vec,
99  const BVecType& b_vec,
100  bool_constant<post_nop_> = {}) const
101  {
102  DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x4f32", Ctrl)
103  else
104  {
105 #if defined(__gfx9__)
106  c_vec = __builtin_amdgcn_mfma_f32_16x16x4f32(a_vec[0], b_vec[0], c_vec, 0, 0, 0);
107 #else
108  ck_tile::ignore = c_vec;
109  ck_tile::ignore = a_vec;
110  ck_tile::ignore = b_vec;
111 #endif
112  }
113  }
114 
115  // c_vec = a_vec * b_vec
116  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
117  {
118 #if defined(__gfx9__)
119  return bit_cast<CVecType>(
120  __builtin_amdgcn_mfma_f32_16x16x4f32(a_vec[0], b_vec[0], CVecType{0.f}, 0, 0, 0));
121 #else
122  ck_tile::ignore = a_vec;
123  ck_tile::ignore = b_vec;
124  return CVecType{0.f};
125 #endif
126  }
127 };
128 
129 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
131 {
132  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
133 
134  using ADataType = float;
135  using BDataType = float;
136  using CDataType = float;
137 
141 
142  static constexpr index_t kM = 32;
143  static constexpr index_t kN = 32;
144  static constexpr index_t kK = 2;
145 
146  static constexpr index_t kAMBlock = 1;
147  static constexpr index_t kBNBlock = 1;
148 
149  static constexpr index_t kAMLane = 32;
150  static constexpr index_t kBNLane = 32;
151  static constexpr index_t kABKLane = 2;
152  static constexpr index_t kABKPerLane = 1;
153 
154  static constexpr index_t kCMLane = 2;
155  static constexpr index_t kCNLane = 32;
156  static constexpr index_t kCM0PerLane = 4;
157  static constexpr index_t kCM1PerLane = 4;
158 
159  // c_vec += a_vec * b_vec
160  template <bool post_nop_ = false>
162  const AVecType& a_vec,
163  const BVecType& b_vec,
164  bool_constant<post_nop_> = {}) const
165  {
166  DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x2f32", Ctrl)
167  else
168  {
169 #if defined(__gfx9__)
170  c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_vec[0], b_vec[0], c_vec, 0, 0, 0);
171 #else
172  ck_tile::ignore = c_vec;
173  ck_tile::ignore = a_vec;
174  ck_tile::ignore = b_vec;
175 #endif
176  }
177  }
178 
179  // c_vec = a_vec * b_vec
180  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
181  {
182 #if defined(__gfx9__)
183  return bit_cast<CVecType>(
184  __builtin_amdgcn_mfma_f32_32x32x2f32(a_vec[0], b_vec[0], CVecType{0.f}, 0, 0, 0));
185 #else
186  ck_tile::ignore = a_vec;
187  ck_tile::ignore = b_vec;
188  return CVecType{0.f};
189 #endif
190  }
191 };
192 
193 // V_MFMA_F32_16x16x32_BF16
194 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
196 {
197  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
198  using ADataType = bf16_t;
199  using BDataType = bf16_t;
200  using CDataType = float;
201 
205 
206  static constexpr index_t kM = 16;
207  static constexpr index_t kN = 16;
208  static constexpr index_t kK = 32;
209 
210  static constexpr index_t kAMBlock = 1;
211  static constexpr index_t kBNBlock = 1;
212 
213  static constexpr index_t kAMLane = 16;
214  static constexpr index_t kBNLane = 16;
215  static constexpr index_t kABKLane = 4;
216  static constexpr index_t kABKPerLane = 8;
217 
218  static constexpr index_t kCMLane = 4;
219  static constexpr index_t kCNLane = 16;
220  static constexpr index_t kCM0PerLane = 1;
221  static constexpr index_t kCM1PerLane = 4;
222 
223  // c_vec += a_vec * b_vec
224  template <bool post_nop_ = false>
226  const AVecType& a_vec,
227  const BVecType& b_vec,
228  bool_constant<post_nop_> = {}) const
229  {
230  DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x32_bf16", Ctrl)
231  else
232  {
233 #if defined(__gfx950__)
234  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_vec, b_vec, c_vec, 0, 0, 0);
235 #else
236  ck_tile::ignore = c_vec;
237  ck_tile::ignore = a_vec;
238  ck_tile::ignore = b_vec;
239 #endif
240  }
241  }
242 
243  // c_vec = a_vec * b_vec
244  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
245  {
246 #if defined(__gfx950__)
247  return bit_cast<CVecType>(
248  __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
249 #else
250  ck_tile::ignore = a_vec;
251  ck_tile::ignore = b_vec;
252  return CVecType{0.f};
253 #endif
254  }
255 };
256 // FP16
257 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
259 {
260  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
261  using ADataType = fp16_t;
262  using BDataType = fp16_t;
263  using CDataType = float;
264 
268 
269  static constexpr index_t kM = 32;
270  static constexpr index_t kN = 32;
271  static constexpr index_t kK = 8;
272 
273  static constexpr index_t kAMBlock = 1;
274  static constexpr index_t kBNBlock = 1;
275 
276  static constexpr index_t kAMLane = 32;
277  static constexpr index_t kBNLane = 32;
278  static constexpr index_t kABKLane = 2;
279  static constexpr index_t kABKPerLane = 4;
280 
281  static constexpr index_t kCMLane = 2;
282  static constexpr index_t kCNLane = 32;
283  static constexpr index_t kCM0PerLane = 4;
284  static constexpr index_t kCM1PerLane = 4;
285 
286  // c_vec += a_vec * b_vec
287  template <bool post_nop_ = false>
289  const AVecType& a_vec,
290  const BVecType& b_vec,
291  bool_constant<post_nop_> = {}) const
292  {
293  DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8f16", Ctrl)
294  else
295  {
296 #if defined(__gfx9__)
297  c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
298 #else
299  ck_tile::ignore = c_vec;
300  ck_tile::ignore = a_vec;
301  ck_tile::ignore = b_vec;
302 #endif
303  }
304  }
305 
306  // c_vec = a_vec * b_vec
307  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
308  {
309 #if defined(__gfx9__)
310  return bit_cast<CVecType>(
311  __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
312 #else
313  ck_tile::ignore = a_vec;
314  ck_tile::ignore = b_vec;
315  return CVecType{0.f};
316 #endif
317  }
318 };
319 
320 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
322 {
323  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
324  using ADataType = fp16_t;
325  using BDataType = fp16_t;
326  using CDataType = float;
327 
331 
332  static constexpr index_t kM = 16;
333  static constexpr index_t kN = 16;
334  static constexpr index_t kK = 16;
335 
336  static constexpr index_t kAMBlock = 1;
337  static constexpr index_t kBNBlock = 1;
338 
339  static constexpr index_t kAMLane = 16;
340  static constexpr index_t kBNLane = 16;
341  static constexpr index_t kABKLane = 4;
342  static constexpr index_t kABKPerLane = 4;
343 
344  static constexpr index_t kCMLane = 4;
345  static constexpr index_t kCNLane = 16;
346  static constexpr index_t kCM0PerLane = 1;
347  static constexpr index_t kCM1PerLane = 4;
348 
349  // c_vec += a_vec * b_vec
350  template <bool post_nop_ = false>
352  const AVecType& a_vec,
353  const BVecType& b_vec,
354  bool_constant<post_nop_> = {}) const
355  {
356  DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16f16", Ctrl)
357  else
358  {
359 #if defined(__gfx9__)
360  c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
361 #else
362  ck_tile::ignore = c_vec;
363  ck_tile::ignore = a_vec;
364  ck_tile::ignore = b_vec;
365 #endif
366  }
367  }
368 
369  // c_vec = a_vec * b_vec
370  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
371  {
372 #if defined(__gfx9__)
373  return bit_cast<CVecType>(
374  __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
375 #else
376  ck_tile::ignore = a_vec;
377  ck_tile::ignore = b_vec;
378  return CVecType{0.f};
379 #endif
380  }
381 };
382 
383 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
385 {
386  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
387  using ADataType = fp16_t;
388  using BDataType = fp16_t;
389  using CDataType = float;
390 
394 
395  static constexpr index_t kM = 16;
396  static constexpr index_t kN = 16;
397  static constexpr index_t kK = 32;
398 
399  static constexpr index_t kAMBlock = 1;
400  static constexpr index_t kBNBlock = 1;
401 
402  static constexpr index_t kAMLane = 16;
403  static constexpr index_t kBNLane = 16;
404  static constexpr index_t kABKLane = 4;
405  static constexpr index_t kABKPerLane = 8;
406 
407  static constexpr index_t kCMLane = 4;
408  static constexpr index_t kCNLane = 16;
409  static constexpr index_t kCM0PerLane = 1;
410  static constexpr index_t kCM1PerLane = 4;
411 
412  // c_vec += a_vec * b_vec
413  template <bool post_nop_ = false>
415  const AVecType& a_vec,
416  const BVecType& b_vec,
417  bool_constant<post_nop_> = {}) const
418  {
419  DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x32f16", Ctrl)
420  else
421  {
422 #if defined(__gfx950__)
423  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_f16(a_vec, b_vec, c_vec, 0, 0, 0);
424 #else
425  ck_tile::ignore = c_vec;
426  ck_tile::ignore = a_vec;
427  ck_tile::ignore = b_vec;
428 #endif
429  }
430  }
431 
432  // c_vec = a_vec * b_vec
433  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
434  {
435 #if defined(__gfx950__)
436  return bit_cast<CVecType>(
437  __builtin_amdgcn_mfma_f32_16x16x32_f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
438 #else
439  ck_tile::ignore = a_vec;
440  ck_tile::ignore = b_vec;
441  return CVecType{0.f};
442 #endif
443  }
444 };
445 
446 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
448 {
449  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
450  using ADataType = fp16_t;
451  using BDataType = fp16_t;
452  using CDataType = float;
453 
457 
458  static constexpr index_t kM = 4;
459  static constexpr index_t kN = 64;
460  static constexpr index_t kK = 4;
461 
462  static constexpr index_t kAMBlock = 1;
463  static constexpr index_t kBNBlock = 16;
464 
465  // we only write down single block (4 threads) thread mapping here
466  static constexpr index_t kAMLane = 4;
467  static constexpr index_t kBNLane = 4;
468  static constexpr index_t kABKLane = 1;
469  static constexpr index_t kABKPerLane = 4;
470 
471  static constexpr index_t kCMLane = 1;
472  static constexpr index_t kCNLane = 4;
473  static constexpr index_t kCM0PerLane = 1;
474  static constexpr index_t kCM1PerLane = 4;
475 
476  // c_vec += a_vec * b_vec
477  template <bool post_nop_ = false>
479  const AVecType& a_vec,
480  const BVecType& b_vec,
481  bool_constant<post_nop_> = {}) const
482  {
483  DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4f16", Ctrl)
484  else
485  {
486 #if defined(__gfx9__)
487  c_vec = __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, c_vec, 0, 0, 0);
488 #else
489  ignore = c_vec;
490  ignore = a_vec;
491  ignore = b_vec;
492 #endif
493  }
494  }
495 
496  // c_vec = a_vec * b_vec
497  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
498  {
499 #if defined(__gfx9__)
500  return bit_cast<CVecType>(
501  __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
502 #else
503  ignore = a_vec;
504  ignore = b_vec;
505  return CVecType{0.f};
506 #endif
507  }
508 };
509 
510 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
512 {
513  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
514  using ADataType = fp16_t;
515  using BDataType = fp16_t;
516  using CDataType = float;
517 
521 
522  static constexpr index_t kM = 64;
523  static constexpr index_t kN = 4;
524  static constexpr index_t kK = 4;
525 
526  static constexpr index_t kAMBlock = 16;
527  static constexpr index_t kBNBlock = 1;
528 
529  // we only write down single block (4 threads) thread mapping here
530  static constexpr index_t kAMLane = 4;
531  static constexpr index_t kBNLane = 4;
532  static constexpr index_t kABKLane = 1;
533  static constexpr index_t kABKPerLane = 4;
534 
535  static constexpr index_t kCMLane = 1;
536  static constexpr index_t kCNLane = 4;
537  static constexpr index_t kCM0PerLane = 1;
538  static constexpr index_t kCM1PerLane = 4;
539 
540  // c_vec += a_vec * b_vec
541  template <bool post_nop_ = false>
543  const AVecType& a_vec,
544  const BVecType& b_vec,
545  bool_constant<post_nop_> = {}) const
546  {
547  DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4f16", Ctrl)
548  else
549  {
550 #if defined(__gfx9__)
551  c_vec = __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, c_vec, 0, 0, 0);
552 #else
553  ignore = c_vec;
554  ignore = a_vec;
555  ignore = b_vec;
556 #endif
557  }
558  }
559 
560  // c_vec = a_vec * b_vec
561  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
562  {
563 #if defined(__gfx9__)
564  return bit_cast<CVecType>(
565  __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
566 #else
567  ignore = a_vec;
568  ignore = b_vec;
569  return CVecType{0.f};
570 #endif
571  }
572 };
573 
574 // Bf16
575 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
577 {
578  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
579  using ADataType = bf16_t;
580  using BDataType = bf16_t;
581  using CDataType = float;
582 
586 
587  static constexpr index_t kM = 32;
588  static constexpr index_t kN = 32;
589  static constexpr index_t kK = 8;
590 
591  static constexpr index_t kAMBlock = 1;
592  static constexpr index_t kBNBlock = 1;
593 
594  static constexpr index_t kAMLane = 32;
595  static constexpr index_t kBNLane = 32;
596  static constexpr index_t kABKLane = 2;
597  static constexpr index_t kABKPerLane = 4;
598 
599  static constexpr index_t kCMLane = 2;
600  static constexpr index_t kCNLane = 32;
601  static constexpr index_t kCM0PerLane = 4;
602  static constexpr index_t kCM1PerLane = 4;
603 
604  // c_vec += a_vec * b_vec
605  template <bool post_nop_ = false>
607  const AVecType& a_vec,
608  const BVecType& b_vec,
609  bool_constant<post_nop_> = {}) const
610  {
611  DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8bf16_1k", Ctrl)
612  else
613  {
614 #if defined(__gfx90a__) || defined(__gfx94__)
615  c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
616 #elif defined(__gfx908__)
617  static_for<0, 2, 1>{}([&](auto k) {
618  c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
619  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
620  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
621  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
622  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
623  c_vec,
624  0,
625  0,
626  0);
627  });
628 #else
629  ck_tile::ignore = c_vec;
630  ck_tile::ignore = a_vec;
631  ck_tile::ignore = b_vec;
632 #endif
633  }
634  }
635 
636  // c_vec = a_vec * b_vec
637  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
638  {
639 #if defined(__gfx90a__) || defined(__gfx94__)
640  return bit_cast<CVecType>(
641  __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
642 #elif defined(__gfx908__)
643  CVecType c_vec{0.f};
644  static_for<0, 2, 1>{}([&](auto k) {
645  c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
646  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
647  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
648  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
649  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
650  c_vec,
651  0,
652  0,
653  0);
654  });
655  return c_vec;
656 #else
657  ck_tile::ignore = a_vec;
658  ck_tile::ignore = b_vec;
659  return CVecType{0.f};
660 #endif
661  }
662 };
663 
664 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
666 {
667  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
668  using ADataType = bf16_t;
669  using BDataType = bf16_t;
670  using CDataType = float;
671 
675 
676  static constexpr index_t kM = 16;
677  static constexpr index_t kN = 16;
678  static constexpr index_t kK = 16;
679 
680  static constexpr index_t kAMBlock = 1;
681  static constexpr index_t kBNBlock = 1;
682 
683  static constexpr index_t kAMLane = 16;
684  static constexpr index_t kBNLane = 16;
685  static constexpr index_t kABKLane = 4;
686  static constexpr index_t kABKPerLane = 4;
687 
688  static constexpr index_t kCMLane = 4;
689  static constexpr index_t kCNLane = 16;
690  static constexpr index_t kCM0PerLane = 1;
691  static constexpr index_t kCM1PerLane = 4;
692 
693  // c_vec += a_vec * b_vec
694  template <bool post_nop_ = false>
696  const AVecType& a_vec,
697  const BVecType& b_vec,
698  bool_constant<post_nop_> = {}) const
699  {
700  DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16bf16_1k", Ctrl)
701  {
702 #if defined(__gfx90a__) || defined(__gfx94__)
703  c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
704 #elif defined(__gfx908__)
705  static_for<0, 2, 1>{}([&](auto k) {
706  c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
707  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
708  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
709  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
710  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
711  c_vec,
712  0,
713  0,
714  0);
715  });
716 #else
717  ck_tile::ignore = c_vec;
718  ck_tile::ignore = a_vec;
719  ck_tile::ignore = b_vec;
720 #endif
721  }
722  }
723 
724  // c_vec = a_vec * b_vec
725  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
726  {
727 #if defined(__gfx90a__) || defined(__gfx94__)
728  return bit_cast<CVecType>(
729  __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
730 #elif defined(__gfx908__)
731  CVecType c_vec{0.f};
732  static_for<0, 2, 1>{}([&](auto k) {
733  c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
734  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
735  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
736  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
737  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
738  c_vec,
739  0,
740  0,
741  0);
742  });
743  return c_vec;
744 #else
745  ck_tile::ignore = a_vec;
746  ck_tile::ignore = b_vec;
747  return CVecType{0.f};
748 #endif
749  }
750 };
751 
752 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
754 {
755  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
756  using ADataType = bf16_t;
757  using BDataType = bf16_t;
758  using CDataType = float;
759 
763 
764  static constexpr index_t kM = 4;
765  static constexpr index_t kN = 64;
766  static constexpr index_t kK = 4;
767 
768  static constexpr index_t kAMBlock = 1;
769  static constexpr index_t kBNBlock = 16;
770 
771  // we only write down single block (4 threads) thread mapping here
772  static constexpr index_t kAMLane = 4;
773  static constexpr index_t kBNLane = 4;
774  static constexpr index_t kABKLane = 1;
775  static constexpr index_t kABKPerLane = 4;
776 
777  static constexpr index_t kCMLane = 1;
778  static constexpr index_t kCNLane = 4;
779  static constexpr index_t kCM0PerLane = 1;
780  static constexpr index_t kCM1PerLane = 4;
781 
782  // c_vec += a_vec * b_vec
783  template <bool post_nop_ = false>
785  const AVecType& a_vec,
786  const BVecType& b_vec,
787  bool_constant<post_nop_> = {}) const
788  {
789  DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl)
790  else
791  {
792 #if defined(__gfx90a__) || defined(__gfx94__)
793  c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
794 #elif defined(__gfx908__)
795  static_for<0, 2, 1>{}([&](auto k) {
796  c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
797  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
798  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
799  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
800  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
801  c_vec,
802  0,
803  0,
804  0);
805  });
806 #else
807  ignore = c_vec;
808  ignore = a_vec;
809  ignore = b_vec;
810 #endif
811  }
812  }
813 
814  // c_vec = a_vec * b_vec
815  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
816  {
817 #if defined(__gfx90a__) || defined(__gfx94__)
818  return bit_cast<CVecType>(
819  __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
820 #elif defined(__gfx908__)
821  CVecType c_vec{0.f};
822  static_for<0, 2, 1>{}([&](auto k) {
823  c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
824  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
825  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
826  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
827  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
828  c_vec,
829  0,
830  0,
831  0);
832  });
833  return c_vec;
834 #else
835  ignore = a_vec;
836  ignore = b_vec;
837  return CVecType{0.f};
838 #endif
839  }
840 };
841 
842 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
844 {
845  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
846  using ADataType = bf16_t;
847  using BDataType = bf16_t;
848  using CDataType = float;
849 
853 
854  static constexpr index_t kM = 64;
855  static constexpr index_t kN = 4;
856  static constexpr index_t kK = 4;
857 
858  static constexpr index_t kAMBlock = 16;
859  static constexpr index_t kBNBlock = 1;
860 
861  // we only write down single block (4 threads) thread mapping here
862  static constexpr index_t kAMLane = 4;
863  static constexpr index_t kBNLane = 4;
864  static constexpr index_t kABKLane = 1;
865  static constexpr index_t kABKPerLane = 4;
866 
867  static constexpr index_t kCMLane = 1;
868  static constexpr index_t kCNLane = 4;
869  static constexpr index_t kCM0PerLane = 1;
870  static constexpr index_t kCM1PerLane = 4;
871 
872  // c_vec += a_vec * b_vec
873  template <bool post_nop_ = false>
875  const AVecType& a_vec,
876  const BVecType& b_vec,
877  bool_constant<post_nop_> = {}) const
878  {
879  DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl)
880  else
881  {
882 #if defined(__gfx90a__) || defined(__gfx94__)
883  c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
884 #elif defined(__gfx908__)
885  static_for<0, 2, 1>{}([&](auto k) {
886  c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
887  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
888  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
889  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
890  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
891  c_vec,
892  0,
893  0,
894  0);
895  });
896 #else
897  ignore = c_vec;
898  ignore = a_vec;
899  ignore = b_vec;
900 #endif
901  }
902  }
903 
904  // c_vec = a_vec * b_vec
905  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
906  {
907 #if defined(__gfx90a__) || defined(__gfx94__)
908  return bit_cast<CVecType>(
909  __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
910 #elif defined(__gfx908__)
911  CVecType c_vec{0.f};
912  static_for<0, 2, 1>{}([&](auto k) {
913  c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
914  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
915  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
916  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
917  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
918  c_vec,
919  0,
920  0,
921  0);
922  });
923  return c_vec;
924 #else
925  ignore = a_vec;
926  ignore = b_vec;
927  return CVecType{0.f};
928 #endif
929  }
930 };
931 
932 // gfx950
933 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
935 {
936  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
937  using ADataType = fp16_t;
938  using BDataType = fp16_t;
939  using CDataType = float;
940 
944 
945  static constexpr index_t kM = 32;
946  static constexpr index_t kN = 32;
947  static constexpr index_t kK = 16;
948 
949  static constexpr index_t kAMBlock = 1;
950  static constexpr index_t kBNBlock = 1;
951 
952  static constexpr index_t kAMLane = 32;
953  static constexpr index_t kBNLane = 32;
954  static constexpr index_t kABKLane = 2;
955  static constexpr index_t kABKPerLane = 8;
956 
957  static constexpr index_t kCMLane = 2;
958  static constexpr index_t kCNLane = 32;
959  static constexpr index_t kCM0PerLane = 4;
960  static constexpr index_t kCM1PerLane = 4;
961 
962  // c_vec += a_vec * b_vec
963  template <bool post_nop_ = false>
965  const AVecType& a_vec,
966  const BVecType& b_vec,
967  bool_constant<post_nop_> = {}) const
968  {
969  DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x16_f16", Ctrl)
970  else
971  {
972 #if defined(__gfx950__)
973  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_f16(a_vec, b_vec, c_vec, 0, 0, 0);
974 #elif defined(__gfx90a__) || defined(__gfx94__)
975  static_for<0, 2, 1>{}([&](auto k) {
976  c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(
977  reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
978  .template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
979  reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
980  .template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
981  c_vec,
982  0,
983  0,
984  0);
985  });
986 #elif defined(__gfx908__)
987  static_for<0, 4, 1>{}([&](auto k) {
988  c_vec = __builtin_amdgcn_mfma_f32_32x32x4f16(
989  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
990  .template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
991  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
992  .template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
993  c_vec,
994  0,
995  0,
996  0);
997  });
998 #else
999  ck_tile::ignore = c_vec;
1000  ck_tile::ignore = a_vec;
1001  ck_tile::ignore = b_vec;
1002 #endif
1003  }
1004  }
1005 
1006  // c_vec = a_vec * b_vec
1007  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1008  {
1009 #if defined(__gfx950__)
1010  return __builtin_amdgcn_mfma_f32_32x32x16_f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0);
1011 #elif defined(__gfx90a__) || defined(__gfx94__)
1012  CVecType c_vec{0.f};
1013  static_for<0, 2, 1>{}([&](auto k) {
1014  c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(
1015  reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1016  .template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
1017  reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1018  .template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
1019  c_vec,
1020  0,
1021  0,
1022  0);
1023  });
1024  return c_vec;
1025 #elif defined(__gfx908__)
1026  CVecType c_vec{0.f};
1027  static_for<0, 4, 1>{}([&](auto k) {
1028  c_vec = __builtin_amdgcn_mfma_f32_32x32x4f16(
1029  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
1030  .template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
1031  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
1032  .template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
1033  c_vec,
1034  0,
1035  0,
1036  0);
1037  });
1038  return c_vec;
1039 #else
1040  ck_tile::ignore = a_vec;
1041  ck_tile::ignore = b_vec;
1042  return CVecType{0.f};
1043 #endif
1044  }
1045 };
1046 
1047 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1049 {
1050  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1053  using CDataType = float;
1054 
1058 
1059  static constexpr index_t kM = 32;
1060  static constexpr index_t kN = 32;
1061  static constexpr index_t kK = 16;
1062 
1063  static constexpr index_t kAMBlock = 1;
1064  static constexpr index_t kBNBlock = 1;
1065 
1066  static constexpr index_t kAMLane = 32;
1067  static constexpr index_t kBNLane = 32;
1068  static constexpr index_t kABKLane = 2;
1069  static constexpr index_t kABKPerLane = 8;
1070 
1071  static constexpr index_t kCMLane = 2;
1072  static constexpr index_t kCNLane = 32;
1073  static constexpr index_t kCM0PerLane = 4;
1074  static constexpr index_t kCM1PerLane = 4;
1075 
1076  // c_vec += a_vec * b_vec
1077  template <bool post_nop_ = false>
1079  const AVecType& a_vec,
1080  const BVecType& b_vec,
1081  bool_constant<post_nop_> = {}) const
1082  {
1083  DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x16_bf16", Ctrl)
1084  else
1085  {
1086 #if defined(__gfx950__)
1087  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_vec, b_vec, c_vec, 0, 0, 0);
1088 #elif defined(__gfx90a__) || defined(__gfx94__)
1089  static_for<0, 2, 1>{}([&](auto k) {
1090  c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
1091  reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1092  .template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
1093  reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1094  .template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
1095  c_vec,
1096  0,
1097  0,
1098  0);
1099  });
1100 #elif defined(__gfx908__)
1101  static_for<0, 4, 1>{}([&](auto k) {
1102  c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
1103  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
1104  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
1105  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
1106  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
1107  c_vec,
1108  0,
1109  0,
1110  0);
1111  });
1112 #else
1113  ck_tile::ignore = c_vec;
1114  ck_tile::ignore = a_vec;
1115  ck_tile::ignore = b_vec;
1116 #endif
1117  }
1118  }
1119 
1120  // c_vec = a_vec * b_vec
1121  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1122  {
1123 #if defined(__gfx950__)
1124  return __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0);
1125 #elif defined(__gfx90a__) || defined(__gfx94__)
1126  CVecType c_vec{0.f};
1127  static_for<0, 2, 1>{}([&](auto k) {
1128  c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
1129  reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1130  .template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
1131  reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1132  .template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
1133  c_vec,
1134  0,
1135  0,
1136  0);
1137  });
1138  return c_vec;
1139 #elif defined(__gfx908__)
1140  CVecType c_vec{0.f};
1141  static_for<0, 4, 1>{}([&](auto k) {
1142  c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
1143  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
1144  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
1145  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
1146  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
1147  c_vec,
1148  0,
1149  0,
1150  0);
1151  });
1152  return c_vec;
1153 #else
1154  ck_tile::ignore = a_vec;
1155  ck_tile::ignore = b_vec;
1156  return CVecType{0.f};
1157 #endif
1158  }
1159 };
1160 
1161 // FP8
1162 template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1164 {
1165  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1166  using ADataType = AType_;
1167  using BDataType = BType_;
1168  using CDataType = float;
1169 
1173 
1174  static constexpr index_t kM = 16;
1175  static constexpr index_t kN = 16;
1176  static constexpr index_t kK = 32;
1177 
1178  static constexpr index_t kAMBlock = 1;
1179  static constexpr index_t kBNBlock = 1;
1180 
1181  static constexpr index_t kAMLane = 16;
1182  static constexpr index_t kBNLane = 16;
1183  static constexpr index_t kABKLane = 4;
1184  static constexpr index_t kABKPerLane = 8;
1185 
1186  static constexpr index_t kCMLane = 4;
1187  static constexpr index_t kCNLane = 16;
1188  static constexpr index_t kCM0PerLane = 1;
1189  static constexpr index_t kCM1PerLane = 4;
1190 
1191  // c_vec += a_vec * b_vec
1192  template <bool post_nop_ = false>
1194  const AVecType& a_vec,
1195  const BVecType& b_vec,
1196  bool_constant<post_nop_> = {}) const
1197  {
1198  if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
1199  {
1200  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1201  {
1202  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "v", "v", "v")
1203  }
1204  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1205  {
1206  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "v", "v", "v")
1207  }
1208  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1209  {
1210  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "v", "v", "v")
1211  }
1212  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1213  {
1214  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "v", "v", "v")
1215  }
1216  }
1217  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
1218  {
1219  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1220  {
1221  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "a", "a", "v")
1222  }
1223  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1224  {
1225  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "a", "a", "v")
1226  }
1227  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1228  {
1229  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "a", "a", "v")
1230  }
1231  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1232  {
1233  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "a", "a", "v")
1234  }
1235  }
1236  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
1237  {
1238  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1239  {
1240  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "a", "v", "v")
1241  }
1242  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1243  {
1244  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "a", "v", "v")
1245  }
1246  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1247  {
1248  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "a", "v", "v")
1249  }
1250  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1251  {
1252  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "a", "v", "v")
1253  }
1254  }
1255  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
1256  {
1257  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1258  {
1259  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "v", "a", "v")
1260  }
1261  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1262  {
1263  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "v", "a", "v")
1264  }
1265  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1266  {
1267  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "v", "a", "v")
1268  }
1269  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1270  {
1271  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "v", "a", "v")
1272  }
1273  }
1274  else
1275  {
1276 #if defined(__gfx94__) or defined(__gfx95__)
1277  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1278  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
1279  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1280  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1281  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
1282  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1283  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1284  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
1285  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1286  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1287  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
1288  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1289 #else
1290  ck_tile::ignore = c_vec;
1291  ck_tile::ignore = a_vec;
1292  ck_tile::ignore = b_vec;
1293 #endif
1294  }
1295  }
1296 
1297  // c_vec = a_vec * b_vec
1298  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1299  {
1300 #if defined(__gfx94__) or defined(__gfx95__)
1301  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1302  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
1303  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1304  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1305  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
1306  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1307  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1308  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
1309  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1310  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1311  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
1312  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1313 #else
1314  ck_tile::ignore = a_vec;
1315  ck_tile::ignore = b_vec;
1316  return CVecType{0.f};
1317 #endif
1318  }
1319 };
1320 
1321 template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1323 {
1324  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1325  using ADataType = AType_;
1326  using BDataType = BType_;
1327  using CDataType = float;
1328 
1332 
1333  static constexpr index_t kM = 32;
1334  static constexpr index_t kN = 32;
1335  static constexpr index_t kK = 16;
1336 
1337  static constexpr index_t kAMBlock = 1;
1338  static constexpr index_t kBNBlock = 1;
1339 
1340  static constexpr index_t kAMLane = 32;
1341  static constexpr index_t kBNLane = 32;
1342  static constexpr index_t kABKLane = 2;
1343  static constexpr index_t kABKPerLane = 8;
1344 
1345  static constexpr index_t kCMLane = 2;
1346  static constexpr index_t kCNLane = 32;
1347  static constexpr index_t kCM0PerLane = 4;
1348  static constexpr index_t kCM1PerLane = 4;
1349 
1350  // c_vec += a_vec * b_vec
1351  template <bool post_nop_ = false>
1353  const AVecType& a_vec,
1354  const BVecType& b_vec,
1355  bool_constant<post_nop_> = {}) const
1356  {
1357  if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
1358  {
1359  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1360  {
1361  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "v", "v", "v")
1362  }
1363  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1364  {
1365  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "v", "v", "v")
1366  }
1367  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1368  {
1369  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "v", "v", "v")
1370  }
1371  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1372  {
1373  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "v", "v", "v")
1374  }
1375  }
1376  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
1377  {
1378  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1379  {
1380  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "a", "a", "v")
1381  }
1382  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1383  {
1384  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "a", "a", "v")
1385  }
1386  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1387  {
1388  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "a", "a", "v")
1389  }
1390  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1391  {
1392  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "a", "a", "v")
1393  }
1394  }
1395  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
1396  {
1397  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1398  {
1399  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "a", "v", "v")
1400  }
1401  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1402  {
1403  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "a", "v", "v")
1404  }
1405  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1406  {
1407  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "a", "v", "v")
1408  }
1409  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1410  {
1411  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "a", "v", "v")
1412  }
1413  }
1414  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
1415  {
1416  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1417  {
1418  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "v", "a", "v")
1419  }
1420  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1421  {
1422  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "v", "a", "v")
1423  }
1424  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1425  {
1426  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "v", "a", "v")
1427  }
1428  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1429  {
1430  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "v", "a", "v")
1431  }
1432  }
1433  else
1434  {
1435 #if defined(__gfx94__) or defined(__gfx95__)
1436  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1437  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
1438  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1439  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1440  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
1441  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1442  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1443  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
1444  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1445  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1446  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
1447  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1448 #elif defined(__gfx908__) || defined(__gfx90a__)
1449  static_for<0, 8, 1>{}([&](auto k) {
1450  float a_f32 =
1451  type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1452  .template get_as<ADataType>()[number<k>{}]);
1453  float b_f32 =
1454  type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1455  .template get_as<BDataType>()[number<k>{}]);
1456 
1457  c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
1458  });
1459 #else
1460  ck_tile::ignore = c_vec;
1461  ck_tile::ignore = a_vec;
1462  ck_tile::ignore = b_vec;
1463 #endif
1464  }
1465  }
1466 
1467  // c_vec = a_vec * b_vec
1468  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1469  {
1470 #if defined(__gfx94__) or defined(__gfx95__)
1471  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1472  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
1473  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1474  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1475  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
1476  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1477  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1478  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
1479  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1480  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1481  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
1482  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), CVecType{0.f}, 0, 0, 0));
1483 #elif defined(__gfx908__) || defined(__gfx90a__)
1484  CVecType c_vec{0.f};
1485  static_for<0, 8, 1>{}([&](auto k) {
1486  float a_f32 =
1487  type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1488  .template get_as<ADataType>()[number<k>{}]);
1489  float b_f32 =
1490  type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1491  .template get_as<BDataType>()[number<k>{}]);
1492 
1493  c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
1494  });
1495  return c_vec;
1496 #else
1497  ck_tile::ignore = a_vec;
1498  ck_tile::ignore = b_vec;
1499  return CVecType{0.f};
1500 #endif
1501  }
1502 };
1503 
1504 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1507 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1510 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1513 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1516 
1517 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1520 
1521 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1524 
1525 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1528 
1529 template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1531 {
1532  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1533  using ADataType = AType_;
1534  using BDataType = BType_;
1535  using CDataType = float;
1536 
1540 
1541  static constexpr index_t kM = 16;
1542  static constexpr index_t kN = 16;
1543  static constexpr index_t kK = 128;
1544 
1545  static constexpr index_t kAMBlock = 1;
1546  static constexpr index_t kBNBlock = 1;
1547 
1548  static constexpr index_t kAMLane = 16;
1549  static constexpr index_t kBNLane = 16;
1550  static constexpr index_t kABKLane = 4;
1551  static constexpr index_t kABKPerLane = 32;
1552 
1553  static constexpr index_t kCMLane = 4;
1554  static constexpr index_t kCNLane = 16;
1555  static constexpr index_t kCM0PerLane = 1;
1556  static constexpr index_t kCM1PerLane = 4;
1557 
1558  // c_vec += a_vec * b_vec
1559  template <index_t opselA, index_t opselB, bool post_nop_ = false>
1561  const AVecType& a_vec,
1562  const int32_t& a_scale,
1563  const BVecType& b_vec,
1564  const int32_t& b_scale,
1565  bool_constant<post_nop_> = {}) const
1566  {
1567 #if defined(__gfx950__)
1568  auto dtype2conf = [](auto dtype) {
1569  if constexpr(std::is_same_v<decltype(dtype), fp8_t>)
1570  return make_tuple(number<0>{}, int32x8_t{});
1571  else if constexpr(std::is_same_v<decltype(dtype), bf8_t>)
1572  return make_tuple(number<1>{}, int32x8_t{});
1573  // else if e2m3 => make_tuple(number<2>{}, int32x6_t{})
1574  // else if e3m2 => make_tuple(number<3>{}, int32x6_t{})
1575  else if constexpr(std::is_same_v<decltype(dtype), pk_fp4_t>)
1576  return make_tuple(number<4>{}, int32x4_t{});
1577  else
1578  static_assert(false, "Unsupported data type for mfma scale");
1579  };
1580  auto dtype2code = [&](auto dtype) { return dtype2conf(dtype)(number<0>{}); };
1581  auto dtype2vec = [&](auto dtype) { return dtype2conf(dtype)(number<1>{}); };
1582  auto arg256 = [&](auto x) {
1583  if constexpr(sizeof(x) == 16)
1584  return int32x8_t{x[0], x[1], x[2], x[3], 0, 0, 0, 0};
1585  else if constexpr(sizeof(x) == 24)
1586  return int32x8_t{x[0], x[1], x[2], x[3], x[4], x[5], 0, 0};
1587  else if constexpr(sizeof(x) == 32)
1588  return x;
1589  else
1590  static_assert(false, "Unexpected vector size for mfma scale");
1591  };
1592 
1593  auto arg_a = bit_cast<decltype(dtype2vec(ADataType{}))>(a_vec);
1594  auto arg_b = bit_cast<decltype(dtype2vec(BDataType{}))>(b_vec);
1595  constexpr int cbsz = decltype(dtype2code(ADataType{}))::value;
1596  constexpr int blgp = decltype(dtype2code(BDataType{}))::value;
1597  c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1598  arg256(arg_a), arg256(arg_b), c_vec, cbsz, blgp, opselA, a_scale, opselB, b_scale);
1599 #else
1600  ck_tile::ignore = c_vec;
1601  ck_tile::ignore = a_vec;
1602  ck_tile::ignore = b_vec;
1603  ck_tile::ignore = a_scale;
1604  ck_tile::ignore = b_scale;
1605 #endif
1606  }
1607 
1608  // c_vec = a_vec * b_vec
1609  template <index_t opselA, index_t opselB>
1611  const int32_t& a_scale,
1612  const BVecType& b_vec,
1613  const int32_t& b_scale) const
1614  {
1615  CVecType c_vec{0.f};
1616  operator()<opselA, opselB>(c_vec, a_vec, a_scale, b_vec, b_scale);
1617  return c_vec;
1618  }
1619 
1620  // c_vec += a_vec * b_vec
1621  template <bool post_nop_ = false>
1623  const AVecType& a_vec,
1624  const BVecType& b_vec,
1625  bool_constant<post_nop_> = {}) const
1626  {
1627  operator()<0, 0>(c_vec, a_vec, 0, b_vec, 0);
1628  }
1629 
1630  // c_vec = a_vec * b_vec
1631  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1632  {
1633  return operator()<0, 0>(a_vec, 0, b_vec, 0);
1634  }
1635 };
1636 
1637 template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1639 {
1640  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1641  using ADataType = AType_;
1642  using BDataType = BType_;
1643  using CDataType = float;
1644 
1648 
1649  static constexpr index_t kM = 32;
1650  static constexpr index_t kN = 32;
1651  static constexpr index_t kK = 64;
1652 
1653  static constexpr index_t kAMBlock = 1;
1654  static constexpr index_t kBNBlock = 1;
1655 
1656  static constexpr index_t kAMLane = 32;
1657  static constexpr index_t kBNLane = 32;
1658  static constexpr index_t kABKLane = 2;
1659  static constexpr index_t kABKPerLane = 32;
1660 
1661  static constexpr index_t kCMLane = 2;
1662  static constexpr index_t kCNLane = 32;
1663  static constexpr index_t kCM0PerLane = 4;
1664  static constexpr index_t kCM1PerLane = 4;
1665 
1666  // c_vec += a_vec * b_vec
1667  template <bool post_nop_ = false>
1669  const AVecType& a_vec,
1670  const BVecType& b_vec,
1671  bool_constant<post_nop_> = {}) const
1672  {
1673  //__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a,
1674  // opsel, scale_b)
1675 #if defined(__gfx950__)
1676  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1677  c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1678  a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0);
1679  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1680  c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1681  a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0);
1682  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1683  c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1684  a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0);
1685  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1686  c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1687  a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0);
1688 #else
1689  ck_tile::ignore = c_vec;
1690  ck_tile::ignore = a_vec;
1691  ck_tile::ignore = b_vec;
1692 #endif
1693  }
1694 
1695  // c_vec = a_vec * b_vec
1696  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1697  {
1698 #if defined(__gfx950__)
1699  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1700  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1701  a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0));
1702  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1703  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1704  a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0));
1705  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1706  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1707  a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0));
1708  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1709  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1710  a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0));
1711 #else
1712  ck_tile::ignore = a_vec;
1713  ck_tile::ignore = b_vec;
1714  return CVecType{0.f};
1715 #endif
1716  }
1717 };
1718 
1719 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1722 
1723 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1726 
1727 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1730 
1731 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1734 
1735 // int8
1736 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1738 {
1739  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1743 
1747 
1748  static constexpr index_t kM = 32;
1749  static constexpr index_t kN = 32;
1750  static constexpr index_t kK = 16;
1751 
1752  static constexpr index_t kAMBlock = 1;
1753  static constexpr index_t kBNBlock = 1;
1754 
1755  static constexpr index_t kAMLane = 32;
1756  static constexpr index_t kBNLane = 32;
1757  static constexpr index_t kABKLane = 2;
1758  static constexpr index_t kABKPerLane = 8;
1759 
1760  static constexpr index_t kCMLane = 2;
1761  static constexpr index_t kCNLane = 32;
1762  static constexpr index_t kCM0PerLane = 4;
1763  static constexpr index_t kCM1PerLane = 4;
1764 
1765  // c_vec += a_vec * b_vec
1766  template <bool post_nop_ = false>
1768  const AVecType& a_vec,
1769  const BVecType& b_vec,
1770  bool_constant<post_nop_> = {}) const
1771  {
1772  DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x16_i8", Ctrl)
1773  else
1774  {
1775 #if defined(__gfx94__) or defined(__gfx95__)
1776  c_vec = __builtin_amdgcn_mfma_i32_32x32x16_i8(
1777  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1778 #elif defined(__gfx908__) || defined(__gfx90a__)
1779  static_for<0, 8, 1>{}([&](auto k) {
1780  float a_f32 =
1781  type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1782  .template get_as<ADataType>()[number<k>{}]);
1783  float b_f32 =
1784  type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1785  .template get_as<BDataType>()[number<k>{}]);
1786 
1787  c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
1788  });
1789 #else
1790  ck_tile::ignore = c_vec;
1791  ck_tile::ignore = a_vec;
1792  ck_tile::ignore = b_vec;
1793 #endif
1794  }
1795  }
1796 
1797  // c_vec = a_vec * b_vec
1798  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1799  {
1800  CVecType c_vec{0};
1801  operator()(c_vec, a_vec, b_vec);
1802  return c_vec;
1803  }
1804 };
1805 
1806 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1808 {
1809  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1813 
1817 
1818  static constexpr index_t kM = 16;
1819  static constexpr index_t kN = 16;
1820  static constexpr index_t kK = 32;
1821 
1822  static constexpr index_t kAMBlock = 1;
1823  static constexpr index_t kBNBlock = 1;
1824 
1825  static constexpr index_t kAMLane = 16;
1826  static constexpr index_t kBNLane = 16;
1827  static constexpr index_t kABKLane = 4;
1828  static constexpr index_t kABKPerLane = 8;
1829 
1830  static constexpr index_t kCMLane = 4;
1831  static constexpr index_t kCNLane = 16;
1832  static constexpr index_t kCM0PerLane = 1;
1833  static constexpr index_t kCM1PerLane = 4; // write to 4x AccVGPRs
1834 
1835  // c_vec += a_vec * b_vec
1836  template <bool post_nop_ = false>
1838  const AVecType& a_vec,
1839  const BVecType& b_vec,
1840  bool_constant<post_nop_> = {}) const
1841  {
1842  DISPATCH_MFMA_CTRL_("v_mfma_i32_16x16x32_i8", Ctrl)
1843  else
1844  {
1845 #if defined(__gfx94__) or defined(__gfx95__)
1846  c_vec = __builtin_amdgcn_mfma_i32_16x16x32_i8(
1847  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1848 #else
1849  ck_tile::ignore = c_vec;
1850  ck_tile::ignore = a_vec;
1851  ck_tile::ignore = b_vec;
1852 #endif
1853  }
1854  }
1855 
1856  // c_vec = a_vec * b_vec
1857  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1858  {
1859  CVecType c_vec{0};
1860  operator()(c_vec, a_vec, b_vec);
1861  return c_vec;
1862  }
1863 };
1864 
1865 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1867 {
1868  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1872 
1876 
1877  static constexpr index_t kM = 16;
1878  static constexpr index_t kN = 16;
1879  static constexpr index_t kK = 64;
1880 
1881  static constexpr index_t kAMBlock = 1;
1882  static constexpr index_t kBNBlock = 1;
1883 
1884  static constexpr index_t kAMLane = 16;
1885  static constexpr index_t kBNLane = 16;
1886  static constexpr index_t kABKLane = 4;
1887  static constexpr index_t kABKPerLane = 16;
1888 
1889  static constexpr index_t kCMLane = 4;
1890  static constexpr index_t kCNLane = 16;
1891  static constexpr index_t kCM0PerLane = 1;
1892  static constexpr index_t kCM1PerLane = 4; // write to 4x AccVGPRs
1893 
1894  // c_vec += a_vec * b_vec
1895  template <bool post_nop_ = false>
1897  const AVecType& a_vec,
1898  const BVecType& b_vec,
1899  bool_constant<post_nop_> = {}) const
1900  {
1901  DISPATCH_MFMA_CTRL_("v_mfma_i32_16x16x64_i8", Ctrl)
1902  else
1903  {
1904 #if defined(__gfx95__)
1905  c_vec = __builtin_amdgcn_mfma_i32_16x16x64_i8(
1906  bit_cast<int64_t>(a_vec), bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1907 #else
1908  ck_tile::ignore = c_vec;
1909  ck_tile::ignore = a_vec;
1910  ck_tile::ignore = b_vec;
1911 #endif
1912  }
1913  }
1914 
1915  // c_vec = a_vec * b_vec
1916  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1917  {
1918  CVecType c_vec{0};
1919  operator()(c_vec, a_vec, b_vec);
1920  return c_vec;
1921  }
1922 };
1923 
1924 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1926 {
1927  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1931 
1935 
1936  static constexpr index_t kM = 32;
1937  static constexpr index_t kN = 32;
1938  static constexpr index_t kK = 32;
1939 
1940  static constexpr index_t kAMBlock = 1;
1941  static constexpr index_t kBNBlock = 1;
1942 
1943  static constexpr index_t kAMLane = 32;
1944  static constexpr index_t kBNLane = 32;
1945  static constexpr index_t kABKLane = 2;
1946  static constexpr index_t kABKPerLane = 16;
1947 
1948  static constexpr index_t kCMLane = 2;
1949  static constexpr index_t kCNLane = 32;
1950  static constexpr index_t kCM0PerLane = 4;
1951  static constexpr index_t kCM1PerLane = 4;
1952 
1953  // c_vec += a_vec * b_vec
1954  template <bool post_nop_ = false>
1956  const AVecType& a_vec,
1957  const BVecType& b_vec,
1958  bool_constant<post_nop_> = {}) const
1959  {
1960  DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x32_i8", Ctrl)
1961  else
1962  {
1963 #if defined(__gfx95__)
1964  c_vec = __builtin_amdgcn_mfma_i32_32x32x32_i8(
1965  a_vec, bit_cast<int64_t>(b_vec), c_vec, 0, 0, 0);
1966 #else
1967  ck_tile::ignore = c_vec;
1968  ck_tile::ignore = a_vec;
1969  ck_tile::ignore = b_vec;
1970 #endif
1971  }
1972  }
1973 
1974  // c_vec = a_vec * b_vec
1975  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1976  {
1977  CVecType c_vec{0};
1978  operator()(c_vec, a_vec, b_vec);
1979  return c_vec;
1980  }
1981 };
1982 
1983 #undef DISPATCH_MFMA_
1984 
1985 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
Definition: cluster_descriptor.hpp:13
_BitInt(8) fp8_t
Definition: float8.hpp:204
WGAttrCtlEnum
Definition: warp_gemm_attribute_mfma_impl.hpp:15
int32_t int32x8_t
Definition: vector_type.hpp:156
constexpr CK_TILE_HOST_DEVICE Y bit_cast(const X &x)
Definition: bit_cast.hpp:11
_Float16 fp16_t
Definition: half.hpp:110
int8_t int8_t
Definition: int8.hpp:20
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
pk_float4_e2m1_t pk_fp4_t
Definition: pk_fp4.hpp:151
constant< v > number
Definition: integral_constant.hpp:37
constexpr detail::ignore_t ignore
Definition: ignore.hpp:24
typename impl::ext_vector< T, N >::type ext_vector_t
Definition: vector_type.hpp:84
int32_t int32_t
Definition: integer.hpp:10
float fp32x16_t
Definition: vector_type.hpp:130
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
int32_t int32x4_t
Definition: vector_type.hpp:155
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
float fp32x4_t
Definition: vector_type.hpp:128
constexpr bool is_same_v
Definition: type.hpp:283
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1697
Definition: warp_gemm_attribute_mfma_impl.hpp:1531
ext_vector_t< BDataType, 32/numeric_traits< BDataType >::PackedSize > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1538
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1545
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1543
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1535
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1631
ext_vector_t< ADataType, 32/numeric_traits< ADataType >::PackedSize > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1537
ext_vector_t< CDataType, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1539
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1541
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1622
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const int32_t &a_scale, const BVecType &b_vec, const int32_t &b_scale, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1560
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1550
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1551
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1554
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1555
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1548
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1532
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1542
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1549
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1546
AType_ ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1533
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1556
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const int32_t &a_scale, const BVecType &b_vec, const int32_t &b_scale) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1610
BType_ BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1534
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1553
Definition: warp_gemm_attribute_mfma_impl.hpp:1164
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1178
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1179
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1188
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1174
AType_ ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1166
ext_vector_t< CDataType, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1172
ext_vector_t< ADataType, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1170
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1165
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1186
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1183
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1193
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1189
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1168
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1298
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1182
BType_ BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1167
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1187
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1184
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1176
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1181
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1175
ext_vector_t< BDataType, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1171
Definition: warp_gemm_attribute_mfma_impl.hpp:1323
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1345
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1324
ext_vector_t< ADataType, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1329
BType_ BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1326
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1348
ext_vector_t< CDataType, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1331
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1343
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1327
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1337
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1346
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1352
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1468
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1340
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1333
AType_ ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1325
ext_vector_t< BDataType, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1330
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1335
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1334
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1347
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1338
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1342
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1341
Definition: warp_gemm_attribute_mfma_impl.hpp:1639
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1696
BType_ BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1642
ext_vector_t< CDataType, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1647
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1643
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1668
ext_vector_t< BDataType, 32 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1646
AType_ ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1641
ext_vector_t< ADataType, 32 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1645
Definition: warp_gemm_attribute_mfma_impl.hpp:1808
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1857
ext_vector_t< CDataType, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1816
ext_vector_t< BDataType, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1815
ext_vector_t< ADataType, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1814
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1837
int8_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1810
int32_t CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1812
int8_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1811
Definition: warp_gemm_attribute_mfma_impl.hpp:1867
int32_t CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1871
ext_vector_t< ADataType, 16 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1873
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1896
int8_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1870
ext_vector_t< BDataType, 16 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1874
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1916
int8_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1869
ext_vector_t< CDataType, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1875
Definition: warp_gemm_attribute_mfma_impl.hpp:1738
int8_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1740
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1767
ext_vector_t< ADataType, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1744
ext_vector_t< BDataType, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1745
int32_t CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1742
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1798
int8_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1741
ext_vector_t< CDataType, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1746
Definition: warp_gemm_attribute_mfma_impl.hpp:1926
int8_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1928
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1955
int8_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1929
int32_t CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1930
ext_vector_t< BDataType, 16 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1933
ext_vector_t< ADataType, 16 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1932
ext_vector_t< CDataType, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1934
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1975
Definition: warp_gemm_attribute_mfma_impl.hpp:666
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:674
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:684
ext_vector_t< bf16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:673
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:695
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:688
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:725
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:678
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:683
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:667
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:689
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:690
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:685
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:680
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:670
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:691
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:677
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:669
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:681
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:668
ext_vector_t< bf16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:672
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:686
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:676
Definition: warp_gemm_attribute_mfma_impl.hpp:196
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:219
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:199
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:204
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:207
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:210
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:220
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:215
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:206
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:208
ext_vector_t< bf16_t, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:202
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:197
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:216
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:200
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:214
ext_vector_t< bf16_t, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:203
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:218
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:198
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:225
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:221
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:244
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:211
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:213
Definition: warp_gemm_attribute_mfma_impl.hpp:1049
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1072
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1078
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1053
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1050
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1052
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1063
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1064
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1067
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1074
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1121
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1068
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1073
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1066
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1059
ext_vector_t< bf16_t, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1055
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1060
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1071
ext_vector_t< bf16_t, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1056
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1051
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1069
ext_vector_t< float, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1057
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1061
Definition: warp_gemm_attribute_mfma_impl.hpp:577
ext_vector_t< bf16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:584
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:580
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:581
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:592
ext_vector_t< float, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:585
ext_vector_t< bf16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:583
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:637
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:588
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:601
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:587
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:600
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:591
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:596
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:595
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:594
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:589
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:599
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:579
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:578
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:606
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:602
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:597
Definition: warp_gemm_attribute_mfma_impl.hpp:754
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:775
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:756
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:778
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:777
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:779
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:815
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:757
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:766
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:755
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:773
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:784
ext_vector_t< bf16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:760
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:765
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:764
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:774
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:769
ext_vector_t< bf16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:761
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:780
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:762
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:772
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:768
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:758
Definition: warp_gemm_attribute_mfma_impl.hpp:844
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:847
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:867
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:862
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:905
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:845
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:868
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:855
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:846
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:858
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:869
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:859
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:863
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:856
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:848
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:864
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:870
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:865
ext_vector_t< bf16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:850
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:874
ext_vector_t< bf16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:851
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:854
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:852
Definition: warp_gemm_attribute_mfma_impl.hpp:322
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:323
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:339
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:326
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:345
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:336
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:325
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:344
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:347
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:328
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:337
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:333
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:324
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:341
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:340
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:330
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:346
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:342
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:370
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:351
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:332
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:334
ext_vector_t< fp16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:329
Definition: warp_gemm_attribute_mfma_impl.hpp:385
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:396
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:408
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:410
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:387
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:407
ext_vector_t< fp16_t, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:391
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:404
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:405
ext_vector_t< fp16_t, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:392
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:395
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:393
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:433
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:388
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:389
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:402
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:409
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:400
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:399
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:386
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:403
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:397
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:414
Definition: warp_gemm_attribute_mfma_impl.hpp:935
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:936
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:937
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:949
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:957
ext_vector_t< fp16_t, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:941
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:958
ext_vector_t< fp16_t, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:942
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:945
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:939
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:964
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:946
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:955
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1007
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:954
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:953
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:960
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:950
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:947
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:952
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:959
ext_vector_t< float, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:943
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:938
Definition: warp_gemm_attribute_mfma_impl.hpp:259
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:282
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:276
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:283
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:281
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:288
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:284
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:263
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:271
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:262
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:277
ext_vector_t< float, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:267
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:274
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:260
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:269
ext_vector_t< fp16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:266
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:270
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:307
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:273
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:265
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:261
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:278
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:279
Definition: warp_gemm_attribute_mfma_impl.hpp:448
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:459
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:471
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:450
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:454
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:497
ext_vector_t< fp16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:455
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:463
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:452
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:478
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:456
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:466
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:467
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:472
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:451
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:460
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:468
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:458
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:473
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:449
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:469
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:474
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:462
Definition: warp_gemm_attribute_mfma_impl.hpp:512
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:533
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:516
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:522
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:530
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:515
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:513
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:527
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:542
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:524
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:526
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:520
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:523
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:561
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:535
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:518
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:531
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:536
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:537
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:532
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:514
ext_vector_t< fp16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:519
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:538
Definition: warp_gemm_attribute_mfma_impl.hpp:67
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:68
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:87
ext_vector_t< ADataType, 1 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:74
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:80
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:78
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:97
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:90
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:93
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:85
float ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:70
ext_vector_t< CDataType, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:76
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:83
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:82
float BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:71
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:92
ext_vector_t< BDataType, 1 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:75
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:88
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:72
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:116
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:91
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:79
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:86
Definition: warp_gemm_attribute_mfma_impl.hpp:131
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:156
float BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:135
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:142
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:157
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:155
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:132
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:161
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:136
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:149
ext_vector_t< CDataType, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:140
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:154
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:150
ext_vector_t< ADataType, 1 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:138
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:180
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:152
ext_vector_t< BDataType, 1 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:139
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:143
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:146
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:151
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:147
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:144
float ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:134
Definition: integral_constant.hpp:13
Definition: numeric.hpp:81
Definition: functional.hpp:43
Definition: debug.hpp:27
#define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_)
Definition: warp_gemm_attribute_mfma_impl.hpp:25
#define DISPATCH_MFMA_CTRL_(mfma_, ctrl_)
Definition: warp_gemm_attribute_mfma_impl.hpp:42