78 template <WmmaInstr Instr, index_t WaveSize,
typename =
void>
84 template <index_t WaveSize>
94 static constexpr
index_t src_a_data_size = 2;
95 static constexpr
index_t src_b_data_size = 2;
96 static constexpr
index_t acc_data_size = 4;
97 static constexpr
index_t acc_pack_number = 1;
99 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
104 static constexpr
index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
105 static constexpr
index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
108 static constexpr
index_t num_acc_vgprs_per_wave =
109 m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
110 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
112 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
113 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
115 if constexpr(wave_size == 32)
119 else if constexpr(wave_size == 64)
126 template <index_t WaveSize>
139 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
143 static constexpr
index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
144 static constexpr
index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
145 static constexpr
index_t num_acc_vgprs_per_wave =
146 m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
147 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
149 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
150 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
152 if constexpr(wave_size == 32)
156 else if constexpr(wave_size == 64)
163 template <index_t WaveSize>
176 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
180 static constexpr
index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
181 static constexpr
index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
182 static constexpr
index_t num_acc_vgprs_per_wave =
183 m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
184 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
186 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
187 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
189 if constexpr(wave_size == 32)
193 else if constexpr(wave_size == 64)
199 template <index_t WaveSize>
212 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
216 static constexpr
index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
217 static constexpr
index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
218 static constexpr
index_t num_acc_vgprs_per_wave =
219 m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
220 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
228 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
230 if constexpr(wave_size == 32)
234 else if constexpr(wave_size == 64)
241 template <index_t WaveSize>
254 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
258 static constexpr
index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
259 static constexpr
index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
260 static constexpr
index_t num_acc_vgprs_per_wave =
261 m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
262 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
272 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
274 if constexpr(wave_size == 32)
279 else if constexpr(wave_size == 64)
290 template <index_t WaveSize>
306 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
315 static constexpr
index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
316 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
318 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
319 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
321 static_assert(wave_size == 32,
"only support wave32 for gfx12 wmma");
322 if constexpr(wave_size == 32)
329 template <index_t WaveSize>
342 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
348 static constexpr
index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
349 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
351 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
352 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
354 static_assert(wave_size == 32,
"only support wave32 for gfx12 wmma");
355 if constexpr(wave_size == 32)
362 template <index_t WaveSize>
375 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
381 static constexpr
index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
382 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
392 __device__
void run(
const FloatA& a,
const FloatB& b, FloatC& reg_c)
const
394 static_assert(wave_size == 32,
"only support wave32 for gfx12 wmma");
395 if constexpr(wave_size == 32)
403 template <
typename src_type_a,
410 template <
typename src_type_a_,
411 typename src_type_b_,
418 constexpr
auto GetWmma<half_t, half_t, float, 16, 16>()
428 constexpr
auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
438 constexpr
auto GetWmma<half_t, half_t, half_t, 16, 16>()
444 constexpr
auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
450 constexpr
auto GetWmma<int8_t, int8_t, int, 16, 16>()
459 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
461 constexpr
auto GetWmma<int4_t, int4_t, int, 16, 16>()
472 static_assert(
selected_wmma.m_per_wmma == 16,
"WRONG! WMMA_M must equal to 16");
474 static_assert(
selected_wmma.m_per_wmma == 16,
"WRONG! WMMA_M must equal to 16");
476 static_assert(
selected_wmma.k_per_wmma == 16,
"WRONG! WMMA_M must equal to 16");
481 "WRONG! Invalid Number of Accumulator Register");
485 template <
typename src_type_a,
491 bool TransposeC =
false,
492 bool AssemblyBackend =
false>
507 static_assert(NPerWmma == 16 && MPerWmma == 16,
508 "Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma");
510 static_assert(KPack %
wmma_instr.k_per_wmma == 0,
"KPack should be multiple of k_per_wmma");
516 template <
typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
517 __host__ __device__
static constexpr
auto
519 const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
520 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
522 const auto MBlockxRepeat =
523 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I0);
524 const auto NBlockxRepeat =
525 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I3);
527 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I1);
529 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I4);
532 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
556 template <
typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
557 __host__ __device__
static constexpr
auto
559 const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
560 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
562 const auto MBlockxRepeat =
563 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I0);
564 const auto NBlockxRepeat =
565 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I3);
567 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I1);
569 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I4);
572 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
602 template <
class FloatA,
class FloatB,
class FloatC>
603 __device__
void Run(
const FloatA& p_a_wave,
const FloatB& p_b_wave, FloatC& p_c_thread)
const
616 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
621 "base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
622 "(int8, int32) or (int4, int32)!");
624 if constexpr(!TransposeC)
626 wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave[k], p_b_wave[k], p_c_thread);
630 wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave[k], p_a_wave[k], p_c_thread);
677 return TransposeC ?
CIndex{n_offset, m_offset} :
CIndex{m_offset, n_offset};
692 __host__ __device__
static constexpr
auto
__host__ constexpr __device__ T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition: math.hpp:148
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:13
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:16
WmmaInstr
Definition: wmma_gemm.hpp:13
@ wmma_f32_16x16x16_bf16_gfx12
@ wmma_i32_16x16x16_iu8_gfx12
@ wmma_f32_16x16x16_f16_gfx12
@ wmma_bf16_16x16x16_bf16
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
Definition: sequence.hpp:43
Definition: wmma_gemm.hpp:494
static constexpr auto I0
Definition: wmma_gemm.hpp:495
static __device__ auto GetLaneId()
Definition: wmma_gemm.hpp:635
__device__ void Run(const FloatA &p_a_wave, const FloatB &p_b_wave, FloatC &p_c_thread) const
Definition: wmma_gemm.hpp:603
static constexpr __device__ index_t GetWaveSize()
Definition: wmma_gemm.hpp:600
static constexpr auto wmma
Definition: wmma_gemm.hpp:688
__host__ static constexpr __device__ auto GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
Definition: wmma_gemm.hpp:693
__host__ static __device__ auto CalculateAThreadOriginDataIndex()
Definition: wmma_gemm.hpp:654
static __device__ auto GetSubGroupId()
Definition: wmma_gemm.hpp:637
static __device__ auto GetSwizzledLaneIdLow()
Definition: wmma_gemm.hpp:649
static constexpr auto I3
Definition: wmma_gemm.hpp:498
static constexpr auto I5
Definition: wmma_gemm.hpp:500
__host__ static __device__ auto CalculateBThreadOriginDataIndex()
Definition: wmma_gemm.hpp:663
__host__ static constexpr __device__ auto MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA &c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
Definition: wmma_gemm.hpp:558
static __device__ CIndex GetBeginOfThreadBlk()
Definition: wmma_gemm.hpp:672
static constexpr auto I4
Definition: wmma_gemm.hpp:499
static constexpr __device__ index_t GetRegSizePerWmma()
Definition: wmma_gemm.hpp:595
__host__ static constexpr __device__ auto MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA &c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
Definition: wmma_gemm.hpp:518
__host__ constexpr __device__ WmmaGemm()
Definition: wmma_gemm.hpp:505
static constexpr auto I2
Definition: wmma_gemm.hpp:497
static __device__ CIndex3D GetBeginOfThreadBlk3D()
Definition: wmma_gemm.hpp:680
static constexpr auto I1
Definition: wmma_gemm.hpp:496
static __device__ auto GetLaneIdUnderSubGroup()
Definition: wmma_gemm.hpp:645
static constexpr auto wmma_instr
Definition: wmma_gemm.hpp:690
Definition: wmma_gemm.hpp:409
static constexpr auto selected_wmma
Definition: wmma_gemm.hpp:467
__host__ constexpr __device__ WmmaSelector()
Definition: wmma_gemm.hpp:470
static constexpr auto GetWmma()
Definition: integral_constant.hpp:10
Definition: amd_wmma.hpp:96
Definition: amd_wmma.hpp:216
Definition: amd_wmma.hpp:72
Definition: amd_wmma.hpp:192
Definition: amd_wmma.hpp:297
Definition: amd_wmma.hpp:50
Definition: amd_wmma.hpp:170
Definition: amd_wmma.hpp:271
Definition: amd_wmma.hpp:25
Definition: amd_wmma.hpp:149
Definition: amd_wmma.hpp:319
Definition: amd_wmma.hpp:121
Definition: amd_wmma.hpp:241
Definition: functional2.hpp:31
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:228
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:187
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:150
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:352
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:113
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:319
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:272
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:392
Definition: wmma_gemm.hpp:80