/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File
fmha_fwd_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
11 
12 #include <string>
13 #include <type_traits>
14 #include <utility>
15 #include <variant>
16 
17 #define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0
18 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
19 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
20 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
21 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
22 // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
23 
24 namespace ck_tile {
25 
26 template <typename FmhaPipeline_, typename EpiloguePipeline_>
28 {
31  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
32 
33  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
34  static_assert(kBlockPerCu > 0);
35  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
36 
47 
49 
50  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
51  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
52  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
53  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
54  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
55  static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
56  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
57  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
58  static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
59  static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
60  static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
61 
64  static constexpr bool kHasMask = FmhaMask::IsMasking;
65 
66  static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
67 
68  static constexpr bool kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad;
69 #if defined(__gfx950__)
70  static constexpr bool kIsAvailable = true;
71 #else
72  static constexpr bool kIsAvailable = !kUseTrLoad;
73 #endif
74  static constexpr std::string_view kPipelineName = FmhaPipeline::name;
75 
76  // clang-format off
77  template <typename T1, typename T2 = T1> struct t2s;
78  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
79  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
80  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
81  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
82  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
83  template <> struct t2s<ck_tile::fp8_t, ck_tile::bf16_t> { static constexpr const char * name = "fp8bf16"; };
84  template <> struct t2s<ck_tile::fp8_t, ck_tile::fp32_t> { static constexpr const char * name = "fp8fp32"; };
85  // clang-format on
86 
87  CK_TILE_HOST static std::string GetName()
88  {
89  // sync with generate.py
90  // clang-format off
91  using bfs = typename FmhaPipeline::BlockFmhaShape;
92  using g0br = typename bfs::Gemm0BlockWarps;
93  using g1br = typename bfs::Gemm1BlockWarps;
94  using g0wt = typename bfs::Gemm0WarpTile;
95  using g1wt = typename bfs::Gemm1WarpTile;
96  #define _SS_ std::string
97  #define _TS_ std::to_string
98  auto pn = [&] () {
99  std::string n;
100  if (kPadSeqLenQ) n += "s";
101  if (kPadSeqLenK) n += "sk";
102  if (kPadHeadDimQ) n += "d";
103  if (kPadHeadDimV) n += "dv";
104  return n.empty() ? n : std::string("p") + n; }();
105  return
106  _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType, ODataType>::name) +
107  "_" + (kIsGroupMode ? "group" : "batch") + "_"
108  "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
109  _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
110  "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
111  "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
112  "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
113  "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
114  (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
115  "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
116  (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
117  (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) +
119  #undef _SS_
120  #undef _TS_
121  // clang-format on
122  }
123 
124  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
125  // arg
127  {
128  };
129 
130  // kargs use aggregate initializer, so no constructor will provided
131  // use inheritance to minimize karg size
132  // user need to use MakeKargs() function to create kargs.
134  {
135  const void* q_ptr;
136  const void* k_ptr;
137  const void* v_ptr;
138  void* o_ptr;
139 
144 
146  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
147  // if this param is larger than 1, indicate MQA/GQA case
149  float scale_s;
150 
155 
160  };
161 
163  {
165 
166  void init_logits_soft_cap(float logits_soft_cap_)
167  {
168  if(0 < logits_soft_cap_)
169  {
170  logits_soft_cap = logits_soft_cap_;
172  }
173  else
174  {
175  logits_soft_cap = 0.f;
176  logits_soft_cap_rcp = 0.f;
177  }
178  }
179 
182  };
183 
185  {
186  const void* bias_ptr = nullptr;
189  };
190 
192  {
194  };
195 
197  {
198  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
199  const void* alibi_slope_ptr;
200  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
201  };
202 
204  {
205  // ck_tile::index_t window_size_left, window_size_right;
208  };
209 
211  {
212  const void* q_descale_ptr = nullptr;
213  const void* k_descale_ptr = nullptr;
214  const void* v_descale_ptr = nullptr;
215  };
216 
218  {
219  void* lse_ptr = nullptr;
222  };
223 
225  {
226  template <typename T>
228  {
229  T val;
230  const T* ptr;
231  };
232 
236  };
237 
239  {
240  void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
241  {
242  float p_undrop = 1.0 - p_drop;
245  rp_undrop = 1.0 / p_undrop;
246 
247  this->drop_seed.val = seed;
248  this->drop_offset.val = offset;
249  this->is_drop_seed_offset_from_host = true;
250  }
251 
252  void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
253  {
254  float p_undrop = 1.0 - p_drop;
257  rp_undrop = 1.0 / p_undrop;
258 
259  this->drop_seed.ptr = seed_ptr;
260  this->drop_offset.ptr = offset_ptr;
261  this->is_drop_seed_offset_from_host = false;
262  }
263 
264  float rp_undrop = 1;
266  bool is_store_randval = false;
267  void* rand_val_ptr = nullptr;
268 
271  };
272 
274  {
276  };
277 
279  {
281  };
282 
285  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
286  FmhaFwdBatchModeBiasKargs,
287  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
288  FmhaFwdAlibiKargs,
289  FmhaFwdEmptyKargs<0>>>,
290  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
291  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
292  std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
293  FmhaFwdCommonQScaleKargs,
294  FmhaFwdEmptyKargs<3>>,
295  std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
296  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
297  {
302 
303  // Optional cumulative sequence length pointers for batch mode
304  // If provided, they override seqlen_q / seqlen_k per-batch to skip tail padding.
305  const int32_t* cu_seqlen_q_ptr = nullptr; // cumulative, length without PAD
306  const int32_t* cu_seqlen_k_ptr = nullptr; // cumulative, length without PAD
307  };
308 
311  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
312  FmhaFwdCommonBiasKargs,
313  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
314  FmhaFwdAlibiKargs,
315  FmhaFwdEmptyKargs<0>>>,
316  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
317  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
318  std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR,
319  FmhaFwdCommonQScaleKargs,
320  FmhaFwdEmptyKargs<3>>,
321  std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
322  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
323  std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
324  {
329 
330  // Optional per-sequence and cumulative logical (excluding padding) sequence length arrays
331  const int32_t* cu_seqlen_q_ptr = nullptr;
332  const int32_t* cu_seqlen_k_ptr = nullptr;
333  };
334 
335  using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
336 
338  {
342  };
343 
344  template <bool Cond = !kIsGroupMode>
345  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
346  MakeKargsImpl(const void* q_ptr,
347  const void* k_ptr,
348  const void* v_ptr,
349  const void* bias_ptr,
350  const void* q_descale_ptr,
351  const void* k_descale_ptr,
352  const void* v_descale_ptr,
353  void* rand_val_ptr,
354  void* lse_ptr,
355  void* o_ptr,
356  ck_tile::index_t seqlen_q,
357  ck_tile::index_t seqlen_k,
358  ck_tile::index_t hdim_q,
359  ck_tile::index_t hdim_v,
360  ck_tile::index_t num_head_q,
361  ck_tile::index_t nhead_ratio_qk,
362  float scale_s,
363  float logits_soft_cap,
364  ck_tile::index_t stride_q,
365  ck_tile::index_t stride_k,
366  ck_tile::index_t stride_v,
367  ck_tile::index_t stride_bias,
368  ck_tile::index_t stride_randval,
369  ck_tile::index_t stride_o,
370  ck_tile::index_t nhead_stride_q,
371  ck_tile::index_t nhead_stride_k,
372  ck_tile::index_t nhead_stride_v,
373  ck_tile::index_t nhead_stride_bias,
374  ck_tile::index_t nhead_stride_randval,
375  ck_tile::index_t nhead_stride_lse,
376  ck_tile::index_t nhead_stride_o,
377  ck_tile::index_t batch_stride_q,
378  ck_tile::index_t batch_stride_k,
379  ck_tile::index_t batch_stride_v,
380  ck_tile::index_t batch_stride_bias,
381  ck_tile::index_t batch_stride_randval,
382  ck_tile::index_t batch_stride_lse,
383  ck_tile::index_t batch_stride_o,
384  ck_tile::index_t window_size_left,
385  ck_tile::index_t window_size_right,
386  ck_tile::index_t mask_type,
387  float p_drop,
388  bool s_randval,
389  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
390  drop_seed_offset,
391  const void* cu_seqlen_q_ptr = nullptr,
392  const void* cu_seqlen_k_ptr = nullptr)
393  {
394  Kargs kargs{{q_ptr,
395  k_ptr,
396  v_ptr,
397  o_ptr,
398  seqlen_q,
399  seqlen_k,
400  hdim_q,
401  hdim_v,
402  num_head_q,
403  nhead_ratio_qk,
404 #if CK_TILE_FMHA_FWD_FAST_EXP2
405  static_cast<float>(scale_s * ck_tile::log2e_v<>),
406 #else
407  scale_s,
408 #endif
409  stride_q,
410  stride_k,
411  stride_v,
412  stride_o,
413  nhead_stride_q,
414  nhead_stride_k,
415  nhead_stride_v,
416  nhead_stride_o}, // args for common karg
417  {}, // placeholder for bias
418  {}, // placeholder for mask
419  {}, // placeholder for lse
420  {}, // placeholder for qscale
421  {}, // placeholder for dropout
422  {}, // placeholder for logits_soft_cap
423  batch_stride_q,
424  batch_stride_k,
425  batch_stride_v,
426  batch_stride_o};
427 
429  {
430  kargs.bias_ptr = bias_ptr;
431  kargs.stride_bias = stride_bias;
432  kargs.nhead_stride_bias = nhead_stride_bias;
433  kargs.batch_stride_bias = batch_stride_bias;
434  }
435  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
436  {
437  kargs.alibi_slope_ptr = bias_ptr;
438  kargs.alibi_slope_stride = stride_bias;
439  }
440  if constexpr(kHasMask)
441  {
442  kargs.window_size_left = window_size_left;
443  kargs.window_size_right = window_size_right;
444  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
445  }
446  if constexpr(kStoreLSE)
447  {
448  kargs.lse_ptr = lse_ptr;
449  kargs.nhead_stride_lse = nhead_stride_lse;
450  kargs.batch_stride_lse = batch_stride_lse;
451  }
453  {
454  kargs.q_descale_ptr = q_descale_ptr;
455  kargs.k_descale_ptr = k_descale_ptr;
456  kargs.v_descale_ptr = v_descale_ptr;
457  }
458  if constexpr(kHasDropout)
459  {
460  if(drop_seed_offset.index() == 0) // seed & offset come from host
461  {
462  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
463  kargs.init_dropout(p_drop, seed, offset);
464  }
465  else // seed & offset come from device
466  {
467  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
468  kargs.init_dropout(p_drop,
469  reinterpret_cast<const uint64_t*>(seed_ptr),
470  reinterpret_cast<const uint64_t*>(offset_ptr));
471  }
472 
473  kargs.rand_val_ptr = rand_val_ptr;
474  kargs.stride_randval = stride_randval;
475  kargs.nhead_stride_randval = nhead_stride_randval;
476  kargs.batch_stride_randval = batch_stride_randval;
477  kargs.is_store_randval = s_randval;
478  }
479  if constexpr(kHasLogitsSoftCap)
480  {
481  kargs.init_logits_soft_cap(logits_soft_cap);
482  }
483 
484  kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
485  kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
486  return kargs;
487  }
488 
489  // std::variant<> can't take in a list initializer, overload for backward compatibility
490  template <bool Cond = !kIsGroupMode>
491  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
492  MakeKargs(const void* q_ptr,
493  const void* k_ptr,
494  const void* v_ptr,
495  const void* bias_ptr,
496  const void* q_descale_ptr,
497  const void* k_descale_ptr,
498  const void* v_descale_ptr,
499  void* rand_val_ptr,
500  void* lse_ptr,
501  void* o_ptr,
502  ck_tile::index_t seqlen_q,
503  ck_tile::index_t seqlen_k,
504  ck_tile::index_t hdim_q,
505  ck_tile::index_t hdim_v,
506  ck_tile::index_t num_head_q,
507  ck_tile::index_t nhead_ratio_qk,
508  float scale_s,
509  float logits_soft_cap,
510  ck_tile::index_t stride_q,
511  ck_tile::index_t stride_k,
512  ck_tile::index_t stride_v,
513  ck_tile::index_t stride_bias,
514  ck_tile::index_t stride_randval,
515  ck_tile::index_t stride_o,
516  ck_tile::index_t nhead_stride_q,
517  ck_tile::index_t nhead_stride_k,
518  ck_tile::index_t nhead_stride_v,
519  ck_tile::index_t nhead_stride_bias,
520  ck_tile::index_t nhead_stride_randval,
521  ck_tile::index_t nhead_stride_lse,
522  ck_tile::index_t nhead_stride_o,
523  ck_tile::index_t batch_stride_q,
524  ck_tile::index_t batch_stride_k,
525  ck_tile::index_t batch_stride_v,
526  ck_tile::index_t batch_stride_bias,
527  ck_tile::index_t batch_stride_randval,
528  ck_tile::index_t batch_stride_lse,
529  ck_tile::index_t batch_stride_o,
530  ck_tile::index_t window_size_left,
531  ck_tile::index_t window_size_right,
532  ck_tile::index_t mask_type,
533  float p_drop,
534  bool s_randval,
535  const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
536  const void* cu_seqlen_q_ptr = nullptr,
537  const void* cu_seqlen_k_ptr = nullptr)
538  {
539  return MakeKargsImpl(
540  q_ptr,
541  k_ptr,
542  v_ptr,
543  bias_ptr,
544  q_descale_ptr,
545  k_descale_ptr,
546  v_descale_ptr,
547  rand_val_ptr,
548  lse_ptr,
549  o_ptr,
550  seqlen_q,
551  seqlen_k,
552  hdim_q,
553  hdim_v,
554  num_head_q,
555  nhead_ratio_qk,
556  scale_s,
557  logits_soft_cap,
558  stride_q,
559  stride_k,
560  stride_v,
561  stride_bias,
562  stride_randval,
563  stride_o,
564  nhead_stride_q,
565  nhead_stride_k,
566  nhead_stride_v,
567  nhead_stride_bias,
568  nhead_stride_randval,
569  nhead_stride_lse,
570  nhead_stride_o,
571  batch_stride_q,
572  batch_stride_k,
573  batch_stride_v,
574  batch_stride_bias,
575  batch_stride_randval,
576  batch_stride_lse,
577  batch_stride_o,
578  window_size_left,
579  window_size_right,
580  mask_type,
581  p_drop,
582  s_randval,
583  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
584  cu_seqlen_q_ptr,
585  cu_seqlen_k_ptr);
586  }
587 
588  // std::variant<> can't take in a list initializer, overload for backward compatibility
589  template <bool Cond = !kIsGroupMode>
590  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
591  MakeKargs(const void* q_ptr,
592  const void* k_ptr,
593  const void* v_ptr,
594  const void* bias_ptr,
595  const void* q_descale_ptr,
596  const void* k_descale_ptr,
597  const void* v_descale_ptr,
598  void* rand_val_ptr,
599  void* lse_ptr,
600  void* o_ptr,
601  ck_tile::index_t seqlen_q,
602  ck_tile::index_t seqlen_k,
603  ck_tile::index_t hdim_q,
604  ck_tile::index_t hdim_v,
605  ck_tile::index_t num_head_q,
606  ck_tile::index_t nhead_ratio_qk,
607  float scale_s,
608  float logits_soft_cap,
609  ck_tile::index_t stride_q,
610  ck_tile::index_t stride_k,
611  ck_tile::index_t stride_v,
612  ck_tile::index_t stride_bias,
613  ck_tile::index_t stride_randval,
614  ck_tile::index_t stride_o,
615  ck_tile::index_t nhead_stride_q,
616  ck_tile::index_t nhead_stride_k,
617  ck_tile::index_t nhead_stride_v,
618  ck_tile::index_t nhead_stride_bias,
619  ck_tile::index_t nhead_stride_randval,
620  ck_tile::index_t nhead_stride_lse,
621  ck_tile::index_t nhead_stride_o,
622  ck_tile::index_t batch_stride_q,
623  ck_tile::index_t batch_stride_k,
624  ck_tile::index_t batch_stride_v,
625  ck_tile::index_t batch_stride_bias,
626  ck_tile::index_t batch_stride_randval,
627  ck_tile::index_t batch_stride_lse,
628  ck_tile::index_t batch_stride_o,
629  ck_tile::index_t window_size_left,
630  ck_tile::index_t window_size_right,
631  ck_tile::index_t mask_type,
632  float p_drop,
633  bool s_randval,
634  const std::tuple<const void*, const void*>& drop_seed_offset,
635  const void* cu_seqlen_q_ptr = nullptr,
636  const void* cu_seqlen_k_ptr = nullptr)
637  {
638  return MakeKargsImpl(
639  q_ptr,
640  k_ptr,
641  v_ptr,
642  bias_ptr,
643  q_descale_ptr,
644  k_descale_ptr,
645  v_descale_ptr,
646  rand_val_ptr,
647  lse_ptr,
648  o_ptr,
649  seqlen_q,
650  seqlen_k,
651  hdim_q,
652  hdim_v,
653  num_head_q,
654  nhead_ratio_qk,
655  scale_s,
656  logits_soft_cap,
657  stride_q,
658  stride_k,
659  stride_v,
660  stride_bias,
661  stride_randval,
662  stride_o,
663  nhead_stride_q,
664  nhead_stride_k,
665  nhead_stride_v,
666  nhead_stride_bias,
667  nhead_stride_randval,
668  nhead_stride_lse,
669  nhead_stride_o,
670  batch_stride_q,
671  batch_stride_k,
672  batch_stride_v,
673  batch_stride_bias,
674  batch_stride_randval,
675  batch_stride_lse,
676  batch_stride_o,
677  window_size_left,
678  window_size_right,
679  mask_type,
680  p_drop,
681  s_randval,
682  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
683  cu_seqlen_q_ptr,
684  cu_seqlen_k_ptr);
685  }
686 
687  template <bool Cond = kIsGroupMode>
688  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
689  MakeKargsImpl(const void* q_ptr,
690  const void* k_ptr,
691  const void* v_ptr,
692  const void* bias_ptr,
693  const void* q_descale_ptr,
694  const void* k_descale_ptr,
695  const void* v_descale_ptr,
696  void* rand_val_ptr,
697  void* lse_ptr,
698  void* o_ptr,
699  const void* seqstart_q_ptr,
700  const void* seqstart_k_ptr,
701  const void* seqlen_q_ptr,
702  const void* seqlen_k_ptr,
703  ck_tile::index_t hdim_q,
704  ck_tile::index_t hdim_v,
705  ck_tile::index_t num_head_q,
706  ck_tile::index_t nhead_ratio_qk,
707  float scale_s,
708  float logits_soft_cap,
709  ck_tile::index_t stride_q,
710  ck_tile::index_t stride_k,
711  ck_tile::index_t stride_v,
712  ck_tile::index_t stride_bias,
713  ck_tile::index_t stride_randval,
714  ck_tile::index_t stride_o,
715  ck_tile::index_t nhead_stride_q,
716  ck_tile::index_t nhead_stride_k,
717  ck_tile::index_t nhead_stride_v,
718  ck_tile::index_t nhead_stride_bias,
719  ck_tile::index_t nhead_stride_randval,
720  ck_tile::index_t nhead_stride_lse,
721  ck_tile::index_t nhead_stride_o,
722  ck_tile::index_t window_size_left,
723  ck_tile::index_t window_size_right,
724  ck_tile::index_t mask_type,
725  ck_tile::index_t min_seqlen_q,
726  float p_drop,
727  bool s_randval,
728  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
729  drop_seed_offset,
730  const void* cu_seqlen_q_ptr = nullptr,
731  const void* cu_seqlen_k_ptr = nullptr)
732  {
733  Kargs kargs{{q_ptr,
734  k_ptr,
735  v_ptr,
736  o_ptr,
737  -1, // seqlen will be updated by another pointer
738  -1, //
739  hdim_q,
740  hdim_v,
741  num_head_q,
742  nhead_ratio_qk,
743 #if CK_TILE_FMHA_FWD_FAST_EXP2
744  static_cast<float>(scale_s * ck_tile::log2e_v<>),
745 #else
746  scale_s,
747 #endif
748  stride_q,
749  stride_k,
750  stride_v,
751  stride_o,
752  nhead_stride_q,
753  nhead_stride_k,
754  nhead_stride_v,
755  nhead_stride_o}, // args for common karg
756  {}, // placeholder for bias
757  {}, // placeholder for mask
758  {}, // placeholder for lse
759  {}, // placeholder for qscale
760  {}, // placeholder for dropout
761  {}, // placeholder for logits_soft_cap
762  {}, // placeholder for min_seqlen_q
763  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
764  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
765  reinterpret_cast<const int32_t*>(seqlen_q_ptr),
766  reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
767 
769  {
770  kargs.bias_ptr = bias_ptr;
771  kargs.stride_bias = stride_bias;
772  kargs.nhead_stride_bias = nhead_stride_bias;
773  }
774  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
775  {
776  kargs.alibi_slope_ptr = bias_ptr;
777  kargs.alibi_slope_stride = stride_bias;
778  }
779  if constexpr(kHasMask)
780  {
781  kargs.window_size_left = window_size_left;
782  kargs.window_size_right = window_size_right;
783  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
784  }
785  if constexpr(kStoreLSE)
786  {
787  kargs.lse_ptr = lse_ptr;
788  kargs.nhead_stride_lse = nhead_stride_lse;
789  }
791  {
792  kargs.q_descale_ptr = q_descale_ptr;
793  kargs.k_descale_ptr = k_descale_ptr;
794  kargs.v_descale_ptr = v_descale_ptr;
795  }
796  if constexpr(kHasDropout)
797  {
798  if(drop_seed_offset.index() == 0) // seed & offset come from host
799  {
800  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
801  kargs.init_dropout(p_drop, seed, offset);
802  }
803  else // seed & offset come from device
804  {
805  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
806  kargs.init_dropout(p_drop,
807  reinterpret_cast<const uint64_t*>(seed_ptr),
808  reinterpret_cast<const uint64_t*>(offset_ptr));
809  }
810 
811  kargs.rand_val_ptr = rand_val_ptr;
812  kargs.stride_randval = stride_randval;
813  kargs.nhead_stride_randval = nhead_stride_randval;
814  kargs.is_store_randval = s_randval;
815  }
816  if constexpr(kHasLogitsSoftCap)
817  {
818  kargs.init_logits_soft_cap(logits_soft_cap);
819  }
820  if constexpr(kSkipMinSeqlenQ)
821  {
822  kargs.min_seqlen_q = min_seqlen_q;
823  }
824 
825  kargs.cu_seqlen_q_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_q_ptr);
826  kargs.cu_seqlen_k_ptr = reinterpret_cast<const int32_t*>(cu_seqlen_k_ptr);
827  return kargs;
828  }
829 
830  // std::variant<> can't take in a list initializer, overload for backward compatibility
831  template <bool Cond = kIsGroupMode>
832  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
833  MakeKargs(const void* q_ptr,
834  const void* k_ptr,
835  const void* v_ptr,
836  const void* bias_ptr,
837  const void* q_descale_ptr,
838  const void* k_descale_ptr,
839  const void* v_descale_ptr,
840  void* rand_val_ptr,
841  void* lse_ptr,
842  void* o_ptr,
843  const void* seqstart_q_ptr,
844  const void* seqstart_k_ptr,
845  const void* seqlen_q_ptr,
846  const void* seqlen_k_ptr,
847  ck_tile::index_t hdim_q,
848  ck_tile::index_t hdim_v,
849  ck_tile::index_t num_head_q,
850  ck_tile::index_t nhead_ratio_qk,
851  float scale_s,
852  float logits_soft_cap,
853  ck_tile::index_t stride_q,
854  ck_tile::index_t stride_k,
855  ck_tile::index_t stride_v,
856  ck_tile::index_t stride_bias,
857  ck_tile::index_t stride_randval,
858  ck_tile::index_t stride_o,
859  ck_tile::index_t nhead_stride_q,
860  ck_tile::index_t nhead_stride_k,
861  ck_tile::index_t nhead_stride_v,
862  ck_tile::index_t nhead_stride_bias,
863  ck_tile::index_t nhead_stride_randval,
864  ck_tile::index_t nhead_stride_lse,
865  ck_tile::index_t nhead_stride_o,
866  ck_tile::index_t window_size_left,
867  ck_tile::index_t window_size_right,
868  ck_tile::index_t mask_type,
869  ck_tile::index_t min_seqlen_q,
870  float p_drop,
871  bool s_randval,
872  const std::tuple<uint64_t, uint64_t>& drop_seed_offset,
873  const void* cu_seqlen_q_ptr = nullptr,
874  const void* cu_seqlen_k_ptr = nullptr)
875  {
876  return MakeKargsImpl(
877  q_ptr,
878  k_ptr,
879  v_ptr,
880  bias_ptr,
881  q_descale_ptr,
882  k_descale_ptr,
883  v_descale_ptr,
884  rand_val_ptr,
885  lse_ptr,
886  o_ptr,
887  seqstart_q_ptr,
888  seqstart_k_ptr,
889  seqlen_q_ptr,
890  seqlen_k_ptr,
891  hdim_q,
892  hdim_v,
893  num_head_q,
894  nhead_ratio_qk,
895  scale_s,
896  logits_soft_cap,
897  stride_q,
898  stride_k,
899  stride_v,
900  stride_bias,
901  stride_randval,
902  stride_o,
903  nhead_stride_q,
904  nhead_stride_k,
905  nhead_stride_v,
906  nhead_stride_bias,
907  nhead_stride_randval,
908  nhead_stride_lse,
909  nhead_stride_o,
910  window_size_left,
911  window_size_right,
912  mask_type,
913  min_seqlen_q,
914  p_drop,
915  s_randval,
916  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
917  cu_seqlen_q_ptr,
918  cu_seqlen_k_ptr);
919  }
920 
921  // std::variant<> can't take in a list initializer, overload for backward compatibility
922  template <bool Cond = kIsGroupMode>
923  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
924  MakeKargs(const void* q_ptr,
925  const void* k_ptr,
926  const void* v_ptr,
927  const void* bias_ptr,
928  const void* q_descale_ptr,
929  const void* k_descale_ptr,
930  const void* v_descale_ptr,
931  void* rand_val_ptr,
932  void* lse_ptr,
933  void* o_ptr,
934  const void* seqstart_q_ptr,
935  const void* seqstart_k_ptr,
936  const void* seqlen_q_ptr,
937  const void* seqlen_k_ptr,
938  ck_tile::index_t hdim_q,
939  ck_tile::index_t hdim_v,
940  ck_tile::index_t num_head_q,
941  ck_tile::index_t nhead_ratio_qk,
942  float scale_s,
943  float logits_soft_cap,
944  ck_tile::index_t stride_q,
945  ck_tile::index_t stride_k,
946  ck_tile::index_t stride_v,
947  ck_tile::index_t stride_bias,
948  ck_tile::index_t stride_randval,
949  ck_tile::index_t stride_o,
950  ck_tile::index_t nhead_stride_q,
951  ck_tile::index_t nhead_stride_k,
952  ck_tile::index_t nhead_stride_v,
953  ck_tile::index_t nhead_stride_bias,
954  ck_tile::index_t nhead_stride_randval,
955  ck_tile::index_t nhead_stride_lse,
956  ck_tile::index_t nhead_stride_o,
957  ck_tile::index_t window_size_left,
958  ck_tile::index_t window_size_right,
959  ck_tile::index_t mask_type,
960  ck_tile::index_t min_seqlen_q,
961  float p_drop,
962  bool s_randval,
963  const std::tuple<const void*, const void*>& drop_seed_offset,
964  const void* cu_seqlen_q_ptr = nullptr,
965  const void* cu_seqlen_k_ptr = nullptr)
966  {
967  return MakeKargsImpl(
968  q_ptr,
969  k_ptr,
970  v_ptr,
971  bias_ptr,
972  q_descale_ptr,
973  k_descale_ptr,
974  v_descale_ptr,
975  rand_val_ptr,
976  lse_ptr,
977  o_ptr,
978  seqstart_q_ptr,
979  seqstart_k_ptr,
980  seqlen_q_ptr,
981  seqlen_k_ptr,
982  hdim_q,
983  hdim_v,
984  num_head_q,
985  nhead_ratio_qk,
986  scale_s,
987  logits_soft_cap,
988  stride_q,
989  stride_k,
990  stride_v,
991  stride_bias,
992  stride_randval,
993  stride_o,
994  nhead_stride_q,
995  nhead_stride_k,
996  nhead_stride_v,
997  nhead_stride_bias,
998  nhead_stride_randval,
999  nhead_stride_lse,
1000  nhead_stride_o,
1001  window_size_left,
1002  window_size_right,
1003  mask_type,
1004  min_seqlen_q,
1005  p_drop,
1006  s_randval,
1007  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)),
1008  cu_seqlen_q_ptr,
1009  cu_seqlen_k_ptr);
1010  }
1011 
1012  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
1013  ck_tile::index_t nhead_,
1014  ck_tile::index_t seqlen_q_,
1015  ck_tile::index_t hdim_v_,
1016  bool has_padded_seqlen_k = false)
1017  {
1018  // has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr)
1019  if(has_padded_seqlen_k)
1020  {
1021  // TODO: this may need tuning
1022  return dim3(nhead_,
1023  batch_size_,
1024  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
1025  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1));
1026  }
1027  else
1028  {
1029  // TODO: this may need tuning
1030  return dim3(nhead_,
1031  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
1032  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
1033  batch_size_);
1034  }
1035  }
1036 
1037  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
1038  {
1039  bool has_padded_seqlen_k = false;
1040 
1041  if constexpr(kIsGroupMode)
1042  has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr);
1043 
1044  if(has_padded_seqlen_k)
1045  {
1046  // const index_t num_tile_m0 = seqlen_q / kM0;
1047  const index_t num_tile_n1 =
1048  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
1049 
1050  const index_t i_block = blockIdx.z;
1051  const index_t i_nhead = blockIdx.x;
1052  const index_t i_batch = blockIdx.y;
1053 
1054  const auto f = [](index_t dividend, index_t divisor) {
1055  index_t quotient = dividend / divisor;
1056  index_t modulus = dividend - quotient * divisor;
1057  return ck_tile::make_tuple(quotient, modulus);
1058  };
1059 
1060  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
1061 
1062  if constexpr(kHasMask)
1063  {
1064  // assume that num_tile_n1 is always 1
1065  return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
1066  }
1067  else
1068  {
1069  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
1070  }
1071  }
1072  else
1073  {
1074  // const index_t num_tile_m0 = seqlen_q / kM0;
1075  const index_t num_tile_n1 =
1076  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
1077 
1078  const index_t i_block = blockIdx.y; // blockIdx.x
1079  const index_t i_nhead = blockIdx.x; // blockIdx.y
1080  const index_t i_batch = blockIdx.z;
1081 
1082  const auto f = [](index_t dividend, index_t divisor) {
1083  index_t quotient = dividend / divisor;
1084  index_t modulus = dividend - quotient * divisor;
1085  return ck_tile::make_tuple(quotient, modulus);
1086  };
1087 
1088  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
1089 
1090  if constexpr(kHasMask)
1091  {
1092  // assume that num_tile_n1 is always 1
1093  return ck_tile::make_tuple(gridDim.y - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
1094  }
1095  else
1096  {
1097  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
1098  }
1099  }
1100  }
1101 
1102  CK_TILE_HOST static dim3 BlockSize()
1103  {
1104  if(is_wave32())
1105  {
1106  return dim3(kBlockSize / 2);
1107  }
1108  else
1109  {
1110  return dim3(kBlockSize);
1111  }
1112  }
1113 
1115  {
1116  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
1117  }
1118 
1119  CK_TILE_DEVICE void operator()(Kargs kargs) const
1120  {
1121  if constexpr(kIsAvailable)
1122  run_(std::move(kargs));
1123  }
1124 
1125  CK_TILE_DEVICE void run_(Kargs kargs) const
1126  {
1127  if constexpr(kPipelineName != "qr_async_trload")
1128  {
1129  // allocate LDS
1130  __shared__ char smem_ptr[GetSmemSize()];
1131 
1132  // divide problem
1133  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
1134 
1135  const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
1136  const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
1137 
1138  long_index_t batch_offset_q = 0;
1139  long_index_t batch_offset_k = 0;
1140  long_index_t batch_offset_v = 0;
1141  long_index_t batch_offset_bias = 0;
1142  long_index_t batch_offset_randval = 0;
1143  long_index_t batch_offset_lse = 0;
1144  long_index_t batch_offset_o = 0;
1145 
1146  if constexpr(kIsGroupMode)
1147  {
1148  // Use seqstart_q_ptr and seqstart_k_ptr for physical starts
1149  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
1150  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
1151 
1152  // DRAM base offsets use physical starts
1153  batch_offset_q = query_start * kargs.stride_q;
1154  batch_offset_k = key_start * kargs.stride_k;
1155  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1156  {
1157  batch_offset_v = key_start * kargs.stride_v;
1158  }
1159  else
1160  {
1161  batch_offset_v = key_start;
1162  }
1164  {
1165  batch_offset_bias = query_start * kargs.stride_bias;
1166  }
1167  if constexpr(kStoreLSE)
1168  {
1169  // LSE follows the physical layout to stay consistent with other tensors
1170  batch_offset_lse = query_start;
1171  }
1172  if constexpr(kHasDropout)
1173  {
1174  batch_offset_randval = query_start * kargs.stride_randval;
1175  }
1176  batch_offset_o = query_start * kargs.stride_o;
1177 
1178  // real logical lengths (exclude PAD)
1179  // Priority: seqlen_q_ptr > cu_seqlen_q_ptr > calculated from seqstart_q_ptr
1180  if(kargs.seqlen_q_ptr != nullptr)
1181  {
1182  kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch];
1183  }
1184  else if(kargs.cu_seqlen_q_ptr != nullptr)
1185  {
1186  kargs.seqlen_q =
1187  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1188  }
1189  else
1190  {
1191  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
1192  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
1193  }
1194 
1195  if constexpr(kSkipMinSeqlenQ)
1196  {
1197  if(kargs.seqlen_q <= kargs.min_seqlen_q)
1198  {
1199  return;
1200  }
1201  }
1202 
1203  // terminate unnecessary blocks earlier
1204  if(kargs.seqlen_q <= i_m0)
1205  {
1206  return;
1207  }
1208 
1209  if(kargs.seqlen_k_ptr != nullptr)
1210  {
1211  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
1212  }
1213  else if(kargs.cu_seqlen_k_ptr != nullptr)
1214  {
1215  kargs.seqlen_k =
1216  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1217  }
1218  else
1219  {
1220  const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
1221  kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
1222  }
1223  }
1224  else
1225  {
1226  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
1227  batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
1228  batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
1230  {
1231  batch_offset_bias =
1232  static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
1233  }
1234  if constexpr(kStoreLSE)
1235  {
1236  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
1237  }
1238  if constexpr(kHasDropout)
1239  {
1240  batch_offset_randval =
1241  static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
1242  }
1243  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
1244 
1245  // If cumulative seqlen pointers are provided, override per-batch effective lengths
1246  if(kargs.cu_seqlen_q_ptr != nullptr)
1247  {
1248  kargs.seqlen_q =
1249  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1250  }
1251  if(kargs.cu_seqlen_k_ptr != nullptr)
1252  {
1253  kargs.seqlen_k =
1254  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1255  }
1256  }
1257 
1258  // for simplicity, batch stride we just modify the pointer
1259  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
1260  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
1261  batch_offset_q;
1262  const KDataType* k_ptr =
1263  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
1264  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
1265  batch_offset_k;
1266  const VDataType* v_ptr =
1267  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
1268  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
1269  batch_offset_v;
1270  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
1271  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1272  batch_offset_o;
1273 
1274  // Q/K/V DRAM and DRAM window
1275  const auto q_dram = [&]() {
1276  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1277  q_ptr,
1278  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1279  make_tuple(kargs.stride_q, 1),
1281  number<1>{});
1282  if constexpr(FmhaPipeline::kQLoadOnce)
1283  {
1284  return pad_tensor_view(q_dram_naive,
1288  }
1289  else
1290  {
1291  return pad_tensor_view(
1292  q_dram_naive,
1295  }
1296  }();
1297  const auto k_dram = [&]() {
1298  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1299  k_ptr,
1300  make_tuple(kargs.seqlen_k, kargs.hdim_q),
1301  make_tuple(kargs.stride_k, 1),
1303  number<1>{});
1304 
1305  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
1306  return pad_tensor_view(
1307  k_dram_naive,
1310  }();
1311  const auto v_dram = [&]() {
1312  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1313  {
1314  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1315  v_ptr,
1316  make_tuple(kargs.seqlen_k, kargs.hdim_v),
1317  make_tuple(kargs.stride_v, 1),
1319  number<1>{});
1320 
1321  const auto v_dram_transposed = transform_tensor_view(
1322  v_dram_naive,
1324  make_pass_through_transform(kargs.seqlen_k)),
1327 
1328  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
1329  return pad_tensor_view(
1330  v_dram_transposed,
1333  }
1334  else
1335  {
1336  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1337  v_ptr,
1338  make_tuple(kargs.hdim_v, kargs.seqlen_k),
1339  make_tuple(kargs.stride_v, 1),
1341  number<1>{});
1342 
1343  constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false;
1344  return pad_tensor_view(
1345  v_dram_naive,
1348  }
1349  }();
1350 
1351  auto q_dram_window = make_tile_window(
1352  q_dram,
1353  [&]() {
1354  if constexpr(FmhaPipeline::kQLoadOnce)
1357  else
1359  }(),
1360  {i_m0, 0});
1361 
1362  auto k_dram_window = make_tile_window(
1363  k_dram,
1365  {0, 0});
1366 
1367  auto v_dram_window = make_tile_window(
1368  v_dram,
1370  {i_n1, 0});
1373  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
1374  constexpr auto bias_dram_window_lengths =
1377  {
1378  const BiasDataType* bias_ptr =
1379  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
1380  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
1381  batch_offset_bias;
1382 
1383  const auto bias_dram = [&]() {
1384  const auto bias_dram_naive =
1385  make_naive_tensor_view<address_space_enum::global>(
1386  bias_ptr,
1387  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1388  make_tuple(kargs.stride_bias, 1),
1390  number<1>{});
1391 
1392  return pad_tensor_view(bias_dram_naive,
1393  bias_dram_window_lengths,
1395  }();
1396 
1397  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
1398  }
1399  else
1400  {
1401  return make_null_tile_window(bias_dram_window_lengths);
1402  }
1403  }();
1404 
1405  // lse
1406  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
1407  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
1408  if constexpr(kStoreLSE)
1409  {
1410  LSEDataType* lse_ptr =
1411  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
1412  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse +
1413  batch_offset_lse;
1414 
1415  const auto lse_dram = [&]() {
1416  const auto lse_dram_naive =
1417  make_naive_tensor_view<address_space_enum::global>(
1418  lse_ptr,
1419  make_tuple(kargs.seqlen_q),
1420  make_tuple(1),
1421  number<1>{},
1422  number<1>{});
1423 
1424  return pad_tensor_view(
1425  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
1426  }();
1427 
1428  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
1429  }
1430  else
1431  {
1432  return make_null_tile_window(lse_dram_window_lengths);
1433  }
1434  }();
1435 
1436  auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
1437  if constexpr(kHasDropout)
1438  {
1439  return BlockDropout{i_batch_,
1440  i_nhead_,
1441  kargs.num_head_q,
1442  kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
1443  : *kargs.drop_seed.ptr,
1444  kargs.is_drop_seed_offset_from_host
1445  ? kargs.drop_offset.val
1446  : *kargs.drop_offset.ptr,
1447  kargs.rp_undrop,
1448  kargs.p_undrop_in_uint8_t,
1449  kargs.is_store_randval};
1450  }
1451  else
1452  {
1453  return NullBlockDropout{};
1454  };
1455  }();
1456 
1457  auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
1458  constexpr auto randval_dram_window_lengths =
1460  if constexpr(kHasDropout)
1461  {
1462  RandValOutputDataType* rand_val_ptr =
1463  reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
1464  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
1465  batch_offset_randval;
1466 
1467  const auto randval_dram = [&]() {
1468  const auto randval_dram_naive =
1469  make_naive_tensor_view<address_space_enum::global>(
1470  rand_val_ptr,
1471  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1472  make_tuple(kargs.stride_randval, 1),
1474  number<1>{});
1475 
1476  return pad_tensor_view(randval_dram_naive,
1477  randval_dram_window_lengths,
1479  }();
1480 
1481  return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0});
1482  }
1483  else
1484  {
1485  return make_null_tile_window(randval_dram_window_lengths);
1486  }
1487  }();
1488 
1489  FmhaMask mask = [&]() {
1490  if constexpr(kHasMask)
1491  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
1492  kargs.window_size_left,
1493  kargs.window_size_right,
1494  kargs.seqlen_q,
1495  kargs.seqlen_k,
1497  else
1498  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1499  }();
1500 
1501  // WA i_batch capture structure binding before c++20
1502  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1503  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
1504  {
1505  // data loading, shared by entire wg
1506  // TODO: how to use s_read?
1507  SaccDataType slope =
1508  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1509  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1510 #if CK_TILE_FMHA_FWD_FAST_EXP2
1511  slope *= ck_tile::log2e_v<>;
1512 #endif
1513  if constexpr(kHasMask)
1514  {
1515  return make_alibi_from_lr_mask<SaccDataType, true>(slope,
1516  kargs.window_size_left,
1517  kargs.window_size_right,
1518  kargs.seqlen_q,
1519  kargs.seqlen_k,
1520  kargs.mask_type);
1521  }
1522  else
1523  {
1525  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1526  }
1527  }
1528  else
1529  {
1531  }
1532  }();
1533 
1534  AttentionVariant variant;
1535  const auto variant_params = [&] {
1536  if constexpr(kHasLogitsSoftCap)
1537  {
1539  mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1540  }
1541  else
1542  {
1543  return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
1544  }
1545  }();
1546 
1547  BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
1548 
1549  auto o_acc_tile = [&]() {
1551  {
1552  // TODO - move global load of descale to pipeline
1553  float q_descale = *(reinterpret_cast<const float*>(kargs.q_descale_ptr));
1554  float k_descale = *(reinterpret_cast<const float*>(kargs.k_descale_ptr));
1555  float v_descale = *(reinterpret_cast<const float*>(kargs.v_descale_ptr));
1556 
1557  float scale_s = kargs.scale_s * q_descale * k_descale;
1558  float scale_p =
1559  ck_tile::type_convert<float>(ck_tile::numeric<PDataType>::max());
1560  float scale_o = v_descale / scale_p;
1561 
1562  auto o_acc_element_func = [&]() {
1563  if constexpr(std::is_same_v<ODataType, ck_tile::fp8_t>)
1565  ck_tile::scales{scale_o});
1566  else
1567  return ck_tile::scales{scale_o};
1568  }();
1569  return FmhaPipeline{}(q_dram_window,
1570  identity{}, // q_element_func
1571  k_dram_window,
1572  identity{}, // k_element_func
1573  v_dram_window,
1574  identity{}, // v_element_func
1575  bias_dram_window,
1576  identity{}, // bias_element_func
1577  randval_dram_window,
1578  lse_dram_window,
1579  identity{}, // lse_element_func
1580  identity{}, // s_acc_element_func
1581  scales{scale_p}, // p_compute_element_func
1582  o_acc_element_func, // o_acc_element_func
1583  mask,
1584  position_encoding,
1585  scale_s,
1586  variant,
1587  variant_params,
1588  block_indices,
1589  smem_ptr,
1590  dropout);
1591  }
1592  else
1593  {
1594  return FmhaPipeline{}(q_dram_window,
1595  k_dram_window,
1596  v_dram_window,
1597  bias_dram_window,
1598  randval_dram_window,
1599  lse_dram_window,
1600  mask,
1601  position_encoding,
1602  kargs.scale_s,
1603  variant,
1604  variant_params,
1605  block_indices,
1606  smem_ptr,
1607  dropout);
1608  }
1609  }();
1610 
1611  // O DRAM and O DRAM window
1612  auto o_dram = [&]() {
1613  const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1614  o_ptr,
1615  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1616  make_tuple(kargs.stride_o, 1),
1618  number<1>{});
1619 
1620  return pad_tensor_view(
1621  o_dram_naive,
1624  }();
1625 
1626  auto o_dram_window = make_tile_window(
1627  o_dram,
1629  {i_m0, i_n1});
1630 
1631  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
1632  }
1633  else
1634  {
1635  // TODO: Refine the logical here.
1636  // In Decode case
1637  // 1. we don't expect KV data reused by different ThreadGroups, bypass the cache
1638  // 2. limit the LDS usage, as we want higher occupancy
1639  // In Prefill case
1640  // 1. we expect KV data reused by different ThreadGroups, use cache
1641  // 2. use more LDS, as we want better memory latency hiding
1642  // If SplitKV off, we don't expect Q data reused by different ThreadGroups, bypass the
1643  // cache
1644  constexpr bool PrefillCase = FmhaPipeline::kM0 > 64;
1645  // divide problem
1646  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
1647 
1648  const index_t i_m0 = i_tile_m * FmhaPipeline::kM0;
1649  const index_t i_n1 = i_tile_n * FmhaPipeline::kN1;
1650 
1651  long_index_t batch_offset_q = 0;
1652  long_index_t batch_offset_k = 0; // unused for paged-kvcache
1653  long_index_t batch_offset_v = 0; // unused for paged-kvcache
1654  long_index_t batch_offset_bias = 0;
1655  long_index_t batch_offset_lse = 0;
1656  long_index_t batch_offset_o = 0;
1657  // index_t kv_l2p_offset =
1658  // 0; // logical-to-physical offset of seqlen_k coordinate. only used for
1659  // paged-kvcache
1660 
1661  if constexpr(kIsGroupMode)
1662  {
1663  // get starting offset for each batch - use seqstart_q_ptr/seqstart_k_ptr for
1664  // physical starts
1665  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
1666  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
1667 
1668  batch_offset_q = query_start * kargs.stride_q;
1669  batch_offset_k = key_start * kargs.stride_k;
1670  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1671  {
1672  batch_offset_v = key_start * kargs.stride_v;
1673  }
1674  else
1675  {
1676  // col-major V: offset along seqlen dimension is scalar index
1677  batch_offset_v = key_start;
1678  }
1680  {
1681  batch_offset_bias = query_start * kargs.stride_bias;
1682  }
1683 
1684  // LSE layout is [nhead, total_seqlen] following the physical layout for Q/O
1685  batch_offset_lse = query_start;
1686  batch_offset_o = query_start * kargs.stride_o;
1687 
1688  // get real # queries & # keys under group mode
1689  if(kargs.seqlen_q_ptr != nullptr)
1690  {
1691  kargs.seqlen_q = kargs.seqlen_q_ptr[i_batch];
1692  }
1693  else if(kargs.cu_seqlen_q_ptr != nullptr)
1694  {
1695  kargs.seqlen_q =
1696  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1697  }
1698  else
1699  {
1700  kargs.seqlen_q =
1701  kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
1702  }
1703 
1704  // # of required blocks is different in each groups, terminate unnecessary blocks
1705  // earlier
1706  if(kargs.seqlen_q <= i_m0)
1707  {
1708  return;
1709  }
1710 
1711  if(kargs.seqlen_k_ptr != nullptr)
1712  {
1713  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
1714  }
1715  else if(kargs.cu_seqlen_k_ptr != nullptr)
1716  {
1717  kargs.seqlen_k =
1718  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1719  }
1720  else
1721  {
1722  kargs.seqlen_k =
1723  kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
1724  }
1725  }
1726  else
1727  {
1728  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
1729  batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
1730  batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
1731  if constexpr(kStoreLSE)
1732  {
1733  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
1734  }
1735  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
1736 
1738  {
1739  batch_offset_bias =
1740  static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
1741  }
1742 
1743  // If cumulative seqlen pointers are provided, override per-batch effective lengths
1744  if(kargs.cu_seqlen_q_ptr != nullptr)
1745  {
1746  kargs.seqlen_q =
1747  kargs.cu_seqlen_q_ptr[i_batch + 1] - kargs.cu_seqlen_q_ptr[i_batch];
1748  }
1749  if(kargs.cu_seqlen_k_ptr != nullptr)
1750  {
1751  kargs.seqlen_k =
1752  kargs.cu_seqlen_k_ptr[i_batch + 1] - kargs.cu_seqlen_k_ptr[i_batch];
1753  }
1754  }
1755 
1756  // for simplicity, batch stride we just modify the pointer
1757  const index_t i_nhead_k = i_nhead / kargs.nhead_ratio_qk;
1758 
1759  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
1760  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
1761  batch_offset_q;
1762  const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) +
1763  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k +
1764  batch_offset_k;
1765  const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) +
1766  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v +
1767  batch_offset_v;
1768 
1769  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
1770  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1771  batch_offset_o;
1772 
1773  // Q/K/V DRAM and DRAM window
1774  const auto q_dram = [&] {
1775  const auto q_dram_naive = [&] {
1776  {
1777  return make_naive_tensor_view<address_space_enum::global,
1778  memory_operation_enum::set,
1780  q_ptr,
1781  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1782  make_tuple(kargs.stride_q, 1),
1784  number<1>{});
1785  }
1786  }();
1787 
1788  if constexpr(FmhaPipeline::kQLoadOnce)
1789  {
1790  const auto seqlen_q = kargs.seqlen_q;
1791  const auto q_dram_pad = pad_tensor_view(
1792  q_dram_naive,
1795 #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1796  constexpr index_t LDSLayerSize = 256 / sizeof(QDataType);
1797  constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
1798 
1799  if constexpr(XorLengthFold > 1)
1800  {
1801  const auto q_dram_unmerged = transform_tensor_view(
1802  q_dram_pad,
1803  make_tuple(
1805  make_tuple(seqlen_q / XorLengthFold, XorLengthFold)),
1809 
1810  const auto q_dram_merged = transform_tensor_view(
1811  q_dram_unmerged,
1812  make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold),
1814  XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
1817 
1818  const auto q_dram_unmerged_xor = transform_tensor_view(
1819  q_dram_merged,
1820  make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold),
1826 
1827  const auto q_dram_permuted = transform_tensor_view(
1828  q_dram_unmerged_xor,
1829  make_tuple(
1831  make_tuple(seqlen_q / XorLengthFold,
1836 
1837  const auto q_dram_tmp = transform_tensor_view(
1838  q_dram_permuted,
1839  make_tuple(
1840  make_pass_through_transform(seqlen_q / XorLengthFold),
1843  number<FmhaPipeline::kQKHeaddim /
1844  FmhaPipeline::kAlignmentQ>{})),
1848 
1849  return transform_tensor_view(
1850  q_dram_tmp,
1851  make_tuple(
1853  make_tuple(seqlen_q / XorLengthFold, number<XorLengthFold>{})),
1859  }
1860  else
1861 #endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1862  {
1863  const auto q_dram_unmerged = transform_tensor_view(
1864  q_dram_pad,
1865  make_tuple(
1866  make_pass_through_transform(seqlen_q),
1872 
1873  const auto q_dram_permuted = transform_tensor_view(
1874  q_dram_unmerged,
1875  make_tuple(
1876  make_xor_transform(make_tuple(seqlen_q,
1877  number<FmhaPipeline::kQKHeaddim /
1878  FmhaPipeline::kAlignmentQ>{})),
1882 
1883  return transform_tensor_view(
1884  q_dram_permuted,
1885  make_tuple(
1886  make_pass_through_transform(seqlen_q),
1892  }
1893  }
1894  else
1895  {
1896  return pad_tensor_view(
1897  q_dram_naive,
1900  }
1901  }();
1902 
1903  const auto make_k_dram = [&](const KDataType* data, index_t height) {
1904  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1905  data, // will update this pointer if using paged-kvcache
1906  make_tuple(height, kargs.hdim_q),
1907  make_tuple(kargs.stride_k, 1),
1909  number<1>{});
1910 
1911  const auto k_dram_pad = pad_tensor_view(
1912  k_dram_naive,
1915 
1916  constexpr auto kDramTileK =
1917  FmhaPipeline::kKLoadOnce ? FmhaPipeline::kQKHeaddim : FmhaPipeline::kK0;
1918 
1919 #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1920  constexpr index_t LDSLayerSize = 256 / sizeof(KDataType);
1921  constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
1922 
1923  if constexpr(XorLengthFold > 1)
1924  {
1925  const auto k_dram_unmerged = transform_tensor_view(
1926  k_dram_pad,
1928  make_tuple(height / XorLengthFold, XorLengthFold)),
1932 
1933  const auto k_dram_merged = transform_tensor_view(
1934  k_dram_unmerged,
1935  make_tuple(make_pass_through_transform(height / XorLengthFold),
1937  XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
1940 
1941  const auto k_dram_unmerged_xor = transform_tensor_view(
1942  k_dram_merged,
1943  make_tuple(make_pass_through_transform(height / XorLengthFold),
1949 
1950  const auto k_dram_permuted = transform_tensor_view(
1951  k_dram_unmerged_xor,
1952  make_tuple(
1954  make_tuple(height / XorLengthFold,
1959 
1960  const auto k_dram_tmp = transform_tensor_view(
1961  k_dram_permuted,
1962  make_tuple(
1963  make_pass_through_transform(height / XorLengthFold),
1966  number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{})),
1970 
1971  return transform_tensor_view(
1972  k_dram_tmp,
1973  make_tuple(
1975  make_tuple(height / XorLengthFold, number<XorLengthFold>{})),
1981  }
1982  else
1983 #endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1984  {
1985  const auto k_dram_unmerged = transform_tensor_view(
1986  k_dram_pad,
1989  make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
1990  FmhaPipeline::kAlignmentK>{},
1991  number<kDramTileK / FmhaPipeline::kAlignmentK>{},
1995 
1996  const auto k_dram_permuted = transform_tensor_view(
1997  k_dram_unmerged,
1998  make_tuple(
2002  number<FmhaPipeline::kQKHeaddim / kDramTileK /
2003  FmhaPipeline::kAlignmentK>{}),
2007 
2008  return transform_tensor_view(
2009  k_dram_permuted,
2012  make_tuple(number<FmhaPipeline::kQKHeaddim / kDramTileK /
2013  FmhaPipeline::kAlignmentK>{},
2014  number<kDramTileK / FmhaPipeline::kAlignmentK>{},
2018  }
2019  };
2020  const auto k_dram = [&]() {
2021  {
2022  return make_k_dram(k_ptr, kargs.seqlen_k);
2023  }
2024  }();
2025 
2026  const auto make_v_dram = [&](const VDataType* data, index_t length) {
2027  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
2028  data, // will update this pointer if using paged-kvcache
2029  make_tuple(length, kargs.hdim_v),
2030  make_tuple(kargs.stride_v, 1),
2032  number<1>{});
2033 
2034  // TODO: Add kVHeadDim
2035  constexpr index_t XorGroupSize =
2036  FmhaPipeline::Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
2037 
2038  const auto v_dram_pad = pad_tensor_view(
2039  v_dram_naive,
2042 
2043 #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
2044  constexpr index_t LDSLayerSize = 256 / sizeof(VDataType);
2045  constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
2046 
2047  if constexpr(XorLengthFold > 1)
2048  {
2049  const auto v_dram_unmerged = transform_tensor_view(
2050  v_dram_pad,
2052  make_tuple(length / XorLengthFold, XorLengthFold)),
2056 
2057  const auto v_dram_merged = transform_tensor_view(
2058  v_dram_unmerged,
2059  make_tuple(make_pass_through_transform(length / XorLengthFold),
2061  XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
2064 
2065  const auto v_dram_unmerged_xor = transform_tensor_view(
2066  v_dram_merged,
2067  make_tuple(
2068  make_pass_through_transform(length / XorLengthFold),
2070  number<XorGroupSize>{}))),
2073 
2074  const auto v_dram_permuted = transform_tensor_view(
2075  v_dram_unmerged_xor,
2076  make_tuple(
2077  make_xor_transform(make_tuple(length / XorLengthFold,
2082 
2083  const auto v_dram_tmp = transform_tensor_view(
2084  v_dram_permuted,
2085  make_tuple(make_pass_through_transform(length / XorLengthFold),
2088  number<FmhaPipeline::kQKHeaddim / XorGroupSize>{})),
2092 
2093  return transform_tensor_view(
2094  v_dram_tmp,
2096  make_tuple(length / XorLengthFold, number<XorLengthFold>{})),
2099  number<XorGroupSize>{}))),
2102  }
2103  else
2104 #endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
2105  {
2106  const auto v_dram_unmerged = transform_tensor_view(
2107  v_dram_pad,
2111  number<XorGroupSize>{}))),
2114 
2115  const auto v_dram_permuted = transform_tensor_view(
2116  v_dram_unmerged,
2122 
2123  return transform_tensor_view(
2124  v_dram_permuted,
2128  number<XorGroupSize>{}))),
2131  }
2132  };
2133 
2134  const auto v_dram = [&]() {
2135  {
2136  return make_v_dram(v_ptr, kargs.seqlen_k);
2137  }
2138  }();
2139 
2140  auto q_dram_window = make_tile_window(
2141  q_dram,
2142  [&]() {
2143  if constexpr(FmhaPipeline::kQLoadOnce)
2146  else
2148  }(),
2149  {i_m0, 0});
2150 
2151  auto k_dram_window = make_tile_window(
2152  k_dram,
2154  {0, 0});
2155 
2156  auto v_dram_window = make_tile_window(
2157  v_dram,
2159  {0, 0});
2160 
2163  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
2164  constexpr auto bias_dram_window_lengths =
2167  {
2168  const BiasDataType* bias_ptr =
2169  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
2170  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
2171  batch_offset_bias;
2172 
2173  const auto bias_dram = [&]() {
2174  const auto bias_dram_naive =
2175  make_naive_tensor_view<address_space_enum::global>(
2176  bias_ptr,
2177  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
2178  make_tuple(kargs.stride_bias, 1),
2180  number<1>{});
2181 
2182  return pad_tensor_view(bias_dram_naive,
2183  bias_dram_window_lengths,
2185  }();
2186 
2187  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
2188  }
2189  else
2190  {
2191  return make_null_tile_window(bias_dram_window_lengths);
2192  }
2193  }();
2194 
2195  // lse acc
2196  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
2197  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
2198  if constexpr(kStoreLSE)
2199  {
2200  LSEDataType* lse_ptr =
2201  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
2202  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse +
2203  batch_offset_lse;
2204 
2205  const auto lse_dram = [&] {
2206  const auto lse_dram_naive = [&] {
2207  {
2208  return make_naive_tensor_view<address_space_enum::global>(
2209  lse_ptr,
2210  make_tuple(kargs.seqlen_q),
2211  make_tuple(1),
2212  number<1>{},
2213  number<1>{});
2214  }
2215  }();
2216  return pad_tensor_view(
2217  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
2218  }();
2219 
2220  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
2221  }
2222  else
2223  {
2224  return make_null_tile_window(lse_dram_window_lengths);
2225  }
2226  }();
2227 
2228  FmhaMask mask = [&]() {
2229  if constexpr(kHasMask)
2230  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
2231  kargs.window_size_left,
2232  kargs.window_size_right,
2233  kargs.seqlen_q,
2234  kargs.seqlen_k,
2236  else
2237  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
2238  }();
2239 
2240  // WA i_batch capture structure binding before c++20
2241  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
2242  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
2243  {
2244  // data loading, shared by entire wg
2245  // TODO: how to use s_read?
2246  SaccDataType slope =
2247  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
2248  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
2249 #if CK_TILE_FMHA_FWD_FAST_EXP2
2250  slope *= ck_tile::log2e_v<>;
2251 #endif
2252  if constexpr(kHasMask)
2253  {
2254  return make_alibi_from_lr_mask<SaccDataType, true, 32>(
2255  slope,
2256  kargs.window_size_left,
2257  kargs.window_size_right,
2258  kargs.seqlen_q,
2259  kargs.seqlen_k,
2260  kargs.mask_type);
2261  }
2262  else
2263  {
2265  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
2266  }
2267  }
2268  else
2269  {
2271  }
2272  }();
2273 
2274  auto o_acc_tile = [&]() {
2275  if constexpr(PrefillCase)
2276  {
2277  // allocate double lds
2278  // add __restrict__ here to avoid aliasing
2279  __shared__ char smem_ptrk0
2280  [FmhaPipeline::Policy::template GetSmemSizeK<typename FmhaPipeline::Problem,
2281  true>()];
2282  __shared__ char smem_ptrk1
2283  [FmhaPipeline::Policy::template GetSmemSizeK<typename FmhaPipeline::Problem,
2284  true>()];
2285  __shared__ char smem_ptrv0[FmhaPipeline::Policy::template GetSmemSizeV<
2286  typename FmhaPipeline::Problem>()];
2287  __shared__ char smem_ptrv1[FmhaPipeline::Policy::template GetSmemSizeV<
2288  typename FmhaPipeline::Problem>()];
2289 
2290  return FmhaPipeline{}(q_dram_window,
2291  k_dram_window,
2292  v_dram_window,
2293  bias_dram_window,
2294  lse_dram_window,
2295  mask,
2296  position_encoding,
2297  kargs.scale_s,
2298  smem_ptrk0,
2299  smem_ptrk1,
2300  smem_ptrv0,
2301  smem_ptrv1);
2302  }
2303  else
2304  {
2305  __shared__ char smem_ptr[GetSmemSize()];
2306  return FmhaPipeline{}(q_dram_window,
2307  k_dram_window,
2308  v_dram_window,
2309  bias_dram_window,
2310  lse_dram_window,
2311  mask,
2312  position_encoding,
2313  kargs.scale_s,
2314  smem_ptr);
2315  }
2316  }();
2317 
2318  // Oacc DRAM and Oacc DRAM window
2319  auto o_dram = [&] {
2320  const auto o_dram_naive = [&] {
2321  {
2322  return make_naive_tensor_view<address_space_enum::global>(
2323  o_ptr,
2324  make_tuple(kargs.seqlen_q, kargs.hdim_v),
2325  make_tuple(kargs.stride_o, 1),
2327  number<1>{});
2328  }
2329  }();
2330 
2331  return pad_tensor_view(
2332  o_dram_naive,
2335  }();
2336 
2337  auto o_dram_window = make_tile_window(
2338  o_dram,
2340  {i_m0, i_n1});
2341 
2342  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
2343  }
2344  }
2345 };
2346 
2347 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
#define _TS_
#define _SS_
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ T floor(T x)
Definition: math_v2.hpp:367
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
_Float16 fp16_t
Definition: half.hpp:110
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:526
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
float fp32_t
Definition: pk_fp4.hpp:21
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1622
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1609
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto make_xor_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1662
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_view.hpp:486
__host__ __device__ composes(Ts &&...) -> composes< remove_cvref_t< Ts >... >
FIXME: create macro to replace 'host device' and nothing more.
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
unsigned char uint8_t
Definition: stdint.h:124
unsigned __int64 uint64_t
Definition: stdint.h:136
Definition: block_position_encoding.hpp:48
Definition: block_attention_bias_enum.hpp:19
Definition: block_attention_quant_scale_enum.hpp:18
Definition: block_dropout.hpp:53
const float rp_undrop
Definition: block_dropout.hpp:371
Definition: block_position_encoding.hpp:137
Definition: fmha_fwd_kernel.hpp:338
ck_tile::index_t kv_head_idx
Definition: fmha_fwd_kernel.hpp:341
ck_tile::index_t batch_idx
Definition: fmha_fwd_kernel.hpp:339
ck_tile::index_t qo_head_idx
Definition: fmha_fwd_kernel.hpp:340
Definition: fmha_fwd_kernel.hpp:197
ck_tile::index_t alibi_slope_stride
Definition: fmha_fwd_kernel.hpp:200
const void * alibi_slope_ptr
Definition: fmha_fwd_kernel.hpp:199
Definition: fmha_fwd_kernel.hpp:192
ck_tile::index_t batch_stride_bias
Definition: fmha_fwd_kernel.hpp:193
Definition: fmha_fwd_kernel.hpp:274
ck_tile::index_t batch_stride_randval
Definition: fmha_fwd_kernel.hpp:275
Definition: fmha_fwd_kernel.hpp:297
ck_tile::index_t batch_stride_o
Definition: fmha_fwd_kernel.hpp:301
const int32_t * cu_seqlen_k_ptr
Definition: fmha_fwd_kernel.hpp:306
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_kernel.hpp:298
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_kernel.hpp:299
const int32_t * cu_seqlen_q_ptr
Definition: fmha_fwd_kernel.hpp:305
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_kernel.hpp:300
Definition: fmha_fwd_kernel.hpp:185
const void * bias_ptr
Definition: fmha_fwd_kernel.hpp:186
ck_tile::index_t stride_bias
Definition: fmha_fwd_kernel.hpp:187
ck_tile::index_t nhead_stride_bias
Definition: fmha_fwd_kernel.hpp:188
Definition: fmha_fwd_kernel.hpp:239
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr)
Definition: fmha_fwd_kernel.hpp:252
float rp_undrop
Definition: fmha_fwd_kernel.hpp:264
ck_tile::index_t stride_randval
Definition: fmha_fwd_kernel.hpp:269
ck_tile::index_t nhead_stride_randval
Definition: fmha_fwd_kernel.hpp:270
void * rand_val_ptr
Definition: fmha_fwd_kernel.hpp:267
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
Definition: fmha_fwd_kernel.hpp:240
bool is_store_randval
Definition: fmha_fwd_kernel.hpp:266
uint8_t p_undrop_in_uint8_t
Definition: fmha_fwd_kernel.hpp:265
Definition: fmha_fwd_kernel.hpp:134
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_kernel.hpp:157
float scale_s
Definition: fmha_fwd_kernel.hpp:149
ck_tile::index_t seqlen_k
Definition: fmha_fwd_kernel.hpp:141
ck_tile::index_t nhead_stride_o
Definition: fmha_fwd_kernel.hpp:159
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_kernel.hpp:148
ck_tile::index_t num_head_q
Definition: fmha_fwd_kernel.hpp:145
ck_tile::index_t hdim_q
Definition: fmha_fwd_kernel.hpp:142
const void * v_ptr
Definition: fmha_fwd_kernel.hpp:137
void * o_ptr
Definition: fmha_fwd_kernel.hpp:138
const void * k_ptr
Definition: fmha_fwd_kernel.hpp:136
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_kernel.hpp:156
ck_tile::index_t stride_k
Definition: fmha_fwd_kernel.hpp:152
ck_tile::index_t stride_o
Definition: fmha_fwd_kernel.hpp:154
ck_tile::index_t stride_v
Definition: fmha_fwd_kernel.hpp:153
ck_tile::index_t hdim_v
Definition: fmha_fwd_kernel.hpp:143
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_kernel.hpp:158
const void * q_ptr
Definition: fmha_fwd_kernel.hpp:135
ck_tile::index_t seqlen_q
Definition: fmha_fwd_kernel.hpp:140
ck_tile::index_t stride_q
Definition: fmha_fwd_kernel.hpp:151
Definition: fmha_fwd_kernel.hpp:218
ck_tile::index_t batch_stride_lse
Definition: fmha_fwd_kernel.hpp:221
void * lse_ptr
Definition: fmha_fwd_kernel.hpp:219
ck_tile::index_t nhead_stride_lse
Definition: fmha_fwd_kernel.hpp:220
Definition: fmha_fwd_kernel.hpp:211
const void * v_descale_ptr
Definition: fmha_fwd_kernel.hpp:214
const void * k_descale_ptr
Definition: fmha_fwd_kernel.hpp:213
const void * q_descale_ptr
Definition: fmha_fwd_kernel.hpp:212
Definition: fmha_fwd_kernel.hpp:225
bool is_drop_seed_offset_from_host
Definition: fmha_fwd_kernel.hpp:235
ValueOrPointer< uint64_t > drop_seed
Definition: fmha_fwd_kernel.hpp:233
ValueOrPointer< uint64_t > drop_offset
Definition: fmha_fwd_kernel.hpp:234
Definition: fmha_fwd_kernel.hpp:127
Definition: fmha_fwd_kernel.hpp:324
const int32_t * seqlen_q_ptr
Definition: fmha_fwd_kernel.hpp:327
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_kernel.hpp:325
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_kernel.hpp:328
const int32_t * cu_seqlen_k_ptr
Definition: fmha_fwd_kernel.hpp:332
const int32_t * cu_seqlen_q_ptr
Definition: fmha_fwd_kernel.hpp:331
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_kernel.hpp:326
Definition: fmha_fwd_kernel.hpp:163
float logits_soft_cap
Definition: fmha_fwd_kernel.hpp:180
float logits_soft_cap_rcp
Definition: fmha_fwd_kernel.hpp:181
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_fwd_kernel.hpp:166
Definition: fmha_fwd_kernel.hpp:204
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_kernel.hpp:207
ck_tile::index_t window_size_right
Definition: fmha_fwd_kernel.hpp:206
ck_tile::index_t window_size_left
Definition: fmha_fwd_kernel.hpp:206
Definition: fmha_fwd_kernel.hpp:279
ck_tile::index_t min_seqlen_q
Definition: fmha_fwd_kernel.hpp:280
Definition: fmha_fwd_kernel.hpp:77
Definition: fmha_fwd_kernel.hpp:28
static constexpr bool kHasDropout
Definition: fmha_fwd_kernel.hpp:58
static CK_TILE_HOST std::string GetName()
Definition: fmha_fwd_kernel.hpp:87
static constexpr bool kIsAvailable
Definition: fmha_fwd_kernel.hpp:72
static constexpr bool kStoreLSE
Definition: fmha_fwd_kernel.hpp:57
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:591
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_kernel.hpp:38
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_fwd_kernel.hpp:335
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_kernel.hpp:33
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_kernel.hpp:45
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_fwd_kernel.hpp:48
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_kernel.hpp:31
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_fwd_kernel.hpp:41
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_, bool has_padded_seqlen_k=false)
Definition: fmha_fwd_kernel.hpp:1012
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:492
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_kernel.hpp:39
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:689
static constexpr auto QScaleEnum
Definition: fmha_fwd_kernel.hpp:59
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_fwd_kernel.hpp:35
static constexpr auto BiasEnum
Definition: fmha_fwd_kernel.hpp:56
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_kernel.hpp:54
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_fwd_kernel.hpp:1037
static constexpr bool kSkipMinSeqlenQ
Definition: fmha_fwd_kernel.hpp:60
static constexpr std::string_view kPipelineName
Definition: fmha_fwd_kernel.hpp:74
ck_tile::remove_cvref_t< typename FmhaPipeline::PDataType > PDataType
Definition: fmha_fwd_kernel.hpp:40
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_kernel.hpp:44
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_kernel.hpp:37
static CK_TILE_HOST dim3 BlockSize()
Definition: fmha_fwd_kernel.hpp:1102
CK_TILE_DEVICE void run_(Kargs kargs) const
Definition: fmha_fwd_kernel.hpp:1125
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_fwd_kernel.hpp:62
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_kernel.hpp:1114
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:833
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:346
static constexpr bool kUseTrLoad
Definition: fmha_fwd_kernel.hpp:68
static constexpr bool kHasMask
Definition: fmha_fwd_kernel.hpp:64
static constexpr bool kUseAsyncCopy
Definition: fmha_fwd_kernel.hpp:66
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_kernel.hpp:29
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_kernel.hpp:53
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_kernel.hpp:46
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_kernel.hpp:51
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_kernel.hpp:63
static constexpr bool kHasLogitsSoftCap
Definition: fmha_fwd_kernel.hpp:55
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition: fmha_fwd_kernel.hpp:43
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, const void *q_descale_ptr, const void *k_descale_ptr, const void *v_descale_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_q_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset, const void *cu_seqlen_q_ptr=nullptr, const void *cu_seqlen_k_ptr=nullptr)
Definition: fmha_fwd_kernel.hpp:924
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_kernel.hpp:30
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_kernel.hpp:52
static constexpr bool kIsGroupMode
Definition: fmha_fwd_kernel.hpp:50
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_kernel.hpp:1119
Definition: variants.hpp:63
float logits_soft_cap
Definition: variants.hpp:128
Definition: block_dropout.hpp:39
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: functional.hpp:86
Definition: numeric.hpp:18
Definition: coordinate_transform.hpp:1392
Definition: unary_element_function.hpp:56
Definition: math.hpp:28
Definition: sequence.hpp:49
const T * ptr
Definition: fmha_fwd_kernel.hpp:230