/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp Source File
warp_gemm_attribute_mfma.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
11 // Number of groups of consecutive elements to fill in a ABKLane
13 {
14  Single = 1,
15  Double = 2,
16  Quad = 4,
17  Invalid = -1
18 };
19 
20 template <typename WarpGemmAttributeMfmaImpl_,
23 {
25  static constexpr auto AttrNumAccess = AttrNumAccess_;
26  static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
27 
28  using ADataType = typename Impl::ADataType;
29  using BDataType = typename Impl::BDataType;
30  using CDataType = typename Impl::CDataType;
31 
32  using AVecType = typename Impl::AVecType;
33  using BVecType = typename Impl::BVecType;
34  using CVecType = typename Impl::CVecType;
35 
36  static constexpr index_t kM = Impl::kM;
37  static constexpr index_t kN = Impl::kN;
38  static constexpr index_t kK = Impl::kK;
39  static constexpr index_t kKPerThread = Impl::kABKPerLane;
40  static constexpr index_t kCMLane = Impl::kCMLane;
41 
42  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
43 
44  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
45  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
46 
47  template <index_t kMNLane>
48  static constexpr auto get_warp_dstr_encoding()
49  {
50  static_assert(kKPerThread % AttrNumAccessV == 0,
51  "kKPerThread must be divisible by NumAccess");
52  if constexpr(AttrNumAccessV == 1)
54  sequence<>,
59  sequence<1>>{};
60  else
62  sequence<>,
64  sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
68  sequence<0, 2>>{};
69  }
70  using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane>());
71  using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane>());
72 
74  sequence<>,
81 
82  // c_vec += a_vec * b_vec
83  template <bool post_nop_ = false>
85  const AVecType& a_vec,
86  const BVecType& b_vec,
87  bool_constant<post_nop_> = {}) const
88  {
89  Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
90  }
91 
92  // c_vec += a_vec * b_vec
93  template <index_t opselA, index_t opselB, bool post_nop_ = false>
95  const AVecType& a_vec,
96  const int32_t& a_scale,
97  const BVecType& b_vec,
98  const int32_t& b_scale,
99  bool_constant<post_nop_> = {}) const
100  {
101  Impl{}.template operator()<opselA, opselB>(
102  c_vec, a_vec, a_scale, b_vec, b_scale, bool_constant<post_nop_>{});
103  }
104 
105  // c_vec = a_vec * b_vec
106  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
107  {
108  return Impl{}(a_vec, b_vec);
109  }
110 
111  // c_vec = a_vec * b_vec
112  template <index_t opselA, index_t opselB>
114  const int32_t& a_scale,
115  const BVecType& b_vec,
116  const int32_t& b_scale) const
117  {
118  return Impl{}.template operator()<opselA, opselB>(a_vec, a_scale, b_vec, b_scale);
119  }
120 };
121 
122 template <typename WarpGemmAttributeMfmaImpl_,
123  index_t kKIter,
126 {
127  static_assert(kKIter > 0, "wrong!");
128 
130  static constexpr auto AttrNumAccess = AttrNumAccess_;
131  static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
132 
133  using ADataType = typename Impl::ADataType;
134  using BDataType = typename Impl::BDataType;
135  using CDataType = typename Impl::CDataType;
136 
137  using AVecType =
139  using BVecType =
141  using CVecType = typename Impl::CVecType;
142 
143  static constexpr index_t kM = Impl::kM;
144  static constexpr index_t kN = Impl::kN;
145  static constexpr index_t kK = Impl::kK * kKIter;
146  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
147  static constexpr index_t kCMLane = Impl::kCMLane;
148 
149  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
150 
151  static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
152  "Multi-block on both M & N directions is not supported");
153 
154  template <index_t kMNLane, index_t kMNBlock, index_t kNMBlock>
155  CK_TILE_DEVICE static constexpr auto get_warp_dstr_encoding()
156  {
157  if constexpr(kMNBlock == 1 && kNMBlock == 1)
158  {
159  static_assert(kKPerThread % AttrNumAccessV == 0,
160  "kKPerThread must be divisible by NumAccess");
161  if constexpr(AttrNumAccessV == 1)
163  sequence<>,
167  sequence<2>,
168  sequence<1>>{};
169  else
171  sequence<>,
174  Impl::kABKLane,
175  Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
179  sequence<0, 2>>{};
180  }
181  else if constexpr(kMNBlock == 1 && 1 < kNMBlock)
182  {
183  static_assert(AttrNumAccessV == 1,
184  "Multiple access is not supported when using multi-block");
185  // each M/N blocks share the same data
191  sequence<2>,
192  sequence<1>>{};
193  }
194  else if constexpr(1 < kMNBlock && kNMBlock == 1)
195  {
196  static_assert(AttrNumAccessV == 1,
197  "Multiple access is not supported when using multi-block");
198  // single block to multi-block thread mapping
200  sequence<>,
205  sequence<2>,
206  sequence<1>>{};
207  }
208  }
209 
210  CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
211  {
212  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
213  {
215  sequence<>,
221  sequence<0, 2>>{};
222  }
223  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
224  {
226  sequence<>,
232  sequence<0, 2>>{};
233  }
234  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
235  {
237  sequence<>,
238  tuple<
244  sequence<0, 2>>{};
245  }
246  }
247 
249  decltype(get_warp_dstr_encoding<Impl::kAMLane, Impl::kAMBlock, Impl::kBNBlock>());
251  decltype(get_warp_dstr_encoding<Impl::kBNLane, Impl::kBNBlock, Impl::kAMBlock>());
253 
254  // c_vec += a_vec * b_vec
255  template <bool post_nop_ = false>
257  const AVecType& a_vec,
258  const BVecType& b_vec,
259  bool_constant<post_nop_> = {}) const
260  {
261  static_for<0, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
262  }
263 
264  template <index_t iKIter, bool post_nop_ = false>
266  const AVecType& a_vec,
267  const BVecType& b_vec,
269  bool_constant<post_nop_> = {}) const
270  {
273 
274  static_assert(iKIter < kKIter);
275 
276  Impl{}(c_vec,
277  reinterpret_cast<const buf_a&>(a_vec)
278  .template get_as<typename Impl::AVecType>()[iKIter],
279  reinterpret_cast<const buf_b&>(b_vec)
280  .template get_as<typename Impl::BVecType>()[iKIter],
281  bool_constant<post_nop_>{});
282  }
283 
284  // c_vec = a_vec * b_vec
285  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
286  {
287  constexpr auto I0 = number<0>{};
290 
291  // c = a * b
292  auto c_vec = Impl{}(
293  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
294  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
295 
296  // c += a * b
297  static_for<1, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
298 
299  return c_vec;
300  }
301 };
302 
303 template <typename WarpGemmAttributeMfmaImpl_,
306 {
308  static constexpr auto AttrNumAccess = AttrNumAccess_;
309  static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
310 
311  using ADataType = typename Impl::BDataType;
312  using BDataType = typename Impl::ADataType;
313  using CDataType = typename Impl::CDataType;
314 
315  using AVecType = typename Impl::BVecType;
316  using BVecType = typename Impl::AVecType;
317  using CVecType = typename Impl::CVecType;
318 
319  static constexpr index_t kM = Impl::kN;
320  static constexpr index_t kN = Impl::kM;
321  static constexpr index_t kK = Impl::kK;
322  static constexpr index_t kKPerThread = Impl::kABKPerLane;
323  static constexpr index_t kCMLane = Impl::kCMLane;
324 
325  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
326 
327  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
328  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
329 
334 
336  sequence<>,
343 
344  // c_vec += a_vec * b_vec
345  template <bool post_nop_ = false>
347  const AVecType& a_vec,
348  const BVecType& b_vec,
349  bool_constant<post_nop_> = {}) const
350  {
351  // swap A and B
352  Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
353  }
354 
355  // c_vec = a_vec * b_vec
356  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
357  {
358  // swap A and B
359  return Impl{}(b_vec, a_vec);
360  }
361 };
362 
363 template <typename WarpGemmAttributeMfmaImpl_, index_t SFactor_ = 2>
365 {
367 
368  using ADataType = typename Impl::BDataType;
369  using BDataType = typename Impl::ADataType;
370  using CDataType = typename Impl::CDataType;
371 
372  using AVecType = typename Impl::BVecType;
373  using BVecType = typename Impl::AVecType;
374  using CVecType = typename Impl::CVecType;
375 
376  static constexpr index_t kM = Impl::kN;
377  static constexpr index_t kN = Impl::kM;
378  static constexpr index_t kK = Impl::kK;
379  static constexpr index_t kKPerThread = Impl::kABKPerLane;
380  static constexpr index_t SFactor = SFactor_; // group how many CM1 together
381 
382  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
383 
384  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
385  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
386 
388  sequence<>,
392  sequence<2>,
393  sequence<1>>;
394 #if 0
396  sequence<>,
397  tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
398  Impl::kABKLane,
399  2,
400  Impl::kABKPerLane>,
404  sequence<2>,
405  sequence<1>>;
406 
408  sequence<>,
410  sequence<Impl::kCM0PerLane / 2, Impl::kCMLane, Impl::kCM1PerLane * 2>>,
415 #else
416  // TODO: more test not only 32x32
418  sequence<>,
419  tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
420  Impl::kCMLane,
421  SFactor,
422  Impl::kCM1PerLane>,
426  sequence<2>,
427  sequence<1>>;
428 
430  sequence<>,
432  sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>>,
437 #endif
438  template <bool post_nop_ = false>
439  // c_vec += a_vec * b_vec
441  const AVecType& a_vec,
442  const BVecType& b_vec,
443  bool_constant<post_nop_> = {}) const
444  {
445  // swap A and B
446  Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
447  }
448 
449  // c_vec = a_vec * b_vec
450  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
451  {
452  // swap A and B
453  return Impl{}(b_vec, a_vec);
454  }
455 };
456 
457 template <typename WarpGemmAttributeMfmaImpl_,
458  index_t kKIter,
461 {
463  static constexpr auto AttrNumAccess = AttrNumAccess_;
464 
465  // swap A and B
466  using ADataType = typename Impl::BDataType;
467  using BDataType = typename Impl::ADataType;
468  using CDataType = typename Impl::CDataType;
469 
470  using AVecType =
472  using BVecType =
474  using CVecType = typename Impl::CVecType;
475 
476  static constexpr index_t kM = Impl::kN;
477  static constexpr index_t kN = Impl::kM;
478  static constexpr index_t kK = Impl::kK * kKIter;
479  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
480 
481  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
482 
483  static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
484  "Multi-block on both M & N directions is not supported");
485 
486  CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
487  {
488  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
489  {
491  sequence<>,
497  sequence<0, 2>>{};
498  }
499  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
500  {
502  sequence<>,
508  sequence<0, 2>>{};
509  }
510  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
511  {
513  sequence<>,
514  tuple<
520  sequence<0, 2>>{};
521  }
522  }
523 
529 
530  // c_vec += a_vec * b_vec
531  template <bool post_nop_ = false>
533  const AVecType& a_vec,
534  const BVecType& b_vec,
535  bool_constant<post_nop_> = {}) const
536  {
537  static_for<0, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
538  }
539 
540  template <index_t iKIter, bool post_nop_ = false>
541  // c_vec += a_vec * b_vec
543  const AVecType& a_vec,
544  const BVecType& b_vec,
546  bool_constant<post_nop_> = {}) const
547  {
550 
551  static_assert(iKIter < kKIter);
552  // swap A and B, value and type
553  Impl{}(c_vec,
554  reinterpret_cast<const buf_b&>(b_vec)
555  .template get_as<typename Impl::BVecType>()[iKIter],
556  reinterpret_cast<const buf_a&>(a_vec)
557  .template get_as<typename Impl::AVecType>()[iKIter],
558  bool_constant<post_nop_>{});
559  }
560 
561  // c_vec = a_vec * b_vec
562  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
563  {
564  constexpr auto I0 = number<0>{};
567 
568  // swap A and B, value and type
569  auto c_vec = Impl{}(
570  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
571  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
572 
573  static_for<1, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
574 
575  return c_vec;
576  }
577 };
578 
579 template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
581 {
583 
584  // swap A and B
585  using ADataType = typename Impl::BDataType;
586  using BDataType = typename Impl::ADataType;
587  using CDataType = typename Impl::CDataType;
588 
589  using AVecType =
591  using BVecType =
593  using CVecType = typename Impl::CVecType;
594 
595  static constexpr index_t kM = Impl::kN;
596  static constexpr index_t kN = Impl::kM;
597  static constexpr index_t kK = Impl::kK * kKIter;
598  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
599  static constexpr index_t SFactor = SFactor_; // group how many CM1 together
600 
601  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
602 
603  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
604  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
605 
607  sequence<>,
611  sequence<2>,
612  sequence<1>>;
613 #if 0
615  sequence<>,
616  tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
617  Impl::kABKLane,
618  2,
619  Impl::kABKPerLane>,
623  sequence<2>,
624  sequence<1>>;
625 
627  sequence<>,
629  sequence<Impl::kCM0PerLane / 2, Impl::kCMLane, Impl::kCM1PerLane * 2>>,
634 #else
635  // TODO: more test not only 32x32
637  sequence<>,
638  tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
639  Impl::kCMLane,
640  SFactor,
641  Impl::kCM1PerLane>,
645  sequence<2>,
646  sequence<1>>;
647 
649  sequence<>,
651  sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>>,
656 #endif
657  // c_vec += a_vec * b_vec
658  template <bool post_nop_ = false>
660  const AVecType& a_vec,
661  const BVecType& b_vec,
662  bool_constant<post_nop_> = {}) const
663  {
664  // swap A and B, value and type
665  static_for<0, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
666  }
667 
668  template <index_t iKIter, bool post_nop_ = false>
670  const AVecType& a_vec,
671  const BVecType& b_vec,
673  bool_constant<post_nop_> = {}) const
674  {
677 
678  static_assert(iKIter < kKIter);
679  // swap A and B, value and type
680  Impl{}(c_vec,
681  reinterpret_cast<const buf_b&>(b_vec)
682  .template get_as<typename Impl::BVecType>()[iKIter],
683  reinterpret_cast<const buf_a&>(a_vec)
684  .template get_as<typename Impl::AVecType>()[iKIter],
685  bool_constant<post_nop_>{});
686  }
687 
688  // c_vec = a_vec * b_vec
689  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
690  {
693  constexpr auto I0 = number<0>{};
694 
695  // swap A and B, value and type
696  auto c_vec = Impl{}(
697  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
698  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
699 
700  static_for<1, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
701 
702  return c_vec;
703  }
704 };
705 
706 template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
708 {
710 
711  using ADataType = typename Impl::ADataType;
712  using BDataType = typename Impl::BDataType;
713  using CDataType = typename Impl::CDataType;
714 
715  using AVecType =
717  using BVecType =
719  using CVecType = typename Impl::CVecType;
720 
721  static constexpr index_t kM = Impl::kM;
722  static constexpr index_t kN = Impl::kN;
723  static constexpr index_t kK = Impl::kK * kKIter;
724  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
725  static constexpr index_t SFactor = SFactor_; // group how many CM1 together
726 
727  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
728 
729  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
730  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
731 
733  sequence<>,
734  tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
735  Impl::kCMLane,
736  SFactor,
737  Impl::kCM1PerLane>,
741  sequence<2>,
742  sequence<1>>;
743 
745  sequence<>,
749  sequence<2>,
750  sequence<1>>;
751 
753  sequence<>,
754  tuple<sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>,
760 
761  // c_vec += a_vec * b_vec
762  template <bool post_nop_ = false>
764  const AVecType& a_vec,
765  const BVecType& b_vec,
766  bool_constant<post_nop_> = {}) const
767  {
768  static_for<0, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
769  }
770 
771  template <index_t iKIter, bool post_nop_ = false>
773  const AVecType& a_vec,
774  const BVecType& b_vec,
776  bool_constant<post_nop_> = {}) const
777  {
780 
781  static_assert(iKIter < kKIter);
782 
783  Impl{}(c_vec,
784  reinterpret_cast<const buf_a&>(a_vec)
785  .template get_as<typename Impl::AVecType>()[iKIter],
786  reinterpret_cast<const buf_b&>(b_vec)
787  .template get_as<typename Impl::BVecType>()[iKIter],
788  bool_constant<post_nop_>{});
789  }
790 
791  // c_vec = a_vec * b_vec
792  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
793  {
794  constexpr auto I0 = number<0>{};
797 
798  auto c_vec = Impl{}(
799  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
800  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
801 
802  static_for<1, kKIter, 1>{}([&](auto iKIter) { operator()(c_vec, a_vec, b_vec, iKIter); });
803 
804  return c_vec;
805  }
806 };
807 
808 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
WGAttrNumAccessEnum
Definition: warp_gemm_attribute_mfma.hpp:13
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename impl::ext_vector< T, N >::type ext_vector_t
Definition: vector_type.hpp:84
int32_t int32_t
Definition: integer.hpp:10
Definition: warp_gemm_attribute_mfma.hpp:23
static constexpr auto get_warp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:48
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const int32_t &a_scale, const BVecType &b_vec, const int32_t &b_scale) const
Definition: warp_gemm_attribute_mfma.hpp:113
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:38
typename Impl::BDataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:29
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma.hpp:40
typename Impl::AVecType AVecType
Definition: warp_gemm_attribute_mfma.hpp:32
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:34
decltype(get_warp_dstr_encoding< Impl::kAMLane >()) AWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:70
static constexpr auto AttrNumAccess
Definition: warp_gemm_attribute_mfma.hpp:25
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:30
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:106
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:42
static constexpr auto AttrNumAccessV
Definition: warp_gemm_attribute_mfma.hpp:26
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:36
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:24
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:84
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:39
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:37
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const int32_t &a_scale, const BVecType &b_vec, const int32_t &b_scale, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:94
typename Impl::ADataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:28
decltype(get_warp_dstr_encoding< Impl::kBNLane >()) BWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:71
typename Impl::BVecType BVecType
Definition: warp_gemm_attribute_mfma.hpp:33
Definition: warp_gemm_attribute_mfma.hpp:708
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:709
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:772
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:722
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:721
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:723
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:792
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:724
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:718
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:763
typename Impl::BDataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:712
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:719
static constexpr index_t SFactor
Definition: warp_gemm_attribute_mfma.hpp:725
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:727
typename Impl::ADataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:711
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:713
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:716
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:585
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:586
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:598
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:590
static constexpr index_t SFactor
Definition: warp_gemm_attribute_mfma.hpp:599
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:597
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:659
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:593
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:582
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:587
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:596
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:689
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:669
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:592
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:595
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:601
Definition: warp_gemm_attribute_mfma.hpp:461
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:474
static constexpr CK_TILE_DEVICE auto get_cwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:486
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:473
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:479
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:468
typename WarpGemmAttributeMfmaIterateK< Impl, kKIter, AttrNumAccess >::AWarpDstrEncoding BWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:527
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:481
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:466
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:476
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:478
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:542
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:477
typename WarpGemmAttributeMfmaIterateK< Impl, kKIter, AttrNumAccess >::BWarpDstrEncoding AWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:525
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:562
decltype(get_cwarp_dstr_encoding()) CWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:528
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:462
static constexpr auto AttrNumAccess
Definition: warp_gemm_attribute_mfma.hpp:463
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:471
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:532
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:467
Definition: warp_gemm_attribute_mfma.hpp:126
static constexpr auto AttrNumAccess
Definition: warp_gemm_attribute_mfma.hpp:130
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma.hpp:147
static constexpr CK_TILE_DEVICE auto get_cwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:210
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:141
typename Impl::BDataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:134
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:138
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:265
decltype(get_cwarp_dstr_encoding()) CWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:252
decltype(get_warp_dstr_encoding< Impl::kBNLane, Impl::kBNBlock, Impl::kAMBlock >()) BWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:251
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:140
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:285
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:149
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:135
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:256
typename Impl::ADataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:133
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:143
decltype(get_warp_dstr_encoding< Impl::kAMLane, Impl::kAMBlock, Impl::kBNBlock >()) AWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:249
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:145
static constexpr auto AttrNumAccessV
Definition: warp_gemm_attribute_mfma.hpp:131
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:144
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:129
static constexpr CK_TILE_DEVICE auto get_warp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:155
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:146
Definition: warp_gemm_attribute_mfma.hpp:365
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:377
static constexpr index_t SFactor
Definition: warp_gemm_attribute_mfma.hpp:380
typename Impl::BVecType AVecType
Definition: warp_gemm_attribute_mfma.hpp:372
typename Impl::AVecType BVecType
Definition: warp_gemm_attribute_mfma.hpp:373
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:369
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:450
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:366
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:440
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:379
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:370
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:376
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:368
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:382
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:378
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:374
Definition: warp_gemm_attribute_mfma.hpp:306
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:356
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma.hpp:323
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:311
typename Impl::AVecType BVecType
Definition: warp_gemm_attribute_mfma.hpp:316
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:322
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:325
typename Impl::BVecType AVecType
Definition: warp_gemm_attribute_mfma.hpp:315
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:321
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:312
typename WarpGemmAttributeMfma< Impl, AttrNumAccess >::AWarpDstrEncoding BWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:333
static constexpr auto AttrNumAccessV
Definition: warp_gemm_attribute_mfma.hpp:309
typename WarpGemmAttributeMfma< Impl, AttrNumAccess >::BWarpDstrEncoding AWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:331
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:313
static constexpr auto AttrNumAccess
Definition: warp_gemm_attribute_mfma.hpp:308
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:317
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:320
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:307
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:319
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:346
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: debug.hpp:27
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192