/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/pooling/kernel/pool_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/pooling/kernel/pool_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/pooling/kernel/pool_kernel.hpp Source File
pool_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 #include "ck_tile/ops/common.hpp"
9 #include <type_traits>
10 
11 namespace ck_tile {
12 
14 template <typename TensorShape, typename WindowShape>
16 {
17 
18  CK_TILE_HOST PoolHostArgs(const void* input_ptr_,
19  void* output_ptr_,
20  TensorShape input_shape_,
21  TensorShape output_shape_,
22  TensorShape input_strides_,
23  TensorShape output_strides_,
24  WindowShape window_lengths_,
25  WindowShape window_strides_,
26  WindowShape window_dilations_,
27  WindowShape input_left_pads_,
28  WindowShape input_right_pads_)
29  : input_ptr(input_ptr_),
30  output_ptr(output_ptr_),
31  input_shape(input_shape_),
32  output_shape(output_shape_),
33  input_strides(input_strides_),
34  output_strides(output_strides_),
35  window_lengths(window_lengths_),
36  window_strides(window_strides_),
37  window_dilations(window_dilations_),
38  input_left_pads(input_left_pads_),
39  input_right_pads(input_right_pads_)
40  {
41  }
42 
43  const void* input_ptr;
44  void* output_ptr;
45 
46  TensorShape input_shape;
47  TensorShape output_shape;
48  TensorShape input_strides;
49  TensorShape output_strides;
50  WindowShape window_lengths;
51  WindowShape window_strides;
52  WindowShape window_dilations;
53  WindowShape input_left_pads;
54  WindowShape input_right_pads;
55 };
56 
58 template <typename TensorShape, typename WindowShape>
60 {
61  const void* input_ptr;
62  void* output_ptr;
63  TensorShape input_shape;
64  TensorShape output_shape;
65  TensorShape input_strides;
66  TensorShape output_strides;
67  WindowShape window_lengths;
68  WindowShape window_strides;
69  WindowShape window_dilations;
70  WindowShape input_left_pads;
71  WindowShape input_right_pads;
72 };
73 
74 template <typename Problem_, typename Policy_ = PoolDefaultPolicy>
75 struct PoolKernel
76 {
79 
83 
84  static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
85 
86  CK_TILE_HOST static constexpr auto BlockSize()
87  {
88  return is_wave32() ? kBlockSize / 2 : kBlockSize;
89  }
90 
91  template <typename TensorShape, typename WindowShape>
93  {
94  using S = typename Problem::BlockShape;
95 
96  // Compile-time validation for 2D pooling
97  static_assert(TensorShape::size() == 4, "2D pooling requires 4D input tensor (N,H,W,C)");
98  static_assert(WindowShape::size() == 2, "2D pooling requires 2D window shape (Y,X)");
99 
100  // Extract dimension values
101  const index_t N = kargs.input_shape.at(number<0>{});
102  const index_t H = kargs.input_shape.at(number<1>{});
103  const index_t W = kargs.input_shape.at(number<2>{});
104  const index_t C = kargs.input_shape.at(number<3>{});
105 
106  const index_t No = kargs.output_shape.at(number<0>{});
107  const index_t Ho = kargs.output_shape.at(number<1>{});
108  const index_t Wo = kargs.output_shape.at(number<2>{});
109  const index_t Co = kargs.output_shape.at(number<3>{});
110 
111  const index_t Y = kargs.window_lengths.at(number<0>{});
112  const index_t X = kargs.window_lengths.at(number<1>{});
113 
114  const index_t WindowStrideH = kargs.window_strides.at(number<0>{});
115  const index_t WindowStrideW = kargs.window_strides.at(number<1>{});
116 
117  const index_t WindowDilationH = kargs.window_dilations.at(number<0>{});
118  const index_t WindowDilationW = kargs.window_dilations.at(number<1>{});
119 
120  const index_t InLeftPadH = kargs.input_left_pads.at(number<0>{});
121  const index_t InLeftPadW = kargs.input_left_pads.at(number<1>{});
122 
123  const index_t InRightPadH = kargs.input_right_pads.at(number<0>{});
124  const index_t InRightPadW = kargs.input_right_pads.at(number<1>{});
125 
126  const index_t MRaw = N * Ho * Wo * C;
127  const index_t KRaw = Y * X;
128  const index_t MPad = integer_least_multiple(MRaw, S::Block_M) - MRaw;
129  const index_t KPad = integer_least_multiple(KRaw, S::Block_N) - KRaw;
130 
131  auto reduce_op = typename Problem::ReduceOp{};
132 
133  // Create input descriptor with all transformations
134  auto in_desc = make_naive_tensor_descriptor(kargs.input_shape, kargs.input_strides);
135 
136  // Apply spatial padding to input descriptor
137  const auto padded_in_desc = transform_tensor_descriptor(
138  in_desc,
140  make_pad_transform(H, InLeftPadH, InRightPadH),
141  make_pad_transform(W, InLeftPadW, InRightPadW),
145 
146  // Create sliding windows by embedding pooling windows into descriptor
147  const auto embed_in_desc = transform_tensor_descriptor(
148  padded_in_desc,
149  make_tuple(
151  make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)),
152  make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)),
156 
157  // Reshape into 2D matrix: output positions (M) x pooling window elements (K)
158  const auto merged_embed_in_desc =
159  transform_tensor_descriptor(embed_in_desc,
164 
165  const auto in_desc_padded = transform_tensor_descriptor(
166  merged_embed_in_desc,
170 
171  // Create output descriptor with transformations
172  auto out_desc = make_naive_tensor_descriptor(kargs.output_shape, kargs.output_strides);
173 
174  const auto merged_out_desc = transform_tensor_descriptor(
175  out_desc,
176  make_tuple(make_merge_transform(make_tuple(No, Ho, Wo, Co))),
179 
180  const auto out_desc_padded =
181  transform_tensor_descriptor(merged_out_desc,
185 
186  // Now create buffer views and tensor views with the fully transformed descriptors
187  const InDataType in_identity =
188  type_convert<InDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
189  const OutDataType out_identity =
190  type_convert<OutDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
191 
192  auto in_buffer_view = make_buffer_view<address_space_enum::global>(
193  static_cast<const InDataType*>(kargs.input_ptr),
194  in_desc.get_element_space_size(),
195  in_identity);
196  const auto in_tensor_padded =
197  tensor_view<decltype(in_buffer_view), decltype(in_desc_padded)>{in_buffer_view,
198  in_desc_padded};
199 
200  auto out_buffer_view = make_buffer_view<address_space_enum::global>(
201  static_cast<OutDataType*>(kargs.output_ptr),
202  out_desc.get_element_space_size(),
203  out_identity);
204  const auto out_tensor_padded =
205  tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
206  out_desc_padded};
207 
208  return make_tuple(in_tensor_padded, out_tensor_padded);
209  }
210 
211  template <typename TensorShape, typename WindowShape>
213  {
214  using S = typename Problem::BlockShape;
215 
216  // Compile-time validation for 3D pooling
217  static_assert(TensorShape::size() == 5, "3D pooling requires 5D input tensor (N,D,H,W,C)");
218  static_assert(WindowShape::size() == 3, "3D pooling requires 3D window shape (Z,Y,X)");
219 
220  // Extract dimension values
221  const index_t N = kargs.input_shape.at(number<0>{});
222  const index_t D = kargs.input_shape.at(number<1>{});
223  const index_t H = kargs.input_shape.at(number<2>{});
224  const index_t W = kargs.input_shape.at(number<3>{});
225  const index_t C = kargs.input_shape.at(number<4>{});
226 
227  const index_t No = kargs.output_shape.at(number<0>{});
228  const index_t Do = kargs.output_shape.at(number<1>{});
229  const index_t Ho = kargs.output_shape.at(number<2>{});
230  const index_t Wo = kargs.output_shape.at(number<3>{});
231  const index_t Co = kargs.output_shape.at(number<4>{});
232 
233  const index_t Z = kargs.window_lengths.at(number<0>{});
234  const index_t Y = kargs.window_lengths.at(number<1>{});
235  const index_t X = kargs.window_lengths.at(number<2>{});
236 
237  const index_t WindowStrideD = kargs.window_strides.at(number<0>{});
238  const index_t WindowStrideH = kargs.window_strides.at(number<1>{});
239  const index_t WindowStrideW = kargs.window_strides.at(number<2>{});
240 
241  const index_t WindowDilationD = kargs.window_dilations.at(number<0>{});
242  const index_t WindowDilationH = kargs.window_dilations.at(number<1>{});
243  const index_t WindowDilationW = kargs.window_dilations.at(number<2>{});
244 
245  const index_t InLeftPadD = kargs.input_left_pads.at(number<0>{});
246  const index_t InLeftPadH = kargs.input_left_pads.at(number<1>{});
247  const index_t InLeftPadW = kargs.input_left_pads.at(number<2>{});
248 
249  const index_t InRightPadD = kargs.input_right_pads.at(number<0>{});
250  const index_t InRightPadH = kargs.input_right_pads.at(number<1>{});
251  const index_t InRightPadW = kargs.input_right_pads.at(number<2>{});
252 
253  const index_t MRaw = N * Do * Ho * Wo * C;
254  const index_t KRaw = Z * Y * X;
255  const index_t MPad = integer_least_multiple(MRaw, S::Block_M) - MRaw;
256  const index_t KPad = integer_least_multiple(KRaw, S::Block_N) - KRaw;
257 
258  auto reduce_op = typename Problem::ReduceOp{};
259 
260  // Create input descriptor with all transformations
261  auto in_desc = make_naive_tensor_descriptor(kargs.input_shape, kargs.input_strides);
262 
263  // Apply spatial padding to input descriptor (all 3D dimensions)
264  const auto padded_in_desc = transform_tensor_descriptor(
265  in_desc,
267  make_pad_transform(D, InLeftPadD, InRightPadD),
268  make_pad_transform(H, InLeftPadH, InRightPadH),
269  make_pad_transform(W, InLeftPadW, InRightPadW),
273 
274  // Create 3D sliding windows by embedding pooling windows into descriptor
275  const auto embed_in_desc = transform_tensor_descriptor(
276  padded_in_desc,
277  make_tuple(
279  make_embed_transform(make_tuple(Z, Do), make_tuple(WindowDilationD, WindowStrideD)),
280  make_embed_transform(make_tuple(Y, Ho), make_tuple(WindowDilationH, WindowStrideH)),
281  make_embed_transform(make_tuple(X, Wo), make_tuple(WindowDilationW, WindowStrideW)),
285  sequence<1, 2>{},
286  sequence<3, 4>{},
287  sequence<5, 6>{},
288  sequence<7>{}));
289 
290  // Reshape into 2D matrix: output positions (M) x pooling window elements (K)
291  const auto merged_embed_in_desc = transform_tensor_descriptor(
292  embed_in_desc,
293  make_tuple(make_merge_transform(make_tuple(N, Do, Ho, Wo, C)),
294  make_merge_transform(make_tuple(Z, Y, X))),
297 
298  const auto in_desc_padded = transform_tensor_descriptor(
299  merged_embed_in_desc,
303 
304  // Create output descriptor with transformations
305  auto out_desc = make_naive_tensor_descriptor(kargs.output_shape, kargs.output_strides);
306 
307  const auto merged_out_desc = transform_tensor_descriptor(
308  out_desc,
309  make_tuple(make_merge_transform(make_tuple(No, Do, Ho, Wo, Co))),
312 
313  const auto out_desc_padded =
314  transform_tensor_descriptor(merged_out_desc,
318 
319  // Now create buffer views and tensor views with the fully transformed descriptors
320  const InDataType in_identity =
321  type_convert<InDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
322  const OutDataType out_identity =
323  type_convert<OutDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
324 
325  auto in_buffer_view = make_buffer_view<address_space_enum::global>(
326  static_cast<const InDataType*>(kargs.input_ptr),
327  in_desc.get_element_space_size(),
328  in_identity);
329  const auto in_tensor_padded =
330  tensor_view<decltype(in_buffer_view), decltype(in_desc_padded)>{in_buffer_view,
331  in_desc_padded};
332 
333  auto out_buffer_view = make_buffer_view<address_space_enum::global>(
334  static_cast<OutDataType*>(kargs.output_ptr),
335  out_desc.get_element_space_size(),
336  out_identity);
337  const auto out_tensor_padded =
338  tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
339  out_desc_padded};
340 
341  return make_tuple(in_tensor_padded, out_tensor_padded);
342  }
343 
344  public:
345  template <typename TensorShape, typename WindowShape>
347  {
348  using S = typename Problem::BlockShape;
349 
350  // Compile-time validation for supported window dimensions
351  static_assert(WindowShape::size() == 2 || WindowShape::size() == 3,
352  "Only 2D and 3D pooling operations are supported");
353 
354  const auto iM = get_block_id() * S::Block_M;
355 
356  // Get tensors based on dimensionality
357  auto [in_tensor_padded, out_tensor_padded] = [&]() {
358  if constexpr(WindowShape::size() == 2)
359  return MakeTensorView2D(kargs);
360  else if constexpr(WindowShape::size() == 3)
361  return MakeTensorView3D(kargs);
362  else
363  static_assert(WindowShape::size() == 2 || WindowShape::size() == 3,
364  "Unsupported WindowShape rank: only 2D or 3D pooling is supported");
365  }();
366 
367  auto reduce_op = typename Problem::ReduceOp{};
368 
369  auto x_window = make_tile_window(in_tensor_padded,
371  {iM, 0},
372  Policy::template MakeXBlockTileDistribution<Problem>());
373  auto y_window = make_tile_window(out_tensor_padded, make_tuple(number<S::Block_M>{}), {iM});
374 
375  __shared__ char smem[Policy::template GetSmemSize<Problem>()];
376 
377  const auto reduce_len =
378  in_tensor_padded.get_tensor_descriptor().get_lengths().at(number<1>{});
379  index_t num_k_tiles =
380  __builtin_amdgcn_readfirstlane(integer_divide_ceil(reduce_len, S::Block_N));
381 
382  auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
383  auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
384  auto block_reduce2d_cross_warp = Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
385 
386  using XTensorTile = decltype(load_tile(x_window));
387  auto y_tile = block_reduce2d.template MakeYBlockTile<XTensorTile>();
388  set_tile(y_tile, reduce_op.template GetIdentityValue<ComputeDataType>());
389 
390  for(int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile)
391  {
392  const auto x_tile = load_tile(x_window);
393  block_reduce2d(x_tile, y_tile, reduce_op);
394  move_tile_window(x_window, {0, S::Block_N});
395  }
396 
397  block_reduce2d_sync(y_tile, reduce_op);
398  block_reduce2d_cross_warp(y_tile, smem, reduce_op);
399  store_tile(y_window, cast_tile<OutDataType>(y_tile));
400  }
401 
412  template <typename TensorShape, typename WindowShape>
414  {
415  constexpr index_t InputRank = TensorShape::size();
416  constexpr index_t OutputRank = TensorShape::size(); // Same as input rank
417  constexpr index_t WindowRank = WindowShape::size();
418 
419  // Validate window dimensions (only 2D and 3D supported)
420  if constexpr(WindowRank != 2 && WindowRank != 3)
421  {
422  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
423  {
424  CK_TILE_ERROR("Only 2D and 3D pooling are supported!");
425  }
426  return false;
427  }
428 
429  // Validate that input rank matches expected rank for window dimensions
430  if constexpr((WindowRank == 2 && InputRank != 4) || (WindowRank == 3 && InputRank != 5))
431  {
432  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
433  {
434  CK_TILE_ERROR("Input tensor rank doesn't match window dimensions!");
435  }
436  return false;
437  }
438 
439  // Check that channel dimension (last dimension) is contiguous for both input and output
440  if(kargs.input_strides.at(number<InputRank - 1>{}) != 1)
441  {
442  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
443  {
444  CK_TILE_ERROR("Input tensor's channel dimension must have stride 1!");
445  }
446  return false;
447  }
448 
449  if(kargs.output_strides.at(number<OutputRank - 1>{}) != 1)
450  {
451  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
452  {
453  CK_TILE_ERROR("Output tensor's channel dimension must have stride 1!");
454  }
455  return false;
456  }
457 
458  return true;
459  }
460 
463  template <typename TensorShape, typename WindowShape>
464  CK_TILE_HOST static constexpr index_t
466  {
467  using S = typename Problem::BlockShape;
468 
469  // Calculate total output elements (M dimension)
470  index_t M = 1;
471  static_for<0, TensorShape::size(), 1>{}([&](auto i) { M *= kargs.output_shape.at(i); });
472 
473  // Calculate grid size: ceil(M / Block_M)
474  return (M + S::Block_M - 1) / S::Block_M;
475  }
476 
478  template <typename TensorShape, typename WindowShape>
479  CK_TILE_HOST static constexpr auto
481  {
483  host_args.output_ptr,
484  host_args.input_shape,
485  host_args.output_shape,
486  host_args.input_strides,
487  host_args.output_strides,
488  host_args.window_lengths,
489  host_args.window_strides,
490  host_args.window_dilations,
491  host_args.input_left_pads,
492  host_args.input_right_pads};
493  }
494 };
495 
496 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad_, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1584
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
constexpr CK_TILE_HOST_DEVICE auto integer_least_multiple(X x, Y y)
Definition: math.hpp:155
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1565
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:203
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
constexpr CK_TILE_HOST_DEVICE auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition: coordinate_transform.hpp:1594
Host arguments for pooling operations.
Definition: pool_kernel.hpp:16
TensorShape input_strides
Definition: pool_kernel.hpp:48
void * output_ptr
Definition: pool_kernel.hpp:44
WindowShape input_left_pads
Definition: pool_kernel.hpp:53
const void * input_ptr
Definition: pool_kernel.hpp:43
WindowShape window_lengths
Definition: pool_kernel.hpp:50
WindowShape window_strides
Definition: pool_kernel.hpp:51
TensorShape input_shape
Definition: pool_kernel.hpp:46
TensorShape output_strides
Definition: pool_kernel.hpp:49
CK_TILE_HOST PoolHostArgs(const void *input_ptr_, void *output_ptr_, TensorShape input_shape_, TensorShape output_shape_, TensorShape input_strides_, TensorShape output_strides_, WindowShape window_lengths_, WindowShape window_strides_, WindowShape window_dilations_, WindowShape input_left_pads_, WindowShape input_right_pads_)
Definition: pool_kernel.hpp:18
TensorShape output_shape
Definition: pool_kernel.hpp:47
WindowShape input_right_pads
Definition: pool_kernel.hpp:54
WindowShape window_dilations
Definition: pool_kernel.hpp:52
Kernel arguments for pooling operations.
Definition: pool_kernel.hpp:60
TensorShape output_shape
Definition: pool_kernel.hpp:64
WindowShape input_right_pads
Definition: pool_kernel.hpp:71
WindowShape window_lengths
Definition: pool_kernel.hpp:67
WindowShape window_dilations
Definition: pool_kernel.hpp:69
TensorShape input_strides
Definition: pool_kernel.hpp:65
const void * input_ptr
Definition: pool_kernel.hpp:61
WindowShape input_left_pads
Definition: pool_kernel.hpp:70
TensorShape input_shape
Definition: pool_kernel.hpp:63
WindowShape window_strides
Definition: pool_kernel.hpp:68
void * output_ptr
Definition: pool_kernel.hpp:62
TensorShape output_strides
Definition: pool_kernel.hpp:66
Definition: pool_kernel.hpp:76
ck_tile::remove_cvref_t< Policy_ > Policy
Definition: pool_kernel.hpp:78
ck_tile::remove_cvref_t< typename Problem::OutDataType > OutDataType
Definition: pool_kernel.hpp:82
ck_tile::remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition: pool_kernel.hpp:81
static constexpr CK_TILE_HOST auto BlockSize()
Definition: pool_kernel.hpp:86
static constexpr CK_TILE_HOST index_t CalculateGridSize(PoolKernelArgs< TensorShape, WindowShape > kargs)
Definition: pool_kernel.hpp:465
static constexpr index_t kBlockSize
Definition: pool_kernel.hpp:84
static CK_TILE_HOST bool IsSupportedArgument(PoolKernelArgs< TensorShape, WindowShape > kargs)
Validates if the given arguments are supported by the pooling kernel.
Definition: pool_kernel.hpp:413
static CK_TILE_DEVICE auto MakeTensorView2D(PoolKernelArgs< TensorShape, WindowShape > kargs)
Definition: pool_kernel.hpp:92
static CK_TILE_DEVICE auto MakeTensorView3D(PoolKernelArgs< TensorShape, WindowShape > kargs)
Definition: pool_kernel.hpp:212
static constexpr CK_TILE_HOST auto MakeKernelArgs(PoolHostArgs< TensorShape, WindowShape > &host_args)
Create kernel arguments from host arguments.
Definition: pool_kernel.hpp:480
ck_tile::remove_cvref_t< typename Problem::InDataType > InDataType
Definition: pool_kernel.hpp:80
CK_TILE_DEVICE void operator()(PoolKernelArgs< TensorShape, WindowShape > kargs) const
Definition: pool_kernel.hpp:346
ck_tile::remove_cvref_t< Problem_ > Problem
Definition: pool_kernel.hpp:77
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
#define CK_TILE_ENV(name)
Definition: env.hpp:145