/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_wmma.hpp Source File
warp_gemm_attribute_wmma.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 // TODO: currently only support 16 bit input, which means only support tr16_b128; will use ADataType
13 // to determine the layout in the future
14 template <typename Impl>
16 {
23  typename Impl::kABYs2RHsMajor,
24  typename Impl::kABYs2RHsMinor>;
25 };
26 
27 template <typename Impl>
29 {
36  typename Impl::kABYs2RHsMajor,
37  typename Impl::kABYs2RHsMinor>;
38 };
39 
40 template <typename Impl>
42 {
44  sequence<>,
49  typename Impl::kCYs2RHsMajor,
50  typename Impl::kCYs2RHsMinor>;
51 };
52 
53 template <typename Impl>
55 {
57  sequence<>,
62  typename Impl::kCTYs2RHsMajor,
63  typename Impl::kCTYs2RHsMinor>;
64 };
65 
66 template <typename WarpGemmAttributeWmmaImpl_, bool kTransC = false>
68 {
70 
71  using ADataType = typename Impl::ADataType;
72  using BDataType = typename Impl::BDataType;
73  using CDataType = typename Impl::CDataType;
74 
75  using AVecType = typename Impl::AVecType;
76  using BVecType = typename Impl::BVecType;
77  using CVecType = typename Impl::CVecType;
78 
79  static constexpr index_t kM = Impl::kM;
80  static constexpr index_t kN = Impl::kN;
81  static constexpr index_t kK = Impl::kK;
82  static constexpr index_t kCMLane = Impl::kCMLane;
83  static constexpr index_t kKPerThread = Impl::kABK0PerLane * Impl::kABK1PerLane;
84 
85  CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
86 
87  // 16 bit input, kAMLane = 16, kABK0PerLane = 4, kABKLane = 2, kABK1PerLane = 2
88  // 8 bit input, kAMLane = 16, kABK0PerLane = 2, kABKLane = 2, kABK1PerLane = 4
91 
92  // kCM0PerLane = 1, kCMLane = 2, kCM1PerLane = 2, kCNLane = 16
94  std::conditional_t<kTransC,
97 
98  // c_vec += a_vec * b_vec
99  template <bool post_nop_ = false>
101  const AVecType& a_vec,
102  const BVecType& b_vec,
103  bool_constant<post_nop_> = {}) const
104  {
105  if constexpr(kTransC)
106  {
107  Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
108  }
109  else
110  {
111  Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
112  }
113  }
114 
115  // c_vec = a_vec * b_vec
116  CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
117  {
118  if constexpr(kTransC)
119  {
120  return Impl{}(b_vec, a_vec);
121  }
122  else
123  {
124  return Impl{}(a_vec, b_vec);
125  }
126  }
127 };
128 
129 template <typename ADataType,
130  typename BDataType,
131  typename AccDataType,
132  index_t M_Warp_Tile,
133  index_t N_Warp_Tile,
134  index_t K_Warp_Tile>
136 {
137  if(is_gfx12_supported())
138  {
139  return has_wmma_traits_v<gfx12_t,
140  ADataType,
141  BDataType,
142  AccDataType,
143  M_Warp_Tile,
144  N_Warp_Tile,
145  K_Warp_Tile>;
146  }
147  else if(is_gfx11_supported())
148  {
149  return has_wmma_traits_v<gfx11_t,
150  ADataType,
151  BDataType,
152  AccDataType,
153  M_Warp_Tile,
154  N_Warp_Tile,
155  K_Warp_Tile>;
156  }
157  else
158  {
159  return false;
160  }
161 }
162 
163 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
constexpr bool has_wmma_traits_v
Definition: warp_gemm_attribute_wmma_impl.hpp:138
CK_TILE_HOST bool check_wmma_supported()
Definition: warp_gemm_attribute_wmma.hpp:135
bool is_gfx12_supported()
Definition: device_prop.hpp:63
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
bool is_gfx11_supported()
Definition: device_prop.hpp:55
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: warp_gemm_attribute_wmma.hpp:16
Definition: warp_gemm_attribute_wmma.hpp:29
Definition: warp_gemm_attribute_wmma.hpp:55
Definition: warp_gemm_attribute_wmma.hpp:42
Definition: warp_gemm_attribute_wmma.hpp:68
remove_cvref_t< WarpGemmAttributeWmmaImpl_ > Impl
Definition: warp_gemm_attribute_wmma.hpp:69
CK_TILE_DEVICE void operator()(CVecType &c_vec, const AVecType &a_vec, const BVecType &b_vec, bool_constant< post_nop_ >={}) const
Definition: warp_gemm_attribute_wmma.hpp:100
std::conditional_t< kTransC, typename CTransposedWarpDstrEncodingTrait< Impl >::type, typename CWarpDstrEncodingTrait< Impl >::type > CWarpDstrEncoding
Definition: warp_gemm_attribute_wmma.hpp:96
static constexpr index_t kN
Definition: warp_gemm_attribute_wmma.hpp:80
static constexpr index_t kKPerThread
Definition: warp_gemm_attribute_wmma.hpp:83
typename Impl::ADataType ADataType
Definition: warp_gemm_attribute_wmma.hpp:71
typename Impl::CDataType CDataType
Definition: warp_gemm_attribute_wmma.hpp:73
static constexpr index_t kCMLane
Definition: warp_gemm_attribute_wmma.hpp:82
typename Impl::BVecType BVecType
Definition: warp_gemm_attribute_wmma.hpp:76
static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access()
Definition: warp_gemm_attribute_wmma.hpp:85
typename Impl::BDataType BDataType
Definition: warp_gemm_attribute_wmma.hpp:72
typename Impl::CVecType CVecType
Definition: warp_gemm_attribute_wmma.hpp:77
typename Impl::AVecType AVecType
Definition: warp_gemm_attribute_wmma.hpp:75
static constexpr index_t kK
Definition: warp_gemm_attribute_wmma.hpp:81
static constexpr index_t kM
Definition: warp_gemm_attribute_wmma.hpp:79
CK_TILE_DEVICE CVecType operator()(const AVecType &a_vec, const BVecType &b_vec) const
Definition: warp_gemm_attribute_wmma.hpp:116
typename AWarpDstrEncodingTrait< Impl >::type AWarpDstrEncoding
Definition: warp_gemm_attribute_wmma.hpp:89
typename BWarpDstrEncodingTrait< Impl >::type BWarpDstrEncoding
Definition: warp_gemm_attribute_wmma.hpp:90
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192