/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp Source File
tile_fmha_traits.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
10 
11 namespace ck_tile {
12 
13 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
14  bool kPadSeqLenK_ /* padding for seqlen_k */,
15  bool kPadHeadDimQ_ /* paddding for hdim_q */,
16  bool kPadHeadDimV_ /* paddding for hdim_v */,
17  bool kHasLogitsSoftCap_,
18  BlockAttentionBiasEnum BiasEnum_,
19  bool kHasBiasGrad_,
20  bool kStoreLSE_,
21  bool kHasDropout_,
22  BlockAttentionQuantScaleEnum QScaleEnum_,
23  index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
24  bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
26 {
27  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
28  static constexpr bool kPadSeqLenK = kPadSeqLenK_;
29  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
30  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
31  static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
32  static constexpr auto BiasEnum = BiasEnum_;
33  static constexpr bool kHasBiasGrad = kHasBiasGrad_;
34  static constexpr bool kStoreLSE = kStoreLSE_;
35  static constexpr bool kHasDropout = kHasDropout_;
36  static constexpr auto QScaleEnum = QScaleEnum_;
37  static constexpr index_t kBlockPerCu = kBlockPerCu_;
38  static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
39 };
40 
41 template <index_t kPadHeadDimQ_ /* paddding for hdim_q */,
42  index_t kPadHeadDimV_ /* paddding for hdim_v */,
43  BlockAttentionBiasEnum BiasEnum_,
44  bool kHasBiasGrad_,
45  index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
47 {
48  static constexpr index_t kPadHeadDimQ = kPadHeadDimQ_;
49  static constexpr index_t kPadHeadDimV = kPadHeadDimV_;
50  static constexpr auto BiasEnum = BiasEnum_;
51  static constexpr bool kHasBiasGrad = kHasBiasGrad_;
52  static constexpr index_t kBlockPerCu = kBlockPerCu_;
53 
54  static_assert(kPadHeadDimQ == 0 || kPadHeadDimQ == 8 || kPadHeadDimQ == 1);
55  static_assert(kPadHeadDimV == 0 || kPadHeadDimV == 8 || kPadHeadDimV == 1);
56 };
57 
58 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
59  bool kPadSeqLenK_ /* padding for seqlen_k */,
60  bool kPadHeadDimQ_ /* paddding for hdim_q */,
61  bool kPadHeadDimV_ /* paddding for hdim_v */,
62  bool kHasLogitsSoftCap_,
63  BlockAttentionBiasEnum BiasEnum_,
64  bool kHasBiasGrad_,
65  bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
66  bool kIsPagedKV_,
67  bool kDoFp8StaticQuant_,
68  index_t kBlockPerCu_ = -1, /* overwrite occupancy if not -1 */
69  bool kSkipMinSeqlenQ_ = false /* skip min seqlen q while chunked prefill */>
71 {
72  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
73  static constexpr bool kPadSeqLenK = kPadSeqLenK_;
74  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
75  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
76  static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
77  static constexpr auto BiasEnum = BiasEnum_;
78  static constexpr bool kHasBiasGrad = kHasBiasGrad_;
79  static constexpr bool kStoreLSE = kStoreLSE_;
80  static constexpr bool kIsPagedKV = kIsPagedKV_;
81  static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
82  static constexpr index_t kBlockPerCu = kBlockPerCu_;
83  static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
84 };
85 
86 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
87  bool kPadSeqLenK_ /* padding for seqlen_k */,
88  bool kPadHeadDimQ_ /* paddding for hdim_q */,
89  bool kPadHeadDimV_ /* paddding for hdim_v */,
90  bool kHasLogitsSoftCap_,
91  BlockAttentionBiasEnum BiasEnum_,
92  bool kHasBiasGrad_,
93  bool kStoreLSE_, /* set to true if either num_splits > 1 or fwd training is running */
94  bool kDoFp8StaticQuant_,
95  bool kIsPagedKV_,
96  bool kHasUnevenSplits_,
97  bool kMergeNumHeadGroupsSeqLenQ_ = false,
98  index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
100 {
101  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
102  static constexpr bool kPadSeqLenK = kPadSeqLenK_;
103  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
104  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
105  static constexpr bool kHasLogitsSoftCap = kHasLogitsSoftCap_;
106  static constexpr auto BiasEnum = BiasEnum_;
107  static constexpr bool kHasBiasGrad = kHasBiasGrad_;
108  static constexpr bool kStoreLSE = kStoreLSE_;
109  static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
110  static constexpr bool kIsPagedKV = kIsPagedKV_;
111  // determine if some split (length) is not divisible by tile size
112  static constexpr bool kHasUnevenSplits = kHasUnevenSplits_;
113  static constexpr bool kMergeNumHeadGroupsSeqLenQ = kMergeNumHeadGroupsSeqLenQ_;
114  static constexpr index_t kBlockPerCu = kBlockPerCu_;
115 };
116 
117 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
118  bool kPadHeadDimV_ /* paddding for hdim_v */,
119  bool kStoreLSE_,
120  bool kDoFp8StaticQuant_,
121  index_t kLogMaxSplits_,
122  index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
124 {
125  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
126  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
127  static constexpr bool kStoreLSE = kStoreLSE_;
128  static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
129 
130  static constexpr index_t kMaxSplits = (1 << kLogMaxSplits_);
131  static_assert(kMaxSplits <= get_warp_size() || kMaxSplits % get_warp_size() == 0);
132  static constexpr index_t kBlockPerCu = kBlockPerCu_;
133 };
134 
135 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
136  bool kPadSeqLenK_ /* padding for seqlen_k */,
137  bool kPadHeadDimQ_ /* paddding for hdim_q */,
138  bool kPadHeadDimV_ /* paddding for hdim_v */,
139  index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
141 {
142  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
143  static constexpr bool kPadSeqLenK = kPadSeqLenK_;
144  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
145  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
146  static constexpr index_t kBlockPerCu = kBlockPerCu_;
147 };
148 
149 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
150  bool kPadHeadDimV_ /* paddding for hdim_v */,
151  index_t kBlockPerCu_ = 2 /* hint to occupancy */>
153 {
154  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
155  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
156  static constexpr index_t kBlockPerCu = kBlockPerCu_;
157 };
158 
159 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
160  bool kPadHeadDimQ_ /* paddding for hdim_q */,
161  index_t kBlockPerCu_ = 2 /* hint to occupancy */>
163 {
164  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
165  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
166  static constexpr index_t kBlockPerCu = kBlockPerCu_;
167 };
168 
169 template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
170  bool kPadSeqLenK_ /* padding for seqlen_k */,
171  bool kPadHeadDimQ_ /* paddding for hdim_q */,
172  bool kPadHeadDimV_ /* paddding for hdim_v */,
173  bool kStoreLSE_,
174  index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
176 {
177  static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
178  static constexpr bool kPadSeqLenK = kPadSeqLenK_;
179  static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
180  static constexpr bool kPadHeadDimV = kPadHeadDimV_;
181  static constexpr bool kStoreLSE = kStoreLSE_;
182  static constexpr index_t kBlockPerCu = kBlockPerCu_;
183 };
184 
185 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
BlockAttentionBiasEnum
Definition: block_attention_bias_enum.hpp:12
int32_t index_t
Definition: integer.hpp:9
BlockAttentionQuantScaleEnum
Definition: block_attention_quant_scale_enum.hpp:12
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
Definition: tile_fmha_traits.hpp:163
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:166
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:165
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:164
Definition: tile_fmha_traits.hpp:153
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:156
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:154
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:155
Definition: tile_fmha_traits.hpp:47
static constexpr index_t kPadHeadDimQ
Definition: tile_fmha_traits.hpp:48
static constexpr index_t kPadHeadDimV
Definition: tile_fmha_traits.hpp:49
static constexpr bool kHasBiasGrad
Definition: tile_fmha_traits.hpp:51
static constexpr auto BiasEnum
Definition: tile_fmha_traits.hpp:50
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:52
Definition: tile_fmha_traits.hpp:141
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:144
static constexpr bool kPadSeqLenK
Definition: tile_fmha_traits.hpp:143
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:146
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:142
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:145
Definition: tile_fmha_traits.hpp:71
static constexpr auto BiasEnum
Definition: tile_fmha_traits.hpp:77
static constexpr bool kDoFp8StaticQuant
Definition: tile_fmha_traits.hpp:81
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:74
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:82
static constexpr bool kHasLogitsSoftCap
Definition: tile_fmha_traits.hpp:76
static constexpr bool kSkipMinSeqlenQ
Definition: tile_fmha_traits.hpp:83
static constexpr bool kStoreLSE
Definition: tile_fmha_traits.hpp:79
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:72
static constexpr bool kIsPagedKV
Definition: tile_fmha_traits.hpp:80
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:75
static constexpr bool kPadSeqLenK
Definition: tile_fmha_traits.hpp:73
static constexpr bool kHasBiasGrad
Definition: tile_fmha_traits.hpp:78
Definition: tile_fmha_traits.hpp:124
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:125
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:126
static constexpr bool kDoFp8StaticQuant
Definition: tile_fmha_traits.hpp:128
static constexpr index_t kMaxSplits
Definition: tile_fmha_traits.hpp:130
static constexpr bool kStoreLSE
Definition: tile_fmha_traits.hpp:127
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:132
Definition: tile_fmha_traits.hpp:100
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:114
static constexpr bool kStoreLSE
Definition: tile_fmha_traits.hpp:108
static constexpr bool kMergeNumHeadGroupsSeqLenQ
Definition: tile_fmha_traits.hpp:113
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:103
static constexpr bool kHasLogitsSoftCap
Definition: tile_fmha_traits.hpp:105
static constexpr bool kHasUnevenSplits
Definition: tile_fmha_traits.hpp:112
static constexpr bool kPadSeqLenK
Definition: tile_fmha_traits.hpp:102
static constexpr bool kHasBiasGrad
Definition: tile_fmha_traits.hpp:107
static constexpr bool kDoFp8StaticQuant
Definition: tile_fmha_traits.hpp:109
static constexpr auto BiasEnum
Definition: tile_fmha_traits.hpp:106
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:104
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:101
static constexpr bool kIsPagedKV
Definition: tile_fmha_traits.hpp:110
Definition: tile_fmha_traits.hpp:176
static constexpr bool kStoreLSE
Definition: tile_fmha_traits.hpp:181
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:179
static constexpr bool kPadSeqLenK
Definition: tile_fmha_traits.hpp:178
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:180
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:182
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:177
Definition: tile_fmha_traits.hpp:26
static constexpr auto QScaleEnum
Definition: tile_fmha_traits.hpp:36
static constexpr bool kHasBiasGrad
Definition: tile_fmha_traits.hpp:33
static constexpr bool kStoreLSE
Definition: tile_fmha_traits.hpp:34
static constexpr auto BiasEnum
Definition: tile_fmha_traits.hpp:32
static constexpr bool kPadSeqLenK
Definition: tile_fmha_traits.hpp:28
static constexpr bool kPadHeadDimQ
Definition: tile_fmha_traits.hpp:29
static constexpr bool kPadSeqLenQ
Definition: tile_fmha_traits.hpp:27
static constexpr bool kHasLogitsSoftCap
Definition: tile_fmha_traits.hpp:31
static constexpr bool kPadHeadDimV
Definition: tile_fmha_traits.hpp:30
static constexpr bool kSkipMinSeqlenQ
Definition: tile_fmha_traits.hpp:38
static constexpr index_t kBlockPerCu
Definition: tile_fmha_traits.hpp:37
static constexpr bool kHasDropout
Definition: tile_fmha_traits.hpp:35