15 template <
typename EDataType,
typename AccDataType,
typename CDEElementWise>
18 template <
typename... DValues>
21 const CDEElementWise& cde_elementwise,
24 if constexpr(
sizeof...(DValues) == 0)
26 result =
static_cast<EDataType
>(sum);
31 result, ck_tile::type_convert<float>(sum), ck_tile::type_convert<float>(d_vals)...);
37 template <
typename DDataType,
39 typename Indices = std::make_index_sequence<NumDTensor>>
42 template <
typename DDataType,
ck_tile::index_t NumDTensor, std::size_t... Is>
45 template <
typename EDataType,
typename AccDataType,
typename CDEElementWise>
49 const CDEElementWise& cde_elementwise,
51 const std::array<std::size_t, NumDTensor>& d_offsets)
54 result, sum, cde_elementwise, ds_tensors[Is].mData[d_offsets[Is]]...);
58 template <
typename ADataType,
63 typename CDEElementWise,
75 const CDEElementWise& cde_elementwise,
76 const std::vector<ck_tile::index_t>& G_dims,
77 const std::vector<ck_tile::index_t>& M_dims,
78 const std::vector<ck_tile::index_t>& N_dims,
79 const std::vector<ck_tile::index_t>& K_dims)
81 std::cout <<
"Calculating reference using stride-aware indexing with parallel processing..."
87 const auto e_strides = e_full_dims_host_ref.
get_strides();
90 std::array<std::vector<std::size_t>, NumDTensor> ds_strides;
93 ds_strides[d] = ds_full_dims_host[d].get_strides();
109 for(
int i = num_g_dims - 1; i >= 0; --i)
111 offset += (temp % G_dims[i]) * a_strides[i];
117 for(
int i = num_m_dims - 1; i >= 0; --i)
119 offset += (temp % M_dims[i]) * a_strides[num_g_dims + i];
125 for(
int i = num_k_dims - 1; i >= 0; --i)
127 offset += (temp % K_dims[i]) * a_strides[num_g_dims + num_m_dims + i];
141 for(
int i = num_g_dims - 1; i >= 0; --i)
143 offset += (temp % G_dims[i]) * b_strides[i];
149 for(
int i = num_n_dims - 1; i >= 0; --i)
151 offset += (temp % N_dims[i]) * b_strides[num_g_dims + i];
157 for(
int i = num_k_dims - 1; i >= 0; --i)
159 offset += (temp % K_dims[i]) * b_strides[num_g_dims + num_n_dims + i];
173 for(
int i = num_g_dims - 1; i >= 0; --i)
175 offset += (temp % G_dims[i]) * e_strides[i];
181 for(
int i = num_m_dims - 1; i >= 0; --i)
183 offset += (temp % M_dims[i]) * e_strides[num_g_dims + i];
189 for(
int i = num_n_dims - 1; i >= 0; --i)
191 offset += (temp % N_dims[i]) * e_strides[num_g_dims + num_m_dims + i];
204 const auto& d_strides = ds_strides[d_idx];
208 for(
int i = num_g_dims - 1; i >= 0; --i)
210 offset += (temp % G_dims[i]) * d_strides[i];
216 for(
int i = num_m_dims - 1; i >= 0; --i)
218 offset += (temp % M_dims[i]) * d_strides[num_g_dims + i];
224 for(
int i = num_n_dims - 1; i >= 0; --i)
226 offset += (temp % N_dims[i]) * d_strides[num_g_dims + num_m_dims + i];
234 auto f_gm = [&](
auto g_flat,
auto m_flat) {
242 const std::size_t a_offset = compute_a_offset(g_flat, m_flat, k_flat);
243 const std::size_t b_offset = compute_b_offset(g_flat, n_flat, k_flat);
245 auto a_val = a_full_dims.
mData[a_offset];
246 auto b_val = b_full_dims.
mData[b_offset];
247 sum +=
static_cast<AccDataType
>(a_val) *
static_cast<AccDataType
>(b_val);
251 const std::size_t e_offset = compute_e_offset(g_flat, m_flat, n_flat);
254 std::array<std::size_t, NumDTensor> d_offsets;
257 d_offsets[d] = compute_d_offset(g_flat, m_flat, n_flat, d);
261 EDataType result =
static_cast<EDataType
>(sum);
263 result, sum, cde_elementwise, ds_full_dims_host, d_offsets);
266 e_full_dims_host_ref.
mData[e_offset] =
static_cast<EDataType
>(result);
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
void compute_reference_batched_contraction(const ck_tile::HostTensor< ADataType > &a_full_dims, const ck_tile::HostTensor< BDataType > &b_full_dims, const std::array< ck_tile::HostTensor< DDataType >, NumDTensor > &ds_full_dims_host, ck_tile::HostTensor< EDataType > &e_full_dims_host_ref, ck_tile::index_t G_total, ck_tile::index_t M_total, ck_tile::index_t N_total, ck_tile::index_t K_total, const CDEElementWise &cde_elementwise, const std::vector< ck_tile::index_t > &G_dims, const std::vector< ck_tile::index_t > &M_dims, const std::vector< ck_tile::index_t > &N_dims, const std::vector< ck_tile::index_t > &K_dims)
Definition: reference_batched_contraction.hpp:66
int32_t index_t
Definition: integer.hpp:9
Definition: reference_batched_contraction.hpp:17
static CK_TILE_HOST_DEVICE void apply(EDataType &result, AccDataType sum, const CDEElementWise &cde_elementwise, DValues... d_vals)
Definition: reference_batched_contraction.hpp:19
Definition: host_tensor.hpp:336
decltype(auto) get_strides() const
Definition: host_tensor.hpp:394
Data mData
Definition: host_tensor.hpp:802
Definition: coordinate_transform.hpp:1392