21 template <
typename Gr
idwiseGemm,
bool HasMainKBlockLoop>
23 #if CK_USE_LAUNCH_BOUNDS
26 #if CK_USE_WAVES_PER_EU
27 __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
31 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx103__) || defined(__gfx11__))
32 __shared__
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
35 GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(karg.M, karg.K, karg.AK0, karg.StrideA));
37 GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(karg.K, karg.N, karg.BK0, karg.StrideB));
39 GridwiseGemm::MakeCGridDescriptor_M_N(karg.M, karg.N, karg.StrideC));
41 GridwiseGemm::template Run<HasMainKBlockLoop>(karg.p_a_grid,
45 a_grid_desc_ak0_m_ak1,
46 b_grid_desc_bk0_n_bk1,
61 typename AElementwiseOperation,
62 typename BElementwiseOperation,
63 typename CElementwiseOperation,
74 typename ABlockTransferThreadClusterLengths_K0_M_K1,
75 typename ABlockTransferThreadClusterArrangeOrder,
76 typename ABlockTransferSrcAccessOrder,
77 index_t ABlockTransferSrcVectorDim,
78 index_t ABlockTransferSrcScalarPerVector,
79 index_t ABlockTransferDstScalarPerVector_K1,
80 bool AThreadTransferSrcResetCoordinateAfterRun,
82 typename BBlockTransferThreadClusterLengths_K0_N_K1,
83 typename BBlockTransferThreadClusterArrangeOrder,
84 typename BBlockTransferSrcAccessOrder,
85 index_t BBlockTransferSrcVectorDim,
86 index_t BBlockTransferSrcScalarPerVector,
87 index_t BBlockTransferDstScalarPerVector_K1,
88 bool BThreadTransferSrcResetCoordinateAfterRun,
90 typename CThreadTransferSrcDstAccessOrder,
91 index_t CThreadTransferSrcDstVectorDim,
92 index_t CThreadTransferDstScalarPerVector,
93 index_t NumGemmKPrefetchStage = 1,
157 std::cout <<
"problem {"
166 <<
"AK0:" <<
AK0 <<
", "
167 <<
"BK0:" <<
BK0 <<
"}" << std::endl;
186 const ABDataType* p_b_grid_,
187 CDataType* p_c_grid_,
194 :
Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
207 decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
212 constexpr
auto a_block_desc_ak0_m_ak1 = [&]() {
213 if constexpr(ABlockLdsExtraM)
226 return a_block_desc_ak0_m_ak1;
232 constexpr
auto b_block_desc_bk0_n_bk1 = [&]() {
233 if constexpr(BBlockLdsExtraN)
246 return b_block_desc_bk0_n_bk1;
256 a_block_desc_ak0_m_ak1.GetElementSpaceSize(),
max_lds_align);
258 b_block_desc_bk0_n_bk1.GetElementSpaceSize(),
max_lds_align);
260 return (a_block_space_size_aligned + b_block_space_size_aligned) *
sizeof(ABDataType);
266 "Wrong! AK1 must be known at the time of compilation.");
268 "Wrong! BK1 must be known at the time of compilation.");
271 MPerBlock % (MPerDpp * MDppPerWave) == 0,
272 "Invalid tuning parameters! MPerBlock must be divisible by MPerDpp * MDppPerWave.");
274 NPerBlock % (NPerDpp * NDppPerWave) == 0,
275 "Invalid tuning parameters! NPerBlock must be divisible by NPerDpp * NDppPerWave.");
278 KPerBlock % AK1Value == 0 && KPerBlock % BK1Value == 0,
279 "Invalid tuning parameters! KPerBlock must be divisible by both AK1 and BK1.");
281 static_assert(AK1Value % ABlockTransferDstScalarPerVector_K1 == 0,
282 "Invalid tuning parameters! AK1Value must be divisible by "
283 "ABlockTransferDstScalarPerVector_K1");
285 static_assert(BK1Value % BBlockTransferDstScalarPerVector_K1 == 0,
286 "Invalid tuning parameters! BK1Value must be divisible by "
287 "BBlockTransferDstScalarPerVector_K1");
294 if(!(problem.
M % MPerBlock == 0))
305 if(!(problem.
N % NPerBlock == 0))
313 if(problem.
K % ABlockTransferSrcScalarPerVector != 0)
320 if(problem.
M % ABlockTransferSrcScalarPerVector != 0)
328 if(problem.
N % BBlockTransferSrcScalarPerVector != 0)
335 if(problem.
K % BBlockTransferSrcScalarPerVector != 0)
341 if(problem.
K % KPerBlock != 0)
347 const auto num_k_loop = problem.
K / KPerBlock;
348 if(!GridwiseGemmPipe::IsSupported(num_k_loop))
358 const auto num_loop = K / KPerBlock;
360 return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
363 template <
typename CGr
idDesc>
364 __host__ __device__
static constexpr
auto
373 using BlockwiseGemm =
377 decltype(a_block_desc_ak0_m_ak1),
378 decltype(b_block_desc_bk0_n_bk1),
385 return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(c_grid_desc_m_n);
390 MPerBlock, NPerBlock, KPerBlock};
392 __device__
static auto
395 const auto a_grid_desc_mraw_kraw = [&]() {
406 const auto a_grid_desc_m_k =
matrix_padder.PadADescriptor_M_K(a_grid_desc_mraw_kraw);
415 __device__
static auto
418 const auto b_grid_desc_nraw_kraw = [&]() {
429 const auto b_grid_desc_n_k =
matrix_padder.PadBDescriptor_N_K(b_grid_desc_nraw_kraw);
440 const auto c_grid_desc_mraw_nraw = [&]() {
451 return matrix_padder.PadCDescriptor_M_N(c_grid_desc_mraw_nraw);
454 template <
bool HasMainKBlockLoop,
455 typename AGridDesc_AK0_M_AK1,
456 typename BGridDesc_BK0_N_BK1,
457 typename CGridDesc_M_N>
458 __device__
static void Run(
const ABDataType* __restrict__ p_a_grid,
459 const ABDataType* __restrict__ p_b_grid,
460 CDataType* __restrict__ p_c_grid,
461 void* __restrict__ p_shared,
462 const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
463 const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
464 const CGridDesc_M_N& c_grid_desc_m_n)
466 const auto c_grid_desc_m0_n0_m1_n1_m2_n2 =
469 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
470 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
471 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
472 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
473 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
474 p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_n2.GetElementSpaceSize());
476 const AElementwiseOperation a_element_op{};
477 const BElementwiseOperation b_element_op{};
478 const CElementwiseOperation c_element_op{};
480 const auto block_2_ctile_map =
484 const auto block_work_idx =
487 if(!block_2_ctile_map.ValidCTileIndex(
490 c_grid_desc_m0_n0_m1_n1_m2_n2.GetLength(
I1))))
496 const index_t m_block_data_idx_on_grid =
497 __builtin_amdgcn_readfirstlane(block_work_idx[
I0] * MPerBlock);
498 const index_t n_block_data_idx_on_grid =
499 __builtin_amdgcn_readfirstlane(block_work_idx[
I1] * NPerBlock);
506 auto a_blockwise_copy =
508 AElementwiseOperation,
512 ABlockTransferThreadClusterLengths_K0_M_K1,
513 ABlockTransferThreadClusterArrangeOrder,
516 decltype(a_grid_desc_ak0_m_ak1),
517 decltype(a_block_desc_ak0_m_ak1),
518 ABlockTransferSrcAccessOrder,
520 ABlockTransferSrcVectorDim,
522 ABlockTransferSrcScalarPerVector,
523 ABlockTransferDstScalarPerVector_K1,
526 AThreadTransferSrcResetCoordinateAfterRun,
528 NumGemmKPrefetchStage>(
529 a_grid_desc_ak0_m_ak1,
532 a_block_desc_ak0_m_ak1,
536 auto b_blockwise_copy =
538 BElementwiseOperation,
542 BBlockTransferThreadClusterLengths_K0_N_K1,
543 BBlockTransferThreadClusterArrangeOrder,
546 decltype(b_grid_desc_bk0_n_bk1),
547 decltype(b_block_desc_bk0_n_bk1),
548 BBlockTransferSrcAccessOrder,
550 BBlockTransferSrcVectorDim,
552 BBlockTransferSrcScalarPerVector,
553 BBlockTransferDstScalarPerVector_K1,
556 BThreadTransferSrcResetCoordinateAfterRun,
558 NumGemmKPrefetchStage>(
559 b_grid_desc_bk0_n_bk1,
562 b_block_desc_bk0_n_bk1,
574 auto blockwise_gemm =
578 decltype(a_block_desc_ak0_m_ak1),
579 decltype(b_block_desc_bk0_n_bk1),
586 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
590 a_block_desc_ak0_m_ak1.GetElementSpaceSize(),
max_lds_align);
592 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
593 static_cast<ABDataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
595 auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
596 static_cast<ABDataType*
>(p_shared) + a_block_space_size_aligned,
597 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
603 const auto AK0 = a_grid_desc_ak0_m_ak1.GetLength(
I0);
605 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(AK0 /
AK0PerBlock);
607 GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_ak0_m_ak1,
608 a_block_desc_ak0_m_ak1,
612 a_block_slice_copy_step,
613 b_grid_desc_bk0_n_bk1,
614 b_block_desc_bk0_n_bk1,
618 b_block_slice_copy_step,
621 num_k_block_main_loop);
625 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_n2 =
626 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_N2();
628 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_n2 =
629 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2();
631 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(
I0);
632 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(
I1);
633 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(
I2);
634 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(
I3);
635 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(
I4);
636 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_n2.GetLength(
I5);
638 constexpr
auto MPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(
I4);
639 constexpr
auto NPerThread = c_thread_desc_m0_n0_m1_n1_m2_n2.GetLength(
I5);
643 const auto c_thread_mtx_on_block =
644 blockwise_gemm.CalculateCThreadOriginDataIndex(
I0,
I0);
646 const index_t m_thread_data_on_grid =
647 m_block_data_idx_on_grid + c_thread_mtx_on_block[
I0];
649 const index_t n_thread_data_on_grid =
650 n_block_data_idx_on_grid + c_thread_mtx_on_block[
I1];
657 const auto m_thread_data_on_grid_idx =
658 m_thread_data_on_grid_to_m0_m1_m2_adaptor.CalculateBottomIndex(
666 const auto n_thread_data_on_grid_idx =
667 n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
673 decltype(c_thread_desc_m0_n0_m1_n1_m2_n2),
674 decltype(c_grid_desc_m0_n0_m1_n1_m2_n2),
675 CElementwiseOperation,
677 CThreadTransferSrcDstAccessOrder,
678 CThreadTransferSrcDstVectorDim,
679 CThreadTransferDstScalarPerVector,
680 CGlobalMemoryDataOperation,
683 c_grid_desc_m0_n0_m1_n1_m2_n2,
685 n_thread_data_on_grid_idx[
I0],
686 m_thread_data_on_grid_idx[
I1],
687 n_thread_data_on_grid_idx[
I1],
688 m_thread_data_on_grid_idx[
I2],
689 n_thread_data_on_grid_idx[
I2]),
692 c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_n2,
695 c_grid_desc_m0_n0_m1_n1_m2_n2,
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:34
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto integer_divide_floor(X x, Y y)
Definition: math.hpp:66
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:267
__global__ void kernel_gemm_dpp(const typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_dpp.hpp:29
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition: amd_wave_read_first_lane.hpp:100
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:17
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:298
Definition: blockwise_gemm_dpp.hpp:33
Definition: dpp_gemm.hpp:322
Definition: gridwise_gemm_dpp.hpp:184
const ABDataType * p_a_grid
Definition: gridwise_gemm_dpp.hpp:201
const ABDataType * p_b_grid
Definition: gridwise_gemm_dpp.hpp:202
CDataType * p_c_grid
Definition: gridwise_gemm_dpp.hpp:203
__host__ Argument(const ABDataType *p_a_grid_, const ABDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_dpp.hpp:185
Definition: gridwise_gemm_dpp.hpp:135
index_t NPadded
Definition: gridwise_gemm_dpp.hpp:177
index_t BK0
Definition: gridwise_gemm_dpp.hpp:179
index_t StrideB
Definition: gridwise_gemm_dpp.hpp:174
index_t N
Definition: gridwise_gemm_dpp.hpp:171
index_t K
Definition: gridwise_gemm_dpp.hpp:172
index_t StrideC
Definition: gridwise_gemm_dpp.hpp:175
index_t M
Definition: gridwise_gemm_dpp.hpp:170
index_t AK0
Definition: gridwise_gemm_dpp.hpp:178
index_t MPadded
Definition: gridwise_gemm_dpp.hpp:176
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_dpp.hpp:136
__host__ void Print() const
Definition: gridwise_gemm_dpp.hpp:155
index_t StrideA
Definition: gridwise_gemm_dpp.hpp:173
Definition: gridwise_gemm_dpp.hpp:96
static __host__ auto CalculateAK0(index_t K)
Definition: gridwise_gemm_dpp.hpp:130
static __device__ void Run(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, CDataType *__restrict__ p_c_grid, void *__restrict__ p_shared, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dpp.hpp:458
static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
Definition: gridwise_gemm_dpp.hpp:438
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_N2(const CGridDesc &c_grid_desc_m_n)
Definition: gridwise_gemm_dpp.hpp:365
static constexpr auto BK0PerBlock
Definition: gridwise_gemm_dpp.hpp:107
static __host__ auto CalculateBK0(index_t K)
Definition: gridwise_gemm_dpp.hpp:131
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_dpp.hpp:356
static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t N, index_t BK0, index_t StrideB)
Definition: gridwise_gemm_dpp.hpp:416
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_dpp.hpp:111
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_dpp.hpp:115
static constexpr auto I4
Definition: gridwise_gemm_dpp.hpp:101
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_dpp.hpp:263
static constexpr auto matrix_padder
Definition: gridwise_gemm_dpp.hpp:388
static constexpr auto I5
Definition: gridwise_gemm_dpp.hpp:102
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_dpp.hpp:120
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage >())> GridwiseGemmPipe
Definition: gridwise_gemm_dpp.hpp:207
static constexpr auto AK0PerBlock
Definition: gridwise_gemm_dpp.hpp:106
static constexpr auto I3
Definition: gridwise_gemm_dpp.hpp:100
__host__ static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_dpp.hpp:229
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_dpp.hpp:125
static constexpr auto BK1
Definition: gridwise_gemm_dpp.hpp:105
static constexpr auto I2
Definition: gridwise_gemm_dpp.hpp:99
static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t K, index_t AK0, index_t StrideA)
Definition: gridwise_gemm_dpp.hpp:393
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_dpp.hpp:249
static constexpr auto I1
Definition: gridwise_gemm_dpp.hpp:98
static constexpr auto I0
Definition: gridwise_gemm_dpp.hpp:97
static constexpr auto AK1
Definition: gridwise_gemm_dpp.hpp:104
static constexpr auto max_lds_align
Definition: gridwise_gemm_dpp.hpp:109
__host__ static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_dpp.hpp:209
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:10
Definition: is_known_at_compile_time.hpp:14
Definition: device_base.hpp:50
Definition: matrix_padder.hpp:180
Definition: unary_element_wise_operation.hpp:241