include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_problem.hpp Source File

include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_problem.hpp Source File#

Composable Kernel: include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_problem.hpp Source File
gemm_aquant_pipeline_problem.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 #include <string>
11 
12 namespace ck_tile {
13 
14 template <typename ADataType_,
15  typename AQDataType_,
16  typename BDataType_,
17  typename CDataType_,
18  typename BlockGemmShape_,
19  typename Traits_,
20  uint32_t QuantGroupSize_,
21  typename ComputeDataType_ = BDataType_,
23  bool HasHotLoop_ = true,
24  TailNumber TailNum_ = TailNumber::Full>
26  BDataType_,
27  CDataType_,
28  BlockGemmShape_,
29  Traits_,
30  ComputeDataType_>
31 {
32  using Base = GemmPipelineProblemBase<ADataType_,
33  BDataType_,
34  CDataType_,
35  BlockGemmShape_,
36  Traits_,
37  ComputeDataType_>;
38 
39  using Traits = typename Base::Traits;
40 
41  using typename Base::ADataType;
42  using typename Base::BDataType;
43  using typename Base::CDataType;
44  using typename Base::ComputeDataType;
46 
48 
49  using typename Base::ALayout;
50  using typename Base::BLayout;
51  using typename Base::CLayout;
52 
53  static constexpr bool TransposeC = false;
54 
55  using Base::kBlockSize;
56 
57  using Base::kPadK;
58  using Base::kPadM;
59  using Base::kPadN;
60 
63 
65 
66  static constexpr uint32_t kQuantGroupSize = QuantGroupSize_;
67  static constexpr auto Scheduler = Scheduler_;
68  static constexpr auto HasHotLoop = HasHotLoop_;
69  static constexpr auto TailNum = TailNum_;
70 
71  static_assert(BlockGemmShape::kK % kQuantGroupSize == 0);
73 
74  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
75  {
76  // clang-format off
77  return concat('_', "gemm_aquant_problem",
79  concat('x', kPadM, kPadN, kPadK),
80  Scheduler,
81  "QuantGroupSize",
83  // clang-format on
84  }
85 
86  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentAQ()
87  {
88  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
89  return VectorLoadSize / sizeof(AQDataType);
90  }
91 
92  static constexpr index_t VectorSizeAQ = []() {
93  static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
94  return kPadK ? 1 : GetAlignmentAQ();
95  }();
96 };
97 
98 template <typename ADataType_,
99  typename AQDataType_,
100  typename BDataType_,
101  typename CDataType_,
102  typename BlockGemmShape_,
103  typename Traits_,
104  uint32_t QuantGroupSize_,
105  typename ComputeDataType_ = BDataType_,
107  bool HasHotLoop_ = true,
108  TailNumber TailNum_ = TailNumber::Full>
110  AQDataType_,
111  BDataType_,
112  CDataType_,
113  BlockGemmShape_,
114  Traits_,
115  QuantGroupSize_,
116  ComputeDataType_,
117  Scheduler_,
118  HasHotLoop_,
119  TailNum_>;
120 
121 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
TailNumber
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:21
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
GemmPipelineScheduler
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:14
Definition: gemm_aquant_pipeline_problem.hpp:31
static constexpr auto TailNum
Definition: gemm_aquant_pipeline_problem.hpp:69
static constexpr bool TransposeC
Definition: gemm_aquant_pipeline_problem.hpp:53
typename Base::BlockGemmShape BlockGemmShape
Definition: gemm_aquant_pipeline_problem.hpp:47
typename Base::Traits Traits
Definition: gemm_aquant_pipeline_problem.hpp:39
static constexpr bool kPadM
Definition: gemm_pipeline_problem.hpp:44
static constexpr index_t VectorSizeAQ
Definition: gemm_aquant_pipeline_problem.hpp:92
static constexpr index_t VectorLoadSize
Definition: gemm_pipeline_problem.hpp:50
remove_cvref_t< AQDataType_ > AQDataType
Definition: gemm_aquant_pipeline_problem.hpp:45
static CK_TILE_HOST const std::string GetName()
Definition: gemm_aquant_pipeline_problem.hpp:74
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentAQ()
Definition: gemm_aquant_pipeline_problem.hpp:86
static constexpr bool kPadN
Definition: gemm_pipeline_problem.hpp:45
static constexpr auto HasHotLoop
Definition: gemm_aquant_pipeline_problem.hpp:68
static constexpr uint32_t kQuantGroupSize
Definition: gemm_aquant_pipeline_problem.hpp:66
static constexpr bool kPadK
Definition: gemm_pipeline_problem.hpp:46
remove_cvref_t< typename Traits::AQLayout > AQLayout
Definition: gemm_aquant_pipeline_problem.hpp:64
static constexpr auto Scheduler
Definition: gemm_aquant_pipeline_problem.hpp:67
static constexpr index_t kBlockSize
Definition: gemm_pipeline_problem.hpp:42
Definition: gemm_pipeline_problem.hpp:22
remove_cvref_t< CDataType_ > CDataType
Definition: gemm_pipeline_problem.hpp:27
static constexpr bool DoubleSmemBuffer
Definition: gemm_pipeline_problem.hpp:48
static constexpr bool kPadM
Definition: gemm_pipeline_problem.hpp:44
static constexpr index_t VectorLoadSize
Definition: gemm_pipeline_problem.hpp:50
remove_cvref_t< ComputeDataType_ > ComputeDataType
Definition: gemm_pipeline_problem.hpp:28
remove_cvref_t< Traits_ > Traits
Definition: gemm_pipeline_problem.hpp:23
remove_cvref_t< BlockGemmShape_ > BlockGemmShape
Definition: gemm_pipeline_problem.hpp:32
remove_cvref_t< typename Traits::ALayout > ALayout
Definition: gemm_pipeline_problem.hpp:34
static constexpr bool kPadN
Definition: gemm_pipeline_problem.hpp:45
remove_cvref_t< typename Traits::CLayout > CLayout
Definition: gemm_pipeline_problem.hpp:36
remove_cvref_t< BDataType_ > BDataType
Definition: gemm_pipeline_problem.hpp:26
remove_cvref_t< ADataType_ > ADataType
Definition: gemm_pipeline_problem.hpp:25
static constexpr bool kPadK
Definition: gemm_pipeline_problem.hpp:46
remove_cvref_t< typename Traits::BLayout > BLayout
Definition: gemm_pipeline_problem.hpp:35
static constexpr index_t kBlockSize
Definition: gemm_pipeline_problem.hpp:42