/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_gemm.hpp Source File
reference_gemm.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <cstdlib>
7 #include <thread>
8 
9 #include "ck_tile/core.hpp"
11 
12 namespace ck_tile {
13 
14 template <typename ADataType,
15  typename QDataType,
16  typename BDataType,
17  typename AccDataType,
18  typename CDataType,
19  typename QuantGroupSize,
20  bool aquant,
21  typename AElementOp = ck_tile::identity,
22  typename BElementOp = ck_tile::identity,
23  typename ACCElementOp = ck_tile::identity>
25  const HostTensor<QDataType>& q,
26  const HostTensor<BDataType>& b_k_n,
27  HostTensor<CDataType>& c_m_n,
28  const AElementOp& a_element_op = {},
29  const BElementOp& b_element_op = {},
30  const ACCElementOp& acc_element_op = {})
31 {
32  const std::size_t M = a_m_k.get_length(0);
33  const std::size_t N = b_k_n.get_length(1);
34  const std::size_t K = a_m_k.get_length(1);
35 
36  auto f_mn = [&](auto m, auto n) {
37  AccDataType v_acc = 0, v_block_acc = 0;
38 
39  static_assert(std::is_same_v<ADataType, pk_int4_t> || std::is_same_v<ADataType, fp8_t> ||
40  std::is_same_v<ADataType, bf8_t>);
41  static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
42  std::is_same_v<BDataType, pk_int4_t>);
43  static_assert(std::is_same_v<AccDataType, float>);
44  static_assert(std::is_same_v<CDataType, float> ||
45  std::is_same_v<CDataType, ck_tile::half_t>);
46  for(std::size_t k = 0; k < K; ++k)
47  {
48  AccDataType v_a;
49  AccDataType v_b;
50  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
51  {
52  const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
53  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
54  if(k % 2 == 1)
55  v_a = fp32_val.hi;
56  else
57  v_a = fp32_val.lo;
58  }
59  else
60  {
61  v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
62  }
63  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
64  {
65  const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
66  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
67  if(k % 2 == 1)
68  v_b = fp32_val.hi;
69  else
70  v_b = fp32_val.lo;
71  }
72  else if constexpr(std::is_same_v<BDataType, fp8_t>)
73  {
74  v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n)));
75  }
76  else
77  {
78  v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
79  }
80  v_block_acc += v_a * v_b;
81 
82  // Apply group dequant scale
83  if((k + 1) % QuantGroupSize::kK == 0)
84  {
85  float scale = 0.f;
86  index_t outer_dim = (aquant) ? (m / QuantGroupSize::kM) : (k / QuantGroupSize::kK);
87  index_t inner_dim = (aquant) ? (k / QuantGroupSize::kK) : (n / QuantGroupSize::kN);
88  if constexpr(std::is_same_v<QDataType, float>)
89  {
90  scale = q(outer_dim, inner_dim);
91  }
92  else if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
93  {
94  scale = fp8_to_float_raw(q(outer_dim, inner_dim));
95  }
96  else if constexpr(std::is_same_v<QDataType, ck_tile::bf8_t>)
97  {
98  scale = bf8_to_float_raw(q(outer_dim, inner_dim));
99  }
100  else
101  {
102  static_assert(false, "Unexpected Q datatype.");
103  }
104  v_block_acc *= scale;
105  v_acc += v_block_acc;
106  v_block_acc = 0;
107  }
108  }
109 
110  c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
111  };
112 
113  make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
114  std::cout << std::endl;
115 }
116 
117 template <typename ADataType,
118  typename AQDataType,
119  typename BDataType,
120  typename BQDataType,
121  typename AccDataType,
122  typename CDataType,
123  typename AElementOp = ck_tile::identity,
124  typename BElementOp = ck_tile::identity,
125  typename ACCElementOp = ck_tile::identity>
127  const HostTensor<AQDataType>& aq_m_1,
128  const HostTensor<BDataType>& b_k_n,
129  const HostTensor<BQDataType>& bq_1_n,
130  HostTensor<CDataType>& c_m_n,
131  const AElementOp& a_element_op = {},
132  const BElementOp& b_element_op = {},
133  const ACCElementOp& acc_element_op = {})
134 {
135  static_assert(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>);
136  static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>);
137  static_assert(std::is_same_v<AccDataType, float>);
138  static_assert(std::is_same_v<CDataType, float> || std::is_same_v<CDataType, ck_tile::half_t>);
139  static_assert(std::is_same_v<AQDataType, float> && std::is_same_v<BQDataType, float>);
140  const std::size_t M = a_m_k.get_length(0);
141  const std::size_t N = b_k_n.get_length(1);
142  const std::size_t K = a_m_k.get_length(1);
143 
144  auto f_mn = [&](auto m, auto n) {
145  // Init accumulator
146  AccDataType v_acc = 0;
147  // Get row scale for A and column scale for B
148  float a_scale = aq_m_1(m, 0);
149  float b_scale = bq_1_n(0, n);
150 
151  // Compute the dot product
152  for(std::size_t k = 0; k < K; ++k)
153  {
154  AccDataType v_a;
155  AccDataType v_b;
156 
157  // Process A data
158  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
159  {
160  const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
161  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
162  if(k % 2 == 1)
163  v_a = fp32_val.hi;
164  else
165  v_a = fp32_val.lo;
166  }
167  else
168  {
169  v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
170  }
171 
172  // Process B data
173  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
174  {
175  const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
176  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t_signed_conversion(pk_val);
177  if(k % 2 == 1)
178  v_b = fp32_val.hi;
179  else
180  v_b = fp32_val.lo;
181  }
182  else
183  {
184  v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
185  }
186 
187  v_acc += v_a * v_b;
188  }
189 
190  v_acc = v_acc * a_scale * b_scale;
191 
192  c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
193  };
194 
195  make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
196 }
197 
198 template <typename ADataType,
199  typename AQDataType,
200  typename BDataType,
201  typename BQDataType,
202  typename AccDataType,
203  typename CDataType,
204  typename AElementOp = ck_tile::identity,
205  typename BElementOp = ck_tile::identity,
206  typename ACCElementOp = ck_tile::identity>
208  const HostTensor<AQDataType>& aq_1_1,
209  const HostTensor<BDataType>& b_k_n,
210  const HostTensor<BQDataType>& bq_1_1,
211  HostTensor<CDataType>& c_m_n,
212  const AElementOp& a_element_op = {},
213  const BElementOp& b_element_op = {},
214  const ACCElementOp& acc_element_op = {})
215 {
216  static_assert(std::is_same_v<ADataType, fp8_t> || std::is_same_v<ADataType, bf8_t>);
217  static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t>);
218  static_assert(std::is_same_v<AccDataType, float>);
219  static_assert(std::is_same_v<CDataType, float> || std::is_same_v<CDataType, ck_tile::half_t>);
220  static_assert(std::is_same_v<AQDataType, float> && std::is_same_v<BQDataType, float>);
221  const std::size_t M = a_m_k.get_length(0);
222  const std::size_t N = b_k_n.get_length(1);
223  const std::size_t K = a_m_k.get_length(1);
224 
225  auto f_mn = [&](auto m, auto n) {
226  // Init accumulator
227  AccDataType v_acc = 0;
228  // Get scale for A and scale for B
229  const AccDataType a_scale = ck_tile::type_convert<AccDataType>(aq_1_1(0, 0));
230  const AccDataType b_scale = ck_tile::type_convert<AccDataType>(bq_1_1(0, 0));
231 
232  // Compute the dot product
233  for(std::size_t k = 0; k < K; ++k)
234  {
235  AccDataType v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
236  AccDataType v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
237 
238  v_acc += v_a * v_b;
239  }
240 
241  v_acc = v_acc * a_scale * b_scale;
242 
243  c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
244  };
245 
246  make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
247 }
248 
249 template <typename ADataType,
250  typename BDataType,
251  typename AccDataType,
252  typename CDataType,
253  typename AElementOp = ck_tile::identity,
254  typename BElementOp = ck_tile::identity,
255  typename ACCElementOp = ck_tile::identity>
257  const HostTensor<BDataType>& b_k_n,
258  HostTensor<CDataType>& c_m_n,
259  const AElementOp& a_element_op = {},
260  const BElementOp& b_element_op = {},
261  const ACCElementOp& acc_element_op = {})
262 {
263  const std::size_t M = a_m_k.get_length(0);
264  const std::size_t N = b_k_n.get_length(1);
265  const std::size_t K = a_m_k.get_length(1);
266 
267  auto f_mn = [&](auto m, auto n) {
268  AccDataType v_acc = 0;
269 
270  for(std::size_t k = 0; k < K; ++k)
271  {
272  AccDataType v_a;
273  AccDataType v_b;
274  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
275  {
276  const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
277  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
278  if(k % 2 == 1)
279  v_a = fp32_val.hi;
280  else
281  v_a = fp32_val.lo;
282  }
283  else
284  {
285  v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
286  }
287  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
288  {
289  const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
290  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
291  if(k % 2 == 1)
292  v_b = fp32_val.hi;
293  else
294  v_b = fp32_val.lo;
295  }
296  else
297  {
298  v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
299  }
300  v_acc += v_a * v_b;
301  }
302 
303  c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
304  };
305 
306  make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
307 }
308 
309 template <typename AsDataType,
310  typename BsDataType,
311  typename DsDataType,
312  typename AccDataType,
313  typename CDataType,
314  typename AElementOp,
315  typename BElementOp,
316  typename CDElementOp,
317  typename ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>,
318  typename BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>,
319  typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
320 CK_TILE_HOST void
321 reference_gemm_multiple_abd(const std::array<HostTensor<ADataType>, AsDataType::size()>& as_m_k,
322  const std::array<HostTensor<BDataType>, BsDataType::size()>& bs_k_n,
323  const std::array<HostTensor<DDataType>, DsDataType::size()>& ds_m_n,
324  HostTensor<ADataType>& a_m_k,
325  HostTensor<BDataType>& b_k_n,
326  HostTensor<CDataType>& c_m_n,
327  const AElementOp& a_element_op = {},
328  const BElementOp& b_element_op = {},
329  const CDElementOp& acc_element_op = {})
330 {
331  const std::size_t M = a_m_k.get_length(0);
332  const std::size_t N = b_k_n.get_length(1);
333  const std::size_t K = a_m_k.get_length(1);
334 
335  auto as_m_k_tuple =
336  generate_tie([&](auto idx) -> auto& { return as_m_k[idx]; }, number<AsDataType::size()>{});
337 
338  auto bs_k_n_tuple =
339  generate_tie([&](auto idx) -> auto& { return bs_k_n[idx]; }, number<BsDataType::size()>{});
340 
341  auto ds_m_n_tuple =
342  generate_tie([&](auto idx) -> auto& { return ds_m_n[idx]; }, number<DsDataType::size()>{});
343 
344  // Apply elementwise function to A
345  auto a_elementwise_fn = [&](auto i, auto j) {
346  ck_tile::apply([&](auto&&... t) { a_element_op(a_m_k(i, j), t(i, j)...); }, as_m_k_tuple);
347  };
348 
349  make_ParallelTensorFunctor(a_elementwise_fn, M, K)(std::thread::hardware_concurrency());
350 
351  // Apply elementwise function to B
352  auto b_elementwise_fn = [&](auto i, auto j) {
353  ck_tile::apply([&](auto&&... t) { b_element_op(b_k_n(i, j), t(i, j)...); }, bs_k_n_tuple);
354  };
355 
356  make_ParallelTensorFunctor(b_elementwise_fn, K, N)(std::thread::hardware_concurrency());
357 
358  auto f_mk_kn_mn = [&](auto m, auto n) {
359  AccDataType v_acc = 0;
360  for(std::size_t k = 0; k < K; ++k)
361  {
362  ADataType v_a = a_m_k(m, k);
363  BDataType v_b = b_k_n(k, n);
364  v_acc +=
365  ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
366  }
367 
368  CDataType v_c = 0;
369 
371  [&](auto&&... t) {
372  acc_element_op(v_c,
373  ck_tile::type_convert<float>(v_acc),
374  ck_tile::type_convert<float>(t(m, n))...);
375  },
376  ds_m_n_tuple);
377 
378  c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
379  };
380 
381  make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
382 }
383 
384 template <typename ADataType,
385  typename BDataType,
386  typename ScaleDataType,
387  typename AccDataType,
388  typename CDataType,
389  typename AElementOp = ck_tile::identity,
390  typename BElementOp = ck_tile::identity,
391  typename ACCElementOp = ck_tile::identity>
393  const HostTensor<BDataType>& b_k_n,
394  HostTensor<CDataType>& c_m_n,
395  const HostTensor<ScaleDataType>& scale_a,
396  const HostTensor<ScaleDataType>& scale_b,
397  const AElementOp& = {},
398  const BElementOp& = {},
399  const ACCElementOp& = {})
400 {
401  static_assert(std::is_same_v<AElementOp, ck_tile::identity>);
402  static_assert(std::is_same_v<BElementOp, ck_tile::identity>);
403  static_assert(std::is_same_v<ACCElementOp, ck_tile::identity>);
404 
405  const std::size_t M = a_m_k.get_length(0);
406  const std::size_t N = b_k_n.get_length(1);
407  const std::size_t K = a_m_k.get_length(1);
408 
409  const std::size_t ScaleBlockSize = K / scale_a.get_length(1);
410 
411  HostTensor<AccDataType> a_m_k_scaled({std::size_t(M), std::size_t(K)},
412  {std::size_t(K), std::size_t(1)});
413  HostTensor<AccDataType> b_k_n_scaled({std::size_t(K), std::size_t(N)},
414  {std::size_t(1), std::size_t(K)});
415 
416  for(std::size_t m = 0; m < M; ++m)
417  {
418  for(std::size_t k = 0; k < K; ++k)
419  {
420  if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
421  {
422  if(k % 2 == 1)
423  continue; // skip odd k
424 
425  auto a_f4x2 = a_m_k(m, k);
426  auto a_scale = ck_tile::type_convert<AccDataType>(scale_a(m, k / ScaleBlockSize));
427  auto a_f4_lo =
428  ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(number<0>{}));
429  auto a_f4_hi =
430  ck_tile::type_convert<AccDataType>(a_f4x2.template unpack<>(number<1>{}));
431 
432  a_m_k_scaled(m, k) = a_f4_lo * a_scale;
433  a_m_k_scaled(m, k + 1) = a_f4_hi * a_scale;
434  }
435  else
436  {
437  a_m_k_scaled(m, k) =
438  ck_tile::type_convert<AccDataType>((a_m_k(m, k))) *
439  ck_tile::type_convert<AccDataType>(scale_a(m, k / ScaleBlockSize));
440  }
441  }
442  }
443 
444  for(std::size_t n = 0; n < N; n++)
445  {
446  for(std::size_t k = 0; k < K; k++)
447  {
448  if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
449  {
450  if(k % 2 == 1)
451  continue; // skip odd k
452 
453  auto b_f4x2 = b_k_n(k, n);
454  auto b_scale = ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
455  auto b_f4_lo =
456  ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(number<0>{}));
457  auto b_f4_hi =
458  ck_tile::type_convert<AccDataType>(b_f4x2.template unpack<>(number<1>{}));
459 
460  b_k_n_scaled(k, n) = b_f4_lo * b_scale;
461  b_k_n_scaled(k + 1, n) = b_f4_hi * b_scale;
462  }
463  else
464  {
465  b_k_n_scaled(k, n) =
466  ck_tile::type_convert<AccDataType>((b_k_n(k, n))) *
467  ck_tile::type_convert<AccDataType>(scale_b(k / ScaleBlockSize, n));
468  }
469  }
470  }
471 
472  // call reference gemm
473  reference_gemm<AccDataType, AccDataType, AccDataType, CDataType>(
474  a_m_k_scaled, b_k_n_scaled, c_m_n);
475 }
476 
477 template <typename ADataType,
478  typename BDataType,
479  typename DsDataType,
480  typename AccDataType,
481  typename CDataType,
482  typename ACCElementOp,
483  typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
484 CK_TILE_HOST void
486  const HostTensor<BDataType>& b_k_n,
487  const std::array<HostTensor<DDataType>, DsDataType::size()>& ds_m_n,
488  HostTensor<CDataType>& c_m_n,
489  const ACCElementOp& acc_element_op = {})
490 {
491  const std::size_t M = a_m_k.get_length(0);
492  const std::size_t N = b_k_n.get_length(1);
493  const std::size_t K = a_m_k.get_length(1);
494 
495  auto f_mk_kn_mn = [&](auto m, auto n) {
496  AccDataType v_acc = 0;
497  for(std::size_t k = 0; k < K; ++k)
498  {
499  ADataType v_a = a_m_k(m, k);
500  BDataType v_b = b_k_n(k, n);
501  v_acc +=
502  ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
503  }
504 
505  CDataType v_c = 0;
506  if constexpr(DsDataType::size() == 0)
507  {
508  acc_element_op(v_c, ck_tile::type_convert<float>(v_acc));
509  }
510  else if constexpr(DsDataType::size() == 1)
511  {
512  acc_element_op(v_c,
513  ck_tile::type_convert<float>(v_acc),
514  ck_tile::type_convert<float>(ds_m_n[0](m, n)));
515  }
516  else if constexpr(DsDataType::size() == 2)
517  {
518  acc_element_op(v_c,
519  ck_tile::type_convert<float>(v_acc),
520  ck_tile::type_convert<float>(ds_m_n[0](m, n)),
521  ck_tile::type_convert<float>(ds_m_n[1](m, n)));
522  }
523  c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
524  };
525 
526  make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
527 }
528 
529 template <typename ADataType,
530  typename BDataType,
531  typename AccDataType,
532  typename CDataType,
533  typename LayoutA,
534  typename LayoutB,
535  typename LayoutC>
536 __global__ void naive_gemm_kernel(ADataType* A,
537  BDataType* B,
538  CDataType* C,
542  ck_tile::index_t strideA,
543  ck_tile::index_t strideB,
544  ck_tile::index_t strideC)
545 {
546  int idx = blockIdx.x * blockDim.x + threadIdx.x;
547  int row = idx / N; // Compute row index
548  int col = idx % N; // Compute column index
549 
550  if(row < M && col < N)
551  {
552  AccDataType acc = 0.0;
553  for(int k = 0; k < K; ++k)
554  {
557  // Adjust indexing based on matrix layout
558  int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
559  ? row * strideA + k
560  : k * strideA + row;
561  int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
562  ? col * strideB + k
563  : k * strideB + col;
564 
565  AccDataType v_a;
566  AccDataType v_b;
567  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
568  {
569  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
570  if(k % 2 == 1)
571  v_a = fp32_val.hi;
572  else
573  v_a = fp32_val.lo;
574  }
575  else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
576  {
577  const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
578  if(k % 2 == 1)
579  v_a = fp32_val.hi;
580  else
581  v_a = fp32_val.lo;
582  }
583  else
584  {
585  v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
586  }
587  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
588  {
589  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
590  if(k % 2 == 1)
591  v_b = fp32_val.hi;
592  else
593  v_b = fp32_val.lo;
594  }
595  else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
596  {
597  const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b]);
598  if(k % 2 == 1)
599  v_b = fp32_val.hi;
600  else
601  v_b = fp32_val.lo;
602  }
603  else
604  {
605  v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
606  }
607  acc += v_a * v_b;
608  }
609 
610  int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
611  ? row * strideC + col
612  : col * strideC + row;
613  C[c_index] = ck_tile::type_convert<CDataType>(acc);
614  }
615 }
616 
617 template <typename ADataType,
618  typename BDataType,
619  typename AccDataType,
620  typename CDataType,
621  typename LayoutA,
622  typename LayoutB,
623  typename LayoutC>
624 __global__ void blockwise_gemm_kernel(ADataType* A,
625  BDataType* B,
626  CDataType* C,
630  ck_tile::index_t strideA,
631  ck_tile::index_t strideB,
632  ck_tile::index_t strideC,
633  ck_tile::index_t scale_granularity_m,
634  ck_tile::index_t scale_granularity_n,
635  ck_tile::index_t scale_granularity_k,
636  float* scale_A_ptr,
637  float* scale_B_ptr)
638 {
639  int idx = blockIdx.x * blockDim.x + threadIdx.x;
640  int row = idx / N; // Compute row index
641  int col = idx % N; // Compute column index
642 
643  if(row < M && col < N)
644  {
645  AccDataType acc = 0.0, acc_temp = 0.0;
646 
647  index_t scale_A_stride = (M + scale_granularity_m - 1) / scale_granularity_m;
648  index_t scale_B_stride = (N + scale_granularity_n - 1) / scale_granularity_n;
649 
650  float scale_A = 0;
651  float scale_B = 0;
652 
653  for(int k = 0; k < K; ++k)
654  {
655  if(k % scale_granularity_k == 0)
656  {
657  // update acc
658  acc += acc_temp * scale_A * scale_B;
659  acc_temp = 0.0;
660  // update scale factors
661  scale_A = scale_A_ptr[(row / scale_granularity_m) +
662  (k / scale_granularity_k) * scale_A_stride];
663  scale_B = scale_B_ptr[(col / scale_granularity_n) +
664  (k / scale_granularity_k) * scale_B_stride];
665  }
666 
669  // Adjust indexing based on matrix layout
670  int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
671  ? row * strideA + k
672  : k * strideA + row;
673  int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
674  ? col * strideB + k
675  : k * strideB + col;
676 
677  AccDataType v_a;
678  AccDataType v_b;
679  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
680  {
681  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
682  if(k % 2 == 1)
683  v_a = fp32_val.hi;
684  else
685  v_a = fp32_val.lo;
686  }
687  else if constexpr(std::is_same_v<ADataType, pk_fp4_t>)
688  {
689  const fp32x2_t fp32_val = pk_fp4_to_fp32x2(A[a_index / packed_size_a]);
690  if(k % 2 == 1)
691  v_a = fp32_val.hi;
692  else
693  v_a = fp32_val.lo;
694  }
695  else
696  {
697  v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
698  }
699 
700  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
701  {
702  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
703  if(k % 2 == 1)
704  v_b = fp32_val.hi;
705  else
706  v_b = fp32_val.lo;
707  }
708  else if constexpr(std::is_same_v<BDataType, pk_fp4_t>)
709  {
710  const fp32x2_t fp32_val = pk_fp4_to_fp32x2(B[b_index / packed_size_b], 1.0f);
711  if(k % 2 == 1)
712  v_b = fp32_val.hi;
713  else
714  v_b = fp32_val.lo;
715  }
716  else
717  {
718  v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
719  }
720  acc_temp += v_a * v_b;
721  }
722  // final accumulation
723  acc += acc_temp * scale_A * scale_B;
724 
725  int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
726  ? row * strideC + col
727  : col * strideC + row;
728  C[c_index] = ck_tile::type_convert<CDataType>(acc);
729  }
730 }
731 
732 template <typename ADataType,
733  typename BDataType,
734  typename AccDataType,
735  typename CDataType,
736  typename LayoutA,
737  typename LayoutB,
738  typename LayoutC>
739 void reference_gemm_gpu(ADataType* a_ptr,
740  BDataType* b_ptr,
741  CDataType* c_ptr,
742  index_t M,
743  index_t N,
744  index_t K,
745  index_t stride_a,
746  index_t stride_b,
747  index_t stride_c)
748 {
749  int totalElements = M * N;
750  int numThreadsPerBlock = 256; // Common choice for threads per block
751  int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
752 
753  naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
754  <<<numBlocks, numThreadsPerBlock>>>(
755  a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
756 
757  return;
758 }
759 
760 template <typename ADataType,
761  typename BDataType,
762  typename AccDataType,
763  typename CDataType,
764  typename LayoutA,
765  typename LayoutB,
766  typename LayoutC>
767 void reference_blockwise_gemm_gpu(ADataType* a_ptr,
768  BDataType* b_ptr,
769  CDataType* c_ptr,
770  index_t M,
771  index_t N,
772  index_t K,
773  index_t stride_a,
774  index_t stride_b,
775  index_t stride_c,
776  index_t scale_granularity_m,
777  index_t scale_granularity_n,
778  index_t scale_granularity_k,
779  float* scale_A_ptr,
780  float* scale_B_ptr)
781 {
782  int totalElements = M * N;
783  int numThreadsPerBlock = 256; // Common choice for threads per block
784  int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
785 
786  blockwise_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
787  <<<numBlocks, numThreadsPerBlock>>>(a_ptr,
788  b_ptr,
789  c_ptr,
790  M,
791  N,
792  K,
793  stride_a,
794  stride_b,
795  stride_c,
796  scale_granularity_m,
797  scale_granularity_n,
798  scale_granularity_k,
799  scale_A_ptr,
800  scale_B_ptr);
801 
802  return;
803 }
804 
805 template <typename ADataType,
806  typename BDataType,
807  typename AccDataType,
808  typename CDataType,
809  typename LayoutA,
810  typename LayoutB,
811  typename LayoutC>
812 void reference_batched_gemm_gpu(ADataType* a_ptr,
813  BDataType* b_ptr,
814  CDataType* c_ptr,
815  index_t M,
816  index_t N,
817  index_t K,
818  index_t stride_a,
819  index_t stride_b,
820  index_t stride_c,
821  index_t batch_stride_A,
822  index_t batch_stride_B,
823  index_t batch_stride_C,
824  index_t batch_count)
825 {
826  int totalElements = M * N;
827  int numThreadsPerBlock = 256; // Common choice for threads per block
828  int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
829 
830  for(index_t batch_id = 0; batch_id < batch_count; ++batch_id)
831  {
832  ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
833  BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
834  CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
835  naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
836  <<<numBlocks, numThreadsPerBlock>>>(
837  d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
838  }
839 
840  return;
841 }
842 
843 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:44
Definition: cluster_descriptor.hpp:13
void reference_batched_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t batch_stride_A, index_t batch_stride_B, index_t batch_stride_C, index_t batch_count)
Definition: reference_gemm.hpp:812
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
constexpr decltype(auto) apply(F &&f, Tuple &&t)
Definition: tuple.hpp:526
__global__ void naive_gemm_kernel(ADataType *A, BDataType *B, CDataType *C, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC)
Definition: reference_gemm.hpp:536
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t &x)
Definition: pk_int4.hpp:105
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t)
Definition: float8.hpp:751
CK_TILE_HOST void reference_gemm_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< QDataType > &q, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:24
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t)
Definition: float8.hpp:764
CK_TILE_HOST void reference_gemm_multiple_abd(const std::array< HostTensor< ADataType >, AsDataType::size()> &as_m_k, const std::array< HostTensor< BDataType >, BsDataType::size()> &bs_k_n, const std::array< HostTensor< DDataType >, DsDataType::size()> &ds_m_n, HostTensor< ADataType > &a_m_k, HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const CDElementOp &acc_element_op={})
Definition: reference_gemm.hpp:321
float fp32x2_t
Definition: bfloat16.hpp:434
void reference_blockwise_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t scale_granularity_m, index_t scale_granularity_n, index_t scale_granularity_k, float *scale_A_ptr, float *scale_B_ptr)
Definition: reference_gemm.hpp:767
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST void reference_gemm_rowcol_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< AQDataType > &aq_m_1, const HostTensor< BDataType > &b_k_n, const HostTensor< BQDataType > &bq_1_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:126
constexpr CK_TILE_HOST_DEVICE auto generate_tie(F &&f, number< N >)
Definition: tuple.hpp:435
constant< v > number
Definition: integral_constant.hpp:37
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t_signed_conversion(const pk_int4_t &x)
Definition: pk_int4.hpp:120
__global__ void blockwise_gemm_kernel(ADataType *A, BDataType *B, CDataType *C, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC, ck_tile::index_t scale_granularity_m, ck_tile::index_t scale_granularity_n, ck_tile::index_t scale_granularity_k, float *scale_A_ptr, float *scale_B_ptr)
Definition: reference_gemm.hpp:624
constexpr CK_TILE_HOST_DEVICE fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t &x, float scale)
Definition: pk_fp4.hpp:350
void reference_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c)
Definition: reference_gemm.hpp:739
CK_TILE_HOST void reference_gemm_multiple_d(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, const std::array< HostTensor< DDataType >, DsDataType::size()> &ds_m_n, HostTensor< CDataType > &c_m_n, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:485
CK_TILE_HOST void reference_gemm(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:256
CK_TILE_HOST void reference_mx_gemm(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const HostTensor< ScaleDataType > &scale_a, const HostTensor< ScaleDataType > &scale_b, const AElementOp &={}, const BElementOp &={}, const ACCElementOp &={})
Definition: reference_gemm.hpp:392
CK_TILE_HOST void reference_gemm_tensor_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< AQDataType > &aq_1_1, const HostTensor< BDataType > &b_k_n, const HostTensor< BQDataType > &bq_1_1, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:207
Definition: host_tensor.hpp:336
std::size_t get_length(std::size_t dim) const
Definition: host_tensor.hpp:388
Definition: functional.hpp:86
Definition: numeric.hpp:81