/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_pool2d_fwd_nhwc_nhwc.hpp Source File
device_pool2d_fwd_nhwc_nhwc.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
18 
19 namespace ck {
20 namespace tensor_operation {
21 namespace device {
22 
23 template <typename InDataType,
24  typename OutDataType,
25  typename IndexDataType, // enable if OutputIndex == true
26  typename ComputeDataType,
27  ck::ReduceTensorOp ReduceOpId,
28  bool OutputIndex,
29  ck::index_t BlockSize,
30  ck::index_t MThreadClusterSize,
31  ck::index_t KThreadClusterSize,
32  ck::index_t MThreadSliceSize,
33  ck::index_t KThreadSliceSize,
34  ck::index_t InSrcOutDstVectorSize>
36  2,
37  InDataType,
38  OutDataType,
39  IndexDataType,
40  tensor_layout::convolution::NHWC,
41  tensor_layout::convolution::NHWC,
42  ReduceOpId,
43  OutputIndex>
44 {
45  static constexpr auto I0 = Number<0>{};
46  static constexpr auto I1 = Number<1>{};
47 
48  static constexpr index_t InOutRank = 4;
49  static constexpr index_t WindowRank = 2;
50 
52 
55 
58 
59  static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
60  static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
61 
62  static auto MakeABGridDescriptor_A_M_K_B_M(std::vector<ck::index_t> input_nchw_lengths,
63  std::vector<ck::index_t> output_nchw_lengths,
64  std::vector<ck::index_t> input_nchw_stride,
65  std::vector<ck::index_t> output_nchw_stride,
66  std::vector<ck::index_t> window_spatial_yx_lengths,
67  std::vector<ck::index_t> window_yx_strides,
68  std::vector<ck::index_t> window_yx_dilations,
69  std::vector<ck::index_t> input_left_hw_pads,
70  std::vector<ck::index_t> input_right_hw_pads)
71  {
72  const index_t N = input_nchw_lengths[0];
73  const index_t C = input_nchw_lengths[1];
74  const index_t Hi = input_nchw_lengths[2];
75  const index_t Wi = input_nchw_lengths[3];
76 
77  const index_t Ho = output_nchw_lengths[2];
78  const index_t Wo = output_nchw_lengths[3];
79  const index_t Y = window_spatial_yx_lengths[0];
80  const index_t X = window_spatial_yx_lengths[1];
81 
82  const index_t WindowStrideH = window_yx_strides[0];
83  const index_t WindowStrideW = window_yx_strides[1];
84 
85  const index_t WindowDilationH = window_yx_dilations[0];
86  const index_t WindowDilationW = window_yx_dilations[1];
87 
88  const index_t InLeftPadH = input_left_hw_pads[0];
89  const index_t InLeftPadW = input_left_hw_pads[1];
90 
91  const index_t InRightPadH = input_right_hw_pads[0];
92  const index_t InRightPadW = input_right_hw_pads[1];
93 
94  const index_t MRaw = N * Ho * Wo * C;
95  const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw;
96 
97  const index_t KRaw = Y * X;
98  const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw;
99 
100  // A[ReduceM, ReduceK]
101  const index_t Ni_stride = input_nchw_stride[0];
102  const index_t Ci_stride = input_nchw_stride[1];
103  const index_t Hi_stride = input_nchw_stride[2];
104  const index_t Wi_stride = input_nchw_stride[3];
105 
106  const auto in_grid_desc_n_hi_wi_c = make_naive_tensor_descriptor(
107  make_tuple(N, Hi, Wi, C), make_tuple(Ni_stride, Hi_stride, Wi_stride, Ci_stride));
108 
109  const auto in_grid_desc_n_hip_wip_c = transform_tensor_descriptor(
110  in_grid_desc_n_hi_wi_c,
112  make_pad_transform(Hi, InLeftPadH, InRightPadH),
113  make_pad_transform(Wi, InLeftPadW, InRightPadW),
117 
118  const auto in_grid_desc_n_y_ho_x_wo_c = transform_tensor_descriptor(
119  in_grid_desc_n_hip_wip_c,
120  make_tuple(
122  make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)),
123  make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)),
127 
128  const auto in_grid_desc_reducemraw_reducekraw =
129  transform_tensor_descriptor(in_grid_desc_n_y_ho_x_wo_c,
134 
135  const auto in_grid_desc_reducem_reducek = transform_tensor_descriptor(
136  in_grid_desc_reducemraw_reducekraw,
140 
141  // B[ReduceM]
142  const index_t No_stride = output_nchw_stride[0];
143  const index_t Co_stride = output_nchw_stride[1];
144  const index_t Ho_stride = output_nchw_stride[2];
145  const index_t Wo_stride = output_nchw_stride[3];
146 
147  const auto out_grid_desc_n_ho_wo_c = make_naive_tensor_descriptor(
148  make_tuple(N, Hi, Wi, C), make_tuple(No_stride, Ho_stride, Wo_stride, Co_stride));
149 
150  const auto out_grid_desc_reducemraw =
151  transform_tensor_descriptor(out_grid_desc_n_ho_wo_c,
155 
156  const auto out_grid_desc_reducem =
157  transform_tensor_descriptor(out_grid_desc_reducemraw,
161 
162  return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem);
163  }
164 
165  using ABGridDescs =
166  decltype(MakeABGridDescriptor_A_M_K_B_M({}, {}, {}, {}, {}, {}, {}, {}, {}));
167 
169  using BGridDesc_M = remove_cvref_t<decltype(ABGridDescs{}[I1])>;
170 
171  struct Argument : public BaseArgument
172  {
173  Argument(const InDataType* p_in_dev,
174  OutDataType* p_out_dev,
175  IndexDataType* p_out_indices_dev,
176  std::vector<ck::index_t>& input_nchw_lengths,
177  std::vector<ck::index_t>& output_nchw_lengths,
178  std::vector<ck::index_t>& input_nchw_stride,
179  std::vector<ck::index_t>& output_nchw_stride,
180  std::vector<ck::index_t>&, // indices_nchw_stride
181  std::vector<ck::index_t>& window_spatial_yx_lengths,
182  std::vector<ck::index_t>& window_yx_strides,
183  std::vector<ck::index_t>& window_yx_dilations,
184  std::vector<ck::index_t>& input_left_hw_pads,
185  std::vector<ck::index_t>& input_right_hw_pads)
186  : p_in_dev_{p_in_dev},
187  p_out_dev_{p_out_dev},
188  p_out_indices_dev_{p_out_indices_dev},
190  b_grid_desc_m_{},
191  input_nchw_lengths_{input_nchw_lengths},
192  output_nchw_lengths_{output_nchw_lengths},
193  input_nchw_stride_{input_nchw_stride},
194  output_nchw_stride_{output_nchw_stride}
195  {
196  const auto descs = MakeABGridDescriptor_A_M_K_B_M(input_nchw_lengths,
197  output_nchw_lengths,
198  input_nchw_stride,
199  output_nchw_stride,
200  window_spatial_yx_lengths,
201  window_yx_strides,
202  window_yx_dilations,
203  input_left_hw_pads,
204  input_right_hw_pads);
205 
206  a_grid_desc_m_k_ = descs[I0];
207  b_grid_desc_m_ = descs[I1];
208 
209  int32_t reduceLength = window_spatial_yx_lengths[0] * window_spatial_yx_lengths[1];
210 
213  }
214 
215  const InDataType* p_in_dev_;
216  OutDataType* p_out_dev_;
217  IndexDataType* p_out_indices_dev_;
220 
223 
224  // for checking vector load/store
225  std::vector<ck::index_t> input_nchw_lengths_;
226  std::vector<ck::index_t> output_nchw_lengths_;
227  std::vector<ck::index_t> input_nchw_stride_;
228  std::vector<ck::index_t> output_nchw_stride_;
229  };
230 
231  struct Invoker : public BaseInvoker
232  {
233  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
234  {
235  // for NHWC, the dim C is the fastest dimension, and is not reduced.
236  // Hence, it is in M dimension for reduction kernel.
237  static constexpr index_t InSrcOutDstVectorDim = 0; // 0: M, 1: K
238 
239  using gridwise_reduce =
241  OutDataType,
242  ComputeDataType,
243  IndexDataType,
245  BGridDesc_M,
250  false, // propagate_nan
251  BlockSize,
252  MThreadSliceSize,
253  KThreadSliceSize,
254  InSrcOutDstVectorDim,
255  InSrcOutDstVectorSize,
256  InSrcOutDstVectorSize>;
257 
258  const auto kernel =
259  kernel_reduce_threadwise<gridwise_reduce,
260  OutputIndex,
261  true, // pooling need to return global index
262  false, // don't have index input
263  InDataType,
264  OutDataType,
265  ComputeDataType,
266  IndexDataType,
268  BGridDesc_M,
271 
272  ck::index_t M = arg.a_grid_desc_m_k_.GetLength(I0);
273 
274  const index_t grid_size = (M / M_BlockTileSize);
275 
276  return launch_and_time_kernel(stream_config,
277  kernel,
278  dim3(grid_size),
279  dim3(BlockSize),
280  0,
281  arg.a_grid_desc_m_k_,
282  arg.b_grid_desc_m_,
283  arg.in_element_op_,
284  arg.acc_element_op_,
285  float(1),
286  arg.p_in_dev_,
287  nullptr,
288  float(0),
289  arg.p_out_dev_,
290  arg.p_out_indices_dev_);
291  }
292 
293  float Run(const BaseArgument* p_arg,
294  const StreamConfig& stream_config = StreamConfig{}) override
295  {
296  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
297  }
298  };
299 
300  bool IsSupportedArgument(const BaseArgument* p_arg) override
301  {
302  const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
303 
304  // C should be fastest dimension
305  if(pArg->input_nchw_stride_[1] != 1)
306  return false;
307 
308  for(int i = 0; i < InOutRank; ++i)
309  {
310  if(pArg->input_nchw_stride_[i] == 1 &&
311  pArg->input_nchw_lengths_[i] % InSrcOutDstVectorSize != 0)
312  return false;
313 
314  if(pArg->output_nchw_stride_[i] == 1 &&
315  pArg->output_nchw_lengths_[i] % InSrcOutDstVectorSize != 0)
316  return false;
317  }
318 
319  return true;
320  }
321 
322  virtual std::unique_ptr<BaseArgument>
323  MakeArgumentPointer(const void* p_in_dev,
324  void* p_out_dev,
325  void* p_out_indices_dev,
326  std::vector<ck::index_t> input_nchw_lengths,
327  std::vector<ck::index_t> window_yx_lengths,
328  std::vector<ck::index_t> output_nchw_lengths,
329  std::vector<ck::index_t> input_nchw_stride,
330  std::vector<ck::index_t> output_nchw_stride,
331  std::vector<ck::index_t> indices_nchw_stride,
332  std::vector<ck::index_t> window_yx_strides,
333  std::vector<ck::index_t> window_yx_dilations,
334  std::vector<ck::index_t> input_left_hw_pads,
335  std::vector<ck::index_t> input_right_hw_pads,
336  std::vector<ck::index_t> pooling_dims) override
337  {
338  if(input_nchw_lengths.size() != InOutRank || window_yx_lengths.size() != WindowRank ||
339  input_nchw_lengths.size() != InOutRank || window_yx_strides.size() != WindowRank ||
340  window_yx_dilations.size() != WindowRank || input_left_hw_pads.size() != WindowRank ||
341  input_right_hw_pads.size() != WindowRank)
342  throw std::runtime_error("dimension is incorrect");
343 
344  if(pooling_dims != std::vector<ck::index_t>{2, 3})
345  throw std::runtime_error("pooling_dims only support {2, 3} in pool2d so far");
346 
347  if(output_nchw_stride != indices_nchw_stride)
348  throw std::runtime_error(
349  "output_nchw_stride need to be equal to indices_nchw_stride for now");
350 
351  return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_dev),
352  static_cast<OutDataType*>(p_out_dev),
353  static_cast<IndexDataType*>(p_out_indices_dev),
354  input_nchw_lengths,
355  output_nchw_lengths,
356  input_nchw_stride,
357  output_nchw_stride,
358  indices_nchw_stride,
359  window_yx_lengths,
360  window_yx_strides,
361  window_yx_dilations,
362  input_left_hw_pads,
363  input_right_hw_pads);
364  }
365 
366  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
367  {
368  return std::make_unique<Invoker>(Invoker{});
369  }
370 
371  std::string GetTypeString() const override
372  {
373  auto str = std::stringstream();
374 
375  // clang-format off
376  str << "DevicePool2dFwd_NHWC_NHWC<" << BlockSize << ",";
377  str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
378  str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
379  str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
380  // clang-format on
381 
382  return str.str();
383  }
384 };
385 
386 } // namespace device
387 } // namespace tensor_operation
388 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
ReduceTensorOp
Definition: reduction_enums.hpp:9
__global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition: gridwise_2d_reduction_threadwise.hpp:28
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition: multi_index_transform_helper.hpp:48
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:19
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
Definition: stream_config.hpp:10
Definition: gridwise_2d_reduction_threadwise.hpp:84
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: reduction_operator_mapping.hpp:20
Definition: reduction_operator_mapping.hpp:91
static std::tuple< InElementwiseOperation, AccElementwiseOperation > GetElementwiseOperator(int32_t reduceLength)
Definition: reduction_operator_mapping.hpp:96
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:172
std::vector< ck::index_t > output_nchw_lengths_
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:226
OutDataType * p_out_dev_
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:216
AGridDesc_M_K a_grid_desc_m_k_
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:218
const InDataType * p_in_dev_
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:215
std::vector< ck::index_t > input_nchw_stride_
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:227
InElementwiseOperation in_element_op_
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:221
std::vector< ck::index_t > output_nchw_stride_
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:228
IndexDataType * p_out_indices_dev_
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:217
Argument(const InDataType *p_in_dev, OutDataType *p_out_dev, IndexDataType *p_out_indices_dev, std::vector< ck::index_t > &input_nchw_lengths, std::vector< ck::index_t > &output_nchw_lengths, std::vector< ck::index_t > &input_nchw_stride, std::vector< ck::index_t > &output_nchw_stride, std::vector< ck::index_t > &, std::vector< ck::index_t > &window_spatial_yx_lengths, std::vector< ck::index_t > &window_yx_strides, std::vector< ck::index_t > &window_yx_dilations, std::vector< ck::index_t > &input_left_hw_pads, std::vector< ck::index_t > &input_right_hw_pads)
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:173
BGridDesc_M b_grid_desc_m_
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:219
std::vector< ck::index_t > input_nchw_lengths_
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:225
AccElementwiseOperation acc_element_op_
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:222
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:232
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:233
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:293
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:44
std::string GetTypeString() const override
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:371
remove_cvref_t< decltype(ABGridDescs{}[I0])> AGridDesc_M_K
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:168
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:366
static auto MakeABGridDescriptor_A_M_K_B_M(std::vector< ck::index_t > input_nchw_lengths, std::vector< ck::index_t > output_nchw_lengths, std::vector< ck::index_t > input_nchw_stride, std::vector< ck::index_t > output_nchw_stride, std::vector< ck::index_t > window_spatial_yx_lengths, std::vector< ck::index_t > window_yx_strides, std::vector< ck::index_t > window_yx_dilations, std::vector< ck::index_t > input_left_hw_pads, std::vector< ck::index_t > input_right_hw_pads)
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:62
typename reduce_unary_operator< ReduceOpId, true, true >::AccElementwiseOperation AccElementwiseOperation
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:57
decltype(MakeABGridDescriptor_A_M_K_B_M({}, {}, {}, {}, {}, {}, {}, {}, {})) ABGridDescs
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:166
static constexpr index_t WindowRank
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:49
static constexpr ck::index_t M_BlockTileSize
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:59
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in_dev, void *p_out_dev, void *p_out_indices_dev, std::vector< ck::index_t > input_nchw_lengths, std::vector< ck::index_t > window_yx_lengths, std::vector< ck::index_t > output_nchw_lengths, std::vector< ck::index_t > input_nchw_stride, std::vector< ck::index_t > output_nchw_stride, std::vector< ck::index_t > indices_nchw_stride, std::vector< ck::index_t > window_yx_strides, std::vector< ck::index_t > window_yx_dilations, std::vector< ck::index_t > input_left_hw_pads, std::vector< ck::index_t > input_right_hw_pads, std::vector< ck::index_t > pooling_dims) override
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:323
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:300
remove_cvref_t< decltype(ABGridDescs{}[I1])> BGridDesc_M
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:169
typename reduce_binary_operator< ReduceOpId >::opType ReduceOperation
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:51
static constexpr ck::index_t K_BlockTileSize
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:60
static constexpr auto I0
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:45
static constexpr index_t InOutRank
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:48
typename reduce_unary_operator< ReduceOpId, true, true >::InElementwiseOperation InElementwiseOperation
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:54
static constexpr auto I1
Definition: device_pool2d_fwd_nhwc_nhwc.hpp:46
Definition: device_pool_fwd.hpp:25