/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_max_pool_bwd_impl.hpp Source File
device_max_pool_bwd_impl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
11 
17 
21 
22 namespace ck {
23 namespace tensor_operation {
24 namespace device {
25 
26 // output[indices] = input
27 template <typename DOutDataType,
28  typename IndexDataType,
29  typename DInDataType,
30  ck::index_t InOutVectorSize>
31 struct DeviceMaxPoolBwdImpl : public DeviceMaxPoolBwd<DOutDataType, IndexDataType, DInDataType>
32 {
34  conditional_t<is_same_v<DInDataType, float> || is_same_v<DInDataType, double>,
35  DInDataType,
36  float>;
37 
40 
41  static constexpr auto I0 = Number<0>{};
42  static constexpr auto I1 = Number<1>{};
43 
44  template <typename Desc_M>
45  static auto PadDescriptor_M_1d(Desc_M& desc_m, index_t loop_step)
46  {
47  const auto m = desc_m.GetLength(I0);
48  const auto pad = math::integer_least_multiple(m, loop_step) - m;
49  const auto desc_m_pad =
54  return desc_m_pad;
55  }
56 
57  static auto MakeDescriptor_M(index_t length, index_t loop_step)
58  {
59  const auto desc_m = make_naive_tensor_descriptor_packed(make_tuple(length));
60  return PadDescriptor_M_1d(desc_m, loop_step);
61  }
62 
63  template <typename Desc_M>
64  static auto ExpendDescFirstDim(Desc_M desc_m)
65  {
67  desc_m,
68  make_tuple(make_unmerge_transform(make_tuple(I1, desc_m.GetLength(I0)))),
71  }
72 
73  using InOutGrid1dDesc = decltype(MakeDescriptor_M(1, 1));
75 
77  DOutDataType,
78  IndexDataType,
79  DInDataType,
82  InOutVectorSize>;
83 
85  DOutDataType,
86  IndexDataType,
90  InOutVectorSize>;
91 
92  static constexpr index_t BlockSize = 256;
93  static constexpr index_t MPerThread = 1;
94  static constexpr index_t NPerThread = InOutVectorSize;
95  static constexpr index_t MPerBlock = 1;
96  static constexpr index_t NPerBlock = BlockSize * NPerThread;
97 
99 
105  UnaryConvert,
106  BlockSize,
107  MPerBlock,
108  NPerBlock,
109  MPerThread,
110  NPerThread,
114  I1,
115  I1>;
116 
117  struct Argument : public BaseArgument
118  {
119  Argument(const DOutDataType* p_dout,
120  const IndexDataType* p_indices,
121  DInDataType* p_din,
122  index_t dout_length,
123  index_t din_length,
124  const std::vector<ck::index_t>& window_lengths,
125  const std::vector<ck::index_t>& window_strides,
126  const std::vector<ck::index_t>& window_dilations)
127  : p_dout_{p_dout},
128  p_indices_{p_indices},
129  p_din_{p_din},
130  dout_length_raw_{dout_length},
131  din_length_raw_{din_length},
133  windowOverlap_{false}
134  {
135  for(size_t i = 0; i < window_lengths.size(); ++i)
136  {
137  auto eff = (window_lengths.at(i) - 1) * window_dilations.at(i) + 1;
138  windowOverlap_ |= eff > window_strides.at(i);
139  }
140  }
141 
142  const DOutDataType* p_dout_;
143  const IndexDataType* p_indices_;
144  DInDataType* p_din_;
149  };
150 
151  struct Invoker : public BaseInvoker
152  {
153  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
154  {
155  index_t gridSize = getAvailableComputeUnitCount(stream_config);
156  index_t loop_step = gridSize * arg.blockSize_ * InOutVectorSize;
157  InOutGrid1dDesc din_grid_desc = MakeDescriptor_M(arg.din_length_raw_, loop_step);
158  InOutGrid1dDesc dout_grid_desc = MakeDescriptor_M(arg.dout_length_raw_, loop_step);
159 
160  if constexpr(is_same_v<DInDataType, float> || is_same_v<DInDataType, double>)
161  {
162  hip_check_error(hipMemsetAsync(arg.p_din_,
163  0,
164  arg.din_length_raw_ * sizeof(DInDataType),
165  stream_config.stream_id_));
166 
167  if(arg.windowOverlap_)
168  {
169  const auto put_kernel = kernel_put_element_1d<GridwisePutElementAtomicAdd,
171  DOutDataType,
172  IndexDataType,
173  DInDataType,
174  PassThrough>;
175 
176  return launch_and_time_kernel(stream_config,
177  put_kernel,
178  dim3(gridSize),
179  dim3(arg.blockSize_),
180  0,
181  dout_grid_desc,
182  arg.p_dout_,
183  arg.p_indices_,
184  arg.p_din_,
185  PassThrough{});
186  }
187  else
188  {
189  const auto put_kernel = kernel_put_element_1d<GridwisePutElementSet,
191  DOutDataType,
192  IndexDataType,
193  DInDataType,
194  PassThrough>;
195 
196  return launch_and_time_kernel(stream_config,
197  put_kernel,
198  dim3(gridSize),
199  dim3(arg.blockSize_),
200  0,
201  dout_grid_desc,
202  arg.p_dout_,
203  arg.p_indices_,
204  arg.p_din_,
205  PassThrough{});
206  }
207  }
208  else
209  {
210  if(arg.windowOverlap_)
211  {
212  if(arg.p_workspace_ == nullptr)
213  throw std::runtime_error("wrong! WorkSpace pointer has not been set");
214 
216  hipMemsetAsync(arg.p_workspace_,
217  0,
219  stream_config.stream_id_));
220 
221  const auto put_kernel = kernel_put_element_1d<GridwisePutElementAtomicAdd,
223  DOutDataType,
224  IndexDataType,
226  PassThrough>;
227 
228  const auto cast_kernel =
235  UnaryConvert>;
236 
237  float elapsed_time = launch_and_time_kernel(
238  stream_config,
239  put_kernel,
240  dim3(gridSize),
241  dim3(arg.blockSize_),
242  0,
243  dout_grid_desc,
244  arg.p_dout_,
245  arg.p_indices_,
246  static_cast<DInDataType_AutomicAddPreCast*>(arg.p_workspace_),
247  PassThrough{});
248 
249  InOutGrid2dDesc din_grid_desc_2d = ExpendDescFirstDim(din_grid_desc);
250  const index_t M = din_grid_desc_2d.GetLength(I0);
251  const index_t N = din_grid_desc_2d.GetLength(I1);
252  const auto block_2_tile_map = Block2TileMap(M, N);
253  const auto cast_kernel_grid_size =
254  block_2_tile_map.CalculateGridSize(din_grid_desc_2d);
255 
256  elapsed_time += launch_and_time_kernel(
257  stream_config,
258  cast_kernel,
259  dim3(cast_kernel_grid_size),
260  dim3(arg.blockSize_),
261  0,
262  ck::make_tuple(din_grid_desc_2d),
263  ck::make_tuple(din_grid_desc_2d),
265  static_cast<const DInDataType_AutomicAddPreCast*>(arg.p_workspace_)),
266  ck::make_tuple(arg.p_din_),
267  block_2_tile_map,
268  UnaryConvert{});
269 
270  return elapsed_time;
271  }
272  else
273  {
274  hip_check_error(hipMemsetAsync(arg.p_din_,
275  0,
276  arg.din_length_raw_ * sizeof(DInDataType),
277  stream_config.stream_id_));
278 
279  const auto put_kernel = kernel_put_element_1d<GridwisePutElementSet,
281  DOutDataType,
282  IndexDataType,
283  DInDataType,
284  PassThrough>;
285 
286  hip_check_error(hipMemsetAsync(arg.p_din_,
287  0,
288  arg.din_length_raw_ * sizeof(DInDataType),
289  stream_config.stream_id_));
290 
291  return launch_and_time_kernel(stream_config,
292  put_kernel,
293  dim3(gridSize),
294  dim3(arg.blockSize_),
295  0,
296  dout_grid_desc,
297  arg.p_dout_,
298  arg.p_indices_,
299  arg.p_din_,
300  PassThrough{});
301  }
302  }
303  }
304 
305  float Run(const BaseArgument* p_arg,
306  const StreamConfig& stream_config = StreamConfig{}) override
307  {
308  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
309  }
310  };
311 
312  size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
313  {
314  const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
315 
316  bool needCast = pArg_->windowOverlap_ &&
317  !(is_same_v<DInDataType, float> || is_same_v<DInDataType, double>);
318 
319  if(!needCast)
320  return 0;
321  else
322  return pArg_->din_length_raw_ * sizeof(DInDataType_AutomicAddPreCast);
323  };
324 
325  bool IsSupportedArgument(const BaseArgument* p_arg) override
326  {
327  const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
328  if(pArg->din_length_raw_ % InOutVectorSize != 0 ||
329  pArg->dout_length_raw_ % InOutVectorSize != 0)
330  {
331  return false;
332  }
333  return true;
334  }
335 
336  std::unique_ptr<BaseArgument>
337  MakeArgumentPointer(const void* p_dout,
338  const void* p_indices,
339  void* p_din,
340  index_t dout_length,
341  index_t din_length,
342  std::vector<ck::index_t> window_lengths,
343  std::vector<ck::index_t> window_strides,
344  std::vector<ck::index_t> window_dilations) override
345  {
346  // Assume p_dout, p_indices, p_din are packed memory space, dout_length and din_length are
347  // physical size of the packed tensor
348  return std::make_unique<Argument>(static_cast<const DOutDataType*>(p_dout),
349  static_cast<const IndexDataType*>(p_indices),
350  static_cast<DInDataType*>(p_din),
351  dout_length,
352  din_length,
353  window_lengths,
354  window_strides,
355  window_dilations);
356  }
357 
358  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
359  {
360  return std::make_unique<Invoker>(Invoker{});
361  }
362 };
363 
364 } // namespace device
365 } // namespace tensor_operation
366 } // namespace ck
auto pad(ck::index_t mpb, ck::index_t npb, ck::index_t kpb, ck::tensor_operation::device::GemmSpecialization gemm, CDesc_MRaw_NRaw conv)
Definition: helper.hpp:70
void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:10
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__global__ void kernel_put_element_1d(const InGrid1dDesc in_grid_1d_desc, const InDataType *__restrict__ p_in_global, const IndexDataType *__restrict__ p_indices_global, OutDataType *__restrict__ p_out_global, const ElementwiseOperation elementwise_op)
Definition: gridwise_put_element_1d.hpp:17
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__global__ void kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op)
Definition: gridwise_elementwise_2d.hpp:29
Definition: stream_config.hpp:10
Definition: gridwise_elementwise_2d.hpp:162
Definition: gridwise_put_element_1d.hpp:36
Definition: multi_index_transform.hpp:13
Definition: sequence.hpp:43
Definition: tuple.hpp:117
Definition: integral_constant.hpp:10
Definition: device_base.hpp:50
void * p_workspace_
Definition: device_base.hpp:57
Definition: device_base.hpp:61
Definition: device_max_pool_bwd.hpp:17
Definition: device_max_pool_bwd_impl.hpp:118
index_t dout_length_raw_
Definition: device_max_pool_bwd_impl.hpp:145
index_t din_length_raw_
Definition: device_max_pool_bwd_impl.hpp:146
index_t blockSize_
Definition: device_max_pool_bwd_impl.hpp:147
const IndexDataType * p_indices_
Definition: device_max_pool_bwd_impl.hpp:143
DInDataType * p_din_
Definition: device_max_pool_bwd_impl.hpp:144
bool windowOverlap_
Definition: device_max_pool_bwd_impl.hpp:148
const DOutDataType * p_dout_
Definition: device_max_pool_bwd_impl.hpp:142
Argument(const DOutDataType *p_dout, const IndexDataType *p_indices, DInDataType *p_din, index_t dout_length, index_t din_length, const std::vector< ck::index_t > &window_lengths, const std::vector< ck::index_t > &window_strides, const std::vector< ck::index_t > &window_dilations)
Definition: device_max_pool_bwd_impl.hpp:119
Definition: device_max_pool_bwd_impl.hpp:152
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_max_pool_bwd_impl.hpp:305
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_max_pool_bwd_impl.hpp:153
Definition: device_max_pool_bwd_impl.hpp:32
decltype(MakeDescriptor_M(1, 1)) InOutGrid1dDesc
Definition: device_max_pool_bwd_impl.hpp:73
static auto ExpendDescFirstDim(Desc_M desc_m)
Definition: device_max_pool_bwd_impl.hpp:64
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_max_pool_bwd_impl.hpp:325
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_dout, const void *p_indices, void *p_din, index_t dout_length, index_t din_length, std::vector< ck::index_t > window_lengths, std::vector< ck::index_t > window_strides, std::vector< ck::index_t > window_dilations) override
Definition: device_max_pool_bwd_impl.hpp:337
static constexpr auto I1
Definition: device_max_pool_bwd_impl.hpp:42
ck::tensor_operation::element_wise::UnaryConvert UnaryConvert
Definition: device_max_pool_bwd_impl.hpp:39
ck::tensor_operation::element_wise::PassThrough PassThrough
Definition: device_max_pool_bwd_impl.hpp:38
decltype(ExpendDescFirstDim(InOutGrid1dDesc{})) InOutGrid2dDesc
Definition: device_max_pool_bwd_impl.hpp:74
static constexpr index_t NPerThread
Definition: device_max_pool_bwd_impl.hpp:94
static constexpr auto I0
Definition: device_max_pool_bwd_impl.hpp:41
GridwisePutElement_1D< InOutGrid1dDesc, DOutDataType, IndexDataType, DInDataType, PassThrough, InMemoryDataOperationEnum::Set, InOutVectorSize > GridwisePutElementSet
Definition: device_max_pool_bwd_impl.hpp:82
static constexpr index_t MPerThread
Definition: device_max_pool_bwd_impl.hpp:93
conditional_t< is_same_v< DInDataType, float >||is_same_v< DInDataType, double >, DInDataType, float > DInDataType_AutomicAddPreCast
Definition: device_max_pool_bwd_impl.hpp:36
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_max_pool_bwd_impl.hpp:358
static constexpr index_t BlockSize
Definition: device_max_pool_bwd_impl.hpp:92
GridwisePutElement_1D< InOutGrid1dDesc, DOutDataType, IndexDataType, DInDataType_AutomicAddPreCast, PassThrough, InMemoryDataOperationEnum::AtomicAdd, InOutVectorSize > GridwisePutElementAtomicAdd
Definition: device_max_pool_bwd_impl.hpp:90
static auto MakeDescriptor_M(index_t length, index_t loop_step)
Definition: device_max_pool_bwd_impl.hpp:57
BlockToCTileMap_M00_N0_M01Adapt< MPerBlock, NPerBlock > Block2TileMap
Definition: device_max_pool_bwd_impl.hpp:98
GridwiseElementwise< Tuple< InOutGrid2dDesc >, Tuple< InOutGrid2dDesc >, Tuple< const DInDataType_AutomicAddPreCast * >, Tuple< DInDataType * >, Block2TileMap, UnaryConvert, BlockSize, MPerBlock, NPerBlock, MPerThread, NPerThread, Sequence< 0, 1 >, Sequence< InOutVectorSize >, Sequence< InOutVectorSize >, I1, I1 > GridwiseCasting
Definition: device_max_pool_bwd_impl.hpp:115
static constexpr index_t NPerBlock
Definition: device_max_pool_bwd_impl.hpp:96
static auto PadDescriptor_M_1d(Desc_M &desc_m, index_t loop_step)
Definition: device_max_pool_bwd_impl.hpp:45
size_t GetWorkSpaceSize(const BaseArgument *pArg) const override
Definition: device_max_pool_bwd_impl.hpp:312
static constexpr index_t MPerBlock
Definition: device_max_pool_bwd_impl.hpp:95
Definition: unary_element_wise_operation.hpp:241
Definition: unary_element_wise_operation.hpp:446