14 template <
typename ADataType,
19 typename CDEElementWise>
30 const CDEElementWise& cde_elementwise)
32 std::cout <<
"Calculating reference using optimized flat indexing with parallel processing..."
36 auto f_gm = [&](
auto g_flat,
auto m_flat) {
45 a_full_dims.
mData[g_flat * M_total * K_total + m_flat * K_total + k_flat];
47 b_full_dims.
mData[g_flat * N_total * K_total + n_flat * K_total + k_flat];
48 sum +=
static_cast<AccDataType
>(a_val) *
static_cast<AccDataType
>(b_val);
52 EDataType result =
static_cast<EDataType
>(sum);
53 if(ds_full_dims_host.size() == 0)
57 else if(ds_full_dims_host.size() == 1)
59 cde_elementwise(result,
60 ck_tile::type_convert<float>(sum),
61 ck_tile::type_convert<float>(
62 ds_full_dims_host[0].mData[g_flat * M_total * N_total +
63 m_flat * N_total + n_flat]));
65 else if(ds_full_dims_host.size() == 2)
69 ck_tile::type_convert<float>(sum),
70 ck_tile::type_convert<float>(
72 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
73 ck_tile::type_convert<float>(
75 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
77 else if(ds_full_dims_host.size() == 3)
81 ck_tile::type_convert<float>(sum),
82 ck_tile::type_convert<float>(
84 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
85 ck_tile::type_convert<float>(
87 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
88 ck_tile::type_convert<float>(
90 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
92 else if(ds_full_dims_host.size() == 4)
96 ck_tile::type_convert<float>(sum),
97 ck_tile::type_convert<float>(
99 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
100 ck_tile::type_convert<float>(
102 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
103 ck_tile::type_convert<float>(
105 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
106 ck_tile::type_convert<float>(
108 .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
112 throw std::runtime_error(
"Unsupported NumDTensor for reference calculation");
116 e_full_dims_host_ref.
mData[g_flat * M_total * N_total + m_flat * N_total + n_flat] =
117 static_cast<EDataType
>(result);
126 template <
typename ADataType,
130 typename AccDataType,
131 typename CDEElementWise>
137 const std::vector<index_t>& G_dims,
138 const std::vector<index_t>& M_dims,
139 const std::vector<index_t>& N_dims,
140 const std::vector<index_t>& K_dims,
141 const std::vector<index_t>& A_dims,
142 const std::vector<index_t>& B_dims,
143 const std::vector<index_t>& E_dims,
144 const CDEElementWise& cde_elementwise)
146 std::cout <<
"Calculating reference using multi-dimensional indexing..." << std::endl;
148 std::vector<std::size_t> g_idx(G_dims.size());
149 std::vector<std::size_t> m_idx(M_dims.size());
150 std::vector<std::size_t> n_idx(N_dims.size());
151 std::vector<std::size_t> k_idx(K_dims.size());
152 std::vector<std::size_t> a_idx, b_idx, e_idx;
154 a_idx.reserve(A_dims.size());
155 b_idx.reserve(B_dims.size());
156 e_idx.reserve(E_dims.size());
158 for(
ck_tile::index_t g_flat = 0; g_flat < calculate_total_elements(G_dims); ++g_flat)
161 for(
int i = G_dims.size() - 1; i >= 0; --i)
163 g_idx[i] = temp % G_dims[i];
167 for(
ck_tile::index_t m_flat = 0; m_flat < calculate_total_elements(M_dims); ++m_flat)
170 for(
int i = M_dims.size() - 1; i >= 0; --i)
172 m_idx[i] = temp % M_dims[i];
176 for(
ck_tile::index_t n_flat = 0; n_flat < calculate_total_elements(N_dims); ++n_flat)
179 for(
int i = N_dims.size() - 1; i >= 0; --i)
181 n_idx[i] = temp % N_dims[i];
191 for(
int i = K_dims.size() - 1; i >= 0; --i)
193 k_idx[i] = temp % K_dims[i];
200 a_idx.insert(a_idx.end(), g_idx.begin(), g_idx.end());
201 a_idx.insert(a_idx.end(), m_idx.begin(), m_idx.end());
202 a_idx.insert(a_idx.end(), k_idx.begin(), k_idx.end());
204 b_idx.insert(b_idx.end(), g_idx.begin(), g_idx.end());
205 b_idx.insert(b_idx.end(), n_idx.begin(), n_idx.end());
206 b_idx.insert(b_idx.end(), k_idx.begin(), k_idx.end());
208 auto a_val = a_full_dims(a_idx);
209 auto b_val = b_full_dims(b_idx);
211 sum +=
static_cast<AccDataType
>(a_val) *
static_cast<AccDataType
>(b_val);
215 e_idx.insert(e_idx.end(), g_idx.begin(), g_idx.end());
216 e_idx.insert(e_idx.end(), m_idx.begin(), m_idx.end());
217 e_idx.insert(e_idx.end(), n_idx.begin(), n_idx.end());
219 EDataType result =
static_cast<EDataType
>(sum);
220 if(ds_full_dims_host.size() == 0)
224 else if(ds_full_dims_host.size() == 1)
226 cde_elementwise(result,
227 ck_tile::type_convert<float>(sum),
228 ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)));
230 else if(ds_full_dims_host.size() == 2)
232 cde_elementwise(result,
233 ck_tile::type_convert<float>(sum),
234 ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
235 ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)));
237 else if(ds_full_dims_host.size() == 3)
239 cde_elementwise(result,
240 ck_tile::type_convert<float>(sum),
241 ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
242 ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)),
243 ck_tile::type_convert<float>(ds_full_dims_host[2](e_idx)));
245 else if(ds_full_dims_host.size() == 4)
247 cde_elementwise(result,
248 ck_tile::type_convert<float>(sum),
249 ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
250 ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)),
251 ck_tile::type_convert<float>(ds_full_dims_host[2](e_idx)),
252 ck_tile::type_convert<float>(ds_full_dims_host[3](e_idx)));
256 throw std::runtime_error(
"Unsupported NumDTensor for reference calculation");
259 e_full_dims_host_ref(e_idx) =
static_cast<EDataType
>(result);
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
int32_t index_t
Definition: integer.hpp:9
void calculate_reference_flat_indexing(const ck_tile::HostTensor< ADataType > &a_full_dims, const ck_tile::HostTensor< BDataType > &b_full_dims, const std::vector< ck_tile::HostTensor< DDataType >> &ds_full_dims_host, ck_tile::HostTensor< EDataType > &e_full_dims_host_ref, ck_tile::index_t G_total, ck_tile::index_t M_total, ck_tile::index_t N_total, ck_tile::index_t K_total, const CDEElementWise &cde_elementwise)
Definition: reference_batched_contraction.hpp:21
void calculate_reference_multi_dimensional(const HostTensor< ADataType > &a_full_dims, const HostTensor< BDataType > &b_full_dims, const std::vector< HostTensor< DDataType >> &ds_full_dims_host, HostTensor< EDataType > &e_full_dims_host_ref, const std::vector< index_t > &G_dims, const std::vector< index_t > &M_dims, const std::vector< index_t > &N_dims, const std::vector< index_t > &K_dims, const std::vector< index_t > &A_dims, const std::vector< index_t > &B_dims, const std::vector< index_t > &E_dims, const CDEElementWise &cde_elementwise)
Definition: reference_batched_contraction.hpp:132
Definition: host_tensor.hpp:336
Data mData
Definition: host_tensor.hpp:801