14 template <
typename InDataType,
15 typename ComputeDataType,
17 typename IndexDataType,
21 bool OutputIndex =
false>
49 auto f = [&](
auto n,
auto ho,
auto wo,
auto c) {
50 ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
52 IndexDataType current_index = 0;
64 if(hi >= 0 && hi < H && wi >= 0 && wi < W)
66 const ComputeDataType v_in = type_convert<ComputeDataType>(input(n, hi, wi, c));
68 if constexpr(OutputIndex)
72 v_acc = reduce_op(v_acc, v_in, changed);
75 current_index = flat_index;
80 v_acc = reduce_op(v_acc, v_in);
87 output(n, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
89 if constexpr(OutputIndex)
91 output_index(n, ho, wo, c) = current_index;
99 template <
typename InDataType,
100 typename ComputeDataType,
101 typename OutDataType,
102 typename IndexDataType,
104 typename TensorShape,
105 typename WindowShape,
106 bool OutputIndex =
false>
140 auto f = [&](
auto n,
auto do_,
auto ho,
auto wo,
auto c) {
141 ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
143 IndexDataType current_index = 0;
160 if(di >= 0 && di < D && hi >= 0 && hi < H && wi >= 0 && wi < W)
162 const ComputeDataType v_in =
163 type_convert<ComputeDataType>(input(n, di, hi, wi, c));
165 if constexpr(OutputIndex)
167 IndexDataType flat_index =
169 bool changed =
false;
170 v_acc = reduce_op(v_acc, v_in, changed);
173 current_index = flat_index;
178 v_acc = reduce_op(v_acc, v_in);
186 output(n, do_, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
188 if constexpr(OutputIndex)
191 output_index(n, do_, ho, wo, c) = current_index;
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
CK_TILE_HOST void reference_pool2d(const HostTensor< InDataType > &input, HostTensor< OutDataType > &output, HostTensor< IndexDataType > &output_index, PoolKernelArgs< TensorShape, WindowShape > kargs, ReduceOp reduce_op)
Definition: reference_pool.hpp:22
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST void reference_pool3d(const HostTensor< InDataType > &input, HostTensor< OutDataType > &output, HostTensor< IndexDataType > &output_index, PoolKernelArgs< TensorShape, WindowShape > kargs, ReduceOp reduce_op)
Definition: reference_pool.hpp:107
Definition: host_tensor.hpp:336
std::size_t GetOffsetFromMultiIndex(Is... is) const
Definition: host_tensor.hpp:531
Kernel arguments for pooling operations.
Definition: pool_kernel.hpp:63
TensorShape output_shape
Definition: pool_kernel.hpp:68
WindowShape window_lengths
Definition: pool_kernel.hpp:71
WindowShape window_dilations
Definition: pool_kernel.hpp:73
WindowShape input_left_pads
Definition: pool_kernel.hpp:74
TensorShape input_shape
Definition: pool_kernel.hpp:67
WindowShape window_strides
Definition: pool_kernel.hpp:72
Definition: integral_constant.hpp:13