/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp Source File
device_cgemm_4gemm_xdl_cshuffle.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
21 
22 namespace ck {
23 namespace tensor_operation {
24 namespace device {
25 
26 template <
27  typename ALayout,
28  typename BLayout,
29  typename CLayout,
30  typename ADataType,
31  typename BDataType,
32  typename CDataType,
33  typename GemmAccDataType,
34  typename CShuffleDataType,
35  typename AElementwiseOperation,
36  typename BElementwiseOperation,
37  typename CElementwiseOperation,
38  GemmSpecialization GemmSpec,
39  index_t NumGemmKPrefetchStage,
40  index_t BlockSize,
41  index_t MPerBlock,
42  index_t NPerBlock,
43  index_t KPerBlock,
44  index_t AK1,
45  index_t BK1,
46  index_t MPerXDL,
47  index_t NPerXDL,
48  index_t MXdlPerWave,
49  index_t NXdlPerWave,
50  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
51  typename ABlockTransferThreadClusterArrangeOrder,
52  typename ABlockTransferSrcAccessOrder,
53  index_t ABlockTransferSrcVectorDim,
54  index_t ABlockTransferSrcScalarPerVector,
55  index_t ABlockTransferDstScalarPerVector_AK1,
56  bool ABlockLdsExtraM,
57  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
58  typename BBlockTransferThreadClusterArrangeOrder,
59  typename BBlockTransferSrcAccessOrder,
60  index_t BBlockTransferSrcVectorDim,
61  index_t BBlockTransferSrcScalarPerVector,
62  index_t BBlockTransferDstScalarPerVector_BK1,
63  bool BBlockLdsExtraN,
64  index_t CShuffleMXdlPerWavePerShuffle,
65  index_t CShuffleNXdlPerWavePerShuffle,
66  typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
67  index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
70  is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
71  is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
72  is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
73  bool> = false>
75  : public DeviceCGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
76 {
78 
79  static constexpr auto I0 = Number<0>{};
80  static constexpr auto I1 = Number<1>{};
81  static constexpr auto I2 = Number<2>{};
82 
83  static constexpr index_t MPerThread =
84  MPerBlock / CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(1);
85  static constexpr index_t NPerThread =
86  NPerBlock / CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock::At(3);
87 
88  static constexpr auto AScalarPerVector = Number<4>{};
89  static constexpr auto BScalarPerVector = Number<4>{};
90  static constexpr auto CScalarPerVector = Number<4>{};
91 
92  template <typename Desc_M_N>
93  static auto PadDescriptor_M_N(Desc_M_N desc)
94  {
95  const auto M = desc.GetLength(I0);
96  const auto N = desc.GetLength(I1);
97  const auto pad_M = math::integer_divide_ceil(M, MPerThread) * MPerThread - M;
98  const auto pad_N = math::integer_divide_ceil(N, NPerThread) * NPerThread - N;
99 
100  const auto padded_desc = transform_tensor_descriptor(
101  desc,
105 
106  return padded_desc;
107  }
108 
109  static auto MakeDescriptor_M_N(const std::vector<index_t>& lengths,
110  const std::vector<index_t>& strides)
111  {
112  auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<2>{});
113  auto tupleOfStride = generate_tuple([&](auto I) { return strides[I]; }, Number<2>{});
114 
115  // nd desc - [s0, s1, s2, ...]
116  const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
117  return PadDescriptor_M_N(desc);
118  }
119 
120  // GridwiseGemm
122  ALayout,
123  BLayout,
124  CLayout,
125  ADataType,
126  BDataType,
127  GemmAccDataType,
128  CShuffleDataType,
129  CDataType,
130  AElementwiseOperation,
131  BElementwiseOperation,
132  CElementwiseOperation,
133  GemmSpec,
135  NumGemmKPrefetchStage,
136  BlockSize,
137  MPerBlock,
138  NPerBlock,
139  KPerBlock,
140  AK1,
141  BK1,
142  MPerXDL,
143  NPerXDL,
144  MXdlPerWave,
145  NXdlPerWave,
146  ABlockTransferThreadClusterLengths_AK0_M_AK1,
147  ABlockTransferThreadClusterArrangeOrder,
148  ABlockTransferSrcAccessOrder,
149  ABlockTransferSrcVectorDim,
150  ABlockTransferSrcScalarPerVector,
151  ABlockTransferDstScalarPerVector_AK1,
152  false,
153  ABlockLdsExtraM,
154  BBlockTransferThreadClusterLengths_BK0_N_BK1,
155  BBlockTransferThreadClusterArrangeOrder,
156  BBlockTransferSrcAccessOrder,
157  BBlockTransferSrcVectorDim,
158  BBlockTransferSrcScalarPerVector,
159  BBlockTransferDstScalarPerVector_BK1,
160  false,
161  BBlockLdsExtraN,
162  CShuffleMXdlPerWavePerShuffle,
163  CShuffleNXdlPerWavePerShuffle,
164  CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
165  CShuffleBlockTransferScalarPerVector_NPerBlock,
166  LoopSched>;
167 
168  using CGridDesc_M_N = decltype(MakeDescriptor_M_N({1, 1}, {1, 1}));
169 
170  // Argument
172  {
173  using Problem = typename GridwiseGemm::Problem;
174 
175  Argument(const ADataType* p_a_grid_real_,
176  const ADataType* p_a_grid_imag_,
177  const BDataType* p_b_grid_real_,
178  const BDataType* p_b_grid_imag_,
179  CDataType* p_c_grid_real_,
180  CDataType* p_c_grid_imag_,
181  CDataType* p_workspace,
182  index_t M_,
183  index_t N_,
184  index_t K_,
185  index_t StrideA_,
186  index_t StrideB_,
187  index_t StrideC_)
188  : Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
189  p_a_grid_real{p_a_grid_real_},
190  p_a_grid_imag{p_a_grid_imag_},
191  p_b_grid_real{p_b_grid_real_},
192  p_b_grid_imag{p_b_grid_imag_},
193  p_c_grid_real{p_c_grid_real_},
194  p_c_grid_imag{p_c_grid_imag_},
195  p_aux_grid{p_workspace}
196  {
198  {
199  c_grid_desc_m_n = DeviceOp::MakeDescriptor_M_N({M_, N_}, {StrideC_, I1});
200  }
202  {
203  c_grid_desc_m_n = DeviceOp::MakeDescriptor_M_N({M_, N_}, {I1, StrideC_});
204  }
205 
206  p_aux_2_grid = p_workspace + GetCElementSpaceSize(M_, N_, StrideC_);
207  }
208 
209  // private:
210  const ADataType* p_a_grid_real;
211  const ADataType* p_a_grid_imag;
212  const BDataType* p_b_grid_real;
213  const BDataType* p_b_grid_imag;
214  CDataType* p_c_grid_real;
215  CDataType* p_c_grid_imag;
216  CDataType* p_aux_grid;
217  CDataType* p_aux_2_grid;
219  };
220 
221  // Invoker
222  struct Invoker : public BaseInvoker
223  {
224  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
225  {
226  if(stream_config.log_level_ > 0)
227  {
228  arg.Print();
229  }
230 
232  {
233  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
234  }
235 
236  index_t gdx, gdy, gdz;
237  std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N);
238 
239  const auto K = GridwiseGemm::CalculateAK0(arg.K) * AK1;
240 
241  float ave_time = 0;
242 
245 
247 
252  Block2TileMap,
253  Add,
254  BlockSize,
255  MPerBlock,
256  NPerBlock,
257  MPerThread,
258  NPerThread,
262  I1,
263  I1>;
264 
265  using GridwiseBinSubtract =
270  Block2TileMap,
271  Subtract,
272  BlockSize,
273  MPerBlock,
274  NPerBlock,
275  MPerThread,
276  NPerThread,
280  I1,
281  I1>;
282 
283  const index_t M = arg.c_grid_desc_m_n.GetLength(I0);
284  const index_t N = arg.c_grid_desc_m_n.GetLength(I1);
285  const auto block_2_tile_map = Block2TileMap(M, N);
286 
287  const auto add_kernel = kernel_elementwise<GridwiseBinAdd,
292  Block2TileMap,
293  Add>;
294 
295  const auto subtract_kernel =
296  kernel_elementwise<GridwiseBinSubtract,
301  Block2TileMap,
302  Subtract>;
303 
305  {
306  const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
307  ADataType,
308  BDataType,
309  CDataType,
310  true>;
311 
312  ave_time += launch_and_time_kernel(stream_config,
313  kernel,
314  dim3(gdx, gdy, gdz),
315  dim3(BlockSize),
316  0,
317  arg.p_a_grid_real,
318  arg.p_b_grid_real,
319  arg.p_aux_grid,
320  arg);
321 
322  ave_time += launch_and_time_kernel(stream_config,
323  kernel,
324  dim3(gdx, gdy, gdz),
325  dim3(BlockSize),
326  0,
327  arg.p_a_grid_imag,
328  arg.p_b_grid_imag,
329  arg.p_aux_2_grid,
330  arg);
331 
332  // c_real = aux - aux_2
333  ave_time += launch_and_time_kernel(
334  stream_config,
335  subtract_kernel,
336  dim3(gdx, gdy, gdz),
337  dim3(BlockSize),
338  0,
341  make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
342  const_cast<const CDataType*>(arg.p_aux_2_grid)),
344  block_2_tile_map,
345  Subtract{});
346 
347  ave_time += launch_and_time_kernel(stream_config,
348  kernel,
349  dim3(gdx, gdy, gdz),
350  dim3(BlockSize),
351  0,
352  arg.p_a_grid_real,
353  arg.p_b_grid_imag,
354  arg.p_aux_grid,
355  arg);
356 
357  ave_time += launch_and_time_kernel(stream_config,
358  kernel,
359  dim3(gdx, gdy, gdz),
360  dim3(BlockSize),
361  0,
362  arg.p_a_grid_imag,
363  arg.p_b_grid_real,
364  arg.p_aux_2_grid,
365  arg);
366 
367  // c_imag = aux + aux_2
368  ave_time += launch_and_time_kernel(
369  stream_config,
370  add_kernel,
371  dim3(gdx, gdy, gdz),
372  dim3(BlockSize),
373  0,
376  make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
377  const_cast<const CDataType*>(arg.p_aux_2_grid)),
379  block_2_tile_map,
380  Add{});
381  }
382  else
383  {
384  const auto kernel = kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
385  ADataType,
386  BDataType,
387  CDataType,
388  false>;
389 
390  ave_time += launch_and_time_kernel(stream_config,
391  kernel,
392  dim3(gdx, gdy, gdz),
393  dim3(BlockSize),
394  0,
395  arg.p_a_grid_real,
396  arg.p_b_grid_real,
397  arg.p_aux_grid,
398  arg);
399 
400  ave_time += launch_and_time_kernel(stream_config,
401  kernel,
402  dim3(gdx, gdy, gdz),
403  dim3(BlockSize),
404  0,
405  arg.p_a_grid_imag,
406  arg.p_b_grid_imag,
407  arg.p_aux_2_grid,
408  arg);
409 
410  // c_real = aux - aux_2
411  ave_time += launch_and_time_kernel(
412  stream_config,
413  subtract_kernel,
414  dim3(gdx, gdy, gdz),
415  dim3(BlockSize),
416  0,
419  make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
420  const_cast<const CDataType*>(arg.p_aux_2_grid)),
422  block_2_tile_map,
423  Subtract{});
424 
425  ave_time += launch_and_time_kernel(stream_config,
426  kernel,
427  dim3(gdx, gdy, gdz),
428  dim3(BlockSize),
429  0,
430  arg.p_a_grid_real,
431  arg.p_b_grid_imag,
432  arg.p_aux_grid,
433  arg);
434 
435  ave_time += launch_and_time_kernel(stream_config,
436  kernel,
437  dim3(gdx, gdy, gdz),
438  dim3(BlockSize),
439  0,
440  arg.p_a_grid_imag,
441  arg.p_b_grid_real,
442  arg.p_aux_2_grid,
443  arg);
444 
445  // c_imag = aux + aux_2
446  ave_time += launch_and_time_kernel(
447  stream_config,
448  add_kernel,
449  dim3(gdx, gdy, gdz),
450  dim3(BlockSize),
451  0,
454  make_tuple(const_cast<const CDataType*>(arg.p_aux_grid),
455  const_cast<const CDataType*>(arg.p_aux_2_grid)),
457  block_2_tile_map,
458  Add{});
459  }
460 
461  return ave_time;
462  }
463 
464  // polymorphic
465  float Run(const BaseArgument* p_arg,
466  const StreamConfig& stream_config = StreamConfig{}) override
467  {
468  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
469  }
470  };
471 
472  static constexpr bool IsValidCompilationParameter()
473  {
474  // TODO: properly implement this check
475  return true;
476  }
477 
478  static bool IsSupportedArgument(const Argument& arg)
479  {
480  if(!ck::is_xdl_supported())
481  {
482  return false;
483  }
484 
485  return GridwiseGemm::CheckValidity(arg);
486  }
487 
488  // polymorphic
489  bool IsSupportedArgument(const BaseArgument* p_arg) override
490  {
491  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
492  }
493 
494  static auto MakeArgument(const ADataType* p_a_real,
495  const ADataType* p_a_imag,
496  const BDataType* p_b_real,
497  const BDataType* p_b_imag,
498  CDataType* p_c_real,
499  CDataType* p_c_imag,
500  CDataType* p_workspace,
501  index_t M,
502  index_t N,
503  index_t K,
504  index_t StrideA,
505  index_t StrideB,
506  index_t StrideC,
507  AElementwiseOperation,
508  BElementwiseOperation,
509  CElementwiseOperation)
510  {
511  return Argument{p_a_real,
512  p_a_imag,
513  p_b_real,
514  p_b_imag,
515  p_c_real,
516  p_c_imag,
517  p_workspace,
518  M,
519  N,
520  K,
521  StrideA,
522  StrideB,
523  StrideC};
524  }
525 
526  static auto MakeInvoker() { return Invoker{}; }
527 
528  // polymorphic
529  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a_real,
530  const void* p_a_imag,
531  const void* p_b_real,
532  const void* p_b_imag,
533  void* p_c_real,
534  void* p_c_imag,
535  void* p_workspace,
536  index_t M,
537  index_t N,
538  index_t K,
539  index_t StrideA,
540  index_t StrideB,
541  index_t StrideC,
542  AElementwiseOperation,
543  BElementwiseOperation,
544  CElementwiseOperation,
545  index_t /* KBatch */ = 1) override
546  {
547  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a_real),
548  static_cast<const ADataType*>(p_a_imag),
549  static_cast<const BDataType*>(p_b_real),
550  static_cast<const BDataType*>(p_b_imag),
551  static_cast<CDataType*>(p_c_real),
552  static_cast<CDataType*>(p_c_imag),
553  static_cast<CDataType*>(p_workspace),
554  M,
555  N,
556  K,
557  StrideA,
558  StrideB,
559  StrideC);
560  }
561 
562  // polymorphic
563  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
564  {
565  return std::make_unique<Invoker>(Invoker{});
566  }
567 
568  // polymorphic
569  std::string GetTypeString() const override
570  {
571  auto str = std::stringstream();
572 
573  // clang-format off
574  str << "DeviceCGemm_4Gemm_Xdl_CShuffle"
575  << "<"
576  << BlockSize << ", "
577  << MPerBlock << ", "
578  << NPerBlock << ", "
579  << KPerBlock << ", "
580  << AK1 << ", "
581  << BK1
582  << ">";
583  // clang-format on
584 
585  return str.str();
586  }
587 
588  static std::size_t GetCElementSpaceSize(index_t M, index_t N, index_t StrideC)
589  {
590  const auto c_grid_desc_m_n = GridwiseGemm::MakeCGridDescriptor_M_N(
592 
593  return c_grid_desc_m_n.GetElementSpaceSize();
594  }
595 
596  std::size_t GetWorkspaceSize(index_t M,
597  index_t N,
598  [[maybe_unused]] index_t K,
599  [[maybe_unused]] index_t StrideA,
600  [[maybe_unused]] index_t StrideB,
601  index_t StrideC) const override
602  {
603  return 2 * sizeof(CDataType) * GetCElementSpaceSize(M, N, StrideC);
604  }
605 
606  std::size_t GetWorkSpaceSize(const BaseArgument* base_arg) const override
607  {
608  const auto* parg = dynamic_cast<const Argument*>(base_arg);
609 
610  if(!parg)
611  {
612  std::ostringstream err;
613  err << "Provided argument pointer is not of an Argument class!"
614  << " In " << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
615  throw std::runtime_error(err.str());
616  }
617 
618  return GetWorkspaceSize(
619  parg->M, parg->N, parg->K, parg->StrideA, parg->StrideB, parg->StrideC);
620  }
621 };
622 
623 } // namespace device
624 } // namespace tensor_operation
625 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:264
bool is_xdl_supported()
Definition: device_prop.hpp:54
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
__global__ void kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:25
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:289
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:13
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__global__ void kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op)
Definition: gridwise_elementwise_2d.hpp:29
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_elementwise_2d.hpp:162
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:414
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:456
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:457
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:437
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:455
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:114
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:557
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:142
static __host__ auto CalculateAK0(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:152
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:137
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:661
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:361
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:132
Definition: sequence.hpp:43
Definition: tuple.hpp:117
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:172
typename GridwiseGemm::Problem Problem
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:173
CGridDesc_M_N c_grid_desc_m_n
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:218
CDataType * p_c_grid_real
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:214
const BDataType * p_b_grid_imag
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:213
const ADataType * p_a_grid_imag
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:211
const ADataType * p_a_grid_real
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:210
CDataType * p_aux_grid
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:216
CDataType * p_aux_2_grid
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:217
CDataType * p_c_grid_imag
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:215
Argument(const ADataType *p_a_grid_real_, const ADataType *p_a_grid_imag_, const BDataType *p_b_grid_real_, const BDataType *p_b_grid_imag_, CDataType *p_c_grid_real_, CDataType *p_c_grid_imag_, CDataType *p_workspace, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:175
const BDataType * p_b_grid_real
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:212
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:223
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:224
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:465
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:76
static constexpr auto I2
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:81
static std::size_t GetCElementSpaceSize(index_t M, index_t N, index_t StrideC)
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:588
static auto MakeDescriptor_M_N(const std::vector< index_t > &lengths, const std::vector< index_t > &strides)
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:109
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:489
static constexpr bool IsValidCompilationParameter()
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:472
static auto MakeInvoker()
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:526
std::size_t GetWorkspaceSize(index_t M, index_t N, [[maybe_unused]] index_t K, [[maybe_unused]] index_t StrideA, [[maybe_unused]] index_t StrideB, index_t StrideC) const override
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:596
static constexpr index_t MPerThread
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:83
decltype(MakeDescriptor_M_N({1, 1}, {1, 1})) CGridDesc_M_N
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:168
static constexpr auto I1
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:80
static constexpr auto I0
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:79
static constexpr auto CScalarPerVector
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:90
std::string GetTypeString() const override
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:569
static auto PadDescriptor_M_N(Desc_M_N desc)
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:93
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a_real, const void *p_a_imag, const void *p_b_real, const void *p_b_imag, void *p_c_real, void *p_c_imag, void *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, index_t=1) override
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:529
static constexpr auto BScalarPerVector
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:89
static auto MakeArgument(const ADataType *p_a_real, const ADataType *p_a_imag, const BDataType *p_b_real, const BDataType *p_b_imag, CDataType *p_c_real, CDataType *p_c_imag, CDataType *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:494
static bool IsSupportedArgument(const Argument &arg)
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:478
std::size_t GetWorkSpaceSize(const BaseArgument *base_arg) const override
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:606
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched > GridwiseGemm
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:166
static constexpr index_t NPerThread
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:85
static constexpr auto AScalarPerVector
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:88
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:563
Definition: device_cgemm.hpp:15
Definition: binary_element_wise_operation.hpp:14
Definition: binary_element_wise_operation.hpp:237