/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_contraction.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_contraction.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_contraction.hpp Source File
reference_batched_contraction.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <cstdlib>
7 #include <thread>
8 
9 #include "ck_tile/core.hpp"
11 
12 namespace ck_tile {
13 
14 template <typename ADataType,
15  typename BDataType,
16  typename DDataType,
17  typename EDataType,
18  typename AccDataType,
19  typename CDEElementWise>
20 
22  const ck_tile::HostTensor<ADataType>& a_full_dims,
23  const ck_tile::HostTensor<BDataType>& b_full_dims,
24  const std::vector<ck_tile::HostTensor<DDataType>>& ds_full_dims_host,
25  ck_tile::HostTensor<EDataType>& e_full_dims_host_ref,
26  ck_tile::index_t G_total,
27  ck_tile::index_t M_total,
28  ck_tile::index_t N_total,
29  ck_tile::index_t K_total,
30  const CDEElementWise& cde_elementwise)
31 {
32  std::cout << "Calculating reference using optimized flat indexing with parallel processing..."
33  << std::endl;
34 
35  // Parallel computation over G and M dimensions using pattern from reference_batched_gemm.hpp
36  auto f_gm = [&](auto g_flat, auto m_flat) {
37  for(ck_tile::index_t n_flat = 0; n_flat < N_total; ++n_flat)
38  {
39  AccDataType sum = 0;
40 
41  // Compute dot product over K dimension
42  for(ck_tile::index_t k_flat = 0; k_flat < K_total; ++k_flat)
43  {
44  auto a_val =
45  a_full_dims.mData[g_flat * M_total * K_total + m_flat * K_total + k_flat];
46  auto b_val =
47  b_full_dims.mData[g_flat * N_total * K_total + n_flat * K_total + k_flat];
48  sum += static_cast<AccDataType>(a_val) * static_cast<AccDataType>(b_val);
49  }
50 
51  // Apply elementwise operation with D tensors
52  EDataType result = static_cast<EDataType>(sum);
53  if(ds_full_dims_host.size() == 0)
54  {
55  ;
56  }
57  else if(ds_full_dims_host.size() == 1)
58  {
59  cde_elementwise(result,
60  ck_tile::type_convert<float>(sum),
61  ck_tile::type_convert<float>(
62  ds_full_dims_host[0].mData[g_flat * M_total * N_total +
63  m_flat * N_total + n_flat]));
64  }
65  else if(ds_full_dims_host.size() == 2)
66  {
67  cde_elementwise(
68  result,
69  ck_tile::type_convert<float>(sum),
70  ck_tile::type_convert<float>(
71  ds_full_dims_host[0]
72  .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
73  ck_tile::type_convert<float>(
74  ds_full_dims_host[1]
75  .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
76  }
77  else if(ds_full_dims_host.size() == 3)
78  {
79  cde_elementwise(
80  result,
81  ck_tile::type_convert<float>(sum),
82  ck_tile::type_convert<float>(
83  ds_full_dims_host[0]
84  .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
85  ck_tile::type_convert<float>(
86  ds_full_dims_host[1]
87  .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
88  ck_tile::type_convert<float>(
89  ds_full_dims_host[2]
90  .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
91  }
92  else if(ds_full_dims_host.size() == 4)
93  {
94  cde_elementwise(
95  result,
96  ck_tile::type_convert<float>(sum),
97  ck_tile::type_convert<float>(
98  ds_full_dims_host[0]
99  .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
100  ck_tile::type_convert<float>(
101  ds_full_dims_host[1]
102  .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
103  ck_tile::type_convert<float>(
104  ds_full_dims_host[2]
105  .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]),
106  ck_tile::type_convert<float>(
107  ds_full_dims_host[3]
108  .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]));
109  }
110  else
111  {
112  throw std::runtime_error("Unsupported NumDTensor for reference calculation");
113  }
114 
115  // Store result
116  e_full_dims_host_ref.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat] =
117  static_cast<EDataType>(result);
118  }
119  };
120 
121  // Execute parallel computation using hardware concurrency
122  // Parallelize over G_total and M_total dimensions for optimal CPU utilization
123  make_ParallelTensorFunctor(f_gm, G_total, M_total)(std::thread::hardware_concurrency());
124 }
125 
126 template <typename ADataType,
127  typename BDataType,
128  typename DDataType,
129  typename EDataType,
130  typename AccDataType,
131  typename CDEElementWise>
133  const HostTensor<ADataType>& a_full_dims,
134  const HostTensor<BDataType>& b_full_dims,
135  const std::vector<HostTensor<DDataType>>& ds_full_dims_host,
136  HostTensor<EDataType>& e_full_dims_host_ref,
137  const std::vector<index_t>& G_dims,
138  const std::vector<index_t>& M_dims,
139  const std::vector<index_t>& N_dims,
140  const std::vector<index_t>& K_dims,
141  const std::vector<index_t>& A_dims,
142  const std::vector<index_t>& B_dims,
143  const std::vector<index_t>& E_dims,
144  const CDEElementWise& cde_elementwise)
145 {
146  std::cout << "Calculating reference using multi-dimensional indexing..." << std::endl;
147 
148  std::vector<std::size_t> g_idx(G_dims.size());
149  std::vector<std::size_t> m_idx(M_dims.size());
150  std::vector<std::size_t> n_idx(N_dims.size());
151  std::vector<std::size_t> k_idx(K_dims.size());
152  std::vector<std::size_t> a_idx, b_idx, e_idx;
153 
154  a_idx.reserve(A_dims.size());
155  b_idx.reserve(B_dims.size());
156  e_idx.reserve(E_dims.size());
157 
158  for(ck_tile::index_t g_flat = 0; g_flat < calculate_total_elements(G_dims); ++g_flat)
159  {
160  ck_tile::index_t temp = g_flat;
161  for(int i = G_dims.size() - 1; i >= 0; --i)
162  {
163  g_idx[i] = temp % G_dims[i];
164  temp /= G_dims[i];
165  }
166 
167  for(ck_tile::index_t m_flat = 0; m_flat < calculate_total_elements(M_dims); ++m_flat)
168  {
169  temp = m_flat;
170  for(int i = M_dims.size() - 1; i >= 0; --i)
171  {
172  m_idx[i] = temp % M_dims[i];
173  temp /= M_dims[i];
174  }
175 
176  for(ck_tile::index_t n_flat = 0; n_flat < calculate_total_elements(N_dims); ++n_flat)
177  {
178  temp = n_flat;
179  for(int i = N_dims.size() - 1; i >= 0; --i)
180  {
181  n_idx[i] = temp % N_dims[i];
182  temp /= N_dims[i];
183  }
184 
185  AccDataType sum = 0;
186 
187  for(ck_tile::index_t k_flat = 0; k_flat < calculate_total_elements(K_dims);
188  ++k_flat)
189  {
190  temp = k_flat;
191  for(int i = K_dims.size() - 1; i >= 0; --i)
192  {
193  k_idx[i] = temp % K_dims[i];
194  temp /= K_dims[i];
195  }
196 
197  a_idx.clear();
198  b_idx.clear();
199 
200  a_idx.insert(a_idx.end(), g_idx.begin(), g_idx.end());
201  a_idx.insert(a_idx.end(), m_idx.begin(), m_idx.end());
202  a_idx.insert(a_idx.end(), k_idx.begin(), k_idx.end());
203 
204  b_idx.insert(b_idx.end(), g_idx.begin(), g_idx.end());
205  b_idx.insert(b_idx.end(), n_idx.begin(), n_idx.end());
206  b_idx.insert(b_idx.end(), k_idx.begin(), k_idx.end());
207 
208  auto a_val = a_full_dims(a_idx);
209  auto b_val = b_full_dims(b_idx);
210 
211  sum += static_cast<AccDataType>(a_val) * static_cast<AccDataType>(b_val);
212  }
213 
214  e_idx.clear();
215  e_idx.insert(e_idx.end(), g_idx.begin(), g_idx.end());
216  e_idx.insert(e_idx.end(), m_idx.begin(), m_idx.end());
217  e_idx.insert(e_idx.end(), n_idx.begin(), n_idx.end());
218 
219  EDataType result = static_cast<EDataType>(sum);
220  if(ds_full_dims_host.size() == 0)
221  {
222  ;
223  }
224  else if(ds_full_dims_host.size() == 1)
225  {
226  cde_elementwise(result,
227  ck_tile::type_convert<float>(sum),
228  ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)));
229  }
230  else if(ds_full_dims_host.size() == 2)
231  {
232  cde_elementwise(result,
233  ck_tile::type_convert<float>(sum),
234  ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
235  ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)));
236  }
237  else if(ds_full_dims_host.size() == 3)
238  {
239  cde_elementwise(result,
240  ck_tile::type_convert<float>(sum),
241  ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
242  ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)),
243  ck_tile::type_convert<float>(ds_full_dims_host[2](e_idx)));
244  }
245  else if(ds_full_dims_host.size() == 4)
246  {
247  cde_elementwise(result,
248  ck_tile::type_convert<float>(sum),
249  ck_tile::type_convert<float>(ds_full_dims_host[0](e_idx)),
250  ck_tile::type_convert<float>(ds_full_dims_host[1](e_idx)),
251  ck_tile::type_convert<float>(ds_full_dims_host[2](e_idx)),
252  ck_tile::type_convert<float>(ds_full_dims_host[3](e_idx)));
253  }
254  else
255  {
256  throw std::runtime_error("Unsupported NumDTensor for reference calculation");
257  }
258 
259  e_full_dims_host_ref(e_idx) = static_cast<EDataType>(result);
260  }
261  }
262  }
263 }
264 
265 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
int32_t index_t
Definition: integer.hpp:9
void calculate_reference_flat_indexing(const ck_tile::HostTensor< ADataType > &a_full_dims, const ck_tile::HostTensor< BDataType > &b_full_dims, const std::vector< ck_tile::HostTensor< DDataType >> &ds_full_dims_host, ck_tile::HostTensor< EDataType > &e_full_dims_host_ref, ck_tile::index_t G_total, ck_tile::index_t M_total, ck_tile::index_t N_total, ck_tile::index_t K_total, const CDEElementWise &cde_elementwise)
Definition: reference_batched_contraction.hpp:21
void calculate_reference_multi_dimensional(const HostTensor< ADataType > &a_full_dims, const HostTensor< BDataType > &b_full_dims, const std::vector< HostTensor< DDataType >> &ds_full_dims_host, HostTensor< EDataType > &e_full_dims_host_ref, const std::vector< index_t > &G_dims, const std::vector< index_t > &M_dims, const std::vector< index_t > &N_dims, const std::vector< index_t > &K_dims, const std::vector< index_t > &A_dims, const std::vector< index_t > &B_dims, const std::vector< index_t > &E_dims, const CDEElementWise &cde_elementwise)
Definition: reference_batched_contraction.hpp:132
Definition: host_tensor.hpp:336
Data mData
Definition: host_tensor.hpp:801