14 template <
typename TensorShape,
typename WindowShape>
20 TensorShape input_shape_,
21 TensorShape output_shape_,
22 TensorShape input_strides_,
23 TensorShape output_strides_,
24 WindowShape window_lengths_,
25 WindowShape window_strides_,
26 WindowShape window_dilations_,
27 WindowShape input_left_pads_,
28 WindowShape input_right_pads_)
58 template <
typename TensorShape,
typename WindowShape>
74 template <
typename Problem_,
typename Policy_ = PoolDefaultPolicy>
91 template <
typename TensorShape,
typename WindowShape>
94 using S =
typename Problem::BlockShape;
97 static_assert(TensorShape::size() == 4,
"2D pooling requires 4D input tensor (N,H,W,C)");
98 static_assert(WindowShape::size() == 2,
"2D pooling requires 2D window shape (Y,X)");
126 const index_t MRaw = N * Ho * Wo * C;
131 auto reduce_op =
typename Problem::ReduceOp{};
158 const auto merged_embed_in_desc =
166 merged_embed_in_desc,
180 const auto out_desc_padded =
188 type_convert<InDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
190 type_convert<OutDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
192 auto in_buffer_view = make_buffer_view<address_space_enum::global>(
194 in_desc.get_element_space_size(),
196 const auto in_tensor_padded =
197 tensor_view<decltype(in_buffer_view), decltype(in_desc_padded)>{in_buffer_view,
200 auto out_buffer_view = make_buffer_view<address_space_enum::global>(
202 out_desc.get_element_space_size(),
204 const auto out_tensor_padded =
205 tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
208 return make_tuple(in_tensor_padded, out_tensor_padded);
211 template <
typename TensorShape,
typename WindowShape>
214 using S =
typename Problem::BlockShape;
217 static_assert(TensorShape::size() == 5,
"3D pooling requires 5D input tensor (N,D,H,W,C)");
218 static_assert(WindowShape::size() == 3,
"3D pooling requires 3D window shape (Z,Y,X)");
253 const index_t MRaw = N * Do * Ho * Wo * C;
254 const index_t KRaw = Z * Y * X;
258 auto reduce_op =
typename Problem::ReduceOp{};
299 merged_embed_in_desc,
313 const auto out_desc_padded =
321 type_convert<InDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
323 type_convert<OutDataType>(reduce_op.template GetIdentityValue<ComputeDataType>());
325 auto in_buffer_view = make_buffer_view<address_space_enum::global>(
327 in_desc.get_element_space_size(),
329 const auto in_tensor_padded =
330 tensor_view<decltype(in_buffer_view), decltype(in_desc_padded)>{in_buffer_view,
333 auto out_buffer_view = make_buffer_view<address_space_enum::global>(
335 out_desc.get_element_space_size(),
337 const auto out_tensor_padded =
338 tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
341 return make_tuple(in_tensor_padded, out_tensor_padded);
345 template <
typename TensorShape,
typename WindowShape>
348 using S =
typename Problem::BlockShape;
351 static_assert(WindowShape::size() == 2 || WindowShape::size() == 3,
352 "Only 2D and 3D pooling operations are supported");
354 const auto iM = get_block_id() * S::Block_M;
357 auto [in_tensor_padded, out_tensor_padded] = [&]() {
358 if constexpr(WindowShape::size() == 2)
360 else if constexpr(WindowShape::size() == 3)
363 static_assert(WindowShape::size() == 2 || WindowShape::size() == 3,
364 "Unsupported WindowShape rank: only 2D or 3D pooling is supported");
367 auto reduce_op =
typename Problem::ReduceOp{};
372 Policy::template MakeXBlockTileDistribution<Problem>());
375 __shared__
char smem[Policy::template GetSmemSize<Problem>()];
377 const auto reduce_len =
378 in_tensor_padded.get_tensor_descriptor().get_lengths().at(
number<1>{});
382 auto block_reduce2d = Policy::template GetBlockReduce2d<Problem>();
383 auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync<Problem>();
384 auto block_reduce2d_cross_warp = Policy::template GetBlockReduce2dCrossWarpSync<Problem>();
386 using XTensorTile = decltype(
load_tile(x_window));
387 auto y_tile = block_reduce2d.template MakeYBlockTile<XTensorTile>();
388 set_tile(y_tile, reduce_op.template GetIdentityValue<ComputeDataType>());
390 for(
int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile)
393 block_reduce2d(x_tile, y_tile, reduce_op);
397 block_reduce2d_sync(y_tile, reduce_op);
398 block_reduce2d_cross_warp(y_tile, smem, reduce_op);
399 store_tile(y_window, cast_tile<OutDataType>(y_tile));
412 template <
typename TensorShape,
typename WindowShape>
415 constexpr
index_t InputRank = TensorShape::size();
416 constexpr
index_t OutputRank = TensorShape::size();
417 constexpr
index_t WindowRank = WindowShape::size();
420 if constexpr(WindowRank != 2 && WindowRank != 3)
430 if constexpr((WindowRank == 2 && InputRank != 4) || (WindowRank == 3 && InputRank != 5))
434 CK_TILE_ERROR(
"Input tensor rank doesn't match window dimensions!");
444 CK_TILE_ERROR(
"Input tensor's channel dimension must have stride 1!");
453 CK_TILE_ERROR(
"Output tensor's channel dimension must have stride 1!");
463 template <
typename TensorShape,
typename WindowShape>
467 using S =
typename Problem::BlockShape;
474 return (M + S::Block_M - 1) / S::Block_M;
478 template <
typename TensorShape,
typename WindowShape>
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad_, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1584
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:274
constexpr CK_TILE_HOST_DEVICE auto integer_least_multiple(X x, Y y)
Definition: math.hpp:155
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1565
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:203
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
constexpr CK_TILE_HOST_DEVICE auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition: coordinate_transform.hpp:1594
Host arguments for pooling operations.
Definition: pool_kernel.hpp:16
TensorShape input_strides
Definition: pool_kernel.hpp:48
void * output_ptr
Definition: pool_kernel.hpp:44
WindowShape input_left_pads
Definition: pool_kernel.hpp:53
const void * input_ptr
Definition: pool_kernel.hpp:43
WindowShape window_lengths
Definition: pool_kernel.hpp:50
WindowShape window_strides
Definition: pool_kernel.hpp:51
TensorShape input_shape
Definition: pool_kernel.hpp:46
TensorShape output_strides
Definition: pool_kernel.hpp:49
CK_TILE_HOST PoolHostArgs(const void *input_ptr_, void *output_ptr_, TensorShape input_shape_, TensorShape output_shape_, TensorShape input_strides_, TensorShape output_strides_, WindowShape window_lengths_, WindowShape window_strides_, WindowShape window_dilations_, WindowShape input_left_pads_, WindowShape input_right_pads_)
Definition: pool_kernel.hpp:18
TensorShape output_shape
Definition: pool_kernel.hpp:47
WindowShape input_right_pads
Definition: pool_kernel.hpp:54
WindowShape window_dilations
Definition: pool_kernel.hpp:52
Kernel arguments for pooling operations.
Definition: pool_kernel.hpp:60
TensorShape output_shape
Definition: pool_kernel.hpp:64
WindowShape input_right_pads
Definition: pool_kernel.hpp:71
WindowShape window_lengths
Definition: pool_kernel.hpp:67
WindowShape window_dilations
Definition: pool_kernel.hpp:69
TensorShape input_strides
Definition: pool_kernel.hpp:65
const void * input_ptr
Definition: pool_kernel.hpp:61
WindowShape input_left_pads
Definition: pool_kernel.hpp:70
TensorShape input_shape
Definition: pool_kernel.hpp:63
WindowShape window_strides
Definition: pool_kernel.hpp:68
void * output_ptr
Definition: pool_kernel.hpp:62
TensorShape output_strides
Definition: pool_kernel.hpp:66
Definition: pool_kernel.hpp:76
ck_tile::remove_cvref_t< Policy_ > Policy
Definition: pool_kernel.hpp:78
ck_tile::remove_cvref_t< typename Problem::OutDataType > OutDataType
Definition: pool_kernel.hpp:82
ck_tile::remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition: pool_kernel.hpp:81
static constexpr CK_TILE_HOST auto BlockSize()
Definition: pool_kernel.hpp:86
static constexpr CK_TILE_HOST index_t CalculateGridSize(PoolKernelArgs< TensorShape, WindowShape > kargs)
Definition: pool_kernel.hpp:465
static constexpr index_t kBlockSize
Definition: pool_kernel.hpp:84
static CK_TILE_HOST bool IsSupportedArgument(PoolKernelArgs< TensorShape, WindowShape > kargs)
Validates if the given arguments are supported by the pooling kernel.
Definition: pool_kernel.hpp:413
static CK_TILE_DEVICE auto MakeTensorView2D(PoolKernelArgs< TensorShape, WindowShape > kargs)
Definition: pool_kernel.hpp:92
static CK_TILE_DEVICE auto MakeTensorView3D(PoolKernelArgs< TensorShape, WindowShape > kargs)
Definition: pool_kernel.hpp:212
static constexpr CK_TILE_HOST auto MakeKernelArgs(PoolHostArgs< TensorShape, WindowShape > &host_args)
Create kernel arguments from host arguments.
Definition: pool_kernel.hpp:480
ck_tile::remove_cvref_t< typename Problem::InDataType > InDataType
Definition: pool_kernel.hpp:80
CK_TILE_DEVICE void operator()(PoolKernelArgs< TensorShape, WindowShape > kargs) const
Definition: pool_kernel.hpp:346
ck_tile::remove_cvref_t< Problem_ > Problem
Definition: pool_kernel.hpp:77
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
#define CK_TILE_ENV(name)
Definition: env.hpp:145