/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl.hpp Source File
device_gemm_xdl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
18 
19 namespace ck {
20 namespace tensor_operation {
21 namespace device {
22 
23 template <typename ADataType,
24  typename BDataType,
25  typename CDataType,
26  typename AccDataType,
27  typename ALayout,
28  typename BLayout,
29  typename CLayout,
30  typename AElementwiseOperation,
31  typename BElementwiseOperation,
32  typename CElementwiseOperation,
33  GemmSpecialization GemmSpec,
34  ck::index_t BlockSize,
35  ck::index_t MPerBlock,
36  ck::index_t NPerBlock,
37  ck::index_t K0PerBlock,
38  ck::index_t K1,
39  ck::index_t MPerXDL,
40  ck::index_t NPerXDL,
41  ck::index_t MXdlPerWave,
42  ck::index_t NXdlPerWave,
43  typename ABlockTransferThreadClusterLengths_K0_M_K1,
44  typename ABlockTransferThreadClusterArrangeOrder,
45  typename ABlockTransferSrcAccessOrder,
46  ck::index_t ABlockTransferSrcVectorDim,
47  ck::index_t ABlockTransferSrcScalarPerVector,
48  ck::index_t ABlockTransferDstScalarPerVector_K1,
49  bool ABlockLdsAddExtraM,
50  typename BBlockTransferThreadClusterLengths_K0_N_K1,
51  typename BBlockTransferThreadClusterArrangeOrder,
52  typename BBlockTransferSrcAccessOrder,
53  ck::index_t BBlockTransferSrcVectorDim,
54  ck::index_t BBlockTransferSrcScalarPerVector,
55  ck::index_t BBlockTransferDstScalarPerVector_K1,
56  bool BBlockLdsAddExtraN,
57  ck::index_t CThreadTransferSrcDstVectorDim,
58  ck::index_t CThreadTransferDstScalarPerVector,
59  ck::index_t NumPrefetch = 1,
62 struct DeviceGemmXdl : public DeviceGemm<ALayout,
63  BLayout,
64  CLayout,
65  ADataType,
66  BDataType,
67  CDataType,
68  AElementwiseOperation,
69  BElementwiseOperation,
70  CElementwiseOperation>
71 {
72  static constexpr auto I0 = Number<0>{};
73  static constexpr auto I1 = Number<1>{};
74  static constexpr auto I2 = Number<2>{};
75 
76  static constexpr auto K1Number = Number<K1>{};
77 
78  // GridwiseGemm
80  BlockSize,
81  ADataType, // TODO: distinguish A/B datatype
82  AccDataType,
83  CDataType,
85  ALayout,
86  BLayout,
87  CLayout,
88  AElementwiseOperation,
89  BElementwiseOperation,
90  CElementwiseOperation,
91  GemmSpec,
92  MPerBlock,
93  NPerBlock,
94  K0PerBlock,
95  MPerXDL,
96  NPerXDL,
97  K1,
98  MXdlPerWave,
99  NXdlPerWave,
100  ABlockTransferThreadClusterLengths_K0_M_K1,
101  ABlockTransferThreadClusterArrangeOrder,
102  ABlockTransferSrcAccessOrder,
103  ABlockTransferSrcVectorDim,
104  ABlockTransferSrcScalarPerVector,
105  ABlockTransferDstScalarPerVector_K1,
106  false, // AThreadTransferSrcResetCoordinateAfterRun,
107  ABlockLdsAddExtraM,
108  BBlockTransferThreadClusterLengths_K0_N_K1,
109  BBlockTransferThreadClusterArrangeOrder,
110  BBlockTransferSrcAccessOrder,
111  BBlockTransferSrcVectorDim,
112  BBlockTransferSrcScalarPerVector,
113  BBlockTransferDstScalarPerVector_K1,
114  false, // BThreadTransferSrcResetCoordinateAfterRun,
115  BBlockLdsAddExtraN,
116  Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
117  CThreadTransferSrcDstVectorDim,
118  CThreadTransferDstScalarPerVector,
119  NumPrefetch,
120  LoopSched,
121  PipelineVer>;
122 
123  using Argument = typename GridwiseGemm::Argument;
124 
125  // Invoker
126  struct Invoker : public BaseInvoker
127  {
128  float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
129  {
130  if(stream_config.log_level_ > 0)
131  {
132  karg.Print();
133  }
134 
135  if(!GridwiseGemm::CheckValidity(karg))
136  {
137  throw std::runtime_error(
138  "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext has invalid setting");
139  }
140 
141  const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N);
142 
143  float ave_time = 0;
144 
146  {
147  const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, true>;
148 
149  ave_time = launch_and_time_kernel(
150  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
151  }
152  else
153  {
154  const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, false>;
155 
156  ave_time = launch_and_time_kernel(
157  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
158  }
159 
160  return ave_time;
161  }
162 
163  // polymorphic
164  float Run(const BaseArgument* p_arg,
165  const StreamConfig& stream_config = StreamConfig{}) override
166  {
167  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
168  }
169  };
170 
171  static constexpr bool IsValidCompilationParameter()
172  {
173  // TODO: properly implement this check
174  return true;
175  }
176 
177  static bool IsSupportedArgument(const Argument& karg)
178  {
179  if(ck::get_device_name() == "gfx908")
180  {
181  if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
182  is_same_v<AccDataType, int32_t>))
183  {
184  return false;
185  }
186  }
188  {
189  if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
190  is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
191  {
192  return false;
193  }
194  }
195  else
196  {
197  return false;
198  }
199 
200  if(karg.K % K1 != 0)
201  {
202  return false;
203  }
204 
205  return GridwiseGemm::CheckValidity(karg);
206  }
207 
208  // polymorphic
209  bool IsSupportedArgument(const BaseArgument* p_arg) override
210  {
211  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
212  }
213 
214  static auto MakeArgument(const ADataType* p_a,
215  const BDataType* p_b,
216  CDataType* p_c,
217  index_t M,
218  index_t N,
219  index_t K,
220  index_t StrideA,
221  index_t StrideB,
222  index_t StrideC,
223  AElementwiseOperation,
224  BElementwiseOperation,
225  CElementwiseOperation)
226  {
227  return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
228  }
229 
230  static auto MakeInvoker() { return Invoker{}; }
231 
232  // polymorphic
233  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
234  const void* p_b,
235  void* p_c,
236  index_t M,
237  index_t N,
238  index_t K,
239  index_t StrideA,
240  index_t StrideB,
241  index_t StrideC,
242  AElementwiseOperation,
243  BElementwiseOperation,
244  CElementwiseOperation) override
245  {
246  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
247  static_cast<const BDataType*>(p_b),
248  static_cast<CDataType*>(p_c),
249  M,
250  N,
251  K,
252  StrideA,
253  StrideB,
254  StrideC);
255  }
256 
257  // polymorphic
258  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
259  {
260  return std::make_unique<Invoker>(Invoker{});
261  }
262 
263  // polymorphic
264  std::string GetTypeString() const override
265  {
266  auto str = std::stringstream();
267 
268  std::map<LoopScheduler, std::string> LoopSchedToString{
269  {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
270 
271  std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
272  {PipelineVersion::v2, "v2"}};
273 
274  // clang-format off
275  str << "DeviceGemmXdl"
276  << "<"
277  << BlockSize << ", "
278  << MPerBlock << ", "
279  << NPerBlock << ", "
280  << K0PerBlock << ", "
281  << K1 << ", "
282  << MPerXDL << ", "
283  << NPerXDL << ", "
284  << MXdlPerWave << ", "
285  << NXdlPerWave << ", "
286  << ABlockTransferSrcScalarPerVector << ", "
287  << ABlockTransferDstScalarPerVector_K1 << ", "
288  << BBlockTransferSrcScalarPerVector << ", "
289  << BBlockTransferDstScalarPerVector_K1
290  << ">"
291  << " NumPrefetch: "
292  << NumPrefetch << ", "
293  << "LoopScheduler: "
294  << LoopSchedToString[LoopSched] << ", "
295  << "PipelineVersion: "
296  << PipelineVersionToString[PipelineVer];
297  // clang-format on
298 
299  return str.str();
300  }
301 };
302 
303 } // namespace device
304 } // namespace tensor_operation
305 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:264
bool is_lds_direct_load_supported()
Definition: device_prop.hpp:61
std::string get_device_name()
Definition: device_prop.hpp:12
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:289
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:17
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_gemm_xdlops_v2r3.hpp:781
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdlops_v2r3.hpp:968
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_gemm.hpp:22
Definition: device_gemm_xdl.hpp:127
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl.hpp:164
float Run(const Argument &karg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl.hpp:128
Definition: device_gemm_xdl.hpp:71
static bool IsSupportedArgument(const Argument &karg)
Definition: device_gemm_xdl.hpp:177
static constexpr auto K1Number
Definition: device_gemm_xdl.hpp:76
static constexpr auto I0
Definition: device_gemm_xdl.hpp:72
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl.hpp:209
static auto MakeInvoker()
Definition: device_gemm_xdl.hpp:230
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl.hpp:258
std::string GetTypeString() const override
Definition: device_gemm_xdl.hpp:264
static constexpr auto I2
Definition: device_gemm_xdl.hpp:74
static constexpr auto I1
Definition: device_gemm_xdl.hpp:73
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_gemm_xdl.hpp:214
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_xdl.hpp:171
typename GridwiseGemm::Argument Argument
Definition: device_gemm_xdl.hpp:123
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_gemm_xdl.hpp:233