22 std::vector<ck_tile::long_index_t> conv_strides,
23 std::vector<ck_tile::long_index_t> conv_dilations,
24 std::vector<ck_tile::long_index_t> in_left_pads,
25 std::vector<ck_tile::long_index_t>)
31 throw std::runtime_error(
"wrong! inconsistent dimension");
34 if constexpr(NDimSpatial == 1)
36 auto func = [&](
auto g,
auto k,
auto c,
auto x) {
39 for(std::size_t n = 0; n < output.
get_lengths()[1]; ++n)
41 for(std::size_t wo = 0; wo < output.
get_lengths()[3]; ++wo)
47 if(wi >= 0 && ck_tile::type_convert<std::size_t>(wi) < input.
get_lengths()[3])
49 InDataType v_in = input(g, n, c, wi);
50 OutDataType v_out = output(g, n, k, wo);
51 v_acc += ck_tile::type_convert<float>(v_out) *
52 ck_tile::type_convert<float>(v_in);
56 OutDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
57 weight(g, k, c, x) = v_acc_converted;
64 weight.
get_lengths()[3])(std::thread::hardware_concurrency());
66 else if constexpr(NDimSpatial == 2)
68 auto func = [&](
auto g,
auto k,
auto c,
auto y,
auto x) {
71 for(std::size_t n = 0; n < output.
get_lengths()[1]; ++n)
73 for(std::size_t ho = 0; ho < output.
get_lengths()[3]; ++ho)
79 for(std::size_t wo = 0; wo < output.
get_lengths()[4]; ++wo)
86 ck_tile::type_convert<std::size_t>(hi) < input.
get_lengths()[3] &&
88 ck_tile::type_convert<std::size_t>(wi) < input.
get_lengths()[4])
90 InDataType v_in = input(g, n, c, hi, wi);
91 OutDataType v_out = output(g, n, k, ho, wo);
93 v_acc += ck_tile::type_convert<float>(v_out) *
94 ck_tile::type_convert<float>(v_in);
99 WeiDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
100 weight(g, k, c, y, x) = v_acc_converted;
108 weight.
get_lengths()[4])(std::thread::hardware_concurrency());
110 else if constexpr(NDimSpatial == 3)
112 auto func = [&](
auto g,
auto k,
auto c,
auto z,
auto y,
auto x) {
115 for(std::size_t n = 0; n < output.
get_lengths()[1]; ++n)
117 for(std::size_t do_ = 0; do_ < output.
get_lengths()[3]; ++do_)
122 for(std::size_t ho = 0; ho < output.
get_lengths()[4]; ++ho)
127 for(std::size_t wo = 0; wo < output.
get_lengths()[5]; ++wo)
133 ck_tile::type_convert<std::size_t>(di) < input.
get_lengths()[3] &&
135 ck_tile::type_convert<std::size_t>(hi) < input.
get_lengths()[4] &&
137 ck_tile::type_convert<std::size_t>(wi) < input.
get_lengths()[5])
139 InDataType v_in = input(g, n, c, di, hi, wi);
140 OutDataType v_out = output(g, n, k, do_, ho, wo);
142 v_acc += ck_tile::type_convert<float>(v_out) *
143 ck_tile::type_convert<float>(v_in);
149 WeiDataType v_acc_converted = ck_tile::type_convert<WeiDataType>(v_acc);
150 weight(g, k, c, z, y, x) = v_acc_converted;
159 weight.
get_lengths()[5])(std::thread::hardware_concurrency());
163 throw std::runtime_error(
164 "Ref_conv_bwd_weight: number of dimensions must be between 1 and 3.");
#define CK_TILE_HOST
Definition: config.hpp:40
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
int32_t index_t
Definition: integer.hpp:9
int64_t long_index_t
Definition: integer.hpp:11
CK_TILE_HOST void reference_grouped_conv_bwd_weight(const HostTensor< InDataType > &input, HostTensor< WeiDataType > &weight, const HostTensor< OutDataType > &output, std::vector< ck_tile::long_index_t > conv_strides, std::vector< ck_tile::long_index_t > conv_dilations, std::vector< ck_tile::long_index_t > in_left_pads, std::vector< ck_tile::long_index_t >)
Definition: reference_grouped_conv_bwd_weight.hpp:19
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
std::size_t get_num_of_dimension() const
Definition: host_tensor.hpp:396