/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp Source File
fmha_fwd_splitkv_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
10 
11 #include <string>
12 #include <type_traits>
13 
14 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
15 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
16 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
17 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
18 // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
19 
20 namespace ck_tile {
21 
22 template <typename FmhaPipeline_, typename EpiloguePipeline_>
24 {
27  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
28  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
29  static_assert(kBlockPerCu > 0);
30  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
31 
40 
42 
43  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
44  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
45  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
46  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
47  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
48  static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
49  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
50  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
51  static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
52  static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
53  static constexpr bool kMergeNumHeadGroupsSeqLenQ =
54  FmhaPipeline::Problem::kMergeNumHeadGroupsSeqLenQ;
57  static constexpr bool kHasMask = FmhaMask::IsMasking;
58 
59  static_assert(!kMergeNumHeadGroupsSeqLenQ ||
61  !kHasMask));
62 
63  // clang-format off
64  template <typename T> struct t2s;
65  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
66  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
67  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
68  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
69  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
70  // clang-format on
71 
72  __host__ static std::string GetName()
73  {
74  // sync with generate.py
75  // clang-format off
76  using bfs = typename FmhaPipeline::BlockFmhaShape;
77  using g0br = typename bfs::Gemm0BlockWarps;
78  using g1br = typename bfs::Gemm1BlockWarps;
79  using g0wt = typename bfs::Gemm0WarpTile;
80  using g1wt = typename bfs::Gemm1WarpTile;
81  #define _SS_ std::string
82  #define _TS_ std::to_string
83  auto pn = [&] () {
84  std::string n;
85  if (kPadSeqLenQ) n += "s";
86  if (kPadSeqLenK) n += "sk";
87  if (kPadHeadDimQ) n += "d";
88  if (kPadHeadDimV) n += "dv";
89  return n.empty() ? n : std::string("p") + n; }();
90  return
91  _SS_("fmha_fwd_splitkv_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
92  "_" + (kIsGroupMode ? "group" : "batch") + "_"
93  "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
94  _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
95  "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
96  "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
97  "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
98  "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
99  (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
100  "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
101  (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
102  (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) +
103  (kDoFp8StaticQuant ? "_squant" : "_nsquant") + (kIsPagedKV ? "_pagedkv" : "_npagedkv" );
104  #undef _SS_
105  #undef _TS_
106  // clang-format on
107  }
108 
109  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
110  // arg
111  struct EmptyKargs
112  {
113  };
114 
115  // kargs use aggregate initializer, so no constructor will provided
116  // use inheritance to minimize karg size
117  // user need to use MakeKargs() function to create kargs.
118  struct CommonKargs
119  {
120  const void* q_ptr;
121  const void* k_ptr;
122  const void* v_ptr;
123  void* lse_acc_ptr;
124  void* o_acc_ptr;
125 
127 
132 
134  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
135  // if this param is larger than 1, indicate MQA/GQA case
138 
139  float scale_s;
140 
145 
151 
154  };
155 
157  {
158  LogitsSoftCapKargs() = default;
159 
160  void init_logits_soft_cap(float logits_soft_cap_)
161  {
162  if(0 < logits_soft_cap_)
163  {
164  logits_soft_cap = logits_soft_cap_;
166  }
167  else
168  {
169  logits_soft_cap = 0.f;
170  logits_soft_cap_rcp = 0.f;
171  }
172  }
173 
176  };
177 
179  {
180  const void* bias_ptr = nullptr;
183  };
184 
186  {
188  };
189 
190  struct AlibiKargs
191  {
192  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
193  const void* alibi_slope_ptr;
194  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
195  };
196 
197  struct MaskKargs
198  {
199  // ck_tile::index_t window_size_left, window_size_right;
202  };
203 
205  {
206  float scale_p;
207  };
208 
210  {
214  };
215 
217  {
218  bool is_gappy = false;
219  };
220 
222  {
224  };
225 
227  : CommonKargs,
228  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
229  BatchModeBiasKargs,
230  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
231  AlibiKargs,
232  EmptyKargs<0>>>,
233  std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
234  std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
235  std::conditional_t<kIsPagedKV, CommonPageBlockTableKargs, CacheBatchIdxKargs>,
236  std::conditional_t<kHasLogitsSoftCap, LogitsSoftCapKargs, EmptyKargs<3>>
237  {
239 
241  ck_tile::index_t batch_stride_k; // when using paged-kvcache, this will be stride/size for
242  // single kcache page-block
243  ck_tile::index_t batch_stride_v; // when using paged-kvcache, this will be stride/size for
244  // single vcache page-block
247  };
248 
250  : CommonKargs,
251  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
252  CommonBiasKargs,
253  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
254  AlibiKargs,
255  EmptyKargs<0>>>,
256  std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
257  std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
258  std::conditional_t<kIsPagedKV, GroupModePageBlockTableKargs, EmptyKargs<3>>,
259  std::conditional_t<kHasLogitsSoftCap, LogitsSoftCapKargs, EmptyKargs<4>>
260  {
264 
265  ck_tile::index_t batch_stride_k; // only used for paged-kvcache, this will be stride/size
266  // for single kcache page-block
267  ck_tile::index_t batch_stride_v; // only used for paged-kvcache, this will be stride/size
268  // for single vcache page-block
269  };
270 
271  using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
272 
274  {
278  };
279 
280  template <bool Cond = !kIsGroupMode>
281  __host__ static constexpr std::enable_if_t<Cond, Kargs>
282  MakeKargs(const void* q_ptr,
283  const void* k_ptr,
284  const void* v_ptr,
285  const void* bias_ptr,
286  void* lse_acc_ptr, /* workspace for lse accumulation when num_splits > 1, otherwise
287  final lse */
288  void* o_acc_ptr, /* workspace for o accumulation when num_splits > 1, otherwise final
289  o */
290  ck_tile::index_t batch,
291  ck_tile::index_t seqlen_q,
292  ck_tile::index_t seqlen_k, // only used if 'seqlen_k_ptr' is not specified
293  const void* seqlen_k_ptr, // only used for (paged-) kvcache
294  ck_tile::index_t hdim_q,
295  ck_tile::index_t hdim_v,
296  ck_tile::index_t num_head_q,
297  ck_tile::index_t nhead_ratio_qk,
298  ck_tile::index_t num_splits,
299  const void* block_table_ptr,
300  ck_tile::index_t batch_stride_block_table,
301  ck_tile::index_t page_block_size,
302  const void* cache_batch_idx,
303  float scale_s,
304  float scale_p,
305  float logits_soft_cap,
306  ck_tile::index_t stride_q,
307  ck_tile::index_t stride_k,
308  ck_tile::index_t stride_v,
309  ck_tile::index_t stride_bias,
310  ck_tile::index_t stride_o_acc,
311  ck_tile::index_t nhead_stride_q,
312  ck_tile::index_t nhead_stride_k,
313  ck_tile::index_t nhead_stride_v,
314  ck_tile::index_t nhead_stride_bias,
315  ck_tile::index_t nhead_stride_lse_acc,
316  ck_tile::index_t nhead_stride_o_acc,
317  ck_tile::index_t batch_stride_q,
318  ck_tile::index_t batch_stride_k,
319  ck_tile::index_t batch_stride_v,
320  ck_tile::index_t batch_stride_bias,
321  ck_tile::index_t batch_stride_lse_acc,
322  ck_tile::index_t batch_stride_o_acc,
323  ck_tile::index_t split_stride_lse_acc,
324  ck_tile::index_t split_stride_o_acc,
325  ck_tile::index_t window_size_left,
326  ck_tile::index_t window_size_right,
327  ck_tile::index_t mask_type)
328  {
329  Kargs kargs{{q_ptr,
330  k_ptr,
331  v_ptr,
332  lse_acc_ptr,
333  o_acc_ptr,
334  batch,
335  seqlen_q,
336  seqlen_k,
337  hdim_q,
338  hdim_v,
339  num_head_q,
340  nhead_ratio_qk,
341  num_splits,
342 #if CK_TILE_FMHA_FWD_FAST_EXP2
343  static_cast<float>(scale_s * ck_tile::log2e_v<>),
344 #else
345  scale_s,
346 #endif
347  stride_q,
348  stride_k,
349  stride_v,
350  stride_o_acc,
351  nhead_stride_q,
352  nhead_stride_k,
353  nhead_stride_v,
354  nhead_stride_lse_acc,
355  nhead_stride_o_acc,
356  split_stride_lse_acc,
357  split_stride_o_acc}, // args for common karg
358  {}, // placeholder for bias
359  {}, // placeholder for mask
360  {}, // placeholder for fp8_static_quant args
361  {}, // placeholder for paged-block table or cache_batch_idx
362  {}, // placeholder for logits_soft_cap
363  reinterpret_cast<const int32_t*>(seqlen_k_ptr),
364  batch_stride_q,
365  batch_stride_k,
366  batch_stride_v,
367  batch_stride_lse_acc,
368  batch_stride_o_acc};
369 
371  {
372  kargs.bias_ptr = bias_ptr;
373  kargs.stride_bias = stride_bias;
374  kargs.nhead_stride_bias = nhead_stride_bias;
375  kargs.batch_stride_bias = batch_stride_bias;
376  }
377  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
378  {
379  kargs.alibi_slope_ptr = bias_ptr;
380  kargs.alibi_slope_stride = stride_bias;
381  }
382  if constexpr(kHasMask)
383  {
384  kargs.window_size_left = window_size_left;
385  kargs.window_size_right = window_size_right;
386  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
387  }
388  if constexpr(kDoFp8StaticQuant)
389  {
390  kargs.scale_p = scale_p;
391  }
392  if constexpr(kIsPagedKV)
393  {
394  kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
395  kargs.batch_stride_block_table = batch_stride_block_table;
396  kargs.page_block_size = page_block_size;
397  }
398  else
399  {
400  kargs.cache_batch_idx = reinterpret_cast<const int32_t*>(cache_batch_idx);
401  }
402  if constexpr(kHasLogitsSoftCap)
403  {
404  kargs.init_logits_soft_cap(logits_soft_cap);
405  }
406 
407  return kargs;
408  }
409 
410  template <bool Cond = kIsGroupMode>
411  __host__ static constexpr std::enable_if_t<Cond, Kargs>
412  MakeKargs(const void* q_ptr,
413  const void* k_ptr,
414  const void* v_ptr,
415  const void* bias_ptr,
416  void* lse_acc_ptr, /* workspace for lse accumulation when num_splits > 1, otherwise
417  final lse */
418  void* o_acc_ptr, /* workspace for o accumulation when num_splits > 1, otherwise final
419  o */
420  ck_tile::index_t batch,
421  const void* seqstart_q_ptr,
422  const void* seqstart_k_ptr,
423  const void* seqlen_k_ptr,
424  ck_tile::index_t hdim_q,
425  ck_tile::index_t hdim_v,
426  ck_tile::index_t num_head_q,
427  ck_tile::index_t nhead_ratio_qk,
428  ck_tile::index_t num_splits,
429  const void* block_table_ptr,
430  ck_tile::index_t batch_stride_block_table,
431  ck_tile::index_t page_block_size,
432  bool is_gappy,
433  float scale_s,
434  float scale_p,
435  float logits_soft_cap,
436  ck_tile::index_t stride_q,
437  ck_tile::index_t stride_k,
438  ck_tile::index_t stride_v,
439  ck_tile::index_t stride_bias,
440  ck_tile::index_t stride_o_acc,
441  ck_tile::index_t nhead_stride_q,
442  ck_tile::index_t nhead_stride_k,
443  ck_tile::index_t nhead_stride_v,
444  ck_tile::index_t nhead_stride_bias,
445  ck_tile::index_t nhead_stride_lse_acc,
446  ck_tile::index_t nhead_stride_o_acc,
447  ck_tile::index_t batch_stride_k, // only used for paged-kvcache
448  ck_tile::index_t batch_stride_v, // only used for paged-kvcache
449  ck_tile::index_t split_stride_lse_acc,
450  ck_tile::index_t split_stride_o_acc,
451  ck_tile::index_t window_size_left,
452  ck_tile::index_t window_size_right,
453  ck_tile::index_t mask_type)
454  {
455  Kargs kargs{{q_ptr,
456  k_ptr,
457  v_ptr,
458  lse_acc_ptr,
459  o_acc_ptr,
460  batch,
461  -1, // seqlen_q will be updated by another pointer
462  -1, // seqlen_k will be updated by another pointer
463  hdim_q,
464  hdim_v,
465  num_head_q,
466  nhead_ratio_qk,
467  num_splits,
468 #if CK_TILE_FMHA_FWD_FAST_EXP2
469  static_cast<float>(scale_s * ck_tile::log2e_v<>),
470 #else
471  scale_s,
472 #endif
473  stride_q,
474  stride_k,
475  stride_v,
476  stride_o_acc,
477  nhead_stride_q,
478  nhead_stride_k,
479  nhead_stride_v,
480  nhead_stride_lse_acc,
481  nhead_stride_o_acc,
482  split_stride_lse_acc,
483  split_stride_o_acc}, // args for common karg
484  {}, // placeholder for bias
485  {}, // placeholder for mask
486  {}, // placeholder for fp8_static_quant args
487  {}, // placeholder for paged-block table
488  {}, // placeholder for logits_soft_cap
489  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
490  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
491  reinterpret_cast<const int32_t*>(seqlen_k_ptr),
492  batch_stride_k,
493  batch_stride_v};
494 
496  {
497  kargs.bias_ptr = bias_ptr;
498  kargs.stride_bias = stride_bias;
499  kargs.nhead_stride_bias = nhead_stride_bias;
500  }
501  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
502  {
503  kargs.alibi_slope_ptr = bias_ptr;
504  kargs.alibi_slope_stride = stride_bias;
505  }
506  if constexpr(kHasMask)
507  {
508  kargs.window_size_left = window_size_left;
509  kargs.window_size_right = window_size_right;
510  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
511  }
512  if constexpr(kDoFp8StaticQuant)
513  {
514  kargs.scale_p = scale_p;
515  }
516  if constexpr(kIsPagedKV)
517  {
518  kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
519  kargs.batch_stride_block_table = batch_stride_block_table;
520  kargs.page_block_size = page_block_size;
521  kargs.is_gappy = is_gappy;
522  }
523  if constexpr(kHasLogitsSoftCap)
524  {
525  kargs.init_logits_soft_cap(logits_soft_cap);
526  }
527 
528  return kargs;
529  }
530 
531  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
532  ck_tile::index_t nhead_q,
533  ck_tile::index_t nhead_kv,
534  ck_tile::index_t max_seqlen_q,
535  ck_tile::index_t hdim_v,
536  ck_tile::index_t num_splits)
537  {
538  ck_tile::index_t nhead_ = kMergeNumHeadGroupsSeqLenQ ? nhead_kv : nhead_q;
539  ck_tile::index_t max_seqlen_q_ =
540  max_seqlen_q * (kMergeNumHeadGroupsSeqLenQ ? nhead_q / nhead_kv : 1);
541 
542  // TODO: this may need tuning
543  return dim3(ck_tile::integer_divide_ceil(max_seqlen_q_, FmhaPipeline::kM0) *
544  ck_tile::integer_divide_ceil(hdim_v, FmhaPipeline::kN1) * num_splits,
545  nhead_,
546  batch_size);
547  }
548 
549  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
550  {
551  const index_t num_tile_n1 = ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
552 
553  const auto f = [](index_t dividend, index_t divisor) {
554  index_t quotient = dividend / divisor;
555  index_t modulus = dividend - quotient * divisor;
556  return ck_tile::make_tuple(quotient, modulus);
557  };
558 
559  const auto [mn, i_split] = f(blockIdx.x, kargs.num_splits);
560  const auto [i_tile_m, i_tile_n] = f(mn, num_tile_n1);
561  const index_t i_nhead = blockIdx.y;
562  const index_t i_batch = blockIdx.z;
563 
564  if constexpr(kHasMask)
565  {
566  // assume that num_tile_n1 is always 1
567  return ck_tile::make_tuple(
568  (gridDim.x / kargs.num_splits) - 1 - i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
569  }
570  else
571  {
572  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_split, i_nhead, i_batch);
573  }
574  }
575 
576  __host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
577 
579  {
580  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
581  }
582 
583  CK_TILE_DEVICE void operator()(Kargs kargs) const
584  {
585  // allocate LDS
586  __shared__ char smem_ptr[GetSmemSize()];
587 
588  // divide problem
589  const auto [i_tile_m, i_tile_n, i_split, i_nhead, i_batch] = GetTileIndex(kargs);
590 
591  const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
592  const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
593 
594  long_index_t batch_offset_q = 0;
595  long_index_t batch_offset_k = 0; // unused for paged-kvcache
596  long_index_t batch_offset_v = 0; // unused for paged-kvcache
597  long_index_t batch_offset_bias = 0;
598  long_index_t batch_offset_lse_acc = 0;
599  long_index_t batch_offset_o_acc = 0;
600  index_t kv_l2p_offset =
601  0; // logical-to-physical offset of seqlen_k coordinate. only used for paged-kvcache
602 
603  if constexpr(kIsGroupMode)
604  {
605  // get starting offset for each batch
606  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
607  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
608 
609  batch_offset_q = query_start * kargs.stride_q;
610  batch_offset_k = key_start * kargs.stride_k;
611  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
612  {
613  batch_offset_v = key_start * kargs.stride_v;
614  }
615  else
616  {
617  batch_offset_v = key_start;
618  }
620  {
621  batch_offset_bias = query_start * kargs.stride_bias;
622  }
623 
624  batch_offset_lse_acc = query_start;
625  batch_offset_o_acc = query_start * kargs.stride_o_acc;
626 
627  // get real # queries & # keys under group mode
628  kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
629 
630  // # of required blocks is different in each groups, terminate unnecessary blocks
631  // earlier
632  if(kargs.seqlen_q * (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) <= i_m0)
633  {
634  return;
635  }
636 
637  if(kargs.seqlen_k_ptr != nullptr)
638  {
639  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
640  }
641  else
642  {
643  kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
644  }
645 
646  if constexpr(kIsPagedKV)
647  {
648  if(kargs.is_gappy)
649  {
650  // seqstart_k_ptr has different meaning in this case
651  kv_l2p_offset = kargs.seqstart_k_ptr[i_batch];
652  }
653  }
654  }
655  else
656  {
657  const index_t i_cache_batch = [&, i_batch_ = i_batch] {
658  if constexpr(kIsPagedKV)
659  {
660  return i_batch_;
661  }
662  else
663  {
664  return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_]
665  : i_batch_);
666  }
667  }();
668 
669  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
670  batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
671  batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
672  batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
673  batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
674 
676  {
677  batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
678  }
679 
680  if(kargs.seqlen_k_ptr != nullptr)
681  {
682  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
683  }
684  }
685 
686  // for simplicity, batch stride we just modify the pointer
687  const index_t i_nhead_k =
688  (kMergeNumHeadGroupsSeqLenQ ? i_nhead : i_nhead / kargs.nhead_ratio_qk);
689 
690  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
691  static_cast<long_index_t>(i_nhead) *
692  (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
693  kargs.nhead_stride_q +
694  batch_offset_q;
695  const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) +
696  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k +
697  batch_offset_k;
698  const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) +
699  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v +
700  batch_offset_v;
701 
702  ODataType* o_acc_ptr = reinterpret_cast<ODataType*>(kargs.o_acc_ptr) +
703  static_cast<long_index_t>(i_nhead) *
704  (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
705  kargs.nhead_stride_o_acc +
706  batch_offset_o_acc + i_split * kargs.split_stride_o_acc;
707 
708  // Q/K/V DRAM and DRAM window
709  const auto q_dram = [&] {
710  const auto q_dram_naive = [&] {
711  if constexpr(kMergeNumHeadGroupsSeqLenQ)
712  {
713  // reshape: (nhead_ratio_qk, seqlen_q, hdim_q) -> (nhead_ratio_qk * seqlen_q,
714  // hdim_q)
715  const auto view = make_naive_tensor_view<address_space_enum::global>(
716  q_ptr,
717  make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_q),
718  make_tuple(kargs.nhead_stride_q, kargs.stride_q, 1),
720  number<1>{});
721 
722  return transform_tensor_view(
723  view,
724  make_tuple(
725  make_merge_transform(make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q)),
726  make_pass_through_transform(kargs.hdim_q)),
729  }
730  else
731  {
732  return make_naive_tensor_view<address_space_enum::global>(
733  q_ptr,
734  make_tuple(kargs.seqlen_q, kargs.hdim_q),
735  make_tuple(kargs.stride_q, 1),
737  number<1>{});
738  }
739  }();
740 
741  if constexpr(FmhaPipeline::kQLoadOnce)
742  {
743  return pad_tensor_view(
744  q_dram_naive,
747  }
748  else
749  {
750  return pad_tensor_view(
751  q_dram_naive,
754  }
755  }();
756 
757  const auto make_k_dram = [&](const KDataType* data, index_t height) {
758  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
759  data, // will update this pointer if using paged-kvcache
760  make_tuple(height, kargs.hdim_q),
761  make_tuple(kargs.stride_k, 1),
763  number<1>{});
764 
765  return pad_tensor_view(
766  k_dram_naive,
769  };
770  const auto k_dram = [&]() {
771  if constexpr(kIsPagedKV)
772  {
773  return make_k_dram(nullptr, kargs.page_block_size);
774  }
775  else
776  {
777  return make_k_dram(k_ptr, kargs.seqlen_k);
778  }
779  }();
780 
781  const auto make_v_dram = [&](const VDataType* data, index_t length) {
782  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
783  {
784  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
785  data, // will update this pointer if using paged-kvcache
786  make_tuple(length, kargs.hdim_v),
787  make_tuple(kargs.stride_v, 1),
789  number<1>{});
790 
791  const auto v_dram_transposed =
792  transform_tensor_view(v_dram_naive,
797 
798  return pad_tensor_view(
799  v_dram_transposed,
802  }
803  else
804  {
805  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
806  data, // will update this pointer if using paged-kvcache
807  make_tuple(kargs.hdim_v, length),
808  make_tuple(kargs.stride_v, 1),
810  number<1>{});
811 
812  return pad_tensor_view(
813  v_dram_naive,
816  }
817  };
818  const auto v_dram = [&]() {
819  if constexpr(kIsPagedKV)
820  {
821  return make_v_dram(nullptr, kargs.page_block_size);
822  }
823  else
824  {
825  return make_v_dram(v_ptr, kargs.seqlen_k);
826  }
827  }();
828 
829  auto k_page_block_navigator = [&, i_batch_ = i_batch]() {
830  if constexpr(kIsPagedKV)
831  {
832  const auto* block_indices =
833  reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
834  i_batch_ * kargs.batch_stride_block_table;
835  const index_t num_blocks =
836  integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
837 
838  const long_index_t fixed_offset =
839  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k;
840 
841  return make_page_block_navigator<const KDataType, 0>(
842  kargs.k_ptr,
843  kargs.batch_stride_k, // kcache page-block stride/size
844  fixed_offset,
845  block_indices,
846  num_blocks,
847  kargs.page_block_size,
848  k_dram,
849  make_k_dram(nullptr,
850  (kv_l2p_offset + kargs.seqlen_k) -
851  (num_blocks - 1) * kargs.page_block_size));
852  }
853  else
854  {
855  return make_page_block_navigator(k_dram);
856  }
857  }();
858 
859  auto v_page_block_navigator = [&, i_batch_ = i_batch]() {
860  if constexpr(kIsPagedKV)
861  {
862  const auto* block_indices =
863  reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
864  i_batch_ * kargs.batch_stride_block_table;
865  const index_t num_blocks =
866  integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
867 
868  const long_index_t fixed_offset =
869  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v;
870 
871  return make_page_block_navigator<const VDataType, 1>(
872  kargs.v_ptr,
873  kargs.batch_stride_v, // vcache page-block stride/size
874  fixed_offset,
875  block_indices,
876  num_blocks,
877  kargs.page_block_size,
878  v_dram,
879  make_v_dram(nullptr,
880  (kv_l2p_offset + kargs.seqlen_k) -
881  (num_blocks - 1) * kargs.page_block_size));
882  }
883  else
884  {
885  return make_page_block_navigator(v_dram);
886  }
887  }();
888 
889  auto q_dram_window = make_tile_window(
890  q_dram,
891  [&]() {
892  if constexpr(FmhaPipeline::kQLoadOnce)
895  else
897  }(),
898  {i_m0, 0});
899 
900  auto k_dram_window_lengths =
902  auto v_dram_window_lengths =
904 
907  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
908  constexpr auto bias_dram_window_lengths =
911  {
912  const BiasDataType* bias_ptr =
913  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
914  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
915  batch_offset_bias;
916 
917  const auto bias_dram = [&]() {
918  const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
919  bias_ptr,
920  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
921  make_tuple(kargs.stride_bias, 1),
923  number<1>{});
924 
925  return pad_tensor_view(
926  bias_dram_naive, bias_dram_window_lengths, sequence<false, kPadSeqLenK>{});
927  }();
928 
929  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
930  }
931  else
932  {
933  return make_null_tile_window(bias_dram_window_lengths);
934  }
935  }();
936 
937  // lse acc
938  auto lse_acc_dram_window = [&, i_nhead_ = i_nhead, i_split_ = i_split]() {
939  constexpr auto lse_acc_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
940  LSEDataType* lse_acc_ptr = reinterpret_cast<LSEDataType*>(kargs.lse_acc_ptr) +
941  static_cast<long_index_t>(i_nhead_) *
942  (kMergeNumHeadGroupsSeqLenQ ? kargs.nhead_ratio_qk : 1) *
943  kargs.nhead_stride_lse_acc +
944  batch_offset_lse_acc + i_split_ * kargs.split_stride_lse_acc;
945 
946  const auto lse_acc_dram = [&] {
947  const auto lse_acc_dram_naive = [&] {
948  if constexpr(kMergeNumHeadGroupsSeqLenQ)
949  {
950  // reshape: (nhead_ratio_qk, seqlen_q) -> (nhead_ratio_qk * seqlen_q)
951  const auto view = make_naive_tensor_view<address_space_enum::global>(
952  lse_acc_ptr,
953  make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q),
954  make_tuple(kargs.nhead_stride_lse_acc, 1),
955  number<1>{},
956  number<1>{});
957 
958  return transform_tensor_view(view,
960  kargs.nhead_ratio_qk, kargs.seqlen_q))),
963  }
964  else
965  {
966  return make_naive_tensor_view<address_space_enum::global>(
967  lse_acc_ptr,
968  make_tuple(kargs.seqlen_q),
969  make_tuple(1),
970  number<1>{},
971  number<1>{});
972  }
973  }();
974  return pad_tensor_view(
975  lse_acc_dram_naive, lse_acc_dram_window_lengths, sequence<kPadSeqLenQ>{});
976  }();
977 
978  return make_tile_window(lse_acc_dram, lse_acc_dram_window_lengths, {i_m0});
979  }();
980 
981  FmhaMask mask = [&]() {
982  if constexpr(kHasMask)
983  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
984  kargs.window_size_left,
985  kargs.window_size_right,
986  kargs.seqlen_q,
987  kargs.seqlen_k,
989  else
990  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
991  }();
992 
993  // WA i_batch capture structure binding before c++20
994  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
995  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
996  {
997  // data loading, shared by entire wg
998  // TODO: how to use s_read?
999  SaccDataType slope =
1000  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1001  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1002 #if CK_TILE_FMHA_FWD_FAST_EXP2
1003  slope *= ck_tile::log2e_v<>;
1004 #endif
1005  if constexpr(kHasMask)
1006  {
1007  return make_alibi_from_lr_mask<SaccDataType, true, 32>(slope,
1008  kargs.window_size_left,
1009  kargs.window_size_right,
1010  kargs.seqlen_q,
1011  kargs.seqlen_k,
1012  kargs.mask_type);
1013  }
1014  else
1015  {
1016  return Alibi<SaccDataType, true, 32>{
1017  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1018  }
1019  }
1020  else
1021  {
1022  return EmptyPositionEncoding<SaccDataType>{};
1023  }
1024  }();
1025 
1026  AttentionVariant variant;
1027  const auto variant_params = [&] {
1028  if constexpr(kHasLogitsSoftCap)
1029  {
1031  mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1032  }
1033  else
1034  {
1035  return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
1036  }
1037  }();
1038 
1039  BlockIndices block_indices{i_batch, i_nhead, i_nhead_k};
1040 
1041  auto o_acc_tile = [&, i_split_ = i_split]() {
1042  if constexpr(kDoFp8StaticQuant)
1043  {
1044  return FmhaPipeline{}(q_dram_window,
1045  identity{}, // q_element_func
1046  k_dram_window_lengths,
1047  k_page_block_navigator,
1048  identity{}, // k_element_func
1049  v_dram_window_lengths,
1050  v_page_block_navigator,
1051  identity{}, // v_element_func
1052  bias_dram_window,
1053  identity{}, // bias_element_func
1054  lse_acc_dram_window,
1055  identity{}, // lse_element_func
1056  identity{}, // s_acc_element_func
1057  scales{kargs.scale_p}, // p_compute_element_func
1058  identity{}, // o_acc_element_func
1059  kargs.num_splits,
1060  i_split_,
1061  mask,
1062  position_encoding,
1063  kargs.scale_s,
1064  variant,
1065  variant_params,
1066  block_indices,
1067  kv_l2p_offset,
1068  smem_ptr);
1069  }
1070  else
1071  {
1072  return FmhaPipeline{}(q_dram_window,
1073  k_dram_window_lengths,
1074  k_page_block_navigator,
1075  v_dram_window_lengths,
1076  v_page_block_navigator,
1077  bias_dram_window,
1078  lse_acc_dram_window,
1079  kargs.num_splits,
1080  i_split_,
1081  mask,
1082  position_encoding,
1083  kargs.scale_s,
1084  variant,
1085  variant_params,
1086  block_indices,
1087  kv_l2p_offset,
1088  smem_ptr);
1089  }
1090  }();
1091 
1092  // Oacc DRAM and Oacc DRAM window
1093  auto o_acc_dram = [&] {
1094  const auto o_acc_dram_naive = [&] {
1095  if constexpr(kMergeNumHeadGroupsSeqLenQ)
1096  {
1097  // reshape: (nhead_ratio_qk, seqlen_q, hdim_v) -> (nhead_ratio_qk * seqlen_q,
1098  // hdim_v)
1099  const auto view = make_naive_tensor_view<address_space_enum::global>(
1100  o_acc_ptr,
1101  make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_v),
1102  make_tuple(kargs.nhead_stride_o_acc, kargs.stride_o_acc, 1),
1103  number<FmhaPipeline::kAlignmentOacc>{},
1104  number<1>{});
1105 
1106  return transform_tensor_view(
1107  view,
1108  make_tuple(
1109  make_merge_transform(make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q)),
1110  make_pass_through_transform(kargs.hdim_v)),
1111  make_tuple(sequence<0, 1>{}, sequence<2>{}),
1112  make_tuple(sequence<0>{}, sequence<1>{}));
1113  }
1114  else
1115  {
1116  return make_naive_tensor_view<address_space_enum::global>(
1117  o_acc_ptr,
1118  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1119  make_tuple(kargs.stride_o_acc, 1),
1120  number<FmhaPipeline::kAlignmentOacc>{},
1121  number<1>{});
1122  }
1123  }();
1124 
1125  return pad_tensor_view(
1126  o_acc_dram_naive,
1127  make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
1128  sequence<kPadSeqLenQ, kPadHeadDimV>{});
1129  }();
1130 
1131  auto o_acc_dram_window =
1132  make_tile_window(o_acc_dram,
1133  make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
1134  {i_m0, i_n1});
1135 
1136  EpiloguePipeline{}(o_acc_dram_window, o_acc_tile);
1137  }
1138 };
1139 
1140 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
#define _TS_
#define _SS_
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:63
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
_Float16 fp16_t
Definition: half.hpp:110
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView &tensor_view)
Definition: page_block_navigator.hpp:333
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:510
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1672
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:529
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1615
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
__host__ __device__ scales(Scale) -> scales< Scale >
FIXME: create macro to replace 'host device' and nothing more.
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: block_attention_bias_enum.hpp:19
Definition: fmha_fwd_splitkv_kernel.hpp:191
const void * alibi_slope_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:193
ck_tile::index_t alibi_slope_stride
Definition: fmha_fwd_splitkv_kernel.hpp:194
Definition: fmha_fwd_splitkv_kernel.hpp:186
ck_tile::index_t batch_stride_bias
Definition: fmha_fwd_splitkv_kernel.hpp:187
Definition: fmha_fwd_splitkv_kernel.hpp:237
ck_tile::index_t batch_stride_lse_acc
Definition: fmha_fwd_splitkv_kernel.hpp:245
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_splitkv_kernel.hpp:243
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_splitkv_kernel.hpp:240
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_splitkv_kernel.hpp:241
ck_tile::index_t batch_stride_o_acc
Definition: fmha_fwd_splitkv_kernel.hpp:246
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:238
Definition: fmha_fwd_splitkv_kernel.hpp:274
ck_tile::index_t batch_idx
Definition: fmha_fwd_splitkv_kernel.hpp:275
ck_tile::index_t kv_head_idx
Definition: fmha_fwd_splitkv_kernel.hpp:277
ck_tile::index_t qo_head_idx
Definition: fmha_fwd_splitkv_kernel.hpp:276
Definition: fmha_fwd_splitkv_kernel.hpp:222
const int32_t * cache_batch_idx
Definition: fmha_fwd_splitkv_kernel.hpp:223
Definition: fmha_fwd_splitkv_kernel.hpp:179
const void * bias_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:180
ck_tile::index_t stride_bias
Definition: fmha_fwd_splitkv_kernel.hpp:181
ck_tile::index_t nhead_stride_bias
Definition: fmha_fwd_splitkv_kernel.hpp:182
Definition: fmha_fwd_splitkv_kernel.hpp:119
ck_tile::index_t split_stride_o_acc
Definition: fmha_fwd_splitkv_kernel.hpp:153
ck_tile::index_t nhead_stride_o_acc
Definition: fmha_fwd_splitkv_kernel.hpp:150
const void * k_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:121
ck_tile::index_t num_splits
Definition: fmha_fwd_splitkv_kernel.hpp:137
void * lse_acc_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:123
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_splitkv_kernel.hpp:146
void * o_acc_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:124
ck_tile::index_t nhead_stride_lse_acc
Definition: fmha_fwd_splitkv_kernel.hpp:149
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_splitkv_kernel.hpp:148
ck_tile::index_t hdim_q
Definition: fmha_fwd_splitkv_kernel.hpp:130
const void * v_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:122
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_splitkv_kernel.hpp:147
ck_tile::index_t split_stride_lse_acc
Definition: fmha_fwd_splitkv_kernel.hpp:152
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_splitkv_kernel.hpp:136
ck_tile::index_t stride_q
Definition: fmha_fwd_splitkv_kernel.hpp:141
ck_tile::index_t seqlen_k
Definition: fmha_fwd_splitkv_kernel.hpp:129
ck_tile::index_t batch
Definition: fmha_fwd_splitkv_kernel.hpp:126
const void * q_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:120
ck_tile::index_t stride_v
Definition: fmha_fwd_splitkv_kernel.hpp:143
ck_tile::index_t stride_k
Definition: fmha_fwd_splitkv_kernel.hpp:142
ck_tile::index_t num_head_q
Definition: fmha_fwd_splitkv_kernel.hpp:133
ck_tile::index_t seqlen_q
Definition: fmha_fwd_splitkv_kernel.hpp:128
ck_tile::index_t stride_o_acc
Definition: fmha_fwd_splitkv_kernel.hpp:144
float scale_s
Definition: fmha_fwd_splitkv_kernel.hpp:139
ck_tile::index_t hdim_v
Definition: fmha_fwd_splitkv_kernel.hpp:131
Definition: fmha_fwd_splitkv_kernel.hpp:210
ck_tile::index_t page_block_size
Definition: fmha_fwd_splitkv_kernel.hpp:213
ck_tile::index_t batch_stride_block_table
Definition: fmha_fwd_splitkv_kernel.hpp:212
const int32_t * block_table_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:211
Definition: fmha_fwd_splitkv_kernel.hpp:112
Definition: fmha_fwd_splitkv_kernel.hpp:205
float scale_p
Definition: fmha_fwd_splitkv_kernel.hpp:206
Definition: fmha_fwd_splitkv_kernel.hpp:260
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:263
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_splitkv_kernel.hpp:265
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:261
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_splitkv_kernel.hpp:262
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_splitkv_kernel.hpp:267
Definition: fmha_fwd_splitkv_kernel.hpp:217
bool is_gappy
Definition: fmha_fwd_splitkv_kernel.hpp:218
Definition: fmha_fwd_splitkv_kernel.hpp:157
float logits_soft_cap_rcp
Definition: fmha_fwd_splitkv_kernel.hpp:175
float logits_soft_cap
Definition: fmha_fwd_splitkv_kernel.hpp:174
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_fwd_splitkv_kernel.hpp:160
Definition: fmha_fwd_splitkv_kernel.hpp:198
ck_tile::index_t window_size_right
Definition: fmha_fwd_splitkv_kernel.hpp:200
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_splitkv_kernel.hpp:201
ck_tile::index_t window_size_left
Definition: fmha_fwd_splitkv_kernel.hpp:200
Definition: fmha_fwd_splitkv_kernel.hpp:64
Definition: fmha_fwd_splitkv_kernel.hpp:24
static constexpr auto BiasEnum
Definition: fmha_fwd_splitkv_kernel.hpp:49
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_splitkv_kernel.hpp:27
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_fwd_splitkv_kernel.hpp:35
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_splitkv_kernel.hpp:578
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_splitkv_kernel.hpp:26
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_fwd_splitkv_kernel.hpp:41
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_splitkv_kernel.hpp:46
std::conditional_t< kIsGroupMode, GroupModeKargs, BatchModeKargs > Kargs
Definition: fmha_fwd_splitkv_kernel.hpp:271
static constexpr __host__ std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_acc_ptr, void *o_acc_ptr, ck_tile::index_t batch, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, ck_tile::index_t num_splits, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, const void *cache_batch_idx, float scale_s, float scale_p, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o_acc, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_lse_acc, ck_tile::index_t batch_stride_o_acc, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type)
Definition: fmha_fwd_splitkv_kernel.hpp:282
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size, ck_tile::index_t nhead_q, ck_tile::index_t nhead_kv, ck_tile::index_t max_seqlen_q, ck_tile::index_t hdim_v, ck_tile::index_t num_splits)
Definition: fmha_fwd_splitkv_kernel.hpp:531
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_splitkv_kernel.hpp:45
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_splitkv_kernel.hpp:44
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_splitkv_kernel.hpp:33
remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_splitkv_kernel.hpp:39
static constexpr bool kMergeNumHeadGroupsSeqLenQ
Definition: fmha_fwd_splitkv_kernel.hpp:53
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_splitkv_kernel.hpp:56
static constexpr bool kDoFp8StaticQuant
Definition: fmha_fwd_splitkv_kernel.hpp:51
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_splitkv_kernel.hpp:37
static constexpr bool kHasMask
Definition: fmha_fwd_splitkv_kernel.hpp:57
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_fwd_splitkv_kernel.hpp:549
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_splitkv_kernel.hpp:47
static __host__ std::string GetName()
Definition: fmha_fwd_splitkv_kernel.hpp:72
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_splitkv_kernel.hpp:34
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_fwd_splitkv_kernel.hpp:30
static constexpr __host__ auto BlockSize()
Definition: fmha_fwd_splitkv_kernel.hpp:576
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_splitkv_kernel.hpp:36
remove_cvref_t< typename FmhaPipeline::OaccDataType > OaccDataType
Definition: fmha_fwd_splitkv_kernel.hpp:38
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_fwd_splitkv_kernel.hpp:55
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_splitkv_kernel.hpp:28
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_splitkv_kernel.hpp:32
static constexpr bool kStoreLSE
Definition: fmha_fwd_splitkv_kernel.hpp:50
static constexpr bool kIsPagedKV
Definition: fmha_fwd_splitkv_kernel.hpp:52
static constexpr bool kIsGroupMode
Definition: fmha_fwd_splitkv_kernel.hpp:43
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_splitkv_kernel.hpp:583
static constexpr __host__ std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_acc_ptr, void *o_acc_ptr, ck_tile::index_t batch, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, ck_tile::index_t num_splits, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, bool is_gappy, float scale_s, float scale_p, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o_acc, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_o_acc, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type)
Definition: fmha_fwd_splitkv_kernel.hpp:412
static constexpr bool kHasLogitsSoftCap
Definition: fmha_fwd_splitkv_kernel.hpp:48
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_splitkv_kernel.hpp:25
Definition: variants.hpp:63
float logits_soft_cap
Definition: variants.hpp:128
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: sequence.hpp:52