/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp Source File
gemm_quant_pipeline_problem.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 #include <string>
11 
12 namespace ck_tile {
13 
14 template <typename ADataType_,
15  typename AQDataType_,
16  typename BDataType_,
17  typename BQDataType_,
18  typename CDataType_,
19  typename BlockGemmShape_,
20  typename Traits_,
21  uint32_t QuantGroupSize_,
22  bool TransposeC_,
23  typename ComputeDataType_ = BDataType_,
25  bool HasHotLoop_ = true,
26  TailNumber TailNum_ = TailNumber::Full>
28  BDataType_,
29  CDataType_,
30  BlockGemmShape_,
31  Traits_,
32  ComputeDataType_>
33 {
34  using Base = GemmPipelineProblemBase<ADataType_,
35  BDataType_,
36  CDataType_,
37  BlockGemmShape_,
38  Traits_,
39  ComputeDataType_>;
40 
41  using Traits = typename Base::Traits;
42 
43  using typename Base::ADataType;
44  using typename Base::BDataType;
45  using typename Base::CDataType;
46  using typename Base::ComputeDataType;
49 
51 
52  using typename Base::ALayout;
53  using typename Base::BLayout;
54  using typename Base::CLayout;
55 
56  static constexpr bool TransposeC = TransposeC_;
57  static constexpr bool PreshuffleB = Traits::PreshuffleB;
58  static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
59  using Base::kBlockSize;
60 
61  using Base::kPadK;
62  using Base::kPadM;
63  using Base::kPadN;
64 
66 
69 
70  static constexpr uint32_t kQuantGroupSize = QuantGroupSize_;
71  static constexpr auto Scheduler = Scheduler_;
72  static constexpr auto HasHotLoop = HasHotLoop_;
73  static constexpr auto TailNum = TailNum_;
74 
75  static_assert(BlockGemmShape::kK % kQuantGroupSize == 0);
76 
77  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
78  {
79  // clang-format off
80  return concat('_', "gemm_quant_problem",
82  concat('x', kPadM, kPadN, kPadK),
83  Scheduler,
84  "QuantGroupSize",
86  // clang-format on
87  }
88 
89  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentAQ()
90  {
91  static_assert(std::is_same_v<AQLayout, tensor_layout::gemm::RowMajor>);
92  return VectorLoadSize / sizeof(AQDataType);
93  }
94 
95  static constexpr index_t VectorSizeAQ = []() {
96  static_assert(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>);
97  return kPadK ? 1 : GetAlignmentAQ();
98  }();
99 
100  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBQ()
101  {
102  return VectorLoadSize / sizeof(BQDataType);
103  }
104 
105  static constexpr index_t VectorSizeBQ = []() { return kPadK ? 1 : GetAlignmentBQ(); }();
106 };
107 
108 template <typename ADataType_,
109  typename AQDataType_,
110  typename BDataType_,
111  typename CDataType_,
112  typename BlockGemmShape_,
113  typename Traits_,
114  uint32_t QuantGroupSize_,
115  bool TransposeC_,
116  typename ComputeDataType_ = BDataType_,
118  bool HasHotLoop_ = true,
119  TailNumber TailNum_ = TailNumber::Full>
121  AQDataType_,
122  BDataType_,
123  void, // no BQDataType for AQuant
124  CDataType_,
125  BlockGemmShape_,
126  Traits_,
127  QuantGroupSize_,
128  TransposeC_,
129  ComputeDataType_,
130  Scheduler_,
131  HasHotLoop_,
132  TailNum_>;
133 
134 template <typename ADataType_,
135  typename BDataType_,
136  typename BQDataType_,
137  typename CDataType_,
138  typename BlockGemmShape_,
139  typename Traits_,
140  uint32_t QuantGroupSize_,
141  typename ComputeDataType_ = ADataType_,
143  bool HasHotLoop_ = true,
144  TailNumber TailNum_ = TailNumber::Full>
146  void, // no AQDataType for BQuant
147  BDataType_,
148  BQDataType_,
149  CDataType_,
150  BlockGemmShape_,
151  Traits_,
152  QuantGroupSize_,
153  false, // no TransposeC
154  ComputeDataType_,
155  Scheduler_,
156  HasHotLoop_,
157  TailNum_>;
158 
159 template <typename ADataType_,
160  typename BDataType_,
161  typename CDataType_,
162  typename AccDataType_,
163  typename BlockGemmShape_,
164  typename Traits_,
165  bool TransposeC_ = false,
166  typename ComputeDataType_ = BDataType_,
168  bool HasHotLoop_ = true,
169  TailNumber TailNum_ = TailNumber::Full>
171  GemmQuantPipelineProblemBase<ADataType_,
172  AccDataType_,
173  BDataType_,
174  AccDataType_,
175  CDataType_,
176  BlockGemmShape_,
177  Traits_,
178  1, // no group size applicable
179  TransposeC_,
180  ComputeDataType_,
181  Scheduler_,
182  HasHotLoop_,
183  TailNum_>;
184 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
TailNumber
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:21
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
GemmPipelineScheduler
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:14
unsigned int uint32_t
Definition: stdint.h:126
Definition: gemm_pipeline_problem.hpp:25
remove_cvref_t< BlockGemmShape_ > BlockGemmShape
Definition: gemm_pipeline_problem.hpp:34
static constexpr bool kPadK
Definition: gemm_pipeline_problem.hpp:80
remove_cvref_t< std::tuple_element_t< number< 0 >{}, ComputeDataTypeTuple > > ComputeDataType
Definition: gemm_pipeline_problem.hpp:66
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsLayoutTuple > > ALayout
Definition: gemm_pipeline_problem.hpp:68
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataTypeTuple > > BDataType
Definition: gemm_pipeline_problem.hpp:69
static constexpr index_t kBlockSize
Definition: gemm_pipeline_problem.hpp:76
static constexpr bool kPadM
Definition: gemm_pipeline_problem.hpp:78
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsDataTypeTuple > > ADataType
Definition: gemm_pipeline_problem.hpp:67
remove_cvref_t< typename Traits::CLayout > CLayout
Definition: gemm_pipeline_problem.hpp:41
remove_cvref_t< EDataType_ > CDataType
Definition: gemm_pipeline_problem.hpp:30
remove_cvref_t< Traits_ > Traits
Definition: gemm_pipeline_problem.hpp:26
static constexpr bool kPadN
Definition: gemm_pipeline_problem.hpp:79
static constexpr index_t VectorLoadSize
Definition: gemm_pipeline_problem.hpp:84
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsLayoutTuple > > BLayout
Definition: gemm_pipeline_problem.hpp:70
Definition: gemm_quant_pipeline_problem.hpp:33
static constexpr uint32_t kQuantGroupSize
Definition: gemm_quant_pipeline_problem.hpp:70
static constexpr bool kPadK
Definition: gemm_pipeline_problem.hpp:80
static constexpr bool PreshuffleB
Definition: gemm_quant_pipeline_problem.hpp:57
remove_cvref_t< BQDataType_ > BQDataType
Definition: gemm_quant_pipeline_problem.hpp:48
remove_cvref_t< AQDataType_ > AQDataType
Definition: gemm_quant_pipeline_problem.hpp:47
remove_cvref_t< typename Traits::BQLayout > BQLayout
Definition: gemm_quant_pipeline_problem.hpp:68
static constexpr bool TransposeC
Definition: gemm_quant_pipeline_problem.hpp:56
remove_cvref_t< typename Traits::AQLayout > AQLayout
Definition: gemm_quant_pipeline_problem.hpp:67
typename Base::BlockGemmShape BlockGemmShape
Definition: gemm_quant_pipeline_problem.hpp:50
static constexpr auto Scheduler
Definition: gemm_quant_pipeline_problem.hpp:71
static constexpr index_t kBlockSize
Definition: gemm_pipeline_problem.hpp:76
static constexpr bool kPadM
Definition: gemm_pipeline_problem.hpp:78
static constexpr index_t VectorSizeBQ
Definition: gemm_quant_pipeline_problem.hpp:105
static constexpr auto TailNum
Definition: gemm_quant_pipeline_problem.hpp:73
static constexpr bool DoubleSmemBuffer
Definition: gemm_quant_pipeline_problem.hpp:58
static CK_TILE_HOST const std::string GetName()
Definition: gemm_quant_pipeline_problem.hpp:77
typename Base::Traits Traits
Definition: gemm_quant_pipeline_problem.hpp:41
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentAQ()
Definition: gemm_quant_pipeline_problem.hpp:89
static constexpr index_t VectorSizeAQ
Definition: gemm_quant_pipeline_problem.hpp:95
static constexpr auto HasHotLoop
Definition: gemm_quant_pipeline_problem.hpp:72
static constexpr bool kPadN
Definition: gemm_pipeline_problem.hpp:79
static constexpr index_t VectorLoadSize
Definition: gemm_pipeline_problem.hpp:84
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentBQ()
Definition: gemm_quant_pipeline_problem.hpp:100