/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp Source File
device_gemm_dl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
18 
19 namespace ck {
20 namespace tensor_operation {
21 namespace device {
22 
23 template <
24  typename ADataType,
25  typename BDataType,
26  typename CDataType,
27  typename AccDataType,
28  typename ALayout,
29  typename BLayout,
30  typename CLayout,
31  typename AElementwiseOperation,
32  typename BElementwiseOperation,
33  typename CElementwiseOperation,
34  GemmSpecialization GemmSpec,
35  index_t BlockSize,
36  index_t MPerBlock,
37  index_t NPerBlock,
38  index_t K0PerBlock,
39  index_t K1,
40  index_t M1PerThread,
41  index_t N1PerThread,
42  index_t KPerThread,
43  typename M1N1ThreadClusterM1Xs,
44  typename M1N1ThreadClusterN1Xs,
45  typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
46  typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
47  typename ABlockTransferThreadClusterArrangeOrder,
48  typename ABlockTransferSrcAccessOrder,
49  typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
50  typename ABlockTransferSrcVectorTensorContiguousDimOrder,
51  typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
52  typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
53  typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
54  typename BBlockTransferThreadClusterArrangeOrder,
55  typename BBlockTransferSrcAccessOrder,
56  typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
57  typename BBlockTransferSrcVectorTensorContiguousDimOrder,
58  typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
59  typename CThreadTransferSrcDstAccessOrder,
60  index_t CThreadTransferSrcDstVectorDim,
61  index_t CThreadTransferDstScalarPerVector,
63  is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
64  is_same_v<BElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
65  is_same_v<CElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
66  bool> = false>
67 struct DeviceGemmDl : public DeviceGemm<ALayout,
68  BLayout,
69  CLayout,
70  ADataType,
71  BDataType,
72  CDataType,
73  AElementwiseOperation,
74  BElementwiseOperation,
75  CElementwiseOperation>
76 
77 {
78  static constexpr auto I0 = Number<0>{};
79  static constexpr auto I1 = Number<1>{};
80  static constexpr auto I2 = Number<2>{};
81  static constexpr auto I3 = Number<3>{};
82  static constexpr auto I4 = Number<4>{};
83  static constexpr auto I5 = Number<5>{};
84 
85  static constexpr auto K1Number = Number<K1>{};
86 
88  {
89  assert(K % K1 == 0);
90 
91  const index_t K0 = K / K1;
92 
93  const auto a_grid_desc_m_k = [&]() {
95  {
96  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1));
97  }
99  {
100  return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA));
101  }
102  }();
103 
104  if constexpr(GemmSpec == GemmSpecialization::MNPadding)
105  {
106  const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
107 
109  a_grid_desc_m_k,
111  make_right_pad_transform(M, PadM)),
114  }
115  else
116  {
118  a_grid_desc_m_k,
123  }
124  }
125 
127  {
128  assert(K % K1 == 0);
129 
130  const index_t K0 = K / K1;
131 
132  const auto b_grid_desc_k_n = [&]() {
134  {
135  return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1));
136  }
138  {
139  return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB));
140  }
141  }();
142 
143  if constexpr(GemmSpec == GemmSpecialization::MNPadding)
144  {
145  const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
146 
148  b_grid_desc_k_n,
150  make_right_pad_transform(N, PadN)),
153  }
154  else
155  {
157  b_grid_desc_k_n,
162  }
163  }
164 
165  static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
166  {
167  const auto c_grid_desc_m_n = [&]() {
169  {
170  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1));
171  }
173  {
174  return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC));
175  }
176  }();
177 
178  if constexpr(GemmSpec == GemmSpecialization::MNPadding)
179  {
180  const auto PadM = (MPerBlock - M % MPerBlock) % MPerBlock;
181  const auto PadN = (NPerBlock - N % NPerBlock) % NPerBlock;
182 
184  c_grid_desc_m_n,
188  }
189  else
190  {
191 
193  c_grid_desc_m_n,
197  }
198  }
199 
200  using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1));
201  using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
202  using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
203 
204  // GridwiseGemm
205  using GridwiseGemm =
207  ADataType,
208  AccDataType,
209  CDataType,
214  MPerBlock,
215  NPerBlock,
216  K0PerBlock,
217  K1,
218  M1PerThread,
219  N1PerThread,
220  KPerThread,
221  M1N1ThreadClusterM1Xs,
222  M1N1ThreadClusterN1Xs,
223  ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
224  ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
225  ABlockTransferThreadClusterArrangeOrder,
226  ABlockTransferSrcAccessOrder,
227  ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
228  ABlockTransferSrcVectorTensorContiguousDimOrder,
229  ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
230  BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
231  BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
232  BBlockTransferThreadClusterArrangeOrder,
233  BBlockTransferSrcAccessOrder,
234  BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
235  BBlockTransferSrcVectorTensorContiguousDimOrder,
236  BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
237  CThreadTransferSrcDstAccessOrder,
238  CThreadTransferSrcDstVectorDim,
239  CThreadTransferDstScalarPerVector>;
240 
249 
250  // Argument
251  struct Argument : public BaseArgument
252  {
253  Argument(const ADataType* p_a_grid,
254  const BDataType* p_b_grid,
255  CDataType* p_c_grid,
256  index_t M,
257  index_t N,
258  index_t K,
259  index_t StrideA,
260  index_t StrideB,
261  index_t StrideC,
262  index_t M01,
263  index_t N01,
264  AElementwiseOperation a_element_op,
265  BElementwiseOperation b_element_op,
266  CElementwiseOperation c_element_op)
267  : p_a_grid_{p_a_grid},
268  p_b_grid_{p_b_grid},
269  p_c_grid_{p_c_grid},
274  M01_{M01},
275  N01_{N01},
276  M_raw_{M},
277  N_raw_{N},
278  K_raw_{K},
279  a_element_op_{a_element_op},
280  b_element_op_{b_element_op},
281  c_element_op_{c_element_op}
282  {
286 
289  {
296 
298  }
299  }
300 
301  // private:
302  const ADataType* p_a_grid_;
303  const BDataType* p_b_grid_;
304  CDataType* p_c_grid_;
305 
309 
313 
315 
316  // TODO: unused, but may be useful in future.
319 
323 
324  // TODO: unused since gridwise_gemm_dl_v1r3 does NOT support prologue for the time being.
325  AElementwiseOperation a_element_op_;
326  BElementwiseOperation b_element_op_;
327  CElementwiseOperation c_element_op_;
328  };
329 
330  // Invoker
331  struct Invoker : public BaseInvoker
332  {
334 
335  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
336  {
337  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
338  {
339  std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{"
340  << arg.a_grid_desc_k0_m_k1_.GetLength(I0) << ", "
341  << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
342  << arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
343 
344  std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{"
345  << arg.b_grid_desc_k0_n_k1_.GetLength(I0) << ", "
346  << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
347  << arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
348 
349  std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
350  << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
351  }
352 
355  {
356  throw std::runtime_error(
357  "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdl_v2r3 has invalid setting");
358  }
359 
360  const index_t grid_size = GridwiseGemm::CalculateGridSize(
361  arg.c_grid_desc_m_n_.GetLength(I0), arg.c_grid_desc_m_n_.GetLength(I1));
362 
363  const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
364  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
365  const bool has_double_tail_k_block_loop =
367 
368  float ave_time = 0;
369 
370  if(has_main_k_block_loop && has_double_tail_k_block_loop)
371  {
372  const auto kernel =
374  ADataType,
375  CDataType,
380  true,
381  true>;
382 
383  ave_time = launch_and_time_kernel(stream_config,
384  kernel,
385  dim3(grid_size),
386  dim3(BlockSize),
387  0,
388  arg.p_a_grid_,
389  arg.p_b_grid_,
390  arg.p_c_grid_,
394  arg.block_2_ctile_map_);
395  }
396  else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
397  {
398  const auto kernel =
400  ADataType,
401  CDataType,
406  true,
407  false>;
408 
409  ave_time = launch_and_time_kernel(stream_config,
410  kernel,
411  dim3(grid_size),
412  dim3(BlockSize),
413  0,
414  arg.p_a_grid_,
415  arg.p_b_grid_,
416  arg.p_c_grid_,
420  arg.block_2_ctile_map_);
421  }
422  else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
423  {
424  const auto kernel =
426  ADataType,
427  CDataType,
428  remove_reference_t<AGridDesc_K0_M0_M1_K1>,
429  remove_reference_t<BGridDesc_K0_N0_N1_K1>,
430  remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
431  remove_reference_t<DefaultBlock2CTileMap>,
432  false,
433  true>;
434 
435  ave_time = launch_and_time_kernel(stream_config,
436  kernel,
437  dim3(grid_size),
438  dim3(BlockSize),
439  0,
440  arg.p_a_grid_,
441  arg.p_b_grid_,
442  arg.p_c_grid_,
446  arg.block_2_ctile_map_);
447  }
448  else
449  {
450  const auto kernel =
452  ADataType,
453  CDataType,
454  remove_reference_t<AGridDesc_K0_M0_M1_K1>,
455  remove_reference_t<BGridDesc_K0_N0_N1_K1>,
456  remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
457  remove_reference_t<DefaultBlock2CTileMap>,
458  false,
459  false>;
460 
461  ave_time = launch_and_time_kernel(stream_config,
462  kernel,
463  dim3(grid_size),
464  dim3(BlockSize),
465  0,
466  arg.p_a_grid_,
467  arg.p_b_grid_,
468  arg.p_c_grid_,
472  arg.block_2_ctile_map_);
473  }
474 
475  return ave_time;
476  }
477 
478  // polymorphic
479  float Run(const BaseArgument* p_arg,
480  const StreamConfig& stream_config = StreamConfig{}) override
481  {
482  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
483  }
484  };
485 
486  static constexpr bool IsValidCompilationParameter()
487  {
488  // TODO: properly implement this check
489  return true;
490  }
491 
492  static bool IsSupportedArgument(const Argument& arg)
493  {
494  // Make sure that the M, N, K dimensions before padding are divisible by respective vector
495  // lengths.
497  {
498  constexpr auto A_K_vec_length =
499  ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I0) *
500  ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I3);
501  if(arg.K_raw_ % A_K_vec_length != 0)
502  {
503  return false;
504  }
505  }
506  else
507  {
508  constexpr auto A_M_vec_lenght =
509  ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I1) *
510  ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1::At(I2);
511  if(arg.M_raw_ % A_M_vec_lenght != 0)
512  {
513  return false;
514  }
515  }
516 
518  {
519  constexpr auto B_N_vec_lenght =
520  BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I1) *
521  BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I2);
522  if(arg.N_raw_ % B_N_vec_lenght != 0)
523  {
524  return false;
525  }
526  }
527  else
528  {
529  constexpr auto B_K_vec_length =
530  BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I0) *
531  BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1::At(I3);
532  if(arg.K_raw_ % B_K_vec_length != 0)
533  {
534  return false;
535  }
536  }
537 
538  if(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
540  {
543  }
544  return false;
545  }
546 
547  // polymorphic
548  bool IsSupportedArgument(const BaseArgument* p_arg) override
549  {
550  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
551  }
552 
553  static auto MakeArgument(const ADataType* p_a,
554  const BDataType* p_b,
555  CDataType* p_c,
556  index_t M,
557  index_t N,
558  index_t K,
559  index_t StrideA,
560  index_t StrideB,
561  index_t StrideC,
562  AElementwiseOperation a_element_op,
563  BElementwiseOperation b_element_op,
564  CElementwiseOperation c_element_op)
565  {
566  return Argument{p_a,
567  p_b,
568  p_c,
569  M,
570  N,
571  K,
572  StrideA,
573  StrideB,
574  StrideC,
575  1,
576  1,
577  a_element_op,
578  b_element_op,
579  c_element_op};
580  }
581 
582  static auto MakeInvoker() { return Invoker{}; }
583 
584  // polymorphic
585  std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
586  const void* p_b,
587  void* p_c,
588  index_t M,
589  index_t N,
590  index_t K,
591  index_t StrideA,
592  index_t StrideB,
593  index_t StrideC,
594  AElementwiseOperation a_element_op,
595  BElementwiseOperation b_element_op,
596  CElementwiseOperation c_element_op) override
597  {
598  return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
599  static_cast<const BDataType*>(p_b),
600  static_cast<CDataType*>(p_c),
601  M,
602  N,
603  K,
604  StrideA,
605  StrideB,
606  StrideC,
607  1,
608  1,
609  a_element_op,
610  b_element_op,
611  c_element_op);
612  }
613 
614  // polymorphic
615  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
616  {
617  return std::make_unique<Invoker>(Invoker{});
618  }
619 
620  // polymorphic
621  virtual std::string GetTypeString() const override
622  {
623  auto str = std::stringstream();
624 
625  // clang-format off
626  str << "DeviceGemmDl"
627  << "<"
628  << BlockSize << ", "
629  << MPerBlock << ", "
630  << NPerBlock << ", "
631  << K0PerBlock << ", "
632  << K1 << ", "
633  << M1PerThread << ", "
634  << N1PerThread << ", "
635  << KPerThread
636  << ">";
637  // clang-format on
638 
639  return str.str();
640  }
641 };
642 
643 } // namespace device
644 } // namespace tensor_operation
645 } // namespace ck
#define CK_ENV(name)
Definition: env.hpp:128
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:264
typename remove_reference< T >::type remove_reference_t
Definition: type.hpp:292
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
std::string get_device_name()
Definition: device_prop.hpp:12
bool is_gfx12_supported()
Definition: device_prop.hpp:94
__global__ void kernel_gemm_dl_v1r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map)
Definition: gridwise_gemm_dl_v1r3.hpp:33
bool is_gfx103_supported()
Definition: device_prop.hpp:81
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:139
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:13
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
bool is_gfx11_supported()
Definition: device_prop.hpp:88
Definition: stream_config.hpp:10
Definition: gridwise_gemm_dl_v1r3.hpp:93
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:129
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:208
__host__ static constexpr __device__ auto MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1)
Definition: gridwise_gemm_dl_v1r3.hpp:188
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_dl_v1r3.hpp:146
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_v1r3.hpp:153
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:241
__host__ static constexpr __device__ bool CalculateHasDoubleTailKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_v1r3.hpp:160
__host__ static constexpr __device__ auto MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1)
Definition: gridwise_gemm_dl_v1r3.hpp:168
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: device_base.hpp:50
Definition: device_base.hpp:61
index_t M_raw_
Definition: device_gemm_dl.hpp:320
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_
Definition: device_gemm_dl.hpp:306
CGridDesc_M_N c_grid_desc_m_n_
Definition: device_gemm_dl.hpp:308
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_
Definition: device_gemm_dl.hpp:311
index_t M01_
Definition: device_gemm_dl.hpp:317
index_t N01_
Definition: device_gemm_dl.hpp:318
CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_
Definition: device_gemm_dl.hpp:312
index_t K_raw_
Definition: device_gemm_dl.hpp:322
CDataType * p_c_grid_
Definition: device_gemm_dl.hpp:304
index_t N_raw_
Definition: device_gemm_dl.hpp:321
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_
Definition: device_gemm_dl.hpp:307
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t M01, index_t N01, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_gemm_dl.hpp:253
const BDataType * p_b_grid_
Definition: device_gemm_dl.hpp:303
AElementwiseOperation a_element_op_
Definition: device_gemm_dl.hpp:325
BElementwiseOperation b_element_op_
Definition: device_gemm_dl.hpp:326
DefaultBlock2CTileMap block_2_ctile_map_
Definition: device_gemm_dl.hpp:314
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_
Definition: device_gemm_dl.hpp:310
CElementwiseOperation c_element_op_
Definition: device_gemm_dl.hpp:327
const ADataType * p_a_grid_
Definition: device_gemm_dl.hpp:302
Definition: device_gemm_dl.hpp:332
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_dl.hpp:335
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_dl.hpp:479
Definition: device_gemm_dl.hpp:77
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{})) DefaultBlock2CTileMap
Definition: device_gemm_dl.hpp:248
static constexpr auto I0
Definition: device_gemm_dl.hpp:78
static constexpr auto I2
Definition: device_gemm_dl.hpp:80
GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector > GridwiseGemm
Definition: device_gemm_dl.hpp:239
virtual std::string GetTypeString() const override
Definition: device_gemm_dl.hpp:621
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_dl.hpp:548
decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)) BGridDesc_K0_N_K1
Definition: device_gemm_dl.hpp:201
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_dl.hpp:492
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})) AGridDesc_K0_M0_M1_K1
Definition: device_gemm_dl.hpp:242
static constexpr auto I3
Definition: device_gemm_dl.hpp:81
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_dl.hpp:615
static auto MakeInvoker()
Definition: device_gemm_dl.hpp:582
decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)) AGridDesc_K0_M_K1
Definition: device_gemm_dl.hpp:200
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
Definition: device_gemm_dl.hpp:126
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})) BGridDesc_K0_N0_N1_K1
Definition: device_gemm_dl.hpp:244
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_gemm_dl.hpp:585
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})) CGridDesc_M0_M10_M11_N0_N10_N11
Definition: device_gemm_dl.hpp:246
static constexpr auto I5
Definition: device_gemm_dl.hpp:83
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) CGridDesc_M_N
Definition: device_gemm_dl.hpp:202
static constexpr auto I1
Definition: device_gemm_dl.hpp:79
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_dl.hpp:486
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_gemm_dl.hpp:553
static constexpr auto I4
Definition: device_gemm_dl.hpp:82
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
Definition: device_gemm_dl.hpp:165
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
Definition: device_gemm_dl.hpp:87
static constexpr auto K1Number
Definition: device_gemm_dl.hpp:85
Definition: device_gemm.hpp:22