/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp Source File
warp_gemm_attribute_mfma.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
11 // Number of groups of consecutive elements to fill in a ABKLane
13 {
14  Single = 1,
15  Double = 2,
16  Quad = 4,
17  Invalid = -1
18 };
19 
20 template <typename WarpGemmAttributeMfmaImpl_,
23 {
25  static constexpr auto AttrNumAccess = AttrNumAccess_;
26  static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
27 
28  using ADataType = typename Impl::ADataType;
29  using BDataType = typename Impl::BDataType;
30  using CDataType = typename Impl::CDataType;
31 
32  using AVecType = typename Impl::AVecType;
33  using BVecType = typename Impl::BVecType;
34  using CVecType = typename Impl::CVecType;
35 
36  static constexpr index_t kM = Impl::kM;
37  static constexpr index_t kN = Impl::kN;
38  static constexpr index_t kK = Impl::kK;
39  static constexpr index_t kKPerThread = Impl::kABKPerLane;
40  static constexpr index_t kCMLane = Impl::kCMLane;
41 
42  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
43 
44  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
45  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
46 
47  template <index_t kMNLane>
48  static constexpr auto get_warp_dstr_encoding()
49  {
50  if constexpr(AttrNumAccessV == 1)
51  {
53  sequence<>,
58  sequence<1>>{};
59  }
60  else
61  {
62  static_assert(kKPerThread % AttrNumAccessV == 0,
63  "kKPerThread must be divisible by NumAccess");
65  sequence<>,
67  sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
71  sequence<0, 2>>{};
72  }
73  }
74  using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane>());
75  using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane>());
76 
78  sequence<>,
85 
86  // c_vec += a_vec * b_vec
87  template <bool post_nop_ = false>
89  const AVecType& a_vec,
90  const BVecType& b_vec,
91  bool_constant<post_nop_> = {}) const
92  {
93  Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
94  }
95 
96  // c_vec = a_vec * b_vec
97  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
98  {
99  return Impl{}(a_vec, b_vec);
100  }
101 };
102 
103 template <typename WarpGemmAttributeMfmaImpl_,
104  index_t kKIter,
107 {
108  static_assert(kKIter > 0, "wrong!");
109 
111  static constexpr auto AttrNumAccess = AttrNumAccess_;
112  static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
113 
114  using ADataType = typename Impl::ADataType;
115  using BDataType = typename Impl::BDataType;
116  using CDataType = typename Impl::CDataType;
117 
118  using AVecType =
120  using BVecType =
122  using CVecType = typename Impl::CVecType;
123 
124  static constexpr index_t kM = Impl::kM;
125  static constexpr index_t kN = Impl::kN;
126  static constexpr index_t kK = Impl::kK * kKIter;
127  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
128 
129  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
130 
131  static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
132  "Multi-block on both M & N directions is not supported");
133 
134  CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
135  {
136  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
137  {
138  if constexpr(AttrNumAccessV == 1)
139  {
141  sequence<>,
146  sequence<2>,
147  sequence<1>>{};
148  }
149  else
150  {
151  static_assert(kKPerThread % AttrNumAccessV == 0,
152  "kKPerThread must be divisible by NumAccess");
154  sequence<>,
157  Impl::kABKLane,
158  Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
162  sequence<0, 2>>{};
163  }
164  }
165  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
166  {
167  static_assert(AttrNumAccessV == 1,
168  "Multiple access is not supported when using multi-block");
169  // each M blocks share the same data
176  sequence<2>,
177  sequence<1>>{};
178  }
179  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
180  {
181  static_assert(AttrNumAccessV == 1,
182  "Multiple access is not supported when using multi-block");
183  // single block to multi-block thread mapping
185  sequence<>,
190  sequence<2>,
191  sequence<1>>{};
192  }
193  }
194 
195  CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
196  {
197  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
198  {
199  if constexpr(AttrNumAccessV == 1)
200  {
202  sequence<>,
207  sequence<2>,
208  sequence<1>>{};
209  }
210  else
211  {
212 
213  static_assert(kKPerThread % AttrNumAccessV == 0,
214  "kKPerThread must be divisible by NumAccess");
216  sequence<>,
219  Impl::kABKLane,
220  Impl::kABKPerLane * kKIter / AttrNumAccessV>>,
224  sequence<0, 2>>{};
225  }
226  }
227  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
228  {
229  static_assert(AttrNumAccessV == 1,
230  "Multiple access is not supported when using multi-block");
231  // single block to multi-block thread mapping
233  sequence<>,
238  sequence<2>,
239  sequence<1>>{};
240  }
241  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
242  {
243  static_assert(AttrNumAccessV == 1,
244  "Multiple access is not supported when using multi-block");
245  // each N blocks share the same data
252  sequence<2>,
253  sequence<1>>{};
254  }
255  }
256 
257  CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
258  {
259  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
260  {
262  sequence<>,
268  sequence<0, 2>>{};
269  }
270  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
271  {
273  sequence<>,
279  sequence<0, 2>>{};
280  }
281  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
282  {
284  sequence<>,
285  tuple<
291  sequence<0, 2>>{};
292  }
293  }
294 
296 
298 
300 
301  // c_vec += a_vec * b_vec
302  template <bool post_nop_ = false>
304  const AVecType& a_vec,
305  const BVecType& b_vec,
306  bool_constant<post_nop_> = {}) const
307  {
310 
311  static_for<0, kKIter, 1>{}([&](auto iKIter) {
312  Impl{}(c_vec,
313  reinterpret_cast<const buf_a&>(a_vec)
314  .template get_as<typename Impl::AVecType>()[iKIter],
315  reinterpret_cast<const buf_b&>(b_vec)
316  .template get_as<typename Impl::BVecType>()[iKIter],
317  bool_constant<post_nop_>{});
318  });
319  }
320 
321  template <index_t iKIter, bool post_nop_ = false>
323  const AVecType& a_vec,
324  const BVecType& b_vec,
326  bool_constant<post_nop_> = {}) const
327  {
330 
331  static_assert(iKIter < kKIter);
332 
333  // static_for<0, kKIter, 1>{}([&](auto iKIter) {
334  Impl{}(c_vec,
335  reinterpret_cast<const buf_a&>(a_vec)
336  .template get_as<typename Impl::AVecType>()[iKIter],
337  reinterpret_cast<const buf_b&>(b_vec)
338  .template get_as<typename Impl::BVecType>()[iKIter],
339  bool_constant<post_nop_>{});
340  //});
341  }
342 
343  // c_vec = a_vec * b_vec
344  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
345  {
346  constexpr auto I0 = number<0>{};
349 
350  // c = a * b
351  auto c_vec = Impl{}(
352  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
353  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
354 
355  // c += a * b
356  static_for<1, kKIter, 1>{}([&](auto iKIter) {
357  Impl{}(c_vec,
358  reinterpret_cast<const buf_a&>(a_vec)
359  .template get_as<typename Impl::AVecType>()[iKIter],
360  reinterpret_cast<const buf_b&>(b_vec)
361  .template get_as<typename Impl::BVecType>()[iKIter]);
362  });
363 
364  return c_vec;
365  }
366 };
367 
368 template <typename WarpGemmAttributeMfmaImpl_,
371 {
373  static constexpr auto AttrNumAccess = AttrNumAccess_;
374  static constexpr auto AttrNumAccessV = static_cast<index_t>(AttrNumAccess);
375 
376  using ADataType = typename Impl::BDataType;
377  using BDataType = typename Impl::ADataType;
378  using CDataType = typename Impl::CDataType;
379 
380  using AVecType = typename Impl::BVecType;
381  using BVecType = typename Impl::AVecType;
382  using CVecType = typename Impl::CVecType;
383 
384  static constexpr index_t kM = Impl::kN;
385  static constexpr index_t kN = Impl::kM;
386  static constexpr index_t kK = Impl::kK;
387  static constexpr index_t kKPerThread = Impl::kABKPerLane;
388 
389  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
390 
391  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
392  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
393 
394  template <index_t kMNLane>
395  static constexpr auto get_warp_dstr_encoding()
396  {
397  if constexpr(AttrNumAccessV == 1)
398  {
400  sequence<>,
404  sequence<2>,
405  sequence<1>>{};
406  }
407  else
408  {
409  static_assert(kKPerThread % AttrNumAccessV == 0,
410  "kKPerThread must be divisible by NumAccess");
412  sequence<>,
414  sequence<AttrNumAccessV, Impl::kABKLane, Impl::kABKPerLane / AttrNumAccessV>>,
418  sequence<0, 2>>{};
419  }
420  }
421  using AWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kBNLane>());
422  using BWarpDstrEncoding = decltype(get_warp_dstr_encoding<Impl::kAMLane>());
423 
425  sequence<>,
432 
433  // c_vec += a_vec * b_vec
434  template <bool post_nop_ = false>
436  const AVecType& a_vec,
437  const BVecType& b_vec,
438  bool_constant<post_nop_> = {}) const
439  {
440  // swap A and B
441  Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
442  }
443 
444  // c_vec = a_vec * b_vec
445  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
446  {
447  // swap A and B
448  return Impl{}(b_vec, a_vec);
449  }
450 };
451 
452 template <typename WarpGemmAttributeMfmaImpl_, index_t SFactor_ = 2>
454 {
456 
457  using ADataType = typename Impl::BDataType;
458  using BDataType = typename Impl::ADataType;
459  using CDataType = typename Impl::CDataType;
460 
461  using AVecType = typename Impl::BVecType;
462  using BVecType = typename Impl::AVecType;
463  using CVecType = typename Impl::CVecType;
464 
465  static constexpr index_t kM = Impl::kN;
466  static constexpr index_t kN = Impl::kM;
467  static constexpr index_t kK = Impl::kK;
468  static constexpr index_t kKPerThread = Impl::kABKPerLane;
469  static constexpr index_t SFactor = SFactor_; // group how many CM1 together
470 
471  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
472 
473  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
474  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
475 
477  sequence<>,
481  sequence<2>,
482  sequence<1>>;
483 #if 0
485  sequence<>,
486  tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
487  Impl::kABKLane,
488  2,
489  Impl::kABKPerLane>,
493  sequence<2>,
494  sequence<1>>;
495 
497  sequence<>,
499  sequence<Impl::kCM0PerLane / 2, Impl::kCMLane, Impl::kCM1PerLane * 2>>,
504 #else
505  // TODO: more test not only 32x32
507  sequence<>,
508  tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
509  Impl::kCMLane,
510  SFactor,
511  Impl::kCM1PerLane>,
515  sequence<2>,
516  sequence<1>>;
517 
519  sequence<>,
521  sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>>,
526 #endif
527  template <bool post_nop_ = false>
528  // c_vec += a_vec * b_vec
530  const AVecType& a_vec,
531  const BVecType& b_vec,
532  bool_constant<post_nop_> = {}) const
533  {
534  // swap A and B
535  Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
536  }
537 
538  // c_vec = a_vec * b_vec
539  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
540  {
541  // swap A and B
542  return Impl{}(b_vec, a_vec);
543  }
544 };
545 
546 template <typename WarpGemmAttributeMfmaImpl_,
547  index_t kKIter,
550 {
552  static constexpr auto AttrNumAccess = AttrNumAccess_;
553 
554  // swap A and B
555  using ADataType = typename Impl::BDataType;
556  using BDataType = typename Impl::ADataType;
557  using CDataType = typename Impl::CDataType;
558 
559  using AVecType =
561  using BVecType =
563  using CVecType = typename Impl::CVecType;
564 
565  static constexpr index_t kM = Impl::kN;
566  static constexpr index_t kN = Impl::kM;
567  static constexpr index_t kK = Impl::kK * kKIter;
568  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
569 
570  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
571 
572  static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
573  "Multi-block on both M & N directions is not supported");
574 
575  CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
576  {
579  }
580 
581  CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
582  {
585  }
586 
587  CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
588  {
589  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
590  {
592  sequence<>,
598  sequence<0, 2>>{};
599  }
600  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
601  {
603  sequence<>,
609  sequence<0, 2>>{};
610  }
611  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
612  {
614  sequence<>,
615  tuple<
621  sequence<0, 2>>{};
622  }
623  }
624 
626 
628 
630 
631  template <bool post_nop_ = false>
632  // c_vec += a_vec * b_vec
634  const AVecType& a_vec,
635  const BVecType& b_vec,
636  bool_constant<post_nop_> = {}) const
637  {
640  // swap A and B, value and type
641  static_for<0, kKIter, 1>{}([&](auto iKIter) {
642  Impl{}(c_vec,
643  reinterpret_cast<const buf_b&>(b_vec)
644  .template get_as<typename Impl::BVecType>()[iKIter],
645  reinterpret_cast<const buf_a&>(a_vec)
646  .template get_as<typename Impl::AVecType>()[iKIter],
647  bool_constant<post_nop_>{});
648  });
649  }
650 
651  template <index_t iKIter, bool post_nop_ = false>
652  // c_vec += a_vec * b_vec
654  const AVecType& a_vec,
655  const BVecType& b_vec,
657  bool_constant<post_nop_> = {}) const
658  {
661 
662  static_assert(iKIter < kKIter);
663  // swap A and B, value and type
664  // static_for<0, kKIter, 1>{}([&](auto iKIter) {
665  Impl{}(c_vec,
666  reinterpret_cast<const buf_b&>(b_vec)
667  .template get_as<typename Impl::BVecType>()[iKIter],
668  reinterpret_cast<const buf_a&>(a_vec)
669  .template get_as<typename Impl::AVecType>()[iKIter],
670  bool_constant<post_nop_>{});
671  //});
672  }
673 
674  // c_vec = a_vec * b_vec
675  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
676  {
677  constexpr auto I0 = number<0>{};
680 
681  // swap A and B, value and type
682  auto c_vec = Impl{}(
683  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
684  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
685 
686  static_for<1, kKIter, 1>{}([&](auto iKIter) {
687  Impl{}(c_vec,
688  reinterpret_cast<const buf_b&>(b_vec)
689  .template get_as<typename Impl::BVecType>()[iKIter],
690  reinterpret_cast<const buf_a&>(a_vec)
691  .template get_as<typename Impl::AVecType>()[iKIter]);
692  });
693 
694  return c_vec;
695  }
696 };
697 
698 template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
700 {
702 
703  // swap A and B
704  using ADataType = typename Impl::BDataType;
705  using BDataType = typename Impl::ADataType;
706  using CDataType = typename Impl::CDataType;
707 
708  using AVecType =
710  using BVecType =
712  using CVecType = typename Impl::CVecType;
713 
714  static constexpr index_t kM = Impl::kN;
715  static constexpr index_t kN = Impl::kM;
716  static constexpr index_t kK = Impl::kK * kKIter;
717  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
718  static constexpr index_t SFactor = SFactor_; // group how many CM1 together
719 
720  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
721 
722  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
723  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
724 
726  sequence<>,
730  sequence<2>,
731  sequence<1>>;
732 #if 0
734  sequence<>,
735  tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
736  Impl::kABKLane,
737  2,
738  Impl::kABKPerLane>,
742  sequence<2>,
743  sequence<1>>;
744 
746  sequence<>,
748  sequence<Impl::kCM0PerLane / 2, Impl::kCMLane, Impl::kCM1PerLane * 2>>,
753 #else
754  // TODO: more test not only 32x32
756  sequence<>,
757  tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
758  Impl::kCMLane,
759  SFactor,
760  Impl::kCM1PerLane>,
764  sequence<2>,
765  sequence<1>>;
766 
768  sequence<>,
770  sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>>,
775 #endif
776  // c_vec += a_vec * b_vec
777  template <bool post_nop_ = false>
779  const AVecType& a_vec,
780  const BVecType& b_vec,
781  bool_constant<post_nop_> = {}) const
782  {
785  // swap A and B, value and type
786  static_for<0, kKIter, 1>{}([&](auto iKIter) {
787  Impl{}(c_vec,
788  reinterpret_cast<const buf_b&>(b_vec)
789  .template get_as<typename Impl::BVecType>()[iKIter],
790  reinterpret_cast<const buf_a&>(a_vec)
791  .template get_as<typename Impl::AVecType>()[iKIter],
792  bool_constant<post_nop_>{});
793  });
794  }
795 
796  template <index_t iKIter, bool post_nop_ = false>
798  const AVecType& a_vec,
799  const BVecType& b_vec,
801  bool_constant<post_nop_> = {}) const
802  {
805 
806  static_assert(iKIter < kKIter);
807  // swap A and B, value and type
808  // static_for<0, kKIter, 1>{}([&](auto iKIter) {
809  Impl{}(c_vec,
810  reinterpret_cast<const buf_b&>(b_vec)
811  .template get_as<typename Impl::BVecType>()[iKIter],
812  reinterpret_cast<const buf_a&>(a_vec)
813  .template get_as<typename Impl::AVecType>()[iKIter],
814  bool_constant<post_nop_>{});
815  //});
816  }
817 
818  // c_vec = a_vec * b_vec
819  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
820  {
823  constexpr auto I0 = number<0>{};
824 
825  // swap A and B, value and type
826  auto c_vec = Impl{}(
827  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
828  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
829 
830  static_for<1, kKIter, 1>{}([&](auto iKIter) {
831  Impl{}(c_vec,
832  reinterpret_cast<const buf_b&>(b_vec)
833  .template get_as<typename Impl::BVecType>()[iKIter],
834  reinterpret_cast<const buf_a&>(a_vec)
835  .template get_as<typename Impl::AVecType>()[iKIter]);
836  });
837 
838  return c_vec;
839  }
840 };
841 
842 template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
844 {
846 
847  using ADataType = typename Impl::ADataType;
848  using BDataType = typename Impl::BDataType;
849  using CDataType = typename Impl::CDataType;
850 
851  using AVecType =
853  using BVecType =
855  using CVecType = typename Impl::CVecType;
856 
857  static constexpr index_t kM = Impl::kM;
858  static constexpr index_t kN = Impl::kN;
859  static constexpr index_t kK = Impl::kK * kKIter;
860  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
861  static constexpr index_t SFactor = SFactor_; // group how many CM1 together
862 
863  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
864 
865  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
866  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
867 
869  sequence<>,
870  tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
871  Impl::kCMLane,
872  SFactor,
873  Impl::kCM1PerLane>,
877  sequence<2>,
878  sequence<1>>;
879 
881  sequence<>,
885  sequence<2>,
886  sequence<1>>;
887 
889  sequence<>,
890  tuple<sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>,
896 
897  // c_vec += a_vec * b_vec
898  template <bool post_nop_ = false>
900  const AVecType& a_vec,
901  const BVecType& b_vec,
902  bool_constant<post_nop_> = {}) const
903  {
906 
907  static_for<0, kKIter, 1>{}([&](auto iKIter) {
908  Impl{}(c_vec,
909  reinterpret_cast<const buf_a&>(a_vec)
910  .template get_as<typename Impl::AVecType>()[iKIter],
911  reinterpret_cast<const buf_b&>(b_vec)
912  .template get_as<typename Impl::BVecType>()[iKIter],
913  bool_constant<post_nop_>{});
914  });
915  }
916 
917  template <index_t iKIter, bool post_nop_ = false>
919  const AVecType& a_vec,
920  const BVecType& b_vec,
922  bool_constant<post_nop_> = {}) const
923  {
926 
927  static_assert(iKIter < kKIter);
928 
929  // static_for<0, kKIter, 1>{}([&](auto iKIter) {
930  Impl{}(c_vec,
931  reinterpret_cast<const buf_a&>(a_vec)
932  .template get_as<typename Impl::AVecType>()[iKIter],
933  reinterpret_cast<const buf_b&>(b_vec)
934  .template get_as<typename Impl::BVecType>()[iKIter],
935  bool_constant<post_nop_>{});
936  //});
937  }
938 
939  // c_vec = a_vec * b_vec
940  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
941  {
942  constexpr auto I0 = number<0>{};
945 
946  auto c_vec = Impl{}(
947  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
948  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
949 
950  static_for<1, kKIter, 1>{}([&](auto iKIter) {
951  Impl{}(c_vec,
952  reinterpret_cast<const buf_a&>(a_vec)
953  .template get_as<typename Impl::AVecType>()[iKIter],
954  reinterpret_cast<const buf_b&>(b_vec)
955  .template get_as<typename Impl::BVecType>()[iKIter]);
956  });
957 
958  return c_vec;
959  }
960 };
961 
962 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
WGAttrNumAccessEnum
Definition: warp_gemm_attribute_mfma.hpp:13
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename impl::ext_vector< T, N >::type ext_vector_t
Definition: vector_type.hpp:83
Definition: warp_gemm_attribute_mfma.hpp:23
static constexpr auto get_warp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:48
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:88
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_mfma.hpp:40
decltype(get_warp_dstr_encoding< Impl::kAMLane >()) AWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:74
typename Impl::BVecType BVecType
Definition: warp_gemm_attribute_mfma.hpp:33
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:42
typename Impl::BDataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:29
typename Impl::AVecType AVecType
Definition: warp_gemm_attribute_mfma.hpp:32
static constexpr auto AttrNumAccess
Definition: warp_gemm_attribute_mfma.hpp:25
static constexpr auto AttrNumAccessV
Definition: warp_gemm_attribute_mfma.hpp:26
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:36
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:30
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:24
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:97
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:37
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:38
decltype(get_warp_dstr_encoding< Impl::kBNLane >()) BWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:75
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:34
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:39
typename Impl::ADataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:28
Definition: warp_gemm_attribute_mfma.hpp:844
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:857
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:863
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:940
typename Impl::BDataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:848
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:845
static constexpr index_t SFactor
Definition: warp_gemm_attribute_mfma.hpp:861
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:854
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:899
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:860
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:855
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:859
typename Impl::ADataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:847
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:849
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:852
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:858
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:918
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:778
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:701
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:717
static constexpr index_t SFactor
Definition: warp_gemm_attribute_mfma.hpp:718
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:720
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:711
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:714
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:716
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:712
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:715
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:705
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:819
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:709
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:706
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:797
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:704
Definition: warp_gemm_attribute_mfma.hpp:550
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:555
decltype(get_cwarp_dstr_encoding()) CWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:629
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:563
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:560
static constexpr CK_TILE_DEVICE auto get_cwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:587
static constexpr CK_TILE_DEVICE auto get_awarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:575
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:567
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:653
decltype(get_awarp_dstr_encoding()) AWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:625
decltype(get_bwarp_dstr_encoding()) BWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:627
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:562
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:565
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:633
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:570
static constexpr auto AttrNumAccess
Definition: warp_gemm_attribute_mfma.hpp:552
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:566
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:556
static constexpr CK_TILE_DEVICE auto get_bwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:581
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:557
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:568
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:551
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:675
Definition: warp_gemm_attribute_mfma.hpp:107
static constexpr CK_TILE_DEVICE auto get_awarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:134
decltype(get_bwarp_dstr_encoding()) BWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:297
static constexpr CK_TILE_DEVICE auto get_cwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:257
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:127
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:124
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:322
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:110
typename Impl::ADataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:114
static constexpr auto AttrNumAccessV
Definition: warp_gemm_attribute_mfma.hpp:112
static constexpr CK_TILE_DEVICE auto get_bwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:195
decltype(get_cwarp_dstr_encoding()) CWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:299
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:344
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:122
decltype(get_awarp_dstr_encoding()) AWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:295
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:116
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:125
static constexpr auto AttrNumAccess
Definition: warp_gemm_attribute_mfma.hpp:111
typename Impl::BDataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:115
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:303
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:121
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:126
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:129
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:119
Definition: warp_gemm_attribute_mfma.hpp:454
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:457
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:468
typename Impl::BVecType AVecType
Definition: warp_gemm_attribute_mfma.hpp:461
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:455
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:529
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:463
static constexpr index_t SFactor
Definition: warp_gemm_attribute_mfma.hpp:469
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:466
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:465
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:539
typename Impl::AVecType BVecType
Definition: warp_gemm_attribute_mfma.hpp:462
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:459
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:471
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:458
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:467
Definition: warp_gemm_attribute_mfma.hpp:371
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:384
decltype(get_warp_dstr_encoding< Impl::kAMLane >()) BWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:422
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:445
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:376
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:387
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:372
static constexpr auto get_warp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:395
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:386
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:382
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:435
typename Impl::BVecType AVecType
Definition: warp_gemm_attribute_mfma.hpp:380
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:385
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:389
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:378
decltype(get_warp_dstr_encoding< Impl::kBNLane >()) AWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:421
static constexpr auto AttrNumAccess
Definition: warp_gemm_attribute_mfma.hpp:373
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:377
static constexpr auto AttrNumAccessV
Definition: warp_gemm_attribute_mfma.hpp:374
typename Impl::AVecType BVecType
Definition: warp_gemm_attribute_mfma.hpp:381
Definition: integral_constant.hpp:13
Definition: sequence.hpp:52
Definition: functional.hpp:43
Definition: debug.hpp:67
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192