82 template <WmmaInstr Instr, index_t WaveSize, 
typename = 
void>
 
   88 template <index_t WaveSize>
 
   98     static constexpr 
index_t src_a_data_size = 2;
 
   99     static constexpr 
index_t src_b_data_size = 2;
 
  103     static constexpr 
index_t num_thread_per_subgroups = n_per_wmma;
 
  108     static constexpr 
index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
 
  109     static constexpr 
index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
 
  112     static constexpr 
index_t num_acc_vgprs_per_wave =
 
  113         m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
 
  114     static constexpr 
index_t num_subgroups = wave_size / num_thread_per_subgroups;
 
  116     template <index_t MPerWmma, index_t NPerWmma, 
class FloatA, 
class FloatB, 
class FloatC>
 
  117     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  119         if constexpr(wave_size == 32)
 
  123         else if constexpr(wave_size == 64)
 
  130 template <index_t WaveSize>
 
  143     static constexpr 
index_t num_thread_per_subgroups = n_per_wmma;
 
  147     static constexpr 
index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
 
  148     static constexpr 
index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
 
  149     static constexpr 
index_t num_acc_vgprs_per_wave =
 
  150         m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
 
  151     static constexpr 
index_t num_subgroups = wave_size / num_thread_per_subgroups;
 
  153     template <index_t MPerWmma, index_t NPerWmma, 
class FloatA, 
class FloatB, 
class FloatC>
 
  154     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  156         if constexpr(wave_size == 32)
 
  160         else if constexpr(wave_size == 64)
 
  167 template <index_t WaveSize>
 
  180     static constexpr 
index_t num_thread_per_subgroups = n_per_wmma;
 
  184     static constexpr 
index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
 
  185     static constexpr 
index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
 
  186     static constexpr 
index_t num_acc_vgprs_per_wave =
 
  187         m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
 
  188     static constexpr 
index_t num_subgroups = wave_size / num_thread_per_subgroups;
 
  190     template <index_t MPerWmma, index_t NPerWmma, 
class FloatA, 
class FloatB, 
class FloatC>
 
  191     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  193         if constexpr(wave_size == 32)
 
  197         else if constexpr(wave_size == 64)
 
  203 template <index_t WaveSize>
 
  216     static constexpr 
index_t num_thread_per_subgroups = n_per_wmma;
 
  220     static constexpr 
index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
 
  221     static constexpr 
index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
 
  222     static constexpr 
index_t num_acc_vgprs_per_wave =
 
  223         m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
 
  224     static constexpr 
index_t num_subgroups = wave_size / num_thread_per_subgroups;
 
  232     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  234         if constexpr(wave_size == 32)
 
  238         else if constexpr(wave_size == 64)
 
  245 template <index_t WaveSize>
 
  258     static constexpr 
index_t num_thread_per_subgroups = n_per_wmma;
 
  262     static constexpr 
index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
 
  263     static constexpr 
index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
 
  264     static constexpr 
index_t num_acc_vgprs_per_wave =
 
  265         m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
 
  266     static constexpr 
index_t num_subgroups = wave_size / num_thread_per_subgroups;
 
  276     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  278         if constexpr(wave_size == 32)
 
  283         else if constexpr(wave_size == 64)
 
  294 template <index_t WaveSize>
 
  310     static constexpr 
index_t num_thread_per_subgroups = n_per_wmma;
 
  319     static constexpr 
index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
 
  320     static constexpr 
index_t num_subgroups          = wave_size / num_thread_per_subgroups;
 
  322     template <index_t MPerWmma, index_t NPerWmma, 
class FloatA, 
class FloatB, 
class FloatC>
 
  323     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  325         static_assert(wave_size == 32, 
"only support wave32 for gfx12 wmma");
 
  326         if constexpr(wave_size == 32)
 
  333 template <index_t WaveSize>
 
  346     static constexpr 
index_t num_thread_per_subgroups = n_per_wmma;
 
  352     static constexpr 
index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
 
  353     static constexpr 
index_t num_subgroups          = wave_size / num_thread_per_subgroups;
 
  355     template <index_t MPerWmma, index_t NPerWmma, 
class FloatA, 
class FloatB, 
class FloatC>
 
  356     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  358         static_assert(wave_size == 32, 
"only support wave32 for gfx12 wmma");
 
  359         if constexpr(wave_size == 32)
 
  366 template <index_t WaveSize>
 
  379     static constexpr 
index_t num_thread_per_subgroups = n_per_wmma;
 
  385     static constexpr 
index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
 
  386     static constexpr 
index_t num_subgroups          = wave_size / num_thread_per_subgroups;
 
  396     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  398         static_assert(wave_size == 32, 
"only support wave32 for gfx12 wmma");
 
  399         if constexpr(wave_size == 32)
 
  407 template <index_t WaveSize>
 
  418     static constexpr 
index_t num_thread_per_subgroups = n_per_wmma;
 
  422     static constexpr 
index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
 
  423     static constexpr 
