16 template <
typename ADataType,
21 typename CDEElementWise>
32 const CDEElementWise& cde_elementwise)
34 std::cout <<
"Calculating reference using optimized flat indexing with parallel processing..."
38 auto f_gm = [&](
auto g_flat,
auto m_flat) {
47 a_full_dims.
mData[g_flat * M_total * K_total + m_flat * K_total + k_flat];
49 b_full_dims.
mData[g_flat * N_total * K_total + n_flat * K_total + k_flat];
50 sum +=
static_cast<AccDataType
>(a_val) *
static_cast<AccDataType
>(b_val);
54 EDataType result =
static_cast<EDataType
>(sum);
55 if(ds_full_dims_host.size() == 0)
59 else if(ds_full_dims_host.size() == 1)
61 cde_elementwise(result,
62 ck_tile::type_convert<float>(sum),
63 ck_tile::type_convert<float>(
64 ds_full_dims_host[0].mData[g_flat * M_total * N_total +
65 m_flat * N_total + n_flat]));
67 else if(ds_full_dims_host.size() == 2)
71 ck_tile::type_convert<float>(sum),
72 ck_tile::type_convert<float>(
74 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
75 ck_tile::type_convert<float>(
77 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
79 else if(ds_full_dims_host.size() == 3)
83 ck_tile::type_convert<float>(sum),
84 ck_tile::type_convert<float>(
86 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
87 ck_tile::type_convert<float>(
89 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
90 ck_tile::type_convert<float>(
92 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
94 else if(ds_full_dims_host.size() == 4)
98 ck_tile::type_convert<float>(sum),
99 ck_tile::type_convert<float>(
101 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
102 ck_tile::type_convert<float>(
104 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
105 ck_tile::type_convert<float>(
107 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
108 ck_tile::type_convert<float>(
110 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
114 throw std::runtime_error(
"Unsupported NumDTensor for reference calculation");
118 e_full_dims_host_ref.
mData[g_flat * M_total * N_total + m_flat * N_total + n_flat] =
119 static_cast<EDataType
>(result);
128 template <
typename ADataType,
132 typename AccDataType,
133 typename CDEElementWise>
139 const std::vector<index_t>& G_dims,
140 const std::vector<index_t>& M_dims,
141 const std::vector<index_t>& N_dims,
142 const std::vector<index_t>& K_dims,
143 const std::vector<index_t>& A_dims,
144 const std::vector<index_t>& B_dims,
145 const std::vector<index_t>& E_dims,
146 const CDEElementWise& cde_elementwise)
148 std::cout <<
"Calculating reference using multi-dimensional indexing..." << std::endl;
150 std::vector<std::size_t> g_idx(G_dims.size());
151 std::vector<std::size_t> m_idx(M_dims.size());
152 std::vector<std::size_t> n_idx(N_dims.size());
153 std::vector<std::size_t> k_idx(K_dims.size());
154 std::vector<std::size_t> a_idx, b_idx, e_idx;
156 a_idx.reserve(A_dims.size());
157 b_idx.reserve(B_dims.size());
158 e_idx.reserve(E_dims.size());
160 auto calculate_total_elements = [](
const std::vector<ck_tile::index_t>& dims) {
161 return std::accumulate(dims.begin(), dims.end(), 1, std::multiplies<ck_tile::index_t>());
164 for(
ck_tile::index_t g_flat = 0; g_flat < calculate_total_elements(G_dims); ++g_flat)
167 for(
int i = G_dims.size() - 1; i >= 0; --i)
169 g_idx[i] = temp % G_dims[i];
173 for(
ck_tile::index_t m_flat = 0; m_flat < calculate_total_elements(M_dims); ++m_flat)
176 for(
int i = M_dims.size() - 1; i >= 0; --i)
178 m_idx[i] = temp % M_dims[i];
182 for(
ck_tile::index_t n_flat = 0; n_flat < calculate_total_elements(N_dims); ++n_flat)
185 for(
int i = N_dims.size() - 1; i >= 0; --i)
187 n_idx[i] = temp % N_dims[i];
197 for(
int i = K_dims.size() - 1; i >= 0; --i)
199 k_idx[i] = temp % K_dims[i];
206 a_idx.insert(a_idx.end(), g_idx.begin(), g_idx.end());
207 a_idx.insert(a_idx.end(), m_idx.begin(), m_idx.end());
208 a_idx.insert(a_idx.end(), k_idx.begin(), k_idx.end());
210 b_idx.insert(b_idx.end(), g_idx.begin(), g_idx.end());
211 b_idx.insert(b_idx.end(), n_idx.begin(), n_idx.end());
212 b_idx.insert(b_idx.end(), k_idx.begin(), k_idx.end());
214 auto a_val = a_full_dims(a_idx);
215 auto b_val = b_full_dims(b_idx);
217 sum +=
static_cast<AccDataType
>(a_val) *
static_cast<AccDataType
>(b_val);
221 e_idx.insert(e_idx.end(), g_idx.begin(), g_idx.end());
222 e_idx.insert(e_idx.end(), m_idx.begin(), m_idx.end());
223 e_idx.insert(e_idx.end(), n_idx.begin(), n_idx.end());
225 EDataType result =
static_cast<EDataType
>(sum);
226 if(ds_full_dims_host.size() == 0)
230 else if(ds_full_dims_host.size() == 1)
232 cde_elementwise(result,
233 ck_tile::type_convert<float>(sum),
234 ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)));
236 else if(ds_full_dims_host.size() == 2)
238 cde_elementwise(result,
239 ck_tile::type_convert<float>(sum),
240 ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
241 ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)));
243 else if(ds_full_dims_host.size() == 3)
245 cde_elementwise(result,
246 ck_tile::type_convert<float>(sum),
247 ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
248 ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)),
249 ck_tile::type_convert<float>(ds_full_dims_host[2](e_idx)));
251 else if(ds_full_dims_host.size() == 4)
253 cde_elementwise(result,
254 ck_tile::type_convert<float>(sum),
255 ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
256 ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)),
257 ck_tile::type_convert<float>(ds_full_dims_host[2](e_idx)),
258 ck_tile::type_convert<float>(ds_full_dims_host[3](e_idx)));
262 throw std::runtime_error(
"Unsupported NumDTensor for reference calculation");
265 e_full_dims_host_ref(e_idx) =
static_cast<EDataType
>(result);
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
int32_t index_t
Definition: integer.hpp:9
void calculate_reference_flat_indexing(const ck_tile::HostTensor< ADataType > &a_full_dims, const ck_tile::HostTensor< BDataType > &b_full_dims, const std::vector< ck_tile::HostTensor< DDataType >> &ds_full_dims_host, ck_tile::HostTensor< EDataType > &e_full_dims_host_ref, ck_tile::index_t G_total, ck_tile::index_t M_total, ck_tile::index_t N_total, ck_tile::index_t K_total, const CDEElementWise &cde_elementwise)
Definition: reference_batched_contraction.hpp:23
void calculate_reference_multi_dimensional(const HostTensor< ADataType > &a_full_dims, const HostTensor< BDataType > &b_full_dims, const std::vector< HostTensor< DDataType >> &ds_full_dims_host, HostTensor< EDataType > &e_full_dims_host_ref, const std::vector< index_t > &G_dims, const std::vector< index_t > &M_dims, const std::vector< index_t > &N_dims, const std::vector< index_t > &K_dims, const std::vector< index_t > &A_dims, const std::vector< index_t > &B_dims, const std::vector< index_t > &E_dims, const CDEElementWise &cde_elementwise)
Definition: reference_batched_contraction.hpp:134
Definition: host_tensor.hpp:336
Data mData
Definition: host_tensor.hpp:801