/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.2/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.2/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.2/include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp Source File
device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <functional>
7 #include <iostream>
8 #include <iterator>
9 #include <numeric>
10 #include <sstream>
11 
25 #include "ck/host_utility/io.hpp"
26 
27 namespace ck {
28 namespace tensor_operation {
29 namespace device {
30 
31 namespace {
32 
33 /*
34  * \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
35  *
36  * \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
37  * given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
38  * strided batched, but we can easily extend to other layouts. The returned offset can be either \p
39  * index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
40  * limitations.
41  *
42  * \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
43  * returns the 2D index of the tile that it computes. \see
44  * GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
45  *
46  * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
47  * tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
48  * descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
49  * device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
50  * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
51  * pointer offset into \p ComputePtrOffsetOfStridedBatch.
52  *
53  * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
54  * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
55  * realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
56  *
57  */
58 template <typename GridwiseGemm,
59  typename ABDataType,
60  typename DsPointer,
61  typename EDataType,
62  typename AElementwiseOperation,
63  typename BElementwiseOperation,
64  typename CDEElementwiseOperation,
65  typename AGridDesc_K0_M0_M1_K1,
66  typename BGridDesc_K0_N0_N1_K1,
67  typename DsGridDesc_M0_M10_M11_N0_N10_N11,
68  typename CGridDesc_M0_M10_M11_N0_N10_N11,
69  typename Block2CTileMap,
70  typename ComputePtrOffsetOfBatch,
71  bool HasMainKBlockLoop,
72  bool HasDoubleTailKBlockLoop>
73 __global__ void
74 #if CK_USE_LAUNCH_BOUNDS
76 #endif
77  kernel_grouped_conv_fwd_dl_multiple_d(
78  const ABDataType* __restrict__ p_a_grid,
79  const ABDataType* __restrict__ p_b_grid,
80  DsPointer p_ds_grid,
81  EDataType* __restrict__ p_e_grid,
82  const AElementwiseOperation a_element_op,
83  const BElementwiseOperation b_element_op,
84  const CDEElementwiseOperation cde_element_op,
85  const index_t batch_count,
86  const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1,
87  const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1,
88  const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11,
89  const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11,
90  const Block2CTileMap block_2_ctile_map,
91  const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
92 {
93 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx103__) || \
94  defined(__gfx90a__) || defined(__gfx908__) || defined(__gfx94__) || defined(__gfx11__) || \
95  defined(__gfx12__))
96  // offset base pointer for each work-group
97  const index_t num_blocks_per_batch =
98  __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
99  const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
100 
101  const long_index_t a_batch_offset = amd_wave_read_first_lane(
102  static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
103  const long_index_t b_batch_offset = amd_wave_read_first_lane(
104  static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
105  const long_index_t c_batch_offset = amd_wave_read_first_lane(
106  static_cast<long_index_t>(compute_ptr_offset_of_batch.GetEPtrOffset(g_idx)));
107 
108  const auto ds_batch_offset = compute_ptr_offset_of_batch.GetDsPtrOffset(g_idx);
109 
110  constexpr index_t shared_block_size =
111  GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(ABDataType);
112 
113  __shared__ ABDataType p_shared[shared_block_size];
114 
115  DsPointer p_ds_grid_grp;
116 
117  static constexpr index_t NumDTensor = DsGridDesc_M0_M10_M11_N0_N10_N11::Size();
118 
119  static_for<0, NumDTensor, 1>{}(
120  [&](auto i) { p_ds_grid_grp(i) = p_ds_grid[i] + ds_batch_offset[i]; });
121 
122  GridwiseGemm::Run(p_a_grid + a_batch_offset,
123  p_b_grid + b_batch_offset,
124  p_ds_grid_grp,
125  p_e_grid + c_batch_offset,
126  p_shared,
127  a_element_op,
128  b_element_op,
129  cde_element_op,
130  a_grid_desc_k0_m0_m1_k1,
131  b_grid_desc_k0_n0_n1_k1,
132  ds_grid_desc_m0_m10_m11_n0_n10_n11,
133  e_grid_desc_m0_m10_m11_n0_n10_n11,
134  block_2_ctile_map,
135  integral_constant<bool, HasMainKBlockLoop>{},
136  integral_constant<bool, HasDoubleTailKBlockLoop>{});
137 #else
138  ignore = p_a_grid;
139  ignore = p_b_grid;
140  ignore = p_ds_grid;
141  ignore = p_e_grid;
142  ignore = a_element_op;
143  ignore = b_element_op;
144  ignore = cde_element_op;
145  ignore = batch_count;
146  ignore = a_grid_desc_k0_m0_m1_k1;
147  ignore = b_grid_desc_k0_n0_n1_k1;
148  ignore = ds_grid_desc_m0_m10_m11_n0_n10_n11;
149  ignore = e_grid_desc_m0_m10_m11_n0_n10_n11;
150  ignore = compute_ptr_offset_of_batch;
151  ignore = block_2_ctile_map;
152 
153  compute_ptr_offset_of_batch.GetAPtrOffset(0);
154  compute_ptr_offset_of_batch.GetBPtrOffset(0);
155  compute_ptr_offset_of_batch.GetEPtrOffset(0);
156 #endif
157 }
158 } // namespace
159 
160 //
161 // @brief Device Convolution operation.
162 //
163 // Supports:
164 // @li Forward convolution with up to 3 spatial dimentions
165 // @li Input tensor in GNWC data format
166 // @li Weight tensor in GKXC data format
167 // @li Output tensor in GNWK data format
168 //
169 // 1D:
170 // out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
171 // 2D:
172 // out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
173 // 3D:
174 // out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
175 //
176 template <index_t NDimSpatial,
177  typename ADataType,
178  typename BDataType,
179  typename DsDataType,
180  typename EDataType,
181  typename AccDataType,
182  typename ALayout,
183  typename BLayout,
184  typename DsLayout,
185  typename ELayout,
186  typename AElementwiseOperation,
187  typename BElementwiseOperation,
188  typename CDEElementwiseOperation,
189  ConvolutionForwardSpecialization ConvForwardSpecialization,
190  GemmSpecialization GemmSpec,
191  index_t BlockSize,
192  index_t MPerBlock,
193  index_t NPerBlock,
194  index_t K0PerBlock,
195  index_t K1,
196  index_t M1PerThread,
197  index_t N1PerThread,
198  index_t KPerThread,
199  typename M1N1ThreadClusterM1Xs,
200  typename M1N1ThreadClusterN1Xs,
201  typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
202  typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
203  typename ABlockTransferThreadClusterArrangeOrder,
204  typename ABlockTransferSrcAccessOrder,
205  typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
206  typename ABlockTransferSrcVectorTensorContiguousDimOrder,
207  typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
208  typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
209  typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
210  typename BBlockTransferThreadClusterArrangeOrder,
211  typename BBlockTransferSrcAccessOrder,
212  typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
213  typename BBlockTransferSrcVectorTensorContiguousDimOrder,
214  typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
215  typename CThreadTransferSrcDstAccessOrder,
216  index_t CThreadTransferSrcDstVectorDim,
217  index_t CThreadTransferDstScalarPerVector>
219  : public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
220  ALayout,
221  BLayout,
222  DsLayout,
223  ELayout,
224  ADataType,
225  BDataType,
226  DsDataType,
227  EDataType,
228  AElementwiseOperation,
229  BElementwiseOperation,
230  CDEElementwiseOperation>
231 {
233 
234  static constexpr index_t NumDTensor = DsDataType::Size();
235 
236  static constexpr auto I0 = Number<0>{};
237  static constexpr auto I1 = Number<1>{};
238  static constexpr auto I2 = Number<2>{};
239  static constexpr auto I3 = Number<3>{};
240 
242 
243  static constexpr auto matrix_padder =
244  MatrixPadder<GemmSpec, index_t, index_t, index_t>{MPerBlock, NPerBlock, K0PerBlock};
245 
246  template <typename ALay>
247  static auto
249  {
250  const auto in_gemmmraw_gemmkraw_desc =
251  conv_to_gemm_transformer.template MakeADescriptor_M_K<ALay>();
252 
253  const auto in_gemmm_gemmk_desc =
254  matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
255 
256  const auto M = in_gemmm_gemmk_desc.GetLength(I0);
257  const auto K = in_gemmm_gemmk_desc.GetLength(I1);
258  const auto AK0 = K / K1;
259 
261  in_gemmm_gemmk_desc,
265  }
266 
267  template <typename BLay>
268  static auto
270  {
271  const auto wei_gemmnraw_gemmkraw_desc =
272  conv_to_gemm_transformer.template MakeBDescriptor_N_K<BLay>();
273 
274  const auto wei_gemmn_gemmk_desc =
275  matrix_padder.PadBDescriptor_N_K(wei_gemmnraw_gemmkraw_desc);
276 
277  const auto N = wei_gemmn_gemmk_desc.GetLength(I0);
278  const auto K = wei_gemmn_gemmk_desc.GetLength(I1);
279 
280  const auto BK0 = K / K1;
281 
283  wei_gemmn_gemmk_desc,
287  }
288 
289  template <typename ELay>
290  static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
291  {
292  const auto out_gemmmraw_gemmnraw_desc =
293  conv_to_gemm_transformer.template MakeCDescriptor_M_N<ELay>();
294 
295  const auto out_gemmm_gemmn_desc =
296  matrix_padder.PadCDescriptor_M_N(out_gemmmraw_gemmnraw_desc);
297 
298  return out_gemmm_gemmn_desc;
299  }
300 
301  static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer& conv_to_gemm_transformer)
302  {
303  return generate_tuple(
304  [&](auto i) {
305  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
306 
307  return DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer);
308  },
310  }
311 
312  // desc for problem definition
322 
323  // GridwiseGemm
324  using GridwiseGemm =
326  ADataType,
327  AccDataType,
328  DsDataType,
329  EDataType,
330  AElementwiseOperation,
331  BElementwiseOperation,
332  CDEElementwiseOperation,
337  MPerBlock,
338  NPerBlock,
339  K0PerBlock,
340  K1,
341  M1PerThread,
342  N1PerThread,
343  KPerThread,
344  M1N1ThreadClusterM1Xs,
345  M1N1ThreadClusterN1Xs,
346  ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
347  ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
348  ABlockTransferThreadClusterArrangeOrder,
349  ABlockTransferSrcAccessOrder,
350  ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
351  ABlockTransferSrcVectorTensorContiguousDimOrder,
352  ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
353  BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
354  BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
355  BBlockTransferThreadClusterArrangeOrder,
356  BBlockTransferSrcAccessOrder,
357  BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
358  BBlockTransferSrcVectorTensorContiguousDimOrder,
359  BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
360  CThreadTransferSrcDstAccessOrder,
361  CThreadTransferSrcDstVectorDim,
362  CThreadTransferDstScalarPerVector>;
363 
374 
375  // Argument
376  struct Argument : public BaseArgument
377  {
378  Argument(const void* p_a,
379  const void* p_b,
380  const std::array<const void*, NumDTensor>& p_ds,
381  void* p_e,
382  const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
383  const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
384  const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
385  const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
386  const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
387  ds_g_n_k_wos_lengths,
388  const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
389  ds_g_n_k_wos_strides,
390  const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
391  const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
392  const std::array<index_t, NDimSpatial>& conv_filter_strides,
393  const std::array<index_t, NDimSpatial>& conv_filter_dilations,
394  const std::array<index_t, NDimSpatial>& input_left_pads,
395  const std::array<index_t, NDimSpatial>& input_right_pads,
396  const AElementwiseOperation& a_element_op,
397  const BElementwiseOperation& b_element_op,
398  const CDEElementwiseOperation& cde_element_op)
399  : p_a_grid_{static_cast<const ADataType*>(p_a)},
400  p_b_grid_{static_cast<const BDataType*>(p_b)},
401  p_ds_grid_{},
402  p_e_grid_{static_cast<EDataType*>(p_e)},
403  num_group_{a_g_n_c_wis_lengths[0]},
404  conv_to_gemm_transformer_{a_g_n_c_wis_lengths,
405  a_g_n_c_wis_strides,
406  b_g_k_c_xs_lengths,
407  b_g_k_c_xs_strides,
408  e_g_n_k_wos_lengths,
409  e_g_n_k_wos_strides,
410  conv_filter_strides,
411  conv_filter_dilations,
412  input_left_pads,
413  input_right_pads},
426  a_element_op_{a_element_op},
427  b_element_op_{b_element_op},
428  cde_element_op_{cde_element_op},
429  a_g_n_c_wis_lengths_{a_g_n_c_wis_lengths},
430  a_g_n_c_wis_strides_{a_g_n_c_wis_strides},
431  b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
432  b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
433  e_g_n_k_wos_lengths_{e_g_n_k_wos_lengths},
434  e_g_n_k_wos_strides_{e_g_n_k_wos_strides},
435  conv_filter_strides_{conv_filter_strides},
436  conv_filter_dilations_{conv_filter_dilations},
437  input_left_pads_{input_left_pads},
438  input_right_pads_{input_right_pads}
439  {
440  // A/B/E Batch Stride
441  compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_c_wis_strides[0];
442  compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
443  compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_k_wos_strides[0];
444 
445  // populate pointer, batch stride, desc for Ds
446  static_for<0, NumDTensor, 1>{}([&](auto i) {
447  using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
448  using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
449 
450  ConvToGemmFwdTransformer conv_to_gemm_transformer_d{a_g_n_c_wis_lengths,
451  a_g_n_c_wis_strides,
452  b_g_k_c_xs_lengths,
453  b_g_k_c_xs_strides,
454  ds_g_n_k_wos_lengths[i],
455  ds_g_n_k_wos_strides[i],
456  conv_filter_strides,
457  conv_filter_dilations,
458  input_left_pads,
459  input_right_pads};
460 
461  // D pointer
462  p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
463 
464  // D batch stride
465  compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_k_wos_strides[i][0];
466 
467  // D desc
468  ds_grid_desc_m_n_(i) =
469  DeviceOp::MakeEGridDescriptor_M_N<DLayout>(conv_to_gemm_transformer_d);
470  });
471 
472  // populate desc for Ds/E
475  {
476 
483 
486 
488  }
489  }
490 
491  void Print() const
492  {
493  std::cout << "A[K0, M, K1]: " << a_grid_desc_ak0_m_ak1_ << std::endl;
494  std::cout << "B[K0, N, K1]: " << b_grid_desc_bk0_n_bk1_ << std::endl;
495  std::cout << "E[M, N]: " << e_grid_desc_m_n_ << std::endl;
496  std::cout << "num_group: " << num_group_ << std::endl;
497 
498  std::cout << "A[k0, m0, m1, k1]: " << a_grid_desc_k0_m0_m1_k1_ << std::endl;
499  std::cout << "B[k0, n0, n1, k1]: " << b_grid_desc_k0_n0_n1_k1_ << std::endl;
500  std::cout << "A[m0, m10, m11, n0, n10, n11]: " << e_grid_desc_m0_m10_m11_n0_n10_n11_
501  << std::endl;
502  }
503 
504  // private:
505  // pointers
506  const ADataType* p_a_grid_;
507  const BDataType* p_b_grid_;
509  EDataType* p_e_grid_;
510 
511  // tensor descriptors for problem definiton
513 
515 
520 
521  // tensor descriptors for block/thread-wise copy
526 
527  // block-to-e-tile map
529 
530  // for computing batch offset
531  ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor> compute_ptr_offset_of_batch_;
532 
533  // element-wise op
534  AElementwiseOperation a_element_op_;
535  BElementwiseOperation b_element_op_;
536  CDEElementwiseOperation cde_element_op_;
537 
538  // for checking IsSupportedArgument()
539  std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_;
540  std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_;
541  std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
542  std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
543  std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_;
544  std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_;
545  std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_;
546  std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_;
547  std::array<index_t, NDimSpatial> conv_filter_strides_;
548  std::array<index_t, NDimSpatial> conv_filter_dilations_;
549  std::array<index_t, NDimSpatial> input_left_pads_;
550  std::array<index_t, NDimSpatial> input_right_pads_;
551  };
552 
553  // Invoker
554  struct Invoker : public BaseInvoker
555  {
557 
558  float Run(const Argument& arg, const StreamConfig& stream_config)
559  {
560  if(stream_config.log_level_ > 0)
561  {
562  arg.Print();
563  }
564 
567  {
568  throw std::runtime_error(
569  "wrong! DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK has invalid setting");
570  }
571 
572  const index_t grid_size =
574  arg.e_grid_desc_m_n_.GetLength(I1)) *
575  arg.num_group_;
576 
577  auto launch_kernel = [&](auto has_main_k_block_loop,
578  auto has_double_tail_k_block_loop) {
579  constexpr bool has_main_loop = has_main_k_block_loop.value;
580  constexpr bool has_double_loop = has_double_tail_k_block_loop;
581 
582  const auto kernel = kernel_grouped_conv_fwd_dl_multiple_d<
583  GridwiseGemm,
584  ADataType, // TODO: distiguish A/B datatype
586  EDataType,
587  AElementwiseOperation,
588  BElementwiseOperation,
589  CDEElementwiseOperation,
595  ComputePtrOffsetOfStridedBatch<I1, I1, NumDTensor>,
596  has_main_loop,
597  has_double_loop>;
598 
599  return launch_and_time_kernel(stream_config,
600  kernel,
601  dim3(grid_size),
602  dim3(BlockSize),
603  0,
604  arg.p_a_grid_,
605  arg.p_b_grid_,
606  arg.p_ds_grid_,
607  arg.p_e_grid_,
608  arg.a_element_op_,
609  arg.b_element_op_,
610  arg.cde_element_op_,
611  arg.a_g_n_c_wis_lengths_[0], // Group count
616  arg.block_2_ctile_map_,
618  };
619 
620  const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
621  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
622  const bool has_double_tail_k_block_loop =
624 
625  if(has_main_k_block_loop && has_double_tail_k_block_loop)
626  {
629  }
630  else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
631  {
634  }
635  else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
636  {
639  }
640  else
641  {
644  }
645  return 0;
646  }
647 
648  float Run(const BaseArgument* p_arg,
649  const StreamConfig& stream_config = StreamConfig{}) override
650  {
651  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
652  }
653  };
654 
655  static bool IsSupportedArgument(const Argument& arg)
656  {
657  namespace ctc = tensor_layout::convolution;
658 
659  // check device
660  if(!(ck::get_device_name() == "gfx906" || ck::is_xdl_supported() ||
662  {
663  return false;
664  }
665 
666  // check ConvolutionForwardSpecialization
667  if constexpr(ConvForwardSpecialization ==
669  {
670  // check if it's 1x1, stride=1 conv
671  for(index_t i = 0; i < NDimSpatial; ++i)
672  {
673  const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
674  const index_t ConvStride = arg.conv_filter_strides_[i];
675  const index_t LeftPad = arg.input_left_pads_[i];
676  const index_t RightPad = arg.input_right_pads_[i];
677 
678  if(!(X == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
679  {
680  std::cout << "Filter1x1Stride1Pad0 check: XY_index = " << i << " X = " << X
681  << " ConvStride = " << ConvStride << " LeftPad = " << LeftPad
682  << " RightPad = " << RightPad << std::endl;
683  return false;
684  }
685  }
686  }
687  else if constexpr(ConvForwardSpecialization ==
689  {
690  // check if it's 1x1 conv
691  for(index_t i = 0; i < NDimSpatial; ++i)
692  {
693  const index_t X = arg.b_g_k_c_xs_lengths_[i + 3];
694  const index_t LeftPad = arg.input_left_pads_[i];
695  const index_t RightPad = arg.input_right_pads_[i];
696 
697  if(!(X == 1 && LeftPad == 0 && RightPad == 0))
698  {
699  std::cout << "Filter1x1Stride1Pad0 check: XY_index = " << i << " X = " << X
700  << " LeftPad = " << LeftPad << " RightPad = " << RightPad
701  << std::endl;
702  return false;
703  }
704  }
705  }
706 
707  // check vector access of A
708  // FIXME: layout
709  if constexpr(is_same_v<ALayout, ctc::G_NW_C> || is_same_v<ALayout, ctc::G_NHW_C> ||
710  is_same_v<ALayout, ctc::G_NDHW_C> || is_same_v<ALayout, ctc::GNWC> ||
711  is_same_v<ALayout, ctc::GNHWC> || is_same_v<ALayout, ctc::GNDHWC> ||
712  is_same_v<ALayout, ctc::NWGC> || is_same_v<ALayout, ctc::NHWGC> ||
713  is_same_v<ALayout, ctc::NDHWGC>)
714  {
715  auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{};
716  if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1)
717  {
718  return false;
719  }
720  if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0)
721  {
722  return false;
723  }
724 
725  const index_t C = arg.a_g_n_c_wis_lengths_[2];
726 
727  if(C % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0)
728  {
729  return false;
730  }
731  }
732  else
733  {
734  return false;
735  }
736 
737  // check vector access of B
738  // FIXME: layout
739  if constexpr(is_same_v<BLayout, ctc::G_K_X_C> || is_same_v<BLayout, ctc::G_K_YX_C> ||
740  is_same_v<BLayout, ctc::G_K_ZYX_C> || is_same_v<BLayout, ctc::GKXC> ||
741  is_same_v<BLayout, ctc::GKYXC> || is_same_v<BLayout, ctc::GKZYXC> ||
742  is_same_v<BLayout, ctc::KXGC> || is_same_v<BLayout, ctc::KYXGC> ||
743  is_same_v<BLayout, ctc::KZYXGC>)
744 
745  {
746  auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{};
747  if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1)
748  {
749  return false;
750  }
751  if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0)
752  {
753  return false;
754  }
755 
756  const index_t C = arg.b_g_k_c_xs_lengths_[2];
757 
758  if(C % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0)
759  {
760  return false;
761  }
762  }
763  else
764  {
765  return false;
766  }
767 
768  // check vector access of E
769  if constexpr(is_same_v<ELayout, ctc::G_NW_K> || is_same_v<ELayout, ctc::G_NHW_K> ||
770  is_same_v<ELayout, ctc::G_NDHW_K> || is_same_v<ELayout, ctc::GNWK> ||
771  is_same_v<ELayout, ctc::GNHWK> || is_same_v<ELayout, ctc::GNDHWK> ||
772  is_same_v<ELayout, ctc::NWGK> || is_same_v<ELayout, ctc::NHWGK> ||
773  is_same_v<ELayout, ctc::NDHWGK>)
774  {
775  const index_t K = arg.e_g_n_k_wos_lengths_[2];
776 
777  if(!(K % CThreadTransferDstScalarPerVector == 0 && CThreadTransferSrcDstVectorDim == 5))
778  {
779  return false;
780  }
781  }
782  else
783  {
784  return false;
785  }
786 
787  // check Gridwise GEMM
790  }
791 
792  bool IsSupportedArgument(const BaseArgument* p_arg) override
793  {
794  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
795  }
796 
797  static auto MakeArgument(
798  const void* p_a,
799  const void* p_b,
800  const std::array<const void*, NumDTensor>& p_ds,
801  void* p_e,
802  const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
803  const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
804  const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
805  const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
806  const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
807  const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
808  const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
809  const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
810  const std::array<index_t, NDimSpatial>& conv_filter_strides,
811  const std::array<index_t, NDimSpatial>& conv_filter_dilations,
812  const std::array<index_t, NDimSpatial>& input_left_pads,
813  const std::array<index_t, NDimSpatial>& input_right_pads,
814  const AElementwiseOperation& a_element_op,
815  const BElementwiseOperation& b_element_op,
816  const CDEElementwiseOperation& cde_element_op)
817  {
818  return Argument{p_a,
819  p_b,
820  p_ds,
821  p_e,
822  a_g_n_c_wis_lengths,
823  a_g_n_c_wis_strides,
824  b_g_k_c_xs_lengths,
825  b_g_k_c_xs_strides,
826  ds_g_n_k_wos_lengths,
827  ds_g_n_k_wos_strides,
828  e_g_n_k_wos_lengths,
829  e_g_n_k_wos_strides,
830  conv_filter_strides,
831  conv_filter_dilations,
832  input_left_pads,
833  input_right_pads,
834  a_element_op,
835  b_element_op,
836  cde_element_op};
837  }
838 
839  static auto
840  MakeArgument(const void* p_a,
841  const void* p_b,
842  const std::array<const void*, NumDTensor>& p_ds,
843  void* p_e,
844  const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
845  const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
846  const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
847  const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
848  const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
849  ds_g_n_k_wos_lengths,
850  const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
851  ds_g_n_k_wos_strides,
852  const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
853  const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
854  const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
855  const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
856  const std::array<long_index_t, NDimSpatial>& input_left_pads,
857  const std::array<long_index_t, NDimSpatial>& input_right_pads,
858  const AElementwiseOperation& a_element_op,
859  const BElementwiseOperation& b_element_op,
860  const CDEElementwiseOperation& cde_element_op)
861  {
862  std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
863  std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
864  std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
865  std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
866  std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
867  std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
868  std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
869  std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
870  std::array<index_t, NDimSpatial> conv_filter_strides_i32;
871  std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
872  std::array<index_t, NDimSpatial> input_left_pads_i32;
873  std::array<index_t, NDimSpatial> input_right_pads_i32;
874 
875  array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
876  array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
877  array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
878  array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
879  for(index_t d = 0; d < NumDTensor; d++)
880  {
881  array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
882  array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
883  }
884  array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
885  array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
886  array_convert(conv_filter_strides_i32, conv_filter_strides);
887  array_convert(conv_filter_dilations_i32, conv_filter_dilations);
888  array_convert(input_left_pads_i32, input_left_pads);
889  array_convert(input_right_pads_i32, input_right_pads);
890 
891  return Argument{p_a,
892  p_b,
893  p_ds,
894  p_e,
895  a_g_n_c_wis_lengths_i32,
896  a_g_n_c_wis_strides_i32,
897  b_g_k_c_xs_lengths_i32,
898  b_g_k_c_xs_strides_i32,
899  ds_g_n_k_wos_lengths_i32,
900  ds_g_n_k_wos_strides_i32,
901  e_g_n_k_wos_lengths_i32,
902  e_g_n_k_wos_strides_i32,
903  conv_filter_strides_i32,
904  conv_filter_dilations_i32,
905  input_left_pads_i32,
906  input_right_pads_i32,
907  a_element_op,
908  b_element_op,
909  cde_element_op};
910  }
911 
912  static auto MakeInvoker() { return Invoker{}; }
913 
914  std::unique_ptr<BaseArgument> MakeArgumentPointer(
915  const void* p_a,
916  const void* p_b,
917  const std::array<const void*, NumDTensor>& p_ds,
918  void* p_e,
919  const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
920  const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
921  const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
922  const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
923  const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
924  const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
925  const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
926  const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
927  const std::array<index_t, NDimSpatial>& conv_filter_strides,
928  const std::array<index_t, NDimSpatial>& conv_filter_dilations,
929  const std::array<index_t, NDimSpatial>& input_left_pads,
930  const std::array<index_t, NDimSpatial>& input_right_pads,
931  const AElementwiseOperation& a_element_op,
932  const BElementwiseOperation& b_element_op,
933  const CDEElementwiseOperation& cde_element_op) override
934  {
935  return std::make_unique<Argument>(p_a,
936  p_b,
937  p_ds,
938  p_e,
939  a_g_n_c_wis_lengths,
940  a_g_n_c_wis_strides,
941  b_g_k_c_xs_lengths,
942  b_g_k_c_xs_strides,
943  ds_g_n_k_wos_lengths,
944  ds_g_n_k_wos_strides,
945  e_g_n_k_wos_lengths,
946  e_g_n_k_wos_strides,
947  conv_filter_strides,
948  conv_filter_dilations,
949  input_left_pads,
950  input_right_pads,
951  a_element_op,
952  b_element_op,
953  cde_element_op);
954  }
955 
956  std::unique_ptr<BaseArgument>
957  MakeArgumentPointer(const void* p_a,
958  const void* p_b,
959  const std::array<const void*, NumDTensor>& p_ds,
960  void* p_e,
961  const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
962  const std::array<long_index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
963  const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
964  const std::array<long_index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
965  const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
966  ds_g_n_k_wos_lengths,
967  const std::array<std::array<long_index_t, NDimSpatial + 3>, NumDTensor>&
968  ds_g_n_k_wos_strides,
969  const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
970  const std::array<long_index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
971  const std::array<long_index_t, NDimSpatial>& conv_filter_strides,
972  const std::array<long_index_t, NDimSpatial>& conv_filter_dilations,
973  const std::array<long_index_t, NDimSpatial>& input_left_pads,
974  const std::array<long_index_t, NDimSpatial>& input_right_pads,
975  const AElementwiseOperation& a_element_op,
976  const BElementwiseOperation& b_element_op,
977  const CDEElementwiseOperation& cde_element_op) override
978  {
979  std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths_i32;
980  std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_strides_i32;
981  std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_i32;
982  std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_i32;
983  std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_lengths_i32;
984  std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_k_wos_strides_i32;
985  std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_lengths_i32;
986  std::array<index_t, NDimSpatial + 3> e_g_n_k_wos_strides_i32;
987  std::array<index_t, NDimSpatial> conv_filter_strides_i32;
988  std::array<index_t, NDimSpatial> conv_filter_dilations_i32;
989  std::array<index_t, NDimSpatial> input_left_pads_i32;
990  std::array<index_t, NDimSpatial> input_right_pads_i32;
991 
992  array_convert(a_g_n_c_wis_lengths_i32, a_g_n_c_wis_lengths);
993  array_convert(a_g_n_c_wis_strides_i32, a_g_n_c_wis_strides);
994  array_convert(b_g_k_c_xs_lengths_i32, b_g_k_c_xs_lengths);
995  array_convert(b_g_k_c_xs_strides_i32, b_g_k_c_xs_strides);
996  for(index_t d = 0; d < NumDTensor; d++)
997  {
998  array_convert(ds_g_n_k_wos_lengths_i32[d], ds_g_n_k_wos_lengths[d]);
999  array_convert(ds_g_n_k_wos_strides_i32[d], ds_g_n_k_wos_strides[d]);
1000  }
1001  array_convert(e_g_n_k_wos_lengths_i32, e_g_n_k_wos_lengths);
1002  array_convert(e_g_n_k_wos_strides_i32, e_g_n_k_wos_strides);
1003  array_convert(conv_filter_strides_i32, conv_filter_strides);
1004  array_convert(conv_filter_dilations_i32, conv_filter_dilations);
1005  array_convert(input_left_pads_i32, input_left_pads);
1006  array_convert(input_right_pads_i32, input_right_pads);
1007 
1008  return std::make_unique<Argument>(p_a,
1009  p_b,
1010  p_ds,
1011  p_e,
1012  a_g_n_c_wis_lengths_i32,
1013  a_g_n_c_wis_strides_i32,
1014  b_g_k_c_xs_lengths_i32,
1015  b_g_k_c_xs_strides_i32,
1016  ds_g_n_k_wos_lengths_i32,
1017  ds_g_n_k_wos_strides_i32,
1018  e_g_n_k_wos_lengths_i32,
1019  e_g_n_k_wos_strides_i32,
1020  conv_filter_strides_i32,
1021  conv_filter_dilations_i32,
1022  input_left_pads_i32,
1023  input_right_pads_i32,
1024  a_element_op,
1025  b_element_op,
1026  cde_element_op);
1027  }
1028 
1029  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
1030  {
1031  return std::make_unique<Invoker>(Invoker{});
1032  }
1033 
1034  std::string GetTypeString() const override
1035  {
1036  auto str = std::stringstream();
1037 
1038  // clang-format off
1039  str << "DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK"
1040  << "<"
1041  << BlockSize << ", "
1042  << MPerBlock << ", "
1043  << NPerBlock << ", "
1044  << K0PerBlock << ", "
1045  << getConvForwardSpecializationString(ConvForwardSpecialization) << ", "
1046  << K1
1047  << ">";
1048  // clang-format on
1049 
1050  return str.str();
1051  }
1052 };
1053 
1054 } // namespace device
1055 } // namespace tensor_operation
1056 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:14
GemmSpecialization
Definition: gemm_specialization.hpp:11
ConvolutionForwardSpecialization
Definition: convolution_forward_specialization.hpp:15
std::string getConvForwardSpecializationString(const ConvolutionForwardSpecialization &s)
Definition: convolution_forward_specialization.hpp:24
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables &&... callables)
Definition: kernel_launch.hpp:94
Definition: ck.hpp:269
bool is_xdl_supported()
Definition: device_prop.hpp:55
__device__ index_t get_grid_size()
Definition: get_id.hpp:27
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
int64_t long_index_t
Definition: ck.hpp:301
std::string get_device_name()
Definition: device_prop.hpp:19
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition: amd_wave_read_first_lane.hpp:100
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
bool is_gfx12_supported()
Definition: device_prop.hpp:94
bool is_gfx103_supported()
Definition: device_prop.hpp:79
__host__ __device__ void array_convert(std::array< Y, NumElems > &y, const std::array< X, NumElems > &x)
Definition: type_convert.hpp:2390
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:300
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
bool is_gfx11_supported()
Definition: device_prop.hpp:86
Definition: stream_config.hpp:10
int log_level_
Definition: stream_config.hpp:13
Definition: gridwise_gemm_dl_multiple_d.hpp:60
__host__ static constexpr __device__ auto MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1)
Definition: gridwise_gemm_dl_multiple_d.hpp:178
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_multiple_d.hpp:143
__host__ static constexpr __device__ auto MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(const DsGridDesc_M_N &ds_grid_desc_m_n)
Definition: gridwise_gemm_dl_multiple_d.hpp:234
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_multiple_d.hpp:242
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_dl_multiple_d.hpp:253
__host__ static constexpr __device__ auto MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1)
Definition: gridwise_gemm_dl_multiple_d.hpp:158
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_dl_multiple_d.hpp:136
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N_ &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_multiple_d.hpp:200
__host__ static constexpr __device__ bool CalculateHasDoubleTailKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_multiple_d.hpp:150
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_multiple_d.hpp:110
Definition: multi_index_transform.hpp:196
Definition: multi_index_transform.hpp:284
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: transform_conv_fwd_to_gemm.hpp:25
Definition: device_base.hpp:51
Definition: device_base.hpp:62
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:377
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_strides_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:542
std::array< index_t, NDimSpatial+3 > a_g_n_c_wis_lengths_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:539
void Print() const
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:491
const ADataType * p_a_grid_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:506
ConvToGemmFwdTransformer conv_to_gemm_transformer_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:514
std::array< index_t, NDimSpatial > conv_filter_strides_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:547
EDataType * p_e_grid_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:509
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:522
std::array< index_t, NDimSpatial+3 > a_g_n_c_wis_strides_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:540
std::array< index_t, NDimSpatial+3 > e_g_n_k_wos_strides_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:546
GridwiseGemm::DsGridPointer p_ds_grid_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:508
std::array< index_t, NDimSpatial > conv_filter_dilations_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:548
ComputePtrOffsetOfStridedBatch< I1, I1, NumDTensor > compute_ptr_offset_of_batch_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:531
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:523
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:516
std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_strides_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:544
const BDataType * p_b_grid_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:507
EGridDesc_M_N e_grid_desc_m_n_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:519
std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > ds_g_n_k_wos_lengths_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:543
BElementwiseOperation b_element_op_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:535
AElementwiseOperation a_element_op_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:534
DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:524
Argument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:378
DefaultBlock2CTileMap block_2_ctile_map_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:528
DsGridDesc_M_N ds_grid_desc_m_n_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:518
index_t num_group_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:512
std::array< index_t, NDimSpatial+3 > e_g_n_k_wos_lengths_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:545
std::array< index_t, NDimSpatial+3 > b_g_k_c_xs_lengths_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:541
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:517
std::array< index_t, NDimSpatial > input_right_pads_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:550
CDEElementwiseOperation cde_element_op_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:536
CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:525
std::array< index_t, NDimSpatial > input_left_pads_
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:549
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:555
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:648
float Run(const Argument &arg, const StreamConfig &stream_config)
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:558
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:231
static constexpr auto I1
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:237
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_AK0_M_AK1{})) AGridDesc_K0_M0_M1_K1
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:365
constexpr static ConvToGemmFwdTransformer dummy_conv_to_gemm_transformer
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:313
static auto MakeBGridDescriptor_BK0_N_BK1(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:269
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(EGridDesc_M_N{})) CGridDesc_M0_M10_M11_N0_N10_N11
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:371
static constexpr index_t NumDTensor
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:234
static constexpr auto I3
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:239
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op) override
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:914
static constexpr auto I2
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:238
remove_cvref_t< decltype(MakeAGridDescriptor_AK0_M_AK1< ALayout >(dummy_conv_to_gemm_transformer))> AGridDesc_AK0_M_AK1
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:315
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(EGridDesc_M_N{})) DefaultBlock2CTileMap
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:373
std::string GetTypeString() const override
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:1034
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_BK0_N_BK1{})) BGridDesc_K0_N0_N1_K1
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:367
static auto MakeAGridDescriptor_AK0_M_AK1(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:248
remove_cvref_t< decltype(MakeBGridDescriptor_BK0_N_BK1< BLayout >(dummy_conv_to_gemm_transformer))> BGridDesc_BK0_N_BK1
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:317
static constexpr auto matrix_padder
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:243
GridwiseGemmDlMultipleD_km_kn_mn< BlockSize, ADataType, AccDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, EGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector > GridwiseGemm
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:362
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N(dummy_conv_to_gemm_transformer))> DsGridDesc_M_N
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:319
static bool IsSupportedArgument(const Argument &arg)
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:655
static auto MakeInvoker()
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:912
static auto MakeArgument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:797
static auto MakeEGridDescriptor_M_N(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:290
static auto MakeArgument(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_dilations, const std::array< long_index_t, NDimSpatial > &input_left_pads, const std::array< long_index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op)
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:840
static auto MakeDsGridDescriptor_M_N(const ConvToGemmFwdTransformer &conv_to_gemm_transformer)
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:301
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const std::array< const void *, NumDTensor > &p_ds, void *p_e, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< long_index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< long_index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_lengths, const std::array< std::array< long_index_t, NDimSpatial+3 >, NumDTensor > &ds_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< long_index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_strides, const std::array< long_index_t, NDimSpatial > &conv_filter_dilations, const std::array< long_index_t, NDimSpatial > &input_left_pads, const std::array< long_index_t, NDimSpatial > &input_right_pads, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CDEElementwiseOperation &cde_element_op) override
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:957
decltype(GridwiseGemm::MakeDsGridDescriptor_M0_M10_M11_N0_N10_N11(DsGridDesc_M_N{})) DsGridDesc_M0_M10_M11_N0_N10_N11
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:369
remove_cvref_t< decltype(MakeEGridDescriptor_M_N< ELayout >(dummy_conv_to_gemm_transformer))> EGridDesc_M_N
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:321
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:1029
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:792
static constexpr auto I0
Definition: device_grouped_conv_fwd_dl_multiple_d_nhwc_kyxc_nhwk.hpp:236
Grouped Convolution Forward.
Definition: device_grouped_conv_fwd_multiple_abd.hpp:73
Definition: matrix_padder.hpp:180