14 template <
typename ADataType,
19 uint32_t QuantGroupSize,
28 const AElementOp& a_element_op = {},
29 const BElementOp& b_element_op = {},
30 const ACCElementOp& acc_element_op = {})
36 auto f_mn = [&](
auto m,
auto n) {
37 AccDataType v_acc = 0, v_block_acc = 0;
39 static_assert(std::is_same_v<ADataType, pk_int4_t> || std::is_same_v<ADataType, fp8_t> ||
40 std::is_same_v<ADataType, bf8_t>);
41 static_assert(std::is_same_v<BDataType, fp8_t> || std::is_same_v<BDataType, bf8_t> ||
42 std::is_same_v<BDataType, pk_int4_t>);
43 static_assert(std::is_same_v<AccDataType, float>);
44 static_assert(std::is_same_v<CDataType, float> ||
45 std::is_same_v<CDataType, ck_tile::half_t>);
46 for(std::size_t k = 0; k < K; ++k)
50 if constexpr(std::is_same_v<ADataType, pk_int4_t>)
52 const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
61 v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
63 if constexpr(std::is_same_v<BDataType, pk_int4_t>)
65 const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
72 else if constexpr(std::is_same_v<BDataType, fp8_t>)
78 v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
80 v_block_acc += v_a * v_b;
83 if((k + 1) % QuantGroupSize == 0)
86 index_t outer_dim = (aquant) ? m : k / QuantGroupSize;
87 index_t inner_dim = (aquant) ? k / QuantGroupSize : n;
89 if constexpr(std::is_same_v<QDataType, float>)
91 scale = q(outer_dim, inner_dim);
93 else if constexpr(std::is_same_v<QDataType, ck_tile::fp8_t>)
97 else if constexpr(std::is_same_v<QDataType, ck_tile::bf8_t>)
103 static_assert(
false,
"Unexpected Q datatype.");
105 v_block_acc *= scale;
106 v_acc += v_block_acc;
111 c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
115 std::cout << std::endl;
118 template <
typename ADataType,
120 typename AccDataType,
128 const AElementOp& a_element_op = {},
129 const BElementOp& b_element_op = {},
130 const ACCElementOp& acc_element_op = {})
136 auto f_mn = [&](
auto m,
auto n) {
137 AccDataType v_acc = 0;
139 for(std::size_t k = 0; k < K; ++k)
143 if constexpr(std::is_same_v<ADataType, pk_int4_t>)
145 const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
154 v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
156 if constexpr(std::is_same_v<BDataType, pk_int4_t>)
158 const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
167 v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
172 c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
178 template <
typename ADataType,
181 typename AccDataType,
183 typename ACCElementOp,
184 typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
190 const ACCElementOp& acc_element_op = {})
196 auto f_mk_kn_mn = [&](
auto m,
auto n) {
197 AccDataType v_acc = 0;
198 for(std::size_t k = 0; k < K; ++k)
200 ADataType v_a = a_m_k(m, k);
201 BDataType v_b = b_k_n(k, n);
203 ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
207 if constexpr(DsDataType::size() == 0)
209 acc_element_op(v_c, ck_tile::type_convert<float>(v_acc));
211 else if constexpr(DsDataType::size() == 1)
214 ck_tile::type_convert<float>(v_acc),
215 ck_tile::type_convert<float>(ds_m_n[0](m, n)));
217 else if constexpr(DsDataType::size() == 2)
220 ck_tile::type_convert<float>(v_acc),
221 ck_tile::type_convert<float>(ds_m_n[0](m, n)),
222 ck_tile::type_convert<float>(ds_m_n[1](m, n)));
224 c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
230 template <
typename ADataType,
232 typename AccDataType,
247 int idx = blockIdx.x * blockDim.x + threadIdx.x;
251 if(row < M && col < N)
253 AccDataType acc = 0.0;
254 for(
int k = 0; k < K; ++k)
259 int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
262 int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
268 if constexpr(std::is_same_v<ADataType, pk_int4_t>)
278 v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
280 if constexpr(std::is_same_v<BDataType, pk_int4_t>)
290 v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
295 int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
296 ? row * strideC + col
297 : col * strideC + row;
298 C[c_index] = ck_tile::type_convert<CDataType>(acc);
302 template <
typename ADataType,
304 typename AccDataType,
319 int totalElements = M * N;
320 int numThreadsPerBlock = 256;
321 int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
323 naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
324 <<<numBlocks, numThreadsPerBlock>>>(
325 a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
330 template <
typename ADataType,
332 typename AccDataType,
351 int totalElements = M * N;
352 int numThreadsPerBlock = 256;
353 int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
355 for(
index_t batch_id = 0; batch_id < batch_count; ++batch_id)
357 ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
358 BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
359 CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
360 naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
361 <<<numBlocks, numThreadsPerBlock>>>(
362 d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
#define CK_TILE_HOST
Definition: config.hpp:39
Definition: cluster_descriptor.hpp:13
void reference_batched_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t batch_stride_A, index_t batch_stride_B, index_t batch_stride_C, index_t batch_count)
Definition: reference_gemm.hpp:337
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
__global__ void naive_gemm_kernel(ADataType *A, BDataType *B, CDataType *C, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC)
Definition: reference_gemm.hpp:237
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t &x)
Definition: pk_int4.hpp:104
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t)
Definition: float8.hpp:751
CK_TILE_HOST void reference_gemm_quant(const HostTensor< ADataType > &a_m_k, const HostTensor< QDataType > &q, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:24
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t)
Definition: float8.hpp:764
float fp32x2_t
Definition: pk_fp4.hpp:22
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t_signed_conversion(const pk_int4_t &x)
Definition: pk_int4.hpp:119
void reference_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c)
Definition: reference_gemm.hpp:309
CK_TILE_HOST void reference_gemm_multiple_d(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, const std::array< HostTensor< DDataType >, DsDataType::size()> &ds_m_n, HostTensor< CDataType > &c_m_n, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:186
CK_TILE_HOST void reference_gemm(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:125
Definition: host_tensor.hpp:336
std::size_t get_length(std::size_t dim) const
Definition: host_tensor.hpp:388
Definition: functional.hpp:86
Definition: numeric.hpp:81