/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_base.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_base.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_base.hpp Source File
device_base.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
7 #include <string>
8 #include <sstream>
9 #include <regex>
10 #include <optional>
11 #include <memory>
12 
13 #include "ck/stream_config.hpp"
14 
15 #ifdef CK_EXPERIMENTAL_BUILDER
16 #include "ck_tile/builder/reflect/description.hpp"
17 #endif
18 #endif
19 #include "ck/utility/get_id.hpp"
20 #include "ck/utility/sequence.hpp"
21 
22 namespace ck {
23 namespace tensor_operation {
24 namespace device {
25 
26 #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
27 #define GET_OBJECT_NAME_IMLP \
28  std::optional<std::string> GetObjectName() const override \
29  { \
30  std::string str = __PRETTY_FUNCTION__; \
31  static std::regex obj_name_expr{"<std::string> (.*)::GetObjectName"}; \
32  std::smatch match; \
33  if(!std::regex_search(str, match, obj_name_expr)) \
34  { \
35  return str; \
36  } \
37  return std::string(match[1]) + ';'; \
38  }
39 
40 #define GET_TEMPLATE_INFO_IMPL \
41  std::optional<std::string> GetTemplateInfo() const override \
42  { \
43  std::string str = __PRETTY_FUNCTION__; \
44  static std::regex template_expr{"\\[(.*)\\]"}; \
45  std::smatch match; \
46  if(!std::regex_search(str, match, template_expr)) \
47  { \
48  return std::nullopt; \
49  } \
50  return std::string(match[1]); \
51  }
52 
53 #define REGISTER_EXTRA_PRINTING_METHODS GET_OBJECT_NAME_IMLP GET_TEMPLATE_INFO_IMPL
54 #endif
55 
56 template <index_t BlockSize_,
57  index_t MPerBlock_,
58  index_t NPerBlock_,
59  index_t MPerXDL_,
60  index_t NPerXDL_,
61  index_t MXdlPerWave_,
62  bool IsWave64>
63 static constexpr auto GetXdlPerWave2()
64 {
65  constexpr index_t Waves = IsWave64 ? BlockSize_ / 64 : BlockSize_ / 32;
66  constexpr index_t MWaves = MPerBlock_ / (MXdlPerWave_ * MPerXDL_);
67  static_assert(MWaves > 0);
68 
69  constexpr index_t NWaves = Waves / MWaves;
70  if constexpr(NWaves == 0)
71  {
72  return 0;
73  }
74  else
75  {
76  if constexpr(NPerBlock_ % (NPerXDL_ * NWaves) == 0)
77  {
78  return NPerBlock_ / (NWaves * NPerXDL_);
79  }
80  else
81  {
82  return 0;
83  }
84  }
85 }
86 
87 #define GET_NXDL_PER_WAVE_IMPL \
88  template <bool IsWave64> \
89  static constexpr auto GetNXdlPerWave() \
90  { \
91  return GetXdlPerWave2<BlockSize, \
92  MPerBlock, \
93  NPerBlock, \
94  MPerXDL, \
95  NPerXDL, \
96  MXdlPerWave, \
97  IsWave64>(); \
98  }
99 
100 #define GET_MXDL_PER_WAVE_IMPL \
101  template <bool IsWave64, \
102  index_t MPerXDLAligned = MPerXDL, \
103  index_t NPerXDLAligned = NPerXDL, \
104  index_t NXdlPerWaveAligned = NXdlPerWave> \
105  static constexpr auto GetMXdlPerWave() \
106  { \
107  return GetXdlPerWave2<BlockSize, \
108  NPerBlock, \
109  MPerBlock, \
110  NPerXDLAligned, \
111  MPerXDLAligned, \
112  NXdlPerWaveAligned, \
113  IsWave64>(); \
114  }
115 
116 template <index_t BlockSize_,
117  index_t MPerBlock_,
118  index_t NPerBlock_,
119  index_t MPerXDL_,
120  index_t NPerXDL_,
121  index_t MXdlPerWave_,
122  index_t CShuffleMXdlPerWavePerShuffle_,
123  index_t CShuffleNXdlPerWavePerShuffle_,
124  bool IsWave64>
125 static constexpr auto GetWarpTileConfig()
126 {
127  constexpr auto MXdlPerWave64 = MXdlPerWave_;
128  constexpr auto MXdlPerWave32 = MXdlPerWave_ * MPerXDL_ / 16;
129  constexpr auto CShuffleMXdlPerWavePerShuffle32 = CShuffleMXdlPerWavePerShuffle_ * MPerXDL_ / 16;
130 
131  constexpr auto NXdlPerWave =
132  IsWave64
133  ? GetXdlPerWave2<BlockSize_,
134  MPerBlock_,
135  NPerBlock_,
136  MPerXDL_,
137  NPerXDL_,
138  MXdlPerWave_,
139  true>()
140  : GetXdlPerWave2<BlockSize_, MPerBlock_, NPerBlock_, 16, 16, MXdlPerWave32, false>();
141 
142  if constexpr(IsWave64 == false && NXdlPerWave != 0)
143  {
144  constexpr auto CShuffleNXdlPerWavePerShuffle32 =
145  NXdlPerWave >= CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16
146  ? CShuffleNXdlPerWavePerShuffle_ * NPerXDL_ / 16
147  : CShuffleNXdlPerWavePerShuffle_;
148  static_assert(CShuffleNXdlPerWavePerShuffle32 > 0);
149  return Sequence<16,
150  16,
151  MXdlPerWave32,
152  NXdlPerWave,
153  CShuffleMXdlPerWavePerShuffle32,
154  CShuffleNXdlPerWavePerShuffle32>{};
155  }
156  else
157  {
158  return Sequence<MPerXDL_,
159  NPerXDL_,
160  MXdlPerWave64,
161  NXdlPerWave,
162  CShuffleMXdlPerWavePerShuffle_,
163  CShuffleNXdlPerWavePerShuffle_>{};
164  }
165 }
166 
167 #define INVOKER_RUN_IMPL \
168  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \
169  { \
170  if(get_warp_size() == 64) \
171  { \
172  if constexpr(NXdlPerWave64 > 0) \
173  { \
174  return RunImp<GridwiseGemm64>(arg, stream_config); \
175  } \
176  } \
177  else \
178  { \
179  if constexpr(NXdlPerWave32 > 0) \
180  { \
181  return RunImp<GridwiseGemm32>(arg, stream_config); \
182  } \
183  } \
184  return 0; \
185  }
186 
187 #define INVOKER_RUN3_IMPL \
188  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) \
189  { \
190  if(get_warp_size() == 64) \
191  { \
192  if constexpr(NXdlPerWave64 > 0) \
193  { \
194  return RunImp<GridwiseGemm64>(arg, stream_config); \
195  } \
196  } \
197  else \
198  { \
199  if constexpr(NXdlPerWave32 > 0) \
200  { \
201  return RunImp<GridwiseGemm32>( \
202  reinterpret_cast<const typename GridwiseGemm32::Argument&>(arg), \
203  stream_config); \
204  } \
205  } \
206  return 0; \
207  }
208 
209 template <index_t BlockSize,
210  index_t MPerBlock,
211  index_t NPerBlock,
212  index_t MPerXdl,
213  index_t NPerXdl,
214  index_t MXdlPerWave,
215  index_t NXdlPerWave,
216  typename CDataType,
217  InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = InMemoryDataOperationEnum::Set>
218 __device__ static bool constexpr IsValidGemmCompilationParameter()
219 {
220 #if defined(__gfx11__) || defined(__gfx12__)
221  if constexpr(MPerXdl != 16 || NPerXdl != 16)
222  {
223  return false;
224  }
225 #endif
226 
227 #if defined(__gfx11__)
228  constexpr bool SupportMemOp = CGlobalMemoryDataOperation_ == InMemoryDataOperationEnum::Set;
229 #else
230  constexpr bool SupportMemOp =
231  sizeof(CDataType) >= 2 || (CGlobalMemoryDataOperation_ == InMemoryDataOperationEnum::Set);
232 #endif
233  if constexpr(SupportMemOp == false)
234  {
235  return false;
236  }
237 
238  if constexpr(MXdlPerWave > 0 && NXdlPerWave > 0)
239  {
240  constexpr index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
241  constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
242  if constexpr(MWaves > 0 && NWaves > 0)
243  {
244  constexpr index_t WaveSize = BlockSize / (MWaves * NWaves);
245  return WaveSize == get_warp_size();
246  }
247  }
248  return false;
249 }
250 
251 #define IS_VALID_COMPILATION_PARAMETER_IMPL(CDataType_) \
252  template <InMemoryDataOperationEnum CGlobalMemoryDataOperation_ = \
253  InMemoryDataOperationEnum::Set> \
254  __device__ static bool constexpr IsValidCompilationParameter() \
255  { \
256  return ck::tensor_operation::device::IsValidGemmCompilationParameter< \
257  BlockSize, \
258  MPerBlock, \
259  NPerBlock, \
260  MPerXdl, \
261  NPerXdl, \
262  MXdlPerWave, \
263  NXdlPerWave, \
264  CDataType_, \
265  CGlobalMemoryDataOperation_>(); \
266  }
267 
268 #ifndef CK_CODE_GEN_RTC
270 {
271  BaseArgument() = default;
272  BaseArgument(const BaseArgument&) = default;
273  BaseArgument& operator=(const BaseArgument&) = default;
274 
275  virtual __host__ __device__ ~BaseArgument() {}
276 
277  void* p_workspace_ = nullptr;
278 };
279 
281 {
282  BaseInvoker() = default;
283  BaseInvoker(const BaseInvoker&) = default;
284  BaseInvoker& operator=(const BaseInvoker&) = default;
285 
286  virtual float Run(const BaseArgument*, const StreamConfig& = StreamConfig{})
287  {
288  return float{0};
289  }
290 
291  virtual ~BaseInvoker() {}
292 };
293 #endif
294 
296 {
297  BaseOperator() = default;
298  BaseOperator(const BaseOperator&) = default;
299  BaseOperator& operator=(const BaseOperator&) = default;
300 #if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
301  virtual bool IsSupportedArgument(const BaseArgument*) { return false; }
302  virtual std::string GetTypeString() const { return ""; }
303 
304 #ifdef CK_EXPERIMENTAL_BUILDER
305  // Return a description object for this operator, or nullptr if not supported.
306  virtual std::unique_ptr<ck_tile::reflect::Description> describe() const { return nullptr; }
307 #endif
308 
309  virtual std::string GetInstanceString() const { return ""; }
310 
311  virtual std::string GetTypeIdName() const { return typeid(*this).name(); }
312 
313  virtual std::optional<std::string> GetObjectName() const { return std::nullopt; }
314 
315  virtual std::optional<std::string> GetTemplateInfo() const { return std::nullopt; }
316 
317  virtual std::string GetTypeIdHashCode() const
318  {
319  std::ostringstream oss;
320 
321  oss << std::hex << typeid(*this).hash_code();
322 
323  return oss.str();
324  };
325 
326  virtual size_t GetWorkSpaceSize(const BaseArgument*) const { return 0; }
327 
328  virtual void SetWorkSpacePointer(BaseArgument* p_arg,
329  void* p_workspace,
330  const StreamConfig& = StreamConfig{}) const
331  {
332  assert(p_arg);
333  p_arg->p_workspace_ = p_workspace;
334  }
335 #endif
336  virtual ~BaseOperator() {}
337 };
338 
339 } // namespace device
340 } // namespace tensor_operation
341 } // namespace ck
Definition: ck.hpp:270
InMemoryDataOperationEnum
Definition: ck.hpp:279
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
int32_t index_t
Definition: ck.hpp:301
Definition: stream_config.hpp:10
Definition: device_base.hpp:270
BaseArgument & operator=(const BaseArgument &)=default
BaseArgument(const BaseArgument &)=default
void * p_workspace_
Definition: device_base.hpp:277
virtual __host__ __device__ ~BaseArgument()
Definition: device_base.hpp:275
Definition: device_base.hpp:281
virtual ~BaseInvoker()
Definition: device_base.hpp:291
BaseInvoker & operator=(const BaseInvoker &)=default
virtual float Run(const BaseArgument *, const StreamConfig &=StreamConfig{})
Definition: device_base.hpp:286
BaseInvoker(const BaseInvoker &)=default
Definition: device_base.hpp:296
virtual void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const
Definition: device_base.hpp:328
virtual std::string GetInstanceString() const
Definition: device_base.hpp:309
virtual bool IsSupportedArgument(const BaseArgument *)
Definition: device_base.hpp:301
virtual size_t GetWorkSpaceSize(const BaseArgument *) const
Definition: device_base.hpp:326
virtual std::optional< std::string > GetTemplateInfo() const
Definition: device_base.hpp:315
virtual std::string GetTypeString() const
Definition: device_base.hpp:302
BaseOperator(const BaseOperator &)=default
virtual std::string GetTypeIdHashCode() const
Definition: device_base.hpp:317
virtual std::optional< std::string > GetObjectName() const
Definition: device_base.hpp:313
BaseOperator & operator=(const BaseOperator &)=default
virtual std::string GetTypeIdName() const
Definition: device_base.hpp:311
virtual ~BaseOperator()
Definition: device_base.hpp:336