/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_pool.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_pool.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_pool.hpp Source File
reference_pool.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 #include <thread>
9 
10 namespace ck_tile {
11 
12 template <typename InDataType,
13  typename ComputeDataType,
14  typename OutDataType,
15  typename ReduceOp,
16  typename TensorShape,
17  typename WindowShape>
21  ReduceOp reduce_op)
22 {
23  const ck_tile::index_t N = kargs.input_shape.at(ck_tile::number<0>{});
24  const ck_tile::index_t H = kargs.input_shape.at(ck_tile::number<1>{});
25  const ck_tile::index_t W = kargs.input_shape.at(ck_tile::number<2>{});
26  const ck_tile::index_t C = kargs.input_shape.at(ck_tile::number<3>{});
27 
28  const ck_tile::index_t Ho = kargs.output_shape.at(ck_tile::number<1>{});
29  const ck_tile::index_t Wo = kargs.output_shape.at(ck_tile::number<2>{});
30 
33 
34  const ck_tile::index_t Sy = kargs.window_strides.at(ck_tile::number<0>{});
35  const ck_tile::index_t Sx = kargs.window_strides.at(ck_tile::number<1>{});
36 
39 
40  const ck_tile::index_t LeftPy = kargs.input_left_pads.at(ck_tile::number<0>{});
41  const ck_tile::index_t LeftPx = kargs.input_left_pads.at(ck_tile::number<1>{});
42  // Right padding is handled implicitly by bounds checking
43 
44  auto f = [&](auto n, auto ho, auto wo, auto c) {
45  ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
46 
47  for(ck_tile::index_t y = 0; y < Y; ++y)
48  {
49  // Calculate input height index with stride, dilation, and padding
50  ck_tile::index_t hi = ho * Sy + y * Dy - LeftPy;
51 
52  for(ck_tile::index_t x = 0; x < X; ++x)
53  {
54  // Calculate input width index with stride, dilation, and padding
55  ck_tile::index_t wi = wo * Sx + x * Dx - LeftPx;
56 
57  if(hi >= 0 && hi < H && wi >= 0 && wi < W)
58  {
59  const ComputeDataType v_in = type_convert<ComputeDataType>(input(n, hi, wi, c));
60  v_acc = reduce_op(v_acc, v_in);
61  }
62  // For positions outside bounds, we implicitly use identity value
63  }
64  }
65 
66  output(n, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
67  };
68 
69  // Parallelize over all output dimensions
70  make_ParallelTensorFunctor(f, N, Ho, Wo, C)(std::thread::hardware_concurrency());
71 }
72 
73 template <typename InDataType,
74  typename ComputeDataType,
75  typename OutDataType,
76  typename ReduceOp,
77  typename TensorShape,
78  typename WindowShape>
82  ReduceOp reduce_op)
83 {
84  const ck_tile::index_t N = kargs.input_shape.at(ck_tile::number<0>{});
85  const ck_tile::index_t D = kargs.input_shape.at(ck_tile::number<1>{});
86  const ck_tile::index_t H = kargs.input_shape.at(ck_tile::number<2>{});
87  const ck_tile::index_t W = kargs.input_shape.at(ck_tile::number<3>{});
88  const ck_tile::index_t C = kargs.input_shape.at(ck_tile::number<4>{});
89 
90  const ck_tile::index_t Do = kargs.output_shape.at(ck_tile::number<1>{});
91  const ck_tile::index_t Ho = kargs.output_shape.at(ck_tile::number<2>{});
92  const ck_tile::index_t Wo = kargs.output_shape.at(ck_tile::number<3>{});
93 
97 
98  const ck_tile::index_t Sz = kargs.window_strides.at(ck_tile::number<0>{});
99  const ck_tile::index_t Sy = kargs.window_strides.at(ck_tile::number<1>{});
100  const ck_tile::index_t Sx = kargs.window_strides.at(ck_tile::number<2>{});
101 
105 
106  const ck_tile::index_t LeftPz = kargs.input_left_pads.at(ck_tile::number<0>{});
107  const ck_tile::index_t LeftPy = kargs.input_left_pads.at(ck_tile::number<1>{});
108  const ck_tile::index_t LeftPx = kargs.input_left_pads.at(ck_tile::number<2>{});
109  // Right padding is handled implicitly by bounds checking
110 
111  auto f = [&](auto n, auto do_, auto ho, auto wo, auto c) {
112  ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
113 
114  for(ck_tile::index_t z = 0; z < Z; ++z)
115  {
116  // Calculate input depth index with stride, dilation, and padding
117  ck_tile::index_t di = do_ * Sz + z * Dz - LeftPz;
118 
119  for(ck_tile::index_t y = 0; y < Y; ++y)
120  {
121  // Calculate input height index with stride, dilation, and padding
122  ck_tile::index_t hi = ho * Sy + y * Dy - LeftPy;
123 
124  for(ck_tile::index_t x = 0; x < X; ++x)
125  {
126  // Calculate input width index with stride, dilation, and padding
127  ck_tile::index_t wi = wo * Sx + x * Dx - LeftPx;
128 
129  if(di >= 0 && di < D && hi >= 0 && hi < H && wi >= 0 && wi < W)
130  {
131  const ComputeDataType v_in =
132  type_convert<ComputeDataType>(input(n, di, hi, wi, c));
133  v_acc = reduce_op(v_acc, v_in);
134  }
135  // For positions outside bounds, we implicitly use identity value
136  }
137  }
138  }
139 
140  output(n, do_, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
141  };
142 
143  // Parallelize over all output dimensions
144  make_ParallelTensorFunctor(f, N, Do, Ho, Wo, C)(std::thread::hardware_concurrency());
145 }
146 
147 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
CK_TILE_HOST void reference_pool3d(const HostTensor< InDataType > &input, HostTensor< OutDataType > &output, PoolKernelArgs< TensorShape, WindowShape > kargs, ReduceOp reduce_op)
Definition: reference_pool.hpp:79
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST void reference_pool2d(const HostTensor< InDataType > &input, HostTensor< OutDataType > &output, PoolKernelArgs< TensorShape, WindowShape > kargs, ReduceOp reduce_op)
Definition: reference_pool.hpp:18
Definition: host_tensor.hpp:336
Kernel arguments for pooling operations.
Definition: pool_kernel.hpp:60
TensorShape output_shape
Definition: pool_kernel.hpp:64
WindowShape window_lengths
Definition: pool_kernel.hpp:67
WindowShape window_dilations
Definition: pool_kernel.hpp:69
WindowShape input_left_pads
Definition: pool_kernel.hpp:70
TensorShape input_shape
Definition: pool_kernel.hpp:63
WindowShape window_strides
Definition: pool_kernel.hpp:68
Definition: integral_constant.hpp:13