/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_contraction.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_contraction.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_batched_contraction.hpp Source File
reference_batched_contraction.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <cstdlib>
7 #include <thread>
8 
9 #include "ck_tile/core.hpp"
11 
12 namespace ck_tile {
13 
14 // Helper to apply elementwise operation with variable number of D tensors
15 template <typename EDataType, typename AccDataType, typename CDEElementWise>
17 {
18  template <typename... DValues>
19  CK_TILE_HOST_DEVICE static void apply(EDataType& result,
20  AccDataType sum,
21  const CDEElementWise& cde_elementwise,
22  DValues... d_vals)
23  {
24  if constexpr(sizeof...(DValues) == 0)
25  {
26  result = static_cast<EDataType>(sum);
27  }
28  else
29  {
30  cde_elementwise(
31  result, ck_tile::type_convert<float>(sum), ck_tile::type_convert<float>(d_vals)...);
32  }
33  }
34 };
35 
36 // Helper to extract D values at a given offset using index sequence
37 template <typename DDataType,
38  ck_tile::index_t NumDTensor,
39  typename Indices = std::make_index_sequence<NumDTensor>>
41 
42 template <typename DDataType, ck_tile::index_t NumDTensor, std::size_t... Is>
43 struct ExtractDValues<DDataType, NumDTensor, std::index_sequence<Is...>>
44 {
45  template <typename EDataType, typename AccDataType, typename CDEElementWise>
46  CK_TILE_HOST static void
47  apply_at_offsets(EDataType& result,
48  AccDataType sum,
49  const CDEElementWise& cde_elementwise,
50  const std::array<ck_tile::HostTensor<DDataType>, NumDTensor>& ds_tensors,
51  const std::array<std::size_t, NumDTensor>& d_offsets)
52  {
54  result, sum, cde_elementwise, ds_tensors[Is].mData[d_offsets[Is]]...);
55  }
56 };
57 
58 template <typename ADataType,
59  typename BDataType,
60  typename DDataType,
61  typename EDataType,
62  typename AccDataType,
63  typename CDEElementWise,
64  ck_tile::index_t NumDTensor>
65 
67  const ck_tile::HostTensor<ADataType>& a_full_dims,
68  const ck_tile::HostTensor<BDataType>& b_full_dims,
69  const std::array<ck_tile::HostTensor<DDataType>, NumDTensor>& ds_full_dims_host,
70  ck_tile::HostTensor<EDataType>& e_full_dims_host_ref,
71  ck_tile::index_t G_total,
72  ck_tile::index_t M_total,
73  ck_tile::index_t N_total,
74  ck_tile::index_t K_total,
75  const CDEElementWise& cde_elementwise,
76  const std::vector<ck_tile::index_t>& G_dims,
77  const std::vector<ck_tile::index_t>& M_dims,
78  const std::vector<ck_tile::index_t>& N_dims,
79  const std::vector<ck_tile::index_t>& K_dims)
80 {
81  std::cout << "Calculating reference using stride-aware indexing with parallel processing..."
82  << std::endl;
83 
84  // Extract stride information from tensor descriptors
85  const auto a_strides = a_full_dims.get_strides();
86  const auto b_strides = b_full_dims.get_strides();
87  const auto e_strides = e_full_dims_host_ref.get_strides();
88 
89  // Extract D tensor strides
90  std::array<std::vector<std::size_t>, NumDTensor> ds_strides;
91  for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
92  {
93  ds_strides[d] = ds_full_dims_host[d].get_strides();
94  }
95 
96  const ck_tile::index_t num_g_dims = G_dims.size();
97  const ck_tile::index_t num_m_dims = M_dims.size();
98  const ck_tile::index_t num_n_dims = N_dims.size();
99  const ck_tile::index_t num_k_dims = K_dims.size();
100 
101  // Helper lambda to compute linear index from flat indices using strides
102  auto compute_a_offset = [&](ck_tile::index_t g_flat,
103  ck_tile::index_t m_flat,
104  ck_tile::index_t k_flat) -> std::size_t {
105  std::size_t offset = 0;
106 
107  // Decode G dimensions
108  ck_tile::index_t temp = g_flat;
109  for(int i = num_g_dims - 1; i >= 0; --i)
110  {
111  offset += (temp % G_dims[i]) * a_strides[i];
112  temp /= G_dims[i];
113  }
114 
115  // Decode M dimensions
116  temp = m_flat;
117  for(int i = num_m_dims - 1; i >= 0; --i)
118  {
119  offset += (temp % M_dims[i]) * a_strides[num_g_dims + i];
120  temp /= M_dims[i];
121  }
122 
123  // Decode K dimensions
124  temp = k_flat;
125  for(int i = num_k_dims - 1; i >= 0; --i)
126  {
127  offset += (temp % K_dims[i]) * a_strides[num_g_dims + num_m_dims + i];
128  temp /= K_dims[i];
129  }
130 
131  return offset;
132  };
133 
134  auto compute_b_offset = [&](ck_tile::index_t g_flat,
135  ck_tile::index_t n_flat,
136  ck_tile::index_t k_flat) -> std::size_t {
137  std::size_t offset = 0;
138 
139  // Decode G dimensions
140  ck_tile::index_t temp = g_flat;
141  for(int i = num_g_dims - 1; i >= 0; --i)
142  {
143  offset += (temp % G_dims[i]) * b_strides[i];
144  temp /= G_dims[i];
145  }
146 
147  // Decode N dimensions
148  temp = n_flat;
149  for(int i = num_n_dims - 1; i >= 0; --i)
150  {
151  offset += (temp % N_dims[i]) * b_strides[num_g_dims + i];
152  temp /= N_dims[i];
153  }
154 
155  // Decode K dimensions
156  temp = k_flat;
157  for(int i = num_k_dims - 1; i >= 0; --i)
158  {
159  offset += (temp % K_dims[i]) * b_strides[num_g_dims + num_n_dims + i];
160  temp /= K_dims[i];
161  }
162 
163  return offset;
164  };
165 
166  auto compute_e_offset = [&](ck_tile::index_t g_flat,
167  ck_tile::index_t m_flat,
168  ck_tile::index_t n_flat) -> std::size_t {
169  std::size_t offset = 0;
170 
171  // Decode G dimensions
172  ck_tile::index_t temp = g_flat;
173  for(int i = num_g_dims - 1; i >= 0; --i)
174  {
175  offset += (temp % G_dims[i]) * e_strides[i];
176  temp /= G_dims[i];
177  }
178 
179  // Decode M dimensions
180  temp = m_flat;
181  for(int i = num_m_dims - 1; i >= 0; --i)
182  {
183  offset += (temp % M_dims[i]) * e_strides[num_g_dims + i];
184  temp /= M_dims[i];
185  }
186 
187  // Decode N dimensions
188  temp = n_flat;
189  for(int i = num_n_dims - 1; i >= 0; --i)
190  {
191  offset += (temp % N_dims[i]) * e_strides[num_g_dims + num_m_dims + i];
192  temp /= N_dims[i];
193  }
194 
195  return offset;
196  };
197 
198  // Helper to compute D tensor offset (D tensors have same shape as E: [G, M, N])
199  auto compute_d_offset = [&](ck_tile::index_t g_flat,
200  ck_tile::index_t m_flat,
201  ck_tile::index_t n_flat,
202  ck_tile::index_t d_idx) -> std::size_t {
203  std::size_t offset = 0;
204  const auto& d_strides = ds_strides[d_idx];
205 
206  // Decode G dimensions
207  ck_tile::index_t temp = g_flat;
208  for(int i = num_g_dims - 1; i >= 0; --i)
209  {
210  offset += (temp % G_dims[i]) * d_strides[i];
211  temp /= G_dims[i];
212  }
213 
214  // Decode M dimensions
215  temp = m_flat;
216  for(int i = num_m_dims - 1; i >= 0; --i)
217  {
218  offset += (temp % M_dims[i]) * d_strides[num_g_dims + i];
219  temp /= M_dims[i];
220  }
221 
222  // Decode N dimensions
223  temp = n_flat;
224  for(int i = num_n_dims - 1; i >= 0; --i)
225  {
226  offset += (temp % N_dims[i]) * d_strides[num_g_dims + num_m_dims + i];
227  temp /= N_dims[i];
228  }
229 
230  return offset;
231  };
232 
233  // Parallel computation over G and M dimensions
234  auto f_gm = [&](auto g_flat, auto m_flat) {
235  for(ck_tile::index_t n_flat = 0; n_flat < N_total; ++n_flat)
236  {
237  AccDataType sum = 0;
238 
239  // Compute dot product over K dimension using stride-aware indexing
240  for(ck_tile::index_t k_flat = 0; k_flat < K_total; ++k_flat)
241  {
242  const std::size_t a_offset = compute_a_offset(g_flat, m_flat, k_flat);
243  const std::size_t b_offset = compute_b_offset(g_flat, n_flat, k_flat);
244 
245  auto a_val = a_full_dims.mData[a_offset];
246  auto b_val = b_full_dims.mData[b_offset];
247  sum += static_cast<AccDataType>(a_val) * static_cast<AccDataType>(b_val);
248  }
249 
250  // Compute output offset using strides
251  const std::size_t e_offset = compute_e_offset(g_flat, m_flat, n_flat);
252 
253  // Compute individual D tensor offsets using their respective strides
254  std::array<std::size_t, NumDTensor> d_offsets;
255  for(ck_tile::index_t d = 0; d < NumDTensor; ++d)
256  {
257  d_offsets[d] = compute_d_offset(g_flat, m_flat, n_flat, d);
258  }
259 
260  // Apply elementwise operation with D tensors using compile-time dispatch
261  EDataType result = static_cast<EDataType>(sum);
263  result, sum, cde_elementwise, ds_full_dims_host, d_offsets);
264 
265  // Store result using stride-aware indexing
266  e_full_dims_host_ref.mData[e_offset] = static_cast<EDataType>(result);
267  }
268  };
269 
270  // Execute parallel computation using hardware concurrency
271  // Parallelize over G_total and M_total dimensions for optimal CPU utilization
272  make_ParallelTensorFunctor(f_gm, G_total, M_total)(std::thread::hardware_concurrency());
273 }
274 
275 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
void compute_reference_batched_contraction(const ck_tile::HostTensor< ADataType > &a_full_dims, const ck_tile::HostTensor< BDataType > &b_full_dims, const std::array< ck_tile::HostTensor< DDataType >, NumDTensor > &ds_full_dims_host, ck_tile::HostTensor< EDataType > &e_full_dims_host_ref, ck_tile::index_t G_total, ck_tile::index_t M_total, ck_tile::index_t N_total, ck_tile::index_t K_total, const CDEElementWise &cde_elementwise, const std::vector< ck_tile::index_t > &G_dims, const std::vector< ck_tile::index_t > &M_dims, const std::vector< ck_tile::index_t > &N_dims, const std::vector< ck_tile::index_t > &K_dims)
Definition: reference_batched_contraction.hpp:66
int32_t index_t
Definition: integer.hpp:9
Definition: reference_batched_contraction.hpp:17
static CK_TILE_HOST_DEVICE void apply(EDataType &result, AccDataType sum, const CDEElementWise &cde_elementwise, DValues... d_vals)
Definition: reference_batched_contraction.hpp:19
static CK_TILE_HOST void apply_at_offsets(EDataType &result, AccDataType sum, const CDEElementWise &cde_elementwise, const std::array< ck_tile::HostTensor< DDataType >, NumDTensor > &ds_tensors, const std::array< std::size_t, NumDTensor > &d_offsets)
Definition: reference_batched_contraction.hpp:47
Definition: reference_batched_contraction.hpp:40
Definition: host_tensor.hpp:336
decltype(auto) get_strides() const
Definition: host_tensor.hpp:394
Data mData
Definition: host_tensor.hpp:802
Definition: coordinate_transform.hpp:1392