/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp Source File
device_gemm_xdl_waveletmodel_cshuffle.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
19 
20 namespace ck {
21 
22 template <typename GridwiseGemm,
23  typename ABDataType,
24  typename EDataType,
25  typename AElementwiseOperation,
26  typename BElementwiseOperation,
27  typename EElementwiseOperation,
28  typename AGridDesc_AK0_M_AK1,
29  typename BGridDesc_BK0_N_BK1,
30  typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
31  typename Block2ETileMap,
32  bool HasMainKBlockLoop>
33 __global__ void
34 #if CK_USE_LAUNCH_BOUNDS
36 #endif
38  const ABDataType* __restrict__ p_a_grid,
39  const ABDataType* __restrict__ p_b_grid,
40  EDataType* __restrict__ p_e_grid,
41  const AElementwiseOperation a_element_op,
42  const BElementwiseOperation b_element_op,
43  const EElementwiseOperation e_element_op,
44  const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
45  const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
46  const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
47  e_grid_desc_mblock_mperblock_nblock_nperblock,
48  const Block2ETileMap block_2_etile_map)
49 {
50 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
51  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
52 
53  GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
54  p_b_grid,
55  p_e_grid,
56  p_shared,
57  a_element_op,
58  b_element_op,
59  e_element_op,
60  a_grid_desc_ak0_m_ak1,
61  b_grid_desc_bk0_n_bk1,
62  e_grid_desc_mblock_mperblock_nblock_nperblock,
63  block_2_etile_map);
64 #else
65  ignore = p_a_grid;
66  ignore = p_b_grid;
67  ignore = p_e_grid;
68  ignore = a_element_op;
69  ignore = b_element_op;
70  ignore = e_element_op;
71  ignore = a_grid_desc_ak0_m_ak1;
72  ignore = b_grid_desc_bk0_n_bk1;
73  ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
74  ignore = block_2_etile_map;
75 #endif
76 }
77 
78 } // namespace ck
79 
80 namespace ck {
81 namespace tensor_operation {
82 namespace device {
83 
84 template <typename ALayout,
85  typename BLayout,
86  typename ELayout,
87  typename ADataType,
88  typename BDataType,
89  typename GemmAcEDataType,
90  typename CShuffleDataType,
91  typename EDataType,
92  typename AElementwiseOperation,
93  typename BElementwiseOperation,
94  typename CDEElementwiseOperation,
95  GemmSpecialization GemmSpec,
96  index_t NumGemmKPrefetchStage,
97  index_t TileLoadThreadGroupSize,
98  index_t TileMathThreadGroupSize,
99  index_t MPerBlock,
100  index_t NPerBlock,
101  index_t KPerBlock,
102  index_t AK1,
103  index_t BK1,
104  index_t MPerXDL,
105  index_t NPerXDL,
106  index_t MXdlPerWave,
107  index_t NXdlPerWave,
108  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
109  typename ABlockTransferThreadClusterArrangeOrder,
110  typename ABlockTransferSrcAccessOrder,
111  index_t ABlockTransferSrcVectorDim,
112  index_t ABlockTransferSrcScalarPerVector,
113  index_t ABlockTransferDstScalarPerVector_AK1,
114  bool ABlockLdsExtraM,
115  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
116  typename BBlockTransferThreadClusterArrangeOrder,
117  typename BBlockTransferSrcAccessOrder,
118  index_t BBlockTransferSrcVectorDim,
119  index_t BBlockTransferSrcScalarPerVector,
120  index_t BBlockTransferDstScalarPerVector_BK1,
121  bool BBlockLdsExtraN,
122  index_t CShuffleMXdlPerWavePerShuffle,
123  index_t CShuffleNXdlPerWavePerShuffle,
124  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
125  index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
127  BLayout,
128  ELayout,
129  ADataType,
130  BDataType,
131  EDataType,
132  AElementwiseOperation,
133  BElementwiseOperation,
134  CDEElementwiseOperation>
135 {
137 
138  static constexpr auto I0 = Number<0>{};
139  static constexpr auto I1 = Number<1>{};
140  static constexpr auto I2 = Number<2>{};
141 
142  static constexpr auto matrix_padder =
143  MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
144 
145  static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
146  {
147  const auto a_grid_desc_mraw_kraw = [&]() {
148  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
149  {
150  return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
151  make_tuple(StrideA, I1));
152  }
153  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
154  {
155  return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
156  make_tuple(I1, StrideA));
157  }
158  }();
159 
160  return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
161  }
162 
163  static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
164  {
165  const auto b_grid_desc_nraw_kraw = [&]() {
167  {
168  return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
169  make_tuple(I1, StrideB));
170  }
172  {
173  return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
174  make_tuple(StrideB, I1));
175  }
176  }();
177 
178  return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
179  }
180 
181  template <typename ELay>
182  static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
183  {
184  const auto e_grid_desc_mraw_nraw = [&]() {
186  {
187  return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
188  make_tuple(StrideE, I1));
189  }
191  {
192  return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
193  make_tuple(I1, StrideE));
194  }
195  }();
196 
197  return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
198  }
199 
200  using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
201  using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
202  using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
203 
204  // GridwiseGemm
206  ADataType, // TODO: distinguish A/B datatype
207  GemmAcEDataType,
208  CShuffleDataType,
209  EDataType,
210  AElementwiseOperation,
211  BElementwiseOperation,
212  CDEElementwiseOperation,
217  NumGemmKPrefetchStage,
218  TileLoadThreadGroupSize,
219  TileMathThreadGroupSize,
220  MPerBlock,
221  NPerBlock,
222  KPerBlock,
223  AK1,
224  BK1,
225  MPerXDL,
226  NPerXDL,
227  MXdlPerWave,
228  NXdlPerWave,
229  ABlockTransferThreadClusterLengths_AK0_M_AK1,
230  ABlockTransferThreadClusterArrangeOrder,
231  ABlockTransferSrcAccessOrder,
232  ABlockTransferSrcVectorDim,
233  ABlockTransferSrcScalarPerVector,
234  ABlockTransferDstScalarPerVector_AK1,
235  false,
236  ABlockLdsExtraM,
237  BBlockTransferThreadClusterLengths_BK0_N_BK1,
238  BBlockTransferThreadClusterArrangeOrder,
239  BBlockTransferSrcAccessOrder,
240  BBlockTransferSrcVectorDim,
241  BBlockTransferSrcScalarPerVector,
242  BBlockTransferDstScalarPerVector_BK1,
243  false,
244  BBlockLdsExtraN,
245  CShuffleMXdlPerWavePerShuffle,
246  CShuffleNXdlPerWavePerShuffle,
247  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
248  CShuffleBlockTransferScalarPerVector_NPerBlock>;
249 
252  AGridDesc_M_K{}))>;
255  BGridDesc_N_K{}))>;
256 
258 
259  // Argument
260  struct Argument : public BaseArgument
261  {
262  Argument(const ADataType* p_a_grid,
263  const BDataType* p_b_grid,
264  EDataType* p_e_grid,
265  index_t MRaw,
266  index_t NRaw,
267  index_t KRaw,
268  index_t StrideA,
269  index_t StrideB,
270  index_t StrideE,
271  AElementwiseOperation a_element_op,
272  BElementwiseOperation b_element_op,
273  CDEElementwiseOperation cde_element_op)
274  : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
275  p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
276  p_e_grid_{static_cast<EDataType*>(p_e_grid)},
277  a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
278  b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
279  e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideE)},
281  GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
283  GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
285  block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
286  a_element_op_{a_element_op},
287  b_element_op_{b_element_op},
288  cde_element_op_{cde_element_op}
289  {
292  {
296  }
297  }
298 
299  void Print() const
300  {
301  std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
302  std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
303  std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
304  }
305 
306  // private:
307  // pointers
308  const ADataType* p_a_grid_;
309  const BDataType* p_b_grid_;
310  EDataType* p_e_grid_;
311 
312  // tensor descriptors for problem definiton
316 
317  // tensor descriptors for block/thread-wise copy
322 
323  // block-to-e-tile map
325 
326  // element-wise op
327  AElementwiseOperation a_element_op_;
328  BElementwiseOperation b_element_op_;
329  CDEElementwiseOperation cde_element_op_;
330  };
331 
332  // Invoker
333  struct Invoker : public BaseInvoker
334  {
336 
337  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
338  {
339 #if 0
340  {
341  std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
342  << arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
343  << arg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
344  << arg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
345 
346  std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
347  << arg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
348  << arg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
349  << arg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
350 
351  std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
352  << arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
353  }
354 #endif
355 
357  arg.b_grid_desc_n_k_,
358  arg.e_grid_desc_m_n_,
359  arg.block_2_etile_map_))
360  {
361  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
362  }
363 
365  const auto K = arg.a_grid_desc_m_k_.GetLength(I1);
366 
367  auto launch_kernel = [&](auto has_main_k_block_loop) {
368  constexpr bool has_main_loop = has_main_k_block_loop.value;
369 
370  const auto kernel = kernel_gemm_xdl_waveletmodel_cshuffle<
371  GridwiseGemm,
372  ADataType, // TODO: distiguish A/B datatype
373  EDataType,
374  AElementwiseOperation,
375  BElementwiseOperation,
376  CDEElementwiseOperation,
381  has_main_loop>;
382 
383  return launch_and_time_kernel(
384  stream_config,
385  kernel,
386  dim3(grid_size),
387  dim3(TileLoadThreadGroupSize + TileMathThreadGroupSize),
388  0,
389  arg.p_a_grid_,
390  arg.p_b_grid_,
391  arg.p_e_grid_,
392  arg.a_element_op_,
393  arg.b_element_op_,
394  arg.cde_element_op_,
398  arg.block_2_etile_map_);
399  };
400 
402  {
404  }
405  else
406  {
407  return launch_kernel(integral_constant<bool, false>{});
408  }
409  }
410 
411  // polymorphic
412  float Run(const BaseArgument* p_arg,
413  const StreamConfig& stream_config = StreamConfig{}) override
414  {
415  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
416  }
417  };
418 
419  static bool IsSupportedArgument(const Argument& arg)
420  {
421  if(!ck::is_xdl_supported())
422  {
423  return false;
424  }
425 
427  arg.b_grid_desc_n_k_,
428  arg.e_grid_desc_m_n_,
429  arg.block_2_etile_map_);
430  }
431 
432  // polymorphic
433  bool IsSupportedArgument(const BaseArgument* p_arg) override
434  {
435  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
436  }
437 
438  static auto MakeArgument(const ADataType* p_a,
439  const BDataType* p_b,
440  EDataType* p_e,
441  index_t MRaw,
442  index_t NRaw,
443  index_t KRaw,
444  index_t StrideA,
445  index_t StrideB,
446  index_t StrideE,
447  AElementwiseOperation a_element_op,
448  BElementwiseOperation b_element_op,
449  CDEElementwiseOperation cde_element_op)
450  {
451  return Argument{p_a,
452  p_b,
453  p_e,
454  MRaw,
455  NRaw,
456  KRaw,
457  StrideA,
458  StrideB,
459  StrideE,
460  a_element_op,
461  b_element_op,
462  cde_element_op};
463  }
464 
465  static auto MakeInvoker() { return Invoker{}; }
466 
467  // polymorphic
468  std::unique_ptr<BaseArgument>
469  MakeArgumentPointer(const void* p_a,
470  const void* p_b,
471  void* p_e,
472  index_t MRaw,
473  index_t NRaw,
474  index_t KRaw,
475  index_t StrideA,
476  index_t StrideB,
477  index_t StrideE,
478  AElementwiseOperation a_element_op,
479  BElementwiseOperation b_element_op,
480  CDEElementwiseOperation cde_element_op) override
481  {
482  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
483  static_cast<const BDataType*>(p_b),
484  static_cast<EDataType*>(p_e),
485  MRaw,
486  NRaw,
487  KRaw,
488  StrideA,
489  StrideB,
490  StrideE,
491  a_element_op,
492  b_element_op,
493  cde_element_op);
494  }
495 
496  // polymorphic
497  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
498  {
499  return std::make_unique<Invoker>(Invoker{});
500  }
501 
502  // polymorphic
503  std::string GetTypeString() const override
504  {
505  auto str = std::stringstream();
506 
507  // clang-format off
508  str << "DeviceGemm_Xdl_WaveletModel_CShuffle"
509  << "<"
510  << TileLoadThreadGroupSize << ", "
511  << TileMathThreadGroupSize << ", "
512  << MPerBlock << ", "
513  << NPerBlock << ", "
514  << KPerBlock << ", "
515  << AK1 << ", "
516  << BK1
517  << ">";
518  // clang-format on
519 
520  return str.str();
521  }
522 };
523 
524 } // namespace device
525 } // namespace tensor_operation
526 } // namespace ck
#define CK_WAVELET_MIN_BLOCK_PER_CU
Definition: ck.hpp:38
#define CK_WAVELET_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:37
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
GemmSpecialization
Definition: gemm_specialization.hpp:11
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables... callables)
Definition: kernel_launch.hpp:72
Definition: ck.hpp:264
bool is_xdl_supported()
Definition: device_prop.hpp:54
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_gemm_xdl_waveletmodel_cshuffle(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, EDataType *__restrict__ p_e_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const EElementwiseOperation e_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:37
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
int32_t index_t
Definition: ck.hpp:289
Definition: stream_config.hpp:10
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:64
__host__ static constexpr __device__ auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDescriptor_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:314
__host__ static constexpr __device__ auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:298
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const EGridDesc_M_N &e_grid_desc_m_n, const Block2ETileMap &)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:177
remove_cvref_t< decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> DefaultBlock2ETileMap
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:338
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:226
__host__ static constexpr __device__ index_t CalculateGridSize(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:270
remove_cvref_t< decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))> EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:335
__host__ static constexpr __device__ auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:282
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:261
AElementwiseOperation a_element_op_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:327
BGridDesc_N_K b_grid_desc_n_k_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:314
const BDataType * p_b_grid_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:309
Block2ETileMap block_2_etile_map_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:324
CDEElementwiseOperation cde_element_op_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:329
const ADataType * p_a_grid_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:308
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:319
BElementwiseOperation b_element_op_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:328
EDataType * p_e_grid_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:310
void Print() const
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:299
AGridDesc_M_K a_grid_desc_m_k_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:313
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, EDataType *p_e_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:262
GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:321
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:318
EGridDesc_M_N e_grid_desc_m_n_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:315
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:334
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:337
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:412
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:135
static constexpr auto matrix_padder
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:142
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:252
static constexpr auto I1
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:139
static auto MakeInvoker()
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:465
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:163
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_e, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:469
static constexpr auto I2
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:140
static constexpr auto I0
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:138
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:497
typename GridwiseGemm::DefaultBlock2ETileMap Block2ETileMap
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:257
decltype(MakeEGridDescriptor_M_N< ELayout >(1, 1, 1)) EGridDesc_M_N
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:202
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:255
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:433
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:145
GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle< ADataType, GemmAcEDataType, CShuffleDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_M_K, BGridDesc_N_K, EGridDesc_M_N, NumGemmKPrefetchStage, TileLoadThreadGroupSize, TileMathThreadGroupSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock > GridwiseGemm
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:248
decltype(MakeAGridDescriptor_M_K(1, 1, 1)) AGridDesc_M_K
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:200
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:419
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:182
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, EDataType *p_e, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:438
decltype(MakeBGridDescriptor_N_K(1, 1, 1)) BGridDesc_N_K
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:201
std::string GetTypeString() const override
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:503
Definition: device_gemm.hpp:22
Definition: matrix_padder.hpp:180