/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp Source File
device_gemm_xdl_cshuffle.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
18 
19 namespace ck {
20 namespace tensor_operation {
21 namespace device {
22 
23 // Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
24 // version currently has compiler issues with register spill which further causes validation
25 // failures.
26 template <typename ALayout,
27  typename BLayout,
28  typename CLayout,
29  typename ADataType,
30  typename BDataType,
31  typename CDataType,
32  typename GemmAccDataType,
33  typename CShuffleDataType,
34  typename AElementwiseOperation,
35  typename BElementwiseOperation,
36  typename CElementwiseOperation,
37  GemmSpecialization GemmSpec,
38  index_t NumGemmKPrefetchStage,
39  index_t BlockSize,
40  index_t MPerBlock,
41  index_t NPerBlock,
42  index_t KPerBlock,
43  index_t AK1,
44  index_t BK1,
45  index_t MPerXDL,
46  index_t NPerXDL,
47  index_t MXdlPerWave,
48  index_t NXdlPerWave,
49  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
50  typename ABlockTransferThreadClusterArrangeOrder,
51  typename ABlockTransferSrcAccessOrder,
52  index_t ABlockTransferSrcVectorDim,
53  index_t ABlockTransferSrcScalarPerVector,
54  index_t ABlockTransferDstScalarPerVector_AK1,
55  bool ABlockLdsExtraM,
56  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
57  typename BBlockTransferThreadClusterArrangeOrder,
58  typename BBlockTransferSrcAccessOrder,
59  index_t BBlockTransferSrcVectorDim,
60  index_t BBlockTransferSrcScalarPerVector,
61  index_t BBlockTransferDstScalarPerVector_BK1,
62  bool BBlockLdsExtraN,
63  index_t CShuffleMXdlPerWavePerShuffle,
64  index_t CShuffleNXdlPerWavePerShuffle,
65  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
66  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
69  typename ComputeTypeA = CDataType,
70  typename ComputeTypeB = ComputeTypeA>
71 struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
72  BLayout,
73  CLayout,
74  ADataType,
75  BDataType,
76  CDataType,
77  AElementwiseOperation,
78  BElementwiseOperation,
79  CElementwiseOperation>
80 {
82 
83  static constexpr auto I0 = Number<0>{};
84  static constexpr auto I1 = Number<1>{};
85  static constexpr auto I2 = Number<2>{};
86 
87  // GridwiseGemm
89  ALayout,
90  BLayout,
91  CLayout,
92  ADataType,
93  BDataType,
94  GemmAccDataType,
95  CShuffleDataType,
96  CDataType,
97  AElementwiseOperation,
98  BElementwiseOperation,
99  CElementwiseOperation,
100  GemmSpec,
102  NumGemmKPrefetchStage,
103  BlockSize,
104  MPerBlock,
105  NPerBlock,
106  KPerBlock,
107  AK1,
108  BK1,
109  MPerXDL,
110  NPerXDL,
111  MXdlPerWave,
112  NXdlPerWave,
113  ABlockTransferThreadClusterLengths_AK0_M_AK1,
114  ABlockTransferThreadClusterArrangeOrder,
115  ABlockTransferSrcAccessOrder,
116  ABlockTransferSrcVectorDim,
117  ABlockTransferSrcScalarPerVector,
118  ABlockTransferDstScalarPerVector_AK1,
119  false,
120  ABlockLdsExtraM,
121  BBlockTransferThreadClusterLengths_BK0_N_BK1,
122  BBlockTransferThreadClusterArrangeOrder,
123  BBlockTransferSrcAccessOrder,
124  BBlockTransferSrcVectorDim,
125  BBlockTransferSrcScalarPerVector,
126  BBlockTransferDstScalarPerVector_BK1,
127  false,
128  BBlockLdsExtraN,
129  CShuffleMXdlPerWavePerShuffle,
130  CShuffleNXdlPerWavePerShuffle,
131  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
132  CShuffleBlockTransferScalarPerVector_NPerBlock,
133  LoopSched,
134  PipelineVer,
135  ComputeTypeA,
136  ComputeTypeB>;
137 
139 
140  // Invoker
141  struct Invoker : public BaseInvoker
142  {
143  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
144  {
145  if(stream_config.log_level_ > 0)
146  {
147  arg.Print();
148  }
149 
151  {
152  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
153  }
154 
155  index_t gdx, gdy, gdz;
156  std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
157 
158  const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1;
159 
160  float ave_time = 0;
161 
163  {
164  const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, true>;
165 
166  ave_time = launch_and_time_kernel(
167  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
168  }
169  else
170  {
171  const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm, false>;
172 
173  ave_time = launch_and_time_kernel(
174  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
175  }
176 
177  return ave_time;
178  }
179 
180  // polymorphic
181  float Run(const BaseArgument* p_arg,
182  const StreamConfig& stream_config = StreamConfig{}) override
183  {
184  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
185  }
186  };
187 
188  static constexpr bool IsValidCompilationParameter()
189  {
190  // TODO: properly implement this check
191  return true;
192  }
193 
194  static bool IsSupportedArgument(const Argument& arg)
195  {
196  if(!ck::is_xdl_supported())
197  {
198  return false;
199  }
200 
201  if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
202  GemmSpec == GemmSpecialization::NKPadding ||
203  GemmSpec == GemmSpecialization::MNKPadding ||
204  GemmSpec == GemmSpecialization::KPadding))
205  {
206  return false;
207  }
208 
209  return GridwiseGemm::CheckValidity(arg);
210  }
211 
212  // polymorphic
213  bool IsSupportedArgument(const BaseArgument* p_arg) override
214  {
215  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
216  }
217 
218  static auto MakeArgument(const ADataType* p_a,
219  const BDataType* p_b,
220  CDataType* p_c,
221  index_t M,
222  index_t N,
223  index_t K,
224  index_t StrideA,
225  index_t StrideB,
226  index_t StrideC,
227  AElementwiseOperation,
228  BElementwiseOperation,
229  CElementwiseOperation)
230  {
231  return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
232  }
233 
234  static auto MakeInvoker() { return Invoker{}; }
235 
236  // polymorphic
237  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
238  const void* p_b,
239  void* p_c,
240  index_t M,
241  index_t N,
242  index_t K,
243  index_t StrideA,
244  index_t StrideB,
245  index_t StrideC,
246  AElementwiseOperation,
247  BElementwiseOperation,
248  CElementwiseOperation) override
249  {
250  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
251  static_cast<const BDataType*>(p_b),
252  static_cast<CDataType*>(p_c),
253  M,
254  N,
255  K,
256  StrideA,
257  StrideB,
258  StrideC);
259  }
260 
261  // polymorphic
262  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
263  {
264  return std::make_unique<Invoker>(Invoker{});
265  }
266 
267  // polymorphic
268  std::string GetTypeString() const override
269  {
270  auto str = std::stringstream();
271 
272  std::map<LoopScheduler, std::string> LoopSchedToString{
273  {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
274 
275  std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
276  {PipelineVersion::v2, "v2"}};
277 
278  // clang-format off
279  str << "DeviceGemm_Xdl_CShuffle"
280  << "<"
281  << getGemmSpecializationString(GemmSpec) << ", "
282  << BlockSize << ", "
283  << MPerBlock << ", "
284  << NPerBlock << ", "
285  << KPerBlock << ", "
286  << AK1 << ", "
287  << BK1 << ", "
288  << MPerXDL << ", "
289  << NPerXDL << ", "
290  << MXdlPerWave << ", "
291  << NXdlPerWave << ", "
292  << ABlockTransferSrcScalarPerVector << ", "
293  << BBlockTransferSrcScalarPerVector << ", "
294  << CShuffleMXdlPerWavePerShuffle << ", "
295  << CShuffleNXdlPerWavePerShuffle
296  << ">"
297  << " LoopScheduler: "
298  << LoopSchedToString[LoopSched] << ", "
299  << "PipelineVersion: "
300  << PipelineVersionToString[PipelineVer];
301  // clang-format on
302 
303  return str.str();
304  }
305 };
306 
307 } // namespace device
308 } // namespace tensor_operation
309 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:33
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:264
bool is_xdl_supported()
Definition: device_prop.hpp:54
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:289
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:17
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:472
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:114
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:557
static __host__ auto CalculateAK0(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:152
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:661
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:132
Definition: integral_constant.hpp:10
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_gemm_xdl_cshuffle.hpp:142
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl_cshuffle.hpp:143
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl_cshuffle.hpp:181
Definition: device_gemm_xdl_cshuffle.hpp:80
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl_cshuffle.hpp:213
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_gemm_xdl_cshuffle.hpp:218
std::string GetTypeString() const override
Definition: device_gemm_xdl_cshuffle.hpp:268
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl_cshuffle.hpp:262
static constexpr auto I1
Definition: device_gemm_xdl_cshuffle.hpp:84
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_xdl_cshuffle.hpp:194
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_xdl_cshuffle.hpp:188
static constexpr auto I0
Definition: device_gemm_xdl_cshuffle.hpp:83
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_gemm_xdl_cshuffle.hpp:237
static auto MakeInvoker()
Definition: device_gemm_xdl_cshuffle.hpp:234
static constexpr auto I2
Definition: device_gemm_xdl_cshuffle.hpp:85
typename GridwiseGemm::Argument Argument
Definition: device_gemm_xdl_cshuffle.hpp:138
Definition: device_gemm.hpp:22