14 template <
typename ADataType,
 
   24                                  const AElementOp& a_element_op     = {},
 
   25                                  const BElementOp& b_element_op     = {},
 
   26                                  const ACCElementOp& acc_element_op = {})
 
   32     auto f_mn = [&](
auto m, 
auto n) {
 
   33         AccDataType v_acc = 0;
 
   35         for(std::size_t k = 0; k < K; ++k)
 
   39             if constexpr(std::is_same_v<ADataType, pk_int4_t>)
 
   41                 const pk_int4_t pk_val  = a_element_op(a_m_k(m, k));
 
   50                 v_a = ck_tile::type_convert<AccDataType>(a_element_op(a_m_k(m, k)));
 
   52             if constexpr(std::is_same_v<BDataType, pk_int4_t>)
 
   54                 const pk_int4_t pk_val  = b_element_op(b_k_n(k, n));
 
   63                 v_b = ck_tile::type_convert<AccDataType>(b_element_op(b_k_n(k, n)));
 
   68         c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
 
   74 template <
typename ADataType,
 
   79           typename ACCElementOp,
 
   80           typename DDataType = remove_cvref_t<std::tuple_element_t<0, DsDataType>>>
 
   86                           const ACCElementOp& acc_element_op = {})
 
   92     auto f_mk_kn_mn = [&](
auto m, 
auto n) {
 
   93         AccDataType v_acc = 0;
 
   94         for(std::size_t k = 0; k < K; ++k)
 
   96             ADataType v_a = a_m_k(m, k);
 
   97             BDataType v_b = b_k_n(k, n);
 
   99                 ck_tile::type_convert<AccDataType>(v_a) * ck_tile::type_convert<AccDataType>(v_b);
 
  103         if constexpr(DsDataType::size() == 0)
 
  105             acc_element_op(v_c, ck_tile::type_convert<float>(v_acc));
 
  107         else if constexpr(DsDataType::size() == 1)
 
  110                            ck_tile::type_convert<float>(v_acc),
 
  111                            ck_tile::type_convert<float>(ds_m_n[0](m, n)));
 
  113         else if constexpr(DsDataType::size() == 2)
 
  116                            ck_tile::type_convert<float>(v_acc),
 
  117                            ck_tile::type_convert<float>(ds_m_n[0](m, n)),
 
  118                            ck_tile::type_convert<float>(ds_m_n[1](m, n)));
 
  120         c_m_n(m, n) = ck_tile::type_convert<CDataType>(v_c);
 
  126 template <
typename ADataType,
 
  128           typename AccDataType,
 
  143     int idx = blockIdx.x * blockDim.x + threadIdx.x;
 
  147     if(row < M && col < N)
 
  149         AccDataType acc = 0.0;
 
  150         for(
int k = 0; k < K; ++k)
 
  155             int a_index = (std::is_same_v<LayoutA, tensor_layout::gemm::RowMajor>)
 
  158             int b_index = (std::is_same_v<LayoutB, tensor_layout::gemm::ColumnMajor>)
 
  164             if constexpr(std::is_same_v<ADataType, pk_int4_t>)
 
  174                 v_a = ck_tile::type_convert<AccDataType>(A[a_index]);
 
  176             if constexpr(std::is_same_v<BDataType, pk_int4_t>)
 
  186                 v_b = ck_tile::type_convert<AccDataType>(B[b_index]);
 
  191         int c_index = (std::is_same_v<LayoutC, tensor_layout::gemm::RowMajor>)
 
  192                           ? row * strideC + col
 
  193                           : col * strideC + row;
 
  194         C[c_index]  = ck_tile::type_convert<CDataType>(acc);
 
  198 template <
typename ADataType,
 
  200           typename AccDataType,
 
  215     int totalElements      = M * N;
 
  216     int numThreadsPerBlock = 256; 
 
  217     int numBlocks          = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
 
  219     naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
 
  220         <<<numBlocks, numThreadsPerBlock>>>(
 
  221             a_ptr, b_ptr, c_ptr, M, N, K, stride_a, stride_b, stride_c);
 
  226 template <
typename ADataType,
 
  228           typename AccDataType,
 
  247     int totalElements      = M * N;
 
  248     int numThreadsPerBlock = 256; 
 
  249     int numBlocks          = (totalElements + numThreadsPerBlock - 1) / numThreadsPerBlock;
 
  251     for(
index_t batch_id = 0; batch_id < batch_count; ++batch_id)
 
  253         ADataType* d_ATemp = a_ptr + batch_id * batch_stride_A;
 
  254         BDataType* d_BTemp = b_ptr + batch_id * batch_stride_B;
 
  255         CDataType* d_CTemp = c_ptr + batch_id * batch_stride_C;
 
  256         naive_gemm_kernel<ADataType, BDataType, AccDataType, CDataType, LayoutA, LayoutB, LayoutC>
 
  257             <<<numBlocks, numThreadsPerBlock>>>(
 
  258                 d_ATemp, d_BTemp, d_CTemp, M, N, K, stride_a, stride_b, stride_c);
 
#define CK_TILE_HOST
Definition: config.hpp:39
 
Definition: cluster_descriptor.hpp:13
 
void reference_batched_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c, index_t batch_stride_A, index_t batch_stride_B, index_t batch_stride_C, index_t batch_count)
Definition: reference_gemm.hpp:233
 
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
 
__global__ void naive_gemm_kernel(ADataType *A, BDataType *B, CDataType *C, ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K, ck_tile::index_t strideA, ck_tile::index_t strideB, ck_tile::index_t strideC)
Definition: reference_gemm.hpp:133
 
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t &x)
Definition: pk_int4.hpp:104
 
float fp32x2_t
Definition: pk_int4.hpp:100
 
int32_t index_t
Definition: integer.hpp:9
 
void reference_gemm_gpu(ADataType *a_ptr, BDataType *b_ptr, CDataType *c_ptr, index_t M, index_t N, index_t K, index_t stride_a, index_t stride_b, index_t stride_c)
Definition: reference_gemm.hpp:205
 
CK_TILE_HOST void reference_gemm_multiple_d(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, const std::array< HostTensor< DDataType >, DsDataType::size()> &ds_m_n, HostTensor< CDataType > &c_m_n, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:82
 
CK_TILE_HOST void reference_gemm(const HostTensor< ADataType > &a_m_k, const HostTensor< BDataType > &b_k_n, HostTensor< CDataType > &c_m_n, const AElementOp &a_element_op={}, const BElementOp &b_element_op={}, const ACCElementOp &acc_element_op={})
Definition: reference_gemm.hpp:21
 
Definition: host_tensor.hpp:336
 
std::size_t get_length(std::size_t dim) const
Definition: host_tensor.hpp:388
 
Definition: functional.hpp:86
 
Definition: numeric.hpp:81