63 :
AQuantGemmProblem(M_, N_, K_, QK_, stride_A_, stride_B_, stride_C_, stride_AQ_),
96 template <
typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_>
121 return concat(
'_',
"gemm", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
127 return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
152 return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
158 const std::size_t k_id = blockIdx.z)
160 constexpr
auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(
number<2>{});
161 const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.
k_batch * K1);
162 const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.
K + K_t - 1) / K_t * K1);
164 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
168 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
173 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
177 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
182 if(k_id <
static_cast<uint32_t
>(kargs.
k_batch - 1))
184 splitted_k = __builtin_amdgcn_readfirstlane(KRead);
208 static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
209 if(kargs.
QK % GemmPipeline::GetVectorSizeAQ() != 0)
213 CK_TILE_ERROR(
"K is not a multiple of vector load size for A tensor!");
218 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
220 if(kargs.
K % (TilePartitioner::KPerBlock * kargs.
k_batch) != 0 &&
221 GemmPipeline::kPadK ==
false)
225 CK_TILE_ERROR(
"Can't support K that is not a multiple of k_batch * KPerBlock "
230 if(kargs.
K % GemmPipeline::GetVectorSizeA() != 0)
234 CK_TILE_ERROR(
"K is not a multiple of vector load size for A tensor!");
241 if(kargs.
M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM ==
false)
246 "Can't support M that is not a multiple of MPerBlock without padding!");
250 if(kargs.
M % GemmPipeline::GetVectorSizeA() != 0)
254 CK_TILE_ERROR(
"M is not a multiple of vector load size for A tensor!");
260 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
262 if(kargs.
N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN ==
false)
267 "Can't support N that is not a multiple of NPerBlock without padding!");
271 if(kargs.
N % GemmPipeline::GetVectorSizeB() != 0)
275 CK_TILE_ERROR(
"N is not a multiple of vector load size for B tensor!");
282 if(kargs.
K % (TilePartitioner::KPerBlock * kargs.
k_batch) != 0 &&
283 GemmPipeline::kPadK ==
false)
287 CK_TILE_ERROR(
"Can't support K that is not a multiple of k_batch * KPerBlock "
292 if(kargs.
K % GemmPipeline::GetVectorSizeB() != 0)
296 CK_TILE_ERROR(
"K is not a multiple of vector load size for B tensor!");
302 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
304 if(kargs.
N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN ==
false)
309 "Can't support N that is not a multiple of NPerBlock without padding!");
313 if(kargs.
N % EpiloguePipeline::GetVectorSizeC() != 0)
317 CK_TILE_ERROR(
"N is not a multiple of vector load size for C tensor!");
324 if(kargs.
M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM ==
false)
329 "Can't support M that is not a multiple of MPerBlock without padding!");
333 if(kargs.
M % EpiloguePipeline::GetVectorSizeC() != 0)
337 CK_TILE_ERROR(
"M is not a multiple of vector load size for C tensor!");
345 template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
353 static_assert(!TilePartitioner::BlockGemmShape::PermuteA,
"Not implemented!");
354 const auto& a_tensor_view = [&]() {
355 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
357 return make_naive_tensor_view<address_space_enum::global>(
361 number<GemmPipeline::GetVectorSizeA()>{},
366 return make_naive_tensor_view<address_space_enum::global>(
370 number<GemmPipeline::GetVectorSizeA()>{},
375 const auto& aq_tensor_view = [&]() {
376 static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
377 return make_naive_tensor_view<address_space_enum::global>(
381 number<GemmPipeline::GetVectorSizeAQ()>{},
385 const auto& b_tensor_view = [&]() {
386 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
388 if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
390 constexpr
index_t K1 = GemmPipeline::GetSmemPackB();
392 constexpr
index_t VectorSizeB =
std::min(K1, GemmPipeline::GetVectorSizeB());
393 const auto b_k0_n_k1_desc =
404 return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
408 return make_naive_tensor_view<address_space_enum::global>(
412 number<GemmPipeline::GetVectorSizeB()>{},
418 if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
420 constexpr
index_t K1 = GemmPipeline::GetSmemPackB();
422 constexpr
index_t VectorSizeB =
std::min(K1, GemmPipeline::GetVectorSizeB());
423 const auto b_k0_n_k1_desc =
434 return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
438 return make_naive_tensor_view<address_space_enum::global>(
442 number<GemmPipeline::GetVectorSizeB()>{},
449 const auto& c_tensor_view = [&]() {
450 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
452 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
456 number<EpiloguePipeline::GetVectorSizeC()>{},
461 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
470 return make_tuple(a_tensor_view, aq_tensor_view, b_tensor_view, c_tensor_view);
473 template <
typename TensorView>
476 const auto& a_pad_view = [&]() {
477 const auto& a_tensor_view = views.at(
I0);
478 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
494 const auto& aq_pad_view = [&]() {
495 const auto& aq_tensor_view = views.at(
I1);
496 static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
500 number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
505 const auto& b_pad_view = [&]() {
506 const auto& b_tensor_view = views.at(
I2);
507 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
524 const auto& c_pad_view = [&]() {
525 const auto& c_tensor_view = views.at(
I3);
526 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
542 return make_tuple(a_pad_view, aq_pad_view, b_pad_view, c_pad_view);
545 template <
typename PadView>
549 const auto& a_pad_view = views.at(
I0);
550 const auto& aq_pad_view = views.at(
I1);
551 const auto& b_pad_view = views.at(
I2);
552 const auto& c_pad_view = views.at(
I3);
554 const auto& a_block_window = [&]() {
555 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
571 const auto& aq_block_window = [&]() {
572 static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
576 number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
580 const auto& b_block_window = [&]() {
581 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
602 return make_tuple(a_block_window, aq_block_window, b_block_window, c_block_window);
620 template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
632 const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
633 a_ptr, b_ptr, aq_ptr, c_ptr, kargs, splitk_batch_offset);
638 const index_t num_loop = __builtin_amdgcn_readfirstlane(
639 TilePartitioner::GetLoopNum(splitk_batch_offset.
splitted_k));
642 const auto& a_block_window = gemm_tile_windows.at(
I0);
643 const auto& aq_block_window = gemm_tile_windows.at(
I1);
644 const auto& b_block_window = gemm_tile_windows.at(
I2);
647 a_block_window, b_block_window, aq_block_window, num_loop, smem_ptr_0);
650 auto& c_block_window = gemm_tile_windows.at(
I3);
653 operator()<decltype(c_block_window), decltype(c_block_tile), decltype(c_block_window)>(
654 c_block_window, c_block_tile, c_block_window, smem_ptr_0);
659 const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
660 const auto [iM, iN] =
TilePartitioner{kargs.
M, kargs.
N}.GetOutputTileIndex(blockId);
661 const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
662 const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
675 RunGemm(a_ptr, b_ptr, aq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:255
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1672
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:529
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1615
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:184
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:343
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
Definition: gemm_aquant_kernel.hpp:48
const void * a_ptr
Definition: gemm_aquant_kernel.hpp:72
const void * b_ptr
Definition: gemm_aquant_kernel.hpp:73
const void * aq_ptr
Definition: gemm_aquant_kernel.hpp:74
void * c_ptr
Definition: gemm_aquant_kernel.hpp:75
index_t k_batch
Definition: gemm_aquant_kernel.hpp:76
CK_TILE_HOST AQuantGemmHostArgs()=default
CK_TILE_HOST AQuantGemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, const void *aq_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t QK_, index_t stride_A_, index_t stride_B_, index_t stride_C_, index_t stride_AQ_)
Definition: gemm_aquant_kernel.hpp:50
Definition: gemm_aquant_kernel.hpp:156
__device__ SplitKBatchOffset(const AQuantGemmKernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: gemm_aquant_kernel.hpp:157
index_t b_k_split_offset
Definition: gemm_aquant_kernel.hpp:193
index_t a_k_split_offset
Definition: gemm_aquant_kernel.hpp:192
index_t splitted_k
Definition: gemm_aquant_kernel.hpp:194
Definition: gemm_aquant_kernel.hpp:80
const void * a_ptr
Definition: gemm_aquant_kernel.hpp:81
index_t stride_A
Definition: gemm_aquant_kernel.hpp:89
index_t stride_B
Definition: gemm_aquant_kernel.hpp:90
const void * b_ptr
Definition: gemm_aquant_kernel.hpp:82
index_t K
Definition: gemm_aquant_kernel.hpp:87
index_t M
Definition: gemm_aquant_kernel.hpp:85
index_t k_batch
Definition: gemm_aquant_kernel.hpp:93
void * c_ptr
Definition: gemm_aquant_kernel.hpp:84
index_t stride_AQ
Definition: gemm_aquant_kernel.hpp:92
index_t N
Definition: gemm_aquant_kernel.hpp:86
index_t stride_C
Definition: gemm_aquant_kernel.hpp:91
index_t QK
Definition: gemm_aquant_kernel.hpp:88
const void * aq_ptr
Definition: gemm_aquant_kernel.hpp:83
Definition: gemm_aquant_kernel.hpp:98
remove_cvref_t< typename GemmPipeline::AQLayout > AQLayout
Definition: gemm_aquant_kernel.hpp:103
static constexpr auto I1
Definition: gemm_aquant_kernel.hpp:114
static constexpr auto I0
Definition: gemm_aquant_kernel.hpp:113
static constexpr auto I3
Definition: gemm_aquant_kernel.hpp:116
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, CDataType *c_ptr, const AQuantGemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: gemm_aquant_kernel.hpp:346
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: gemm_aquant_kernel.hpp:101
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: gemm_aquant_kernel.hpp:474
static constexpr CK_TILE_HOST AQuantGemmKernelArgs MakeKernelArgs(const AQuantGemmHostArgs &hostArgs)
Definition: gemm_aquant_kernel.hpp:133
static constexpr CK_TILE_HOST auto BlockSize()
Definition: gemm_aquant_kernel.hpp:130
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Definition: gemm_aquant_kernel.hpp:108
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: gemm_aquant_kernel.hpp:104
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: gemm_aquant_kernel.hpp:110
static constexpr index_t KernelBlockSize
Definition: gemm_aquant_kernel.hpp:106
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: gemm_aquant_kernel.hpp:105
CK_TILE_DEVICE void operator()(AQuantGemmKernelArgs kargs) const
Definition: gemm_aquant_kernel.hpp:657
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition: gemm_aquant_kernel.hpp:102
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: gemm_aquant_kernel.hpp:547
remove_cvref_t< typename GemmPipeline::AQDataType > AQDataType
Definition: gemm_aquant_kernel.hpp:109
static CK_TILE_DEVICE void RunGemm(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, CDataType *c_ptr, void *smem_ptr_0, const AQuantGemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: gemm_aquant_kernel.hpp:621
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: gemm_aquant_kernel.hpp:99
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: gemm_aquant_kernel.hpp:150
static CK_TILE_HOST bool IsSupportedArgument(const AQuantGemmKernelArgs &kargs)
Definition: gemm_aquant_kernel.hpp:197
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: gemm_aquant_kernel.hpp:111
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: gemm_aquant_kernel.hpp:100
static CK_TILE_HOST const std::string GetName()
Definition: gemm_aquant_kernel.hpp:118
static constexpr auto I2
Definition: gemm_aquant_kernel.hpp:115
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: gemm_aquant_kernel.hpp:125
Definition: gemm_aquant_kernel.hpp:16
index_t M
Definition: gemm_aquant_kernel.hpp:37
index_t stride_B
Definition: gemm_aquant_kernel.hpp:42
CK_TILE_HOST AQuantGemmProblem(index_t M_, index_t N_, index_t K_, index_t QK_, index_t stride_A_, index_t stride_B_, index_t stride_C_, index_t stride_AQ_)
Definition: gemm_aquant_kernel.hpp:18
index_t stride_A
Definition: gemm_aquant_kernel.hpp:41
CK_TILE_HOST AQuantGemmProblem()=default
index_t QK
Definition: gemm_aquant_kernel.hpp:40
index_t stride_C
Definition: gemm_aquant_kernel.hpp:43
index_t K
Definition: gemm_aquant_kernel.hpp:39
index_t N
Definition: gemm_aquant_kernel.hpp:38
index_t stride_AQ
Definition: gemm_aquant_kernel.hpp:44
Definition: integral_constant.hpp:13
Definition: sequence.hpp:52
#define CK_TILE_ENV(name)
Definition: env.hpp:145