/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_avgpool3d_bwd_ndhwc_ndhwc.hpp Source File
device_avgpool3d_bwd_ndhwc_ndhwc.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
17 
18 namespace ck {
19 namespace tensor_operation {
20 namespace device {
21 
22 // In and Din = [N, C, Di, Hi, Wi]
23 // Out and Dout = [N, C, Do, Ho, Wo]
24 // Out = AvgPoolFwd(In)
25 // Din = AvgPoolBwd(Dout)
26 // Pooling dimension = D, H, W
27 template <typename DOutDataType,
28  typename DInDataType,
29  typename ComputeDataType,
30  ck::index_t BlockSize,
31  ck::index_t MThreadClusterSize,
32  ck::index_t KThreadClusterSize,
33  ck::index_t MThreadSliceSize,
34  ck::index_t KThreadSliceSize,
35  ck::index_t InSrcOutDstVectorSize>
37  DOutDataType,
38  DInDataType,
39  tensor_layout::convolution::NDHWC,
40  tensor_layout::convolution::NDHWC>
41 {
42  static constexpr ck::index_t NDimSpatial = 3;
43 
44  static constexpr auto I0 = Number<0>{};
45  static constexpr auto I1 = Number<1>{};
46 
47  static constexpr ck::index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
48  static constexpr ck::index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
49 
50  static auto
51  Make3DGridDescriptor_Out_M_K_In_M(const std::vector<ck::index_t>& dout_n_c_wos_lengths,
52  const std::vector<ck::index_t>& din_n_c_wos_length,
53  const std::vector<ck::index_t>& dout_n_c_wos_strides,
54  const std::vector<ck::index_t>& din_n_c_wos_strides,
55  const std::vector<ck::index_t>& window_lengths,
56  const std::vector<ck::index_t>& window_strides,
57  const std::vector<ck::index_t>& window_dilations,
58  const std::vector<ck::index_t>& input_left_pads,
59  const std::vector<ck::index_t>& input_right_pads,
60  const std::vector<ck::index_t>& tildes)
61  {
62  index_t i_ztilde = tildes[0];
63  index_t i_ytilde = tildes[1];
64  index_t i_xtilde = tildes[2];
65 
66  const index_t N = dout_n_c_wos_lengths[0];
67  const index_t C = dout_n_c_wos_lengths[1];
68 
69  const index_t Di = din_n_c_wos_length[2];
70  const index_t Hi = din_n_c_wos_length[3];
71  const index_t Wi = din_n_c_wos_length[4];
72 
73  const index_t Do = dout_n_c_wos_lengths[2];
74  const index_t Ho = dout_n_c_wos_lengths[3];
75  const index_t Wo = dout_n_c_wos_lengths[4];
76 
77  const index_t Z = window_lengths[0];
78  const index_t Y = window_lengths[1];
79  const index_t X = window_lengths[2];
80 
81  const index_t InLeftPadD = input_left_pads[0];
82  const index_t InLeftPadH = input_left_pads[1];
83  const index_t InLeftPadW = input_left_pads[2];
84 
85  const index_t InRightPadD = input_right_pads[0];
86  const index_t InRightPadH = input_right_pads[1];
87  const index_t InRightPadW = input_right_pads[2];
88 
89  const index_t ConvStrideD = window_strides[0];
90  const index_t ConvStrideH = window_strides[1];
91  const index_t ConvStrideW = window_strides[2];
92 
93  const index_t ConvDilationD = window_dilations[0];
94  const index_t ConvDilationH = window_dilations[1];
95  const index_t ConvDilationW = window_dilations[2];
96 
97  const auto out_n_do_ho_wo_c_grid_desc =
98  make_naive_tensor_descriptor(make_tuple(N, Do, Ho, Wo, C),
99  make_tuple(dout_n_c_wos_strides[0],
100  dout_n_c_wos_strides[2],
101  dout_n_c_wos_strides[3],
102  dout_n_c_wos_strides[4],
103  dout_n_c_wos_strides[1]));
104 
105  const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
106  const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
107  const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
108 
109  const auto ZTilde = ConvStrideD / GcdStrideDilationD;
110  const auto YTilde = ConvStrideH / GcdStrideDilationH;
111  const auto XTilde = ConvStrideW / GcdStrideDilationW;
112 
113  const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
114  const auto YDot = math::integer_divide_ceil(Y, YTilde);
115  const auto XDot = math::integer_divide_ceil(X, XTilde);
116 
117  const auto DTilde = Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
118  const auto HTilde = Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
119  const auto WTilde = Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
120 
121  // only work on Tildes that contribute to non-padding area of input tensor
122  const auto IDTildeSliceBegin = math::integer_divide_floor(
123  math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD);
124  const auto IHTildeSliceBegin = math::integer_divide_floor(
125  math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
126  const auto IWTildeSliceBegin = math::integer_divide_floor(
127  math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
128 
129  const auto IDTildeSliceEnd =
130  math::min(DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1);
131  const auto IHTildeSliceEnd =
132  math::min(HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
133  const auto IWTildeSliceEnd =
134  math::min(WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
135 
136  const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
137  const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
138  const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
139 
140  // ReduceK is different for each Reduce
141  const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
142  const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
143  const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
144 
145  // Problem size of reduction kernel
146  const index_t MRaw = N * DTildeSlice * HTildeSlice * WTildeSlice * C;
147  const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw;
148 
149  const index_t KRaw = ZDotSlice * YDotSlice * XDotSlice;
150  const index_t KPad = math::integer_least_multiple(KRaw, K_BlockTileSize) - KRaw;
151 
152  // Out[ReduceM, ReduceK]
153  const auto out_n_dop_hop_wop_c_grid_desc = transform_tensor_descriptor(
154  out_n_do_ho_wo_c_grid_desc,
156  make_pad_transform(Do, I0, I0),
157  make_pad_transform(Ho, I0, I0),
158  make_pad_transform(Wo, I0, I0),
162 
163  const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_c_grid_desc =
165  out_n_dop_hop_wop_c_grid_desc,
166  make_tuple(
168  make_embed_transform(make_tuple(ZDot, DTilde),
169  make_tuple(-ConvDilationD / GcdStrideDilationD, I1)),
170  make_embed_transform(make_tuple(YDot, HTilde),
171  make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
172  make_embed_transform(make_tuple(XDot, WTilde),
173  make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
175  make_tuple(
178  Sequence<1, 2>{},
179  Sequence<3, 4>{},
180  Sequence<5, 6>{},
181  Sequence<7>{}));
182 
183  const auto
184  out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc =
186  out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_c_grid_desc,
188  make_slice_transform(ZDot, I0, ZDotSlice),
189  make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
190  make_slice_transform(YDot, I0, YDotSlice),
191  make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
192  make_slice_transform(XDot, I0, XDotSlice),
193  make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
196  Sequence<1>{},
197  Sequence<2>{},
198  Sequence<3>{},
199  Sequence<4>{},
200  Sequence<5>{},
201  Sequence<6>{},
202  Sequence<7>{}),
204  Sequence<1>{},
205  Sequence<2>{},
206  Sequence<3>{},
207  Sequence<4>{},
208  Sequence<5>{},
209  Sequence<6>{},
210  Sequence<7>{}));
211 
212  const auto out_grid_desc_reducemraw_reducekraw = transform_tensor_descriptor(
213  out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_c_grid_desc,
214  make_tuple(
215  make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice, C)),
216  make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice))),
219 
220  const auto out_grid_desc_reducem_reducek = transform_tensor_descriptor(
221  out_grid_desc_reducemraw_reducekraw,
225 
226  // In[ReduceM]
227  const auto in_n_di_hi_wi_c_grid_desc =
228  make_naive_tensor_descriptor(make_tuple(N, Di, Hi, Wi, C),
229  make_tuple(din_n_c_wos_strides[0],
230  din_n_c_wos_strides[2],
231  din_n_c_wos_strides[3],
232  din_n_c_wos_strides[4],
233  din_n_c_wos_strides[1]));
234 
235  const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
236  in_n_di_hi_wi_c_grid_desc,
238  make_pad_transform(Di, InLeftPadD, InRightPadD),
239  make_pad_transform(Hi, InLeftPadH, InRightPadH),
240  make_pad_transform(Wi, InLeftPadW, InRightPadW),
244 
245  const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
247  in_n_dip_hip_wip_c_grid_desc,
249  make_embed_transform(make_tuple(XTilde, DTilde),
250  make_tuple(ConvDilationD, ConvStrideD)),
251  make_embed_transform(make_tuple(YTilde, HTilde),
252  make_tuple(ConvDilationH, ConvStrideH)),
253  make_embed_transform(make_tuple(XTilde, WTilde),
254  make_tuple(ConvDilationW, ConvStrideW)),
256  make_tuple(
259  Sequence<1, 2>{},
260  Sequence<3, 4>{},
261  Sequence<5, 6>{},
262  Sequence<7>{}));
263 
264  const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
266  in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
268  make_freeze_transform(i_ztilde),
269  make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
270  make_freeze_transform(i_ytilde),
271  make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
272  make_freeze_transform(i_xtilde),
273  make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
276  Sequence<1>{},
277  Sequence<2>{},
278  Sequence<3>{},
279  Sequence<4>{},
280  Sequence<5>{},
281  Sequence<6>{},
282  Sequence<7>{}),
284  Sequence<>{},
285  Sequence<1>{},
286  Sequence<>{},
287  Sequence<2>{},
288  Sequence<>{},
289  Sequence<3>{},
290  Sequence<4>{}));
291 
292  const auto in_grid_desc_reducemraw = transform_tensor_descriptor(
293  in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
294  make_tuple(
295  make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice, C))),
298 
299  const auto in_grid_desc_reducem =
300  transform_tensor_descriptor(in_grid_desc_reducemraw,
304 
305  return make_tuple(out_grid_desc_reducem_reducek, in_grid_desc_reducem);
306  }
307 
308  using DoutDinGridDesc = decltype(Make3DGridDescriptor_Out_M_K_In_M({0, 0, 0, 0, 0},
309  {0, 0, 0, 0, 0},
310  {0, 0, 0, 0, 0},
311  {0, 0, 0, 0, 0},
312  {0, 0, 0},
313  {0, 0, 0},
314  {0, 0, 0},
315  {0, 0, 0},
316  {0, 0, 0},
317  {0, 0, 0}));
318 
321 
322  // FIXME
323  // for NDHWC, the dim C is the fastest dimension, and is not reduced.
324  // Hence, it is in M dimension for reduction kernel.
325  static constexpr index_t OutSrcInDstVectorDim = 0; // 0: M, 1: K
326 
329 
331  DInDataType,
332  ComputeDataType,
333  int,
336  reduce::Add,
337  PassThrough,
338  Div,
340  false, // propagate_nan
341  BlockSize,
342  MThreadSliceSize,
343  KThreadSliceSize,
345  InSrcOutDstVectorSize,
346  InSrcOutDstVectorSize>;
347 
348  struct Argument : public BaseArgument
349  {
350  Argument(const DOutDataType* p_dout,
351  DInDataType* p_din,
352  std::vector<ck::index_t> dout_n_c_wos_lengths,
353  std::vector<ck::index_t> din_n_c_wos_length,
354  std::vector<ck::index_t> dout_n_c_wos_strides,
355  std::vector<ck::index_t> din_n_c_wos_strides,
356  std::vector<ck::index_t> window_lengths,
357  std::vector<ck::index_t> window_strides,
358  std::vector<ck::index_t> window_dilations,
359  std::vector<ck::index_t> input_left_pads,
360  std::vector<ck::index_t> input_right_pads)
361  : p_dout_grid_{p_dout},
362  p_din_grid_{p_din},
363  dout_n_c_wos_lengths_{dout_n_c_wos_lengths},
364  din_n_c_wos_length_{din_n_c_wos_length},
365  dout_n_c_wos_strides_{dout_n_c_wos_strides},
366  din_n_c_wos_strides_{din_n_c_wos_strides},
367  num_reduce_{1},
368  div_element_op_{window_lengths[0] * window_lengths[1] * window_lengths[2]}
369  {
370  std::vector<ck::index_t> Tildes(NDimSpatial);
371  for(int i = 0; i < NDimSpatial; ++i)
372  {
373  int GcdStrideDilation = math::gcd(window_strides[i], window_dilations[i]);
374  Tildes[i] = window_strides[i] / GcdStrideDilation;
375  num_reduce_ *= Tildes[i];
376  }
377 
378  for(index_t i_ztilde = 0; i_ztilde < Tildes[0]; ++i_ztilde)
379  {
380  for(index_t i_ytilde = 0; i_ytilde < Tildes[1]; ++i_ytilde)
381  {
382  for(index_t i_xtilde = 0; i_xtilde < Tildes[2]; ++i_xtilde)
383  {
384  // check slice is valid
385  const auto ZDotSlice =
386  math::integer_divide_ceil(window_lengths[0] - i_ztilde, Tildes[0]);
387  const auto YDotSlice =
388  math::integer_divide_ceil(window_lengths[1] - i_ytilde, Tildes[1]);
389  const auto XDotSlice =
390  math::integer_divide_ceil(window_lengths[2] - i_xtilde, Tildes[2]);
391 
392  if(ZDotSlice * YDotSlice * XDotSlice <= 0)
393  {
394  continue;
395  }
396 
397  const auto dout_din_grid_desc =
398  Make3DGridDescriptor_Out_M_K_In_M(dout_n_c_wos_lengths,
399  din_n_c_wos_length,
400  dout_n_c_wos_strides,
401  din_n_c_wos_strides,
402  window_lengths,
403  window_strides,
404  window_dilations,
405  input_left_pads,
406  input_right_pads,
407  {i_ztilde, i_ytilde, i_xtilde});
408 
409  dout_grid_desc_m_k_container_.push_back(dout_din_grid_desc[I0]);
410  din_grid_desc_m_container_.push_back(dout_din_grid_desc[I1]);
411  }
412  }
413  }
414  }
415 
416  const DOutDataType* p_dout_grid_;
417  DInDataType* p_din_grid_;
418  std::vector<ck::index_t> dout_n_c_wos_lengths_;
419  std::vector<ck::index_t> din_n_c_wos_length_;
420  std::vector<ck::index_t> dout_n_c_wos_strides_;
421  std::vector<ck::index_t> din_n_c_wos_strides_;
422 
424  std::vector<DoutGridDesc_M_K> dout_grid_desc_m_k_container_;
425  std::vector<DinGridDesc_M> din_grid_desc_m_container_;
426 
428  };
429 
430  struct Invoker : public BaseInvoker
431  {
432  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
433  {
434  float ave_time = 0;
435 
436  for(index_t i = 0; i < arg.num_reduce_; i++)
437  {
438  const auto kernel = kernel_reduce_threadwise<gridwise_reduce,
439  false,
440  false,
441  false, // don't have index input
442  DOutDataType,
443  DInDataType,
444  ComputeDataType,
445  int,
448  PassThrough,
449  Div>;
450 
451  ck::index_t M = arg.dout_grid_desc_m_k_container_[i].GetLength(I0);
452  const index_t grid_size = (M / M_BlockTileSize);
453 
454  ave_time += launch_and_time_kernel(stream_config,
455  kernel,
456  dim3(grid_size),
457  dim3(BlockSize),
458  0,
461  PassThrough{},
462  arg.div_element_op_,
463  float(1),
464  arg.p_dout_grid_,
465  nullptr,
466  float(0),
467  arg.p_din_grid_,
468  nullptr);
469  }
470 
471  return ave_time;
472  }
473 
474  float Run(const BaseArgument* p_arg,
475  const StreamConfig& stream_config = StreamConfig{}) override
476  {
477  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
478  }
479  };
480 
481  static bool IsSupportedArgument(const Argument& arg)
482  {
483  constexpr index_t Rank = NDimSpatial + 2;
484  int doutFastestDim = -1;
485  int dinFastestDim = -1;
486 
487  for(int i = 0; i < Rank; ++i)
488  {
489  if(arg.dout_n_c_wos_strides_[i] == 1)
490  doutFastestDim = i;
491  if(arg.din_n_c_wos_strides_[i] == 1)
492  dinFastestDim = i;
493  }
494 
495  if(doutFastestDim == -1 || dinFastestDim == -1)
496  {
497  if constexpr(InSrcOutDstVectorSize != 1)
498  return false;
499  }
500  else
501  {
502  if(arg.dout_n_c_wos_lengths_[doutFastestDim] % InSrcOutDstVectorSize != 0)
503  return false;
504  if(arg.din_n_c_wos_length_[dinFastestDim] % InSrcOutDstVectorSize != 0)
505  return false;
506  }
507 
508  return true;
509  }
510 
511  bool IsSupportedArgument(const BaseArgument* p_arg) override
512  {
513  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
514  }
515 
516  std::unique_ptr<BaseArgument>
517  MakeArgumentPointer(const void* p_dout,
518  void* p_din,
519  std::vector<ck::index_t> dout_n_c_wos_lengths,
520  std::vector<ck::index_t> din_n_c_wos_length,
521  std::vector<ck::index_t> dout_n_c_wos_strides,
522  std::vector<ck::index_t> din_n_c_wos_strides,
523  std::vector<ck::index_t> window_lengths,
524  std::vector<ck::index_t> window_strides,
525  std::vector<ck::index_t> window_dilations,
526  std::vector<ck::index_t> input_left_pads,
527  std::vector<ck::index_t> input_right_pads) override
528  {
529  constexpr index_t Rank = NDimSpatial + 2;
530 
531  if(dout_n_c_wos_strides.size() != Rank || din_n_c_wos_strides.size() != Rank ||
532  dout_n_c_wos_lengths.size() != Rank || din_n_c_wos_length.size() != Rank)
533  throw std::runtime_error("dimension is incorrect");
534 
535  if(window_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial ||
536  window_dilations.size() != NDimSpatial || input_left_pads.size() != NDimSpatial ||
537  input_right_pads.size() != NDimSpatial)
538  throw std::runtime_error("dimension is incorrect");
539 
540  return std::make_unique<Argument>(static_cast<const DOutDataType*>(p_dout),
541  static_cast<DInDataType*>(p_din),
542  dout_n_c_wos_lengths,
543  din_n_c_wos_length,
544  dout_n_c_wos_strides,
545  din_n_c_wos_strides,
546  window_lengths,
547  window_strides,
548  window_dilations,
549  input_left_pads,
550  input_right_pads);
551  }
552 
553  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
554  {
555  return std::make_unique<Invoker>(Invoker{});
556  }
557 
558  std::string GetTypeString() const override
559  {
560  auto str = std::stringstream();
561 
562  // clang-format off
563  str << "DeviceAvgPool3dBwd<" << BlockSize << ",";
564  str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
565  str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
566  str <<"InSrcOutDstVectorSize_" << InSrcOutDstVectorSize << ">";
567  // clang-format on
568 
569  return str.str();
570  }
571 };
572 
573 } // namespace device
574 } // namespace tensor_operation
575 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto integer_divide_floor(X x, Y y)
Definition: math.hpp:66
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
__host__ constexpr __device__ index_t gcd(index_t x, index_t y)
Definition: math.hpp:154
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition: gridwise_2d_reduction_threadwise.hpp:28
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
__host__ constexpr __device__ auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition: multi_index_transform_helper.hpp:48
__host__ constexpr __device__ auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition: multi_index_transform_helper.hpp:110
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:19
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
Definition: stream_config.hpp:10
Definition: gridwise_2d_reduction_threadwise.hpp:84
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: reduction_operator.hpp:37
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:349
int num_reduce_
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:423
Div div_element_op_
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:427
std::vector< ck::index_t > dout_n_c_wos_strides_
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:420
const DOutDataType * p_dout_grid_
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:416
std::vector< ck::index_t > din_n_c_wos_length_
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:419
DInDataType * p_din_grid_
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:417
Argument(const DOutDataType *p_dout, DInDataType *p_din, std::vector< ck::index_t > dout_n_c_wos_lengths, std::vector< ck::index_t > din_n_c_wos_length, std::vector< ck::index_t > dout_n_c_wos_strides, std::vector< ck::index_t > din_n_c_wos_strides, std::vector< ck::index_t > window_lengths, std::vector< ck::index_t > window_strides, std::vector< ck::index_t > window_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads)
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:350
std::vector< ck::index_t > dout_n_c_wos_lengths_
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:418
std::vector< DinGridDesc_M > din_grid_desc_m_container_
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:425
std::vector< ck::index_t > din_n_c_wos_strides_
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:421
std::vector< DoutGridDesc_M_K > dout_grid_desc_m_k_container_
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:424
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:431
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:474
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:432
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:41
static constexpr auto I0
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:44
tensor_operation::element_wise::UnaryDivide Div
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:328
remove_cvref_t< tuple_element_t< 1, DoutDinGridDesc > > DinGridDesc_M
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:320
static auto Make3DGridDescriptor_Out_M_K_In_M(const std::vector< ck::index_t > &dout_n_c_wos_lengths, const std::vector< ck::index_t > &din_n_c_wos_length, const std::vector< ck::index_t > &dout_n_c_wos_strides, const std::vector< ck::index_t > &din_n_c_wos_strides, const std::vector< ck::index_t > &window_lengths, const std::vector< ck::index_t > &window_strides, const std::vector< ck::index_t > &window_dilations, const std::vector< ck::index_t > &input_left_pads, const std::vector< ck::index_t > &input_right_pads, const std::vector< ck::index_t > &tildes)
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:51
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_dout, void *p_din, std::vector< ck::index_t > dout_n_c_wos_lengths, std::vector< ck::index_t > din_n_c_wos_length, std::vector< ck::index_t > dout_n_c_wos_strides, std::vector< ck::index_t > din_n_c_wos_strides, std::vector< ck::index_t > window_lengths, std::vector< ck::index_t > window_strides, std::vector< ck::index_t > window_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads) override
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:517
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:511
remove_cvref_t< tuple_element_t< 0, DoutDinGridDesc > > DoutGridDesc_M_K
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:319
static constexpr ck::index_t M_BlockTileSize
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:47
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:553
tensor_operation::element_wise::PassThrough PassThrough
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:327
std::string GetTypeString() const override
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:558
static constexpr ck::index_t K_BlockTileSize
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:48
decltype(Make3DGridDescriptor_Out_M_K_In_M({0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0, 0, 0}, {0, 0, 0}, {0, 0, 0}, {0, 0, 0}, {0, 0, 0}, {0, 0, 0}, {0, 0, 0})) DoutDinGridDesc
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:317
static bool IsSupportedArgument(const Argument &arg)
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:481
static constexpr index_t OutSrcInDstVectorDim
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:325
static constexpr auto I1
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:45
GridwiseReduction_mk_to_m_threadwise< DOutDataType, DInDataType, ComputeDataType, int, DoutGridDesc_M_K, DinGridDesc_M, reduce::Add, PassThrough, Div, InMemoryDataOperationEnum::Set, false, BlockSize, MThreadSliceSize, KThreadSliceSize, OutSrcInDstVectorDim, InSrcOutDstVectorSize, InSrcOutDstVectorSize > gridwise_reduce
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:346
static constexpr ck::index_t NDimSpatial
Definition: device_avgpool3d_bwd_ndhwc_ndhwc.hpp:42
Definition: device_avgpool_bwd.hpp:20
Definition: unary_element_wise_operation.hpp:241
Definition: unary_element_wise_operation.hpp:569