/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp Source File
gemm_pipeline_problem.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 template <typename ADataType_,
13  typename BDataType_,
14  typename CDataType_,
15  typename BlockGemmShape_,
16  typename Traits_,
17  typename ComputeDataType_ = ADataType_,
18  bool FixedVectorSize_ = false,
19  index_t VectorSizeA_ = 1,
20  index_t VectorSizeB_ = 1>
22 {
24 
29 
30  static constexpr bool FixedVectorSize = FixedVectorSize_;
31 
33 
37 
38  static constexpr bool TransposeC = Traits::TransposeC;
39 
40  static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
41 
42  static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
43 
44  static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
45 
46  static constexpr bool kPadM = Traits::kPadM;
47  static constexpr bool kPadN = Traits::kPadN;
48  static constexpr bool kPadK = Traits::kPadK;
49 
50  static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
51  static constexpr auto Scheduler = GemmPipelineScheduler::Default;
52  static constexpr index_t VectorLoadSize = Traits::_VectorSize;
53 
54  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
55  {
56  // clang-format off
57  return concat('_', "gemm_problem",
59  concat('x', kPadM, kPadN, kPadK),
60  Scheduler);
61  // clang-format on
62  }
63 
64  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
65  {
66  constexpr index_t PackedSize =
68  if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
69  {
70  constexpr index_t pixels_per_thread =
71  BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
72  return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType)
73  ? pixels_per_thread
74  : PackedSize * VectorLoadSize / sizeof(ADataType);
75  }
76  else
77  {
78  return VectorLoadSize / sizeof(ADataType);
79  }
80  }
81 
82  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
83  {
84  constexpr index_t PackedSize =
86  if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
87  {
88  constexpr index_t pixels_per_thread =
89  BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
90  return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType)
91  ? pixels_per_thread
92  : PackedSize * VectorLoadSize / sizeof(BDataType);
93  }
94  else
95  {
96  return PackedSize * VectorLoadSize / sizeof(BDataType);
97  }
98  }
99 
100  CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC()
101  {
102  if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
103  {
104  constexpr index_t N1 = kBlockSize / get_warp_size();
105  constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size());
106  constexpr index_t M0 = get_warp_size() / N2;
107  constexpr index_t M1 = BlockGemmShape::kM / M0;
108 
109  return std::min(M1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
110  }
111  else
112  {
113  constexpr index_t M1 = kBlockSize / get_warp_size();
114  constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size());
115  constexpr index_t N0 = get_warp_size() / M2;
116  constexpr index_t N1 = BlockGemmShape::kN / N0;
117 
118  return std::min(N1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
119  }
120  }
121 
122  static constexpr index_t VectorSizeA = []() {
123  if constexpr(FixedVectorSize)
124  {
125  return VectorSizeA_;
126  }
127  else if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
128  {
129  return kPadK ? 1 : GetAlignmentA();
130  }
131  else
132  {
133  return kPadM ? 1 : GetAlignmentA();
134  }
135  }();
136 
137  static constexpr index_t VectorSizeB = []() {
138  if constexpr(FixedVectorSize)
139  {
140  return VectorSizeB_;
141  }
142  else if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
143  {
144  return kPadN ? 1 : GetAlignmentB();
145  }
146  else
147  {
148  return kPadK ? 1 : GetAlignmentB();
149  }
150  }();
151  static constexpr index_t VectorSizeC = []() {
152  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
153  {
154  return kPadN ? 1 : GetAlignmentC();
155  }
156  else
157  {
158  return kPadM ? 1 : GetAlignmentC();
159  }
160  }();
161 };
162 
163 // Alias for GemmPipelineProblem
164 template <typename ADataType_,
165  typename BDataType_,
166  typename CDataType_,
167  typename BlockGemmShape_,
168  typename Traits_,
169  typename ComputeDataType_ = ADataType_,
170  bool FixedVectorSize_ = false,
171  index_t VectorSizeA_ = 1,
172  index_t VectorSizeB_ = 1>
174  BDataType_,
175  CDataType_,
176  BlockGemmShape_,
177  Traits_,
178  ComputeDataType_,
179  FixedVectorSize_,
180  VectorSizeA_,
181  VectorSizeB_>;
182 
183 template <typename ADataType_,
184  typename BDataType_,
185  typename CDataType_,
186  typename BlockGemmShape_,
187  typename Traits_,
189  bool HasHotLoop_ = true,
190  TailNumber TailNum_ = TailNumber::Full,
191  typename ComputeDataType_ = ADataType_,
192  bool FixedVectorSize_ = false,
193  index_t VectorSizeA_ = 1,
194  index_t VectorSizeB_ = 1>
196 {
198 
203 
204  static constexpr bool FixedVectorSize = FixedVectorSize_;
205  static constexpr index_t VectorSizeA = VectorSizeA_;
206  static constexpr index_t VectorSizeB = VectorSizeB_;
207 
209 
213 
214  static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
215 
216  static constexpr bool kPadM = Traits::kPadM;
217  static constexpr bool kPadN = Traits::kPadN;
218  static constexpr bool kPadK = Traits::kPadK;
219 
220  static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
221 
222  static constexpr auto Scheduler = Scheduler_;
223  static constexpr auto HasHotLoop = HasHotLoop_;
224  static constexpr auto TailNum = TailNum_;
225 
226  static constexpr bool TransposeC = Traits::TransposeC;
227  static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
228 
229  static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
230 };
231 
232 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE index_t get_warp_size()
Definition: arch.hpp:51
TailNumber
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:21
int32_t index_t
Definition: integer.hpp:9
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:41
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
GemmPipelineScheduler
Definition: gemm_pipeline_ag_bg_cr_scheduler.hpp:14
Definition: gemm_pipeline_problem.hpp:22
static constexpr index_t VectorSizeA
Definition: gemm_pipeline_problem.hpp:122
static constexpr index_t VectorSizeB
Definition: gemm_pipeline_problem.hpp:137
remove_cvref_t< CDataType_ > CDataType
Definition: gemm_pipeline_problem.hpp:27
static constexpr bool DoubleSmemBuffer
Definition: gemm_pipeline_problem.hpp:50
static constexpr bool kPadM
Definition: gemm_pipeline_problem.hpp:46
static constexpr index_t VectorLoadSize
Definition: gemm_pipeline_problem.hpp:52
static constexpr auto Scheduler
Definition: gemm_pipeline_problem.hpp:51
remove_cvref_t< ComputeDataType_ > ComputeDataType
Definition: gemm_pipeline_problem.hpp:28
static constexpr bool TransposeC
Definition: gemm_pipeline_problem.hpp:38
remove_cvref_t< Traits_ > Traits
Definition: gemm_pipeline_problem.hpp:23
static constexpr index_t NumWaveGroups
Definition: gemm_pipeline_problem.hpp:40
remove_cvref_t< BlockGemmShape_ > BlockGemmShape
Definition: gemm_pipeline_problem.hpp:32
remove_cvref_t< typename Traits::ALayout > ALayout
Definition: gemm_pipeline_problem.hpp:34
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentB()
Definition: gemm_pipeline_problem.hpp:82
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentC()
Definition: gemm_pipeline_problem.hpp:100
static constexpr CK_TILE_HOST_DEVICE auto GetAlignmentA()
Definition: gemm_pipeline_problem.hpp:64
static constexpr bool kPadN
Definition: gemm_pipeline_problem.hpp:47
remove_cvref_t< typename Traits::CLayout > CLayout
Definition: gemm_pipeline_problem.hpp:36
remove_cvref_t< BDataType_ > BDataType
Definition: gemm_pipeline_problem.hpp:26
static CK_TILE_HOST const std::string GetName()
Definition: gemm_pipeline_problem.hpp:54
remove_cvref_t< ADataType_ > ADataType
Definition: gemm_pipeline_problem.hpp:25
static constexpr bool kPadK
Definition: gemm_pipeline_problem.hpp:48
remove_cvref_t< typename Traits::BLayout > BLayout
Definition: gemm_pipeline_problem.hpp:35
static constexpr bool FixedVectorSize
Definition: gemm_pipeline_problem.hpp:30
static constexpr bool UseStructuredSparsity
Definition: gemm_pipeline_problem.hpp:42
static constexpr index_t VectorSizeC
Definition: gemm_pipeline_problem.hpp:151
static constexpr index_t kBlockSize
Definition: gemm_pipeline_problem.hpp:44
Definition: gemm_pipeline_problem.hpp:196
static constexpr bool TransposeC
Definition: gemm_pipeline_problem.hpp:226
static constexpr index_t VectorSizeB
Definition: gemm_pipeline_problem.hpp:206
static constexpr bool kPadK
Definition: gemm_pipeline_problem.hpp:218
static constexpr index_t VectorSizeA
Definition: gemm_pipeline_problem.hpp:205
remove_cvref_t< typename Traits::ALayout > ALayout
Definition: gemm_pipeline_problem.hpp:210
static constexpr bool DoubleSmemBuffer
Definition: gemm_pipeline_problem.hpp:220
remove_cvref_t< Traits_ > Traits
Definition: gemm_pipeline_problem.hpp:197
remove_cvref_t< typename Traits::CLayout > CLayout
Definition: gemm_pipeline_problem.hpp:212
static constexpr bool UseStructuredSparsity
Definition: gemm_pipeline_problem.hpp:227
remove_cvref_t< BDataType_ > BDataType
Definition: gemm_pipeline_problem.hpp:200
static constexpr bool kPadN
Definition: gemm_pipeline_problem.hpp:217
remove_cvref_t< ADataType_ > ADataType
Definition: gemm_pipeline_problem.hpp:199
static constexpr auto Scheduler
Definition: gemm_pipeline_problem.hpp:222
remove_cvref_t< typename Traits::BLayout > BLayout
Definition: gemm_pipeline_problem.hpp:211
remove_cvref_t< ComputeDataType_ > ComputeDataType
Definition: gemm_pipeline_problem.hpp:202
static constexpr bool kPadM
Definition: gemm_pipeline_problem.hpp:216
remove_cvref_t< CDataType_ > CDataType
Definition: gemm_pipeline_problem.hpp:201
remove_cvref_t< BlockGemmShape_ > BlockGemmShape
Definition: gemm_pipeline_problem.hpp:208
static constexpr bool FixedVectorSize
Definition: gemm_pipeline_problem.hpp:204
static constexpr index_t NumWaveGroups
Definition: gemm_pipeline_problem.hpp:229
static constexpr auto TailNum
Definition: gemm_pipeline_problem.hpp:224
static constexpr index_t kBlockSize
Definition: gemm_pipeline_problem.hpp:214
static constexpr auto HasHotLoop
Definition: gemm_pipeline_problem.hpp:223
Definition: numeric.hpp:81