47 template <DppInstr instr>
 
   54     static constexpr 
index_t lanegroup_size  = 8;
 
   57     static constexpr 
index_t m_per_lanegroup = 8;
 
   58     static constexpr 
index_t n_per_lanegroup = 8;
 
   59     static constexpr 
index_t m_per_thread    = 8;
 
   60     static constexpr 
index_t n_per_thread    = 1;
 
   62     static constexpr 
bool share_a            = 
true;
 
   65     template <index_t MPerDpp, index_t NPerDpp, 
class ADataType, 
class BDataType, 
class CDataType>
 
   66     __device__ 
void run(
const ADataType& a, 
const BDataType& b, CDataType& reg_c)
 const 
   84     static constexpr 
index_t lanegroup_size  = 8;
 
   87     static constexpr 
index_t m_per_lanegroup = 8;
 
   88     static constexpr 
index_t n_per_lanegroup = 8;
 
   89     static constexpr 
index_t m_per_thread    = 8;
 
   90     static constexpr 
index_t n_per_thread    = 1;
 
   92     static constexpr 
bool share_a            = 
true;
 
   95     template <index_t MPerDpp, index_t NPerDpp, 
class ADataType, 
class BDataType, 
class CDataType>
 
   96     __device__ 
void run(
const ADataType& a, 
const BDataType& b, CDataType& reg_c)
 const 
  122     static constexpr 
bool share_a            = 
true;
 
  125     template <index_t MPerDpp, index_t NPerDpp, 
class ADataType, 
class BDataType, 
class CDataType>
 
  126     __device__ 
void run(
const ADataType& a, 
const BDataType& b, CDataType& reg_c)
 const 
  152     static constexpr 
bool share_a            = 
true;
 
  155     template <index_t MPerDpp, index_t NPerDpp, 
class ADataType, 
class BDataType, 
class CDataType>
 
  156     __device__ 
void run(
const ADataType& a, 
const BDataType& b, CDataType& reg_c)
 const 
  182     static constexpr 
bool share_a            = 
true;
 
  185     template <index_t MPerDpp, index_t NPerDpp, 
class ADataType, 
class BDataType, 
class CDataType>
 
  186     __device__ 
void run(
const ADataType& a, 
const BDataType& b, CDataType& reg_c)
 const 
  212     static constexpr 
bool share_a            = 
true;
 
  215     template <index_t MPerDpp, index_t NPerDpp, 
class ADataType, 
class BDataType, 
class CDataType>
 
  216     __device__ 
void run(
const ADataType& a, 
const BDataType& b, CDataType& reg_c)
 const 
  242     static constexpr 
bool share_a            = 
true;
 
  245     template <index_t MPerDpp, index_t NPerDpp, 
class ADataType, 
class BDataType, 
class CDataType>
 
  246     __device__ 
void run(
const ADataType& a, 
const BDataType& b, CDataType& reg_c)
 const 
  272     static constexpr 
bool share_a            = 
true;
 
  275     template <index_t MPerDpp, index_t NPerDpp, 
class ADataType, 
class BDataType, 
class CDataType>
 
  276     __device__ 
void run(
const ADataType& a, 
const BDataType& b, CDataType& reg_c)
 const 
  302     static constexpr 
bool share_a            = 
true;
 
  305     template <index_t MPerDpp, index_t NPerDpp, 
class ADataType, 
class BDataType, 
class CDataType>
 
  306     __device__ 
void run(
const ADataType& a, 
const BDataType& b, CDataType& reg_c)
 const 
  320 template <
typename BaseType, index_t MPerDpp, index_t NPerDpp>
 
  323     template <
typename BaseType_, index_t MPerDpp_, index_t NPerDpp_>
 
  327     constexpr 
auto GetDpp<half_t, 8, 32>()
 
  333     constexpr 
auto GetDpp<half_t, 8, 16>()
 
  339     constexpr 
auto GetDpp<half_t, 16, 16>()
 
  345     constexpr 
auto GetDpp<half_t, 32, 8>()
 
  351     constexpr 
auto GetDpp<half_t, 1, 32>()
 
  357     constexpr 
auto GetDpp<half_t, 2, 32>()
 
  363     constexpr 
auto GetDpp<half_t, 2, 16>()
 
  369     constexpr 
auto GetDpp<half_t, 4, 16>()
 
  375     constexpr 
auto GetDpp<half_t, 4, 32>()
 
  392         constexpr 
index_t num_dpp_c_elems =
 
  394         static_assert(num_wave_c_elems % num_dpp_c_elems == 0);
 
  395         static_assert(num_dpp_per_wave == num_wave_c_elems / num_dpp_c_elems);
 
  424 template <
typename BaseType, index_t MPerDpp, index_t NPerDpp, index_t KPack>
 
  439         static_assert(KPack % 
dpp_instr.k_per_dpp == 0, 
"KPack must be divisible by k_per_dpp.");
 
  444         return MPerDpp * NPerDpp / 
dpp_instr.wave_size;
 
  447     template <