index_t num_subgroups          = wave_size / num_thread_per_subgroups;
 
  425     template <index_t MPerWmma, index_t NPerWmma, 
class FloatA, 
class FloatB, 
class FloatC>
 
  426     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  428         static_assert(wave_size == 32, 
"only support wave32 for gfx12 wmma");
 
  429         if constexpr(wave_size == 32)
 
  442 template <index_t WaveSize>
 
  453     static constexpr 
index_t num_thread_per_subgroups = n_per_wmma;
 
  457     static constexpr 
index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
 
  458     static constexpr 
index_t num_subgroups          = wave_size / num_thread_per_subgroups;
 
  460     template <index_t MPerWmma, index_t NPerWmma, 
class FloatA, 
class FloatB, 
class FloatC>
 
  461     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  463         static_assert(wave_size == 32, 
"only support wave32 for gfx12 wmma");
 
  464         if constexpr(wave_size == 32)
 
  477 template <index_t WaveSize>
 
  488     static constexpr 
index_t num_thread_per_subgroups = n_per_wmma;
 
  492     static constexpr 
index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
 
  493     static constexpr 
index_t num_subgroups          = wave_size / num_thread_per_subgroups;
 
  495     template <index_t MPerWmma, index_t NPerWmma, 
class FloatA, 
class FloatB, 
class FloatC>
 
  496     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  498         static_assert(wave_size == 32, 
"only support wave32 for gfx12 wmma");
 
  499         if constexpr(wave_size == 32)
 
  512 template <index_t WaveSize>
 
  523     static constexpr 
index_t num_thread_per_subgroups = n_per_wmma;
 
  527     static constexpr 
index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
 
  528     static constexpr 
index_t num_subgroups          = wave_size / num_thread_per_subgroups;
 
  530     template <index_t MPerWmma, index_t NPerWmma, 
class FloatA, 
class FloatB, 
class FloatC>
 
  531     __device__ 
void run(
const FloatA& a, 
const FloatB& b, FloatC& reg_c)
 const 
  533         static_assert(wave_size == 32, 
"only support wave32 for gfx12 wmma");
 
  534         if constexpr(wave_size == 32)
 
  547 template <
typename src_type_a,
 
  554     template <
typename src_type_a_,
 
  555               typename src_type_b_,
 
  562     constexpr 
auto GetWmma<half_t, half_t, float, 16, 16>()
 
  572     constexpr 
auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
 
  582     constexpr 
auto GetWmma<half_t, half_t, half_t, 16, 16>()
 
  588     constexpr 
auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
 
  594     constexpr 
auto GetWmma<int8_t, int8_t, int, 16, 16>()
 
  603 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 
  605     constexpr 
auto GetWmma<int4_t, int4_t, int, 16, 16>()
 
  612     constexpr 
auto GetWmma<f8_t, f8_t, float, 16, 16>()
 
  618     constexpr 
auto GetWmma<f8_t, bf8_t, float, 16, 16>()
 
  624     constexpr 
auto GetWmma<bf8_t, f8_t, float, 16, 16>()
 
  630     constexpr 
auto GetWmma<bf8_t, bf8_t, float, 16, 16>()
 
  641         static_assert(
selected_wmma.m_per_wmma == 16, 
"WRONG! WMMA_M must equal to 16");
 
  643         static_assert(
selected_wmma.m_per_wmma == 16, 
"WRONG! WMMA_M must equal to 16");
 
  645         static_assert(
selected_wmma.k_per_wmma == 16, 
"WRONG! WMMA_M must equal to 16");
 
  650                       "WRONG! Invalid Number of Accumulator Register");
 
  654 template <
typename src_type_a,
 
  660           bool TransposeC      = 
false,
 
  661           bool AssemblyBackend = 
false>
 
  676         static_assert(NPerWmma == 16 && MPerWmma == 16,
 
  677                       "Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma");
 
  679         static_assert(KPack % 
wmma_instr.k_per_wmma == 0, 
"KPack should be multiple of k_per_wmma");
 
  685     template <
typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
 
  686     __host__ __device__ 
static constexpr 
auto 
  688         const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
 
  689             c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
 
  691         const auto MBlockxRepeat =
 
  692             c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I0);
 
  693         const auto NBlockxRepeat =
 
  694             c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I3);
 
  696             c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I1);
 
  698             c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I4);
 
  701             c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
 
  725     template <
typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
 
  726     __host__ __device__ 
static constexpr 
auto 
  728         const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
 
  729             c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
 
  731         const auto MBlockxRepeat =
 
  732             c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I0);
 
  733         const auto NBlockxRepeat =
 
  734             c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I3);
 
  736             c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I1);
 
  738             c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I4);
 
  741             c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
 
  771     template <
class FloatA, 
class FloatB, 
class FloatC>
 
  772     __device__ 
