/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_pool3d_fwd_ndhwc_ndhwc.hpp Source File
device_pool3d_fwd_ndhwc_ndhwc.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
18 
19 namespace ck {
20 namespace tensor_operation {
21 namespace device {
22 
23 template <typename InDataType,
24  typename OutDataType,
25  typename IndexDataType, // enable if OutputIndex == true
26  typename ComputeDataType,
27  ck::ReduceTensorOp ReduceOpId,
28  bool OutputIndex,
29  ck::index_t BlockSize,
30  ck::index_t MThreadClusterSize,
31  ck::index_t KThreadClusterSize,
32  ck::index_t MThreadSliceSize,
33  ck::index_t KThreadSliceSize,
34  ck::index_t InSrcOutDstVectorSize>
36  3,
37  InDataType,
38  OutDataType,
39  IndexDataType,
40  tensor_layout::convolution::NDHWC,
41  tensor_layout::convolution::NDHWC,
42  ReduceOpId,
43  OutputIndex>
44 {
45  static constexpr auto I0 = Number<0>{};
46  static constexpr auto I1 = Number<1>{};
47  static constexpr auto I2 = Number<2>{};
48  static constexpr auto I3 = Number<3>{};
49  static constexpr auto I4 = Number<4>{};
50  static constexpr auto I5 = Number<5>{};
51 
52  static constexpr index_t InOutRank = 5;
53  static constexpr index_t WindowRank = 3;
54 
56 
59 
62 
63  static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
64  static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
65 
66  static auto MakeABGridDescriptor_A_M_K_B_M(std::vector<ck::index_t> input_ncdhw_lengths,
67  std::vector<ck::index_t> output_ncdhw_lengths,
68  std::vector<ck::index_t> input_ncdhw_stride,
69  std::vector<ck::index_t> output_ncdhw_stride,
70  std::vector<ck::index_t> window_spatial_zyx_lengths,
71  std::vector<ck::index_t> window_zyx_strides,
72  std::vector<ck::index_t> window_zyx_dilations,
73  std::vector<ck::index_t> input_left_dhw_pads,
74  std::vector<ck::index_t> input_right_dhw_pads)
75  {
76  const index_t N = input_ncdhw_lengths[0];
77  const index_t C = input_ncdhw_lengths[1];
78  const index_t Di = input_ncdhw_lengths[2];
79  const index_t Hi = input_ncdhw_lengths[3];
80  const index_t Wi = input_ncdhw_lengths[4];
81 
82  const index_t Do = output_ncdhw_lengths[2];
83  const index_t Ho = output_ncdhw_lengths[3];
84  const index_t Wo = output_ncdhw_lengths[4];
85 
86  const index_t Z = window_spatial_zyx_lengths[0];
87  const index_t Y = window_spatial_zyx_lengths[1];
88  const index_t X = window_spatial_zyx_lengths[2];
89 
90  const index_t WindowStrideD = window_zyx_strides[0];
91  const index_t WindowStrideH = window_zyx_strides[1];
92  const index_t WindowStrideW = window_zyx_strides[2];
93 
94  const index_t WindowDilationD = window_zyx_dilations[0];
95  const index_t WindowDilationH = window_zyx_dilations[1];
96  const index_t WindowDilationW = window_zyx_dilations[2];
97 
98  const index_t InLeftPadD = input_left_dhw_pads[0];
99  const index_t InLeftPadH = input_left_dhw_pads[1];
100  const index_t InLeftPadW = input_left_dhw_pads[2];
101 
102  const index_t InRightPadD = input_right_dhw_pads[0];
103  const index_t InRightPadH = input_right_dhw_pads[1];
104  const index_t InRightPadW = input_right_dhw_pads[2];
105 
106  const index_t MRaw = N * Do * Ho * Wo * C;
107  const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw;
108 
109  const index_t KRaw = Z * Y * X;
110  const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw;
111 
112  // A[ReduceM, ReduceK]
113  const index_t Ni_stride = input_ncdhw_stride[0];
114  const index_t Ci_stride = input_ncdhw_stride[1];
115  const index_t Di_stride = input_ncdhw_stride[2];
116  const index_t Hi_stride = input_ncdhw_stride[3];
117  const index_t Wi_stride = input_ncdhw_stride[4];
118 
119  const auto in_grid_desc_n_di_hi_wi_c = make_naive_tensor_descriptor(
120  make_tuple(N, Di, Hi, Wi, C),
121  make_tuple(Ni_stride, Di_stride, Hi_stride, Wi_stride, Ci_stride));
122 
123  const auto in_grid_desc_n_dip_hip_wip_c = transform_tensor_descriptor(
124  in_grid_desc_n_di_hi_wi_c,
126  make_pad_transform(Di, InLeftPadD, InRightPadD),
127  make_pad_transform(Hi, InLeftPadH, InRightPadH),
128  make_pad_transform(Wi, InLeftPadW, InRightPadW),
132 
133  const auto in_grid_desc_n_z_do_y_ho_x_wo_c = transform_tensor_descriptor(
134  in_grid_desc_n_dip_hip_wip_c,
135  make_tuple(
137  make_embed_transform(make_tuple(Z, Do), make_tuple(WindowDilationD, WindowStrideD)),
138  make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)),
139  make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)),
143  Sequence<1, 2>{},
144  Sequence<3, 4>{},
145  Sequence<5, 6>{},
146  Sequence<7>{}));
147 
148  const auto in_grid_desc_reducemraw_reducekraw = transform_tensor_descriptor(
149  in_grid_desc_n_z_do_y_ho_x_wo_c,
150  make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, C)),
151  make_merge_transform(make_tuple(Z, Y, X))),
154 
155  const auto in_grid_desc_reducem_reducek = transform_tensor_descriptor(
156  in_grid_desc_reducemraw_reducekraw,
160 
161  // B[ReduceM]
162  const index_t No_stride = output_ncdhw_stride[0];
163  const index_t Co_stride = output_ncdhw_stride[1];
164  const index_t Do_stride = output_ncdhw_stride[2];
165  const index_t Ho_stride = output_ncdhw_stride[3];
166  const index_t Wo_stride = output_ncdhw_stride[4];
167 
168  const auto out_grid_desc_n_do_ho_wo_c = make_naive_tensor_descriptor(
169  make_tuple(N, Di, Hi, Wi, C),
170  make_tuple(No_stride, Do_stride, Ho_stride, Wo_stride, Co_stride));
171 
172  const auto out_grid_desc_reducemraw = transform_tensor_descriptor(
173  out_grid_desc_n_do_ho_wo_c,
174  make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, C))),
177 
178  const auto out_grid_desc_reducem =
179  transform_tensor_descriptor(out_grid_desc_reducemraw,
183 
184  return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem);
185  }
186 
187  using ABGridDescs =
188  decltype(MakeABGridDescriptor_A_M_K_B_M({}, {}, {}, {}, {}, {}, {}, {}, {}));
189 
191  using BGridDesc_M = remove_cvref_t<decltype(ABGridDescs{}[I1])>;
192 
193  struct Argument : public BaseArgument
194  {
195  Argument(const InDataType* p_in_dev,
196  OutDataType* p_out_dev,
197  IndexDataType* p_out_indices_dev,
198  std::vector<ck::index_t>& input_ncdhw_lengths,
199  std::vector<ck::index_t>& output_ncdhw_lengths,
200  std::vector<ck::index_t>& input_ncdhw_stride,
201  std::vector<ck::index_t>& output_ncdhw_stride,
202  std::vector<ck::index_t>&, // indices_ncdhw_stride
203  std::vector<ck::index_t>& window_spatial_zyx_lengths,
204  std::vector<ck::index_t>& window_zyx_strides,
205  std::vector<ck::index_t>& window_zyx_dilations,
206  std::vector<ck::index_t>& input_left_dhw_pads,
207  std::vector<ck::index_t>& input_right_dhw_pads)
208  : p_in_dev_{p_in_dev},
209  p_out_dev_{p_out_dev},
210  p_out_indices_dev_{p_out_indices_dev},
212  b_grid_desc_m_{},
213  input_ncdhw_lengths_{input_ncdhw_lengths},
214  output_ncdhw_lengths_{output_ncdhw_lengths},
215  input_ncdhw_stride_{input_ncdhw_stride},
216  output_ncdhw_stride_{output_ncdhw_stride}
217  {
218  const auto descs = MakeABGridDescriptor_A_M_K_B_M(input_ncdhw_lengths,
219  output_ncdhw_lengths,
220  input_ncdhw_stride,
221  output_ncdhw_stride,
222  window_spatial_zyx_lengths,
223  window_zyx_strides,
224  window_zyx_dilations,
225  input_left_dhw_pads,
226  input_right_dhw_pads);
227 
228  a_grid_desc_m_k_ = descs[I0];
229  b_grid_desc_m_ = descs[I1];
230 
231  int32_t reduceLength = window_spatial_zyx_lengths[0] * window_spatial_zyx_lengths[1] *
232  window_spatial_zyx_lengths[2];
233 
236  }
237 
238  const InDataType* p_in_dev_;
239  OutDataType* p_out_dev_;
240  IndexDataType* p_out_indices_dev_;
243 
246 
247  // for checking vector load/store
248  std::vector<ck::index_t> input_ncdhw_lengths_;
249  std::vector<ck::index_t> output_ncdhw_lengths_;
250  std::vector<ck::index_t> input_ncdhw_stride_;
251  std::vector<ck::index_t> output_ncdhw_stride_;
252  };
253 
254  struct Invoker : public BaseInvoker
255  {
256  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
257  {
258  // for NDHWC, the dim C is the fastest dimension, and is not reduced.
259  // Hence, it is in M dimension for reduction kernel.
260  static constexpr index_t InSrcOutDstVectorDim = 0; // 0: M, 1: K
261 
262  using gridwise_reduce =
264  OutDataType,
265  ComputeDataType,
266  IndexDataType,
268  BGridDesc_M,
273  false, // propagate_nan
274  BlockSize,
275  MThreadSliceSize,
276  KThreadSliceSize,
277  InSrcOutDstVectorDim,
278  InSrcOutDstVectorSize,
279  InSrcOutDstVectorSize>;
280 
281  const auto kernel =
282  kernel_reduce_threadwise<gridwise_reduce,
283  OutputIndex,
284  true, // pooling need to return global index
285  false, // don't have index input
286  InDataType,
287  OutDataType,
288  ComputeDataType,
289  IndexDataType,
291  BGridDesc_M,
294 
295  ck::index_t M = arg.a_grid_desc_m_k_.GetLength(I0);
296 
297  const index_t grid_size = (M / M_BlockTileSize);
298 
299  return launch_and_time_kernel(stream_config,
300  kernel,
301  dim3(grid_size),
302  dim3(BlockSize),
303  0,
304  arg.a_grid_desc_m_k_,
305  arg.b_grid_desc_m_,
306  arg.in_element_op_,
307  arg.acc_element_op_,
308  float(1),
309  arg.p_in_dev_,
310  nullptr,
311  float(0),
312  arg.p_out_dev_,
313  arg.p_out_indices_dev_);
314  }
315 
316  float Run(const BaseArgument* p_arg,
317  const StreamConfig& stream_config = StreamConfig{}) override
318  {
319  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
320  }
321  };
322 
323  bool IsSupportedArgument(const BaseArgument* p_arg) override
324  {
325  const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
326 
327  // C should be fastest dimension
328  if(pArg->input_ncdhw_stride_[1] != 1)
329  return false;
330 
331  for(int i = 0; i < InOutRank; ++i)
332  {
333  if(pArg->input_ncdhw_stride_[i] == 1 &&
334  pArg->input_ncdhw_lengths_[i] % InSrcOutDstVectorSize != 0)
335  return false;
336 
337  if(pArg->output_ncdhw_stride_[i] == 1 &&
338  pArg->output_ncdhw_lengths_[i] % InSrcOutDstVectorSize != 0)
339  return false;
340  }
341 
342  return true;
343  }
344 
345  virtual std::unique_ptr<BaseArgument>
346  MakeArgumentPointer(const void* p_in_dev,
347  void* p_out_dev,
348  void* p_out_indices_dev,
349  std::vector<ck::index_t> input_ncdhw_lengths,
350  std::vector<ck::index_t> window_zyx_lengths,
351  std::vector<ck::index_t> output_ncdhw_lengths,
352  std::vector<ck::index_t> input_ncdhw_stride,
353  std::vector<ck::index_t> output_ncdhw_stride,
354  std::vector<ck::index_t> indices_ncdhw_stride,
355  std::vector<ck::index_t> window_zyx_strides,
356  std::vector<ck::index_t> window_zyx_dilations,
357  std::vector<ck::index_t> input_left_dhw_pads,
358  std::vector<ck::index_t> input_right_dhw_pads,
359  std::vector<ck::index_t> pooling_dims) override
360  {
361  if(input_ncdhw_lengths.size() != InOutRank || window_zyx_lengths.size() != WindowRank ||
362  input_ncdhw_lengths.size() != InOutRank || window_zyx_strides.size() != WindowRank ||
363  window_zyx_dilations.size() != WindowRank || input_left_dhw_pads.size() != WindowRank ||
364  input_right_dhw_pads.size() != WindowRank)
365  throw std::runtime_error("dimension is incorrect");
366 
367  if(pooling_dims != std::vector<ck::index_t>{2, 3, 4})
368  throw std::runtime_error("pooling_dims only support {2, 3, 4} in pool3d so far");
369 
370  if(output_ncdhw_stride != indices_ncdhw_stride)
371  throw std::runtime_error(
372  "output_ncdhw_stride need to be equal to indices_ncdhw_stride for now");
373 
374  return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev),
375  static_cast<OutDataType*>(p_out_dev),
376  static_cast<IndexDataType*>(p_out_indices_dev),
377  input_ncdhw_lengths,
378  output_ncdhw_lengths,
379  input_ncdhw_stride,
380  output_ncdhw_stride,
381  indices_ncdhw_stride,
382  window_zyx_lengths,
383  window_zyx_strides,
384  window_zyx_dilations,
385  input_left_dhw_pads,
386  input_right_dhw_pads);
387  }
388 
389  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
390  {
391  return std::make_unique<Invoker>(Invoker{});
392  }
393 
394  std::string GetTypeString() const override
395  {
396  auto str = std::stringstream();
397 
398  // clang-format off
399  str << "DevicePool3dFwd_NDHWC_NDHWC<" << BlockSize << ",";
400  str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
401  str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
402  str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
403  // clang-format on
404 
405  return str.str();
406  }
407 };
408 
409 } // namespace device
410 } // namespace tensor_operation
411 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
ReduceTensorOp
Definition: reduction_enums.hpp:9
__global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition: gridwise_2d_reduction_threadwise.hpp:28
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition: multi_index_transform_helper.hpp:48
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:19
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
Definition: stream_config.hpp:10
Definition: gridwise_2d_reduction_threadwise.hpp:84
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: reduction_operator_mapping.hpp:20
Definition: reduction_operator_mapping.hpp:91
static std::tuple< InElementwiseOperation, AccElementwiseOperation > GetElementwiseOperator(int32_t reduceLength)
Definition: reduction_operator_mapping.hpp:96
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:194
BGridDesc_M b_grid_desc_m_
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:242
Argument(const InDataType *p_in_dev, OutDataType *p_out_dev, IndexDataType *p_out_indices_dev, std::vector< ck::index_t > &input_ncdhw_lengths, std::vector< ck::index_t > &output_ncdhw_lengths, std::vector< ck::index_t > &input_ncdhw_stride, std::vector< ck::index_t > &output_ncdhw_stride, std::vector< ck::index_t > &, std::vector< ck::index_t > &window_spatial_zyx_lengths, std::vector< ck::index_t > &window_zyx_strides, std::vector< ck::index_t > &window_zyx_dilations, std::vector< ck::index_t > &input_left_dhw_pads, std::vector< ck::index_t > &input_right_dhw_pads)
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:195
IndexDataType * p_out_indices_dev_
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:240
std::vector< ck::index_t > input_ncdhw_lengths_
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:248
std::vector< ck::index_t > output_ncdhw_lengths_
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:249
const InDataType * p_in_dev_
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:238
std::vector< ck::index_t > output_ncdhw_stride_
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:251
InElementwiseOperation in_element_op_
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:244
AGridDesc_M_K a_grid_desc_m_k_
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:241
std::vector< ck::index_t > input_ncdhw_stride_
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:250
AccElementwiseOperation acc_element_op_
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:245
OutDataType * p_out_dev_
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:239
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:255
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:316
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:256
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:44
static constexpr index_t WindowRank
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:53
static constexpr auto I2
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:47
static constexpr auto I1
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:46
static constexpr auto I3
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:48
static constexpr ck::index_t M_BlockTileSize
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:63
remove_cvref_t< decltype(ABGridDescs{}[I0])> AGridDesc_M_K
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:190
static auto MakeABGridDescriptor_A_M_K_B_M(std::vector< ck::index_t > input_ncdhw_lengths, std::vector< ck::index_t > output_ncdhw_lengths, std::vector< ck::index_t > input_ncdhw_stride, std::vector< ck::index_t > output_ncdhw_stride, std::vector< ck::index_t > window_spatial_zyx_lengths, std::vector< ck::index_t > window_zyx_strides, std::vector< ck::index_t > window_zyx_dilations, std::vector< ck::index_t > input_left_dhw_pads, std::vector< ck::index_t > input_right_dhw_pads)
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:66
static constexpr ck::index_t K_BlockTileSize
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:64
decltype(MakeABGridDescriptor_A_M_K_B_M({}, {}, {}, {}, {}, {}, {}, {}, {})) ABGridDescs
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:188
static constexpr index_t InOutRank
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:52
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:323
static constexpr auto I4
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:49
typename reduce_unary_operator< ReduceOpId, true, true >::AccElementwiseOperation AccElementwiseOperation
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:61
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:389
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in_dev, void *p_out_dev, void *p_out_indices_dev, std::vector< ck::index_t > input_ncdhw_lengths, std::vector< ck::index_t > window_zyx_lengths, std::vector< ck::index_t > output_ncdhw_lengths, std::vector< ck::index_t > input_ncdhw_stride, std::vector< ck::index_t > output_ncdhw_stride, std::vector< ck::index_t > indices_ncdhw_stride, std::vector< ck::index_t > window_zyx_strides, std::vector< ck::index_t > window_zyx_dilations, std::vector< ck::index_t > input_left_dhw_pads, std::vector< ck::index_t > input_right_dhw_pads, std::vector< ck::index_t > pooling_dims) override
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:346
typename reduce_unary_operator< ReduceOpId, true, true >::InElementwiseOperation InElementwiseOperation
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:58
remove_cvref_t< decltype(ABGridDescs{}[I1])> BGridDesc_M
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:191
std::string GetTypeString() const override
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:394
typename reduce_binary_operator< ReduceOpId >::opType ReduceOperation
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:55
static constexpr auto I5
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:50
static constexpr auto I0
Definition: device_pool3d_fwd_ndhwc_ndhwc.hpp:45
Definition: device_pool_fwd.hpp:25