/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp Source File
device_batched_gemm_xdl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
18 
19 namespace ck {
20 namespace tensor_operation {
21 namespace device {
22 
23 /*
24  * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
25  *
26  * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
27  * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
28  * strided batched, but we can easily extend to other layouts. The returned offset can be either \p
29  * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
30  * limitations.
31  *
32  * \tparam Block2CTileMap Block2CTileMap::CalculateBottomIndex() takes in id of a workgroup and
33  * returns the 2D index of the tile that it computes. \see
34  * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
35  *
36  * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
37  * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
38  * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
39  * device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
40  * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
41  * pointer offset into \p ComputePtrOffsetOfStridedBatch.
42  *
43  * \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes.
44  * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
45  * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
46  *
47  */
48 template <typename DeviceOp, typename GridwiseGemm, bool HasMainKBlockLoop>
49 __global__ void
50 #if CK_USE_LAUNCH_BOUNDS
52 #endif
53  kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg)
54 {
55 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
56  const index_t num_blocks_per_batch =
57  __builtin_amdgcn_readfirstlane(get_grid_size() / karg.Batch);
58  const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
59 
60  const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
61  static_cast<long_index_t>(karg.compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
62  const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
63  static_cast<long_index_t>(karg.compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
64  const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
65  static_cast<long_index_t>(karg.compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
66 
67  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
68 
69  const auto a_grid_desc_k0_m_k1 =
70  amd_wave_read_first_lane(GridwiseGemm::MakeAGridDescriptor_K0_M_K1(
71  karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA));
72  const auto b_grid_desc_k0_n_k1 =
73  amd_wave_read_first_lane(GridwiseGemm::MakeBGridDescriptor_K0_N_K1(
74  karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB));
75  const auto c_grid_desc_m_n = amd_wave_read_first_lane(GridwiseGemm::MakeCGridDescriptor_M_N(
76  karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC));
77 
78  GridwiseGemm::template Run<HasMainKBlockLoop>(karg.p_a_grid + a_batch_offset,
79  karg.p_b_grid + b_batch_offset,
80  karg.p_c_grid + c_batch_offset,
81  p_shared,
82  a_grid_desc_k0_m_k1,
83  b_grid_desc_k0_n_k1,
84  c_grid_desc_m_n);
85 #else
86  ignore = karg;
87 #endif
88 }
89 
90 template <typename ADataType,
91  typename BDataType,
92  typename CDataType,
93  typename AccDataType,
94  typename ALayout,
95  typename BLayout,
96  typename CLayout,
97  typename AElementwiseOperation,
98  typename BElementwiseOperation,
99  typename CElementwiseOperation,
100  ck::index_t BlockSize,
101  ck::index_t MPerBlock,
102  ck::index_t NPerBlock,
103  ck::index_t K0PerBlock,
104  ck::index_t K1,
105  ck::index_t MPerXDL,
106  ck::index_t NPerXDL,
107  ck::index_t MXdlPerWave,
108  ck::index_t NXdlPerWave,
109  typename ABlockTransferThreadClusterLengths_K0_M_K1,
110  typename ABlockTransferThreadClusterArrangeOrder,
111  typename ABlockTransferSrcAccessOrder,
112  ck::index_t ABlockTransferSrcVectorDim,
113  ck::index_t ABlockTransferSrcScalarPerVector,
114  ck::index_t ABlockTransferDstScalarPerVector_K1,
115  bool ABlockLdsAddExtraM,
116  typename BBlockTransferThreadClusterLengths_K0_N_K1,
117  typename BBlockTransferThreadClusterArrangeOrder,
118  typename BBlockTransferSrcAccessOrder,
119  ck::index_t BBlockTransferSrcVectorDim,
120  ck::index_t BBlockTransferSrcScalarPerVector,
121  ck::index_t BBlockTransferDstScalarPerVector_K1,
122  bool BBlockLdsAddExtraN,
123  ck::index_t CThreadTransferSrcDstVectorDim,
124  ck::index_t CThreadTransferDstScalarPerVector,
125  ck::index_t NumGemmKPrefetchStage = 1,
128 struct DeviceBatchedGemmXdl : public DeviceBatchedGemm<ALayout,
129  BLayout,
130  CLayout,
131  ADataType,
132  BDataType,
133  CDataType,
134  AElementwiseOperation,
135  BElementwiseOperation,
136  CElementwiseOperation>
137 {
138  static constexpr auto I0 = Number<0>{};
139  static constexpr auto I1 = Number<1>{};
140  static constexpr auto I2 = Number<2>{};
141 
142  static constexpr auto K1Number = Number<K1>{};
143 
145  {
147  index_t BatchStrideB,
148  index_t BatchStrideC)
149  : BatchStrideA_(BatchStrideA), BatchStrideB_(BatchStrideB), BatchStrideC_(BatchStrideC)
150  {
151  }
152 
153  __host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
154  {
155  return g_idx * static_cast<long_index_t>(BatchStrideA_);
156  }
157 
158  __host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
159  {
160  return g_idx * static_cast<long_index_t>(BatchStrideB_);
161  }
162 
163  __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
164  {
165  return g_idx * static_cast<long_index_t>(BatchStrideC_);
166  }
167 
168  private:
169  index_t BatchStrideA_;
170  index_t BatchStrideB_;
171  index_t BatchStrideC_;
172  };
173 
174  // GridwiseGemm
176  BlockSize,
177  ADataType, // TODO: distinguish A/B datatype
178  AccDataType,
179  CDataType,
181  ALayout,
182  BLayout,
183  CLayout,
184  AElementwiseOperation,
185  BElementwiseOperation,
186  CElementwiseOperation,
188  MPerBlock,
189  NPerBlock,
190  K0PerBlock,
191  MPerXDL,
192  NPerXDL,
193  K1,
194  MXdlPerWave,
195  NXdlPerWave,
196  ABlockTransferThreadClusterLengths_K0_M_K1,
197  ABlockTransferThreadClusterArrangeOrder,
198  ABlockTransferSrcAccessOrder,
199  ABlockTransferSrcVectorDim,
200  ABlockTransferSrcScalarPerVector,
201  ABlockTransferDstScalarPerVector_K1,
202  false, // AThreadTransferSrcResetCoordinateAfterRun,
203  ABlockLdsAddExtraM,
204  BBlockTransferThreadClusterLengths_K0_N_K1,
205  BBlockTransferThreadClusterArrangeOrder,
206  BBlockTransferSrcAccessOrder,
207  BBlockTransferSrcVectorDim,
208  BBlockTransferSrcScalarPerVector,
209  BBlockTransferDstScalarPerVector_K1,
210  false, // BThreadTransferSrcResetCoordinateAfterRun,
211  BBlockLdsAddExtraN,
213  CThreadTransferSrcDstVectorDim,
214  CThreadTransferDstScalarPerVector,
215  NumGemmKPrefetchStage,
216  LoopSched,
217  PipelineVer>;
218 
219  using Problem = typename GridwiseGemm::Problem;
220 
221  // Argument
222  struct Argument : public Problem, public BaseArgument
223  {
224  Argument(const ADataType* p_a_grid_,
225  const BDataType* p_b_grid_,
226  CDataType* p_c_grid_,
227  index_t M_,
228  index_t N_,
229  index_t K_,
230  index_t StrideA_,
231  index_t StrideB_,
232  index_t StrideC_,
233  index_t BatchStrideA,
234  index_t BatchStrideB,
235  index_t BatchStrideC,
236  index_t Batch_)
237  : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
238  p_a_grid{p_a_grid_},
239  p_b_grid{p_b_grid_},
240  p_c_grid{p_c_grid_},
241  Batch(Batch_),
242  compute_ptr_offset_of_batch{BatchStrideA, BatchStrideB, BatchStrideC}
243  {
244  }
245 
246  const ADataType* p_a_grid;
247  const BDataType* p_b_grid;
248  CDataType* p_c_grid;
251  };
252 
253  // Invoker
254  struct Invoker : public BaseInvoker
255  {
257 
258  float Run(const Argument& karg, const StreamConfig& stream_config = StreamConfig{})
259  {
260  if(stream_config.log_level_ > 0)
261  {
262  karg.Print();
263  }
264 
265  if(!GridwiseGemm::CheckValidity(karg))
266  {
267  throw std::runtime_error(
268  "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext has invalid setting");
269  }
270 
271  auto [gdx, gdy, gdz] = GridwiseGemm::CalculateGridSize(karg.M, karg.N);
272  gdx *= karg.Batch;
273 
274  float ave_time = 0;
275 
277  {
278  const auto kernel =
279  kernel_batched_gemm_xdlops_v2r3<DeviceBatchedGemmXdl, GridwiseGemm, true>;
280 
281  ave_time = launch_and_time_kernel(
282  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
283  }
284  else
285  {
286  const auto kernel =
287  kernel_batched_gemm_xdlops_v2r3<DeviceBatchedGemmXdl, GridwiseGemm, false>;
288 
289  ave_time = launch_and_time_kernel(
290  stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
291  }
292 
293  return ave_time;
294  }
295 
296  // polymorphic
297  float Run(const BaseArgument* p_arg,
298  const StreamConfig& stream_config = StreamConfig{}) override
299  {
300  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
301  }
302  };
303 
304  static constexpr bool IsValidCompilationParameter()
305  {
306  // TODO: properly implement this check
307  return true;
308  }
309 
310  static bool IsSupportedArgument(const Problem& problem)
311  {
312  if(!ck::is_xdl_supported())
313  {
314  return false;
315  }
316 
317  return GridwiseGemm::CheckValidity(problem);
318  }
319 
320  // polymorphic
321  bool IsSupportedArgument(const BaseArgument* p_arg) override
322  {
323  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
324  }
325 
326  static auto MakeArgument(const ADataType* p_a,
327  const BDataType* p_b,
328  CDataType* p_c,
329  index_t M,
330  index_t N,
331  index_t K,
332  index_t StrideA,
333  index_t StrideB,
334  index_t StrideC,
335  index_t BatchStrideA,
336  index_t BatchStrideB,
337  index_t BatchStrideC,
338  index_t Batch)
339  {
340  return Argument{p_a,
341  p_b,
342  p_c,
343  M,
344  N,
345  K,
346  StrideA,
347  StrideB,
348  StrideC,
349  BatchStrideA,
350  BatchStrideB,
351  BatchStrideC,
352  Batch};
353  }
354 
355  static auto MakeInvoker() { return Invoker{}; }
356 
357  // polymorphic
358  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
359  const void* p_b,
360  void* p_c,
361  index_t M,
362  index_t N,
363  index_t K,
364  index_t StrideA,
365  index_t StrideB,
366  index_t StrideC,
367  index_t BatchStrideA,
368  index_t BatchStrideB,
369  index_t BatchStrideC,
370  index_t Batch,
371  AElementwiseOperation,
372  BElementwiseOperation,
373  CElementwiseOperation) override
374  {
375  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
376  static_cast<const BDataType*>(p_b),
377  static_cast<CDataType*>(p_c),
378  M,
379  N,
380  K,
381  StrideA,
382  StrideB,
383  StrideC,
384  BatchStrideA,
385  BatchStrideB,
386  BatchStrideC,
387  Batch);
388  }
389 
390  // polymorphic
391  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
392  {
393  return std::make_unique<Invoker>(Invoker{});
394  }
395 
396  // polymorphic
397  std::string GetTypeString() const override
398  {
399  auto str = std::stringstream();
400 
401  std::map<LoopScheduler, std::string> LoopSchedToString{
402  {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
403 
404  std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
405  {PipelineVersion::v2, "v2"}};
406 
407  // clang-format off
408  str << "DeviceBatchedGemmXdl"
409  << "<"
410  << BlockSize << ", "
411  << MPerBlock << ", "
412  << NPerBlock << ", "
413  << K0PerBlock << ", "
414  << K1 << ", "
415  << MPerXDL << ", "
416  << NPerXDL << ", "
417  << MXdlPerWave << ", "
418  << NXdlPerWave << ", "
419  << ">"
420  << " NumGemmKPrefetchStage: "
421  << NumGemmKPrefetchStage << ", "
422  << "LoopScheduler: "
423  << LoopSchedToString[LoopSched] << ", "
424  << "PipelineVersion: "
425  << PipelineVersionToString[PipelineVer];
426  // clang-format on
427 
428  return str.str();
429  }
430 };
431 
432 } // namespace device
433 } // namespace tensor_operation
434 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:34
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__global__ void kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg)
Definition: device_batched_gemm_xdl.hpp:53
Definition: ck.hpp:264
bool is_xdl_supported()
Definition: device_prop.hpp:54
__device__ index_t get_grid_size()
Definition: get_id.hpp:24
int64_t long_index_t
Definition: ck.hpp:290
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition: amd_wave_read_first_lane.hpp:100
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:289
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:17
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_gemm_xdlops_v2r3.hpp:781
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdlops_v2r3.hpp:968
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_batched_gemm.hpp:25
Definition: device_batched_gemm_xdl.hpp:223
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch
Definition: device_batched_gemm_xdl.hpp:250
const BDataType * p_b_grid
Definition: device_batched_gemm_xdl.hpp:247
const ADataType * p_a_grid
Definition: device_batched_gemm_xdl.hpp:246
Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t Batch_)
Definition: device_batched_gemm_xdl.hpp:224
index_t Batch
Definition: device_batched_gemm_xdl.hpp:249
CDataType * p_c_grid
Definition: device_batched_gemm_xdl.hpp:248
__host__ constexpr __device__ long_index_t GetBPtrOffset(index_t g_idx) const
Definition: device_batched_gemm_xdl.hpp:158
__host__ constexpr __device__ long_index_t GetCPtrOffset(index_t g_idx) const
Definition: device_batched_gemm_xdl.hpp:163
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC)
Definition: device_batched_gemm_xdl.hpp:146
__host__ constexpr __device__ long_index_t GetAPtrOffset(index_t g_idx) const
Definition: device_batched_gemm_xdl.hpp:153
Definition: device_batched_gemm_xdl.hpp:255
float Run(const Argument &karg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_batched_gemm_xdl.hpp:258
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_batched_gemm_xdl.hpp:297
Definition: device_batched_gemm_xdl.hpp:137
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_batched_gemm_xdl.hpp:321
static constexpr auto I0
Definition: device_batched_gemm_xdl.hpp:138
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t Batch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_batched_gemm_xdl.hpp:358
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_batched_gemm_xdl.hpp:391
static auto MakeInvoker()
Definition: device_batched_gemm_xdl.hpp:355
std::string GetTypeString() const override
Definition: device_batched_gemm_xdl.hpp:397
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t Batch)
Definition: device_batched_gemm_xdl.hpp:326
static constexpr auto I1
Definition: device_batched_gemm_xdl.hpp:139
static bool IsSupportedArgument(const Problem &problem)
Definition: device_batched_gemm_xdl.hpp:310
static constexpr auto K1Number
Definition: device_batched_gemm_xdl.hpp:142
static constexpr auto I2
Definition: device_batched_gemm_xdl.hpp:140
static constexpr bool IsValidCompilationParameter()
Definition: device_batched_gemm_xdl.hpp:304
typename GridwiseGemm::Problem Problem
Definition: device_batched_gemm_xdl.hpp:219