void Run(
const FloatA& p_a_wave, 
const FloatB& p_b_wave, FloatC& p_c_thread)
 const 
  788 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
 
  793             "base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), " 
  794             "((f8 or bf8, f8 or bf8), float), (int8, int32) or (int4, int32)!");
 
  796             if constexpr(!TransposeC)
 
  798                 wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave[k], p_b_wave[k], p_c_thread);
 
  802                 wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave[k], p_a_wave[k], p_c_thread);
 
  849         return TransposeC ? 
CIndex{n_offset, m_offset} : 
CIndex{m_offset, n_offset};
 
  864     __host__ __device__ 
static constexpr 
auto 
__host__ constexpr __device__ T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition: math.hpp:148
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
 
int32_t index_t
Definition: ck.hpp:297
 
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
 
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
 
WmmaInstr
Definition: wmma_gemm.hpp:13
 
@ wmma_f32_16x16x16_bf16_gfx12
 
@ wmma_i32_16x16x16_iu8_gfx12
 
@ wmma_f32_16x16x16_bf8f8_gfx12
 
@ wmma_f32_16x16x16_f16_gfx12
 
@ wmma_f32_16x16x16_bf8bf8_gfx12
 
@ wmma_f32_16x16x16_f8f8_gfx12
 
@ wmma_bf16_16x16x16_bf16
 
@ wmma_f32_16x16x16_f8bf8_gfx12
 
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
 
Definition: sequence.hpp:43
 
Definition: wmma_gemm.hpp:663
 
static constexpr auto I0
Definition: wmma_gemm.hpp:664
 
static __device__ auto GetLaneId()
Definition: wmma_gemm.hpp:807
 
__device__ void Run(const FloatA &p_a_wave, const FloatB &p_b_wave, FloatC &p_c_thread) const
Definition: wmma_gemm.hpp:772
 
static constexpr __device__ index_t GetWaveSize()
Definition: wmma_gemm.hpp:769
 
static constexpr auto wmma
Definition: wmma_gemm.hpp:860
 
__host__ static constexpr __device__ auto GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
Definition: wmma_gemm.hpp:865
 
__host__ static __device__ auto CalculateAThreadOriginDataIndex()
Definition: wmma_gemm.hpp:826
 
static __device__ auto GetSubGroupId()
Definition: wmma_gemm.hpp:809
 
static __device__ auto GetSwizzledLaneIdLow()
Definition: wmma_gemm.hpp:821
 
static constexpr auto I3
Definition: wmma_gemm.hpp:667
 
static constexpr auto I5
Definition: wmma_gemm.hpp:669
 
__host__ static __device__ auto CalculateBThreadOriginDataIndex()
Definition: wmma_gemm.hpp:835
 
__host__ static constexpr __device__ auto MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA &c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
Definition: wmma_gemm.hpp:727
 
static __device__ CIndex GetBeginOfThreadBlk()
Definition: wmma_gemm.hpp:844
 
static constexpr auto I4
Definition: wmma_gemm.hpp:668
 
static constexpr __device__ index_t GetRegSizePerWmma()
Definition: wmma_gemm.hpp:764
 
__host__ static constexpr __device__ auto MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA &c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
Definition: wmma_gemm.hpp:687
 
__host__ constexpr __device__ WmmaGemm()
Definition: wmma_gemm.hpp:674
 
static constexpr auto I2
Definition: wmma_gemm.hpp:666
 
static __device__ CIndex3D GetBeginOfThreadBlk3D()
Definition: wmma_gemm.hpp:852
 
static constexpr auto I1
Definition: wmma_gemm.hpp:665
 
static __device__ auto GetLaneIdUnderSubGroup()
Definition: wmma_gemm.hpp:817
 
static constexpr auto wmma_instr
Definition: wmma_gemm.hpp:862
 
Definition: wmma_gemm.hpp:553
 
static constexpr auto selected_wmma
Definition: wmma_gemm.hpp:636
 
__host__ constexpr __device__ WmmaSelector()
Definition: wmma_gemm.hpp:639
 
static constexpr auto GetWmma()
 
Definition: integral_constant.hpp:20
 
Definition: amd_wmma.hpp:96
 
Definition: amd_wmma.hpp:216
 
Definition: amd_wmma.hpp:72
 
Definition: amd_wmma.hpp:192
 
Definition: amd_wmma.hpp:297
 
Definition: amd_wmma.hpp:50
 
Definition: amd_wmma.hpp:170
 
Definition: amd_wmma.hpp:418
 
Definition: amd_wmma.hpp:394
 
Definition: amd_wmma.hpp:271
 
Definition: amd_wmma.hpp:25
 
Definition: amd_wmma.hpp:149
 
Definition: amd_wmma.hpp:370
 
Definition: amd_wmma.hpp:346
 
Definition: amd_wmma.hpp:319
 
Definition: amd_wmma.hpp:121
 
Definition: amd_wmma.hpp:241
 
Definition: functional2.hpp:33
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:232
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:191
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:154
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:356
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:531
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:496
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:117
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:323
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:461
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:426
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:276
 
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:396
 
Definition: wmma_gemm.hpp:84