/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/host/reference/reference_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/host/reference/reference_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/host/reference/reference_gemm.hpp Source File
reference_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <cstdlib>
7 #include <thread>
8 
9 #include "ck_tile/core.hpp"
11 
12 namespace ck_tile {
13 
14 template <typename ADataType,
15  typename BDataType,
16  typename AccDataType,
17  typename CDataType,
18  typename AElementOp = ck_tile::identity,
19  typename BElementOp = ck_tile::identity,
20  typename ACCElementOp = ck_tile::identity>
22  const HostTensor<BDataType>& b_k_n,
23  HostTensor<CDataType>& c_m_n,
24  const AElementOp& a_element_op = {},
25  const BElementOp& b_element_op = {},
26  const ACCElementOp& acc_element_op = {})
27 {
28  const std::size_t M = a_m_k.get_length(0);
29  const std::size_t N = b_k_n.get_length(1);
30  const std::size_t K = a_m_k.get_length(1);
31 
32  auto f_mn = [&](auto m, auto n) {
33  AccDataType v_acc = 0;
34 
35  for(std::size_t k = 0; k < K; ++k)
36  {
37  AccDataType v_a;
38  AccDataType v_b;
39  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
40  {
41  const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
42  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
43  if(k % 2 == 1)
44  v_a = fp32_val.hi;
45  else
46  v_a = fp32_val.lo;
47  }
48  else
49  {
50  v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
51  }
52  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
53  {
54  const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
55  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(pk_val);
56  if(k % 2 == 1)
57  v_b = fp32_val.hi;
58  else
59  v_b = fp32_val.lo;
60  }
61  else
62  {
63  v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
64  }
65  v_acc += v_a * v_b;
66  }
67 
68  c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
69  };
70 
71  make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
72 }
73 
74 template <typename ADataType,
75  typename BDataType,
76  typename DsDataType,
77  typename AccDataType,
78  typename CDataType,
79  typename ACCElementOp,
80  typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
81 CK_TILE_HOST void
83  const HostTensor<BDataType>& b_k_n,
84  const std::array<HostTensor<DDataType>, DsDataType::size()>& ds_m_n,
85  HostTensor<CDataType>& c_m_n,
86  const ACCElementOp& acc_element_op = {})
87 {
88  const std::size_t M = a_m_k.get_length(0);
89  const std::size_t N = b_k_n.get_length(1);
90  const std::size_t K = a_m_k.get_length(1);
91 
92  auto f_mk_kn_mn = [&](auto m, auto n) {
93  AccDataType v_acc = 0;
94  for(std::size_t k = 0; k < K; ++k)
95  {
96  ADataType v_a = a_m_k(m, k);
97  BDataType v_b = b_k_n(k, n);
98  v_acc +=
99  ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
100  }
101 
102  CDataType v_c = 0;
103  if constexpr(DsDataType::size() == 0)
104  {
105  acc_element_op(v_c, ck_tile::type_convert<float>(v_acc));
106  }
107  else if constexpr(DsDataType::size() == 1)
108  {
109  acc_element_op(v_c,
110  ck_tile::type_convert<float>(v_acc),
111  ck_tile::type_convert<float>(ds_m_n[0](m, n)));
112  }
113  else if constexpr(DsDataType::size() == 2)
114  {
115  acc_element_op(v_c,
116  ck_tile::type_convert<float>(v_acc),
117  ck_tile::type_convert<float>(ds_m_n[0](m, n)),
118  ck_tile::type_convert<float>(ds_m_n[1](m, n)));
119  }
120  c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
121  };
122 
123  make_ParallelTensorFunctor(f_mk_kn_mn, M, N)(std::thread::hardware_concurrency());
124 }
125 
126 template <typename ADataType,
127  typename BDataType,
128  typename AccDataType,
129  typename CDataType,
130  typename LayoutA,
131  typename LayoutB,
132  typename LayoutC>
133 __global__ void naive_gemm_kernel(ADataType* A,
134  BDataType* B,
135  CDataType* C,
139  ck_tile::index_t strideA,
140  ck_tile::index_t strideB,
141  ck_tile::index_t strideC)
142 {
143  int idx = blockIdx.x * blockDim.x + threadIdx.x;
144  int row = idx / N; // Compute row index
145  int col = idx % N; // Compute column index
146 
147  if(row < M && col < N)
148  {
149  AccDataType acc = 0.0;
150  for(int k = 0; k < K; ++k)
151  {
154  // Adjust indexing based on matrix layout
155  int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
156  ? row * strideA + k
157  : k * strideA + row;
158  int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
159  ? col * strideB + k
160  : k * strideB + col;
161 
162  AccDataType v_a;
163  AccDataType v_b;
164  if constexpr(std::is_same_v<ADataType, pk_int4_t>)
165  {
166  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(A[a_index / packed_size_a]);
167  if(k % 2 == 1)
168  v_a = fp32_val.hi;
169  else
170  v_a = fp32_val.lo;
171  }
172  else
173  {
174  v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
175  }
176  if constexpr(std::is_same_v<BDataType, pk_int4_t>)
177  {
178  const fp32x2_t fp32_val = pk_int4_t_to_fp32x2_t(B[b_index / packed_size_b]);
179  if(k % 2 == 1)
180  v_b = fp32_val.hi;
181  else
182  v_b = fp32_val.lo;
183  }
184  else
185  {
186  v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
187  }
188  acc += v_a * v_b;
189  }
190 
191  int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
192  ? row * strideC + col
193  : col * strideC + row;
194  C[c_index] = ck_tile::type_convert<CDataType>(acc);
195  }
196 }
197 
198 template <typename ADataType,
199  typename BDataType,
200  typename AccDataType,
201  typename CDataType,
202  typename LayoutA,
203  typename LayoutB,
204  typename LayoutC>
205 void reference_gemm_gpu(ADataType* a_ptr,
206  BDataType* b_ptr,
207  CDataType* c_ptr,
208  index_t M,
209  index_t N,
210  index_t K,
211  index_t stride_a,
212  index_t stride_b,
213  index_t stride_c)
214 {
215  int totalElements = M * N;
216  int numThreadsPerBlock = 256; // Common choice for threads per block
217  int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
218 
219  naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
220  <<<numBlocks, numThreadsPerBlock>>>(
221  a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
222 
223  return;
224 }
225 
226 template <typename ADataType,
227  typename BDataType,
228  typename AccDataType,
229  typename CDataType,
230  typename LayoutA,
231  typename LayoutB,
232  typename LayoutC>
233 void reference_batched_gemm_gpu(ADataType* a_ptr,
234  BDataType* b_ptr,
235  CDataType* c_ptr,
236  index_t M,
237  index_t N,
238  index_t K,
239  index_t stride_a,
240  index_t stride_b,
241  index_t stride_c,
242  index_t batch_stride_A,
243  index_t batch_stride_B,
244  index_t batch_stride_C,
245  index_t batch_count)
246 {
247  int totalElements = M * N;
248  int numThreadsPerBlock = 256; // Common choice for threads per block
249  int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
250 
251  for(index_t batch_id = 0; batch_id < batch_count; ++batch_id)
252  {
253  ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
254  BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
255  CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
256  naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
257  <<<numBlocks, numThreadsPerBlock>>>(
258  d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
259  }
260 
261  return;
262 }
263 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:39
Definition: cluster_descriptor.hpp:13
void reference_batched_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t batch_stride_A, index_t batch_stride_B, index_t batch_stride_C, index_t batch_count)
Definition: reference_gemm.hpp:233
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
__global__ void naive_gemm_kernel(ADataType *A, BDataType *B, CDataType *C, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC)
Definition: reference_gemm.hpp:133
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t &x)
Definition: pk_int4.hpp:104
float fp32x2_t
Definition: pk_int4.hpp:100
int32_t index_t
Definition: integer.hpp:9
void reference_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c)
Definition: reference_gemm.hpp:205
CK_TILE_HOST void reference_gemm_multiple_d(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, const std::array< HostTensor< DDataType >, DsDataType::size()> &ds_m_n, HostTensor< CDataType > &c_m_n, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:82
CK_TILE_HOST void reference_gemm(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:21
Definition: host_tensor.hpp:336
std::size_t get_length(std::size_t dim) const
Definition: host_tensor.hpp:388
Definition: functional.hpp:86
Definition: numeric.hpp:81