/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_avgpool2d_bwd_nhwc_nhwc.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_avgpool2d_bwd_nhwc_nhwc.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_avgpool2d_bwd_nhwc_nhwc.hpp Source File
device_avgpool2d_bwd_nhwc_nhwc.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
17 
18 namespace ck {
19 namespace tensor_operation {
20 namespace device {
21 
22 // In and Din = [N, C, Hi, Wi]
23 // Out and Dout = [N, C, Ho, Wo]
24 // Out = AvgPool2dFwd(In)
25 // Din = AvgPool2dBwd(Dout)
26 // Pooling dimension = H, W
27 template <typename DOutDataType,
28  typename DInDataType,
29  typename ComputeDataType,
30  ck::index_t BlockSize,
31  ck::index_t MThreadClusterSize,
32  ck::index_t KThreadClusterSize,
33  ck::index_t MThreadSliceSize,
34  ck::index_t KThreadSliceSize,
35  ck::index_t InSrcOutDstVectorSize>
37  DOutDataType,
38  DInDataType,
39  tensor_layout::convolution::NHWC,
40  tensor_layout::convolution::NHWC>
41 {
42 
43  static constexpr ck::index_t NDimSpatial = 2;
44 
45  static constexpr auto I0 = Number<0>{};
46  static constexpr auto I1 = Number<1>{};
47 
48  static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
49  static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
50 
51  static auto
52  Make2DGridDescriptor_Out_M_K_In_M(const std::vector<ck::index_t>& dout_n_c_wos_lengths,
53  const std::vector<ck::index_t>& din_n_c_wos_length,
54  const std::vector<ck::index_t>& dout_n_c_wos_strides,
55  const std::vector<ck::index_t>& din_n_c_wos_strides,
56  const std::vector<ck::index_t>& window_lengths,
57  const std::vector<ck::index_t>& window_strides,
58  const std::vector<ck::index_t>& window_dilations,
59  const std::vector<ck::index_t>& input_left_pads,
60  const std::vector<ck::index_t>& input_right_pads,
61  const std::vector<ck::index_t>& tildes)
62  {
63  index_t i_ytilde = tildes[0];
64  index_t i_xtilde = tildes[1];
65 
66  const index_t N = dout_n_c_wos_lengths[0];
67  const index_t C = dout_n_c_wos_lengths[1];
68  const index_t Ho = dout_n_c_wos_lengths[2];
69  const index_t Wo = dout_n_c_wos_lengths[3];
70 
71  const index_t Hi = din_n_c_wos_length[2];
72  const index_t Wi = din_n_c_wos_length[3];
73 
74  const index_t Y = window_lengths[0];
75  const index_t X = window_lengths[1];
76 
77  const index_t InLeftPadH = input_left_pads[0];
78  const index_t InLeftPadW = input_left_pads[1];
79 
80  const index_t InRightPadH = input_right_pads[0];
81  const index_t InRightPadW = input_right_pads[1];
82 
83  const index_t ConvStrideH = window_strides[0];
84  const index_t ConvStrideW = window_strides[1];
85 
86  const index_t ConvDilationH = window_dilations[0];
87  const index_t ConvDilationW = window_dilations[1];
88 
89  const index_t Ni_stride = dout_n_c_wos_strides[0];
90  const index_t Ci_stride = dout_n_c_wos_strides[1];
91  const index_t Ho_stride = dout_n_c_wos_strides[2];
92  const index_t Wo_stride = dout_n_c_wos_strides[3];
93 
94  const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
95  const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
96 
97  const auto YTilde = ConvStrideH / GcdStrideDilationH;
98  const auto XTilde = ConvStrideW / GcdStrideDilationW;
99 
100  const auto YDot = math::integer_divide_ceil(Y, YTilde);
101  const auto XDot = math::integer_divide_ceil(X, XTilde);
102 
103  const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
104  const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
105 
106  // only work on Tildes that contribute to non-padding area of input tensor
107  const auto IHTildeSliceBegin = math::integer_divide_floor(
108  math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
109  const auto IWTildeSliceBegin = math::integer_divide_floor(
110  math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
111 
112  const auto IHTildeSliceEnd =
113  math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
114  const auto IWTildeSliceEnd =
115  math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
116 
117  const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
118  const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
119 
120  // ReduceK is different for each Reduce
121  const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
122  const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
123 
124  // Problem size of reduction kernel
125  const index_t MRaw = N * HTildeSlice * WTildeSlice * C;
126  const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw;
127 
128  const index_t KRaw = YDotSlice * XDotSlice;
129  const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw;
130 
131  const auto out_n_ho_wo_c_grid_desc = make_naive_tensor_descriptor(
132  make_tuple(N, Ho, Wo, C), make_tuple(Ni_stride, Ho_stride, Wo_stride, Ci_stride));
133 
134  // Out[ReduceM, ReduceK]
135  const auto out_n_hop_wop_c_grid_desc = transform_tensor_descriptor(
136  out_n_ho_wo_c_grid_desc,
138  make_pad_transform(Ho, I0, I0),
139  make_pad_transform(Wo, I0, I0),
143 
144  const auto out_n_ydot_htilde_xdot_wtilde_c_grid_desc = transform_tensor_descriptor(
145  out_n_hop_wop_c_grid_desc,
147  make_embed_transform(make_tuple(YDot, HTilde),
148  make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
149  make_embed_transform(make_tuple(XDot, WTilde),
150  make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
154 
155  const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc =
157  out_n_ydot_htilde_xdot_wtilde_c_grid_desc,
159  make_slice_transform(YDot, I0, YDotSlice),
160  make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
161  make_slice_transform(XDot, I0, XDotSlice),
162  make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
165  Sequence<1>{},
166  Sequence<2>{},
167  Sequence<3>{},
168  Sequence<4>{},
169  Sequence<5>{}),
171  Sequence<1>{},
172  Sequence<2>{},
173  Sequence<3>{},
174  Sequence<4>{},
175  Sequence<5>{}));
176 
177  const auto out_grid_desc_reducemraw_reducekraw = transform_tensor_descriptor(
178  out_n_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc,
179  make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice, C)),
180  make_merge_transform(make_tuple(YDotSlice, XDotSlice))),
183 
184  const auto out_grid_desc_reducem_reducek = transform_tensor_descriptor(
185  out_grid_desc_reducemraw_reducekraw,
189 
190  // In[ReduceM]
191  const auto in_n_hi_wi_c_grid_desc =
193  make_tuple(din_n_c_wos_strides[0],
194  din_n_c_wos_strides[2],
195  din_n_c_wos_strides[3],
196  din_n_c_wos_strides[1]));
197 
198  const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
199  in_n_hi_wi_c_grid_desc,
201  make_pad_transform(Hi, InLeftPadH, InRightPadH),
202  make_pad_transform(Wi, InLeftPadW, InRightPadW),
206 
207  const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
208  in_n_hip_wip_c_grid_desc,
210  make_embed_transform(make_tuple(YTilde, HTilde),
211  make_tuple(ConvDilationH, ConvStrideH)),
212  make_embed_transform(make_tuple(XTilde, WTilde),
213  make_tuple(ConvDilationW, ConvStrideW)),
217 
218  const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
219  in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
221  make_freeze_transform(i_ytilde),
222  make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
223  make_freeze_transform(i_xtilde),
224  make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
227  Sequence<1>{},
228  Sequence<2>{},
229  Sequence<3>{},
230  Sequence<4>{},
231  Sequence<5>{}),
233  Sequence<>{},
234  Sequence<1>{},
235  Sequence<>{},
236  Sequence<2>{},
237  Sequence<3>{}));
238 
239  const auto in_grid_desc_reducemraw = transform_tensor_descriptor(
240  in_n_htildeslice_wtildeslice_c_grid_desc,
241  make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice, C))),
244 
245  const auto in_grid_desc_reducem =
246  transform_tensor_descriptor(in_grid_desc_reducemraw,
250 
251  return make_tuple(out_grid_desc_reducem_reducek, in_grid_desc_reducem);
252  }
253 
254  using DoutDinGridDesc = decltype(Make2DGridDescriptor_Out_M_K_In_M({0, 0, 0, 0},
255  {0, 0, 0, 0},
256  {0, 0, 0, 0},
257  {0, 0, 0, 0},
258  {0, 0},
259  {0, 0},
260  {0, 0},
261  {0, 0},
262  {0, 0},
263  {0, 0}));
264 
267 
268  // FIXME
269  // for NHWC, the dim C is the fastest dimension, and is not reduced.
270  // Hence, it is in M dimension for reduction kernel.
271  static constexpr index_t OutSrcInDstVectorDim = 0; // 0: M, 1: K
272 
275 
277  DInDataType,
278  ComputeDataType,
279  int,
282  reduce::Add,
283  PassThrough,
284  Div,
286  false, // propagate_nan
287  BlockSize,
288  MThreadSliceSize,
289  KThreadSliceSize,
291  InSrcOutDstVectorSize,
292  InSrcOutDstVectorSize>;
293 
294  struct Argument : public BaseArgument
295  {
296  Argument(const DOutDataType* p_dout,
297  DInDataType* p_din,
298  std::vector<ck::index_t> dout_n_c_wos_lengths,
299  std::vector<ck::index_t> din_n_c_wos_length,
300  std::vector<ck::index_t> dout_n_c_wos_strides,
301  std::vector<ck::index_t> din_n_c_wos_strides,
302  std::vector<ck::index_t> window_lengths,
303  std::vector<ck::index_t> window_strides,
304  std::vector<ck::index_t> window_dilations,
305  std::vector<ck::index_t> input_left_pads,
306  std::vector<ck::index_t> input_right_pads)
307  : p_dout_grid_{p_dout},
308  p_din_grid_{p_din},
309  dout_n_c_wos_lengths_{dout_n_c_wos_lengths},
310  din_n_c_wos_length_{din_n_c_wos_length},
311  dout_n_c_wos_strides_{dout_n_c_wos_strides},
312  din_n_c_wos_strides_{din_n_c_wos_strides},
313  num_reduce_{1},
314  div_element_op_{window_lengths[0] * window_lengths[1]}
315  {
316  std::vector<ck::index_t> Tildes(NDimSpatial);
317  for(int i = 0; i < NDimSpatial; ++i)
318  {
319  int GcdStrideDilation = math::gcd(window_strides[i], window_dilations[i]);
320  Tildes[i] = window_strides[i] / GcdStrideDilation;
321  num_reduce_ *= Tildes[i];
322  }
323 
324  for(index_t i_ytilde = 0; i_ytilde < Tildes[0]; ++i_ytilde)
325  {
326  for(index_t i_xtilde = 0; i_xtilde < Tildes[1]; ++i_xtilde)
327  {
328  const auto YDotSlice =
329  math::integer_divide_ceil(window_lengths[0] - i_ytilde, Tildes[0]);
330  const auto XDotSlice =
331  math::integer_divide_ceil(window_lengths[1] - i_xtilde, Tildes[1]);
332 
333  if(YDotSlice * XDotSlice <= 0)
334  {
335  continue;
336  }
337 
338  const auto dout_din_grid_desc =
339  Make2DGridDescriptor_Out_M_K_In_M(dout_n_c_wos_lengths,
340  din_n_c_wos_length,
341  dout_n_c_wos_strides,
342  din_n_c_wos_strides,
343  window_lengths,
344  window_strides,
345  window_dilations,
346  input_left_pads,
347  input_right_pads,
348  {i_ytilde, i_xtilde});
349 
350  dout_grid_desc_m_k_container_.push_back(dout_din_grid_desc[I0]);
351  din_grid_desc_m_container_.push_back(dout_din_grid_desc[I1]);
352  }
353  }
354  }
355 
356  const DOutDataType* p_dout_grid_;
357  DInDataType* p_din_grid_;
358  std::vector<ck::index_t> dout_n_c_wos_lengths_;
359  std::vector<ck::index_t> din_n_c_wos_length_;
360  std::vector<ck::index_t> dout_n_c_wos_strides_;
361  std::vector<ck::index_t> din_n_c_wos_strides_;
362 
364  std::vector<DoutGridDesc_M_K> dout_grid_desc_m_k_container_;
365  std::vector<DinGridDesc_M> din_grid_desc_m_container_;
366 
368  };
369 
370  struct Invoker : public BaseInvoker
371  {
372  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
373  {
374  float ave_time = 0;
375 
376  for(index_t i = 0; i < arg.num_reduce_; i++)
377  {
378  const auto kernel = kernel_reduce_threadwise<gridwise_reduce,
379  false,
380  false,
381  false, // don't have index input
382  DOutDataType,
383  DInDataType,
384  ComputeDataType,
385  int,
388  PassThrough,
389  Div>;
390 
391  ck::index_t M = arg.dout_grid_desc_m_k_container_[i].GetLength(I0);
392  const index_t grid_size = (M / M_BlockTileSize);
393 
394  ave_time += launch_and_time_kernel(stream_config,
395  kernel,
396  dim3(grid_size),
397  dim3(BlockSize),
398  0,
401  PassThrough{},
402  arg.div_element_op_,
403  float(1),
404  arg.p_dout_grid_,
405  nullptr,
406  float(0),
407  arg.p_din_grid_,
408  nullptr);
409  }
410 
411  return ave_time;
412  }
413 
414  float Run(const BaseArgument* p_arg,
415  const StreamConfig& stream_config = StreamConfig{}) override
416  {
417  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
418  }
419  };
420 
421  static bool IsSupportedArgument(const Argument& arg)
422  {
423  constexpr index_t Rank = NDimSpatial + 2;
424  int doutFastestDim = -1;
425  int dinFastestDim = -1;
426 
427  for(int i = 0; i < Rank; ++i)
428  {
429  if(arg.dout_n_c_wos_strides_[i] == 1)
430  doutFastestDim = i;
431  if(arg.din_n_c_wos_strides_[i] == 1)
432  dinFastestDim = i;
433  }
434  if(InSrcOutDstVectorSize != 1 && (dinFastestDim != 1 || doutFastestDim != 1))
435  {
436  return false;
437  }
438  if(doutFastestDim == -1 || dinFastestDim == -1)
439  {
440  if constexpr(InSrcOutDstVectorSize != 1)
441  return false;
442  }
443  else
444  {
445  if(arg.dout_n_c_wos_lengths_[doutFastestDim] % InSrcOutDstVectorSize != 0)
446  return false;
447  if(arg.din_n_c_wos_length_[dinFastestDim] % InSrcOutDstVectorSize != 0)
448  return false;
449  }
450  return true;
451  }
452 
453  bool IsSupportedArgument(const BaseArgument* p_arg) override
454  {
455  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
456  }
457 
458  std::unique_ptr<BaseArgument>
459  MakeArgumentPointer(const void* p_dout,
460  void* p_din,
461  std::vector<ck::index_t> dout_n_c_wos_lengths,
462  std::vector<ck::index_t> din_n_c_wos_length,
463  std::vector<ck::index_t> dout_n_c_wos_strides,
464  std::vector<ck::index_t> din_n_c_wos_strides,
465  std::vector<ck::index_t> window_lengths,
466  std::vector<ck::index_t> window_strides,
467  std::vector<ck::index_t> window_dilations,
468  std::vector<ck::index_t> input_left_pads,
469  std::vector<ck::index_t> input_right_pads) override
470  {
471  constexpr index_t Rank = NDimSpatial + 2;
472 
473  if(dout_n_c_wos_strides.size() != Rank || din_n_c_wos_strides.size() != Rank ||
474  dout_n_c_wos_lengths.size() != Rank || din_n_c_wos_length.size() != Rank)
475  {
476  throw std::runtime_error("dimension of [dout|din]_n_c_wos_strides or "
477  "[dout|din]_n_c_wos_lengths is not equal to Rank");
478  }
479 
480  if(window_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial ||
481  window_dilations.size() != NDimSpatial || input_left_pads.size() != NDimSpatial ||
482  input_right_pads.size() != NDimSpatial)
483  {
484  throw std::runtime_error(
485  "dimension of [window_lengths, window_strides, window_dilations, input_left_pads, "
486  "input_right_pads] is not equal to Rank");
487  }
488  return std::make_unique<Argument>(static_cast<const DOutDataType*>(p_dout),
489  static_cast<DInDataType*>(p_din),
490  dout_n_c_wos_lengths,
491  din_n_c_wos_length,
492  dout_n_c_wos_strides,
493  din_n_c_wos_strides,
494  window_lengths,
495  window_strides,
496  window_dilations,
497  input_left_pads,
498  input_right_pads);
499  }
500 
501  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
502  {
503  return std::make_unique<Invoker>(Invoker{});
504  }
505 
506  std::string GetTypeString() const override
507  {
508  auto str = std::stringstream();
509 
510  // clang-format off
511  str << "DeviceAvgPool2dBwd<" << BlockSize << ",";
512  str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
513  str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
514  str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
515  // clang-format on
516 
517  return str.str();
518  }
519 };
520 
521 } // namespace device
522 } // namespace tensor_operation
523 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto integer_divide_floor(X x, Y y)
Definition: math.hpp:66
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
__host__ constexpr __device__ index_t gcd(index_t x, index_t y)
Definition: math.hpp:154
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition: gridwise_2d_reduction_threadwise.hpp:28
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
__host__ constexpr __device__ auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition: multi_index_transform_helper.hpp:48
__host__ constexpr __device__ auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition: multi_index_transform_helper.hpp:110
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:19
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
Definition: stream_config.hpp:10
Definition: gridwise_2d_reduction_threadwise.hpp:84
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: reduction_operator.hpp:37
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:295
int num_reduce_
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:363
const DOutDataType * p_dout_grid_
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:356
std::vector< ck::index_t > dout_n_c_wos_strides_
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:360
std::vector< ck::index_t > dout_n_c_wos_lengths_
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:358
Argument(const DOutDataType *p_dout, DInDataType *p_din, std::vector< ck::index_t > dout_n_c_wos_lengths, std::vector< ck::index_t > din_n_c_wos_length, std::vector< ck::index_t > dout_n_c_wos_strides, std::vector< ck::index_t > din_n_c_wos_strides, std::vector< ck::index_t > window_lengths, std::vector< ck::index_t > window_strides, std::vector< ck::index_t > window_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads)
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:296
DInDataType * p_din_grid_
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:357
std::vector< DinGridDesc_M > din_grid_desc_m_container_
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:365
std::vector< DoutGridDesc_M_K > dout_grid_desc_m_k_container_
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:364
std::vector< ck::index_t > din_n_c_wos_strides_
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:361
Div div_element_op_
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:367
std::vector< ck::index_t > din_n_c_wos_length_
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:359
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:371
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:372
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:414
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:41
remove_cvref_t< tuple_element_t< 0, DoutDinGridDesc > > DoutGridDesc_M_K
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:265
decltype(Make2DGridDescriptor_Out_M_K_In_M({0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0, 0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0})) DoutDinGridDesc
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:263
tensor_operation::element_wise::UnaryDivide Div
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:274
static auto Make2DGridDescriptor_Out_M_K_In_M(const std::vector< ck::index_t > &dout_n_c_wos_lengths, const std::vector< ck::index_t > &din_n_c_wos_length, const std::vector< ck::index_t > &dout_n_c_wos_strides, const std::vector< ck::index_t > &din_n_c_wos_strides, const std::vector< ck::index_t > &window_lengths, const std::vector< ck::index_t > &window_strides, const std::vector< ck::index_t > &window_dilations, const std::vector< ck::index_t > &input_left_pads, const std::vector< ck::index_t > &input_right_pads, const std::vector< ck::index_t > &tildes)
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:52
static constexpr auto I1
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:46
GridwiseReduction_mk_to_m_threadwise< DOutDataType, DInDataType, ComputeDataType, int, DoutGridDesc_M_K, DinGridDesc_M, reduce::Add, PassThrough, Div, InMemoryDataOperationEnum::Set, false, BlockSize, MThreadSliceSize, KThreadSliceSize, OutSrcInDstVectorDim, InSrcOutDstVectorSize, InSrcOutDstVectorSize > gridwise_reduce
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:292
tensor_operation::element_wise::PassThrough PassThrough
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:273
static constexpr ck::index_t NDimSpatial
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:43
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:501
static constexpr ck::index_t M_BlockTileSize
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:48
static constexpr index_t OutSrcInDstVectorDim
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:271
static constexpr ck::index_t K_BlockTileSize
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:49
std::string GetTypeString() const override
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:506
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_dout, void *p_din, std::vector< ck::index_t > dout_n_c_wos_lengths, std::vector< ck::index_t > din_n_c_wos_length, std::vector< ck::index_t > dout_n_c_wos_strides, std::vector< ck::index_t > din_n_c_wos_strides, std::vector< ck::index_t > window_lengths, std::vector< ck::index_t > window_strides, std::vector< ck::index_t > window_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads) override
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:459
static constexpr auto I0
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:45
remove_cvref_t< tuple_element_t< 1, DoutDinGridDesc > > DinGridDesc_M
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:266
static bool IsSupportedArgument(const Argument &arg)
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:421
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_avgpool2d_bwd_nhwc_nhwc.hpp:453
Definition: device_avgpool_bwd.hpp:20
Definition: unary_element_wise_operation.hpp:241
Definition: unary_element_wise_operation.hpp:569