/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp Source File
device_gemm_multiple_d_dl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
18 
19 namespace ck {
20 
21 template <typename GridwiseGemm,
22  typename ABDataType,
23  typename DsPointer,
24  typename EDataType,
25  typename AElementwiseOperation,
26  typename BElementwiseOperation,
27  typename CDEElementwiseOperation,
28  typename AGridDesc_K0_M0_M1_K1,
29  typename BGridDesc_K0_N0_N1_K1,
30  typename DsGridDesc_M0_M10_M11_N0_N10_N11,
31  typename CGridDesc_M0_M10_M11_N0_N10_N11,
32  typename Block2CTileMap,
33  bool HasMainKBlockLoop,
34  bool HasDoubleTailKBlockLoop>
35 __global__ void
36 #if CK_USE_LAUNCH_BOUNDS
38 #endif
40  const ABDataType* __restrict__ p_a_grid,
41  const ABDataType* __restrict__ p_b_grid,
42  DsPointer p_ds_grid,
43  EDataType* __restrict__ p_e_grid,
44  const AElementwiseOperation a_element_op,
45  const BElementwiseOperation b_element_op,
46  const CDEElementwiseOperation cde_element_op,
47  const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
48  const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
49  const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11,
50  const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
51  const Block2CTileMap block_2_ctile_map)
52 {
53 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx9__) || \
54  defined(__gfx103__) || defined(__gfx11__) || defined(__gfx12__))
55 
56  constexpr index_t shared_block_size =
57  GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
58 
59  __shared__ ABDataType p_shared[shared_block_size];
60 
61  GridwiseGemm::Run(p_a_grid,
62  p_b_grid,
63  p_ds_grid,
64  p_e_grid,
65  p_shared,
66  a_element_op,
67  b_element_op,
68  cde_element_op,
69  a_grid_desc_k0_m0_m1_k1,
70  b_grid_desc_k0_n0_n1_k1,
71  ds_grid_desc_m0_m10_m11_n0_n10_n11,
72  e_grid_desc_m0_m10_m11_n0_n10_n11,
73  block_2_ctile_map,
76 #else
77  ignore = p_a_grid;
78  ignore = p_b_grid;
79  ignore = p_ds_grid;
80  ignore = p_e_grid;
81  ignore = a_element_op;
82  ignore = b_element_op;
83  ignore = cde_element_op;
84  ignore = a_grid_desc_k0_m0_m1_k1;
85  ignore = b_grid_desc_k0_n0_n1_k1;
86  ignore = ds_grid_desc_m0_m10_m11_n0_n10_n11;
87  ignore = e_grid_desc_m0_m10_m11_n0_n10_n11;
88  ignore = block_2_ctile_map;
89 #endif
90 }
91 } // namespace ck
92 
93 namespace ck {
94 namespace tensor_operation {
95 namespace device {
96 
97 template <typename ALayout,
98  typename BLayout,
99  typename DsLayout,
100  typename ELayout,
101  typename ADataType,
102  typename BDataType,
103  typename AccDataType,
104  typename DsDataType,
105  typename EDataType,
106  typename AElementwiseOperation,
107  typename BElementwiseOperation,
108  typename CDEElementwiseOperation,
109  GemmSpecialization GemmSpec,
110  index_t BlockSize,
111  index_t MPerBlock,
112  index_t NPerBlock,
113  index_t K0PerBlock,
114  index_t K1,
115  index_t M1PerThread,
116  index_t N1PerThread,
117  index_t KPerThread,
118  typename M1N1ThreadClusterM1Xs,
119  typename M1N1ThreadClusterN1Xs,
120  typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
121  typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
122  typename ABlockTransferThreadClusterArrangeOrder,
123  typename ABlockTransferSrcAccessOrder,
124  typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
125  typename ABlockTransferSrcVectorTensorContiguousDimOrder,
126  typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
127  typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
128  typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
129  typename BBlockTransferThreadClusterArrangeOrder,
130  typename BBlockTransferSrcAccessOrder,
131  typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
132  typename BBlockTransferSrcVectorTensorContiguousDimOrder,
133  typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
134  typename CThreadTransferSrcDstAccessOrder,
135  index_t CThreadTransferSrcDstVectorDim,
136  index_t CThreadTransferDstScalarPerVector,
137  enable_if_t<
138  is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
139  is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
140  bool> = false>
142  BLayout,
143  DsLayout,
144  ELayout,
145  ADataType,
146  BDataType,
147  DsDataType,
148  EDataType,
149  AElementwiseOperation,
150  BElementwiseOperation,
151  CDEElementwiseOperation>
152 
153 {
155  static constexpr index_t NumDTensor = DsDataType::Size();
156 
157  static constexpr auto I0 = Number<0>{};
158  static constexpr auto I1 = Number<1>{};
159  static constexpr auto I2 = Number<2>{};
160  static constexpr auto I3 = Number<3>{};
161  static constexpr auto I4 = Number<4>{};
162  static constexpr auto I5 = Number<5>{};
163 
164  static constexpr auto K1Number = Number<K1>{};
165 
167  {
168  assert(K % K1 == 0);
169 
170  const index_t K0 = K / K1;
171 
172  const auto a_grid_desc_m_k = [&]() {
174  {
175  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
176  }
178  {
179  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
180  }
181  }();
182 
183  if constexpr(GemmSpec == GemmSpecialization::MNPadding)
184  {
185  const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
186 
188  a_grid_desc_m_k,
190  make_right_pad_transform(M, PadM)),
193  }
194  else
195  {
197  a_grid_desc_m_k,
202  }
203  }
204 
206  {
207  assert(K % K1 == 0);
208 
209  const index_t K0 = K / K1;
210 
211  const auto b_grid_desc_k_n = [&]() {
213  {
214  return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
215  }
217  {
218  return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
219  }
220  }();
221 
222  if constexpr(GemmSpec == GemmSpecialization::MNPadding)
223  {
224  const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
225 
227  b_grid_desc_k_n,
229  make_right_pad_transform(N, PadN)),
232  }
233  else
234  {
236  b_grid_desc_k_n,
241  }
242  }
243 
244  template <typename ELay>
245  static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE)
246  {
247  const auto c_grid_desc_m_n = [&]() {
249  {
250  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideE, I1));
251  }
253  {
254  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideE));
255  }
256  }();
257 
258  if constexpr(GemmSpec == GemmSpecialization::MNPadding)
259  {
260  const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
261  const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
262 
264  c_grid_desc_m_n,
268  }
269  else
270  {
271 
273  c_grid_desc_m_n,
277  }
278  }
279 
280  static auto MakeDsGridDescriptor_M_N(const std::array<index_t, NumDTensor>& MRaws,
281  const std::array<index_t, NumDTensor>& NRaws,
282  const std::array<index_t, NumDTensor>& DsStride)
283  {
284  return generate_tuple(
285  [&](auto i) {
286  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
287 
288  return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(MRaws[i], NRaws[i], DsStride[i]);
289  },
291  }
292 
293  using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
294  using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
295  using DsGridDesc_M_N = decltype(MakeDsGridDescriptor_M_N({}, {}, {}));
296  using EGridDesc_M_N = decltype(MakeEGridDescriptor_M_N<ELayout>(1, 1, 1));
297 
298  // GridwiseGemm
299  using GridwiseGemm =
301  ADataType,
302  AccDataType,
303  DsDataType,
304  EDataType,
305  AElementwiseOperation,
306  BElementwiseOperation,
307  CDEElementwiseOperation,
312  MPerBlock,
313  NPerBlock,
314  K0PerBlock,
315  K1,
316  M1PerThread,
317  N1PerThread,
318  KPerThread,
319  M1N1ThreadClusterM1Xs,
320  M1N1ThreadClusterN1Xs,
321  ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
322  ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
323  ABlockTransferThreadClusterArrangeOrder,
324  ABlockTransferSrcAccessOrder,
325  ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
326  ABlockTransferSrcVectorTensorContiguousDimOrder,
327  ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
328  BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
329  BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
330  BBlockTransferThreadClusterArrangeOrder,
331  BBlockTransferSrcAccessOrder,
332  BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
333  BBlockTransferSrcVectorTensorContiguousDimOrder,
334  BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
335  CThreadTransferSrcDstAccessOrder,
336  CThreadTransferSrcDstVectorDim,
337  CThreadTransferDstScalarPerVector>;
338 
349 
350  // Argument
351  struct Argument : public BaseArgument
352  {
353  Argument(const void* p_a_grid,
354  const void* p_b_grid,
355  std::array<const void*, NumDTensor> p_ds_grid,
356  void* p_e_grid,
357  index_t M,
358  index_t N,
359  index_t K,
360  index_t StrideA,
361  index_t StrideB,
362  std::array<index_t, NumDTensor> StrideDs,
363  index_t StrideE,
364  AElementwiseOperation a_element_op,
365  BElementwiseOperation b_element_op,
366  CDEElementwiseOperation cde_element_op)
367  : p_a_grid_{static_cast<const ADataType*>(p_a_grid)},
368  p_b_grid_{static_cast<const BDataType*>(p_b_grid)},
369  p_ds_grid_{},
370  p_e_grid_{static_cast<EDataType*>(p_e_grid)},
375  a_element_op_{a_element_op},
376  b_element_op_{b_element_op},
377  cde_element_op_{cde_element_op}
378  {
383  static_for<0, NumDTensor, 1>{}([&](auto i) {
384  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
385  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
386 
387  // D pointer
388  p_ds_grid_(i) = static_cast<const DDataType*>(p_ds_grid[i]);
389 
390  // D desc
391  ds_grid_desc_m_n_(i) =
392  DeviceOp::MakeEGridDescriptor_M_N<DLayout>(M, N, StrideDs[i]);
393  });
395  DeviceGemmMultipleD_Dl::MakeEGridDescriptor_M_N<ELayout>(M, N, StrideE);
396 
399  {
404 
407 
410 
412  }
413  }
414 
415  // private:
416  const ADataType* p_a_grid_;
417  const BDataType* p_b_grid_;
419  EDataType* p_e_grid_;
420 
425 
430 
432 
433  // TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being.
434  AElementwiseOperation a_element_op_;
435  BElementwiseOperation b_element_op_;
436  CDEElementwiseOperation cde_element_op_;
437  };
438 
439  // Invoker
440  struct Invoker : public BaseInvoker
441  {
443 
444  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
445  {
446  {
447  std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{"
448  << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
449  << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
450  << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
451 
452  std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{"
453  << arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
454  << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
455  << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
456 
457  std::cout << "arg.e_grid_desc_m_n_{ " << arg.e_grid_desc_m_n_.GetLength(I0) << ", "
458  << arg.e_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
459  }
460 
463  {
464  throw std::runtime_error(
465  "wrong! GridwiseGemmDlMultipleD_km_kn_mn has invalid setting");
466  }
467 
468  const index_t grid_size = GridwiseGemm::CalculateGridSize(
469  arg.e_grid_desc_m_n_.GetLength(I0), arg.e_grid_desc_m_n_.GetLength(I1));
470 
471  auto launch_kernel = [&](auto has_main_k_block_loop,
472  auto has_double_tail_k_block_loop) {
473  constexpr bool has_main_loop = has_main_k_block_loop.value;
474  constexpr bool has_double_loop = has_double_tail_k_block_loop.value;
475 
476  const auto kernel =
478  ADataType,
480  EDataType,
481  AElementwiseOperation,
482  BElementwiseOperation,
483  CDEElementwiseOperation,
489  has_main_loop,
490  has_double_loop>;
491 
492  return launch_and_time_kernel(stream_config,
493  kernel,
494  dim3(grid_size),
495  dim3(BlockSize),
496  0,
497  arg.p_a_grid_,
498  arg.p_b_grid_,
499  arg.p_ds_grid_,
500  arg.p_e_grid_,
501  arg.a_element_op_,
502  arg.b_element_op_,
503  arg.cde_element_op_,
508  arg.block_2_ctile_map_);
509  };
510 
511  const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
512  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
513  const bool has_double_tail_k_block_loop =
515 
516  if(has_main_k_block_loop && has_double_tail_k_block_loop)
517  {
519  integral_constant<bool, true>{});
520  }
521  else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
522  {
523  return launch_kernel(integral_constant<bool, true>{},
524  integral_constant<bool, false>{});
525  }
526  else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
527  {
528  return launch_kernel(integral_constant<bool, false>{},
529  integral_constant<bool, true>{});
530  }
531  else
532  {
533  return launch_kernel(integral_constant<bool, false>{},
534  integral_constant<bool, false>{});
535  }
536  }
537 
538  // polymorphic
539  float Run(const BaseArgument* p_arg,
540  const StreamConfig& stream_config = StreamConfig{}) override
541  {
542  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
543  }
544  };
545 
546  static constexpr bool IsValidCompilationParameter()
547  {
548  // TODO: properly implement this check
549  return true;
550  }
551 
552  static bool IsSupportedArgument(const Argument& arg)
553  {
554  if(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
556  {
559  }
560  else
561  {
562  return false;
563  }
564  }
565 
566  // polymorphic
567  bool IsSupportedArgument(const BaseArgument* p_arg) override
568  {
569  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
570  }
571 
572  static auto MakeArgument(const void* p_a,
573  const void* p_b,
574  std::array<const void*, NumDTensor> p_ds,
575  void* p_e,
576  index_t M,
577  index_t N,
578  index_t K,
579  index_t StrideA,
580  index_t StrideB,
581  std::array<ck::index_t, NumDTensor> StrideDs,
582  index_t StrideE,
583  AElementwiseOperation a_element_op,
584  BElementwiseOperation b_element_op,
585  CDEElementwiseOperation cde_element_op)
586  {
587  return Argument{p_a,
588  p_b,
589  p_ds,
590  p_e,
591  M,
592  N,
593  K,
594  StrideA,
595  StrideB,
596  StrideDs,
597  StrideE,
598  a_element_op,
599  b_element_op,
600  cde_element_op};
601  }
602 
603  static auto MakeInvoker() { return Invoker{}; }
604 
605  // polymorphic
606  std::unique_ptr<BaseArgument>
607  MakeArgumentPointer(const void* p_a,
608  const void* p_b,
609  std::array<const void*, NumDTensor> p_ds,
610  void* p_e,
611  index_t M,
612  index_t N,
613  index_t K,
614  index_t StrideA,
615  index_t StrideB,
616  std::array<ck::index_t, NumDTensor> StrideDs,
617  index_t StrideE,
618  AElementwiseOperation a_element_op,
619  BElementwiseOperation b_element_op,
620  CDEElementwiseOperation cde_element_op) override
621  {
622  return std::make_unique<Argument>(p_a,
623  p_b,
624  p_ds,
625  p_e,
626  M,
627  N,
628  K,
629  StrideA,
630  StrideB,
631  StrideDs,
632  StrideE,
633  a_element_op,
634  b_element_op,
635  cde_element_op);
636  }
637 
638  // polymorphic
639  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
640  {
641  return std::make_unique<Invoker>(Invoker{});
642  }
643 
644  // polymorphic
645  std::string GetTypeString() const override
646  {
647  auto str = std::stringstream();
648 
649  // clang-format off
650  str << "DeviceGemmMultipleD_Dl"
651  << "<"
652  << BlockSize << ", "
653  << MPerBlock << ", "
654  << NPerBlock << ", "
655  << K0PerBlock << ", "
656  << K1 << ", "
657  << M1PerThread << ", "
658  << N1PerThread << ", "
659  << KPerThread
660  << ">";
661  // clang-format on
662 
663  return str.str();
664  }
665 };
666 
667 } // namespace device
668 } // namespace tensor_operation
669 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:34
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__global__ void kernel_gemm_dl_multiple_d(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const index_t batch_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2CTileMap block_2_ctile_map)
Definition: device_batched_gemm_multiple_d_dl.hpp:57
GemmSpecialization
Definition: gemm_specialization.hpp:11
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables... callables)
Definition: kernel_launch.hpp:72
Definition: ck.hpp:264
bool is_xdl_supported()
Definition: device_prop.hpp:54
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
std::string get_device_name()
Definition: device_prop.hpp:12
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
bool is_gfx12_supported()
Definition: device_prop.hpp:94
__global__ void kernel_gemm_dl_multiple_d(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map)
Definition: device_gemm_multiple_d_dl.hpp:39
bool is_gfx103_supported()
Definition: device_prop.hpp:81
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:13
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
bool is_gfx11_supported()
Definition: device_prop.hpp:88
Definition: stream_config.hpp:10
Definition: gridwise_gemm_dl_multiple_d.hpp:60
__host__ static constexpr __device__ auto MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1)
Definition: gridwise_gemm_dl_multiple_d.hpp:178
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_multiple_d.hpp:143
__host__ static constexpr __device__ auto MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(const DsGridDesc_M_N &ds_grid_desc_m_n)
Definition: gridwise_gemm_dl_multiple_d.hpp:234
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_multiple_d.hpp:242
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_dl_multiple_d.hpp:253
__host__ static constexpr __device__ auto MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1)
Definition: gridwise_gemm_dl_multiple_d.hpp:158
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_dl_multiple_d.hpp:136
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N_ &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_multiple_d.hpp:200
__host__ static constexpr __device__ bool CalculateHasDoubleTailKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_multiple_d.hpp:150
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_multiple_d.hpp:110
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: functional2.hpp:31
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_gemm_multiple_d_dl.hpp:352
GridwiseGemm::DsGridPointer p_ds_grid_
Definition: device_gemm_multiple_d_dl.hpp:418
const BDataType * p_b_grid_
Definition: device_gemm_multiple_d_dl.hpp:417
BElementwiseOperation b_element_op_
Definition: device_gemm_multiple_d_dl.hpp:435
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_
Definition: device_gemm_multiple_d_dl.hpp:427
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_
Definition: device_gemm_multiple_d_dl.hpp:426
CDEElementwiseOperation cde_element_op_
Definition: device_gemm_multiple_d_dl.hpp:436
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_
Definition: device_gemm_multiple_d_dl.hpp:422
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_
Definition: device_gemm_multiple_d_dl.hpp:421
EDataType * p_e_grid_
Definition: device_gemm_multiple_d_dl.hpp:419
Argument(const void *p_a_grid, const void *p_b_grid, std::array< const void *, NumDTensor > p_ds_grid, void *p_e_grid, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition: device_gemm_multiple_d_dl.hpp:353
DefaultBlock2CTileMap block_2_ctile_map_
Definition: device_gemm_multiple_d_dl.hpp:431
EGridDesc_M_N e_grid_desc_m_n_
Definition: device_gemm_multiple_d_dl.hpp:424
const ADataType * p_a_grid_
Definition: device_gemm_multiple_d_dl.hpp:416
DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11_
Definition: device_gemm_multiple_d_dl.hpp:428
DsGridDesc_M_N ds_grid_desc_m_n_
Definition: device_gemm_multiple_d_dl.hpp:423
AElementwiseOperation a_element_op_
Definition: device_gemm_multiple_d_dl.hpp:434
EGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11_
Definition: device_gemm_multiple_d_dl.hpp:429
Definition: device_gemm_multiple_d_dl.hpp:441
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_multiple_d_dl.hpp:444
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_multiple_d_dl.hpp:539
Definition: device_gemm_multiple_d_dl.hpp:153
GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector > GridwiseGemm
Definition: device_gemm_multiple_d_dl.hpp:337
std::string GetTypeString() const override
Definition: device_gemm_multiple_d_dl.hpp:645
decltype(GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(DsGridDesc_M_N{})) DsGridDesc_M0_M10_M11_N0_N10_N11
Definition: device_gemm_multiple_d_dl.hpp:344
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{})) DefaultBlock2CTileMap
Definition: device_gemm_multiple_d_dl.hpp:348
decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)) AGridDesc_K0_M_K1
Definition: device_gemm_multiple_d_dl.hpp:293
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_multiple_d_dl.hpp:546
static auto MakeDsGridDescriptor_M_N(const std::array< index_t, NumDTensor > &MRaws, const std::array< index_t, NumDTensor > &NRaws, const std::array< index_t, NumDTensor > &DsStride)
Definition: device_gemm_multiple_d_dl.hpp:280
static constexpr auto I4
Definition: device_gemm_multiple_d_dl.hpp:161
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})) BGridDesc_K0_N0_N1_K1
Definition: device_gemm_multiple_d_dl.hpp:342
static constexpr index_t NumDTensor
Definition: device_gemm_multiple_d_dl.hpp:155
static constexpr auto K1Number
Definition: device_gemm_multiple_d_dl.hpp:164
decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)) BGridDesc_K0_N_K1
Definition: device_gemm_multiple_d_dl.hpp:294
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_multiple_d_dl.hpp:552
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition: device_gemm_multiple_d_dl.hpp:572
static auto MakeEGridDescriptor_M_N(index_t M, index_t N, index_t StrideE)
Definition: device_gemm_multiple_d_dl.hpp:245
static constexpr auto I3
Definition: device_gemm_multiple_d_dl.hpp:160
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
Definition: device_gemm_multiple_d_dl.hpp:205
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_multiple_d_dl.hpp:639
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_multiple_d_dl.hpp:567
decltype(MakeDsGridDescriptor_M_N({}, {}, {})) DsGridDesc_M_N
Definition: device_gemm_multiple_d_dl.hpp:295
static constexpr auto I0
Definition: device_gemm_multiple_d_dl.hpp:157
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
Definition: device_gemm_multiple_d_dl.hpp:166
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})) AGridDesc_K0_M0_M1_K1
Definition: device_gemm_multiple_d_dl.hpp:340
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(EGridDesc_M_N{})) EGridDesc_M0_M10_M11_N0_N10_N11
Definition: device_gemm_multiple_d_dl.hpp:346
static constexpr auto I2
Definition: device_gemm_multiple_d_dl.hpp:159
decltype(MakeEGridDescriptor_M_N< ELayout >(1, 1, 1)) EGridDesc_M_N
Definition: device_gemm_multiple_d_dl.hpp:296
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition: device_gemm_multiple_d_dl.hpp:607
static constexpr auto I1
Definition: device_gemm_multiple_d_dl.hpp:158
static auto MakeInvoker()
Definition: device_gemm_multiple_d_dl.hpp:603
static constexpr auto I5
Definition: device_gemm_multiple_d_dl.hpp:162
Definition: device_gemm_multiple_d.hpp:34