21 template <
typename T,
typename Default>
27 template <
typename T,
typename Default>
31 using type =
typename T::AQLayout;
34 template <
typename T,
typename Default>
40 template <
typename T,
typename Default>
44 using type =
typename T::BQLayout;
47 template <
typename T,
typename Default>
53 template <
typename T,
typename Default>
57 using type =
typename T::AQDataType;
60 template <
typename T,
typename Default>
66 template <
typename T,
typename Default>
70 using type =
typename T::BQDataType;
73 template <
typename T,
typename Default>
79 template <
typename T,
typename Default>
83 using type =
typename T::PreshuffleQuant;
145 M_, N_, K_, QK_A_, QK_B_, stride_A_, stride_B_, stride_C_, stride_AQ_, stride_BQ_),
183 template <
typename TilePartitioner_,
184 typename GemmPipeline_,
185 typename EpiloguePipeline_,
226 return concat(
'_',
"gemm_quant", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
232 return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
260 return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
266 const std::size_t k_id = blockIdx.z)
268 constexpr
auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(
I2);
269 const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.
k_batch * K1);
270 const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.
K + K_t - 1) / K_t * K1);
272 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
276 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
281 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
285 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
292 splitted_k = __builtin_amdgcn_readfirstlane(KRead);
318 static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
319 if(kargs.
QK_A % GemmPipeline::GetVectorSizeAQ() != 0)
323 CK_TILE_ERROR(
"K_A is not a multiple of vector load size for A tensor!");
332 static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
333 if(kargs.
QK_B % GemmPipeline::GetVectorSizeBQ() != 0)
337 CK_TILE_ERROR(
"K_B is not a multiple of vector load size for B tensor!");
343 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
345 if(kargs.
K % (TilePartitioner::KPerBlock * kargs.
k_batch) != 0 &&
346 GemmPipeline::kPadK ==
false)
350 CK_TILE_ERROR(
"Can't support K that is not a multiple of k_batch * KPerBlock "
355 if(kargs.
K % GemmPipeline::GetVectorSizeA() != 0)
359 CK_TILE_ERROR(
"K is not a multiple of vector load size for A tensor!");
366 if(kargs.
M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM ==
false)
371 "Can't support M that is not a multiple of MPerBlock without padding!");
375 if(kargs.
M % GemmPipeline::GetVectorSizeA() != 0)
379 CK_TILE_ERROR(
"M is not a multiple of vector load size for A tensor!");
385 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
387 if(kargs.
N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN ==
false)
392 "Can't support N that is not a multiple of NPerBlock without padding!");
396 if(kargs.
N % GemmPipeline::GetVectorSizeB() != 0)
400 CK_TILE_ERROR(
"N is not a multiple of vector load size for B tensor!");
407 if(kargs.
K % (TilePartitioner::KPerBlock * kargs.
k_batch) != 0 &&
408 GemmPipeline::kPadK ==
false)
412 CK_TILE_ERROR(
"Can't support K that is not a multiple of k_batch * KPerBlock "
417 if(kargs.
K % GemmPipeline::GetVectorSizeB() != 0)
421 CK_TILE_ERROR(
"K is not a multiple of vector load size for B tensor!");
427 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
429 if(kargs.
N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN ==
false)
434 "Can't support N that is not a multiple of NPerBlock without padding!");
438 if(kargs.
N % EpiloguePipeline::GetVectorSizeC() != 0)
442 CK_TILE_ERROR(
"N is not a multiple of vector load size for C tensor!");
449 if(kargs.
M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM ==
false)
454 "Can't support M that is not a multiple of MPerBlock without padding!");
458 if(kargs.
M % EpiloguePipeline::GetVectorSizeC() != 0)
462 CK_TILE_ERROR(
"M is not a multiple of vector load size for C tensor!");
470 template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
479 static_assert(!TilePartitioner::BlockGemmShape::PermuteA,
"Not implemented!");
480 const auto& a_tensor_view = [&]() {
481 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
483 return make_naive_tensor_view<address_space_enum::global>(
487 number<GemmPipeline::GetVectorSizeA()>{},
492 return make_naive_tensor_view<address_space_enum::global>(
496 number<GemmPipeline::GetVectorSizeA()>{},
501 const auto get_padding_size = [](
index_t length,
index_t alignment) {
505 const auto& aq_tensor_view = [&]() {
508 static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
509 const auto aq_x = kargs.
M * GemmPipeline::KPerBlockAQ;
510 const auto aq_y = kargs.
QK_A / GemmPipeline::KPerBlockAQ;
515 number<GemmPipeline::GetVectorSizeAQ()>{},
518 const auto block_tile_size = GemmPipeline::MPerBlock * GemmPipeline::KPerBlockAQ;
527 const auto pad_aq_x = aq_pad0_desc.get_lengths()[
I1];
528 const auto wave_tile_size =
529 TilePartitioner::BlockGemmShape::WarpTile::at(
I0) * GemmPipeline::KPerBlockAQ;
530 const auto wave_tile_count_x =
541 aq_unmerge_pad0_desc,
546 wave_tile_size, get_padding_size(wave_tile_size,
get_warp_size()))),
550 const auto pad_wave_size =
559 return make_tensor_view<address_space_enum::global>(aq_ptr, aq_merge_pad1_desc);
563 static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
564 return make_naive_tensor_view<address_space_enum::global>(
568 number<GemmPipeline::GetVectorSizeAQ()>{},
573 return make_naive_tensor_view<address_space_enum::global>(
586 const auto& b_tensor_view = [&]() {
587 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
589 if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
591 constexpr
index_t K1 = GemmPipeline::GetSmemPackB();
593 constexpr
index_t VectorSizeB =
std::min(K1, GemmPipeline::GetVectorSizeB());
594 const auto b_k0_n_k1_desc =
605 return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
609 return make_naive_tensor_view<address_space_enum::global>(
613 number<GemmPipeline::GetVectorSizeB()>{},
619 if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
621 constexpr
index_t K1 = GemmPipeline::GetSmemPackB();
623 constexpr
index_t VectorSizeB =
std::min(K1, GemmPipeline::GetVectorSizeB());
624 const auto b_k0_n_k1_desc =
635 return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
639 return make_naive_tensor_view<address_space_enum::global>(
643 number<GemmPipeline::GetVectorSizeB()>{},
649 const auto& bq_tensor_view = [&]() {
652 return make_naive_tensor_view<address_space_enum::global>(
661 static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
662 return make_naive_tensor_view<address_space_enum::global>(
666 number<GemmPipeline::GetVectorSizeBQ()>{},
676 const auto& c_tensor_view = [&]() {
677 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
679 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
683 number<EpiloguePipeline::GetVectorSizeC()>{},
688 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
698 a_tensor_view, aq_tensor_view, b_tensor_view, bq_tensor_view, c_tensor_view);
701 template <
typename TensorView>
704 const auto& a_pad_view = [&]() {
705 const auto& a_tensor_view = views.at(
I0);
706 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
723 const auto& aq_pad_view = [&]() {
return views.at(
I1); }();
725 const auto& b_pad_view = [&]() {
726 const auto& b_tensor_view = views.at(
I2);
727 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
744 const auto& bq_pad_view = [&]() {
return views.at(
I3); }();
747 const auto& c_pad_view = [&]() {
748 const auto& c_tensor_view = views.at(
I4);
749 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
765 return make_tuple(a_pad_view, aq_pad_view, b_pad_view, bq_pad_view, c_pad_view);
768 template <
typename PadView>
772 const auto& a_pad_view = views.at(
I0);
773 const auto& aq_pad_view = views.at(
I1);
774 const auto& b_pad_view = views.at(
I2);
775 const auto& bq_pad_view = views.at(
I3);
776 const auto& c_pad_view = views.at(
I4);
778 const auto& a_block_window = [&]() {
779 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
795 const auto& aq_block_window = [&]() {
798 static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
799 constexpr
auto block_m = TilePartitioner::MPerBlock;
800 constexpr
auto warp_m = TilePartitioner::BlockGemmShape::WarpTile::at(
I0);
801 constexpr
auto aqk_per_block =
802 TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize;
803 constexpr
auto tile_window_width =
805 constexpr
auto tile_window_height = block_m / warp_m;
806 auto block_m_idx = i_m / block_m;
810 {block_m_idx * tile_window_height, 0});
814 static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
815 constexpr
auto block_m = TilePartitioner::MPerBlock;
816 constexpr
auto block_k = TilePartitioner::KPerBlock;
835 const auto& b_block_window = [&]() {
836 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
852 const auto& bq_block_window = [&]() {
862 static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
866 number<TilePartitioner::KPerBlock / GemmPipeline::QuantGroupSize>{}),
881 a_block_window, aq_block_window, b_block_window, bq_block_window, c_block_window);
899 template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
912 const auto& gemm_tensor_views_tuple = MakeGemmTensorViews<DstInMemOp>(
913 a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
918 const index_t num_loop = __builtin_amdgcn_readfirstlane(
919 TilePartitioner::GetLoopNum(splitk_batch_offset.
splitted_k));
922 const auto& a_block_window = gemm_tile_windows.at(
I0);
923 const auto& b_block_window = gemm_tile_windows.at(
I2);
925 const auto& c_block_tile = [&]() {
928 const auto& aq_block_window = gemm_tile_windows.at(
I1);
930 a_block_window, b_block_window, aq_block_window, kargs.
M, num_loop, smem_ptr_0);
934 const auto& bq_block_window = gemm_tile_windows.at(
I3);
936 a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0);
941 a_block_window, b_block_window, num_loop, smem_ptr_0);
946 auto& c_block_window = gemm_tile_windows.at(
I4);
951 EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
955 const auto& aq_block_window = gemm_tile_windows.at(
I1);
956 const auto& bq_block_window = gemm_tile_windows.at(
I3);
968 const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
969 const auto [iM, iN] =
TilePartitioner{kargs.
M, kargs.
N}.GetOutputTileIndex(blockId);
970 const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
971 const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
986 a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
requires requires
Definition: gemm_quant_kernel.hpp:28
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad_, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1584
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:268
constexpr CK_TILE_HOST_DEVICE auto integer_least_multiple(X x, Y y)
Definition: math.hpp:155
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1615
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1622
QuantType
Definition: tile_gemm_quant_traits.hpp:12
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:197
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
unsigned int uint32_t
Definition: stdint.h:126
Definition: gemm_quant_kernel.hpp:126
void * c_ptr
Definition: gemm_quant_kernel.hpp:159
const void * aq_ptr
Definition: gemm_quant_kernel.hpp:157
const void * bq_ptr
Definition: gemm_quant_kernel.hpp:158
const void * b_ptr
Definition: gemm_quant_kernel.hpp:156
CK_TILE_HOST QuantGemmHostArgs()=default
index_t k_batch
Definition: gemm_quant_kernel.hpp:160
const void * a_ptr
Definition: gemm_quant_kernel.hpp:155
CK_TILE_HOST QuantGemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, const void *aq_ptr_, const void *bq_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t QK_A_, index_t QK_B_, index_t stride_A_, index_t stride_B_, index_t stride_C_, index_t stride_AQ_, index_t stride_BQ_)
Definition: gemm_quant_kernel.hpp:128
Definition: gemm_quant_kernel.hpp:264
__device__ SplitKBatchOffset(const QuantGemmKernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: gemm_quant_kernel.hpp:265
index_t a_k_split_offset
Definition: gemm_quant_kernel.hpp:300
index_t b_k_split_offset
Definition: gemm_quant_kernel.hpp:301
index_t splitted_k
Definition: gemm_quant_kernel.hpp:302
Definition: gemm_quant_kernel.hpp:164
index_t k_batch
Definition: gemm_quant_kernel.hpp:180
index_t stride_BQ
Definition: gemm_quant_kernel.hpp:179
const void * b_ptr
Definition: gemm_quant_kernel.hpp:166
void * c_ptr
Definition: gemm_quant_kernel.hpp:169
const void * aq_ptr
Definition: gemm_quant_kernel.hpp:167
index_t stride_A
Definition: gemm_quant_kernel.hpp:175
index_t M
Definition: gemm_quant_kernel.hpp:170
const void * a_ptr
Definition: gemm_quant_kernel.hpp:165
const void * bq_ptr
Definition: gemm_quant_kernel.hpp:168
index_t QK_B
Definition: gemm_quant_kernel.hpp:174
index_t K
Definition: gemm_quant_kernel.hpp:172
index_t QK_A
Definition: gemm_quant_kernel.hpp:173
index_t stride_AQ
Definition: gemm_quant_kernel.hpp:178
index_t N
Definition: gemm_quant_kernel.hpp:171
index_t stride_C
Definition: gemm_quant_kernel.hpp:177
index_t stride_B
Definition: gemm_quant_kernel.hpp:176
Definition: gemm_quant_kernel.hpp:188
static constexpr auto I4
Definition: gemm_quant_kernel.hpp:219
static constexpr auto I3
Definition: gemm_quant_kernel.hpp:218
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: gemm_quant_kernel.hpp:230
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: gemm_quant_kernel.hpp:190
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: gemm_quant_kernel.hpp:191
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: gemm_quant_kernel.hpp:702
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: gemm_quant_kernel.hpp:189
remove_cvref_t< typename EpiloguePipeline::AccDataType > AccDataType
Definition: gemm_quant_kernel.hpp:208
static constexpr auto I0
Definition: gemm_quant_kernel.hpp:215
CK_TILE_DEVICE void operator()(QuantGemmKernelArgs kargs) const
Definition: gemm_quant_kernel.hpp:966
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: gemm_quant_kernel.hpp:207
static constexpr index_t kBlockSize
Definition: gemm_quant_kernel.hpp:201
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: gemm_quant_kernel.hpp:193
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: gemm_quant_kernel.hpp:194
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, const QuantGemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: gemm_quant_kernel.hpp:471
static constexpr auto I1
Definition: gemm_quant_kernel.hpp:216
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition: gemm_quant_kernel.hpp:192
static constexpr bool PreshuffleQuant
Definition: gemm_quant_kernel.hpp:202
static CK_TILE_HOST bool IsSupportedArgument(const QuantGemmKernelArgs &kargs)
Definition: gemm_quant_kernel.hpp:305
remove_cvref_t< typename detail::get_aq_data_type_or< GemmPipeline, AccDataType >::type > AQDataType
Definition: gemm_quant_kernel.hpp:211
remove_cvref_t< typename detail::get_bq_data_type_or< GemmPipeline, AccDataType >::type > BQDataType
Definition: gemm_quant_kernel.hpp:213
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: gemm_quant_kernel.hpp:206
static constexpr auto I2
Definition: gemm_quant_kernel.hpp:217
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: gemm_quant_kernel.hpp:258
static CK_TILE_DEVICE void RunGemm(const ADataType *a_ptr, const BDataType *b_ptr, const AQDataType *aq_ptr, const BQDataType *bq_ptr, CDataType *c_ptr, void *smem_ptr_0, const QuantGemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: gemm_quant_kernel.hpp:900
static constexpr CK_TILE_HOST QuantGemmKernelArgs MakeKernelArgs(const QuantGemmHostArgs &hostArgs)
Definition: gemm_quant_kernel.hpp:238
static CK_TILE_HOST const std::string GetName()
Definition: gemm_quant_kernel.hpp:223
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: gemm_quant_kernel.hpp:770
remove_cvref_t< typename detail::get_bq_layout_or< GemmPipeline, typename GemmPipeline::BLayout >::type > BQLayout
Definition: gemm_quant_kernel.hpp:199
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Definition: gemm_quant_kernel.hpp:205
remove_cvref_t< typename detail::get_aq_layout_or< GemmPipeline, typename GemmPipeline::ALayout >::type > AQLayout
Definition: gemm_quant_kernel.hpp:197
static constexpr auto kQuantType
Definition: gemm_quant_kernel.hpp:221
static constexpr CK_TILE_HOST auto BlockSize()
Definition: gemm_quant_kernel.hpp:235
Definition: gemm_quant_kernel.hpp:88
index_t stride_AQ
Definition: gemm_quant_kernel.hpp:121
index_t N
Definition: gemm_quant_kernel.hpp:114
index_t K
Definition: gemm_quant_kernel.hpp:115
index_t stride_BQ
Definition: gemm_quant_kernel.hpp:122
index_t stride_C
Definition: gemm_quant_kernel.hpp:120
index_t stride_B
Definition: gemm_quant_kernel.hpp:119
index_t stride_A
Definition: gemm_quant_kernel.hpp:118
CK_TILE_HOST QuantGemmProblem(index_t M_, index_t N_, index_t K_, index_t QK_A_, index_t QK_B_, index_t stride_A_, index_t stride_B_, index_t stride_C_, index_t stride_AQ_, index_t stride_BQ_)
Definition: gemm_quant_kernel.hpp:90
index_t QK_A
Definition: gemm_quant_kernel.hpp:116
index_t QK_B
Definition: gemm_quant_kernel.hpp:117
CK_TILE_HOST QuantGemmProblem()=default
index_t M
Definition: gemm_quant_kernel.hpp:113
Definition: integral_constant.hpp:13
typename T::AQDataType type
Definition: gemm_quant_kernel.hpp:57
Definition: gemm_quant_kernel.hpp:49
Default type
Definition: gemm_quant_kernel.hpp:50
typename T::AQLayout type
Definition: gemm_quant_kernel.hpp:31
Definition: gemm_quant_kernel.hpp:23
Default type
Definition: gemm_quant_kernel.hpp:24
typename T::BQDataType type
Definition: gemm_quant_kernel.hpp:70
Definition: gemm_quant_kernel.hpp:62
Default type
Definition: gemm_quant_kernel.hpp:63
typename T::BQLayout type
Definition: gemm_quant_kernel.hpp:44
Definition: gemm_quant_kernel.hpp:36
Default type
Definition: gemm_quant_kernel.hpp:37
typename T::PreshuffleQuant type
Definition: gemm_quant_kernel.hpp:83
Definition: gemm_quant_kernel.hpp:75
Default type
Definition: gemm_quant_kernel.hpp:76
Definition: sequence.hpp:49
#define CK_TILE_ENV(name)
Definition: env.hpp:145