class ADataType, 
class BDataType, 
class CDataType>
 
  449     Run(
const ADataType& p_a_wave, 
const BDataType& p_b_wave, CDataType& p_c_thread)
 const 
  454                       "base BaseType must be double, float, half, bfloat16, and int8_t!");
 
  457             dpp_instr.template run<MPerDpp, NPerDpp>(p_a_wave[k], p_b_wave[k], p_c_thread);
 
  489         const auto dpp_idx = lanegroup_idx_1d_to_dpp_idx_2d_adaptor.CalculateBottomIndex(
 
  492         const auto m_dpp_idx = dpp_idx[
I0];
 
  493         const auto n_dpp_idx = dpp_idx[
I1];
 
  501         const auto wave_row = laneId / 
dpp_instr.n_per_wave;
 
  516         const auto m_dpp_op_idx = dpp_op_idx[
I0];
 
  517         const auto n_dpp_op_idx = dpp_op_idx[
I1];
 
  522         return CIndex{m_offset, n_offset};
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
 
_Float16 half_t
Definition: data_type.hpp:30
 
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
 
DppInstr
Definition: dpp_gemm.hpp:13
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:297
 
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
 
Definition: dpp_gemm.hpp:426
 
__host__ static __device__ auto CalculateBThreadOriginDataIndex_K_N()
Definition: dpp_gemm.hpp:506
 
__device__ void Run(const ADataType &p_a_wave, const BDataType &p_b_wave, CDataType &p_c_thread) const
Definition: dpp_gemm.hpp:449
 
static constexpr auto dpp_instr
Definition: dpp_gemm.hpp:527
 
__host__ static constexpr __device__ auto GetCMNThreadBlkLengths()
Definition: dpp_gemm.hpp:532
 
__host__ constexpr __device__ DppGemm()
Definition: dpp_gemm.hpp:437
 
static constexpr auto I3
Definition: dpp_gemm.hpp:430
 
static constexpr auto I1
Definition: dpp_gemm.hpp:428
 
static __device__ auto GetWaveId()
Definition: dpp_gemm.hpp:466
 
static constexpr auto I5
Definition: dpp_gemm.hpp:432
 
static constexpr __device__ index_t GetRegSizePerDpp()
Definition: dpp_gemm.hpp:442
 
static __device__ auto GetLaneGroupIdInWave()
Definition: dpp_gemm.hpp:473
 
static __device__ CIndex GetBeginOfThreadBlk()
Definition: dpp_gemm.hpp:512
 
static constexpr auto I4
Definition: dpp_gemm.hpp:431
 
static constexpr auto I2
Definition: dpp_gemm.hpp:429
 
static __device__ auto GetLaneIdInLaneGroup()
Definition: dpp_gemm.hpp:468
 
static constexpr auto K1PerDpp
Definition: dpp_gemm.hpp:530
 
static constexpr auto dpp
Definition: dpp_gemm.hpp:525
 
__host__ static __device__ auto CalculateAThreadOriginDataIndex_K_M()
Definition: dpp_gemm.hpp:498
 
static constexpr auto I0
Definition: dpp_gemm.hpp:427
 
static __device__ auto GetDppOpIdx()
Definition: dpp_gemm.hpp:478
 
static __device__ auto GetLaneIdInWave()
Definition: dpp_gemm.hpp:461
 
static constexpr auto K0PerDpp
Definition: dpp_gemm.hpp:529
 
Definition: dpp_gemm.hpp:322
 
static constexpr index_t GetK1PerDpp()
Definition: dpp_gemm.hpp:421
 
static constexpr auto selected_dpp
Definition: dpp_gemm.hpp:380
 
static constexpr auto GetDpp()
 
__host__ constexpr __device__ DppSelector()
Definition: dpp_gemm.hpp:382
 
Definition: sequence.hpp:43
 
Definition: amd_gemm_dpp.hpp:37
 
__device__ void Run(const AVecDataType &a_vec, const BVecDataType &b_vec, CVecDataType &c_vec)
Definition: amd_gemm_dpp.hpp:43
 
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition: dpp_gemm.hpp:156
 
half_t BaseType
Definition: dpp_gemm.hpp:153
 
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition: dpp_gemm.hpp:246
 
half_t BaseType
Definition: dpp_gemm.hpp:243
 
half_t BaseType
Definition: dpp_gemm.hpp:303
 
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition: dpp_gemm.hpp:306
 
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition: dpp_gemm.hpp:276
 
half_t BaseType
Definition: dpp_gemm.hpp:273
 
half_t BaseType
Definition: dpp_gemm.hpp:63
 
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition: dpp_gemm.hpp:66
 
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition: dpp_gemm.hpp:216
 
half_t BaseType
Definition: dpp_gemm.hpp:213
 
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition: dpp_gemm.hpp:186
 
half_t BaseType
Definition: dpp_gemm.hpp:183
 
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition: dpp_gemm.hpp:126
 
half_t BaseType
Definition: dpp_gemm.hpp:123
 
half_t BaseType
Definition: dpp_gemm.hpp:93
 
__device__ void run(const ADataType &a, const BDataType &b, CDataType ®_c) const
Definition: dpp_gemm.hpp:96
 
Definition: dpp_gemm.hpp:48
 
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33