include/ck_tile/host/reference/reference_gemm.hpp Source File

include/ck_tile/host/reference/reference_gemm.hpp Source File#

Composable Kernel: include/ck_tile/host/reference/reference_gemm.hpp Source File
reference_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <cstdlib>
7 #include <thread>
8 
9 #include "ck_tile/core.hpp"
11 
12 namespace ck_tile {
13 
14 template <typename ADataType,
15  typename QDataType,
16  typename BDataType,
17  typename AccDataType,
18  typename CDataType,
19  uint32_t QuantGroupSize,
20  bool aquant,
21  typename AElementOp = ck_tile::identity,
22  typename BElementOp = ck_tile::identity,
23  typename ACCElementOp = ck_tile::identity>
25  const HostTensor<QDataType>& q,
26  const HostTensor<BDataType>& b_k_n,
27  HostTensor<CDataType>& c_m_n,
28  const AElementOp& a_element_op = {},
29  const BElementOp& b_element_op = {},
30  const ACCElementOp& acc_element_op = {})
31 {
32  const std::size_t M = a_m_k.get_length(0);
33  const std::size_t N = b_k_n.get_length(1);
34  const std::size_t K = a_m_k.get_length(1);
35 
36  auto f_mn = [&](auto m, auto n) {
37  AccDataType v_acc = 0, v_block_acc = 0;
38 
39  static_assert(std::is_same_v<ADataType, pk_int4_t> || std::is_same_v<ADataType, fp8_t> ||
40  std::is_same_v<ADataType, bf8_t>);
41  static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
42  std::is_same_v<BDataType, pk_int4_t>);
43  static_assert(std::is_same_v<AccDataType, float>);
44  static_assert(std::is_same_v<CDataType, float> ||
45  std::is_same_v<CDataType, ck_tile::half_t>);
46  for(std::size_t k = 0; k < K; ++k)
47  {
48  AccDataType v_a;
49  AccDataType v_b;
50  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
51  {
52  const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
53  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
54  if(k % 2 == 1)
55  v_a = fp32_val.hi;
56  else
57  v_a = fp32_val.lo;
58  }
59  else
60  {
61  v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
62  }
63  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
64  {
65  const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
66  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
67  if(k % 2 == 1)
68  v_b = fp32_val.hi;
69  else
70  v_b = fp32_val.lo;
71  }
72  else if constexpr(std::is_same_v<BDataType, fp8_t>)
73  {
74  v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
75  }
76  else
77  {
78  v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
79  }
80  v_block_acc += v_a * v_b;
81 
82  // Apply group dequant scale
83  if((k + 1) % QuantGroupSize == 0)
84  {
85  float scale = 0.f;
86  index_t outer_dim = (aquant) ? m : k / QuantGroupSize;
87  index_t inner_dim = (aquant) ? k / QuantGroupSize : n;
88 
89  if constexpr(std::is_same_v<QDataType, float>)
90  {
91  scale = q(outer_dim, inner_dim);
92  }
93  else if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
94  {
95  scale = fp8_to_float_raw(q(outer_dim, inner_dim));
96  }
97  else if constexpr(std::is_same_v<QDataType, ck_tile::bf8_t>)
98  {
99  scale = bf8_to_float_raw(q(outer_dim, inner_dim));
100  }
101  else
102  {
103  static_assert(false, "Unexpected Q datatype.");
104  }
105  v_block_acc *= scale;
106  v_acc += v_block_acc;
107  v_block_acc = 0;
108  }
109  }
110 
111  c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
112  };
113 
114  make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
115  std::cout << std::endl;
116 }
117 
118 template <typename ADataType,
119  typename BDataType,
120  typename AccDataType,
121  typename CDataType,
122  typename AElementOp = ck_tile::identity,
123  typename BElementOp = ck_tile::identity,
124  typename ACCElementOp = ck_tile::identity>
126  const HostTensor<BDataType>& b_k_n,
127  HostTensor<CDataType>& c_m_n,
128  const AElementOp& a_element_op = {},
129  const BElementOp& b_element_op = {},
130  const ACCElementOp& acc_element_op = {})
131 {
132  const std::size_t M = a_m_k.get_length(0);
133  const std::size_t N = b_k_n.get_length(1);
134  const std::size_t K = a_m_k.get_length(1);
135 
136  auto f_mn = [&](auto m, auto n) {
137  AccDataType v_acc = 0;
138 
139  for(std::size_t k = 0; k < K; ++k)
140  {
141  AccDataType v_a;
142  AccDataType v_b;
143  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
144  {
145  const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
146  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
147  if(k % 2 == 1)
148  v_a = fp32_val.hi;
149  else
150  v_a = fp32_val.lo;
151  }
152  else
153  {
154  v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
155  }
156  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
157  {
158  const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
159  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
160  if(k % 2 == 1)
161  v_b = fp32_val.hi;
162  else
163  v_b = fp32_val.lo;
164  }
165  else
166  {
167  v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
168  }
169  v_acc += v_a * v_b;
170  }
171 
172  c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
173  };
174 
175  make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
176 }
177 
178 template <typename ADataType,
179  typename BDataType,
180  typename DsDataType,
181  typename AccDataType,
182  typename CDataType,
183  typename ACCElementOp,
184  typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
185 CK_TILE_HOST void
187  const HostTensor<BDataType>& b_k_n,
188  const std::array<HostTensor<DDataType>, DsDataType::size()>& ds_m_n,
189  HostTensor<CDataType>& c_m_n,
190  const ACCElementOp& acc_element_op = {})
191 {
192  const std::size_t M = a_m_k.get_length(0);
193  const std::size_t N = b_k_n.get_length(1);
194  const std::size_t K = a_m_k.get_length(1);
195 
196  auto f_mk_kn_mn = [&](auto m, auto n) {
197  AccDataType v_acc = 0;
198  for(std::size_t k = 0; k < K; ++k)
199  {
200  ADataType v_a = a_m_k(m, k);
201  BDataType v_b = b_k_n(k, n);
202  v_acc +=
203  ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
204  }
205 
206  CDataType v_c = 0;
207  if constexpr(DsDataType::size() == 0)
208  {
209  acc_element_op(v_c, ck_tile::type_convert<float>(v_acc));
210  }
211  else if constexpr(DsDataType::size() == 1)
212  {
213  acc_element_op(v_c,
214  ck_tile::type_convert<float>(v_acc),
215  ck_tile::type_convert<float>(ds_m_n[0](m, n)));
216  }
217  else if constexpr(DsDataType::size() == 2)
218  {
219  acc_element_op(v_c,
220  ck_tile::type_convert<float>(v_acc),
221  ck_tile::type_convert<float>(ds_m_n[0](m, n)),
222  ck_tile::type_convert<float>(ds_m_n[1](m, n)));
223  }
224  c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
225  };
226 
227  make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
228 }
229 
230 template <typename ADataType,
231  typename BDataType,
232  typename AccDataType,
233  typename CDataType,
234  typename LayoutA,
235  typename LayoutB,
236  typename LayoutC>
237 __global__ void naive_gemm_kernel(ADataType* A,
238  BDataType* B,
239  CDataType* C,
243  ck_tile::index_t strideA,
244  ck_tile::index_t strideB,
245  ck_tile::index_t strideC)
246 {
247  int idx = blockIdx.x * blockDim.x + threadIdx.x;
248  int row = idx / N; // Compute row index
249  int col = idx % N; // Compute column index
250 
251  if(row < M && col < N)
252  {
253  AccDataType acc = 0.0;
254  for(int k = 0; k < K; ++k)
255  {
258  // Adjust indexing based on matrix layout
259  int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
260  ? row * strideA + k
261  : k * strideA + row;
262  int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
263  ? col * strideB + k
264  : k * strideB + col;
265 
266  AccDataType v_a;
267  AccDataType v_b;
268  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
269  {
270  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
271  if(k % 2 == 1)
272  v_a = fp32_val.hi;
273  else
274  v_a = fp32_val.lo;
275  }
276  else
277  {
278  v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
279  }
280  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
281  {
282  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
283  if(k % 2 == 1)
284  v_b = fp32_val.hi;
285  else
286  v_b = fp32_val.lo;
287  }
288  else
289  {
290  v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
291  }
292  acc += v_a * v_b;
293  }
294 
295  int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
296  ? row * strideC + col
297  : col * strideC + row;
298  C[c_index] = ck_tile::type_convert<CDataType>(acc);
299  }
300 }
301 
302 template <typename ADataType,
303  typename BDataType,
304  typename AccDataType,
305  typename CDataType,
306  typename LayoutA,
307  typename LayoutB,
308  typename LayoutC>
309 void reference_gemm_gpu(ADataType* a_ptr,
310  BDataType* b_ptr,
311  CDataType* c_ptr,
312  index_t M,
313  index_t N,
314  index_t K,
315  index_t stride_a,
316  index_t stride_b,
317  index_t stride_c)
318 {
319  int totalElements = M * N;
320  int numThreadsPerBlock = 256; // Common choice for threads per block
321  int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
322 
323  naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
324  <<<numBlocks, numThreadsPerBlock>>>(
325  a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
326 
327  return;
328 }
329 
330 template <typename ADataType,
331  typename BDataType,
332  typename AccDataType,
333  typename CDataType,
334  typename LayoutA,
335  typename LayoutB,
336  typename LayoutC>
337 void reference_batched_gemm_gpu(ADataType* a_ptr,
338  BDataType* b_ptr,
339  CDataType* c_ptr,
340  index_t M,
341  index_t N,
342  index_t K,
343  index_t stride_a,
344  index_t stride_b,
345  index_t stride_c,
346  index_t batch_stride_A,
347  index_t batch_stride_B,
348  index_t batch_stride_C,
349  index_t batch_count)
350 {
351  int totalElements = M * N;
352  int numThreadsPerBlock = 256; // Common choice for threads per block
353  int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
354 
355  for(index_t batch_id = 0; batch_id < batch_count; ++batch_id)
356  {
357  ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
358  BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
359  CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
360  naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
361  <<<numBlocks, numThreadsPerBlock>>>(
362  d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
363  }
364 
365  return;
366 }
367 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:39
Definition: cluster_descriptor.hpp:13
void reference_batched_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t batch_stride_A, index_t batch_stride_B, index_t batch_stride_C, index_t batch_count)
Definition: reference_gemm.hpp:337
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
__global__ void naive_gemm_kernel(ADataType *A, BDataType *B, CDataType *C, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC)
Definition: reference_gemm.hpp:237
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t &x)
Definition: pk_int4.hpp:104
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t)
Definition: float8.hpp:751
CK_TILE_HOST void reference_gemm_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< QDataType > &q, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:24
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t)
Definition: float8.hpp:764
float fp32x2_t
Definition: pk_fp4.hpp:22
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t_signed_conversion(const pk_int4_t &x)
Definition: pk_int4.hpp:119
void reference_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c)
Definition: reference_gemm.hpp:309
CK_TILE_HOST void reference_gemm_multiple_d(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, const std::array< HostTensor< DDataType >, DsDataType::size()> &ds_m_n, HostTensor< CDataType > &c_m_n, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:186
CK_TILE_HOST void reference_gemm(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:125
Definition: host_tensor.hpp:336
std::size_t get_length(std::size_t dim) const
Definition: host_tensor.hpp:388
Definition: functional.hpp:86
Definition: numeric.hpp:81