14 template <
typename ADataType,
24 const AElementOp& a_element_op = {},
25 const BElementOp& b_element_op = {},
26 const ACCElementOp& acc_element_op = {})
32 auto f_mn = [&](
auto m,
auto n) {
33 AccDataType v_acc = 0;
35 for(std::size_t k = 0; k < K; ++k)
39 if constexpr(std::is_same_v<ADataType, pk_int4_t>)
41 const pk_int4_t pk_val = a_element_op(a_m_k(m, k));
50 v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
52 if constexpr(std::is_same_v<BDataType, pk_int4_t>)
54 const pk_int4_t pk_val = b_element_op(b_k_n(k, n));
63 v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
68 c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
74 template <
typename ADataType,
79 typename ACCElementOp,
80 typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
86 const ACCElementOp& acc_element_op = {})
92 auto f_mk_kn_mn = [&](
auto m,
auto n) {
93 AccDataType v_acc = 0;
94 for(std::size_t k = 0; k < K; ++k)
96 ADataType v_a = a_m_k(m, k);
97 BDataType v_b = b_k_n(k, n);
99 ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
103 if constexpr(DsDataType::size() == 0)
105 acc_element_op(v_c, ck_tile::type_convert<float>(v_acc));
107 else if constexpr(DsDataType::size() == 1)
110 ck_tile::type_convert<float>(v_acc),
111 ck_tile::type_convert<float>(ds_m_n[0](m, n)));
113 else if constexpr(DsDataType::size() == 2)
116 ck_tile::type_convert<float>(v_acc),
117 ck_tile::type_convert<float>(ds_m_n[0](m, n)),
118 ck_tile::type_convert<float>(ds_m_n[1](m, n)));
120 c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
126 template <
typename ADataType,
128 typename AccDataType,
143 int idx = blockIdx.x * blockDim.x + threadIdx.x;
147 if(row < M && col < N)
149 AccDataType acc = 0.0;
150 for(
int k = 0; k < K; ++k)
155 int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
158 int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
164 if constexpr(std::is_same_v<ADataType, pk_int4_t>)
174 v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
176 if constexpr(std::is_same_v<BDataType, pk_int4_t>)
186 v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
191 int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
192 ? row * strideC + col
193 : col * strideC + row;
194 C[c_index] = ck_tile::type_convert<CDataType>(acc);
198 template <
typename ADataType,
200 typename AccDataType,
215 int totalElements = M * N;
216 int numThreadsPerBlock = 256;
217 int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
219 naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
220 <<<numBlocks, numThreadsPerBlock>>>(
221 a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
226 template <
typename ADataType,
228 typename AccDataType,
247 int totalElements = M * N;
248 int numThreadsPerBlock = 256;
249 int numBlocks = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
251 for(
index_t batch_id = 0; batch_id < batch_count; ++batch_id)
253 ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
254 BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
255 CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
256 naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
257 <<<numBlocks, numThreadsPerBlock>>>(
258 d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
#define CK_TILE_HOST
Definition: config.hpp:39
Definition: cluster_descriptor.hpp:13
void reference_batched_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t batch_stride_A, index_t batch_stride_B, index_t batch_stride_C, index_t batch_count)
Definition: reference_gemm.hpp:233
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
__global__ void naive_gemm_kernel(ADataType *A, BDataType *B, CDataType *C, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC)
Definition: reference_gemm.hpp:133
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t &x)
Definition: pk_int4.hpp:104
float fp32x2_t
Definition: pk_int4.hpp:100
int32_t index_t
Definition: integer.hpp:9
void reference_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c)
Definition: reference_gemm.hpp:205
CK_TILE_HOST void reference_gemm_multiple_d(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, const std::array< HostTensor< DDataType >, DsDataType::size()> &ds_m_n, HostTensor< CDataType > &c_m_n, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:82
CK_TILE_HOST void reference_gemm(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:21
Definition: host_tensor.hpp:336
std::size_t get_length(std::size_t dim) const
Definition: host_tensor.hpp:388
Definition: functional.hpp:86
Definition: numeric.hpp:81