82 template <WmmaInstr Instr, index_t WaveSize,
typename =
void>
88 template <index_t WaveSize>
99 static constexpr
index_t src_a_data_size = 2;
104 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
109 static constexpr
index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
110 static constexpr
index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
113 static constexpr
index_t num_acc_vgprs_per_wave =
114 m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
115 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
117 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
118 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
120 if constexpr(wave_size == 32)
124 else if constexpr(wave_size == 64)
131 template <index_t WaveSize>
145 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
149 static constexpr
index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
150 static constexpr
index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
151 static constexpr
index_t num_acc_vgprs_per_wave =
152 m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
153 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
155 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
156 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
158 if constexpr(wave_size == 32)
162 else if constexpr(wave_size == 64)
169 template <index_t WaveSize>
183 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
187 static constexpr
index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
188 static constexpr
index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
189 static constexpr
index_t num_acc_vgprs_per_wave =
190 m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
191 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
193 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
194 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
196 if constexpr(wave_size == 32)
200 else if constexpr(wave_size == 64)
206 template <index_t WaveSize>
220 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
224 static constexpr
index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
225 static constexpr
index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
226 static constexpr
index_t num_acc_vgprs_per_wave =
227 m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
228 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
236 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
238 if constexpr(wave_size == 32)
242 else if constexpr(wave_size == 64)
249 template <index_t WaveSize>
263 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
267 static constexpr
index_t num_src_a_vgprs_per_wave = m_per_wmma * src_a_data_size / 4;
268 static constexpr
index_t num_src_b_vgprs_per_wave = n_per_wmma * src_b_data_size / 4;
269 static constexpr
index_t num_acc_vgprs_per_wave =
270 m_per_wmma * n_per_wmma * acc_data_size * acc_pack_number / wave_size / 4;
271 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
281 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
283 if constexpr(wave_size == 32)
288 else if constexpr(wave_size == 64)
299 template <index_t WaveSize>
316 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
325 static constexpr
index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
326 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
328 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
329 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
331 static_assert(wave_size == 32,
"only support wave32 for gfx12 wmma");
332 if constexpr(wave_size == 32)
339 template <index_t WaveSize>
353 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
359 static constexpr
index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
360 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
362 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
363 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
365 static_assert(wave_size == 32,
"only support wave32 for gfx12 wmma");
366 if constexpr(wave_size == 32)
373 template <index_t WaveSize>
387 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
393 static constexpr
index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
394 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
404 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
406 static_assert(wave_size == 32,
"only support wave32 for gfx12 wmma");
407 if constexpr(wave_size == 32)
415 template <index_t WaveSize>
427 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
431 static constexpr
index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
432 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
434 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
435 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
437 static_assert(wave_size == 32,
"only support wave32 for gfx12 wmma");
438 if constexpr(wave_size == 32)
451 template <index_t WaveSize>
463 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
467 static constexpr
index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
468 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
470 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
471 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
473 static_assert(wave_size == 32,
"only support wave32 for gfx12 wmma");
474 if constexpr(wave_size == 32)
487 template <index_t WaveSize>
499 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
503 static constexpr
index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
504 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
506 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
507 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
509 static_assert(wave_size == 32,
"only support wave32 for gfx12 wmma");
510 if constexpr(wave_size == 32)
523 template <index_t WaveSize>
535 static constexpr
index_t num_thread_per_subgroups = n_per_wmma;
539 static constexpr
index_t num_acc_vgprs_per_wave = m_per_wmma * n_per_wmma / wave_size;
540 static constexpr
index_t num_subgroups = wave_size / num_thread_per_subgroups;
542 template <index_t MPerWmma, index_t NPerWmma,
class FloatA,
class FloatB,
class FloatC>
543 __device__
void run(
const FloatA&
a,
const FloatB& b, FloatC& reg_c)
const
545 static_assert(wave_size == 32,
"only support wave32 for gfx12 wmma");
546 if constexpr(wave_size == 32)
559 template <
typename src_type_a,
566 template <
typename src_type_a_,
567 typename src_type_b_,
574 constexpr
auto GetWmma<half_t, half_t, float, 16, 16>()
584 constexpr
auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
594 constexpr
auto GetWmma<half_t, half_t, half_t, 16, 16>()
600 constexpr
auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
606 constexpr
auto GetWmma<int8_t, int8_t, int, 16, 16>()
615 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
617 constexpr
auto GetWmma<int4_t, int4_t, int, 16, 16>()
624 constexpr
auto GetWmma<f8_t, f8_t, float, 16, 16>()
630 constexpr
auto GetWmma<f8_t, bf8_t, float, 16, 16>()
636 constexpr
auto GetWmma<bf8_t, f8_t, float, 16, 16>()
642 constexpr
auto GetWmma<bf8_t, bf8_t, float, 16, 16>()
653 static_assert(
selected_wmma.m_per_wmma == 16,
"WRONG! WMMA_M must equal to 16");
655 static_assert(
selected_wmma.m_per_wmma == 16,
"WRONG! WMMA_M must equal to 16");
657 static_assert(
selected_wmma.k_per_wmma == 16,
"WRONG! WMMA_M must equal to 16");
662 "WRONG! Invalid Number of Accumulator Register");
666 template <
typename src_type_a,
672 bool TransposeC =
false,
673 bool AssemblyBackend =
false>
688 static_assert(NPerWmma == 16 && MPerWmma == 16,
689 "Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma");
691 static_assert(KPack %
wmma_instr.k_per_wmma == 0,
"KPack should be multiple of k_per_wmma");
697 template <
typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
698 __host__ __device__
static constexpr
auto
700 const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
701 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
703 const auto MBlockxRepeat =
704 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I0);
705 const auto NBlockxRepeat =
706 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I3);
708 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I1);
710 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I4);
713 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
737 template <
typename CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA>
738 __host__ __device__
static constexpr
auto
740 const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA&
741 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
743 const auto MBlockxRepeat =
744 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I0);
745 const auto NBlockxRepeat =
746 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I3);
748 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I1);
750 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma.GetLength(
I4);
753 c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma,
785 template <
class FloatA,
class FloatB,
class FloatC>
786 __device__
void Run(
const FloatA& p_a_wave,
const FloatB& p_b_wave, FloatC& p_c_thread)
const
802 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
807 "base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), "
808 "((f8 or bf8, f8 or bf8), float), (int8, int32) or (int4, int32)!");
813 if constexpr(!TransposeC)
815 wmma_instr.template run<MPerWmma, NPerWmma>(p_a_wave[k], p_b_wave[k], p_c_thread);
819 wmma_instr.template run<MPerWmma, NPerWmma>(p_b_wave[k], p_a_wave[k], p_c_thread);
866 return TransposeC ?
CIndex{n_offset, m_offset} :
CIndex{m_offset, n_offset};
881 __host__ __device__
static constexpr
auto
__host__ constexpr __device__ T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition: math.hpp:148
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:301
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
WmmaInstr
Definition: wmma_gemm.hpp:13
@ wmma_f32_16x16x16_bf16_gfx12
@ wmma_i32_16x16x16_iu8_gfx12
@ wmma_f32_16x16x16_bf8f8_gfx12
@ wmma_f32_16x16x16_f16_gfx12
@ wmma_f32_16x16x16_bf8bf8_gfx12
@ wmma_f32_16x16x16_f8f8_gfx12
@ wmma_bf16_16x16x16_bf16
@ wmma_f32_16x16x16_f8bf8_gfx12
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1517
Definition: sequence.hpp:43
Definition: wmma_gemm.hpp:675
static constexpr auto I0
Definition: wmma_gemm.hpp:676
static __device__ auto GetLaneId()
Definition: wmma_gemm.hpp:824
__device__ void Run(const FloatA &p_a_wave, const FloatB &p_b_wave, FloatC &p_c_thread) const
Definition: wmma_gemm.hpp:786
static constexpr __device__ index_t GetWaveSize()
Definition: wmma_gemm.hpp:781
static constexpr auto wmma
Definition: wmma_gemm.hpp:877
__host__ static constexpr __device__ auto GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()
Definition: wmma_gemm.hpp:882
__host__ static __device__ auto CalculateAThreadOriginDataIndex()
Definition: wmma_gemm.hpp:843
static __device__ auto GetSubGroupId()
Definition: wmma_gemm.hpp:826
static __device__ auto GetSwizzledLaneIdLow()
Definition: wmma_gemm.hpp:838
static constexpr __device__ index_t GetKPerWaveBlk()
Definition: wmma_gemm.hpp:783
static constexpr auto I3
Definition: wmma_gemm.hpp:679
static constexpr auto I5
Definition: wmma_gemm.hpp:681
__host__ static __device__ auto CalculateBThreadOriginDataIndex()
Definition: wmma_gemm.hpp:852
__host__ static constexpr __device__ auto MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs(const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA &c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
Definition: wmma_gemm.hpp:739
static __device__ CIndex GetBeginOfThreadBlk()
Definition: wmma_gemm.hpp:861
static constexpr auto I4
Definition: wmma_gemm.hpp:680
static constexpr __device__ index_t GetRegSizePerWmma()
Definition: wmma_gemm.hpp:776
__host__ static constexpr __device__ auto MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(const CDesc_MBlockxRepeat_MWave_MPerWMMA_NBlockxRepeat_NWave_NPerWMMA &c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma)
Definition: wmma_gemm.hpp:699
__host__ constexpr __device__ WmmaGemm()
Definition: wmma_gemm.hpp:686
static constexpr auto I2
Definition: wmma_gemm.hpp:678
static __device__ CIndex3D GetBeginOfThreadBlk3D()
Definition: wmma_gemm.hpp:869
static constexpr auto I1
Definition: wmma_gemm.hpp:677
static __device__ auto GetLaneIdUnderSubGroup()
Definition: wmma_gemm.hpp:834
static constexpr auto wmma_instr
Definition: wmma_gemm.hpp:879
Definition: wmma_gemm.hpp:565
static constexpr auto selected_wmma
Definition: wmma_gemm.hpp:648
__host__ constexpr __device__ WmmaSelector()
Definition: wmma_gemm.hpp:651
static constexpr auto GetWmma()
Definition: integral_constant.hpp:20
Definition: amd_wmma.hpp:96
Definition: amd_wmma.hpp:216
Definition: amd_wmma.hpp:72
Definition: amd_wmma.hpp:192
Definition: amd_wmma.hpp:297
Definition: amd_wmma.hpp:50
Definition: amd_wmma.hpp:170
Definition: amd_wmma.hpp:418
Definition: amd_wmma.hpp:394
Definition: amd_wmma.hpp:271
Definition: amd_wmma.hpp:25
Definition: amd_wmma.hpp:149
Definition: amd_wmma.hpp:370
Definition: amd_wmma.hpp:346
Definition: amd_wmma.hpp:319
Definition: amd_wmma.hpp:121
Definition: amd_wmma.hpp:241
Definition: functional2.hpp:33
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:236
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:194
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:156
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:363
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:543
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:507
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:118
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:329
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:471
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:435
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:281
__device__ void run(const FloatA &a, const FloatB &b, FloatC ®_c) const
Definition: wmma_gemm.hpp:404
Definition: wmma_gemm.hpp:84