/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/fmha/kernel/fmha_fwd_pagedkv_kernel.hpp Source File
fmha_fwd_pagedkv_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
10 
11 #include <string>
12 #include <type_traits>
13 #include <utility>
14 #include <variant>
15 
16 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
17 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
18 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
19 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
20 // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
21 
22 namespace ck_tile {
23 
24 // TODO: This class is a variant of the existing FmhaFwdSplitKVKernel pipeline.
25 // Refactoring to extract shared logic is recommended as future work.
26 template <typename FmhaPipeline_, typename EpiloguePipeline_>
28 {
31  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
32  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
33  static_assert(kBlockPerCu > 0);
34  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
35 
43 
45 
46  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
47  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
48  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
49  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
50  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
51  static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
52  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
53  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
54  static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
55  static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
56  static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
57 
60  static constexpr bool kHasMask = FmhaMask::IsMasking;
61 
62  static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
63 
64  // clang-format off
65  template <typename T> struct t2s;
66  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
67  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
68  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
69  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
70  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
71  // clang-format on
72 
73  CK_TILE_HOST static std::string GetName()
74  {
75  // sync with generate.py
76  // clang-format off
77  using bfs = typename FmhaPipeline::BlockFmhaShape;
78  using g0br = typename bfs::Gemm0BlockWarps;
79  using g1br = typename bfs::Gemm1BlockWarps;
80  using g0wt = typename bfs::Gemm0WarpTile;
81  using g1wt = typename bfs::Gemm1WarpTile;
82  #define _SS_ std::string
83  #define _TS_ std::to_string
84  auto pn = [&] () {
85  std::string n;
86  if (kPadSeqLenQ) n += "s";
87  if (kPadSeqLenK) n += "sk";
88  if (kPadHeadDimQ) n += "d";
89  if (kPadHeadDimV) n += "dv";
90  return n.empty() ? n : std::string("p") + n; }();
91  return
92  _SS_("fmha_fwd_pagedkv_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
93  "_" + (kIsGroupMode ? "group" : "batch") + "_"
94  "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
95  _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
96  "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
97  "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
98  "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
99  "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
100  (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
101  "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
102  (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
103  (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kIsPagedKV ? "_pagedkv" : "_npagedkv" );
104  #undef _SS_
105  #undef _TS_
106  // clang-format on
107  }
108 
109  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
110  // arg
112  {
113  };
114 
115  // kargs use aggregate initializer, so no constructor will provided
116  // use inheritance to minimize karg size
117  // user need to use MakeKargs() function to create kargs.
119  {
120  const void* q_ptr;
121  const void* k_ptr;
122  const void* v_ptr;
123  void* o_ptr;
124 
129 
131  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
132  // if this param is larger than 1, indicate MQA/GQA case
134  float scale_s;
135 
140 
145  };
146 
148  {
150 
151  void init_logits_soft_cap(float logits_soft_cap_)
152  {
153  if(0 < logits_soft_cap_)
154  {
155  logits_soft_cap = logits_soft_cap_;
157  }
158  else
159  {
160  logits_soft_cap = 0.f;
161  logits_soft_cap_rcp = 0.f;
162  }
163  }
164 
167  };
168 
170  {
171  const void* bias_ptr = nullptr;
174  };
175 
177  {
179  };
180 
182  {
183  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
184  const void* alibi_slope_ptr;
185  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
186  };
187 
189  {
190  // ck_tile::index_t window_size_left, window_size_right;
193  };
194 
196  {
197  float scale_p;
198  float scale_o;
199  };
200 
202  {
203  void* lse_ptr = nullptr;
206  };
207 
209  {
211  };
212 
214  {
218  };
219 
221  {
222  bool is_gappy = false;
223  };
224 
226  {
228  };
229 
232  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
233  FmhaFwdBatchModeBiasKargs,
234  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
235  FmhaFwdAlibiKargs,
236  FmhaFwdEmptyKargs<0>>>,
237  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
238  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
239  std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
240  std::conditional_t<kIsPagedKV, CommonPageBlockTableKargs, CacheBatchIdxKargs>,
241  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<4>>
242  {
244 
246  ck_tile::index_t batch_stride_k; // when using paged-kvcache, this will be stride/size for
247  // single kcache page-block
248  ck_tile::index_t batch_stride_v; // when using paged-kvcache, this will be stride/size for
249  // single vcache page-block
251  };
252 
255  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
256  FmhaFwdCommonBiasKargs,
257  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
258  FmhaFwdAlibiKargs,
259  FmhaFwdEmptyKargs<0>>>,
260  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
261  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
262  std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
263  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<4>>,
264  std::conditional_t<kIsPagedKV, GroupModePageBlockTableKargs, FmhaFwdEmptyKargs<5>>,
265  std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
266  {
270 
271  ck_tile::index_t batch_stride_k; // only used for paged-kvcache, this will be stride/size
272  // for single kcache page-block
273  ck_tile::index_t batch_stride_v; // only used for paged-kvcache, this will be stride/size
274  // for single vcache page-block
275  };
276 
277  using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
278 
280  {
284  };
285 
286  template <bool Cond = !kIsGroupMode>
287  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
288  MakeKargsImpl(const void* q_ptr,
289  const void* k_ptr,
290  const void* v_ptr,
291  const void* bias_ptr,
292  void* lse_ptr,
293  void* o_ptr,
294  ck_tile::index_t seqlen_q,
295  ck_tile::index_t seqlen_k,
296  const void* seqlen_k_ptr, // only used for (paged-) kvcache
297  ck_tile::index_t hdim_q,
298  ck_tile::index_t hdim_v,
299  ck_tile::index_t num_head_q,
300  ck_tile::index_t nhead_ratio_qk,
301  const void* block_table_ptr,
302  ck_tile::index_t batch_stride_block_table,
303  ck_tile::index_t page_block_size,
304  const void* cache_batch_idx,
305  float scale_s,
306  float scale_p,
307  float scale_o,
308  float logits_soft_cap,
309  ck_tile::index_t stride_q,
310  ck_tile::index_t stride_k,
311  ck_tile::index_t stride_v,
312  ck_tile::index_t stride_bias,
313  ck_tile::index_t stride_o,
314  ck_tile::index_t nhead_stride_q,
315  ck_tile::index_t nhead_stride_k,
316  ck_tile::index_t nhead_stride_v,
317  ck_tile::index_t nhead_stride_bias,
318  ck_tile::index_t nhead_stride_lse,
319  ck_tile::index_t nhead_stride_o,
320  ck_tile::index_t batch_stride_q,
321  ck_tile::index_t batch_stride_k,
322  ck_tile::index_t batch_stride_v,
323  ck_tile::index_t batch_stride_bias,
324  ck_tile::index_t batch_stride_lse,
325  ck_tile::index_t batch_stride_o,
326  ck_tile::index_t window_size_left,
327  ck_tile::index_t window_size_right,
328  ck_tile::index_t mask_type)
329  {
330  Kargs kargs{{q_ptr,
331  k_ptr,
332  v_ptr,
333  o_ptr,
334  seqlen_q,
335  seqlen_k,
336  hdim_q,
337  hdim_v,
338  num_head_q,
339  nhead_ratio_qk,
340 #if CK_TILE_FMHA_FWD_FAST_EXP2
341  static_cast<float>(scale_s * ck_tile::log2e_v<>),
342 #else
343  scale_s,
344 #endif
345  stride_q,
346  stride_k,
347  stride_v,
348  stride_o,
349  nhead_stride_q,
350  nhead_stride_k,
351  nhead_stride_v,
352  nhead_stride_o}, // args for common karg
353  {}, // placeholder for bias
354  {}, // placeholder for mask
355  {}, // placeholder for lse
356  {}, // placeholder for fp8_static_quant args
357  {}, // placeholder for pagedkv
358  {}, // placeholder for logits_soft_cap
359  reinterpret_cast<const int32_t*>(seqlen_k_ptr),
360  batch_stride_q,
361  batch_stride_k,
362  batch_stride_v,
363  batch_stride_o};
364 
366  {
367  kargs.bias_ptr = bias_ptr;
368  kargs.stride_bias = stride_bias;
369  kargs.nhead_stride_bias = nhead_stride_bias;
370  kargs.batch_stride_bias = batch_stride_bias;
371  }
372  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
373  {
374  kargs.alibi_slope_ptr = bias_ptr;
375  kargs.alibi_slope_stride = stride_bias;
376  }
377  if constexpr(kHasMask)
378  {
379  kargs.window_size_left = window_size_left;
380  kargs.window_size_right = window_size_right;
381  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
382  }
383  if constexpr(kStoreLSE)
384  {
385  kargs.lse_ptr = lse_ptr;
386  kargs.nhead_stride_lse = nhead_stride_lse;
387  kargs.batch_stride_lse = batch_stride_lse;
388  }
389  if constexpr(kDoFp8StaticQuant)
390  {
391  kargs.scale_p = scale_p;
392  kargs.scale_o = scale_o;
393  }
394  if constexpr(kIsPagedKV)
395  {
396  kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
397  kargs.batch_stride_block_table = batch_stride_block_table;
398  kargs.page_block_size = page_block_size;
399  }
400  else
401  {
402  kargs.cache_batch_idx = reinterpret_cast<const int32_t*>(cache_batch_idx);
403  }
404  if constexpr(kHasLogitsSoftCap)
405  {
406  kargs.init_logits_soft_cap(logits_soft_cap);
407  }
408 
409  return kargs;
410  }
411 
412  // std::variant<> can't take in a list initializer, overload for backward compatibility
413  template <bool Cond = !kIsGroupMode>
414  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
415  MakeKargs(const void* q_ptr,
416  const void* k_ptr,
417  const void* v_ptr,
418  const void* bias_ptr,
419  void* lse_ptr,
420  void* o_ptr,
421  ck_tile::index_t seqlen_q,
422  ck_tile::index_t seqlen_k,
423  const void* seqlen_k_ptr, // only used for (paged-) kvcache
424  ck_tile::index_t hdim_q,
425  ck_tile::index_t hdim_v,
426  ck_tile::index_t num_head_q,
427  ck_tile::index_t nhead_ratio_qk,
428  const void* block_table_ptr,
429  ck_tile::index_t batch_stride_block_table,
430  ck_tile::index_t page_block_size,
431  const void* cache_batch_idx,
432  float scale_s,
433  float scale_p,
434  float scale_o,
435  float logits_soft_cap,
436  ck_tile::index_t stride_q,
437  ck_tile::index_t stride_k,
438  ck_tile::index_t stride_v,
439  ck_tile::index_t stride_bias,
440  ck_tile::index_t stride_o,
441  ck_tile::index_t nhead_stride_q,
442  ck_tile::index_t nhead_stride_k,
443  ck_tile::index_t nhead_stride_v,
444  ck_tile::index_t nhead_stride_bias,
445  ck_tile::index_t nhead_stride_lse,
446  ck_tile::index_t nhead_stride_o,
447  ck_tile::index_t batch_stride_q,
448  ck_tile::index_t batch_stride_k,
449  ck_tile::index_t batch_stride_v,
450  ck_tile::index_t batch_stride_bias,
451  ck_tile::index_t batch_stride_lse,
452  ck_tile::index_t batch_stride_o,
453  ck_tile::index_t window_size_left,
454  ck_tile::index_t window_size_right,
455  ck_tile::index_t mask_type)
456  {
457  return MakeKargsImpl(q_ptr,
458  k_ptr,
459  v_ptr,
460  bias_ptr,
461  lse_ptr,
462  o_ptr,
463  seqlen_q,
464  seqlen_k,
465  seqlen_k_ptr,
466  hdim_q,
467  hdim_v,
468  num_head_q,
469  nhead_ratio_qk,
470  block_table_ptr,
471  batch_stride_block_table,
472  page_block_size,
473  cache_batch_idx,
474  scale_s,
475  scale_p,
476  scale_o,
477  logits_soft_cap,
478  stride_q,
479  stride_k,
480  stride_v,
481  stride_bias,
482  stride_o,
483  nhead_stride_q,
484  nhead_stride_k,
485  nhead_stride_v,
486  nhead_stride_bias,
487  nhead_stride_lse,
488  nhead_stride_o,
489  batch_stride_q,
490  batch_stride_k,
491  batch_stride_v,
492  batch_stride_bias,
493  batch_stride_lse,
494  batch_stride_o,
495  window_size_left,
496  window_size_right,
497  mask_type);
498  }
499 
500  template <bool Cond = kIsGroupMode>
501  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
502  MakeKargsImpl(const void* q_ptr,
503  const void* k_ptr,
504  const void* v_ptr,
505  const void* bias_ptr,
506  void* lse_ptr,
507  void* o_ptr,
508  const void* seqstart_q_ptr,
509  const void* seqstart_k_ptr,
510  const void* seqlen_k_ptr,
511  ck_tile::index_t hdim_q,
512  ck_tile::index_t hdim_v,
513  ck_tile::index_t num_head_q,
514  ck_tile::index_t nhead_ratio_qk,
515  const void* block_table_ptr,
516  ck_tile::index_t batch_stride_block_table,
517  ck_tile::index_t page_block_size,
518  bool is_gappy,
519  float scale_s,
520  float scale_p,
521  float scale_o,
522  float logits_soft_cap,
523  ck_tile::index_t stride_q,
524  ck_tile::index_t stride_k,
525  ck_tile::index_t stride_v,
526  ck_tile::index_t stride_bias,
527  ck_tile::index_t stride_o,
528  ck_tile::index_t nhead_stride_q,
529  ck_tile::index_t nhead_stride_k,
530  ck_tile::index_t nhead_stride_v,
531  ck_tile::index_t nhead_stride_bias,
532  ck_tile::index_t nhead_stride_lse,
533  ck_tile::index_t nhead_stride_o,
534  ck_tile::index_t batch_stride_k, // only used for paged-kvcache
535  ck_tile::index_t batch_stride_v, // only used for paged-kvcache
536  ck_tile::index_t window_size_left,
537  ck_tile::index_t window_size_right,
538  ck_tile::index_t mask_type,
539  ck_tile::index_t min_seqlen_q)
540  {
541  Kargs kargs{{q_ptr,
542  k_ptr,
543  v_ptr,
544  o_ptr,
545  -1, // seqlen will be updated by another pointer
546  -1, //
547  hdim_q,
548  hdim_v,
549  num_head_q,
550  nhead_ratio_qk,
551 #if CK_TILE_FMHA_FWD_FAST_EXP2
552  static_cast<float>(scale_s * ck_tile::log2e_v<>),
553 #else
554  scale_s,
555 #endif
556  stride_q,
557  stride_k,
558  stride_v,
559  stride_o,
560  nhead_stride_q,
561  nhead_stride_k,
562  nhead_stride_v,
563  nhead_stride_o}, // args for common karg
564  {}, // placeholder for bias
565  {}, // placeholder for mask
566  {}, // placeholder for lse
567  {}, // placeholder for fp8_static_quant args
568  {}, // placeholder for logits_soft_cap
569  {}, // placeholder for pagdkv
570  {}, // placeholder for min_seqlen_q
571  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
572  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
573  reinterpret_cast<const int32_t*>(seqlen_k_ptr),
574  batch_stride_k,
575  batch_stride_v};
576 
578  {
579  kargs.bias_ptr = bias_ptr;
580  kargs.stride_bias = stride_bias;
581  kargs.nhead_stride_bias = nhead_stride_bias;
582  }
583  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
584  {
585  kargs.alibi_slope_ptr = bias_ptr;
586  kargs.alibi_slope_stride = stride_bias;
587  }
588  if constexpr(kHasMask)
589  {
590  kargs.window_size_left = window_size_left;
591  kargs.window_size_right = window_size_right;
592  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
593  }
594  if constexpr(kStoreLSE)
595  {
596  kargs.lse_ptr = lse_ptr;
597  kargs.nhead_stride_lse = nhead_stride_lse;
598  }
599  if constexpr(kDoFp8StaticQuant)
600  {
601  kargs.scale_p = scale_p;
602  kargs.scale_o = scale_o;
603  }
604  if constexpr(kHasLogitsSoftCap)
605  {
606  kargs.init_logits_soft_cap(logits_soft_cap);
607  }
608  if constexpr(kIsPagedKV)
609  {
610  kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
611  kargs.batch_stride_block_table = batch_stride_block_table;
612  kargs.page_block_size = page_block_size;
613  kargs.is_gappy = is_gappy;
614  }
615  if constexpr(kSkipMinSeqlenQ)
616  {
617  kargs.min_seqlen_q = min_seqlen_q;
618  }
619 
620  return kargs;
621  }
622 
623  // std::variant<> can't take in a list initializer, overload for backward compatibility
624  template <bool Cond = kIsGroupMode>
625  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
626  MakeKargs(const void* q_ptr,
627  const void* k_ptr,
628  const void* v_ptr,
629  const void* bias_ptr,
630  void* lse_ptr,
631  void* o_ptr,
632  const void* seqstart_q_ptr,
633  const void* seqstart_k_ptr,
634  const void* seqlen_k_ptr,
635  ck_tile::index_t hdim_q,
636  ck_tile::index_t hdim_v,
637  ck_tile::index_t num_head_q,
638  ck_tile::index_t nhead_ratio_qk,
639  const void* block_table_ptr,
640  ck_tile::index_t batch_stride_block_table,
641  ck_tile::index_t page_block_size,
642  bool is_gappy,
643  float scale_s,
644  float scale_p,
645  float scale_o,
646  float logits_soft_cap,
647  ck_tile::index_t stride_q,
648  ck_tile::index_t stride_k,
649  ck_tile::index_t stride_v,
650  ck_tile::index_t stride_bias,
651  ck_tile::index_t stride_o,
652  ck_tile::index_t nhead_stride_q,
653  ck_tile::index_t nhead_stride_k,
654  ck_tile::index_t nhead_stride_v,
655  ck_tile::index_t nhead_stride_bias,
656  ck_tile::index_t nhead_stride_lse,
657  ck_tile::index_t nhead_stride_o,
658  ck_tile::index_t batch_stride_k, // only used for paged-kvcache
659  ck_tile::index_t batch_stride_v, // only used for paged-kvcache
660  ck_tile::index_t window_size_left,
661  ck_tile::index_t window_size_right,
662  ck_tile::index_t mask_type,
663  ck_tile::index_t min_seqlen_q)
664  {
665  return MakeKargsImpl(q_ptr,
666  k_ptr,
667  v_ptr,
668  bias_ptr,
669  lse_ptr,
670  o_ptr,
671  seqstart_q_ptr,
672  seqstart_k_ptr,
673  seqlen_k_ptr,
674  hdim_q,
675  hdim_v,
676  num_head_q,
677  nhead_ratio_qk,
678  block_table_ptr,
679  batch_stride_block_table,
680  page_block_size,
681  is_gappy,
682  scale_s,
683  scale_p,
684  scale_o,
685  logits_soft_cap,
686  stride_q,
687  stride_k,
688  stride_v,
689  stride_bias,
690  stride_o,
691  nhead_stride_q,
692  nhead_stride_k,
693  nhead_stride_v,
694  nhead_stride_bias,
695  nhead_stride_lse,
696  nhead_stride_o,
697  batch_stride_k,
698  batch_stride_v,
699  window_size_left,
700  window_size_right,
701  mask_type,
702  min_seqlen_q);
703  }
704 
705  CK_TILE_HOST static void PrintParameters(const Kargs& kargs, int num_batches)
706  {
707  static bool dummy = [&]() {
708  std::cout << std::endl;
709 
710  std::cout << " q_ptr: " << kargs.q_ptr << " k_ptr:" << kargs.k_ptr
711  << " v_ptr: " << kargs.v_ptr << " o_ptr:" << kargs.o_ptr
712  << " hdim_q: " << kargs.hdim_q << " hdim_v: " << kargs.hdim_v
713  << " num_head_q:" << kargs.num_head_q
714  << " nhead_ratio_qk: " << kargs.nhead_ratio_qk << " scale_s:" << kargs.scale_s
715  << " stride_q:" << kargs.stride_q << " stride_k:" << kargs.stride_k
716  << " stride_v:" << kargs.stride_v << " stride_o:" << kargs.stride_o
717  << " nhead_stride_q: " << kargs.nhead_stride_q
718  << " nhead_stride_k: " << kargs.nhead_stride_k
719  << " nhead_stride_v:" << kargs.nhead_stride_v
720  << " nhead_stride_o: " << kargs.nhead_stride_o;
721  if constexpr(!kIsGroupMode)
722  {
723  std::cout << " batch_stride_q:" << kargs.batch_stride_q;
724  }
725  std::cout << " batch_stride_k:" << kargs.batch_stride_k
726  << " batch_stride_v:" << kargs.batch_stride_v;
727 
728  if constexpr(kIsGroupMode)
729  {
730  if constexpr(kSkipMinSeqlenQ)
731  {
732  std::cout << " min_seqlen_q: " << kargs.min_seqlen_q;
733  }
734 
735  std::cout << " seqstart_q_ptr:" << kargs.seqstart_q_ptr
736  << " seqstart_k_ptr: " << kargs.seqstart_k_ptr
737  << " seqlen_k_ptr:" << kargs.seqlen_k_ptr;
738  if(kargs.seqlen_k_ptr != nullptr)
739  {
740  std::cout << "{";
741  for(int i_batch = 0; i_batch < num_batches; i_batch++)
742  std::cout << kargs.seqlen_k_ptr[i_batch] << ",";
743  std::cout << "}";
744  }
745  }
746  if constexpr(kHasMask)
747  {
748  std::cout << " window_size_left: " << kargs.window_size_left
749  << " window_size_right:" << kargs.window_size_right
750  << " mask_type: " << static_cast<int>(kargs.mask_type);
751  }
752 
753  if constexpr(kIsPagedKV)
754  {
755  std::cout << " block_table_ptr: " << kargs.block_table_ptr
756  << " batch_stride_block_table:" << kargs.batch_stride_block_table
757  << " page_block_size: " << kargs.page_block_size;
758 
759  std::cout << "table value: [";
760  for(int b = 0; b < num_batches; b++)
761  {
762  std::cout << "[ ";
763  for(int i = 0; i < kargs.batch_stride_block_table; i++)
764  {
765  std::cout << kargs.block_table_ptr[b * kargs.batch_stride_block_table + i]
766  << ",";
767  }
768  std::cout << " ]";
769  }
770  std::cout << " ]";
771  }
772  std::cout << std::endl;
773  return true;
774  }();
775  (void)dummy;
776  }
777  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
778  ck_tile::index_t nhead_,
779  ck_tile::index_t seqlen_q_,
780  ck_tile::index_t hdim_v_,
781  bool has_padded_seqlen_k)
782  {
783  // has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr)
784  if(has_padded_seqlen_k)
785  {
786  // TODO: this may need tuning
787  return dim3(nhead_,
788  batch_size_,
789  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
790  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1));
791  }
792  else
793  {
794  // TODO: this may need tuning
795  return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
796  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
797  nhead_,
798  batch_size_);
799  }
800  }
801 
802  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
803  {
804  bool has_padded_seqlen_k = false;
805 
806  if constexpr(kIsGroupMode)
807  has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr);
808 
809  if(has_padded_seqlen_k)
810  {
811  // const index_t num_tile_m0 = seqlen_q / kM0;
812  const index_t num_tile_n1 =
813  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
814 
815  const index_t i_block = blockIdx.z;
816  const index_t i_nhead = blockIdx.x;
817  const index_t i_batch = blockIdx.y;
818 
819  const auto f = [](index_t dividend, index_t divisor) {
820  index_t quotient = dividend / divisor;
821  index_t modulus = dividend - quotient * divisor;
822  return ck_tile::make_tuple(quotient, modulus);
823  };
824 
825  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
826 
827  if constexpr(kHasMask)
828  {
829  // assume that num_tile_n1 is always 1
830  return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
831  }
832  else
833  {
834  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
835  }
836  }
837  else
838  {
839  // const index_t num_tile_m0 = seqlen_q / kM0;
840  const index_t num_tile_n1 =
841  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
842 
843  const index_t i_block = blockIdx.x;
844  const index_t i_nhead = blockIdx.y;
845  const index_t i_batch = blockIdx.z;
846 
847  const auto f = [](index_t dividend, index_t divisor) {
848  index_t quotient = dividend / divisor;
849  index_t modulus = dividend - quotient * divisor;
850  return ck_tile::make_tuple(quotient, modulus);
851  };
852 
853  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
854 
855  if constexpr(kHasMask)
856  {
857  // assume that num_tile_n1 is always 1
858  return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
859  }
860  else
861  {
862  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
863  }
864  }
865  }
866 
867  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
868 
870  {
871  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
872  }
873 
874  CK_TILE_DEVICE void operator()(Kargs kargs) const
875  {
876  // allocate LDS
877  __shared__ char smem_ptr[GetSmemSize()];
878 
879  // divide problem
880  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
881 
882  const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
883  const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
884 
885  long_index_t batch_offset_q = 0;
886  long_index_t batch_offset_k = 0;
887  long_index_t batch_offset_v = 0;
888  long_index_t batch_offset_bias = 0;
889  long_index_t batch_offset_lse = 0;
890  long_index_t batch_offset_o = 0;
891  index_t kv_l2p_offset = 0;
892 
893  if constexpr(kIsGroupMode)
894  {
895  // get starting offset for each batch
896  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
897  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
898 
899  batch_offset_q = query_start * kargs.stride_q;
900  batch_offset_k = key_start * kargs.stride_k;
901  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
902  {
903  batch_offset_v = key_start * kargs.stride_v;
904  }
905  else
906  {
907  batch_offset_v = key_start;
908  }
910  {
911  batch_offset_bias = query_start * kargs.stride_bias;
912  }
913  if constexpr(kStoreLSE)
914  {
915  batch_offset_lse = query_start;
916  }
917 
918  batch_offset_o = query_start * kargs.stride_o;
919 
920  // get real # queries & # keys under group mode
921  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
922  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
923 
924  if constexpr(kSkipMinSeqlenQ)
925  {
926  if(kargs.seqlen_q <= kargs.min_seqlen_q)
927  {
928  return;
929  }
930  }
931 
932  // # of required blocks is different in each groups, terminate unnecessary blocks
933  // earlier
934  if(kargs.seqlen_q <= i_m0)
935  {
936  return;
937  }
938 
939  if(kargs.seqlen_k_ptr != nullptr)
940  {
941  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
942  }
943  else
944  {
945  const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
946  kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
947  }
948 
949  if constexpr(kIsPagedKV)
950  {
951  if(kargs.is_gappy)
952  {
953  // seqstart_k_ptr has different meaning in this case
954  kv_l2p_offset = kargs.seqstart_k_ptr[i_batch];
955  }
956  }
957  }
958  else
959  {
960  const index_t i_cache_batch = [&, i_batch_ = i_batch] {
961  if constexpr(kIsPagedKV)
962  {
963  return i_batch_;
964  }
965  else
966  {
967  return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_]
968  : i_batch_);
969  }
970  }();
971 
972  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
973  batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
974  batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
976  {
977  batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
978  }
979  if constexpr(kStoreLSE)
980  {
981  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
982  }
983 
984  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
985 
986  if(kargs.seqlen_k_ptr != nullptr)
987  {
988  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
989  }
990  }
991 
992  // for simplicity, batch stride we just modify the pointer
993  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
994  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
995  batch_offset_q;
996  const KDataType* k_ptr =
997  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
998  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
999  batch_offset_k;
1000  const VDataType* v_ptr =
1001  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
1002  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
1003  batch_offset_v;
1004  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
1005  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1006  batch_offset_o;
1007 
1008  // Q/K/V DRAM and DRAM window
1009  const auto q_dram = [&]() {
1010  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1011  q_ptr,
1012  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1013  make_tuple(kargs.stride_q, 1),
1015  number<1>{});
1016  if constexpr(FmhaPipeline::kQLoadOnce)
1017  {
1018  return pad_tensor_view(
1019  q_dram_naive,
1022  }
1023  else
1024  {
1025  return pad_tensor_view(
1026  q_dram_naive,
1029  }
1030  }();
1031 
1032  const auto make_k_dram = [&](const KDataType* data, index_t height) {
1033  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1034  data, // will update this pointer if using paged-kvcache
1035  make_tuple(height, kargs.hdim_q),
1036  make_tuple(kargs.stride_k, 1),
1038  number<1>{});
1039 
1040  return pad_tensor_view(
1041  k_dram_naive,
1044  };
1045  const auto k_dram = [&]() {
1046  if constexpr(kIsPagedKV)
1047  {
1048  return make_k_dram(nullptr, kargs.page_block_size);
1049  }
1050  else
1051  {
1052  return make_k_dram(k_ptr, kargs.seqlen_k);
1053  }
1054  }();
1055 
1056  const auto make_v_dram = [&](const VDataType* data, index_t length) {
1057  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1058  {
1059  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1060  data, // will update this pointer if using paged-kvcache
1061  make_tuple(length, kargs.hdim_v),
1062  make_tuple(kargs.stride_v, 1),
1064  number<1>{});
1065 
1066  const auto v_dram_transposed =
1067  transform_tensor_view(v_dram_naive,
1069  make_pass_through_transform(length)),
1072 
1073  return pad_tensor_view(
1074  v_dram_transposed,
1077  }
1078  else
1079  {
1080  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1081  data, // will update this pointer if using paged-kvcache
1082  make_tuple(kargs.hdim_v, length),
1083  make_tuple(kargs.stride_v, 1),
1085  number<1>{});
1086 
1087  return pad_tensor_view(
1088  v_dram_naive,
1091  }
1092  };
1093  const auto v_dram = [&]() {
1094  if constexpr(kIsPagedKV)
1095  {
1096  return make_v_dram(nullptr, kargs.page_block_size);
1097  }
1098  else
1099  {
1100  return make_v_dram(v_ptr, kargs.seqlen_k);
1101  }
1102  }();
1103 
1104  auto q_dram_window = make_tile_window(
1105  q_dram,
1106  [&]() {
1107  if constexpr(FmhaPipeline::kQLoadOnce)
1110  else
1112  }(),
1113  {i_m0, 0});
1114 
1115  auto k_page_block_navigator =
1116  [&, i_batch_ = i_batch, i_nhead_ = i_nhead / kargs.nhead_ratio_qk]() {
1117  if constexpr(kIsPagedKV)
1118  {
1119  const auto* block_indices =
1120  reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
1121  i_batch_ * kargs.batch_stride_block_table;
1122  const index_t num_blocks =
1123  integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
1124 
1125  const long_index_t fixed_offset = i_nhead_ * kargs.nhead_stride_k;
1126 
1127  return make_page_block_navigator<const KDataType, 0>(
1128  kargs.k_ptr,
1129  kargs.batch_stride_k, // kcache page-block stride/size
1130  fixed_offset,
1131  block_indices,
1132  num_blocks,
1133  kargs.page_block_size,
1134  k_dram,
1135  make_k_dram(nullptr,
1136  (kv_l2p_offset + kargs.seqlen_k) -
1137  (num_blocks - 1) * kargs.page_block_size));
1138  }
1139  else
1140  {
1141  return make_page_block_navigator(k_dram);
1142  }
1143  }();
1144 
1145  auto v_page_block_navigator =
1146  [&, i_batch_ = i_batch, i_nhead_ = i_nhead / kargs.nhead_ratio_qk]() {
1147  if constexpr(kIsPagedKV)
1148  {
1149  const auto* block_indices =
1150  reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
1151  i_batch_ * kargs.batch_stride_block_table;
1152  const index_t num_blocks =
1153  integer_divide_ceil(kv_l2p_offset + kargs.seqlen_k, kargs.page_block_size);
1154 
1155  const long_index_t fixed_offset = i_nhead_ * kargs.nhead_stride_v;
1156 
1157  return make_page_block_navigator<const VDataType, 1>(
1158  kargs.v_ptr,
1159  kargs.batch_stride_v, // vcache page-block stride/size
1160  fixed_offset,
1161  block_indices,
1162  num_blocks,
1163  kargs.page_block_size,
1164  v_dram,
1165  make_v_dram(nullptr,
1166  (kv_l2p_offset + kargs.seqlen_k) -
1167  (num_blocks - 1) * kargs.page_block_size));
1168  }
1169  else
1170  {
1171  return make_page_block_navigator(v_dram);
1172  }
1173  }();
1174 
1175  auto k_dram_window_lengths =
1177  auto v_dram_window_lengths =
1179 
1182  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
1183  constexpr auto bias_dram_window_lengths =
1186  {
1187  const BiasDataType* bias_ptr =
1188  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
1189  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
1190  batch_offset_bias;
1191 
1192  const auto bias_dram = [&]() {
1193  const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1194  bias_ptr,
1195  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1196  make_tuple(kargs.stride_bias, 1),
1198  number<1>{});
1199 
1200  return pad_tensor_view(bias_dram_naive,
1201  bias_dram_window_lengths,
1203  }();
1204 
1205  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
1206  }
1207  else
1208  {
1209  return make_null_tile_window(bias_dram_window_lengths);
1210  }
1211  }();
1212 
1213  // lse
1214  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
1215  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
1216  if constexpr(kStoreLSE)
1217  {
1218  LSEDataType* lse_ptr =
1219  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
1220  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
1221 
1222  const auto lse_dram = [&]() {
1223  const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1224  lse_ptr,
1225  make_tuple(kargs.seqlen_q),
1226  make_tuple(1),
1227  number<1>{},
1228  number<1>{});
1229 
1230  return pad_tensor_view(
1231  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
1232  }();
1233 
1234  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
1235  }
1236  else
1237  {
1238  return make_null_tile_window(lse_dram_window_lengths);
1239  }
1240  }();
1241 
1242  FmhaMask mask = [&]() {
1243  if constexpr(kHasMask)
1244  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
1245  kargs.window_size_left,
1246  kargs.window_size_right,
1247  kargs.seqlen_q,
1248  kargs.seqlen_k,
1250  else
1251  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1252  }();
1253 
1254  // WA i_batch capture structure binding before c++20
1255  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1256  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
1257  {
1258  // data loading, shared by entire wg
1259  // TODO: how to use s_read?
1260  SaccDataType slope =
1261  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1262  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1263 #if CK_TILE_FMHA_FWD_FAST_EXP2
1264  slope *= ck_tile::log2e_v<>;
1265 #endif
1266  if constexpr(kHasMask)
1267  {
1268  return make_alibi_from_lr_mask<SaccDataType, true>(slope,
1269  kargs.window_size_left,
1270  kargs.window_size_right,
1271  kargs.seqlen_q,
1272  kargs.seqlen_k,
1273  kargs.mask_type);
1274  }
1275  else
1276  {
1278  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1279  }
1280  }
1281  else
1282  {
1284  }
1285  }();
1286 
1287  AttentionVariant variant;
1288  const auto variant_params = [&] {
1289  if constexpr(kHasLogitsSoftCap)
1290  {
1292  mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1293  }
1294  else
1295  {
1296  return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
1297  }
1298  }();
1299 
1300  BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
1301 
1302  auto o_acc_tile = [&]() {
1303  if constexpr(kDoFp8StaticQuant)
1304  {
1305  return FmhaPipeline{}(
1306  q_dram_window,
1307  identity{}, // q_element_func
1308  k_dram_window_lengths,
1309  k_page_block_navigator,
1310  identity{}, // k_element_func
1311  v_dram_window_lengths,
1312  v_page_block_navigator,
1313  identity{}, // v_element_func
1314  bias_dram_window,
1315  identity{}, // bias_element_func
1316  lse_dram_window,
1317  identity{}, // lse_element_func
1318  identity{}, // s_acc_element_func
1319  scales{kargs.scale_p}, // p_compute_element_func
1320  composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
1321  mask,
1322  position_encoding,
1323  kargs.scale_s,
1324  variant,
1325  variant_params,
1326  block_indices,
1327  kv_l2p_offset,
1328  smem_ptr);
1329  }
1330  else
1331  {
1332  return FmhaPipeline{}(q_dram_window,
1333  k_dram_window_lengths,
1334  k_page_block_navigator,
1335  v_dram_window_lengths,
1336  v_page_block_navigator,
1337  bias_dram_window,
1338  lse_dram_window,
1339  mask,
1340  position_encoding,
1341  kargs.scale_s,
1342  variant,
1343  variant_params,
1344  block_indices,
1345  kv_l2p_offset,
1346  smem_ptr);
1347  }
1348  }();
1349 
1350  // O DRAM and O DRAM window
1351  auto o_dram = [&]() {
1352  const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1353  o_ptr,
1354  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1355  make_tuple(kargs.stride_o, 1),
1357  number<1>{});
1358 
1359  return pad_tensor_view(
1360  o_dram_naive,
1363  }();
1364 
1365  auto o_dram_window =
1366  make_tile_window(o_dram,
1368  {i_m0, i_n1});
1369 
1370  EpiloguePipeline{}(o_dram_window, o_acc_tile);
1371  }
1372 };
1373 
1374 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
#define _TS_
#define _SS_
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:63
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
_Float16 fp16_t
Definition: half.hpp:110
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView &tensor_view)
Definition: page_block_navigator.hpp:333
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:510
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:529
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1615
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
__host__ __device__ composes(Ts &&...) -> composes< remove_cvref_t< Ts >... >
FIXME: create macro to replace 'host device' and nothing more.
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: block_position_encoding.hpp:48
Definition: block_attention_bias_enum.hpp:19
Definition: block_position_encoding.hpp:137
Definition: fmha_fwd_pagedkv_kernel.hpp:280
ck_tile::index_t qo_head_idx
Definition: fmha_fwd_pagedkv_kernel.hpp:282
ck_tile::index_t batch_idx
Definition: fmha_fwd_pagedkv_kernel.hpp:281
ck_tile::index_t kv_head_idx
Definition: fmha_fwd_pagedkv_kernel.hpp:283
Definition: fmha_fwd_pagedkv_kernel.hpp:226
const int32_t * cache_batch_idx
Definition: fmha_fwd_pagedkv_kernel.hpp:227
Definition: fmha_fwd_pagedkv_kernel.hpp:214
ck_tile::index_t batch_stride_block_table
Definition: fmha_fwd_pagedkv_kernel.hpp:216
const int32_t * block_table_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:215
ck_tile::index_t page_block_size
Definition: fmha_fwd_pagedkv_kernel.hpp:217
Definition: fmha_fwd_pagedkv_kernel.hpp:182
ck_tile::index_t alibi_slope_stride
Definition: fmha_fwd_pagedkv_kernel.hpp:185
const void * alibi_slope_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:184
Definition: fmha_fwd_pagedkv_kernel.hpp:177
ck_tile::index_t batch_stride_bias
Definition: fmha_fwd_pagedkv_kernel.hpp:178
Definition: fmha_fwd_pagedkv_kernel.hpp:242
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_pagedkv_kernel.hpp:246
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_pagedkv_kernel.hpp:245
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:243
ck_tile::index_t batch_stride_o
Definition: fmha_fwd_pagedkv_kernel.hpp:250
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_pagedkv_kernel.hpp:248
Definition: fmha_fwd_pagedkv_kernel.hpp:170
ck_tile::index_t stride_bias
Definition: fmha_fwd_pagedkv_kernel.hpp:172
const void * bias_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:171
ck_tile::index_t nhead_stride_bias
Definition: fmha_fwd_pagedkv_kernel.hpp:173
Definition: fmha_fwd_pagedkv_kernel.hpp:119
ck_tile::index_t hdim_v
Definition: fmha_fwd_pagedkv_kernel.hpp:128
ck_tile::index_t seqlen_q
Definition: fmha_fwd_pagedkv_kernel.hpp:125
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_pagedkv_kernel.hpp:142
ck_tile::index_t stride_o
Definition: fmha_fwd_pagedkv_kernel.hpp:139
const void * k_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:121
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_pagedkv_kernel.hpp:141
ck_tile::index_t stride_v
Definition: fmha_fwd_pagedkv_kernel.hpp:138
float scale_s
Definition: fmha_fwd_pagedkv_kernel.hpp:134
ck_tile::index_t stride_k
Definition: fmha_fwd_pagedkv_kernel.hpp:137
const void * v_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:122
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_pagedkv_kernel.hpp:143
ck_tile::index_t nhead_stride_o
Definition: fmha_fwd_pagedkv_kernel.hpp:144
ck_tile::index_t hdim_q
Definition: fmha_fwd_pagedkv_kernel.hpp:127
ck_tile::index_t seqlen_k
Definition: fmha_fwd_pagedkv_kernel.hpp:126
const void * q_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:120
ck_tile::index_t num_head_q
Definition: fmha_fwd_pagedkv_kernel.hpp:130
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_pagedkv_kernel.hpp:133
void * o_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:123
ck_tile::index_t stride_q
Definition: fmha_fwd_pagedkv_kernel.hpp:136
Definition: fmha_fwd_pagedkv_kernel.hpp:202
ck_tile::index_t batch_stride_lse
Definition: fmha_fwd_pagedkv_kernel.hpp:205
void * lse_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:203
ck_tile::index_t nhead_stride_lse
Definition: fmha_fwd_pagedkv_kernel.hpp:204
Definition: fmha_fwd_pagedkv_kernel.hpp:112
Definition: fmha_fwd_pagedkv_kernel.hpp:196
float scale_p
Definition: fmha_fwd_pagedkv_kernel.hpp:197
float scale_o
Definition: fmha_fwd_pagedkv_kernel.hpp:198
Definition: fmha_fwd_pagedkv_kernel.hpp:266
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:269
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:268
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_pagedkv_kernel.hpp:271
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_pagedkv_kernel.hpp:273
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_pagedkv_kernel.hpp:267
Definition: fmha_fwd_pagedkv_kernel.hpp:148
float logits_soft_cap
Definition: fmha_fwd_pagedkv_kernel.hpp:165
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_fwd_pagedkv_kernel.hpp:151
float logits_soft_cap_rcp
Definition: fmha_fwd_pagedkv_kernel.hpp:166
Definition: fmha_fwd_pagedkv_kernel.hpp:189
ck_tile::index_t window_size_left
Definition: fmha_fwd_pagedkv_kernel.hpp:191
ck_tile::index_t window_size_right
Definition: fmha_fwd_pagedkv_kernel.hpp:191
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_pagedkv_kernel.hpp:192
Definition: fmha_fwd_pagedkv_kernel.hpp:209
ck_tile::index_t min_seqlen_q
Definition: fmha_fwd_pagedkv_kernel.hpp:210
Definition: fmha_fwd_pagedkv_kernel.hpp:221
bool is_gappy
Definition: fmha_fwd_pagedkv_kernel.hpp:222
Definition: fmha_fwd_pagedkv_kernel.hpp:65
Definition: fmha_fwd_pagedkv_kernel.hpp:28
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, const void *cache_batch_idx, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type)
Definition: fmha_fwd_pagedkv_kernel.hpp:288
static constexpr bool kIsGroupMode
Definition: fmha_fwd_pagedkv_kernel.hpp:46
static constexpr bool kHasLogitsSoftCap
Definition: fmha_fwd_pagedkv_kernel.hpp:51
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_pagedkv_kernel.hpp:30
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_pagedkv_kernel.hpp:32
static constexpr bool kStoreLSE
Definition: fmha_fwd_pagedkv_kernel.hpp:53
static CK_TILE_HOST std::string GetName()
Definition: fmha_fwd_pagedkv_kernel.hpp:73
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_pagedkv_kernel.hpp:48
static constexpr bool kIsPagedKV
Definition: fmha_fwd_pagedkv_kernel.hpp:56
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, bool is_gappy, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q)
Definition: fmha_fwd_pagedkv_kernel.hpp:626
static constexpr bool kSkipMinSeqlenQ
Definition: fmha_fwd_pagedkv_kernel.hpp:55
static CK_TILE_HOST void PrintParameters(const Kargs &kargs, int num_batches)
Definition: fmha_fwd_pagedkv_kernel.hpp:705
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, const void *cache_batch_idx, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type)
Definition: fmha_fwd_pagedkv_kernel.hpp:415
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_pagedkv_kernel.hpp:31
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_pagedkv_kernel.hpp:50
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_pagedkv_kernel.hpp:40
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_pagedkv_kernel.hpp:38
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_fwd_pagedkv_kernel.hpp:802
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_fwd_pagedkv_kernel.hpp:34
static constexpr bool kUseAsyncCopy
Definition: fmha_fwd_pagedkv_kernel.hpp:62
static constexpr bool kHasMask
Definition: fmha_fwd_pagedkv_kernel.hpp:60
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_, bool has_padded_seqlen_k)
Definition: fmha_fwd_pagedkv_kernel.hpp:777
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_pagedkv_kernel.hpp:29
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_pagedkv_kernel.hpp:41
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_fwd_pagedkv_kernel.hpp:867
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_fwd_pagedkv_kernel.hpp:39
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_pagedkv_kernel.hpp:49
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_fwd_pagedkv_kernel.hpp:44
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_pagedkv_kernel.hpp:47
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_pagedkv_kernel.hpp:42
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_pagedkv_kernel.hpp:59
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_pagedkv_kernel.hpp:869
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_pagedkv_kernel.hpp:874
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_fwd_pagedkv_kernel.hpp:58
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_pagedkv_kernel.hpp:37
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_pagedkv_kernel.hpp:36
static constexpr bool kDoFp8StaticQuant
Definition: fmha_fwd_pagedkv_kernel.hpp:54
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_fwd_pagedkv_kernel.hpp:277
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, const void *block_table_ptr, ck_tile::index_t batch_stride_block_table, ck_tile::index_t page_block_size, bool is_gappy, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q)
Definition: fmha_fwd_pagedkv_kernel.hpp:502
static constexpr auto BiasEnum
Definition: fmha_fwd_pagedkv_kernel.hpp:52
Definition: variants.hpp:63
float logits_soft_cap
Definition: variants.hpp:128
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: functional.hpp:86
Definition: unary_element_function.hpp:56
Definition: math.hpp:28
Definition: sequence.hpp:52