/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp Source File
warp_gemm_attribute_mfma_impl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
10 // TODO: refactor warp-gemm
11 // currently there is a discrepency for vav/vva if we need transpose C/D
12 // e.g. if we want A:agpr, B:vgpr, we have to use vva in WGAttrEnum
13 // because we swap the A/B pointer in _impl code (but not known this info here)
14 enum class WGAttrCtlEnum
15 {
16  Default_ = 0,
17  Raw_vvv = 1, // c-vgpr, a-vgpr, b-vgpr
18  Raw_vaa = 2, // c-vgpr, a-agpr, b-agpr
19  Raw_vav = 3, // c-vgpr, a-agpr, b-vgpr
20  Raw_vva = 4, // c-vgpr, a-vgpr, b-agpr
21  Raw_avv = 5, // c-agpr, a-vgpr, b-vgpr
22  // raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr
23 };
24 
25 #define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_) \
26  if constexpr(post_nop_) \
27  { \
28  asm volatile(mfma_ " %0, %1, %2, %3 ; yyy\n" \
29  "s_nop 3" \
30  : dmod_(c_vec) \
31  : amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
32  :); \
33  } \
34  else \
35  { \
36  asm volatile(mfma_ " %0, %1, %2, %3\n" \
37  : dmod_(c_vec) \
38  : amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
39  :); \
40  }
41 
42 #define DISPATCH_MFMA_CTRL_(mfma_, ctrl_) \
43  if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vvv) \
44  { \
45  DISPATCH_MFMA_(mfma_, "+v", "v", "v", "v") \
46  } \
47  else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vaa) \
48  { \
49  DISPATCH_MFMA_(mfma_, "+v", "a", "a", "v") \
50  } \
51  else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vav) \
52  { \
53  DISPATCH_MFMA_(mfma_, "+v", "a", "v", "v") \
54  } \
55  else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vva) \
56  { \
57  DISPATCH_MFMA_(mfma_, "+v", "v", "a", "v") \
58  } \
59  else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_avv) \
60  { \
61  DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a") \
62  }
63 
64 // V_MFMA_F32_16x16x32_BF16
65 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
67 {
68  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
69  using ADataType = bf16_t;
70  using BDataType = bf16_t;
71  using CDataType = float;
72 
76 
77  static constexpr index_t kM = 16;
78  static constexpr index_t kN = 16;
79  static constexpr index_t kK = 32;
80 
81  static constexpr index_t kAMBlock = 1;
82  static constexpr index_t kBNBlock = 1;
83 
84  static constexpr index_t kAMLane = 16;
85  static constexpr index_t kBNLane = 16;
86  static constexpr index_t kABKLane = 4;
87  static constexpr index_t kABKPerLane = 8;
88 
89  static constexpr index_t kCMLane = 4;
90  static constexpr index_t kCNLane = 16;
91  static constexpr index_t kCM0PerLane = 1;
92  static constexpr index_t kCM1PerLane = 4;
93 
94  // c_vec += a_vec * b_vec
95  template <bool post_nop_ = false>
97  const AVecType& a_vec,
98  const BVecType& b_vec,
99  bool_constant<post_nop_> = {}) const
100  {
101  DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x32_bf16", Ctrl)
102  else
103  {
104 #if defined(__gfx950__)
105  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_vec, b_vec, c_vec, 0, 0, 0);
106 #else
107  ck_tile::ignore = c_vec;
108  ck_tile::ignore = a_vec;
109  ck_tile::ignore = b_vec;
110 #endif
111  }
112  }
113 
114  // c_vec = a_vec * b_vec
115  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
116  {
117 #if defined(__gfx950__)
118  return bit_cast<CVecType>(
119  __builtin_amdgcn_mfma_f32_16x16x32_bf16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
120 #else
121  ck_tile::ignore = a_vec;
122  ck_tile::ignore = b_vec;
123  return CVecType{0.f};
124 #endif
125  }
126 };
127 // FP16
128 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
130 {
131  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
132  using ADataType = fp16_t;
133  using BDataType = fp16_t;
134  using CDataType = float;
135 
139 
140  static constexpr index_t kM = 32;
141  static constexpr index_t kN = 32;
142  static constexpr index_t kK = 8;
143 
144  static constexpr index_t kAMBlock = 1;
145  static constexpr index_t kBNBlock = 1;
146 
147  static constexpr index_t kAMLane = 32;
148  static constexpr index_t kBNLane = 32;
149  static constexpr index_t kABKLane = 2;
150  static constexpr index_t kABKPerLane = 4;
151 
152  static constexpr index_t kCMLane = 2;
153  static constexpr index_t kCNLane = 32;
154  static constexpr index_t kCM0PerLane = 4;
155  static constexpr index_t kCM1PerLane = 4;
156 
157  // c_vec += a_vec * b_vec
158  template <bool post_nop_ = false>
160  const AVecType& a_vec,
161  const BVecType& b_vec,
162  bool_constant<post_nop_> = {}) const
163  {
164  DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8f16", Ctrl)
165  else
166  {
167 #if defined(__gfx9__)
168  c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
169 #else
170  ck_tile::ignore = c_vec;
171  ck_tile::ignore = a_vec;
172  ck_tile::ignore = b_vec;
173 #endif
174  }
175  }
176 
177  // c_vec = a_vec * b_vec
178  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
179  {
180 #if defined(__gfx9__)
181  return bit_cast<CVecType>(
182  __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
183 #else
184  ck_tile::ignore = a_vec;
185  ck_tile::ignore = b_vec;
186  return CVecType{0.f};
187 #endif
188  }
189 };
190 
191 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
193 {
194  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
195  using ADataType = fp16_t;
196  using BDataType = fp16_t;
197  using CDataType = float;
198 
202 
203  static constexpr index_t kM = 16;
204  static constexpr index_t kN = 16;
205  static constexpr index_t kK = 16;
206 
207  static constexpr index_t kAMBlock = 1;
208  static constexpr index_t kBNBlock = 1;
209 
210  static constexpr index_t kAMLane = 16;
211  static constexpr index_t kBNLane = 16;
212  static constexpr index_t kABKLane = 4;
213  static constexpr index_t kABKPerLane = 4;
214 
215  static constexpr index_t kCMLane = 4;
216  static constexpr index_t kCNLane = 16;
217  static constexpr index_t kCM0PerLane = 1;
218  static constexpr index_t kCM1PerLane = 4;
219 
220  // c_vec += a_vec * b_vec
221  template <bool post_nop_ = false>
223  const AVecType& a_vec,
224  const BVecType& b_vec,
225  bool_constant<post_nop_> = {}) const
226  {
227  DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16f16", Ctrl)
228  else
229  {
230 #if defined(__gfx9__)
231  c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
232 #else
233  ck_tile::ignore = c_vec;
234  ck_tile::ignore = a_vec;
235  ck_tile::ignore = b_vec;
236 #endif
237  }
238  }
239 
240  // c_vec = a_vec * b_vec
241  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
242  {
243 #if defined(__gfx9__)
244  return bit_cast<CVecType>(
245  __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
246 #else
247  ck_tile::ignore = a_vec;
248  ck_tile::ignore = b_vec;
249  return CVecType{0.f};
250 #endif
251  }
252 };
253 
254 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
256 {
257  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
258  using ADataType = fp16_t;
259  using BDataType = fp16_t;
260  using CDataType = float;
261 
265 
266  static constexpr index_t kM = 16;
267  static constexpr index_t kN = 16;
268  static constexpr index_t kK = 32;
269 
270  static constexpr index_t kAMBlock = 1;
271  static constexpr index_t kBNBlock = 1;
272 
273  static constexpr index_t kAMLane = 16;
274  static constexpr index_t kBNLane = 16;
275  static constexpr index_t kABKLane = 4;
276  static constexpr index_t kABKPerLane = 8;
277 
278  static constexpr index_t kCMLane = 4;
279  static constexpr index_t kCNLane = 16;
280  static constexpr index_t kCM0PerLane = 1;
281  static constexpr index_t kCM1PerLane = 4;
282 
283  // c_vec += a_vec * b_vec
284  template <bool post_nop_ = false>
286  const AVecType& a_vec,
287  const BVecType& b_vec,
288  bool_constant<post_nop_> = {}) const
289  {
290  DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x32f16", Ctrl)
291  else
292  {
293 #if defined(__gfx950__)
294  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_f16(a_vec, b_vec, c_vec, 0, 0, 0);
295 #else
296  ck_tile::ignore = c_vec;
297  ck_tile::ignore = a_vec;
298  ck_tile::ignore = b_vec;
299 #endif
300  }
301  }
302 
303  // c_vec = a_vec * b_vec
304  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
305  {
306 #if defined(__gfx950__)
307  return bit_cast<CVecType>(
308  __builtin_amdgcn_mfma_f32_16x16x32_f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
309 #else
310  ck_tile::ignore = a_vec;
311  ck_tile::ignore = b_vec;
312  return CVecType{0.f};
313 #endif
314  }
315 };
316 
317 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
319 {
320  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
321  using ADataType = fp16_t;
322  using BDataType = fp16_t;
323  using CDataType = float;
324 
328 
329  static constexpr index_t kM = 4;
330  static constexpr index_t kN = 64;
331  static constexpr index_t kK = 4;
332 
333  static constexpr index_t kAMBlock = 1;
334  static constexpr index_t kBNBlock = 16;
335 
336  // we only write down single block (4 threads) thread mapping here
337  static constexpr index_t kAMLane = 4;
338  static constexpr index_t kBNLane = 4;
339  static constexpr index_t kABKLane = 1;
340  static constexpr index_t kABKPerLane = 4;
341 
342  static constexpr index_t kCMLane = 1;
343  static constexpr index_t kCNLane = 4;
344  static constexpr index_t kCM0PerLane = 1;
345  static constexpr index_t kCM1PerLane = 4;
346 
347  // c_vec += a_vec * b_vec
348  template <bool post_nop_ = false>
350  const AVecType& a_vec,
351  const BVecType& b_vec,
352  bool_constant<post_nop_> = {}) const
353  {
354  DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4f16", Ctrl)
355  else
356  {
357 #if defined(__gfx9__)
358  c_vec = __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, c_vec, 0, 0, 0);
359 #else
360  ignore = c_vec;
361  ignore = a_vec;
362  ignore = b_vec;
363 #endif
364  }
365  }
366 
367  // c_vec = a_vec * b_vec
368  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
369  {
370 #if defined(__gfx9__)
371  return bit_cast<CVecType>(
372  __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
373 #else
374  ignore = a_vec;
375  ignore = b_vec;
376  return CVecType{0.f};
377 #endif
378  }
379 };
380 
381 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
383 {
384  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
385  using ADataType = fp16_t;
386  using BDataType = fp16_t;
387  using CDataType = float;
388 
392 
393  static constexpr index_t kM = 64;
394  static constexpr index_t kN = 4;
395  static constexpr index_t kK = 4;
396 
397  static constexpr index_t kAMBlock = 16;
398  static constexpr index_t kBNBlock = 1;
399 
400  // we only write down single block (4 threads) thread mapping here
401  static constexpr index_t kAMLane = 4;
402  static constexpr index_t kBNLane = 4;
403  static constexpr index_t kABKLane = 1;
404  static constexpr index_t kABKPerLane = 4;
405 
406  static constexpr index_t kCMLane = 1;
407  static constexpr index_t kCNLane = 4;
408  static constexpr index_t kCM0PerLane = 1;
409  static constexpr index_t kCM1PerLane = 4;
410 
411  // c_vec += a_vec * b_vec
412  template <bool post_nop_ = false>
414  const AVecType& a_vec,
415  const BVecType& b_vec,
416  bool_constant<post_nop_> = {}) const
417  {
418  DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4f16", Ctrl)
419  else
420  {
421 #if defined(__gfx9__)
422  c_vec = __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, c_vec, 0, 0, 0);
423 #else
424  ignore = c_vec;
425  ignore = a_vec;
426  ignore = b_vec;
427 #endif
428  }
429  }
430 
431  // c_vec = a_vec * b_vec
432  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
433  {
434 #if defined(__gfx9__)
435  return bit_cast<CVecType>(
436  __builtin_amdgcn_mfma_f32_4x4x4f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
437 #else
438  ignore = a_vec;
439  ignore = b_vec;
440  return CVecType{0.f};
441 #endif
442  }
443 };
444 
445 // Bf16
446 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
448 {
449  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
450  using ADataType = bf16_t;
451  using BDataType = bf16_t;
452  using CDataType = float;
453 
457 
458  static constexpr index_t kM = 32;
459  static constexpr index_t kN = 32;
460  static constexpr index_t kK = 8;
461 
462  static constexpr index_t kAMBlock = 1;
463  static constexpr index_t kBNBlock = 1;
464 
465  static constexpr index_t kAMLane = 32;
466  static constexpr index_t kBNLane = 32;
467  static constexpr index_t kABKLane = 2;
468  static constexpr index_t kABKPerLane = 4;
469 
470  static constexpr index_t kCMLane = 2;
471  static constexpr index_t kCNLane = 32;
472  static constexpr index_t kCM0PerLane = 4;
473  static constexpr index_t kCM1PerLane = 4;
474 
475  // c_vec += a_vec * b_vec
476  template <bool post_nop_ = false>
478  const AVecType& a_vec,
479  const BVecType& b_vec,
480  bool_constant<post_nop_> = {}) const
481  {
482  DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8bf16_1k", Ctrl)
483  else
484  {
485 #if defined(__gfx90a__) || defined(__gfx94__)
486  c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
487 #elif defined(__gfx908__)
488  static_for<0, 2, 1>{}([&](auto k) {
489  c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
490  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
491  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
492  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
493  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
494  c_vec,
495  0,
496  0,
497  0);
498  });
499 #else
500  ck_tile::ignore = c_vec;
501  ck_tile::ignore = a_vec;
502  ck_tile::ignore = b_vec;
503 #endif
504  }
505  }
506 
507  // c_vec = a_vec * b_vec
508  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
509  {
510 #if defined(__gfx90a__) || defined(__gfx94__)
511  return bit_cast<CVecType>(
512  __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
513 #elif defined(__gfx908__)
514  CVecType c_vec{0.f};
515  static_for<0, 2, 1>{}([&](auto k) {
516  c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
517  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
518  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
519  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
520  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
521  c_vec,
522  0,
523  0,
524  0);
525  });
526  return c_vec;
527 #else
528  ck_tile::ignore = a_vec;
529  ck_tile::ignore = b_vec;
530  return CVecType{0.f};
531 #endif
532  }
533 };
534 
535 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
537 {
538  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
539  using ADataType = bf16_t;
540  using BDataType = bf16_t;
541  using CDataType = float;
542 
546 
547  static constexpr index_t kM = 16;
548  static constexpr index_t kN = 16;
549  static constexpr index_t kK = 16;
550 
551  static constexpr index_t kAMBlock = 1;
552  static constexpr index_t kBNBlock = 1;
553 
554  static constexpr index_t kAMLane = 16;
555  static constexpr index_t kBNLane = 16;
556  static constexpr index_t kABKLane = 4;
557  static constexpr index_t kABKPerLane = 4;
558 
559  static constexpr index_t kCMLane = 4;
560  static constexpr index_t kCNLane = 16;
561  static constexpr index_t kCM0PerLane = 1;
562  static constexpr index_t kCM1PerLane = 4;
563 
564  // c_vec += a_vec * b_vec
565  template <bool post_nop_ = false>
567  const AVecType& a_vec,
568  const BVecType& b_vec,
569  bool_constant<post_nop_> = {}) const
570  {
571  DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16bf16_1k", Ctrl)
572  {
573 #if defined(__gfx90a__) || defined(__gfx94__)
574  c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
575 #elif defined(__gfx908__)
576  static_for<0, 2, 1>{}([&](auto k) {
577  c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
578  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
579  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
580  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
581  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
582  c_vec,
583  0,
584  0,
585  0);
586  });
587 #else
588  ck_tile::ignore = c_vec;
589  ck_tile::ignore = a_vec;
590  ck_tile::ignore = b_vec;
591 #endif
592  }
593  }
594 
595  // c_vec = a_vec * b_vec
596  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
597  {
598 #if defined(__gfx90a__) || defined(__gfx94__)
599  return bit_cast<CVecType>(
600  __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
601 #elif defined(__gfx908__)
602  CVecType c_vec{0.f};
603  static_for<0, 2, 1>{}([&](auto k) {
604  c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
605  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
606  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
607  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
608  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
609  c_vec,
610  0,
611  0,
612  0);
613  });
614  return c_vec;
615 #else
616  ck_tile::ignore = a_vec;
617  ck_tile::ignore = b_vec;
618  return CVecType{0.f};
619 #endif
620  }
621 };
622 
623 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
625 {
626  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
627  using ADataType = bf16_t;
628  using BDataType = bf16_t;
629  using CDataType = float;
630 
634 
635  static constexpr index_t kM = 4;
636  static constexpr index_t kN = 64;
637  static constexpr index_t kK = 4;
638 
639  static constexpr index_t kAMBlock = 1;
640  static constexpr index_t kBNBlock = 16;
641 
642  // we only write down single block (4 threads) thread mapping here
643  static constexpr index_t kAMLane = 4;
644  static constexpr index_t kBNLane = 4;
645  static constexpr index_t kABKLane = 1;
646  static constexpr index_t kABKPerLane = 4;
647 
648  static constexpr index_t kCMLane = 1;
649  static constexpr index_t kCNLane = 4;
650  static constexpr index_t kCM0PerLane = 1;
651  static constexpr index_t kCM1PerLane = 4;
652 
653  // c_vec += a_vec * b_vec
654  template <bool post_nop_ = false>
656  const AVecType& a_vec,
657  const BVecType& b_vec,
658  bool_constant<post_nop_> = {}) const
659  {
660  DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl)
661  else
662  {
663 #if defined(__gfx9__)
664  c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
665 #else
666  ignore = c_vec;
667  ignore = a_vec;
668  ignore = b_vec;
669 #endif
670  }
671  }
672 
673  // c_vec = a_vec * b_vec
674  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
675  {
676 #if defined(__gfx9__)
677  return bit_cast<CVecType>(
678  __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
679 #else
680  ignore = a_vec;
681  ignore = b_vec;
682  return CVecType{0.f};
683 #endif
684  }
685 };
686 
687 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
689 {
690  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
691  using ADataType = bf16_t;
692  using BDataType = bf16_t;
693  using CDataType = float;
694 
698 
699  static constexpr index_t kM = 64;
700  static constexpr index_t kN = 4;
701  static constexpr index_t kK = 4;
702 
703  static constexpr index_t kAMBlock = 16;
704  static constexpr index_t kBNBlock = 1;
705 
706  // we only write down single block (4 threads) thread mapping here
707  static constexpr index_t kAMLane = 4;
708  static constexpr index_t kBNLane = 4;
709  static constexpr index_t kABKLane = 1;
710  static constexpr index_t kABKPerLane = 4;
711 
712  static constexpr index_t kCMLane = 1;
713  static constexpr index_t kCNLane = 4;
714  static constexpr index_t kCM0PerLane = 1;
715  static constexpr index_t kCM1PerLane = 4;
716 
717  // c_vec += a_vec * b_vec
718  template <bool post_nop_ = false>
720  const AVecType& a_vec,
721  const BVecType& b_vec,
722  bool_constant<post_nop_> = {}) const
723  {
724  DISPATCH_MFMA_CTRL_("v_mfma_f32_4x4x4bf16_1k", Ctrl)
725  else
726  {
727 #if defined(__gfx9__)
728  c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
729 #else
730  ignore = c_vec;
731  ignore = a_vec;
732  ignore = b_vec;
733 #endif
734  }
735  }
736 
737  // c_vec = a_vec * b_vec
738  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
739  {
740 #if defined(__gfx9__)
741  return bit_cast<CVecType>(
742  __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
743 #else
744  ignore = a_vec;
745  ignore = b_vec;
746  return CVecType{0.f};
747 #endif
748  }
749 };
750 
751 // gfx950
752 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
754 {
755  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
756  using ADataType = fp16_t;
757  using BDataType = fp16_t;
758  using CDataType = float;
759 
763 
764  static constexpr index_t kM = 32;
765  static constexpr index_t kN = 32;
766  static constexpr index_t kK = 16;
767 
768  static constexpr index_t kAMBlock = 1;
769  static constexpr index_t kBNBlock = 1;
770 
771  static constexpr index_t kAMLane = 32;
772  static constexpr index_t kBNLane = 32;
773  static constexpr index_t kABKLane = 2;
774  static constexpr index_t kABKPerLane = 8;
775 
776  static constexpr index_t kCMLane = 2;
777  static constexpr index_t kCNLane = 32;
778  static constexpr index_t kCM0PerLane = 4;
779  static constexpr index_t kCM1PerLane = 4;
780 
781  // c_vec += a_vec * b_vec
782  template <bool post_nop_ = false>
784  const AVecType& a_vec,
785  const BVecType& b_vec,
786  bool_constant<post_nop_> = {}) const
787  {
788  DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x16_f16", Ctrl)
789  else
790  {
791 #if defined(__gfx950__)
792  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_f16(a_vec, b_vec, c_vec, 0, 0, 0);
793 #elif defined(__gfx90a__) || defined(__gfx94__)
794  static_for<0, 2, 1>{}([&](auto k) {
795  c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(
796  reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
797  .template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
798  reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
799  .template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
800  c_vec,
801  0,
802  0,
803  0);
804  });
805 #elif defined(__gfx908__)
806  static_for<0, 4, 1>{}([&](auto k) {
807  c_vec = __builtin_amdgcn_mfma_f32_32x32x4f16(
808  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
809  .template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
810  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
811  .template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
812  c_vec,
813  0,
814  0,
815  0);
816  });
817 #else
818  ck_tile::ignore = c_vec;
819  ck_tile::ignore = a_vec;
820  ck_tile::ignore = b_vec;
821 #endif
822  }
823  }
824 
825  // c_vec = a_vec * b_vec
826  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
827  {
828 #if defined(__gfx950__)
829  return __builtin_amdgcn_mfma_f32_32x32x16_f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0);
830 #elif defined(__gfx90a__) || defined(__gfx94__)
831  CVecType c_vec{0.f};
832  static_for<0, 2, 1>{}([&](auto k) {
833  c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(
834  reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
835  .template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
836  reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
837  .template get_as<ext_vector_t<fp16_t, 4>>()[number<k>{}],
838  c_vec,
839  0,
840  0,
841  0);
842  });
843  return c_vec;
844 #elif defined(__gfx908__)
845  CVecType c_vec{0.f};
846  static_for<0, 4, 1>{}([&](auto k) {
847  c_vec = __builtin_amdgcn_mfma_f32_32x32x4f16(
848  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
849  .template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
850  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
851  .template get_as<ext_vector_t<fp16_t, 2>>()[number<k>{}],
852  c_vec,
853  0,
854  0,
855  0);
856  });
857  return c_vec;
858 #else
859  ck_tile::ignore = a_vec;
860  ck_tile::ignore = b_vec;
861  return CVecType{0.f};
862 #endif
863  }
864 };
865 
866 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
868 {
869  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
870  using ADataType = bf16_t;
871  using BDataType = bf16_t;
872  using CDataType = float;
873 
877 
878  static constexpr index_t kM = 32;
879  static constexpr index_t kN = 32;
880  static constexpr index_t kK = 16;
881 
882  static constexpr index_t kAMBlock = 1;
883  static constexpr index_t kBNBlock = 1;
884 
885  static constexpr index_t kAMLane = 32;
886  static constexpr index_t kBNLane = 32;
887  static constexpr index_t kABKLane = 2;
888  static constexpr index_t kABKPerLane = 8;
889 
890  static constexpr index_t kCMLane = 2;
891  static constexpr index_t kCNLane = 32;
892  static constexpr index_t kCM0PerLane = 4;
893  static constexpr index_t kCM1PerLane = 4;
894 
895  // c_vec += a_vec * b_vec
896  template <bool post_nop_ = false>
898  const AVecType& a_vec,
899  const BVecType& b_vec,
900  bool_constant<post_nop_> = {}) const
901  {
902  DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x16_bf16", Ctrl)
903  else
904  {
905 #if defined(__gfx950__)
906  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_vec, b_vec, c_vec, 0, 0, 0);
907 #elif defined(__gfx90a__) || defined(__gfx94__)
908  static_for<0, 2, 1>{}([&](auto k) {
909  c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
910  reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
911  .template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
912  reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
913  .template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
914  c_vec,
915  0,
916  0,
917  0);
918  });
919 #elif defined(__gfx908__)
920  static_for<0, 4, 1>{}([&](auto k) {
921  c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
922  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
923  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
924  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
925  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
926  c_vec,
927  0,
928  0,
929  0);
930  });
931 #else
932  ck_tile::ignore = c_vec;
933  ck_tile::ignore = a_vec;
934  ck_tile::ignore = b_vec;
935 #endif
936  }
937  }
938 
939  // c_vec = a_vec * b_vec
940  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
941  {
942 #if defined(__gfx950__)
943  return __builtin_amdgcn_mfma_f32_32x32x16_bf16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0);
944 #elif defined(__gfx90a__) || defined(__gfx94__)
945  CVecType c_vec{0.f};
946  static_for<0, 2, 1>{}([&](auto k) {
947  c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
948  reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
949  .template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
950  reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
951  .template get_as<ext_vector_t<bf16_t, 4>>()[number<k>{}],
952  c_vec,
953  0,
954  0,
955  0);
956  });
957  return c_vec;
958 #elif defined(__gfx908__)
959  CVecType c_vec{0.f};
960  static_for<0, 4, 1>{}([&](auto k) {
961  c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
962  reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
963  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
964  reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
965  .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
966  c_vec,
967  0,
968  0,
969  0);
970  });
971  return c_vec;
972 #else
973  ck_tile::ignore = a_vec;
974  ck_tile::ignore = b_vec;
975  return CVecType{0.f};
976 #endif
977  }
978 };
979 
980 // FP8
981 template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
983 {
984  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
985  using ADataType = AType_;
986  using BDataType = BType_;
987  using CDataType = float;
988 
992 
993  static constexpr index_t kM = 16;
994  static constexpr index_t kN = 16;
995  static constexpr index_t kK = 32;
996 
997  static constexpr index_t kAMBlock = 1;
998  static constexpr index_t kBNBlock = 1;
999 
1000  static constexpr index_t kAMLane = 16;
1001  static constexpr index_t kBNLane = 16;
1002  static constexpr index_t kABKLane = 4;
1003  static constexpr index_t kABKPerLane = 8;
1004 
1005  static constexpr index_t kCMLane = 4;
1006  static constexpr index_t kCNLane = 16;
1007  static constexpr index_t kCM0PerLane = 1;
1008  static constexpr index_t kCM1PerLane = 4;
1009 
1010  // c_vec += a_vec * b_vec
1011  template <bool post_nop_ = false>
1013  const AVecType& a_vec,
1014  const BVecType& b_vec,
1015  bool_constant<post_nop_> = {}) const
1016  {
1017  if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
1018  {
1019  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1020  {
1021  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "v", "v", "v")
1022  }
1023  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1024  {
1025  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "v", "v", "v")
1026  }
1027  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1028  {
1029  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "v", "v", "v")
1030  }
1031  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1032  {
1033  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "v", "v", "v")
1034  }
1035  }
1036  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
1037  {
1038  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1039  {
1040  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "a", "a", "v")
1041  }
1042  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1043  {
1044  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "a", "a", "v")
1045  }
1046  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1047  {
1048  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "a", "a", "v")
1049  }
1050  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1051  {
1052  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "a", "a", "v")
1053  }
1054  }
1055  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
1056  {
1057  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1058  {
1059  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "a", "v", "v")
1060  }
1061  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1062  {
1063  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "a", "v", "v")
1064  }
1065  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1066  {
1067  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "a", "v", "v")
1068  }
1069  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1070  {
1071  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "a", "v", "v")
1072  }
1073  }
1074  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
1075  {
1076  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1077  {
1078  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_fp8", "+v", "v", "a", "v")
1079  }
1080  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1081  {
1082  DISPATCH_MFMA_("mfma_f32_16x16x32_fp8_bf8", "+v", "v", "a", "v")
1083  }
1084  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1085  {
1086  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_fp8", "+v", "v", "a", "v")
1087  }
1088  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1089  {
1090  DISPATCH_MFMA_("mfma_f32_16x16x32_bf8_bf8", "+v", "v", "a", "v")
1091  }
1092  }
1093  else
1094  {
1095 #if defined(__gfx94__) or defined(__gfx95__)
1096  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1097  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
1098  bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
1099  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1100  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
1101  bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
1102  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1103  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
1104  bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
1105  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1106  c_vec = __builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
1107  bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
1108 #else
1109  ck_tile::ignore = c_vec;
1110  ck_tile::ignore = a_vec;
1111  ck_tile::ignore = b_vec;
1112 #endif
1113  }
1114  }
1115 
1116  // c_vec = a_vec * b_vec
1117  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1118  {
1119 #if defined(__gfx94__) or defined(__gfx95__)
1120  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1121  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(
1122  bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
1123  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1124  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_fp8_bf8(
1125  bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
1126  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1127  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_fp8(
1128  bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
1129  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1130  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x32_bf8_bf8(
1131  bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
1132 #else
1133  ck_tile::ignore = a_vec;
1134  ck_tile::ignore = b_vec;
1135  return CVecType{0.f};
1136 #endif
1137  }
1138 };
1139 
1140 template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1142 {
1143  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1144  using ADataType = AType_;
1145  using BDataType = BType_;
1146  using CDataType = float;
1147 
1151 
1152  static constexpr index_t kM = 32;
1153  static constexpr index_t kN = 32;
1154  static constexpr index_t kK = 16;
1155 
1156  static constexpr index_t kAMBlock = 1;
1157  static constexpr index_t kBNBlock = 1;
1158 
1159  static constexpr index_t kAMLane = 32;
1160  static constexpr index_t kBNLane = 32;
1161  static constexpr index_t kABKLane = 2;
1162  static constexpr index_t kABKPerLane = 8;
1163 
1164  static constexpr index_t kCMLane = 2;
1165  static constexpr index_t kCNLane = 32;
1166  static constexpr index_t kCM0PerLane = 4;
1167  static constexpr index_t kCM1PerLane = 4;
1168 
1169  // c_vec += a_vec * b_vec
1170  template <bool post_nop_ = false>
1172  const AVecType& a_vec,
1173  const BVecType& b_vec,
1174  bool_constant<post_nop_> = {}) const
1175  {
1176  if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
1177  {
1178  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1179  {
1180  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "v", "v", "v")
1181  }
1182  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1183  {
1184  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "v", "v", "v")
1185  }
1186  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1187  {
1188  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "v", "v", "v")
1189  }
1190  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1191  {
1192  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "v", "v", "v")
1193  }
1194  }
1195  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
1196  {
1197  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1198  {
1199  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "a", "a", "v")
1200  }
1201  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1202  {
1203  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "a", "a", "v")
1204  }
1205  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1206  {
1207  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "a", "a", "v")
1208  }
1209  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1210  {
1211  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "a", "a", "v")
1212  }
1213  }
1214  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
1215  {
1216  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1217  {
1218  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "a", "v", "v")
1219  }
1220  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1221  {
1222  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "a", "v", "v")
1223  }
1224  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1225  {
1226  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "a", "v", "v")
1227  }
1228  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1229  {
1230  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "a", "v", "v")
1231  }
1232  }
1233  else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
1234  {
1235  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1236  {
1237  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "v", "a", "v")
1238  }
1239  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1240  {
1241  DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "v", "a", "v")
1242  }
1243  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1244  {
1245  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "v", "a", "v")
1246  }
1247  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1248  {
1249  DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "v", "a", "v")
1250  }
1251  }
1252  else
1253  {
1254 #if defined(__gfx94__) or defined(__gfx95__)
1255  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1256  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
1257  bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
1258  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1259  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
1260  bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
1261  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1262  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
1263  bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
1264  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1265  c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
1266  bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
1267 #elif defined(__gfx908__) || defined(__gfx90a__)
1268  static_for<0, 8, 1>{}([&](auto k) {
1269  float a_f32 =
1270  type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1271  .template get_as<ADataType>()[number<k>{}]);
1272  float b_f32 =
1273  type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1274  .template get_as<BDataType>()[number<k>{}]);
1275 
1276  c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
1277  });
1278 #else
1279  ck_tile::ignore = c_vec;
1280  ck_tile::ignore = a_vec;
1281  ck_tile::ignore = b_vec;
1282 #endif
1283  }
1284  }
1285 
1286  // c_vec = a_vec * b_vec
1287  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1288  {
1289 #if defined(__gfx94__) or defined(__gfx95__)
1290  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1291  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
1292  bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
1293  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1294  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
1295  bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
1296  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1297  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
1298  bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
1299  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1300  return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
1301  bit_cast<long>(a_vec), bit_cast<long>(b_vec), CVecType{0.f}, 0, 0, 0));
1302 #elif defined(__gfx908__) || defined(__gfx90a__)
1303  CVecType c_vec{0.f};
1304  static_for<0, 8, 1>{}([&](auto k) {
1305  float a_f32 =
1306  type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1307  .template get_as<ADataType>()[number<k>{}]);
1308  float b_f32 =
1309  type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1310  .template get_as<BDataType>()[number<k>{}]);
1311 
1312  c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
1313  });
1314  return c_vec;
1315 #else
1316  ck_tile::ignore = a_vec;
1317  ck_tile::ignore = b_vec;
1318  return CVecType{0.f};
1319 #endif
1320  }
1321 };
1322 
1323 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1326 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1329 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1332 
1333 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1336 
1337 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1340 
1341 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1344 
1345 template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1347 {
1348  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1349  using ADataType = AType_;
1350  using BDataType = BType_;
1351  using CDataType = float;
1352 
1356 
1357  static constexpr index_t kM = 16;
1358  static constexpr index_t kN = 16;
1359  static constexpr index_t kK = 128;
1360 
1361  static constexpr index_t kAMBlock = 1;
1362  static constexpr index_t kBNBlock = 1;
1363 
1364  static constexpr index_t kAMLane = 16;
1365  static constexpr index_t kBNLane = 16;
1366  static constexpr index_t kABKLane = 4;
1367  static constexpr index_t kABKPerLane = 32;
1368 
1369  static constexpr index_t kCMLane = 4;
1370  static constexpr index_t kCNLane = 16;
1371  static constexpr index_t kCM0PerLane = 1;
1372  static constexpr index_t kCM1PerLane = 4;
1373 
1374  // c_vec += a_vec * b_vec
1375  template <bool post_nop_ = false>
1377  const AVecType& a_vec,
1378  const BVecType& b_vec,
1379  bool_constant<post_nop_> = {}) const
1380  {
1381  //__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a,
1382  // opsel, scale_b)
1383 #if defined(__gfx950__)
1384  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1385  c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1386  a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0);
1387  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1388  c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1389  a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0);
1390  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1391  c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1392  a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0);
1393  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1394  c_vec = __builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1395  a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0);
1396 #else
1397  ck_tile::ignore = c_vec;
1398  ck_tile::ignore = a_vec;
1399  ck_tile::ignore = b_vec;
1400 #endif
1401  }
1402 
1403  // c_vec = a_vec * b_vec
1404  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1405  {
1406 #if defined(__gfx950__)
1407  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1408  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1409  a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0));
1410  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1411  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1412  a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0));
1413  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1414  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1415  a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0));
1416  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1417  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
1418  a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0));
1419 #else
1420  ck_tile::ignore = a_vec;
1421  ck_tile::ignore = b_vec;
1422  return CVecType{0.f};
1423 #endif
1424  }
1425 };
1426 
1427 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1430 
1431 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1434 
1435 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1438 
1439 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1442 
1443 template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1445 {
1446  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1447  using ADataType = AType_;
1448  using BDataType = BType_;
1449  using CDataType = float;
1450 
1454 
1455  static constexpr index_t kM = 32;
1456  static constexpr index_t kN = 32;
1457  static constexpr index_t kK = 64;
1458 
1459  static constexpr index_t kAMBlock = 1;
1460  static constexpr index_t kBNBlock = 1;
1461 
1462  static constexpr index_t kAMLane = 32;
1463  static constexpr index_t kBNLane = 32;
1464  static constexpr index_t kABKLane = 2;
1465  static constexpr index_t kABKPerLane = 32;
1466 
1467  static constexpr index_t kCMLane = 2;
1468  static constexpr index_t kCNLane = 32;
1469  static constexpr index_t kCM0PerLane = 4;
1470  static constexpr index_t kCM1PerLane = 4;
1471 
1472  // c_vec += a_vec * b_vec
1473  template <bool post_nop_ = false>
1475  const AVecType& a_vec,
1476  const BVecType& b_vec,
1477  bool_constant<post_nop_> = {}) const
1478  {
1479  //__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(a, b, c, cbsz, blgp, opsel, scale_a,
1480  // opsel, scale_b)
1481 #if defined(__gfx950__)
1482  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1483  c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1484  a_vec, b_vec, c_vec, 0, 0, 0, 0, 0, 0);
1485  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1486  c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1487  a_vec, b_vec, c_vec, 0, 1, 0, 0, 0, 0);
1488  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1489  c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1490  a_vec, b_vec, c_vec, 1, 0, 0, 0, 0, 0);
1491  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1492  c_vec = __builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1493  a_vec, b_vec, c_vec, 1, 1, 0, 0, 0, 0);
1494 #else
1495  ck_tile::ignore = c_vec;
1496  ck_tile::ignore = a_vec;
1497  ck_tile::ignore = b_vec;
1498 #endif
1499  }
1500 
1501  // c_vec = a_vec * b_vec
1502  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1503  {
1504 #if defined(__gfx950__)
1505  if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
1506  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1507  a_vec, b_vec, CVecType{0.f}, 0, 0, 0, 0, 0, 0));
1508  else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
1509  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1510  a_vec, b_vec, CVecType{0.f}, 0, 1, 0, 0, 0, 0));
1511  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
1512  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1513  a_vec, b_vec, CVecType{0.f}, 1, 0, 0, 0, 0, 0));
1514  else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
1515  return bit_cast<CVecType>(__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
1516  a_vec, b_vec, CVecType{0.f}, 1, 1, 0, 0, 0, 0));
1517 #else
1518  ck_tile::ignore = a_vec;
1519  ck_tile::ignore = b_vec;
1520  return CVecType{0.f};
1521 #endif
1522  }
1523 };
1524 
1525 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1528 
1529 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1532 
1533 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1536 
1537 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1540 
1541 // int8
1542 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1544 {
1545  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1549 
1553 
1554  static constexpr index_t kM = 32;
1555  static constexpr index_t kN = 32;
1556  static constexpr index_t kK = 16;
1557 
1558  static constexpr index_t kAMBlock = 1;
1559  static constexpr index_t kBNBlock = 1;
1560 
1561  static constexpr index_t kAMLane = 32;
1562  static constexpr index_t kBNLane = 32;
1563  static constexpr index_t kABKLane = 2;
1564  static constexpr index_t kABKPerLane = 8;
1565 
1566  static constexpr index_t kCMLane = 2;
1567  static constexpr index_t kCNLane = 32;
1568  static constexpr index_t kCM0PerLane = 4;
1569  static constexpr index_t kCM1PerLane = 4;
1570 
1571  // c_vec += a_vec * b_vec
1572  template <bool post_nop_ = false>
1574  const AVecType& a_vec,
1575  const BVecType& b_vec,
1576  bool_constant<post_nop_> = {}) const
1577  {
1578  DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x16_i8", Ctrl)
1579  else
1580  {
1581 #if defined(__gfx94__) or defined(__gfx95__)
1582  c_vec = __builtin_amdgcn_mfma_i32_32x32x16_i8(
1583  bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
1584 #elif defined(__gfx908__) || defined(__gfx90a__)
1585  static_for<0, 8, 1>{}([&](auto k) {
1586  float a_f32 =
1587  type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
1588  .template get_as<ADataType>()[number<k>{}]);
1589  float b_f32 =
1590  type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
1591  .template get_as<BDataType>()[number<k>{}]);
1592 
1593  c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
1594  });
1595 #else
1596  ck_tile::ignore = c_vec;
1597  ck_tile::ignore = a_vec;
1598  ck_tile::ignore = b_vec;
1599 #endif
1600  }
1601  }
1602 
1603  // c_vec = a_vec * b_vec
1604  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1605  {
1606  CVecType c_vec{0};
1607  operator()(c_vec, a_vec, b_vec);
1608  return c_vec;
1609  }
1610 };
1611 
1612 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1614 {
1615  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1619 
1623 
1624  static constexpr index_t kM = 16;
1625  static constexpr index_t kN = 16;
1626  static constexpr index_t kK = 32;
1627 
1628  static constexpr index_t kAMBlock = 1;
1629  static constexpr index_t kBNBlock = 1;
1630 
1631  static constexpr index_t kAMLane = 16;
1632  static constexpr index_t kBNLane = 16;
1633  static constexpr index_t kABKLane = 4;
1634  static constexpr index_t kABKPerLane = 8;
1635 
1636  static constexpr index_t kCMLane = 4;
1637  static constexpr index_t kCNLane = 16;
1638  static constexpr index_t kCM0PerLane = 1;
1639  static constexpr index_t kCM1PerLane = 4; // write to 4x AccVGPRs
1640 
1641  // c_vec += a_vec * b_vec
1642  template <bool post_nop_ = false>
1644  const AVecType& a_vec,
1645  const BVecType& b_vec,
1646  bool_constant<post_nop_> = {}) const
1647  {
1648  DISPATCH_MFMA_CTRL_("v_mfma_i32_16x16x32_i8", Ctrl)
1649  else
1650  {
1651 #if defined(__gfx94__) or defined(__gfx95__)
1652  c_vec = __builtin_amdgcn_mfma_i32_16x16x32_i8(
1653  bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
1654 #else
1655  ck_tile::ignore = c_vec;
1656  ck_tile::ignore = a_vec;
1657  ck_tile::ignore = b_vec;
1658 #endif
1659  }
1660  }
1661 
1662  // c_vec = a_vec * b_vec
1663  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1664  {
1665  CVecType c_vec{0};
1666  operator()(c_vec, a_vec, b_vec);
1667  return c_vec;
1668  }
1669 };
1670 
1671 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1673 {
1674  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1678 
1682 
1683  static constexpr index_t kM = 16;
1684  static constexpr index_t kN = 16;
1685  static constexpr index_t kK = 64;
1686 
1687  static constexpr index_t kAMBlock = 1;
1688  static constexpr index_t kBNBlock = 1;
1689 
1690  static constexpr index_t kAMLane = 16;
1691  static constexpr index_t kBNLane = 16;
1692  static constexpr index_t kABKLane = 4;
1693  static constexpr index_t kABKPerLane = 16;
1694 
1695  static constexpr index_t kCMLane = 4;
1696  static constexpr index_t kCNLane = 16;
1697  static constexpr index_t kCM0PerLane = 1;
1698  static constexpr index_t kCM1PerLane = 4; // write to 4x AccVGPRs
1699 
1700  // c_vec += a_vec * b_vec
1701  template <bool post_nop_ = false>
1703  const AVecType& a_vec,
1704  const BVecType& b_vec,
1705  bool_constant<post_nop_> = {}) const
1706  {
1707  DISPATCH_MFMA_CTRL_("v_mfma_i32_16x16x64_i8", Ctrl)
1708  else
1709  {
1710 #if defined(__gfx95__)
1711  c_vec = __builtin_amdgcn_mfma_i32_16x16x64_i8(
1712  bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
1713 #else
1714  ck_tile::ignore = c_vec;
1715  ck_tile::ignore = a_vec;
1716  ck_tile::ignore = b_vec;
1717 #endif
1718  }
1719  }
1720 
1721  // c_vec = a_vec * b_vec
1722  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1723  {
1724  CVecType c_vec{0};
1725  operator()(c_vec, a_vec, b_vec);
1726  return c_vec;
1727  }
1728 };
1729 
1730 template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
1732 {
1733  static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
1737 
1741 
1742  static constexpr index_t kM = 32;
1743  static constexpr index_t kN = 32;
1744  static constexpr index_t kK = 32;
1745 
1746  static constexpr index_t kAMBlock = 1;
1747  static constexpr index_t kBNBlock = 1;
1748 
1749  static constexpr index_t kAMLane = 32;
1750  static constexpr index_t kBNLane = 32;
1751  static constexpr index_t kABKLane = 2;
1752  static constexpr index_t kABKPerLane = 16;
1753 
1754  static constexpr index_t kCMLane = 2;
1755  static constexpr index_t kCNLane = 32;
1756  static constexpr index_t kCM0PerLane = 4;
1757  static constexpr index_t kCM1PerLane = 4;
1758 
1759  // c_vec += a_vec * b_vec
1760  template <bool post_nop_ = false>
1762  const AVecType& a_vec,
1763  const BVecType& b_vec,
1764  bool_constant<post_nop_> = {}) const
1765  {
1766  DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x32_i8", Ctrl)
1767  else
1768  {
1769 #if defined(__gfx95__)
1770  c_vec =
1771  __builtin_amdgcn_mfma_i32_32x32x32_i8(a_vec, bit_cast<long>(b_vec), c_vec, 0, 0, 0);
1772 #else
1773  ck_tile::ignore = c_vec;
1774  ck_tile::ignore = a_vec;
1775  ck_tile::ignore = b_vec;
1776 #endif
1777  }
1778  }
1779 
1780  // c_vec = a_vec * b_vec
1781  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
1782  {
1783  CVecType c_vec{0};
1784  operator()(c_vec, a_vec, b_vec);
1785  return c_vec;
1786  }
1787 };
1788 
1789 #undef DISPATCH_MFMA_
1790 
1791 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
WGAttrCtlEnum
Definition: warp_gemm_attribute_mfma_impl.hpp:15
_Float16 fp16_t
Definition: half.hpp:110
tuple_array< T, N > thread_buffer
Definition: thread_buffer.hpp:14
int8_t int8_t
Definition: int8.hpp:20
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
int32_t index_t
Definition: integer.hpp:9
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
typename impl::ext_vector< T, N >::type ext_vector_t
Definition: vector_type.hpp:83
int32_t int32_t
Definition: integer.hpp:10
float fp32x16_t
Definition: vector_type.hpp:119
float fp32x4_t
Definition: vector_type.hpp:117
Definition: warp_gemm_attribute_mfma_impl.hpp:1347
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1357
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1371
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1362
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1348
ext_vector_t< CDataType, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1355
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1366
ext_vector_t< ADataType, 32 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1353
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1369
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1351
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1364
BType_ BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1350
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1372
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1359
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1367
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1361
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1370
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1365
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1358
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1376
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1404
ext_vector_t< BDataType, 32 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1354
AType_ ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1349
Definition: warp_gemm_attribute_mfma_impl.hpp:983
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:997
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:998
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1007
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:993
AType_ ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:985
ext_vector_t< CDataType, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:991
ext_vector_t< ADataType, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:989
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:984
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1005
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1002
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1012
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1008
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:987
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1117
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1001
BType_ BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:986
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1006
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1003
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:995
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1000
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:994
ext_vector_t< BDataType, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:990
Definition: warp_gemm_attribute_mfma_impl.hpp:1142
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1164
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1143
ext_vector_t< ADataType, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1148
BType_ BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1145
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1167
ext_vector_t< CDataType, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1150
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1162
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1146
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1156
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1165
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1171
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1287
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1159
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1152
AType_ ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1144
ext_vector_t< BDataType, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1149
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1154
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1153
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1166
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1157
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1161
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1160
Definition: warp_gemm_attribute_mfma_impl.hpp:1445
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1502
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1456
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1467
BType_ BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1448
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1465
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1455
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1463
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1462
ext_vector_t< CDataType, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1453
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1449
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1474
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1469
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1460
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1459
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1468
ext_vector_t< BDataType, 32 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1452
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1446
AType_ ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1447
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1464
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1470
ext_vector_t< ADataType, 32 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1451
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1457
Definition: warp_gemm_attribute_mfma_impl.hpp:1614
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1663
ext_vector_t< CDataType, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1622
ext_vector_t< BDataType, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1621
ext_vector_t< ADataType, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1620
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1631
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1643
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1615
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1625
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1629
int8_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1616
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1632
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1634
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1633
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1628
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1636
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1638
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1624
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1639
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1637
int32_t CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1618
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1626
int8_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1617
Definition: warp_gemm_attribute_mfma_impl.hpp:1673
int32_t CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1677
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1698
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1685
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1684
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1691
ext_vector_t< ADataType, 16 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1679
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1693
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1697
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1683
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1690
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1702
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1696
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1688
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1692
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1674
int8_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1676
ext_vector_t< BDataType, 16 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1680
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1722
int8_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1675
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1695
ext_vector_t< CDataType, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1681
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1687
Definition: warp_gemm_attribute_mfma_impl.hpp:1544
int8_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1546
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1573
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1566
ext_vector_t< ADataType, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1550
ext_vector_t< BDataType, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1551
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1568
int32_t CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1548
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1545
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1604
int8_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1547
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1561
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1555
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1559
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1556
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1569
ext_vector_t< CDataType, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1552
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1567
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1564
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1558
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1562
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1563
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1554
Definition: warp_gemm_attribute_mfma_impl.hpp:1732
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:1742
int8_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1734
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1756
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1751
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1761
int8_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1735
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1755
int32_t CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:1736
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1750
ext_vector_t< BDataType, 16 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1739
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1747
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1749
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1754
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:1744
ext_vector_t< ADataType, 16 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1738
ext_vector_t< CDataType, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:1740
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:1781
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:1743
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1757
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:1752
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:1733
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:1746
Definition: warp_gemm_attribute_mfma_impl.hpp:537
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:545
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:555
ext_vector_t< bf16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:544
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:566
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:559
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:596
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:549
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:554
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:538
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:560
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:561
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:556
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:551
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:541
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:562
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:548
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:540
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:552
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:539
ext_vector_t< bf16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:543
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:557
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:547
Definition: warp_gemm_attribute_mfma_impl.hpp:67
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:90
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:70
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:75
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:78
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:81
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:91
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:86
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:77
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:79
ext_vector_t< bf16_t, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:73
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:68
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:87
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:71
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:85
ext_vector_t< bf16_t, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:74
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:89
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:69
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:96
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:92
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:115
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:82
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:84
Definition: warp_gemm_attribute_mfma_impl.hpp:868
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:891
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:897
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:872
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:869
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:871
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:882
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:883
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:886
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:893
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:940
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:887
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:892
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:885
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:878
ext_vector_t< bf16_t, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:874
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:879
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:890
ext_vector_t< bf16_t, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:875
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:870
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:888
ext_vector_t< float, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:876
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:880
Definition: warp_gemm_attribute_mfma_impl.hpp:448
ext_vector_t< bf16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:455
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:451
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:452
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:463
ext_vector_t< float, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:456
ext_vector_t< bf16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:454
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:508
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:459
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:472
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:458
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:471
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:462
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:467
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:466
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:465
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:460
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:470
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:450
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:449
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:477
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:473
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:468
Definition: warp_gemm_attribute_mfma_impl.hpp:625
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:646
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:627
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:649
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:648
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:650
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:674
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:628
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:637
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:626
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:644
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:655
ext_vector_t< bf16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:631
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:636
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:635
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:645
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:640
ext_vector_t< bf16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:632
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:651
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:633
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:643
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:639
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:629
Definition: warp_gemm_attribute_mfma_impl.hpp:689
bf16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:692
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:712
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:707
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:738
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:690
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:713
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:700
bf16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:691
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:703
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:714
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:704
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:708
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:701
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:693
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:709
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:715
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:710
ext_vector_t< bf16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:695
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:719
ext_vector_t< bf16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:696
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:699
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:697
Definition: warp_gemm_attribute_mfma_impl.hpp:193
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:194
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:210
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:197
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:216
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:207
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:196
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:215
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:218
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:199
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:208
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:204
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:195
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:212
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:211
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:201
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:217
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:213
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:241
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:222
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:203
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:205
ext_vector_t< fp16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:200
Definition: warp_gemm_attribute_mfma_impl.hpp:256
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:267
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:279
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:281
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:258
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:278
ext_vector_t< fp16_t, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:262
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:275
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:276
ext_vector_t< fp16_t, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:263
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:266
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:264
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:304
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:259
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:260
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:273
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:280
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:271
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:270
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:257
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:274
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:268
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:285
Definition: warp_gemm_attribute_mfma_impl.hpp:754
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:755
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:756
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:768
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:776
ext_vector_t< fp16_t, 8 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:760
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:777
ext_vector_t< fp16_t, 8 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:761
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:764
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:758
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:783
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:765
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:774
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:826
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:773
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:772
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:779
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:769
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:766
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:771
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:778
ext_vector_t< float, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:762
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:757
Definition: warp_gemm_attribute_mfma_impl.hpp:130
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:153
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:147
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:154
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:152
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:159
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:155
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:134
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:142
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:133
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:148
ext_vector_t< float, 16 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:138
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:145
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:131
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:140
ext_vector_t< fp16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:137
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:141
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:178
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:144
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:136
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:132
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:149
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:150
Definition: warp_gemm_attribute_mfma_impl.hpp:319
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:330
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:342
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:321
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:325
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:368
ext_vector_t< fp16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:326
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:334
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:323
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:349
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:327
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:337
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:338
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:343
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:322
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:331
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:339
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:329
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:344
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:320
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:340
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:345
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:333
Definition: warp_gemm_attribute_mfma_impl.hpp:383
static constexpr index_t kABKPerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:404
float CDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:387
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma_impl.hpp:393
static constexpr index_t kAMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:401
fp16_t BDataType
Definition: warp_gemm_attribute_mfma_impl.hpp:386
static constexpr WGAttrCtlEnum Ctrl
Definition: warp_gemm_attribute_mfma_impl.hpp:384
static constexpr index_t kBNBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:398
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma_impl.hpp:413
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma_impl.hpp:395
static constexpr index_t kAMBlock
Definition: warp_gemm_attribute_mfma_impl.hpp:397
ext_vector_t< float, 4 > CVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:391
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma_impl.hpp:394
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma_impl.hpp:432
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma_impl.hpp:406
ext_vector_t< fp16_t, 4 > AVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:389
static constexpr index_t kBNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:402
static constexpr index_t kCNLane
Definition: warp_gemm_attribute_mfma_impl.hpp:407
static constexpr index_t kCM0PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:408
static constexpr index_t kABKLane
Definition: warp_gemm_attribute_mfma_impl.hpp:403
fp16_t ADataType
Definition: warp_gemm_attribute_mfma_impl.hpp:385
ext_vector_t< fp16_t, 4 > BVecType
Definition: warp_gemm_attribute_mfma_impl.hpp:390
static constexpr index_t kCM1PerLane
Definition: warp_gemm_attribute_mfma_impl.hpp:409
Definition: integral_constant.hpp:13
Definition: functional.hpp:43
#define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_)
Definition: warp_gemm_attribute_mfma_impl.hpp:25
#define DISPATCH_MFMA_CTRL_(mfma_, ctrl_)
Definition: warp_gemm_attribute_mfma_impl.hpp:42