24 std::vector<ck_tile::long_index_t> conv_strides,
25 std::vector<ck_tile::long_index_t> conv_dilations,
26 std::vector<ck_tile::long_index_t> in_left_pads,
27 std::vector<ck_tile::long_index_t>,
28 Elfunc elfunc = Elfunc{},
35 throw std::runtime_error(
"wrong! inconsistent dimension");
38 if constexpr(NDimSpatial == 1)
40 auto func = [&](
auto g,
auto n,
auto k,
auto wo) {
43 for(std::size_t c = 0; c < weight.
get_lengths()[2]; ++c)
45 for(std::size_t x = 0; x < weight.
get_lengths()[3]; ++x)
51 if(wi >= 0 && ck_tile::type_convert<std::size_t>(wi) < input.
get_lengths()[3])
53 InDataType v_in = input(g, n, c, wi);
54 WeiDataType v_wei = weight(g, k, c, x);
55 v_acc += ck_tile::type_convert<float>(v_in) *
56 ck_tile::type_convert<float>(v_wei);
60 if constexpr(Tuple::size() > 0)
61 elfunc(v_acc, v_acc, ds.at(
ck_tile::
number<0>{})(g, n, k, wo));
64 OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
65 output(g, n, k, wo) = v_acc_out;
72 output.
get_lengths()[3])(std::thread::hardware_concurrency());
74 else if constexpr(NDimSpatial == 2)
76 auto func = [&](
auto g,
auto n,
auto k,
auto ho,
auto wo) {
79 for(std::size_t c = 0; c < weight.
get_lengths()[2]; ++c)
81 for(std::size_t y = 0; y < weight.
get_lengths()[3]; ++y)
87 for(std::size_t x = 0; x < weight.
get_lengths()[4]; ++x)
94 ck_tile::type_convert<std::size_t>(hi) < input.
get_lengths()[3] &&
96 ck_tile::type_convert<std::size_t>(wi) < input.
get_lengths()[4])
98 InDataType v_in = input(g, n, c, hi, wi);
99 WeiDataType v_wei = weight(g, k, c, y, x);
101 v_acc += ck_tile::type_convert<float>(v_in) *
102 ck_tile::type_convert<float>(v_wei);
107 if constexpr(Tuple::size() > 0)
108 elfunc(v_acc, v_acc, ds.at(
ck_tile::
number<0>{})(g, n, k, ho, wo));
110 elfunc(v_acc, v_acc);
111 OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
112 output(g, n, k, ho, wo) = v_acc_out;
120 output.
get_lengths()[4])(std::thread::hardware_concurrency());
122 else if constexpr(NDimSpatial == 3)
124 auto func = [&](
auto g,
auto n,
auto k,
auto d_o,
auto ho,
auto wo) {
127 for(std::size_t c = 0; c < weight.
get_lengths()[2]; ++c)
129 for(std::size_t z = 0; z < weight.
get_lengths()[3]; ++z)
134 for(std::size_t y = 0; y < weight.
get_lengths()[4]; ++y)
139 for(std::size_t x = 0; x < weight.
get_lengths()[5]; ++x)
145 ck_tile::type_convert<std::size_t>(di) < input.
get_lengths()[3] &&
147 ck_tile::type_convert<std::size_t>(hi) < input.
get_lengths()[4] &&
149 ck_tile::type_convert<std::size_t>(wi) < input.
get_lengths()[5])
151 InDataType v_in = input(g, n, c, di, hi, wi);
152 WeiDataType v_wei = weight(g, k, c, z, y, x);
154 v_acc += ck_tile::type_convert<float>(v_in) *
155 ck_tile::type_convert<float>(v_wei);
161 if constexpr(Tuple::size() > 0)
162 elfunc(v_acc, v_acc, ds.at(
ck_tile::
number<0>{})(g, n, k, d_o, ho, wo));
164 elfunc(v_acc, v_acc);
165 OutDataType v_acc_out = ck_tile::type_convert<OutDataType>(v_acc);
166 output(g, n, k, d_o, ho, wo) = v_acc_out;
175 output.
get_lengths()[5])(std::thread::hardware_concurrency());
179 throw std::runtime_error(
"Ref_Conv_fwd: number of dimensions must be between 1 and 3.");
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
int32_t index_t
Definition: integer.hpp:9
constant< v > number
Definition: integral_constant.hpp:37
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:47
int64_t long_index_t
Definition: integer.hpp:11
CK_TILE_HOST void reference_grouped_conv_fwd(const HostTensor< InDataType > &input, const HostTensor< WeiDataType > &weight, HostTensor< OutDataType > &output, std::vector< ck_tile::long_index_t > conv_strides, std::vector< ck_tile::long_index_t > conv_dilations, std::vector< ck_tile::long_index_t > in_left_pads, std::vector< ck_tile::long_index_t >, Elfunc elfunc=Elfunc{}, Tuple ds={})
Definition: reference_grouped_conv_fwd.hpp:21
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
std::size_t get_num_of_dimension() const
Definition: host_tensor.hpp:396
Definition: tuple.hpp:192