/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File
fmha_fwd_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
10 
11 #include <string>
12 #include <type_traits>
13 #include <utility>
14 #include <variant>
15 
16 #define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0
17 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
18 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
19 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
20 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
21 // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
22 
23 namespace ck_tile {
24 
25 template <typename FmhaPipeline_, typename EpiloguePipeline_>
27 {
30  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
31 
32  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
33  static_assert(kBlockPerCu > 0);
34  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
35 
45 
47 
48  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
49  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
50  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
51  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
52  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
53  static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
54  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
55  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
56  static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
57  static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
58  static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
59 
62  static constexpr bool kHasMask = FmhaMask::IsMasking;
63 
64  static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
65 
66  static constexpr bool kUseTrLoad = FmhaPipeline::Problem::kUseTrLoad;
67 #if defined(__gfx950__)
68  static constexpr bool kIsAvailable = true;
69 #else
70  static constexpr bool kIsAvailable = !kUseTrLoad;
71 #endif
72  static constexpr std::string_view kPipelineName = FmhaPipeline::name;
73 
74  // clang-format off
75  template <typename T> struct t2s;
76  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
77  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
78  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
79  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
80  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
81  // clang-format on
82 
83  CK_TILE_HOST static std::string GetName()
84  {
85  // sync with generate.py
86  // clang-format off
87  using bfs = typename FmhaPipeline::BlockFmhaShape;
88  using g0br = typename bfs::Gemm0BlockWarps;
89  using g1br = typename bfs::Gemm1BlockWarps;
90  using g0wt = typename bfs::Gemm0WarpTile;
91  using g1wt = typename bfs::Gemm1WarpTile;
92  #define _SS_ std::string
93  #define _TS_ std::to_string
94  auto pn = [&] () {
95  std::string n;
96  if (kPadSeqLenQ) n += "s";
97  if (kPadSeqLenK) n += "sk";
98  if (kPadHeadDimQ) n += "d";
99  if (kPadHeadDimV) n += "dv";
100  return n.empty() ? n : std::string("p") + n; }();
101  return
102  _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
103  "_" + (kIsGroupMode ? "group" : "batch") + "_"
104  "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
105  _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
106  "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
107  "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
108  "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
109  "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
110  (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
111  "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
112  (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
113  (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" ) + (kUseTrLoad ? "_trload" : "_ntrload");
114  #undef _SS_
115  #undef _TS_
116  // clang-format on
117  }
118 
119  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
120  // arg
122  {
123  };
124 
125  // kargs use aggregate initializer, so no constructor will provided
126  // use inheritance to minimize karg size
127  // user need to use MakeKargs() function to create kargs.
129  {
130  const void* q_ptr;
131  const void* k_ptr;
132  const void* v_ptr;
133  void* o_ptr;
134 
139 
141  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
142  // if this param is larger than 1, indicate MQA/GQA case
144  float scale_s;
145 
150 
155  };
156 
158  {
160 
161  void init_logits_soft_cap(float logits_soft_cap_)
162  {
163  if(0 < logits_soft_cap_)
164  {
165  logits_soft_cap = logits_soft_cap_;
167  }
168  else
169  {
170  logits_soft_cap = 0.f;
171  logits_soft_cap_rcp = 0.f;
172  }
173  }
174 
177  };
178 
180  {
181  const void* bias_ptr = nullptr;
184  };
185 
187  {
189  };
190 
192  {
193  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
194  const void* alibi_slope_ptr;
195  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
196  };
197 
199  {
200  // ck_tile::index_t window_size_left, window_size_right;
203  };
204 
206  {
207  float scale_p;
208  float scale_o;
209  };
210 
212  {
213  void* lse_ptr = nullptr;
216  };
217 
219  {
220  template <typename T>
222  {
223  T val;
224  const T* ptr;
225  };
226 
230  };
231 
233  {
234  void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
235  {
236  float p_undrop = 1.0 - p_drop;
239  rp_undrop = 1.0 / p_undrop;
240 
241  this->drop_seed.val = seed;
242  this->drop_offset.val = offset;
243  this->is_drop_seed_offset_from_host = true;
244  }
245 
246  void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
247  {
248  float p_undrop = 1.0 - p_drop;
251  rp_undrop = 1.0 / p_undrop;
252 
253  this->drop_seed.ptr = seed_ptr;
254  this->drop_offset.ptr = offset_ptr;
255  this->is_drop_seed_offset_from_host = false;
256  }
257 
258  float rp_undrop = 1;
260  bool is_store_randval = false;
261  void* rand_val_ptr = nullptr;
262 
265  };
266 
268  {
270  };
271 
273  {
275  };
276 
279  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
280  FmhaFwdBatchModeBiasKargs,
281  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
282  FmhaFwdAlibiKargs,
283  FmhaFwdEmptyKargs<0>>>,
284  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
285  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
286  std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
287  std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
288  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
289  {
294  };
295 
298  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
299  FmhaFwdCommonBiasKargs,
300  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
301  FmhaFwdAlibiKargs,
302  FmhaFwdEmptyKargs<0>>>,
303  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
304  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
305  std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
306  std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
307  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
308  std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
309  {
313  };
314 
315  using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
316 
318  {
322  };
323 
324  template <bool Cond = !kIsGroupMode>
325  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
326  MakeKargsImpl(const void* q_ptr,
327  const void* k_ptr,
328  const void* v_ptr,
329  const void* bias_ptr,
330  void* rand_val_ptr,
331  void* lse_ptr,
332  void* o_ptr,
333  ck_tile::index_t seqlen_q,
334  ck_tile::index_t seqlen_k,
335  ck_tile::index_t hdim_q,
336  ck_tile::index_t hdim_v,
337  ck_tile::index_t num_head_q,
338  ck_tile::index_t nhead_ratio_qk,
339  float scale_s,
340  float scale_p,
341  float scale_o,
342  float logits_soft_cap,
343  ck_tile::index_t stride_q,
344  ck_tile::index_t stride_k,
345  ck_tile::index_t stride_v,
346  ck_tile::index_t stride_bias,
347  ck_tile::index_t stride_randval,
348  ck_tile::index_t stride_o,
349  ck_tile::index_t nhead_stride_q,
350  ck_tile::index_t nhead_stride_k,
351  ck_tile::index_t nhead_stride_v,
352  ck_tile::index_t nhead_stride_bias,
353  ck_tile::index_t nhead_stride_randval,
354  ck_tile::index_t nhead_stride_lse,
355  ck_tile::index_t nhead_stride_o,
356  ck_tile::index_t batch_stride_q,
357  ck_tile::index_t batch_stride_k,
358  ck_tile::index_t batch_stride_v,
359  ck_tile::index_t batch_stride_bias,
360  ck_tile::index_t batch_stride_randval,
361  ck_tile::index_t batch_stride_lse,
362  ck_tile::index_t batch_stride_o,
363  ck_tile::index_t window_size_left,
364  ck_tile::index_t window_size_right,
365  ck_tile::index_t mask_type,
366  float p_drop,
367  bool s_randval,
368  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
369  drop_seed_offset)
370  {
371  Kargs kargs{{q_ptr,
372  k_ptr,
373  v_ptr,
374  o_ptr,
375  seqlen_q,
376  seqlen_k,
377  hdim_q,
378  hdim_v,
379  num_head_q,
380  nhead_ratio_qk,
381 #if CK_TILE_FMHA_FWD_FAST_EXP2
382  static_cast<float>(scale_s * ck_tile::log2e_v<>),
383 #else
384  scale_s,
385 #endif
386  stride_q,
387  stride_k,
388  stride_v,
389  stride_o,
390  nhead_stride_q,
391  nhead_stride_k,
392  nhead_stride_v,
393  nhead_stride_o}, // args for common karg
394  {}, // placeholder for bias
395  {}, // placeholder for mask
396  {}, // placeholder for lse
397  {}, // placeholder for fp8_static_quant args
398  {}, // placeholder for dropout
399  {}, // placeholder for logits_soft_cap
400  batch_stride_q,
401  batch_stride_k,
402  batch_stride_v,
403  batch_stride_o};
404 
406  {
407  kargs.bias_ptr = bias_ptr;
408  kargs.stride_bias = stride_bias;
409  kargs.nhead_stride_bias = nhead_stride_bias;
410  kargs.batch_stride_bias = batch_stride_bias;
411  }
412  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
413  {
414  kargs.alibi_slope_ptr = bias_ptr;
415  kargs.alibi_slope_stride = stride_bias;
416  }
417  if constexpr(kHasMask)
418  {
419  kargs.window_size_left = window_size_left;
420  kargs.window_size_right = window_size_right;
421  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
422  }
423  if constexpr(kStoreLSE)
424  {
425  kargs.lse_ptr = lse_ptr;
426  kargs.nhead_stride_lse = nhead_stride_lse;
427  kargs.batch_stride_lse = batch_stride_lse;
428  }
429  if constexpr(kDoFp8StaticQuant)
430  {
431  kargs.scale_p = scale_p;
432  kargs.scale_o = scale_o;
433  }
434  if constexpr(kHasDropout)
435  {
436  if(drop_seed_offset.index() == 0) // seed & offset come from host
437  {
438  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
439  kargs.init_dropout(p_drop, seed, offset);
440  }
441  else // seed & offset come from device
442  {
443  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
444  kargs.init_dropout(p_drop,
445  reinterpret_cast<const uint64_t*>(seed_ptr),
446  reinterpret_cast<const uint64_t*>(offset_ptr));
447  }
448 
449  kargs.rand_val_ptr = rand_val_ptr;
450  kargs.stride_randval = stride_randval;
451  kargs.nhead_stride_randval = nhead_stride_randval;
452  kargs.batch_stride_randval = batch_stride_randval;
453  kargs.is_store_randval = s_randval;
454  }
455  if constexpr(kHasLogitsSoftCap)
456  {
457  kargs.init_logits_soft_cap(logits_soft_cap);
458  }
459 
460  return kargs;
461  }
462 
463  // std::variant<> can't take in a list initializer, overload for backward compatibility
464  template <bool Cond = !kIsGroupMode>
465  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
466  MakeKargs(const void* q_ptr,
467  const void* k_ptr,
468  const void* v_ptr,
469  const void* bias_ptr,
470  void* rand_val_ptr,
471  void* lse_ptr,
472  void* o_ptr,
473  ck_tile::index_t seqlen_q,
474  ck_tile::index_t seqlen_k,
475  ck_tile::index_t hdim_q,
476  ck_tile::index_t hdim_v,
477  ck_tile::index_t num_head_q,
478  ck_tile::index_t nhead_ratio_qk,
479  float scale_s,
480  float scale_p,
481  float scale_o,
482  float logits_soft_cap,
483  ck_tile::index_t stride_q,
484  ck_tile::index_t stride_k,
485  ck_tile::index_t stride_v,
486  ck_tile::index_t stride_bias,
487  ck_tile::index_t stride_randval,
488  ck_tile::index_t stride_o,
489  ck_tile::index_t nhead_stride_q,
490  ck_tile::index_t nhead_stride_k,
491  ck_tile::index_t nhead_stride_v,
492  ck_tile::index_t nhead_stride_bias,
493  ck_tile::index_t nhead_stride_randval,
494  ck_tile::index_t nhead_stride_lse,
495  ck_tile::index_t nhead_stride_o,
496  ck_tile::index_t batch_stride_q,
497  ck_tile::index_t batch_stride_k,
498  ck_tile::index_t batch_stride_v,
499  ck_tile::index_t batch_stride_bias,
500  ck_tile::index_t batch_stride_randval,
501  ck_tile::index_t batch_stride_lse,
502  ck_tile::index_t batch_stride_o,
503  ck_tile::index_t window_size_left,
504  ck_tile::index_t window_size_right,
505  ck_tile::index_t mask_type,
506  float p_drop,
507  bool s_randval,
508  const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
509  {
510  return MakeKargsImpl(
511  q_ptr,
512  k_ptr,
513  v_ptr,
514  bias_ptr,
515  rand_val_ptr,
516  lse_ptr,
517  o_ptr,
518  seqlen_q,
519  seqlen_k,
520  hdim_q,
521  hdim_v,
522  num_head_q,
523  nhead_ratio_qk,
524  scale_s,
525  scale_p,
526  scale_o,
527  logits_soft_cap,
528  stride_q,
529  stride_k,
530  stride_v,
531  stride_bias,
532  stride_randval,
533  stride_o,
534  nhead_stride_q,
535  nhead_stride_k,
536  nhead_stride_v,
537  nhead_stride_bias,
538  nhead_stride_randval,
539  nhead_stride_lse,
540  nhead_stride_o,
541  batch_stride_q,
542  batch_stride_k,
543  batch_stride_v,
544  batch_stride_bias,
545  batch_stride_randval,
546  batch_stride_lse,
547  batch_stride_o,
548  window_size_left,
549  window_size_right,
550  mask_type,
551  p_drop,
552  s_randval,
553  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
554  }
555 
556  // std::variant<> can't take in a list initializer, overload for backward compatibility
557  template <bool Cond = !kIsGroupMode>
558  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
559  MakeKargs(const void* q_ptr,
560  const void* k_ptr,
561  const void* v_ptr,
562  const void* bias_ptr,
563  void* rand_val_ptr,
564  void* lse_ptr,
565  void* o_ptr,
566  ck_tile::index_t seqlen_q,
567  ck_tile::index_t seqlen_k,
568  ck_tile::index_t hdim_q,
569  ck_tile::index_t hdim_v,
570  ck_tile::index_t num_head_q,
571  ck_tile::index_t nhead_ratio_qk,
572  float scale_s,
573  float scale_p,
574  float scale_o,
575  float logits_soft_cap,
576  ck_tile::index_t stride_q,
577  ck_tile::index_t stride_k,
578  ck_tile::index_t stride_v,
579  ck_tile::index_t stride_bias,
580  ck_tile::index_t stride_randval,
581  ck_tile::index_t stride_o,
582  ck_tile::index_t nhead_stride_q,
583  ck_tile::index_t nhead_stride_k,
584  ck_tile::index_t nhead_stride_v,
585  ck_tile::index_t nhead_stride_bias,
586  ck_tile::index_t nhead_stride_randval,
587  ck_tile::index_t nhead_stride_lse,
588  ck_tile::index_t nhead_stride_o,
589  ck_tile::index_t batch_stride_q,
590  ck_tile::index_t batch_stride_k,
591  ck_tile::index_t batch_stride_v,
592  ck_tile::index_t batch_stride_bias,
593  ck_tile::index_t batch_stride_randval,
594  ck_tile::index_t batch_stride_lse,
595  ck_tile::index_t batch_stride_o,
596  ck_tile::index_t window_size_left,
597  ck_tile::index_t window_size_right,
598  ck_tile::index_t mask_type,
599  float p_drop,
600  bool s_randval,
601  const std::tuple<const void*, const void*>& drop_seed_offset)
602  {
603  return MakeKargsImpl(
604  q_ptr,
605  k_ptr,
606  v_ptr,
607  bias_ptr,
608  rand_val_ptr,
609  lse_ptr,
610  o_ptr,
611  seqlen_q,
612  seqlen_k,
613  hdim_q,
614  hdim_v,
615  num_head_q,
616  nhead_ratio_qk,
617  scale_s,
618  scale_p,
619  scale_o,
620  logits_soft_cap,
621  stride_q,
622  stride_k,
623  stride_v,
624  stride_bias,
625  stride_randval,
626  stride_o,
627  nhead_stride_q,
628  nhead_stride_k,
629  nhead_stride_v,
630  nhead_stride_bias,
631  nhead_stride_randval,
632  nhead_stride_lse,
633  nhead_stride_o,
634  batch_stride_q,
635  batch_stride_k,
636  batch_stride_v,
637  batch_stride_bias,
638  batch_stride_randval,
639  batch_stride_lse,
640  batch_stride_o,
641  window_size_left,
642  window_size_right,
643  mask_type,
644  p_drop,
645  s_randval,
646  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
647  }
648 
649  template <bool Cond = kIsGroupMode>
650  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
651  MakeKargsImpl(const void* q_ptr,
652  const void* k_ptr,
653  const void* v_ptr,
654  const void* bias_ptr,
655  void* rand_val_ptr,
656  void* lse_ptr,
657  void* o_ptr,
658  const void* seqstart_q_ptr,
659  const void* seqstart_k_ptr,
660  const void* seqlen_k_ptr,
661  ck_tile::index_t hdim_q,
662  ck_tile::index_t hdim_v,
663  ck_tile::index_t num_head_q,
664  ck_tile::index_t nhead_ratio_qk,
665  float scale_s,
666  float scale_p,
667  float scale_o,
668  float logits_soft_cap,
669  ck_tile::index_t stride_q,
670  ck_tile::index_t stride_k,
671  ck_tile::index_t stride_v,
672  ck_tile::index_t stride_bias,
673  ck_tile::index_t stride_randval,
674  ck_tile::index_t stride_o,
675  ck_tile::index_t nhead_stride_q,
676  ck_tile::index_t nhead_stride_k,
677  ck_tile::index_t nhead_stride_v,
678  ck_tile::index_t nhead_stride_bias,
679  ck_tile::index_t nhead_stride_randval,
680  ck_tile::index_t nhead_stride_lse,
681  ck_tile::index_t nhead_stride_o,
682  ck_tile::index_t window_size_left,
683  ck_tile::index_t window_size_right,
684  ck_tile::index_t mask_type,
685  ck_tile::index_t min_seqlen_q,
686  float p_drop,
687  bool s_randval,
688  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
689  drop_seed_offset)
690  {
691  Kargs kargs{{q_ptr,
692  k_ptr,
693  v_ptr,
694  o_ptr,
695  -1, // seqlen will be updated by another pointer
696  -1, //
697  hdim_q,
698  hdim_v,
699  num_head_q,
700  nhead_ratio_qk,
701 #if CK_TILE_FMHA_FWD_FAST_EXP2
702  static_cast<float>(scale_s * ck_tile::log2e_v<>),
703 #else
704  scale_s,
705 #endif
706  stride_q,
707  stride_k,
708  stride_v,
709  stride_o,
710  nhead_stride_q,
711  nhead_stride_k,
712  nhead_stride_v,
713  nhead_stride_o}, // args for common karg
714  {}, // placeholder for bias
715  {}, // placeholder for mask
716  {}, // placeholder for lse
717  {}, // placeholder for fp8_static_quant args
718  {}, // placeholder for dropout
719  {}, // placeholder for logits_soft_cap
720  {}, // placeholder for min_seqlen_q
721  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
722  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
723  reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
724 
726  {
727  kargs.bias_ptr = bias_ptr;
728  kargs.stride_bias = stride_bias;
729  kargs.nhead_stride_bias = nhead_stride_bias;
730  }
731  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
732  {
733  kargs.alibi_slope_ptr = bias_ptr;
734  kargs.alibi_slope_stride = stride_bias;
735  }
736  if constexpr(kHasMask)
737  {
738  kargs.window_size_left = window_size_left;
739  kargs.window_size_right = window_size_right;
740  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
741  }
742  if constexpr(kStoreLSE)
743  {
744  kargs.lse_ptr = lse_ptr;
745  kargs.nhead_stride_lse = nhead_stride_lse;
746  }
747  if constexpr(kDoFp8StaticQuant)
748  {
749  kargs.scale_p = scale_p;
750  kargs.scale_o = scale_o;
751  }
752  if constexpr(kHasDropout)
753  {
754  if(drop_seed_offset.index() == 0) // seed & offset come from host
755  {
756  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
757  kargs.init_dropout(p_drop, seed, offset);
758  }
759  else // seed & offset come from device
760  {
761  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
762  kargs.init_dropout(p_drop,
763  reinterpret_cast<const uint64_t*>(seed_ptr),
764  reinterpret_cast<const uint64_t*>(offset_ptr));
765  }
766 
767  kargs.rand_val_ptr = rand_val_ptr;
768  kargs.stride_randval = stride_randval;
769  kargs.nhead_stride_randval = nhead_stride_randval;
770  kargs.is_store_randval = s_randval;
771  }
772  if constexpr(kHasLogitsSoftCap)
773  {
774  kargs.init_logits_soft_cap(logits_soft_cap);
775  }
776  if constexpr(kSkipMinSeqlenQ)
777  {
778  kargs.min_seqlen_q = min_seqlen_q;
779  }
780 
781  return kargs;
782  }
783 
784  // std::variant<> can't take in a list initializer, overload for backward compatibility
785  template <bool Cond = kIsGroupMode>
786  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
787  MakeKargs(const void* q_ptr,
788  const void* k_ptr,
789  const void* v_ptr,
790  const void* bias_ptr,
791  void* rand_val_ptr,
792  void* lse_ptr,
793  void* o_ptr,
794  const void* seqstart_q_ptr,
795  const void* seqstart_k_ptr,
796  const void* seqlen_k_ptr,
797  ck_tile::index_t hdim_q,
798  ck_tile::index_t hdim_v,
799  ck_tile::index_t num_head_q,
800  ck_tile::index_t nhead_ratio_qk,
801  float scale_s,
802  float scale_p,
803  float scale_o,
804  float logits_soft_cap,
805  ck_tile::index_t stride_q,
806  ck_tile::index_t stride_k,
807  ck_tile::index_t stride_v,
808  ck_tile::index_t stride_bias,
809  ck_tile::index_t stride_randval,
810  ck_tile::index_t stride_o,
811  ck_tile::index_t nhead_stride_q,
812  ck_tile::index_t nhead_stride_k,
813  ck_tile::index_t nhead_stride_v,
814  ck_tile::index_t nhead_stride_bias,
815  ck_tile::index_t nhead_stride_randval,
816  ck_tile::index_t nhead_stride_lse,
817  ck_tile::index_t nhead_stride_o,
818  ck_tile::index_t window_size_left,
819  ck_tile::index_t window_size_right,
820  ck_tile::index_t mask_type,
821  ck_tile::index_t min_seqlen_q,
822  float p_drop,
823  bool s_randval,
824  const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
825  {
826  return MakeKargsImpl(
827  q_ptr,
828  k_ptr,
829  v_ptr,
830  bias_ptr,
831  rand_val_ptr,
832  lse_ptr,
833  o_ptr,
834  seqstart_q_ptr,
835  seqstart_k_ptr,
836  seqlen_k_ptr,
837  hdim_q,
838  hdim_v,
839  num_head_q,
840  nhead_ratio_qk,
841  scale_s,
842  scale_p,
843  scale_o,
844  logits_soft_cap,
845  stride_q,
846  stride_k,
847  stride_v,
848  stride_bias,
849  stride_randval,
850  stride_o,
851  nhead_stride_q,
852  nhead_stride_k,
853  nhead_stride_v,
854  nhead_stride_bias,
855  nhead_stride_randval,
856  nhead_stride_lse,
857  nhead_stride_o,
858  window_size_left,
859  window_size_right,
860  mask_type,
861  min_seqlen_q,
862  p_drop,
863  s_randval,
864  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
865  }
866 
867  // std::variant<> can't take in a list initializer, overload for backward compatibility
868  template <bool Cond = kIsGroupMode>
869  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
870  MakeKargs(const void* q_ptr,
871  const void* k_ptr,
872  const void* v_ptr,
873  const void* bias_ptr,
874  void* rand_val_ptr,
875  void* lse_ptr,
876  void* o_ptr,
877  const void* seqstart_q_ptr,
878  const void* seqstart_k_ptr,
879  const void* seqlen_k_ptr,
880  ck_tile::index_t hdim_q,
881  ck_tile::index_t hdim_v,
882  ck_tile::index_t num_head_q,
883  ck_tile::index_t nhead_ratio_qk,
884  float scale_s,
885  float scale_p,
886  float scale_o,
887  float logits_soft_cap,
888  ck_tile::index_t stride_q,
889  ck_tile::index_t stride_k,
890  ck_tile::index_t stride_v,
891  ck_tile::index_t stride_bias,
892  ck_tile::index_t stride_randval,
893  ck_tile::index_t stride_o,
894  ck_tile::index_t nhead_stride_q,
895  ck_tile::index_t nhead_stride_k,
896  ck_tile::index_t nhead_stride_v,
897  ck_tile::index_t nhead_stride_bias,
898  ck_tile::index_t nhead_stride_randval,
899  ck_tile::index_t nhead_stride_lse,
900  ck_tile::index_t nhead_stride_o,
901  ck_tile::index_t window_size_left,
902  ck_tile::index_t window_size_right,
903  ck_tile::index_t mask_type,
904  ck_tile::index_t min_seqlen_q,
905  float p_drop,
906  bool s_randval,
907  const std::tuple<const void*, const void*>& drop_seed_offset)
908  {
909  return MakeKargsImpl(
910  q_ptr,
911  k_ptr,
912  v_ptr,
913  bias_ptr,
914  rand_val_ptr,
915  lse_ptr,
916  o_ptr,
917  seqstart_q_ptr,
918  seqstart_k_ptr,
919  seqlen_k_ptr,
920  hdim_q,
921  hdim_v,
922  num_head_q,
923  nhead_ratio_qk,
924  scale_s,
925  scale_p,
926  scale_o,
927  logits_soft_cap,
928  stride_q,
929  stride_k,
930  stride_v,
931  stride_bias,
932  stride_randval,
933  stride_o,
934  nhead_stride_q,
935  nhead_stride_k,
936  nhead_stride_v,
937  nhead_stride_bias,
938  nhead_stride_randval,
939  nhead_stride_lse,
940  nhead_stride_o,
941  window_size_left,
942  window_size_right,
943  mask_type,
944  min_seqlen_q,
945  p_drop,
946  s_randval,
947  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
948  }
949 
950  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
951  ck_tile::index_t nhead_,
952  ck_tile::index_t seqlen_q_,
953  ck_tile::index_t hdim_v_,
954  bool has_padded_seqlen_k = false)
955  {
956  // has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr)
957  if(has_padded_seqlen_k)
958  {
959  // TODO: this may need tuning
960  return dim3(nhead_,
961  batch_size_,
962  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
963  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1));
964  }
965  else
966  {
967  // TODO: this may need tuning
968  return dim3(nhead_,
969  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
970  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
971  batch_size_);
972  }
973  }
974 
975  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
976  {
977  bool has_padded_seqlen_k = false;
978 
979  if constexpr(kIsGroupMode)
980  has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr);
981 
982  if(has_padded_seqlen_k)
983  {
984  // const index_t num_tile_m0 = seqlen_q / kM0;
985  const index_t num_tile_n1 =
986  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
987 
988  const index_t i_block = blockIdx.z;
989  const index_t i_nhead = blockIdx.x;
990  const index_t i_batch = blockIdx.y;
991 
992  const auto f = [](index_t dividend, index_t divisor) {
993  index_t quotient = dividend / divisor;
994  index_t modulus = dividend - quotient * divisor;
995  return ck_tile::make_tuple(quotient, modulus);
996  };
997 
998  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
999 
1000  if constexpr(kHasMask)
1001  {
1002  // assume that num_tile_n1 is always 1
1003  return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
1004  }
1005  else
1006  {
1007  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
1008  }
1009  }
1010  else
1011  {
1012  // const index_t num_tile_m0 = seqlen_q / kM0;
1013  const index_t num_tile_n1 =
1014  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
1015 
1016  const index_t i_block = blockIdx.y; // blockIdx.x
1017  const index_t i_nhead = blockIdx.x; // blockIdx.y
1018  const index_t i_batch = blockIdx.z;
1019 
1020  const auto f = [](index_t dividend, index_t divisor) {
1021  index_t quotient = dividend / divisor;
1022  index_t modulus = dividend - quotient * divisor;
1023  return ck_tile::make_tuple(quotient, modulus);
1024  };
1025 
1026  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
1027 
1028  if constexpr(kHasMask)
1029  {
1030  // assume that num_tile_n1 is always 1
1031  return ck_tile::make_tuple(gridDim.y - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
1032  }
1033  else
1034  {
1035  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
1036  }
1037  }
1038  }
1039 
1040  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
1041 
1043  {
1044  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
1045  }
1046 
1047  CK_TILE_DEVICE void operator()(Kargs kargs) const
1048  {
1049  if constexpr(kIsAvailable)
1050  run_(std::move(kargs));
1051  }
1052 
1053  CK_TILE_DEVICE void run_(Kargs kargs) const
1054  {
1055  if constexpr(kPipelineName != "qr_async_trload")
1056  {
1057  // allocate LDS
1058  __shared__ char smem_ptr[GetSmemSize()];
1059 
1060  // divide problem
1061  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
1062 
1063  const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
1064  const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
1065 
1066  long_index_t batch_offset_q = 0;
1067  long_index_t batch_offset_k = 0;
1068  long_index_t batch_offset_v = 0;
1069  long_index_t batch_offset_bias = 0;
1070  long_index_t batch_offset_randval = 0;
1071  long_index_t batch_offset_lse = 0;
1072  long_index_t batch_offset_o = 0;
1073 
1074  if constexpr(kIsGroupMode)
1075  {
1076  // get starting offset for each batch
1077  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
1078  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
1079 
1080  batch_offset_q = query_start * kargs.stride_q;
1081  batch_offset_k = key_start * kargs.stride_k;
1082  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1083  {
1084  batch_offset_v = key_start * kargs.stride_v;
1085  }
1086  else
1087  {
1088  batch_offset_v = key_start;
1089  }
1091  {
1092  batch_offset_bias = query_start * kargs.stride_bias;
1093  }
1094  if constexpr(kStoreLSE)
1095  {
1096  batch_offset_lse = query_start;
1097  }
1098  if constexpr(kHasDropout)
1099  {
1100  batch_offset_randval = query_start * kargs.stride_randval;
1101  }
1102  batch_offset_o = query_start * kargs.stride_o;
1103 
1104  // get real # queries & # keys under group mode
1105  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
1106  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
1107 
1108  if constexpr(kSkipMinSeqlenQ)
1109  {
1110  if(kargs.seqlen_q <= kargs.min_seqlen_q)
1111  {
1112  return;
1113  }
1114  }
1115 
1116  // # of required blocks is different in each groups, terminate unnecessary blocks
1117  // earlier
1118  if(kargs.seqlen_q <= i_m0)
1119  {
1120  return;
1121  }
1122 
1123  if(kargs.seqlen_k_ptr != nullptr)
1124  {
1125  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
1126  }
1127  else
1128  {
1129  const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
1130  kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
1131  }
1132  }
1133  else
1134  {
1135  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
1136  batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
1137  batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
1139  {
1140  batch_offset_bias =
1141  static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
1142  }
1143  if constexpr(kStoreLSE)
1144  {
1145  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
1146  }
1147  if constexpr(kHasDropout)
1148  {
1149  batch_offset_randval =
1150  static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
1151  }
1152  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
1153  }
1154 
1155  // for simplicity, batch stride we just modify the pointer
1156  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
1157  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
1158  batch_offset_q;
1159  const KDataType* k_ptr =
1160  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
1161  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
1162  batch_offset_k;
1163  const VDataType* v_ptr =
1164  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
1165  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
1166  batch_offset_v;
1167  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
1168  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1169  batch_offset_o;
1170 
1171  // Q/K/V DRAM and DRAM window
1172  const auto q_dram = [&]() {
1173  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1174  q_ptr,
1175  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1176  make_tuple(kargs.stride_q, 1),
1178  number<1>{});
1179  if constexpr(FmhaPipeline::kQLoadOnce)
1180  {
1181  return pad_tensor_view(q_dram_naive,
1185  }
1186  else
1187  {
1188  return pad_tensor_view(
1189  q_dram_naive,
1192  }
1193  }();
1194  const auto k_dram = [&]() {
1195  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1196  k_ptr,
1197  make_tuple(kargs.seqlen_k, kargs.hdim_q),
1198  make_tuple(kargs.stride_k, 1),
1200  number<1>{});
1201 
1202  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
1203  return pad_tensor_view(
1204  k_dram_naive,
1207  }();
1208  const auto v_dram = [&]() {
1209  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1210  {
1211  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1212  v_ptr,
1213  make_tuple(kargs.seqlen_k, kargs.hdim_v),
1214  make_tuple(kargs.stride_v, 1),
1216  number<1>{});
1217 
1218  const auto v_dram_transposed = transform_tensor_view(
1219  v_dram_naive,
1221  make_pass_through_transform(kargs.seqlen_k)),
1224 
1225  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
1226  return pad_tensor_view(
1227  v_dram_transposed,
1230  }
1231  else
1232  {
1233  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1234  v_ptr,
1235  make_tuple(kargs.hdim_v, kargs.seqlen_k),
1236  make_tuple(kargs.stride_v, 1),
1238  number<1>{});
1239 
1240  constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false;
1241  return pad_tensor_view(
1242  v_dram_naive,
1245  }
1246  }();
1247 
1248  auto q_dram_window = make_tile_window(
1249  q_dram,
1250  [&]() {
1251  if constexpr(FmhaPipeline::kQLoadOnce)
1254  else
1256  }(),
1257  {i_m0, 0});
1258 
1259  auto k_dram_window = make_tile_window(
1260  k_dram,
1262  {0, 0});
1263 
1264  auto v_dram_window = make_tile_window(
1265  v_dram,
1267  {i_n1, 0});
1270  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
1271  constexpr auto bias_dram_window_lengths =
1274  {
1275  const BiasDataType* bias_ptr =
1276  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
1277  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
1278  batch_offset_bias;
1279 
1280  const auto bias_dram = [&]() {
1281  const auto bias_dram_naive =
1282  make_naive_tensor_view<address_space_enum::global>(
1283  bias_ptr,
1284  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1285  make_tuple(kargs.stride_bias, 1),
1287  number<1>{});
1288 
1289  return pad_tensor_view(bias_dram_naive,
1290  bias_dram_window_lengths,
1292  }();
1293 
1294  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
1295  }
1296  else
1297  {
1298  return make_null_tile_window(bias_dram_window_lengths);
1299  }
1300  }();
1301 
1302  // lse
1303  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
1304  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
1305  if constexpr(kStoreLSE)
1306  {
1307  LSEDataType* lse_ptr =
1308  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
1309  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse +
1310  batch_offset_lse;
1311 
1312  const auto lse_dram = [&]() {
1313  const auto lse_dram_naive =
1314  make_naive_tensor_view<address_space_enum::global>(
1315  lse_ptr,
1316  make_tuple(kargs.seqlen_q),
1317  make_tuple(1),
1318  number<1>{},
1319  number<1>{});
1320 
1321  return pad_tensor_view(
1322  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
1323  }();
1324 
1325  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
1326  }
1327  else
1328  {
1329  return make_null_tile_window(lse_dram_window_lengths);
1330  }
1331  }();
1332 
1333  auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
1334  if constexpr(kHasDropout)
1335  {
1336  return BlockDropout{i_batch_,
1337  i_nhead_,
1338  kargs.num_head_q,
1339  kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
1340  : *kargs.drop_seed.ptr,
1341  kargs.is_drop_seed_offset_from_host
1342  ? kargs.drop_offset.val
1343  : *kargs.drop_offset.ptr,
1344  kargs.rp_undrop,
1345  kargs.p_undrop_in_uint8_t,
1346  kargs.is_store_randval};
1347  }
1348  else
1349  {
1350  return NullBlockDropout{};
1351  };
1352  }();
1353 
1354  auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
1355  constexpr auto randval_dram_window_lengths =
1357  if constexpr(kHasDropout)
1358  {
1359  RandValOutputDataType* rand_val_ptr =
1360  reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
1361  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
1362  batch_offset_randval;
1363 
1364  const auto randval_dram = [&]() {
1365  const auto randval_dram_naive =
1366  make_naive_tensor_view<address_space_enum::global>(
1367  rand_val_ptr,
1368  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1369  make_tuple(kargs.stride_randval, 1),
1370  number<1>{},
1371  number<1>{});
1372 
1373  return pad_tensor_view(randval_dram_naive,
1374  randval_dram_window_lengths,
1376  }();
1377 
1378  return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0});
1379  }
1380  else
1381  {
1382  return make_null_tile_window(randval_dram_window_lengths);
1383  }
1384  }();
1385 
1386  FmhaMask mask = [&]() {
1387  if constexpr(kHasMask)
1388  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
1389  kargs.window_size_left,
1390  kargs.window_size_right,
1391  kargs.seqlen_q,
1392  kargs.seqlen_k,
1394  else
1395  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1396  }();
1397 
1398  // WA i_batch capture structure binding before c++20
1399  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1400  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
1401  {
1402  // data loading, shared by entire wg
1403  // TODO: how to use s_read?
1404  SaccDataType slope =
1405  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1406  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1407 #if CK_TILE_FMHA_FWD_FAST_EXP2
1408  slope *= ck_tile::log2e_v<>;
1409 #endif
1410  if constexpr(kHasMask)
1411  {
1412  return make_alibi_from_lr_mask<SaccDataType, true>(slope,
1413  kargs.window_size_left,
1414  kargs.window_size_right,
1415  kargs.seqlen_q,
1416  kargs.seqlen_k,
1417  kargs.mask_type);
1418  }
1419  else
1420  {
1422  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1423  }
1424  }
1425  else
1426  {
1428  }
1429  }();
1430 
1431  AttentionVariant variant;
1432  const auto variant_params = [&] {
1433  if constexpr(kHasLogitsSoftCap)
1434  {
1436  mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1437  }
1438  else
1439  {
1440  return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
1441  }
1442  }();
1443 
1444  BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
1445 
1446  auto o_acc_tile = [&]() {
1447  if constexpr(kDoFp8StaticQuant)
1448  {
1449  return FmhaPipeline{}(
1450  q_dram_window,
1451  identity{}, // q_element_func
1452  k_dram_window,
1453  identity{}, // k_element_func
1454  v_dram_window,
1455  identity{}, // v_element_func
1456  bias_dram_window,
1457  identity{}, // bias_element_func
1458  randval_dram_window,
1459  lse_dram_window,
1460  identity{}, // lse_element_func
1461  identity{}, // s_acc_element_func
1462  scales{kargs.scale_p}, // p_compute_element_func
1463  composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
1464  mask,
1465  position_encoding,
1466  kargs.scale_s,
1467  variant,
1468  variant_params,
1469  block_indices,
1470  smem_ptr,
1471  dropout);
1472  }
1473  else
1474  {
1475  return FmhaPipeline{}(q_dram_window,
1476  k_dram_window,
1477  v_dram_window,
1478  bias_dram_window,
1479  randval_dram_window,
1480  lse_dram_window,
1481  mask,
1482  position_encoding,
1483  kargs.scale_s,
1484  variant,
1485  variant_params,
1486  block_indices,
1487  smem_ptr,
1488  dropout);
1489  }
1490  }();
1491 
1492  // O DRAM and O DRAM window
1493  auto o_dram = [&]() {
1494  const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1495  o_ptr,
1496  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1497  make_tuple(kargs.stride_o, 1),
1499  number<1>{});
1500 
1501  return pad_tensor_view(
1502  o_dram_naive,
1505  }();
1506 
1507  auto o_dram_window = make_tile_window(
1508  o_dram,
1510  {i_m0, i_n1});
1511 
1512  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
1513  }
1514  else
1515  {
1516  // TODO: Refine the logical here.
1517  // In Decode case
1518  // 1. we don't expect KV data reused by different ThreadGroups, bypass the cache
1519  // 2. limit the LDS usage, as we want higher occupancy
1520  // In Prefill case
1521  // 1. we expect KV data reused by different ThreadGroups, use cache
1522  // 2. use more LDS, as we want better memory latency hiding
1523  // If SplitKV off, we don't expect Q data reused by different ThreadGroups, bypass the
1524  // cache
1525  constexpr bool PrefillCase = FmhaPipeline::kM0 >= 128;
1526  // divide problem
1527  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
1528 
1529  const index_t i_m0 = i_tile_m * FmhaPipeline::kM0;
1530  const index_t i_n1 = i_tile_n * FmhaPipeline::kN1;
1531 
1532  long_index_t batch_offset_q = 0;
1533  long_index_t batch_offset_k = 0; // unused for paged-kvcache
1534  long_index_t batch_offset_v = 0; // unused for paged-kvcache
1535  long_index_t batch_offset_bias = 0;
1536  long_index_t batch_offset_lse = 0;
1537  long_index_t batch_offset_o = 0;
1538  // index_t kv_l2p_offset =
1539  // 0; // logical-to-physical offset of seqlen_k coordinate. only used for
1540  // paged-kvcache
1541 
1542  if constexpr(kIsGroupMode)
1543  {
1544  // get starting offset for each batch
1545  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
1546  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
1547 
1548  batch_offset_q = query_start * kargs.stride_q;
1549  batch_offset_k = key_start * kargs.stride_k;
1550  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1551  {
1552  batch_offset_v = key_start * kargs.stride_v;
1553  }
1554  else
1555  {
1556  batch_offset_v = key_start;
1557  }
1559  {
1560  batch_offset_bias = query_start * kargs.stride_bias;
1561  }
1562 
1563  batch_offset_lse = query_start;
1564  batch_offset_o = query_start * kargs.stride_o;
1565 
1566  // get real # queries & # keys under group mode
1567  kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
1568 
1569  // # of required blocks is different in each groups, terminate unnecessary blocks
1570  // earlier
1571  if(kargs.seqlen_q <= i_m0)
1572  {
1573  return;
1574  }
1575 
1576  if(kargs.seqlen_k_ptr != nullptr)
1577  {
1578  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
1579  }
1580  else
1581  {
1582  kargs.seqlen_k =
1583  kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
1584  }
1585  }
1586  else
1587  {
1588  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
1589  batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
1590  batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
1591  if constexpr(kStoreLSE)
1592  {
1593  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
1594  }
1595  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
1596 
1598  {
1599  batch_offset_bias =
1600  static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
1601  }
1602  }
1603 
1604  // for simplicity, batch stride we just modify the pointer
1605  const index_t i_nhead_k = i_nhead / kargs.nhead_ratio_qk;
1606 
1607  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
1608  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
1609  batch_offset_q;
1610  const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) +
1611  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k +
1612  batch_offset_k;
1613  const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) +
1614  static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v +
1615  batch_offset_v;
1616 
1617  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
1618  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1619  batch_offset_o;
1620 
1621  // Q/K/V DRAM and DRAM window
1622  const auto q_dram = [&] {
1623  const auto q_dram_naive = [&] {
1624  {
1625  return make_naive_tensor_view<address_space_enum::global,
1626  memory_operation_enum::set,
1628  q_ptr,
1629  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1630  make_tuple(kargs.stride_q, 1),
1632  number<1>{});
1633  }
1634  }();
1635 
1636  if constexpr(FmhaPipeline::kQLoadOnce)
1637  {
1638  const auto seqlen_q = kargs.seqlen_q;
1639  const auto q_dram_pad = pad_tensor_view(
1640  q_dram_naive,
1643 #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1644  constexpr index_t LDSLayerSize = 256 / sizeof(QDataType);
1645  constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
1646 
1647  if constexpr(XorLengthFold > 1)
1648  {
1649  const auto q_dram_unmerged = transform_tensor_view(
1650  q_dram_pad,
1651  make_tuple(
1653  make_tuple(seqlen_q / XorLengthFold, XorLengthFold)),
1657 
1658  const auto q_dram_merged = transform_tensor_view(
1659  q_dram_unmerged,
1660  make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold),
1662  XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
1665 
1666  const auto q_dram_unmerged_xor = transform_tensor_view(
1667  q_dram_merged,
1668  make_tuple(make_pass_through_transform(seqlen_q / XorLengthFold),
1674 
1675  const auto q_dram_permuted = transform_tensor_view(
1676  q_dram_unmerged_xor,
1677  make_tuple(
1679  make_tuple(seqlen_q / XorLengthFold,
1684 
1685  const auto q_dram_tmp = transform_tensor_view(
1686  q_dram_permuted,
1687  make_tuple(
1688  make_pass_through_transform(seqlen_q / XorLengthFold),
1691  number<FmhaPipeline::kQKHeaddim /
1692  FmhaPipeline::kAlignmentQ>{})),
1696 
1697  return transform_tensor_view(
1698  q_dram_tmp,
1699  make_tuple(
1701  make_tuple(seqlen_q / XorLengthFold, number<XorLengthFold>{})),
1707  }
1708  else
1709 #endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1710  {
1711  const auto q_dram_unmerged = transform_tensor_view(
1712  q_dram_pad,
1713  make_tuple(
1714  make_pass_through_transform(seqlen_q),
1720 
1721  const auto q_dram_permuted = transform_tensor_view(
1722  q_dram_unmerged,
1723  make_tuple(
1724  make_xor_transform(make_tuple(seqlen_q,
1725  number<FmhaPipeline::kQKHeaddim /
1726  FmhaPipeline::kAlignmentQ>{})),
1730 
1731  return transform_tensor_view(
1732  q_dram_permuted,
1733  make_tuple(
1734  make_pass_through_transform(seqlen_q),
1740  }
1741  }
1742  else
1743  {
1744  return pad_tensor_view(
1745  q_dram_naive,
1748  }
1749  }();
1750 
1751  const auto make_k_dram = [&](const KDataType* data, index_t height) {
1752  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1753  data, // will update this pointer if using paged-kvcache
1754  make_tuple(height, kargs.hdim_q),
1755  make_tuple(kargs.stride_k, 1),
1757  number<1>{});
1758 
1759  const auto k_dram_pad = pad_tensor_view(
1760  k_dram_naive,
1763 
1764 #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1765  constexpr index_t LDSLayerSize = 256 / sizeof(KDataType);
1766  constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
1767 
1768  if constexpr(XorLengthFold > 1)
1769  {
1770  const auto k_dram_unmerged = transform_tensor_view(
1771  k_dram_pad,
1773  make_tuple(height / XorLengthFold, XorLengthFold)),
1777 
1778  const auto k_dram_merged = transform_tensor_view(
1779  k_dram_unmerged,
1780  make_tuple(make_pass_through_transform(height / XorLengthFold),
1782  XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
1785 
1786  const auto k_dram_unmerged_xor = transform_tensor_view(
1787  k_dram_merged,
1788  make_tuple(make_pass_through_transform(height / XorLengthFold),
1794 
1795  const auto k_dram_permuted = transform_tensor_view(
1796  k_dram_unmerged_xor,
1797  make_tuple(
1799  make_tuple(height / XorLengthFold,
1804 
1805  const auto k_dram_tmp = transform_tensor_view(
1806  k_dram_permuted,
1807  make_tuple(
1808  make_pass_through_transform(height / XorLengthFold),
1811  number<FmhaPipeline::kQKHeaddim / FmhaPipeline::kAlignmentK>{})),
1815 
1816  return transform_tensor_view(
1817  k_dram_tmp,
1818  make_tuple(
1820  make_tuple(height / XorLengthFold, number<XorLengthFold>{})),
1826  }
1827  else
1828 #endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1829  {
1830  const auto k_dram_unmerged = transform_tensor_view(
1831  k_dram_pad,
1832  make_tuple(
1839 
1840  const auto k_dram_permuted = transform_tensor_view(
1841  k_dram_unmerged,
1842  make_tuple(
1844  height,
1849 
1850  return transform_tensor_view(
1851  k_dram_permuted,
1852  make_tuple(
1859  }
1860  };
1861  const auto k_dram = [&]() {
1862  {
1863  return make_k_dram(k_ptr, kargs.seqlen_k);
1864  }
1865  }();
1866 
1867  const auto make_v_dram = [&](const VDataType* data, index_t length) {
1868  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1869  data, // will update this pointer if using paged-kvcache
1870  make_tuple(length, kargs.hdim_v),
1871  make_tuple(kargs.hdim_v, 1),
1873  number<1>{});
1874 
1875  // TODO: Add kVHeadDim
1876  constexpr index_t XorGroupSize =
1877  FmhaPipeline::Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
1878 
1879  const auto v_dram_pad = pad_tensor_view(
1880  v_dram_naive,
1883 
1884 #if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1885  constexpr index_t LDSLayerSize = 256 / sizeof(VDataType);
1886  constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
1887 
1888  if constexpr(XorLengthFold > 1)
1889  {
1890  const auto v_dram_unmerged = transform_tensor_view(
1891  v_dram_pad,
1893  make_tuple(length / XorLengthFold, XorLengthFold)),
1897 
1898  const auto v_dram_merged = transform_tensor_view(
1899  v_dram_unmerged,
1900  make_tuple(make_pass_through_transform(length / XorLengthFold),
1902  XorLengthFold, number<FmhaPipeline::kQKHeaddim>{}))),
1905 
1906  const auto v_dram_unmerged_xor = transform_tensor_view(
1907  v_dram_merged,
1908  make_tuple(
1909  make_pass_through_transform(length / XorLengthFold),
1911  number<XorGroupSize>{}))),
1914 
1915  const auto v_dram_permuted = transform_tensor_view(
1916  v_dram_unmerged_xor,
1917  make_tuple(
1918  make_xor_transform(make_tuple(length / XorLengthFold,
1923 
1924  const auto v_dram_tmp = transform_tensor_view(
1925  v_dram_permuted,
1926  make_tuple(make_pass_through_transform(length / XorLengthFold),
1929  number<FmhaPipeline::kQKHeaddim / XorGroupSize>{})),
1933 
1934  return transform_tensor_view(
1935  v_dram_tmp,
1937  make_tuple(length / XorLengthFold, number<XorLengthFold>{})),
1940  number<XorGroupSize>{}))),
1943  }
1944  else
1945 #endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
1946  {
1947  const auto v_dram_unmerged = transform_tensor_view(
1948  v_dram_pad,
1952  number<XorGroupSize>{}))),
1955 
1956  const auto v_dram_permuted = transform_tensor_view(
1957  v_dram_unmerged,
1963 
1964  return transform_tensor_view(
1965  v_dram_permuted,
1969  number<XorGroupSize>{}))),
1972  }
1973  };
1974 
1975  const auto v_dram = [&]() {
1976  {
1977  return make_v_dram(v_ptr, kargs.seqlen_k);
1978  }
1979  }();
1980 
1981  auto q_dram_window = make_tile_window(
1982  q_dram,
1983  [&]() {
1984  if constexpr(FmhaPipeline::kQLoadOnce)
1987  else
1989  }(),
1990  {i_m0, 0});
1991 
1992  auto k_dram_window = make_tile_window(
1993  k_dram,
1995  {0, 0});
1996 
1997  auto v_dram_window = make_tile_window(
1998  v_dram,
2000  {0, 0});
2001 
2004  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
2005  constexpr auto bias_dram_window_lengths =
2008  {
2009  const BiasDataType* bias_ptr =
2010  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
2011  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
2012  batch_offset_bias;
2013 
2014  const auto bias_dram = [&]() {
2015  const auto bias_dram_naive =
2016  make_naive_tensor_view<address_space_enum::global>(
2017  bias_ptr,
2018  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
2019  make_tuple(kargs.stride_bias, 1),
2021  number<1>{});
2022 
2023  return pad_tensor_view(bias_dram_naive,
2024  bias_dram_window_lengths,
2026  }();
2027 
2028  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
2029  }
2030  else
2031  {
2032  return make_null_tile_window(bias_dram_window_lengths);
2033  }
2034  }();
2035 
2036  // lse acc
2037  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
2038  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
2039  if constexpr(kStoreLSE)
2040  {
2041  LSEDataType* lse_ptr =
2042  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
2043  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse +
2044  batch_offset_lse;
2045 
2046  const auto lse_dram = [&] {
2047  const auto lse_dram_naive = [&] {
2048  {
2049  return make_naive_tensor_view<address_space_enum::global>(
2050  lse_ptr,
2051  make_tuple(kargs.seqlen_q),
2052  make_tuple(1),
2053  number<1>{},
2054  number<1>{});
2055  }
2056  }();
2057  return pad_tensor_view(
2058  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
2059  }();
2060 
2061  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
2062  }
2063  else
2064  {
2065  return make_null_tile_window(lse_dram_window_lengths);
2066  }
2067  }();
2068 
2069  FmhaMask mask = [&]() {
2070  if constexpr(kHasMask)
2071  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
2072  kargs.window_size_left,
2073  kargs.window_size_right,
2074  kargs.seqlen_q,
2075  kargs.seqlen_k,
2077  else
2078  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
2079  }();
2080 
2081  // WA i_batch capture structure binding before c++20
2082  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
2083  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
2084  {
2085  // data loading, shared by entire wg
2086  // TODO: how to use s_read?
2087  SaccDataType slope =
2088  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
2089  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
2090 #if CK_TILE_FMHA_FWD_FAST_EXP2
2091  slope *= ck_tile::log2e_v<>;
2092 #endif
2093  if constexpr(kHasMask)
2094  {
2095  return make_alibi_from_lr_mask<SaccDataType, true, 32>(
2096  slope,
2097  kargs.window_size_left,
2098  kargs.window_size_right,
2099  kargs.seqlen_q,
2100  kargs.seqlen_k,
2101  kargs.mask_type);
2102  }
2103  else
2104  {
2106  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
2107  }
2108  }
2109  else
2110  {
2112  }
2113  }();
2114 
2115  auto o_acc_tile = [&]() {
2116  if constexpr(PrefillCase)
2117  {
2118  // allocate double lds
2119  // add __restrict__ here to avoid aliasing
2120  __shared__ char smem_ptrk0
2121  [FmhaPipeline::Policy::template GetSmemSizeK<typename FmhaPipeline::Problem,
2122  true>()];
2123  __shared__ char smem_ptrk1
2124  [FmhaPipeline::Policy::template GetSmemSizeK<typename FmhaPipeline::Problem,
2125  true>()];
2126  __shared__ char smem_ptrv0[FmhaPipeline::Policy::template GetSmemSizeV<
2127  typename FmhaPipeline::Problem>()];
2128  __shared__ char smem_ptrv1[FmhaPipeline::Policy::template GetSmemSizeV<
2129  typename FmhaPipeline::Problem>()];
2130 
2131  return FmhaPipeline{}(q_dram_window,
2132  k_dram_window,
2133  v_dram_window,
2134  bias_dram_window,
2135  lse_dram_window,
2136  mask,
2137  position_encoding,
2138  kargs.scale_s,
2139  smem_ptrk0,
2140  smem_ptrk1,
2141  smem_ptrv0,
2142  smem_ptrv1);
2143  }
2144  else
2145  {
2146  __shared__ char smem_ptr[GetSmemSize()];
2147  return FmhaPipeline{}(q_dram_window,
2148  k_dram_window,
2149  v_dram_window,
2150  bias_dram_window,
2151  lse_dram_window,
2152  mask,
2153  position_encoding,
2154  kargs.scale_s,
2155  smem_ptr);
2156  }
2157  }();
2158 
2159  // Oacc DRAM and Oacc DRAM window
2160  auto o_dram = [&] {
2161  const auto o_dram_naive = [&] {
2162  {
2163  return make_naive_tensor_view<address_space_enum::global>(
2164  o_ptr,
2165  make_tuple(kargs.seqlen_q, kargs.hdim_v),
2166  make_tuple(kargs.stride_o, 1),
2168  number<1>{});
2169  }
2170  }();
2171 
2172  return pad_tensor_view(
2173  o_dram_naive,
2176  }();
2177 
2178  auto o_dram_window = make_tile_window(
2179  o_dram,
2181  {i_m0, i_n1});
2182 
2183  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
2184  }
2185  }
2186 };
2187 
2188 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
#define _TS_
#define _SS_
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ T floor(T x)
Definition: math_v2.hpp:367
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
_Float16 fp16_t
Definition: half.hpp:110
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:511
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1622
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1609
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto make_xor_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1662
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_view(DataType *__restrict__ p, const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_view.hpp:471
__host__ __device__ composes(Ts &&...) -> composes< remove_cvref_t< Ts >... >
FIXME: create macro to replace 'host device' and nothing more.
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
unsigned char uint8_t
Definition: stdint.h:124
unsigned __int64 uint64_t
Definition: stdint.h:136
Definition: block_position_encoding.hpp:48
Definition: block_attention_bias_enum.hpp:19
Definition: block_dropout.hpp:26
const float rp_undrop
Definition: block_dropout.hpp:290
Definition: block_position_encoding.hpp:137
Definition: fmha_fwd_kernel.hpp:318
ck_tile::index_t kv_head_idx
Definition: fmha_fwd_kernel.hpp:321
ck_tile::index_t batch_idx
Definition: fmha_fwd_kernel.hpp:319
ck_tile::index_t qo_head_idx
Definition: fmha_fwd_kernel.hpp:320
Definition: fmha_fwd_kernel.hpp:192
ck_tile::index_t alibi_slope_stride
Definition: fmha_fwd_kernel.hpp:195
const void * alibi_slope_ptr
Definition: fmha_fwd_kernel.hpp:194
Definition: fmha_fwd_kernel.hpp:187
ck_tile::index_t batch_stride_bias
Definition: fmha_fwd_kernel.hpp:188
Definition: fmha_fwd_kernel.hpp:268
ck_tile::index_t batch_stride_randval
Definition: fmha_fwd_kernel.hpp:269
Definition: fmha_fwd_kernel.hpp:289
ck_tile::index_t batch_stride_o
Definition: fmha_fwd_kernel.hpp:293
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_kernel.hpp:290
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_kernel.hpp:291
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_kernel.hpp:292
Definition: fmha_fwd_kernel.hpp:180
const void * bias_ptr
Definition: fmha_fwd_kernel.hpp:181
ck_tile::index_t stride_bias
Definition: fmha_fwd_kernel.hpp:182
ck_tile::index_t nhead_stride_bias
Definition: fmha_fwd_kernel.hpp:183
Definition: fmha_fwd_kernel.hpp:233
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr)
Definition: fmha_fwd_kernel.hpp:246
float rp_undrop
Definition: fmha_fwd_kernel.hpp:258
ck_tile::index_t stride_randval
Definition: fmha_fwd_kernel.hpp:263
ck_tile::index_t nhead_stride_randval
Definition: fmha_fwd_kernel.hpp:264
void * rand_val_ptr
Definition: fmha_fwd_kernel.hpp:261
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
Definition: fmha_fwd_kernel.hpp:234
bool is_store_randval
Definition: fmha_fwd_kernel.hpp:260
uint8_t p_undrop_in_uint8_t
Definition: fmha_fwd_kernel.hpp:259
Definition: fmha_fwd_kernel.hpp:129
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_kernel.hpp:152
float scale_s
Definition: fmha_fwd_kernel.hpp:144
ck_tile::index_t seqlen_k
Definition: fmha_fwd_kernel.hpp:136
ck_tile::index_t nhead_stride_o
Definition: fmha_fwd_kernel.hpp:154
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_kernel.hpp:143
ck_tile::index_t num_head_q
Definition: fmha_fwd_kernel.hpp:140
ck_tile::index_t hdim_q
Definition: fmha_fwd_kernel.hpp:137
const void * v_ptr
Definition: fmha_fwd_kernel.hpp:132
void * o_ptr
Definition: fmha_fwd_kernel.hpp:133
const void * k_ptr
Definition: fmha_fwd_kernel.hpp:131
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_kernel.hpp:151
ck_tile::index_t stride_k
Definition: fmha_fwd_kernel.hpp:147
ck_tile::index_t stride_o
Definition: fmha_fwd_kernel.hpp:149
ck_tile::index_t stride_v
Definition: fmha_fwd_kernel.hpp:148
ck_tile::index_t hdim_v
Definition: fmha_fwd_kernel.hpp:138
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_kernel.hpp:153
const void * q_ptr
Definition: fmha_fwd_kernel.hpp:130
ck_tile::index_t seqlen_q
Definition: fmha_fwd_kernel.hpp:135
ck_tile::index_t stride_q
Definition: fmha_fwd_kernel.hpp:146
Definition: fmha_fwd_kernel.hpp:212
ck_tile::index_t batch_stride_lse
Definition: fmha_fwd_kernel.hpp:215
void * lse_ptr
Definition: fmha_fwd_kernel.hpp:213
ck_tile::index_t nhead_stride_lse
Definition: fmha_fwd_kernel.hpp:214
Definition: fmha_fwd_kernel.hpp:219
bool is_drop_seed_offset_from_host
Definition: fmha_fwd_kernel.hpp:229
ValueOrPointer< uint64_t > drop_seed
Definition: fmha_fwd_kernel.hpp:227
ValueOrPointer< uint64_t > drop_offset
Definition: fmha_fwd_kernel.hpp:228
Definition: fmha_fwd_kernel.hpp:122
Definition: fmha_fwd_kernel.hpp:206
float scale_o
Definition: fmha_fwd_kernel.hpp:208
float scale_p
Definition: fmha_fwd_kernel.hpp:207
Definition: fmha_fwd_kernel.hpp:309
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_kernel.hpp:310
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_kernel.hpp:312
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_kernel.hpp:311
Definition: fmha_fwd_kernel.hpp:158
float logits_soft_cap
Definition: fmha_fwd_kernel.hpp:175
float logits_soft_cap_rcp
Definition: fmha_fwd_kernel.hpp:176
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_fwd_kernel.hpp:161
Definition: fmha_fwd_kernel.hpp:199
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_kernel.hpp:202
ck_tile::index_t window_size_right
Definition: fmha_fwd_kernel.hpp:201
ck_tile::index_t window_size_left
Definition: fmha_fwd_kernel.hpp:201
Definition: fmha_fwd_kernel.hpp:273
ck_tile::index_t min_seqlen_q
Definition: fmha_fwd_kernel.hpp:274
Definition: fmha_fwd_kernel.hpp:75
Definition: fmha_fwd_kernel.hpp:27
static constexpr bool kHasDropout
Definition: fmha_fwd_kernel.hpp:56
static CK_TILE_HOST std::string GetName()
Definition: fmha_fwd_kernel.hpp:83
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset)
Definition: fmha_fwd_kernel.hpp:466
static constexpr bool kIsAvailable
Definition: fmha_fwd_kernel.hpp:70
static constexpr bool kDoFp8StaticQuant
Definition: fmha_fwd_kernel.hpp:57
static constexpr bool kStoreLSE
Definition: fmha_fwd_kernel.hpp:55
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_kernel.hpp:37
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_fwd_kernel.hpp:315
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_kernel.hpp:32
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_kernel.hpp:43
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_fwd_kernel.hpp:46
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_kernel.hpp:30
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_fwd_kernel.hpp:39
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_, bool has_padded_seqlen_k=false)
Definition: fmha_fwd_kernel.hpp:950
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_fwd_kernel.hpp:1040
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_kernel.hpp:38
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset)
Definition: fmha_fwd_kernel.hpp:559
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_fwd_kernel.hpp:34
static constexpr auto BiasEnum
Definition: fmha_fwd_kernel.hpp:54
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_kernel.hpp:52
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_fwd_kernel.hpp:975
static constexpr bool kSkipMinSeqlenQ
Definition: fmha_fwd_kernel.hpp:58
static constexpr std::string_view kPipelineName
Definition: fmha_fwd_kernel.hpp:72
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_kernel.hpp:42
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_kernel.hpp:36
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset)
Definition: fmha_fwd_kernel.hpp:787
CK_TILE_DEVICE void run_(Kargs kargs) const
Definition: fmha_fwd_kernel.hpp:1053
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_fwd_kernel.hpp:60
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_kernel.hpp:1042
static constexpr bool kUseTrLoad
Definition: fmha_fwd_kernel.hpp:66
static constexpr bool kHasMask
Definition: fmha_fwd_kernel.hpp:62
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_fwd_kernel.hpp:651
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_fwd_kernel.hpp:326
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset)
Definition: fmha_fwd_kernel.hpp:870
static constexpr bool kUseAsyncCopy
Definition: fmha_fwd_kernel.hpp:64
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_kernel.hpp:28
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_kernel.hpp:51
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_kernel.hpp:44
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_kernel.hpp:49
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_kernel.hpp:61
static constexpr bool kHasLogitsSoftCap
Definition: fmha_fwd_kernel.hpp:53
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition: fmha_fwd_kernel.hpp:41
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_kernel.hpp:29
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_kernel.hpp:50
static constexpr bool kIsGroupMode
Definition: fmha_fwd_kernel.hpp:48
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_kernel.hpp:1047
Definition: variants.hpp:63
float logits_soft_cap
Definition: variants.hpp:128
Definition: block_dropout.hpp:12
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: functional.hpp:86
Definition: coordinate_transform.hpp:1392
Definition: unary_element_function.hpp:56
Definition: math.hpp:28
Definition: sequence.hpp:49
const T * ptr
Definition: fmha_fwd_kernel.hpp:224