/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp Source File
device_gemm_multiple_d_xdl_cshuffle.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
19 
20 namespace ck {
21 
22 template <typename GridwiseGemm,
23  typename ADataType,
24  typename BDataType,
25  typename DsPointer,
26  typename EDataType,
27  typename AElementwiseOperation,
28  typename BElementwiseOperation,
29  typename CDEElementwiseOperation,
30  typename AGridDesc_AK0_M_AK1,
31  typename BGridDesc_BK0_N_BK1,
32  typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
33  typename EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
34  typename Block2ETileMap,
35  bool HasMainKBlockLoop>
36 __global__ void
37 #if CK_USE_LAUNCH_BOUNDS
39 #endif
40  kernel_gemm_multiple_d_xdl_cshuffle(const ADataType* __restrict__ p_a_grid,
41  const BDataType* __restrict__ p_b_grid,
42  DsPointer p_ds_grid,
43  EDataType* __restrict__ p_e_grid,
44  const AElementwiseOperation a_element_op,
45  const BElementwiseOperation b_element_op,
46  const CDEElementwiseOperation cde_element_op,
47  const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
48  const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
49  const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
50  ds_grid_desc_mblock_mperblock_nblock_nperblock,
51  const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
52  e_grid_desc_mblock_mperblock_nblock_nperblock,
53  const Block2ETileMap block_2_etile_map)
54 {
55 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__))
56  __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
57 
58  GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
59  p_b_grid,
60  p_ds_grid,
61  p_e_grid,
62  p_shared,
63  a_element_op,
64  b_element_op,
65  cde_element_op,
66  a_grid_desc_ak0_m_ak1,
67  b_grid_desc_bk0_n_bk1,
68  ds_grid_desc_mblock_mperblock_nblock_nperblock,
69  e_grid_desc_mblock_mperblock_nblock_nperblock,
70  block_2_etile_map);
71 #else
72  ignore = p_a_grid;
73  ignore = p_b_grid;
74  ignore = p_ds_grid;
75  ignore = p_e_grid;
76  ignore = a_element_op;
77  ignore = b_element_op;
78  ignore = cde_element_op;
79  ignore = a_grid_desc_ak0_m_ak1;
80  ignore = b_grid_desc_bk0_n_bk1;
81  ignore = ds_grid_desc_mblock_mperblock_nblock_nperblock;
82  ignore = e_grid_desc_mblock_mperblock_nblock_nperblock;
83  ignore = block_2_etile_map;
84 #endif
85 }
86 
87 } // namespace ck
88 
89 namespace ck {
90 namespace tensor_operation {
91 namespace device {
92 
93 // GEMM:
94 // input : A[M, K]
95 // input : B[N, K]
96 // input : D0[M, N], D1[M, N], ...
97 // output : E[M, N]
98 // C = a_op(A) * b_op(B)
99 // E = cde_op(C, D0, D1, ...)
100 // Assume:
101 // D0, D1, ... and E have the same layout
102 template <typename ALayout,
103  typename BLayout,
104  typename DsLayout,
105  typename ELayout,
106  typename ADataType,
107  typename BDataType,
108  typename AccDataType,
109  typename CShuffleDataType,
110  typename DsDataType,
111  typename EDataType,
112  typename AElementwiseOperation,
113  typename BElementwiseOperation,
114  typename CDEElementwiseOperation,
115  GemmSpecialization GemmSpec,
116  index_t NumGemmKPrefetchStage,
117  index_t BlockSize,
118  index_t MPerBlock,
119  index_t NPerBlock,
120  index_t KPerBlock,
121  index_t AK1,
122  index_t BK1,
123  index_t MPerXDL,
124  index_t NPerXDL,
125  index_t MXdlPerWave,
126  index_t NXdlPerWave,
127  typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
128  typename ABlockTransferThreadClusterArrangeOrder,
129  typename ABlockTransferSrcAccessOrder,
130  index_t ABlockTransferSrcVectorDim,
131  index_t ABlockTransferSrcScalarPerVector,
132  index_t ABlockTransferDstScalarPerVector_AK1,
133  index_t ABlockLdsExtraM,
134  typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
135  typename BBlockTransferThreadClusterArrangeOrder,
136  typename BBlockTransferSrcAccessOrder,
137  index_t BBlockTransferSrcVectorDim,
138  index_t BBlockTransferSrcScalarPerVector,
139  index_t BBlockTransferDstScalarPerVector_BK1,
140  index_t BBlockLdsExtraN,
141  index_t CShuffleMXdlPerWavePerShuffle,
142  index_t CShuffleNXdlPerWavePerShuffle,
143  typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
144  index_t CDEBlockTransferScalarPerVector_NPerBlock,
146  PipelineVersion PipelineVer = PipelineVersion::v1,
147  typename ComputeDataType = EDataType>
149  BLayout,
150  DsLayout,
151  ELayout,
152  ADataType,
153  BDataType,
154  DsDataType,
155  EDataType,
156  AElementwiseOperation,
157  BElementwiseOperation,
158  CDEElementwiseOperation>
159 {
161 
162  static constexpr index_t NumDTensor = DsDataType::Size();
163 
164  static constexpr auto I0 = Number<0>{};
165  static constexpr auto I1 = Number<1>{};
166  static constexpr auto I2 = Number<2>{};
167  static constexpr auto I3 = Number<3>{};
168 
169  static constexpr auto matrix_padder =
170  MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, KPerBlock};
171 
172  static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
173  {
174  const auto a_grid_desc_mraw_kraw = [&]() {
175  if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
176  {
177  return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
178  make_tuple(StrideA, I1));
179  }
180  else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
181  {
182  return make_naive_tensor_descriptor(make_tuple(MRaw, KRaw),
183  make_tuple(I1, StrideA));
184  }
185  }();
186 
187  return matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
188  }
189 
190  static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
191  {
192  const auto b_grid_desc_nraw_kraw = [&]() {
194  {
195  return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
196  make_tuple(I1, StrideB));
197  }
199  {
200  return make_naive_tensor_descriptor(make_tuple(NRaw, KRaw),
201  make_tuple(StrideB, I1));
202  }
203  }();
204 
205  return matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
206  }
207 
208  template <typename ELay>
209  static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
210  {
211  const auto e_grid_desc_mraw_nraw = [&]() {
213  {
214  return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
215  make_tuple(StrideE, I1));
216  }
218  {
219  return make_naive_tensor_descriptor(make_tuple(MRaw, NRaw),
220  make_tuple(I1, StrideE));
221  }
222  }();
223 
224  return matrix_padder.PadCDescriptor_M_N(e_grid_desc_mraw_nraw);
225  }
226 
227  static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
228  const std::array<index_t, NumDTensor>& NRaws,
229  const std::array<index_t, NumDTensor>& DsStride)
230  {
231  return generate_tuple(
232  [&](auto i) {
233  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
234 
235  return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
236  },
238  }
239 
240  // desc for problem definition
241  using AGridDesc_M_K = decltype(MakeAGridDescriptor_M_K(1, 1, 1));
242  using BGridDesc_N_K = decltype(MakeBGridDescriptor_N_K(1, 1, 1));
244  using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
245 
246  // GridwiseGemm
248  ADataType,
249  BDataType,
250  ComputeDataType,
251  AccDataType,
252  CShuffleDataType,
253  DsDataType,
254  EDataType,
255  AElementwiseOperation,
256  BElementwiseOperation,
257  CDEElementwiseOperation,
259  NumGemmKPrefetchStage,
260  BlockSize,
261  MPerBlock,
262  NPerBlock,
263  KPerBlock,
264  AK1,
265  BK1,
266  MPerXDL,
267  NPerXDL,
268  MXdlPerWave,
269  NXdlPerWave,
270  ABlockTransferThreadClusterLengths_AK0_M_AK1,
271  ABlockTransferThreadClusterArrangeOrder,
272  ABlockTransferSrcAccessOrder,
273  ABlockTransferSrcVectorDim,
274  ABlockTransferSrcScalarPerVector,
275  ABlockTransferDstScalarPerVector_AK1,
276  false,
277  ABlockLdsExtraM,
278  BBlockTransferThreadClusterLengths_BK0_N_BK1,
279  BBlockTransferThreadClusterArrangeOrder,
280  BBlockTransferSrcAccessOrder,
281  BBlockTransferSrcVectorDim,
282  BBlockTransferSrcScalarPerVector,
283  BBlockTransferDstScalarPerVector_BK1,
284  false,
285  BBlockLdsExtraN,
286  CShuffleMXdlPerWavePerShuffle,
287  CShuffleNXdlPerWavePerShuffle,
288  CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
289  CDEBlockTransferScalarPerVector_NPerBlock,
290  LoopSched,
291  PipelineVer>;
292 
293  // desc for blockwise copy
296  AGridDesc_M_K{}))>;
299  BGridDesc_N_K{}))>;
302  DsGridDesc_M_N{}))>;
305  EGridDesc_M_N{}))>;
306 
307  // block-to-e-tile map
310 
311  // Argument
312  struct Argument : public BaseArgument
313  {
314  Argument(const void* p_a_grid,
315  const void* p_b_grid,
316  std::array<const void*, NumDTensor> p_ds_grid,
317  void* p_e_grid,
318  index_t MRaw,
319  index_t NRaw,
320  index_t KRaw,
321  index_t StrideA,
322  index_t StrideB,
323  std::array<index_t, NumDTensor> StrideDs,
324  index_t StrideE,
325  AElementwiseOperation a_element_op,
326  BElementwiseOperation b_element_op,
327  CDEElementwiseOperation cde_element_op)
328  : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
329  p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
330  p_ds_grid_{},
331  p_e_grid_{static_cast<EDataType*>(p_e_grid)},
332  a_grid_desc_m_k_{DeviceOp::MakeAGridDescriptor_M_K(MRaw, KRaw, StrideA)},
333  b_grid_desc_n_k_{DeviceOp::MakeBGridDescriptor_N_K(KRaw, NRaw, StrideB)},
335  e_grid_desc_m_n_{DeviceOp::MakeEGridDescriptor_M_N<ELayout>(MRaw, NRaw, StrideE)},
337  GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k_)},
339  GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k_)},
342  block_2_etile_map_{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n_)},
343  a_element_op_{a_element_op},
344  b_element_op_{b_element_op},
345  cde_element_op_{cde_element_op},
346  MRaw_{MRaw},
347  NRaw_{NRaw},
348  KRaw_{KRaw}
349  {
350  // populate pointer, desc for Ds
351  static_for<0, NumDTensor, 1>{}([&](auto i) {
352  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
353  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
354 
355  // D pointer
356  p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
357 
358  // D desc
359  ds_grid_desc_m_n_(i) =
360  DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaw, NRaw, StrideDs[i]);
361  });
362 
363  // populate desc for Ds/E
369  {
373 
377  }
378  }
379 
380  void Print() const
381  {
382  std::cout << "A[M, K]: " << a_grid_desc_m_k_ << std::endl;
383  std::cout << "B[N, K]: " << b_grid_desc_n_k_ << std::endl;
385  [&](auto i) { std::cout << "Ds[M, N]: " << ds_grid_desc_m_n_[i] << std::endl; });
386  std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
387  }
388 
389  // private:
390  // pointers
391  const ADataType* p_a_grid_;
392  const BDataType* p_b_grid_;
394  EDataType* p_e_grid_;
395 
396  // tensor descriptors for problem definiton
401 
402  // tensor descriptors for block/thread-wise copy
408 
409  // block-to-e-tile map
411 
412  // element-wise op
413  AElementwiseOperation a_element_op_;
414  BElementwiseOperation b_element_op_;
415  CDEElementwiseOperation cde_element_op_;
416 
417  // for checking vector load/store
421  };
422 
423  // Invoker
424  struct Invoker : public BaseInvoker
425  {
427 
428  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
429  {
431  arg.b_grid_desc_n_k_,
432  arg.ds_grid_desc_m_n_,
433  arg.e_grid_desc_m_n_,
434  arg.block_2_etile_map_))
435  {
436  throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
437  }
438 
439  const index_t grid_size =
440  arg.block_2_etile_map_.CalculateGridSize(arg.e_grid_desc_m_n_);
441 
442  auto launch_kernel = [&](auto has_main_k_block_loop) {
443  constexpr bool has_main_loop = has_main_k_block_loop.value;
444 
445  const auto kernel = kernel_gemm_multiple_d_xdl_cshuffle<
446  GridwiseGemm,
447  ADataType, // TODO: distiguish A/B datatype
448  BDataType, // TODO: distiguish A/B datatype
450  EDataType,
451  AElementwiseOperation,
452  BElementwiseOperation,
453  CDEElementwiseOperation,
459  has_main_loop>;
460 
461  return launch_and_time_kernel(stream_config,
462  kernel,
463  dim3(grid_size),
464  dim3(BlockSize),
465  0,
466  arg.p_a_grid_,
467  arg.p_b_grid_,
468  arg.p_ds_grid_,
469  arg.p_e_grid_,
470  arg.a_element_op_,
471  arg.b_element_op_,
472  arg.cde_element_op_,
477  arg.block_2_etile_map_);
478  };
479 
480  const auto K = arg.a_grid_desc_m_k_.GetLength(I1);
481 
483  {
485  }
486  else
487  {
488  return launch_kernel(integral_constant<bool, false>{});
489  }
490  }
491 
492  // polymorphic
493  float Run(const BaseArgument* p_arg,
494  const StreamConfig& stream_config = StreamConfig{}) override
495  {
496  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
497  }
498  };
499 
500  static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_)
501  {
502  // check vector load/store
505  // check vector load of A
506  if constexpr(is_same_v<ALayout, Row> && ABlockTransferSrcVectorDim == 2)
507  {
508  if(KRaw_ % ABlockTransferSrcScalarPerVector != 0)
509  {
510  return false;
511  }
512  }
513  else if constexpr(is_same_v<ALayout, Col> && ABlockTransferSrcVectorDim == 1)
514  {
515  // FIXME: not rigorous
516  if(MRaw_ % ABlockTransferSrcScalarPerVector != 0)
517  {
518  return false;
519  }
520  }
521  else
522  {
523  return false;
524  }
525  // check vector laod of B
526  if constexpr(is_same_v<BLayout, Col> && BBlockTransferSrcVectorDim == 2)
527  {
528  if(KRaw_ % BBlockTransferSrcScalarPerVector != 0)
529  {
530  return false;
531  }
532  }
533  else if constexpr(is_same_v<BLayout, Row> && BBlockTransferSrcVectorDim == 1)
534  {
535  // FIXME: not rigorous
536  if(NRaw_ % BBlockTransferSrcScalarPerVector != 0)
537  {
538  return false;
539  }
540  }
541  else
542  {
543  return false;
544  }
545 
546  // check vector load of Ds
547  // only support RowMajor for now
548  bool all_valid = true;
549 
550  static_for<0, NumDTensor, 1>{}([&](auto i) {
551  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
552 
553  if constexpr(!is_same_v<DLayout, Row>)
554  {
555  all_valid = false;
556  }
557  });
558 
559  if(!all_valid)
560  {
561  return false;
562  }
563 
564  // check vector store of E
565  // only support RowMajor for now
566  if constexpr(is_same_v<ELayout, Row>)
567  {
568  if(NRaw_ % CDEBlockTransferScalarPerVector_NPerBlock != 0)
569  {
570  return false;
571  }
572  }
573  else
574  {
575  return false;
576  }
577  return true;
578  }
579 
580  static bool IsSupportedArgument(const Argument& arg)
581  {
582  if(!ck::is_xdl_supported())
583  {
584  return false;
585  }
586 
587  return IsSupported(arg.MRaw_, arg.NRaw_, arg.KRaw_) and
589  arg.b_grid_desc_n_k_,
590  arg.ds_grid_desc_m_n_,
591  arg.e_grid_desc_m_n_,
592  arg.block_2_etile_map_);
593  }
594 
595  // polymorphic
596  bool IsSupportedArgument(const BaseArgument* p_arg) override
597  {
598  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
599  }
600 
601  static auto MakeArgument(const void* p_a,
602  const void* p_b,
603  std::array<const void*, NumDTensor> p_ds,
604  void* p_e,
605  index_t MRaw,
606  index_t NRaw,
607  index_t KRaw,
608  index_t StrideA,
609  index_t StrideB,
610  std::array<index_t, NumDTensor> StrideDs,
611  index_t StrideE,
612  AElementwiseOperation a_element_op,
613  BElementwiseOperation b_element_op,
614  CDEElementwiseOperation cde_element_op)
615  {
616  return Argument{p_a,
617  p_b,
618  p_ds,
619  p_e,
620  MRaw,
621  NRaw,
622  KRaw,
623  StrideA,
624  StrideB,
625  StrideDs,
626  StrideE,
627  a_element_op,
628  b_element_op,
629  cde_element_op};
630  }
631 
632  static auto MakeInvoker() { return Invoker{}; }
633 
634  // polymorphic
635  std::unique_ptr<BaseArgument>
636  MakeArgumentPointer(const void* p_a,
637  const void* p_b,
638  std::array<const void*, NumDTensor> p_ds,
639  void* p_e,
640  index_t MRaw,
641  index_t NRaw,
642  index_t KRaw,
643  index_t StrideA,
644  index_t StrideB,
645  std::array<ck::index_t, NumDTensor> StrideDs,
646  index_t StrideE,
647  AElementwiseOperation a_element_op,
648  BElementwiseOperation b_element_op,
649  CDEElementwiseOperation cde_element_op) override
650  {
651  return std::make_unique<Argument>(p_a,
652  p_b,
653  p_ds,
654  p_e,
655  MRaw,
656  NRaw,
657  KRaw,
658  StrideA,
659  StrideB,
660  StrideDs,
661  StrideE,
662  a_element_op,
663  b_element_op,
664  cde_element_op);
665  }
666 
667  // polymorphic
668  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
669  {
670  return std::make_unique<Invoker>(Invoker{});
671  }
672 
673  // polymorphic
674  std::string GetTypeString() const override
675  {
676  auto str = std::stringstream();
677 
678  std::map<LoopScheduler, std::string> LoopSchedToString{
679  {LoopScheduler::Default, "Default"}, {LoopScheduler::Interwave, "Interwave"}};
680 
681  std::map<PipelineVersion, std::string> PipelineVersionToString{{PipelineVersion::v1, "v1"},
682  {PipelineVersion::v2, "v2"}};
683 
684  // clang-format off
685  str << "DeviceGemmMultipleD_Xdl_CShuffle"
686  << "<"
687  << BlockSize << ", "
688  << MPerBlock << ", "
689  << NPerBlock << ", "
690  << KPerBlock << ", "
691  << AK1 << ", "
692  << BK1 << ", "
693  << MPerXDL << ", "
694  << NPerXDL << ", "
695  << MXdlPerWave << ", "
696  << NXdlPerWave << ", "
697  << ABlockTransferSrcScalarPerVector << ", "
698  << BBlockTransferSrcScalarPerVector << ", "
699  << CShuffleMXdlPerWavePerShuffle << ", "
700  << CShuffleNXdlPerWavePerShuffle << ", "
701  << getGemmSpecializationString(GemmSpec)
702  << ">"
703  << " LoopScheduler: "
704  << LoopSchedToString[LoopSched] << ", "
705  << "PipelineVersion: "
706  << PipelineVersionToString[PipelineVer];
707  // clang-format on
708 
709  return str.str();
710  }
711 
712  template <class ADesc, class BDesc, class DsDesc, class EDesc>
713  struct Descriptor
714  {
715  static constexpr auto ds_tuple()
716  {
717  return transform_tuples(
718  [&](auto d) constexpr { return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); },
719  DsDesc{});
720  }
722  remove_cvref_t<decltype(DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{}))>;
724  remove_cvref_t<decltype(DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{}))>;
725  using DsGridDesc_M_N = remove_cvref_t<decltype(ds_tuple())>;
727  remove_cvref_t<decltype(DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{}))>;
730  DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{})))>;
733  DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{})))>;
736  ds_tuple()))>;
739  DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>;
741  DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))>;
742 
743  // tensor descriptors for problem definiton
748 
749  // tensor descriptors for block/thread-wise copy
754 
755  // block-to-e-tile map
757 
758  // element-wise op
759  AElementwiseOperation a_element_op;
760  BElementwiseOperation b_element_op;
761  CDEElementwiseOperation cde_element_op;
762 
763  // for checking vector load/store
767 
769 
770  constexpr Descriptor(ADesc a,
771  BDesc b,
772  DsDesc ds,
773  EDesc e,
774  AElementwiseOperation a_element_op_,
775  BElementwiseOperation b_element_op_,
776  CDEElementwiseOperation cde_element_op_)
777  : a_grid_desc_m_k{DeviceOp::matrix_padder.PadADescriptor_M_K(a)},
778  b_grid_desc_n_k{DeviceOp::matrix_padder.PadBDescriptor_N_K(b)},
780  [&](auto d) constexpr { return DeviceOp::matrix_padder.PadCDescriptor_M_N(d); },
781  ds)},
782  e_grid_desc_m_n{DeviceOp::matrix_padder.PadCDescriptor_M_N(e)},
783  a_grid_desc_ak0_m_ak1{
784  GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k)},
785  b_grid_desc_bk0_n_bk1{
786  GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k)},
787  ds_grid_desc_mblock_mperblock_nblock_nperblock{
788  GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
790  [&](auto d) constexpr {
791  return DeviceOp::matrix_padder.PadCDescriptor_M_N(d);
792  },
793  ds))},
794  e_grid_desc_mblock_mperblock_nblock_nperblock{
795  GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
796  e_grid_desc_m_n)},
797  block_2_etile_map{GridwiseGemm::MakeDefaultBlock2ETileMap(e_grid_desc_m_n)},
798  has_main_k_block_loop{GridwiseGemm::CalculateHasMainKBlockLoop(
799  a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2))},
800  a_element_op{a_element_op_},
801  b_element_op{b_element_op_},
802  cde_element_op{cde_element_op_},
803  MRaw{e.GetLength(I0)},
804  NRaw{e.GetLength(I1)},
805  KRaw{a.GetLength(I1)}
806  {
807  }
808 
809  constexpr bool IsValid() const
810  {
811  return GridwiseGemm::CheckValidity(a_grid_desc_m_k,
812  b_grid_desc_n_k,
813  ds_grid_desc_m_n,
814  e_grid_desc_m_n,
815  block_2_etile_map) and
816  IsSupported(MRaw, NRaw, KRaw);
817  }
818 
819  constexpr index_t GetBlockSize() const { return BlockSize; }
820 
821  constexpr index_t GetGridSize() const
822  {
823  return block_2_etile_map.CalculateGridSize(e_grid_desc_m_n);
824  }
825  };
826 
827  template <class ADesc, class BDesc, class DsDesc, class EDesc>
828  static constexpr auto
830  BDesc b,
831  DsDesc ds,
832  EDesc e,
833  AElementwiseOperation a_element_op = AElementwiseOperation{},
834  BElementwiseOperation b_element_op = BElementwiseOperation{},
835  CDEElementwiseOperation cde_element_op = CDEElementwiseOperation{})
836  {
837  return Descriptor<ADesc, BDesc, DsDesc, EDesc>(
838  a, b, ds, e, a_element_op, b_element_op, cde_element_op);
839  }
840 
841  template <class Desc, class DsPointer>
842  __device__ static void Run(const Desc& desc,
843  const ADataType* __restrict__ p_a_grid,
844  const BDataType* __restrict__ p_b_grid,
845  DsPointer p_ds_grid,
846  EDataType* __restrict__ p_e_grid)
847  {
848  __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
849  assert(desc.IsValid());
850  if(desc.has_main_k_block_loop)
851  {
852  GridwiseGemm::template Run<true>(p_a_grid,
853  p_b_grid,
854  p_ds_grid,
855  p_e_grid,
856  p_shared_block,
857  desc.a_element_op,
858  desc.b_element_op,
859  desc.cde_element_op,
860  desc.a_grid_desc_ak0_m_ak1,
861  desc.b_grid_desc_bk0_n_bk1,
862  desc.ds_grid_desc_mblock_mperblock_nblock_nperblock,
863  desc.e_grid_desc_mblock_mperblock_nblock_nperblock,
864  desc.block_2_etile_map);
865  }
866  else
867  {
868  GridwiseGemm::template Run<false>(p_a_grid,
869  p_b_grid,
870  p_ds_grid,
871  p_e_grid,
872  p_shared_block,
873  desc.a_element_op,
874  desc.b_element_op,
875  desc.cde_element_op,
876  desc.a_grid_desc_ak0_m_ak1,
877  desc.b_grid_desc_bk0_n_bk1,
878  desc.ds_grid_desc_mblock_mperblock_nblock_nperblock,
879  desc.e_grid_desc_mblock_mperblock_nblock_nperblock,
880  desc.block_2_etile_map);
881  }
882  }
883 };
884 
885 } // namespace device
886 } // namespace tensor_operation
887 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:34
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:33
GemmSpecialization
Definition: gemm_specialization.hpp:11
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables... callables)
Definition: kernel_launch.hpp:72
Definition: ck.hpp:264
bool is_xdl_supported()
Definition: device_prop.hpp:54
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
__global__ void kernel_gemm_multiple_d_xdl_cshuffle(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map)
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:40
__host__ constexpr __device__ auto transform_tuples(F f, const X &x)
Definition: tuple_helper.hpp:86
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:289
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:17
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
__host__ static constexpr __device__ auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:221
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:396
__host__ static constexpr __device__ auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:204
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:403
__host__ static constexpr __device__ auto MakeDefaultBlock2ETileMap(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:254
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const DsGridDesc_M_N &ds_grid_desc_m_n, const EGridDesc_M_N &e_grid_desc_m_n, [[maybe_unused]] const Block2ETileMap &)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:329
__host__ static constexpr __device__ auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:187
__host__ static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N &ds_grid_desc_m_n)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:242
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: functional2.hpp:31
Definition: tensor_layout.hpp:21
Definition: tensor_layout.hpp:16
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:313
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:404
index_t MRaw_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:418
const BDataType * p_b_grid_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:392
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:403
EGridDesc_M_N e_grid_desc_m_n_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:400
BElementwiseOperation b_element_op_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:414
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:406
index_t KRaw_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:420
index_t NRaw_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:419
EDataType * p_e_grid_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:394
AGridDesc_M_K a_grid_desc_m_k_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:397
Block2ETileMap block_2_etile_map_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:410
AElementwiseOperation a_element_op_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:413
void Print() const
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:380
DsGridDesc_M_N ds_grid_desc_m_n_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:399
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:407
GridwiseGemm::DsGridPointer p_ds_grid_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:393
const ADataType * p_a_grid_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:391
CDEElementwiseOperation cde_element_op_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:415
Argument(const void *p_a_grid, const void *p_b_grid, std::array< const void *, NumDTensor > p_ds_grid, void *p_e_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:314
BGridDesc_N_K b_grid_desc_n_k_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:398
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:714
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:751
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{})))> AGridDesc_AK0_M_AK1
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:730
index_t NRaw
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:765
index_t MRaw
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:764
AElementwiseOperation a_element_op
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:759
constexpr index_t GetGridSize() const
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:821
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{})))> BGridDesc_BK0_N_BK1
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:733
DsGridDesc_M_N ds_grid_desc_m_n
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:746
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:750
constexpr bool IsValid() const
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:809
EGridDesc_M_N e_grid_desc_m_n
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:747
Block2ETileMap block_2_etile_map
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:756
remove_cvref_t< decltype(DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{}))> AGridDesc_M_K
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:722
constexpr Descriptor(ADesc a, BDesc b, DsDesc ds, EDesc e, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CDEElementwiseOperation cde_element_op_)
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:770
AGridDesc_M_K a_grid_desc_m_k
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:744
remove_cvref_t< decltype(ds_tuple())> DsGridDesc_M_N
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:725
remove_cvref_t< decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))> EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:739
constexpr index_t GetBlockSize() const
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:819
index_t KRaw
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:766
bool has_main_k_block_loop
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:768
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))> Block2ETileMap
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:741
remove_cvref_t< decltype(DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{}))> EGridDesc_M_N
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:727
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:752
BElementwiseOperation b_element_op
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:760
remove_cvref_t< decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_tuple()))> DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:736
CDEElementwiseOperation cde_element_op
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:761
static constexpr auto ds_tuple()
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:715
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:753
BGridDesc_N_K b_grid_desc_n_k
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:745
remove_cvref_t< decltype(DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{}))> BGridDesc_N_K
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:724
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:425
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:493
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:428
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:159
static auto MakeDsGridDescriptor_M_N(const std::array< index_t, NumDTensor > &MRaws, const std::array< index_t, NumDTensor > &NRaws, const std::array< index_t, NumDTensor > &DsStride)
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:227
static constexpr auto I1
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:165
GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer > GridwiseGemm
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:291
std::string GetTypeString() const override
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:674
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N({}, {}, {}))> DsGridDesc_M_N
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:243
static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_)
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:500
decltype(MakeBGridDescriptor_N_K(1, 1, 1)) BGridDesc_N_K
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:242
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:601
remove_cvref_t< decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))> EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:305
static constexpr auto make_descriptor(ADesc a, BDesc b, DsDesc ds, EDesc e, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CDEElementwiseOperation cde_element_op=CDEElementwiseOperation{})
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:829
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> Block2ETileMap
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:309
static constexpr auto I0
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:164
decltype(MakeAGridDescriptor_M_K(1, 1, 1)) AGridDesc_M_K
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:241
static constexpr auto matrix_padder
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:169
decltype(MakeEGridDescriptor_M_N< ELayout >(1, 1, 1)) EGridDesc_M_N
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:244
static __device__ void Run(const Desc &desc, const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid)
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:842
static constexpr index_t NumDTensor
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:162
static auto MakeInvoker()
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:632
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:636
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:596
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:668
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:580
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:190
static constexpr auto I3
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:167
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:209
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:172
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:296
static constexpr auto I2
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:166
remove_cvref_t< decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))> DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:302
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:299
Definition: device_gemm_multiple_d.hpp:34
Definition: matrix_padder.hpp:180