/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_dpp.hpp Source File
device_gemm_dpp.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <sstream>
7 
17 
18 namespace ck {
19 namespace tensor_operation {
20 namespace device {
21 
22 template <typename ADataType,
23  typename BDataType,
24  typename CDataType,
25  typename AccDataType,
26  typename ALayout,
27  typename BLayout,
28  typename CLayout,
29  typename AElementwiseOperation,
30  typename BElementwiseOperation,
31  typename CElementwiseOperation,
32  GemmSpecialization GemmSpec,
33  ck::index_t BlockSize,
34  ck::index_t MPerBlock,
35  ck::index_t NPerBlock,
36  ck::index_t KPerBlock,
37  ck::index_t AK1,
38  ck::index_t BK1,
39  ck::index_t MPerDpp,
40  ck::index_t NPerDpp,
41  ck::index_t MDppPerWave,
42  ck::index_t NDppPerWave,
43  typename ABlockTransferThreadClusterLengths_K0_M_K1,
44  typename ABlockTransferThreadClusterArrangeOrder,
45  typename ABlockTransferSrcAccessOrder,
46  ck::index_t ABlockTransferSrcVectorDim,
47  ck::index_t ABlockTransferSrcScalarPerVector,
48  ck::index_t ABlockTransferDstScalarPerVector_K1,
49  bool ABlockLdsAddExtraM,
50  typename BBlockTransferThreadClusterLengths_K0_N_K1,
51  typename BBlockTransferThreadClusterArrangeOrder,
52  typename BBlockTransferSrcAccessOrder,
53  ck::index_t BBlockTransferSrcVectorDim,
54  ck::index_t BBlockTransferSrcScalarPerVector,
55  ck::index_t BBlockTransferDstScalarPerVector_K1,
56  bool BBlockLdsAddExtraN,
57  ck::index_t CThreadTransferSrcDstVectorDim,
58  ck::index_t CThreadTransferDstScalarPerVector,
59  ck::index_t NumPrefetch = 1,
61 struct DeviceGemmDpp : public DeviceGemm<ALayout,
62  BLayout,
63  CLayout,
64  ADataType,
65  BDataType,
66  CDataType,
67  AElementwiseOperation,
68  BElementwiseOperation,
69  CElementwiseOperation>
70 {
72  BlockSize,
73  ADataType,
74  AccDataType,
75  CDataType,
77  ALayout,
78  BLayout,
79  CLayout,
80  AElementwiseOperation,
81  BElementwiseOperation,
82  CElementwiseOperation,
83  GemmSpec,
84  MPerBlock,
85  NPerBlock,
86  KPerBlock,
87  MPerDpp,
88  NPerDpp,
89  AK1,
90  BK1,
91  MDppPerWave,
92  NDppPerWave,
93  ABlockTransferThreadClusterLengths_K0_M_K1,
94  ABlockTransferThreadClusterArrangeOrder,
95  ABlockTransferSrcAccessOrder,
96  ABlockTransferSrcVectorDim,
97  ABlockTransferSrcScalarPerVector,
98  ABlockTransferDstScalarPerVector_K1,
99  false, // AThreadTransferSrcResetCoordinateAfterRun,
100  ABlockLdsAddExtraM,
101  BBlockTransferThreadClusterLengths_K0_N_K1,
102  BBlockTransferThreadClusterArrangeOrder,
103  BBlockTransferSrcAccessOrder,
104  BBlockTransferSrcVectorDim,
105  BBlockTransferSrcScalarPerVector,
106  BBlockTransferDstScalarPerVector_K1,
107  false, // BThreadTransferSrcResetCoordinateAfterRun,
108  BBlockLdsAddExtraN,
109  Sequence<0, 2, 4, 1, 3, 5>, // CThreadTransferSrcDstAccessOrder,
110  CThreadTransferSrcDstVectorDim,
111  CThreadTransferDstScalarPerVector,
112  NumPrefetch,
113  PipelineVer>;
114 
116 
117  // Invoker
118  struct Invoker : public BaseInvoker
119  {
120  float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
121  {
122  if(stream_config.log_level_ > 0)
123  {
124  karg.Print();
125  }
126 
127  if(!GridwiseGemm::CheckValidity(karg))
128  {
129  throw std::runtime_error(
130  "wrong! GridwiseGemm_k0mk1_k0nk1_mn_dpp has invalid setting");
131  }
132 
133  const auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N);
134 
135  float ave_time = 0;
136 
138  {
139  const auto kernel = kernel_gemm_dpp<GridwiseGemm, true>;
140 
141  ave_time = launch_and_time_kernel(
142  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
143  }
144  else
145  {
146  const auto kernel = kernel_gemm_dpp<GridwiseGemm, false>;
147 
148  ave_time = launch_and_time_kernel(
149  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
150  }
151 
152  return ave_time;
153  }
154 
155  // polymorphic
156  float Run(const BaseArgument* p_arg,
157  const StreamConfig& stream_config = StreamConfig{}) override
158  {
159  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
160  }
161  };
162 
163  static constexpr bool IsValidCompilationParameter()
164  {
165  // TODO: properly implement this check
166  return true;
167  }
168 
169  static bool IsSupportedArgument(const Argument& karg)
170  {
172  {
173  return GridwiseGemm::CheckValidity(karg);
174  }
175  return false;
176  }
177 
178  // polymorphic
179  bool IsSupportedArgument(const BaseArgument* p_arg) override
180  {
181  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
182  }
183 
184  static auto MakeArgument(const ADataType* p_a,
185  const BDataType* p_b,
186  CDataType* p_c,
187  index_t M,
188  index_t N,
189  index_t K,
190  index_t StrideA,
191  index_t StrideB,
192  index_t StrideC,
193  AElementwiseOperation,
194  BElementwiseOperation,
195  CElementwiseOperation)
196  {
197  return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
198  }
199 
200  static auto MakeInvoker() { return Invoker{}; }
201 
202  // polymorphic
203  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
204  const void* p_b,
205  void* p_c,
206  index_t M,
207  index_t N,
208  index_t K,
209  index_t StrideA,
210  index_t StrideB,
211  index_t StrideC,
212  AElementwiseOperation,
213  BElementwiseOperation,
214  CElementwiseOperation) override
215  {
216  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
217  static_cast<const BDataType*>(p_b),
218  static_cast<CDataType*>(p_c),
219  M,
220  N,
221  K,
222  StrideA,
223  StrideB,
224  StrideC);
225  }
226 
227  // polymorphic
228  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
229  {
230  return std::make_unique<Invoker>(Invoker{});
231  }
232 
233  // polymorphic
234  std::string GetTypeString() const override
235  {
236  auto str = std::stringstream();
237 
238  std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
239  {PipelineVersion::v2, "v2"}};
240 
241  // clang-format off
242  str << "DeviceGemmDpp"
243  << "<"
244  << BlockSize << ", "
245  << MPerBlock << ", "
246  << NPerBlock << ", "
247  << KPerBlock << ", "
248  << AK1 << ", "
249  << BK1 << ", "
250  << MPerDpp << ", "
251  << NPerDpp << ", "
252  << MDppPerWave << ", "
253  << MDppPerWave << ", "
254  << ABlockTransferSrcScalarPerVector << ", "
255  << ABlockTransferDstScalarPerVector_K1 << ", "
256  << BBlockTransferSrcScalarPerVector << ", "
257  << BBlockTransferDstScalarPerVector_K1
258  << ">"
259  << " NumPrefetch: "
260  << NumPrefetch << ", "
261  << "PipelineVersion: "
262  << PipelineVersionToString[PipelineVer];
263  // clang-format on
264 
265  return str.str();
266  }
267 };
268 
269 } // namespace device
270 } // namespace tensor_operation
271 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:264
bool is_gfx103_supported()
Definition: device_prop.hpp:81
int32_t index_t
Definition: ck.hpp:289
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:17
bool is_gfx11_supported()
Definition: device_prop.hpp:88
Definition: stream_config.hpp:10
Definition: gridwise_gemm_dpp.hpp:184
Definition: gridwise_gemm_dpp.hpp:96
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_dpp.hpp:356
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_dpp.hpp:115
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_dpp.hpp:263
Definition: sequence.hpp:43
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_gemm_dpp.hpp:119
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_dpp.hpp:156
float Run(const Argument &karg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_dpp.hpp:120
Definition: device_gemm_dpp.hpp:70
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_gemm_dpp.hpp:203
typename GridwiseGemm::Argument Argument
Definition: device_gemm_dpp.hpp:115
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_dpp.hpp:228
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_dpp.hpp:179
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_dpp.hpp:163
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_gemm_dpp.hpp:184
static bool IsSupportedArgument(const Argument &karg)
Definition: device_gemm_dpp.hpp:169
static auto MakeInvoker()
Definition: device_gemm_dpp.hpp:200
std::string GetTypeString() const override
Definition: device_gemm_dpp.hpp:234
Definition: device_gemm.hpp:22