/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp Source File
warp_gemm_attribute_mfma.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
11 template <typename WarpGemmAttributeMfmaImpl_>
13 {
15 
16  using ADataType = typename Impl::ADataType;
17  using BDataType = typename Impl::BDataType;
18  using CDataType = typename Impl::CDataType;
19 
20  using AVecType = typename Impl::AVecType;
21  using BVecType = typename Impl::BVecType;
22  using CVecType = typename Impl::CVecType;
23 
24  static constexpr index_t kM = Impl::kM;
25  static constexpr index_t kN = Impl::kN;
26  static constexpr index_t kK = Impl::kK;
27  static constexpr index_t kKPerThread = Impl::kABKPerLane;
28 
29  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
30 
31  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
32  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
33 
35  sequence<>,
40  sequence<1>>;
41 
43  sequence<>,
48  sequence<1>>;
49 
51  sequence<>,
58 
59  // c_vec += a_vec * b_vec
60  template <bool post_nop_ = false>
62  const AVecType& a_vec,
63  const BVecType& b_vec,
64  bool_constant<post_nop_> = {}) const
65  {
66  Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
67  }
68 
69  // c_vec = a_vec * b_vec
70  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
71  {
72  return Impl{}(a_vec, b_vec);
73  }
74 };
75 
76 template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter>
78 {
79  static_assert(kKIter > 0, "wrong!");
80 
82 
83  using ADataType = typename Impl::ADataType;
84  using BDataType = typename Impl::BDataType;
85  using CDataType = typename Impl::CDataType;
86 
87  using AVecType =
89  using BVecType =
91  using CVecType = typename Impl::CVecType;
92 
93  static constexpr index_t kM = Impl::kM;
94  static constexpr index_t kN = Impl::kN;
95  static constexpr index_t kK = Impl::kK * kKIter;
96  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
97 
98  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
99 
100  static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
101  "Multi-block on both M & N directions is not supported");
102 
103  CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
104  {
105  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
106  {
108  sequence<>,
113  sequence<2>,
114  sequence<1>>{};
115  }
116  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
117  {
118  // each M blocks share the same data
125  sequence<2>,
126  sequence<1>>{};
127  }
128  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
129  {
130  // single block to multi-block thread mapping
132  sequence<>,
137  sequence<2>,
138  sequence<1>>{};
139  }
140  }
141 
142  CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
143  {
144  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
145  {
147  sequence<>,
152  sequence<2>,
153  sequence<1>>{};
154  }
155  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
156  {
157  // single block to multi-block thread mapping
159  sequence<>,
164  sequence<2>,
165  sequence<1>>{};
166  }
167  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
168  {
169  // each N blocks share the same data
176  sequence<2>,
177  sequence<1>>{};
178  }
179  }
180 
181  CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
182  {
183  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
184  {
186  sequence<>,
192  sequence<0, 2>>{};
193  }
194  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
195  {
197  sequence<>,
203  sequence<0, 2>>{};
204  }
205  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
206  {
208  sequence<>,
209  tuple<
215  sequence<0, 2>>{};
216  }
217  }
218 
220 
222 
224 
225  // c_vec += a_vec * b_vec
226  template <bool post_nop_ = false>
228  const AVecType& a_vec,
229  const BVecType& b_vec,
230  bool_constant<post_nop_> = {}) const
231  {
234 
235  static_for<0, kKIter, 1>{}([&](auto iKIter) {
236  Impl{}(c_vec,
237  reinterpret_cast<const buf_a&>(a_vec)
238  .template get_as<typename Impl::AVecType>()[iKIter],
239  reinterpret_cast<const buf_b&>(b_vec)
240  .template get_as<typename Impl::BVecType>()[iKIter],
241  bool_constant<post_nop_>{});
242  });
243  }
244 
245  template <index_t iKIter, bool post_nop_ = false>
247  const AVecType& a_vec,
248  const BVecType& b_vec,
250  bool_constant<post_nop_> = {}) const
251  {
254 
255  static_assert(iKIter < kKIter);
256 
257  // static_for<0, kKIter, 1>{}([&](auto iKIter) {
258  Impl{}(c_vec,
259  reinterpret_cast<const buf_a&>(a_vec)
260  .template get_as<typename Impl::AVecType>()[iKIter],
261  reinterpret_cast<const buf_b&>(b_vec)
262  .template get_as<typename Impl::BVecType>()[iKIter],
263  bool_constant<post_nop_>{});
264  //});
265  }
266 
267  // c_vec = a_vec * b_vec
268  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
269  {
270  constexpr auto I0 = number<0>{};
273 
274  // c = a * b
275  auto c_vec = Impl{}(
276  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
277  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
278 
279  // c += a * b
280  static_for<1, kKIter, 1>{}([&](auto iKIter) {
281  Impl{}(c_vec,
282  reinterpret_cast<const buf_a&>(a_vec)
283  .template get_as<typename Impl::AVecType>()[iKIter],
284  reinterpret_cast<const buf_b&>(b_vec)
285  .template get_as<typename Impl::BVecType>()[iKIter]);
286  });
287 
288  return c_vec;
289  }
290 };
291 
292 template <typename WarpGemmAttributeMfmaImpl_>
294 {
296 
297  using ADataType = typename Impl::BDataType;
298  using BDataType = typename Impl::ADataType;
299  using CDataType = typename Impl::CDataType;
300 
301  using AVecType = typename Impl::BVecType;
302  using BVecType = typename Impl::AVecType;
303  using CVecType = typename Impl::CVecType;
304 
305  static constexpr index_t kM = Impl::kN;
306  static constexpr index_t kN = Impl::kM;
307  static constexpr index_t kK = Impl::kK;
308  static constexpr index_t kKPerThread = Impl::kABKPerLane;
309 
310  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
311 
312  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
313  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
314 
316  sequence<>,
320  sequence<2>,
321  sequence<1>>;
322 
324  sequence<>,
328  sequence<2>,
329  sequence<1>>;
330 
332  sequence<>,
339 
340  // c_vec += a_vec * b_vec
341  template <bool post_nop_ = false>
343  const AVecType& a_vec,
344  const BVecType& b_vec,
345  bool_constant<post_nop_> = {}) const
346  {
347  // swap A and B
348  Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
349  }
350 
351  // c_vec = a_vec * b_vec
352  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
353  {
354  // swap A and B
355  return Impl{}(b_vec, a_vec);
356  }
357 };
358 
359 template <typename WarpGemmAttributeMfmaImpl_, index_t SFactor_ = 2>
361 {
363 
364  using ADataType = typename Impl::BDataType;
365  using BDataType = typename Impl::ADataType;
366  using CDataType = typename Impl::CDataType;
367 
368  using AVecType = typename Impl::BVecType;
369  using BVecType = typename Impl::AVecType;
370  using CVecType = typename Impl::CVecType;
371 
372  static constexpr index_t kM = Impl::kN;
373  static constexpr index_t kN = Impl::kM;
374  static constexpr index_t kK = Impl::kK;
375  static constexpr index_t kKPerThread = Impl::kABKPerLane;
376  static constexpr index_t SFactor = SFactor_; // group how many CM1 together
377 
378  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
379 
380  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
381  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
382 
384  sequence<>,
388  sequence<2>,
389  sequence<1>>;
390 #if 0
392  sequence<>,
393  tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
394  Impl::kABKLane,
395  2,
396  Impl::kABKPerLane>,
400  sequence<2>,
401  sequence<1>>;
402 
404  sequence<>,
406  sequence<Impl::kCM0PerLane / 2, Impl::kCMLane, Impl::kCM1PerLane * 2>>,
411 #else
412  // TODO: more test not only 32x32
414  sequence<>,
415  tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
416  Impl::kCMLane,
417  SFactor,
418  Impl::kCM1PerLane>,
422  sequence<2>,
423  sequence<1>>;
424 
426  sequence<>,
428  sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>>,
433 #endif
434  template <bool post_nop_ = false>
435  // c_vec += a_vec * b_vec
437  const AVecType& a_vec,
438  const BVecType& b_vec,
439  bool_constant<post_nop_> = {}) const
440  {
441  // swap A and B
442  Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
443  }
444 
445  // c_vec = a_vec * b_vec
446  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
447  {
448  // swap A and B
449  return Impl{}(b_vec, a_vec);
450  }
451 };
452 
453 template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter>
455 {
457 
458  // swap A and B
459  using ADataType = typename Impl::BDataType;
460  using BDataType = typename Impl::ADataType;
461  using CDataType = typename Impl::CDataType;
462 
463  using AVecType =
465  using BVecType =
467  using CVecType = typename Impl::CVecType;
468 
469  static constexpr index_t kM = Impl::kN;
470  static constexpr index_t kN = Impl::kM;
471  static constexpr index_t kK = Impl::kK * kKIter;
472  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
473 
474  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
475 
476  static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1,
477  "Multi-block on both M & N directions is not supported");
478 
479  CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding()
480  {
481  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
482  {
484  sequence<>,
489  sequence<2>,
490  sequence<1>>{};
491  }
492  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
493  {
494  // single block to multi-block thread mapping
496  sequence<>,
501  sequence<2>,
502  sequence<1>>{};
503  }
504  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
505  {
506  // each N blocks share the same data
513  sequence<2>,
514  sequence<1>>{};
515  }
516  }
517 
518  CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding()
519  {
520  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
521  {
523  sequence<>,
528  sequence<2>,
529  sequence<1>>{};
530  }
531  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
532  {
533  // each M blocks share the same data
540  sequence<2>,
541  sequence<1>>{};
542  }
543  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
544  {
545  // single block to multi-block thread mapping
547  sequence<>,
552  sequence<2>,
553  sequence<1>>{};
554  }
555  }
556 
557  CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding()
558  {
559  if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1)
560  {
562  sequence<>,
568  sequence<0, 2>>{};
569  }
570  else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock)
571  {
573  sequence<>,
579  sequence<0, 2>>{};
580  }
581  else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1)
582  {
584  sequence<>,
585  tuple<
591  sequence<0, 2>>{};
592  }
593  }
594 
596 
598 
600 
601  template <bool post_nop_ = false>
602  // c_vec += a_vec * b_vec
604  const AVecType& a_vec,
605  const BVecType& b_vec,
606  bool_constant<post_nop_> = {}) const
607  {
610  // swap A and B, value and type
611  static_for<0, kKIter, 1>{}([&](auto iKIter) {
612  Impl{}(c_vec,
613  reinterpret_cast<const buf_b&>(b_vec)
614  .template get_as<typename Impl::BVecType>()[iKIter],
615  reinterpret_cast<const buf_a&>(a_vec)
616  .template get_as<typename Impl::AVecType>()[iKIter],
617  bool_constant<post_nop_>{});
618  });
619  }
620 
621  template <index_t iKIter, bool post_nop_ = false>
622  // c_vec += a_vec * b_vec
624  const AVecType& a_vec,
625  const BVecType& b_vec,
627  bool_constant<post_nop_> = {}) const
628  {
631 
632  static_assert(iKIter < kKIter);
633  // swap A and B, value and type
634  // static_for<0, kKIter, 1>{}([&](auto iKIter) {
635  Impl{}(c_vec,
636  reinterpret_cast<const buf_b&>(b_vec)
637  .template get_as<typename Impl::BVecType>()[iKIter],
638  reinterpret_cast<const buf_a&>(a_vec)
639  .template get_as<typename Impl::AVecType>()[iKIter],
640  bool_constant<post_nop_>{});
641  //});
642  }
643 
644  // c_vec = a_vec * b_vec
645  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
646  {
647  constexpr auto I0 = number<0>{};
650 
651  // swap A and B, value and type
652  auto c_vec = Impl{}(
653  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
654  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
655 
656  static_for<1, kKIter, 1>{}([&](auto iKIter) {
657  Impl{}(c_vec,
658  reinterpret_cast<const buf_b&>(b_vec)
659  .template get_as<typename Impl::BVecType>()[iKIter],
660  reinterpret_cast<const buf_a&>(a_vec)
661  .template get_as<typename Impl::AVecType>()[iKIter]);
662  });
663 
664  return c_vec;
665  }
666 };
667 
668 template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
670 {
672 
673  // swap A and B
674  using ADataType = typename Impl::BDataType;
675  using BDataType = typename Impl::ADataType;
676  using CDataType = typename Impl::CDataType;
677 
678  using AVecType =
680  using BVecType =
682  using CVecType = typename Impl::CVecType;
683 
684  static constexpr index_t kM = Impl::kN;
685  static constexpr index_t kN = Impl::kM;
686  static constexpr index_t kK = Impl::kK * kKIter;
687  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
688  static constexpr index_t SFactor = SFactor_; // group how many CM1 together
689 
690  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
691 
692  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
693  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
694 
696  sequence<>,
700  sequence<2>,
701  sequence<1>>;
702 #if 0
704  sequence<>,
705  tuple<sequence<Impl::kAMLane / (Impl::kABKPerLane * Impl::kABKLane * 2),
706  Impl::kABKLane,
707  2,
708  Impl::kABKPerLane>,
712  sequence<2>,
713  sequence<1>>;
714 
716  sequence<>,
718  sequence<Impl::kCM0PerLane / 2, Impl::kCMLane, Impl::kCM1PerLane * 2>>,
723 #else
724  // TODO: more test not only 32x32
726  sequence<>,
727  tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
728  Impl::kCMLane,
729  SFactor,
730  Impl::kCM1PerLane>,
734  sequence<2>,
735  sequence<1>>;
736 
738  sequence<>,
740  sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>>,
745 #endif
746  // c_vec += a_vec * b_vec
747  template <bool post_nop_ = false>
749  const AVecType& a_vec,
750  const BVecType& b_vec,
751  bool_constant<post_nop_> = {}) const
752  {
755  // swap A and B, value and type
756  static_for<0, kKIter, 1>{}([&](auto iKIter) {
757  Impl{}(c_vec,
758  reinterpret_cast<const buf_b&>(b_vec)
759  .template get_as<typename Impl::BVecType>()[iKIter],
760  reinterpret_cast<const buf_a&>(a_vec)
761  .template get_as<typename Impl::AVecType>()[iKIter],
762  bool_constant<post_nop_>{});
763  });
764  }
765 
766  template <index_t iKIter, bool post_nop_ = false>
768  const AVecType& a_vec,
769  const BVecType& b_vec,
771  bool_constant<post_nop_> = {}) const
772  {
775 
776  static_assert(iKIter < kKIter);
777  // swap A and B, value and type
778  // static_for<0, kKIter, 1>{}([&](auto iKIter) {
779  Impl{}(c_vec,
780  reinterpret_cast<const buf_b&>(b_vec)
781  .template get_as<typename Impl::BVecType>()[iKIter],
782  reinterpret_cast<const buf_a&>(a_vec)
783  .template get_as<typename Impl::AVecType>()[iKIter],
784  bool_constant<post_nop_>{});
785  //});
786  }
787 
788  // c_vec = a_vec * b_vec
789  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
790  {
793  constexpr auto I0 = number<0>{};
794 
795  // swap A and B, value and type
796  auto c_vec = Impl{}(
797  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0],
798  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0]);
799 
800  static_for<1, kKIter, 1>{}([&](auto iKIter) {
801  Impl{}(c_vec,
802  reinterpret_cast<const buf_b&>(b_vec)
803  .template get_as<typename Impl::BVecType>()[iKIter],
804  reinterpret_cast<const buf_a&>(a_vec)
805  .template get_as<typename Impl::AVecType>()[iKIter]);
806  });
807 
808  return c_vec;
809  }
810 };
811 
812 template <typename WarpGemmAttributeMfmaImpl_, index_t kKIter, index_t SFactor_ = 2>
814 {
816 
817  using ADataType = typename Impl::ADataType;
818  using BDataType = typename Impl::BDataType;
819  using CDataType = typename Impl::CDataType;
820 
821  using AVecType =
823  using BVecType =
825  using CVecType = typename Impl::CVecType;
826 
827  static constexpr index_t kM = Impl::kM;
828  static constexpr index_t kN = Impl::kN;
829  static constexpr index_t kK = Impl::kK * kKIter;
830  static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
831  static constexpr index_t SFactor = SFactor_; // group how many CM1 together
832 
833  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
834 
835  static_assert(Impl::kAMBlock == 1 && Impl::kBNBlock == 1,
836  "Multi-block WarpGemmAttributeMfmaImpl is not supported");
837 
839  sequence<>,
840  tuple<sequence<Impl::kAMLane / (Impl::kCMLane * SFactor * Impl::kCM1PerLane),
841  Impl::kCMLane,
842  SFactor,
843  Impl::kCM1PerLane>,
847  sequence<2>,
848  sequence<1>>;
849 
851  sequence<>,
855  sequence<2>,
856  sequence<1>>;
857 
859  sequence<>,
860  tuple<sequence<Impl::kCM0PerLane / SFactor, Impl::kCMLane, Impl::kCM1PerLane * SFactor>,
866 
867  // c_vec += a_vec * b_vec
868  template <bool post_nop_ = false>
870  const AVecType& a_vec,
871  const BVecType& b_vec,
872  bool_constant<post_nop_> = {}) const
873  {
876 
877  static_for<0, kKIter, 1>{}([&](auto iKIter) {
878  Impl{}(c_vec,
879  reinterpret_cast<const buf_a&>(a_vec)
880  .template get_as<typename Impl::AVecType>()[iKIter],
881  reinterpret_cast<const buf_b&>(b_vec)
882  .template get_as<typename Impl::BVecType>()[iKIter],
883  bool_constant<post_nop_>{});
884  });
885  }
886 
887  template <index_t iKIter, bool post_nop_ = false>
889  const AVecType& a_vec,
890  const BVecType& b_vec,
892  bool_constant<post_nop_> = {}) const
893  {
896 
897  static_assert(iKIter < kKIter);
898 
899  // static_for<0, kKIter, 1>{}([&](auto iKIter) {
900  Impl{}(c_vec,
901  reinterpret_cast<const buf_a&>(a_vec)
902  .template get_as<typename Impl::AVecType>()[iKIter],
903  reinterpret_cast<const buf_b&>(b_vec)
904  .template get_as<typename Impl::BVecType>()[iKIter],
905  bool_constant<post_nop_>{});
906  //});
907  }
908 
909  // c_vec = a_vec * b_vec
910  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
911  {
912  constexpr auto I0 = number<0>{};
915 
916  auto c_vec = Impl{}(
917  reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
918  reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
919 
920  static_for<1, kKIter, 1>{}([&](auto iKIter) {
921  Impl{}(c_vec,
922  reinterpret_cast<const buf_a&>(a_vec)
923  .template get_as<typename Impl::AVecType>()[iKIter],
924  reinterpret_cast<const buf_b&>(b_vec)
925  .template get_as<typename Impl::BVecType>()[iKIter]);
926  });
927 
928  return c_vec;
929  }
930 };
931 
932 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
tuple_array< T, N > thread_buffer
Definition: thread_buffer.hpp:14
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename impl::ext_vector< T, N >::type ext_vector_t
Definition: vector_type.hpp:83
Definition: warp_gemm_attribute_mfma.hpp:13
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:14
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:29
typename Impl::BDataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:17
typename Impl::ADataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:16
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:22
typename Impl::AVecType AVecType
Definition: warp_gemm_attribute_mfma.hpp:20
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:70
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:18
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:26
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:25
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:24
typename Impl::BVecType BVecType
Definition: warp_gemm_attribute_mfma.hpp:21
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:27
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:61
Definition: warp_gemm_attribute_mfma.hpp:814
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:827
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:833
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:910
typename Impl::BDataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:818
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:815
static constexpr index_t SFactor
Definition: warp_gemm_attribute_mfma.hpp:831
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:824
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:869
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:830
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:825
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:829
typename Impl::ADataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:817
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:819
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:822
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:828
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:888
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:748
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:671
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:687
static constexpr index_t SFactor
Definition: warp_gemm_attribute_mfma.hpp:688
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:690
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:681
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:684
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:686
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:682
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:685
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:675
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:789
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:679
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:676
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:767
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:674
Definition: warp_gemm_attribute_mfma.hpp:455
decltype(get_awarp_dstr_encoding()) AWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:595
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:466
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:461
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:467
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:645
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:472
static constexpr CK_TILE_DEVICE auto get_cwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:557
static constexpr CK_TILE_DEVICE auto get_awarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:479
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:471
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:460
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:474
decltype(get_cwarp_dstr_encoding()) CWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:599
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:603
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:623
static constexpr CK_TILE_DEVICE auto get_bwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:518
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:464
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:459
decltype(get_bwarp_dstr_encoding()) BWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:597
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:470
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:469
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:456
Definition: warp_gemm_attribute_mfma.hpp:78
ext_vector_t< ADataType, vector_traits< typename Impl::AVecType >::vector_size *kKIter > AVecType
Definition: warp_gemm_attribute_mfma.hpp:88
decltype(get_cwarp_dstr_encoding()) CWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:223
static constexpr CK_TILE_DEVICE auto get_awarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:103
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:268
decltype(get_awarp_dstr_encoding()) AWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:219
ext_vector_t< BDataType, vector_traits< typename Impl::BVecType >::vector_size *kKIter > BVecType
Definition: warp_gemm_attribute_mfma.hpp:90
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:93
typename Impl::BDataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:84
static constexpr CK_TILE_DEVICE auto get_cwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:181
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:227
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:81
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, number< iKIter >, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:246
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:96
static constexpr CK_TILE_DEVICE auto get_bwarp_dstr_encoding()
Definition: warp_gemm_attribute_mfma.hpp:142
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:98
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:94
typename Impl::ADataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:83
decltype(get_bwarp_dstr_encoding()) BWarpDstrEncoding
Definition: warp_gemm_attribute_mfma.hpp:221
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:85
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:91
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:95
Definition: warp_gemm_attribute_mfma.hpp:361
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:364
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:375
typename Impl::BVecType AVecType
Definition: warp_gemm_attribute_mfma.hpp:368
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:362
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:436
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:370
static constexpr index_t SFactor
Definition: warp_gemm_attribute_mfma.hpp:376
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:373
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:372
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:446
typename Impl::AVecType BVecType
Definition: warp_gemm_attribute_mfma.hpp:369
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:366
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:378
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:365
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:374
Definition: warp_gemm_attribute_mfma.hpp:294
static constexpr index_t kM
Definition: warp_gemm_attribute_mfma.hpp:305
typename Impl::ADataType BDataType
Definition: warp_gemm_attribute_mfma.hpp:298
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_mfma.hpp:303
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_mfma.hpp:308
static constexpr index_t kK
Definition: warp_gemm_attribute_mfma.hpp:307
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_mfma.hpp:342
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_mfma.hpp:299
typename Impl::BVecType AVecType
Definition: warp_gemm_attribute_mfma.hpp:301
static constexpr index_t kN
Definition: warp_gemm_attribute_mfma.hpp:306
remove_cvref_t< WarpGemmAttributeMfmaImpl_ > Impl
Definition: warp_gemm_attribute_mfma.hpp:295
typename Impl::AVecType BVecType
Definition: warp_gemm_attribute_mfma.hpp:302
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_mfma.hpp:310
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_mfma.hpp:352
typename Impl::BDataType ADataType
Definition: warp_gemm_attribute_mfma.hpp:297
Definition: integral_constant.hpp:13
Definition: sequence.hpp:52
Definition: functional.hpp:43
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192