/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp Source File
block_fmha_pipeline_problem.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 
8 namespace ck_tile {
9 
10 template <typename QDataType_,
11  typename KDataType_,
12  typename VDataType_,
13  typename SaccDataType_,
14  typename SMPLComputeDataType_,
15  typename BiasDataType_,
16  typename RandValOutputDataType_,
17  typename LSEDataType_,
18  typename PDataType_,
19  typename OaccDataType_,
20  typename ODataType_,
21  typename BlockFmhaShape_,
22  bool kIsGroupMode_,
23  typename AttentionVariant_,
24  typename FmhaMask_,
25  typename Traits_>
27 {
43 
44  static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
45  static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
46  static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
47 
48  static constexpr bool kIsGroupMode = kIsGroupMode_;
49 
50  // attributes from traits
51  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
52  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
53  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
54  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
55  static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
56  static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
57  static constexpr auto BiasEnum = Traits::BiasEnum;
58  static constexpr bool kStoreLSE = Traits::kStoreLSE;
59  static constexpr bool kHasDropout = Traits::kHasDropout;
60  static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
61  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
62 };
63 
64 template <typename QDataType_,
65  typename KDataType_,
66  typename VDataType_,
67  typename SaccDataType_,
68  typename SMPLComputeDataType_,
69  typename BiasDataType_,
70  typename LSEDataType_,
71  typename PDataType_,
72  typename OaccDataType_,
73  typename ODataType_,
74  typename BlockFmhaShape_,
75  bool kIsGroupMode_,
76  typename AttentionVariant_,
77  typename FmhaMask_,
78  typename Traits_>
80 {
95 
96  static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
97  static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
98  static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
99 
100  static constexpr bool kIsGroupMode = kIsGroupMode_;
101 
102  // attributes from traits
103  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
104  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
105  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
106  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
107  static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
108  static constexpr bool kSkipMinSeqlenQ = Traits::kSkipMinSeqlenQ;
109  static constexpr auto BiasEnum = Traits::BiasEnum;
110  static constexpr bool kStoreLSE = Traits::kStoreLSE;
111  static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
112  static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
113  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
114 };
115 
116 template <typename QDataType_,
117  typename KDataType_,
118  typename VDataType_,
119  typename SaccDataType_,
120  typename SMPLComputeDataType_,
121  typename BiasDataType_,
122  typename LSEDataType_,
123  typename PDataType_,
124  typename OaccDataType_,
125  typename ODataType_,
126  typename BlockFmhaShape_,
127  bool kIsGroupMode_,
128  typename AttentionVariant_,
129  typename FmhaMask_,
130  typename Traits_>
132 {
147 
148  static constexpr index_t kNumGemm0Warps = BlockFmhaShape::NumGemm0Warps;
149  static constexpr index_t kNumGemm1Warps = BlockFmhaShape::NumGemm1Warps;
150  static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
151 
152  static constexpr bool kIsGroupMode = kIsGroupMode_;
153 
154  // attributes from traits
155  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
156  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
157  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
158  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
159  static constexpr bool kHasLogitsSoftCap = Traits::kHasLogitsSoftCap;
160  static constexpr auto BiasEnum = Traits::BiasEnum;
161  static constexpr bool kStoreLSE = Traits::kStoreLSE;
162  static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
163  static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
164  static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
165  static constexpr bool kMergeNumHeadGroupsSeqLenQ = Traits::kMergeNumHeadGroupsSeqLenQ;
166  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
167 };
168 
169 // extract tile size attributes to remove dependency on traits
170 template <typename OaccDataType_, ck_tile::index_t kN1_>
172 {
173  static constexpr index_t MaxVectorSize = 16 / sizeof(OaccDataType_);
174 
175  static constexpr index_t kN1 = kN1_;
176  static constexpr index_t NThreads = kN1 / MaxVectorSize;
177  static constexpr index_t kM0 = get_warp_size() / NThreads; // MThreadPerWarp
178 };
179 
180 template <typename LSEDataType_,
181  typename OaccDataType_,
182  typename ODataType_,
183  index_t HeadDimV_,
184  bool kIsGroupMode_,
185  ck_tile::index_t kN1_,
186  typename Traits_>
188  : BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType_, kN1_>
189 {
191 
196 
197  static_assert(std::is_same_v<LSEDataType, OaccDataType>);
198 
199  static constexpr index_t kHeadDimV = HeadDimV_;
200  static constexpr bool kIsGroupMode = kIsGroupMode_;
201 
202  using BaseType::kM0;
203  using BaseType::kN1;
204 
205  static_assert(kN1 <= kHeadDimV && kHeadDimV % kN1 == 0);
206 
207  // attributes from traits
208  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
209  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
210  static constexpr bool kStoreLSE = Traits::kStoreLSE;
211  static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
212  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
213  static constexpr index_t kMaxSplits = Traits::kMaxSplits;
214  static_assert(8 <= kMaxSplits);
215 
216  static constexpr index_t kNumWarps = 4; // always use 4 warps for each workgroup
217  static constexpr index_t kBlockSize = kNumWarps * get_warp_size();
218 
219  static_assert(get_warp_size() <= (kM0 * kMaxSplits) &&
220  (kM0 * kMaxSplits) % get_warp_size() == 0);
221 };
222 
223 template <typename QDataType_,
224  typename KDataType_,
225  typename VDataType_,
226  index_t kM0_,
227  index_t kN0_,
228  index_t kK0_,
229  index_t kN1_,
230  bool kIsVLayoutRowMajor_,
231  RotaryEmbeddingEnum RotaryEnum_,
232  bool kIsPagedKV_,
233  typename Traits_>
235 {
240 
241  static constexpr index_t kBlockSize = 256;
242 
243  static constexpr index_t kM0 = kM0_;
244  static constexpr index_t kN0 = kN0_;
245  static constexpr index_t kK0 = kK0_;
246  static constexpr index_t kN1 = kN1_;
247 
248  using VLayout = std::conditional_t<kIsVLayoutRowMajor_,
251 
252  static constexpr auto RotaryEnum = RotaryEnum_;
253  static constexpr bool kIsPagedKV = kIsPagedKV_;
254 
255  // attributes from traits
256  static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
257  static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
258  static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
259  static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
260  static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
261 };
262 
263 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE index_t get_warp_size()
Definition: arch.hpp:51
RotaryEmbeddingEnum
Definition: block_rotary_embedding.hpp:12
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: block_fmha_pipeline_problem.hpp:235
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:236
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:257
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:258
std::conditional_t< kIsVLayoutRowMajor_, ck_tile::tensor_layout::gemm::RowMajor, ck_tile::tensor_layout::gemm::ColumnMajor > VLayout
Definition: block_fmha_pipeline_problem.hpp:250
static constexpr auto RotaryEnum
Definition: block_fmha_pipeline_problem.hpp:252
static constexpr index_t kK0
Definition: block_fmha_pipeline_problem.hpp:245
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:256
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:239
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:253
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:238
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:243
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:246
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:260
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:241
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:259
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:237
static constexpr index_t kN0
Definition: block_fmha_pipeline_problem.hpp:244
Definition: block_fmha_pipeline_problem.hpp:80
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:85
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:105
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:111
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:94
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:113
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:91
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:106
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:83
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition: block_fmha_pipeline_problem.hpp:92
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:93
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:88
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:84
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:82
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:109
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:81
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:86
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:89
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_problem.hpp:107
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:103
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:98
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:104
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:90
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:96
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:100
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:87
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:112
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:110
static constexpr bool kSkipMinSeqlenQ
Definition: block_fmha_pipeline_problem.hpp:108
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:97
Definition: block_fmha_pipeline_problem.hpp:132
static constexpr bool kHasUnevenSplits
Definition: block_fmha_pipeline_problem.hpp:164
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:135
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_problem.hpp:159
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:145
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:157
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:162
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:148
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:133
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:141
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:139
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:152
static constexpr bool kMergeNumHeadGroupsSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:165
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:149
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:136
static constexpr bool kIsPagedKV
Definition: block_fmha_pipeline_problem.hpp:163
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:137
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:143
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:134
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:155
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:150
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:166
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:140
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:142
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:160
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition: block_fmha_pipeline_problem.hpp:144
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:156
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:161
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:138
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:158
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:146
Definition: block_fmha_pipeline_problem.hpp:27
remove_cvref_t< AttentionVariant_ > AttentionVariant
Definition: block_fmha_pipeline_problem.hpp:40
remove_cvref_t< KDataType_ > KDataType
Definition: block_fmha_pipeline_problem.hpp:29
remove_cvref_t< SaccDataType_ > SaccDataType
Definition: block_fmha_pipeline_problem.hpp:31
static constexpr bool kSkipMinSeqlenQ
Definition: block_fmha_pipeline_problem.hpp:56
remove_cvref_t< SMPLComputeDataType_ > SMPLComputeDataType
Definition: block_fmha_pipeline_problem.hpp:32
remove_cvref_t< FmhaMask_ > FmhaMask
Definition: block_fmha_pipeline_problem.hpp:41
remove_cvref_t< BiasDataType_ > BiasDataType
Definition: block_fmha_pipeline_problem.hpp:33
static constexpr bool kHasDropout
Definition: block_fmha_pipeline_problem.hpp:59
remove_cvref_t< QDataType_ > QDataType
Definition: block_fmha_pipeline_problem.hpp:28
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:38
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:61
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:58
remove_cvref_t< BlockFmhaShape_ > BlockFmhaShape
Definition: block_fmha_pipeline_problem.hpp:39
static constexpr auto BiasEnum
Definition: block_fmha_pipeline_problem.hpp:57
static constexpr bool kPadHeadDimQ
Definition: block_fmha_pipeline_problem.hpp:53
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:35
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:42
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:54
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:37
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:48
static constexpr bool kPadSeqLenK
Definition: block_fmha_pipeline_problem.hpp:52
remove_cvref_t< PDataType_ > PDataType
Definition: block_fmha_pipeline_problem.hpp:36
static constexpr index_t kNumGemm0Warps
Definition: block_fmha_pipeline_problem.hpp:44
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_pipeline_problem.hpp:55
remove_cvref_t< RandValOutputDataType_ > RandValOutputDataType
Definition: block_fmha_pipeline_problem.hpp:34
static constexpr index_t kNumGemm1Warps
Definition: block_fmha_pipeline_problem.hpp:45
remove_cvref_t< VDataType_ > VDataType
Definition: block_fmha_pipeline_problem.hpp:30
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:60
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:46
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:51
Definition: block_fmha_pipeline_problem.hpp:189
remove_cvref_t< ODataType_ > ODataType
Definition: block_fmha_pipeline_problem.hpp:194
static constexpr index_t kNumWarps
Definition: block_fmha_pipeline_problem.hpp:216
remove_cvref_t< Traits_ > Traits
Definition: block_fmha_pipeline_problem.hpp:195
static constexpr index_t kHeadDimV
Definition: block_fmha_pipeline_problem.hpp:199
static constexpr index_t kBlockSize
Definition: block_fmha_pipeline_problem.hpp:217
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:177
static constexpr index_t kBlockPerCu
Definition: block_fmha_pipeline_problem.hpp:212
static constexpr bool kIsGroupMode
Definition: block_fmha_pipeline_problem.hpp:200
static constexpr index_t kMaxSplits
Definition: block_fmha_pipeline_problem.hpp:213
static constexpr bool kPadHeadDimV
Definition: block_fmha_pipeline_problem.hpp:209
static constexpr bool kStoreLSE
Definition: block_fmha_pipeline_problem.hpp:210
static constexpr bool kDoFp8StaticQuant
Definition: block_fmha_pipeline_problem.hpp:211
remove_cvref_t< LSEDataType_ > LSEDataType
Definition: block_fmha_pipeline_problem.hpp:192
static constexpr bool kPadSeqLenQ
Definition: block_fmha_pipeline_problem.hpp:208
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:175
remove_cvref_t< OaccDataType_ > OaccDataType
Definition: block_fmha_pipeline_problem.hpp:193
Definition: block_fmha_pipeline_problem.hpp:172
static constexpr index_t NThreads
Definition: block_fmha_pipeline_problem.hpp:176
static constexpr index_t kM0
Definition: block_fmha_pipeline_problem.hpp:177
static constexpr index_t MaxVectorSize
Definition: block_fmha_pipeline_problem.hpp:173
static constexpr index_t kN1
Definition: block_fmha_pipeline_problem.hpp:175
Definition: tensor_layout.hpp:22
Definition: tensor_layout.hpp:17