/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp Source File
warp_gemm_attribute_wmma.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 // TODO: currently only support 16 bit input, which means only support tr16_b128; will use ADataType
13 // to determine the layout in the future
14 template <typename Impl>
16 {
23  typename Impl::kABYs2RHsMajor,
24  typename Impl::kABYs2RHsMinor>;
25 };
26 
27 template <typename Impl>
29 {
36  typename Impl::kABYs2RHsMajor,
37  typename Impl::kABYs2RHsMinor>;
38 };
39 
40 template <typename Impl>
42 {
44  sequence<>,
49  typename Impl::kCYs2RHsMajor,
50  typename Impl::kCYs2RHsMinor>;
51 };
52 
53 template <typename Impl>
55 {
57  sequence<>,
62  typename Impl::kCTYs2RHsMajor,
63  typename Impl::kCTYs2RHsMinor>;
64 };
65 
66 template <typename WarpGemmAttributeWmmaImpl_, bool kTransC = false>
68 {
70 
71  // When kTransC is true and A/B types differ, we need an impl with swapped types
73  std::conditional_t<kTransC &&
74  !std::is_same_v<typename Impl::ADataType, typename Impl::BDataType>,
75  WarpGemmAttributeWmmaImpl<WmmaTraits<typename Impl::TraitsType::ArchType,
76  typename Impl::BDataType,
77  typename Impl::ADataType,
78  typename Impl::CDataType,
79  Impl::kM,
80  Impl::kN,
81  Impl::kK>>,
82  Impl>;
83 
84  using ADataType = typename Impl::ADataType;
85  using BDataType = typename Impl::BDataType;
86  using CDataType = typename Impl::CDataType;
87 
88  using AVecType = typename Impl::AVecType;
89  using BVecType = typename Impl::BVecType;
90  using CVecType = typename Impl::CVecType;
91 
92  static constexpr index_t kM = Impl::kM;
93  static constexpr index_t kN = Impl::kN;
94  static constexpr index_t kK = Impl::kK;
95  static constexpr index_t kCMLane = Impl::kCMLane;
96  static constexpr index_t kKPerThread = Impl::kABK0PerLane * Impl::kABK1PerLane;
97 
98  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
99 
100  // 16 bit input, kAMLane = 16, kABK0PerLane = 4, kABKLane = 2, kABK1PerLane = 2
101  // 8 bit input, kAMLane = 16, kABK0PerLane = 2, kABKLane = 2, kABK1PerLane = 4
104 
105  // kCM0PerLane = 1, kCMLane = 2, kCM1PerLane = 2, kCNLane = 16
107  std::conditional_t<kTransC,
110 
111  // c_vec += a_vec * b_vec
112  template <bool post_nop_ = false>
114  const AVecType& a_vec,
115  const BVecType& b_vec,
116  bool_constant<post_nop_> = {}) const
117  {
118  if constexpr(kTransC)
119  {
120  TransposedImpl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
121  }
122  else
123  {
124  Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
125  }
126  }
127 
128  // c_vec = a_vec * b_vec
129  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
130  {
131  if constexpr(kTransC)
132  {
133  return TransposedImpl{}(b_vec, a_vec);
134  }
135  else
136  {
137  return Impl{}(a_vec, b_vec);
138  }
139  }
140 };
141 
142 template <typename ADataType,
143  typename BDataType,
144  typename AccDataType,
145  index_t M_Warp_Tile,
146  index_t N_Warp_Tile,
147  index_t K_Warp_Tile>
149 {
150  if(is_gfx12_supported())
151  {
152  return has_wmma_traits_v<gfx12_t,
153  ADataType,
154  BDataType,
155  AccDataType,
156  M_Warp_Tile,
157  N_Warp_Tile,
158  K_Warp_Tile>;
159  }
160  else if(is_gfx11_supported())
161  {
162  return has_wmma_traits_v<gfx11_t,
163  ADataType,
164  BDataType,
165  AccDataType,
166  M_Warp_Tile,
167  N_Warp_Tile,
168  K_Warp_Tile>;
169  }
170  else
171  {
172  return false;
173  }
174 }
175 
176 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
constexpr bool has_wmma_traits_v
Definition: warp_gemm_attribute_wmma_impl.hpp:139
CK_TILE_HOST bool check_wmma_supported()
Definition: warp_gemm_attribute_wmma.hpp:148
bool is_gfx12_supported()
Definition: device_prop.hpp:63
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
bool is_gfx11_supported()
Definition: device_prop.hpp:55
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: warp_gemm_attribute_wmma.hpp:16
Definition: warp_gemm_attribute_wmma.hpp:29
Definition: warp_gemm_attribute_wmma.hpp:55
Definition: warp_gemm_attribute_wmma.hpp:42
Definition: warp_gemm_attribute_wmma.hpp:68
remove_cvref_t< WarpGemmAttributeWmmaImpl_ > Impl
Definition: warp_gemm_attribute_wmma.hpp:69
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_wmma.hpp:113
std::conditional_t< kTransC, typename CTransposedWarpDstrEncodingTrait< Impl >::type, typename CWarpDstrEncodingTrait< Impl >::type > CWarpDstrEncoding
Definition: warp_gemm_attribute_wmma.hpp:109
static constexpr index_t kN
Definition: warp_gemm_attribute_wmma.hpp:93
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_wmma.hpp:96
typename Impl::ADataType ADataType
Definition: warp_gemm_attribute_wmma.hpp:84
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_wmma.hpp:86
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_wmma.hpp:95
typename Impl::BVecType BVecType
Definition: warp_gemm_attribute_wmma.hpp:89
std::conditional_t< kTransC &&!std::is_same_v< typename Impl::ADataType, typename Impl::BDataType >, WarpGemmAttributeWmmaImpl< WmmaTraits< typename Impl::TraitsType::ArchType, typename Impl::BDataType, typename Impl::ADataType, typename Impl::CDataType, Impl::kM, Impl::kN, Impl::kK > >, Impl > TransposedImpl
Definition: warp_gemm_attribute_wmma.hpp:82
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_wmma.hpp:98
typename Impl::BDataType BDataType
Definition: warp_gemm_attribute_wmma.hpp:85
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_wmma.hpp:90
typename Impl::AVecType AVecType
Definition: warp_gemm_attribute_wmma.hpp:88
static constexpr index_t kK
Definition: warp_gemm_attribute_wmma.hpp:94
static constexpr index_t kM
Definition: warp_gemm_attribute_wmma.hpp:92
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_wmma.hpp:129
typename AWarpDstrEncodingTrait< Impl >::type AWarpDstrEncoding
Definition: warp_gemm_attribute_wmma.hpp:102
typename BWarpDstrEncodingTrait< Impl >::type BWarpDstrEncoding
Definition: warp_gemm_attribute_wmma.hpp:103
Definition: warp_gemm_attribute_wmma_impl.hpp:24
Definition: warp_gemm_attribute_wmma_impl.hpp:19
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192