/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp Source File
fmha_fwd_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
10 
11 #include <string>
12 #include <type_traits>
13 #include <utility>
14 #include <variant>
15 
16 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
17 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
18 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
19 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
20 // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
21 
22 namespace ck_tile {
23 
24 template <typename FmhaPipeline_, typename EpiloguePipeline_>
26 {
29  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
30  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
31  static_assert(kBlockPerCu > 0);
32  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
33 
43 
45 
46  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
47  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
48  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
49  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
50  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
51  static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
52  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
53  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
54  static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
55  static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
56  static constexpr bool kSkipMinSeqlenQ = FmhaPipeline::Problem::kSkipMinSeqlenQ;
57 
60  static constexpr bool kHasMask = FmhaMask::IsMasking;
61 
62  static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
63 
64  // clang-format off
65  template <typename T> struct t2s;
66  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
67  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
68  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
69  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
70  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
71  // clang-format on
72 
73  CK_TILE_HOST static std::string GetName()
74  {
75  // sync with generate.py
76  // clang-format off
77  using bfs = typename FmhaPipeline::BlockFmhaShape;
78  using g0br = typename bfs::Gemm0BlockWarps;
79  using g1br = typename bfs::Gemm1BlockWarps;
80  using g0wt = typename bfs::Gemm0WarpTile;
81  using g1wt = typename bfs::Gemm1WarpTile;
82  #define _SS_ std::string
83  #define _TS_ std::to_string
84  auto pn = [&] () {
85  std::string n;
86  if (kPadSeqLenQ) n += "s";
87  if (kPadSeqLenK) n += "sk";
88  if (kPadHeadDimQ) n += "d";
89  if (kPadHeadDimV) n += "dv";
90  return n.empty() ? n : std::string("p") + n; }();
91  return
92  _SS_("fmha_fwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
93  "_" + (kIsGroupMode ? "group" : "batch") + "_"
94  "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
95  _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
96  "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
97  "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
98  "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
99  "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
100  (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
101  "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
102  (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
103  (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kSkipMinSeqlenQ ? "_skip" : "_nskip" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" );
104  #undef _SS_
105  #undef _TS_
106  // clang-format on
107  }
108 
109  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
110  // arg
112  {
113  };
114 
115  // kargs use aggregate initializer, so no constructor will provided
116  // use inheritance to minimize karg size
117  // user need to use MakeKargs() function to create kargs.
119  {
120  const void* q_ptr;
121  const void* k_ptr;
122  const void* v_ptr;
123  void* o_ptr;
124 
129 
131  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
132  // if this param is larger than 1, indicate MQA/GQA case
134  float scale_s;
135 
140 
145  };
146 
148  {
150 
151  void init_logits_soft_cap(float logits_soft_cap_)
152  {
153  if(0 < logits_soft_cap_)
154  {
155  logits_soft_cap = logits_soft_cap_;
157  }
158  else
159  {
160  logits_soft_cap = 0.f;
161  logits_soft_cap_rcp = 0.f;
162  }
163  }
164 
167  };
168 
170  {
171  const void* bias_ptr = nullptr;
174  };
175 
177  {
179  };
180 
182  {
183  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
184  const void* alibi_slope_ptr;
185  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
186  };
187 
189  {
190  // ck_tile::index_t window_size_left, window_size_right;
193  };
194 
196  {
197  float scale_p;
198  float scale_o;
199  };
200 
202  {
203  void* lse_ptr = nullptr;
206  };
207 
209  {
210  template <typename T>
212  {
213  T val;
214  const T* ptr;
215  };
216 
220  };
221 
223  {
224  void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
225  {
226  float p_undrop = 1.0 - p_drop;
228  uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
229  rp_undrop = 1.0 / p_undrop;
230 
231  this->drop_seed.val = seed;
232  this->drop_offset.val = offset;
233  this->is_drop_seed_offset_from_host = true;
234  }
235 
236  void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
237  {
238  float p_undrop = 1.0 - p_drop;
240  uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
241  rp_undrop = 1.0 / p_undrop;
242 
243  this->drop_seed.ptr = seed_ptr;
244  this->drop_offset.ptr = offset_ptr;
245  this->is_drop_seed_offset_from_host = false;
246  }
247 
248  float rp_undrop = 1;
250  bool is_store_randval = false;
251  void* rand_val_ptr = nullptr;
252 
255  };
256 
258  {
260  };
261 
263  {
265  };
266 
269  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
270  FmhaFwdBatchModeBiasKargs,
271  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
272  FmhaFwdAlibiKargs,
273  FmhaFwdEmptyKargs<0>>>,
274  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
275  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
276  std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
277  std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
278  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
279  {
284  };
285 
288  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
289  FmhaFwdCommonBiasKargs,
290  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
291  FmhaFwdAlibiKargs,
292  FmhaFwdEmptyKargs<0>>>,
293  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
294  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
295  std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
296  std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
297  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
298  std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
299  {
303  };
304 
305  using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
306 
308  {
312  };
313 
314  template <bool Cond = !kIsGroupMode>
315  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
316  MakeKargsImpl(const void* q_ptr,
317  const void* k_ptr,
318  const void* v_ptr,
319  const void* bias_ptr,
320  void* rand_val_ptr,
321  void* lse_ptr,
322  void* o_ptr,
323  ck_tile::index_t seqlen_q,
324  ck_tile::index_t seqlen_k,
325  ck_tile::index_t hdim_q,
326  ck_tile::index_t hdim_v,
327  ck_tile::index_t num_head_q,
328  ck_tile::index_t nhead_ratio_qk,
329  float scale_s,
330  float scale_p,
331  float scale_o,
332  float logits_soft_cap,
333  ck_tile::index_t stride_q,
334  ck_tile::index_t stride_k,
335  ck_tile::index_t stride_v,
336  ck_tile::index_t stride_bias,
337  ck_tile::index_t stride_randval,
338  ck_tile::index_t stride_o,
339  ck_tile::index_t nhead_stride_q,
340  ck_tile::index_t nhead_stride_k,
341  ck_tile::index_t nhead_stride_v,
342  ck_tile::index_t nhead_stride_bias,
343  ck_tile::index_t nhead_stride_randval,
344  ck_tile::index_t nhead_stride_lse,
345  ck_tile::index_t nhead_stride_o,
346  ck_tile::index_t batch_stride_q,
347  ck_tile::index_t batch_stride_k,
348  ck_tile::index_t batch_stride_v,
349  ck_tile::index_t batch_stride_bias,
350  ck_tile::index_t batch_stride_randval,
351  ck_tile::index_t batch_stride_lse,
352  ck_tile::index_t batch_stride_o,
353  ck_tile::index_t window_size_left,
354  ck_tile::index_t window_size_right,
355  ck_tile::index_t mask_type,
356  float p_drop,
357  bool s_randval,
358  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
359  drop_seed_offset)
360  {
361  Kargs kargs{{q_ptr,
362  k_ptr,
363  v_ptr,
364  o_ptr,
365  seqlen_q,
366  seqlen_k,
367  hdim_q,
368  hdim_v,
369  num_head_q,
370  nhead_ratio_qk,
371 #if CK_TILE_FMHA_FWD_FAST_EXP2
372  static_cast<float>(scale_s * ck_tile::log2e_v<>),
373 #else
374  scale_s,
375 #endif
376  stride_q,
377  stride_k,
378  stride_v,
379  stride_o,
380  nhead_stride_q,
381  nhead_stride_k,
382  nhead_stride_v,
383  nhead_stride_o}, // args for common karg
384  {}, // placeholder for bias
385  {}, // placeholder for mask
386  {}, // placeholder for lse
387  {}, // placeholder for fp8_static_quant args
388  {}, // placeholder for dropout
389  {}, // placeholder for logits_soft_cap
390  batch_stride_q,
391  batch_stride_k,
392  batch_stride_v,
393  batch_stride_o};
394 
396  {
397  kargs.bias_ptr = bias_ptr;
398  kargs.stride_bias = stride_bias;
399  kargs.nhead_stride_bias = nhead_stride_bias;
400  kargs.batch_stride_bias = batch_stride_bias;
401  }
402  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
403  {
404  kargs.alibi_slope_ptr = bias_ptr;
405  kargs.alibi_slope_stride = stride_bias;
406  }
407  if constexpr(kHasMask)
408  {
409  kargs.window_size_left = window_size_left;
410  kargs.window_size_right = window_size_right;
411  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
412  }
413  if constexpr(kStoreLSE)
414  {
415  kargs.lse_ptr = lse_ptr;
416  kargs.nhead_stride_lse = nhead_stride_lse;
417  kargs.batch_stride_lse = batch_stride_lse;
418  }
419  if constexpr(kDoFp8StaticQuant)
420  {
421  kargs.scale_p = scale_p;
422  kargs.scale_o = scale_o;
423  }
424  if constexpr(kHasDropout)
425  {
426  if(drop_seed_offset.index() == 0) // seed & offset come from host
427  {
428  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
429  kargs.init_dropout(p_drop, seed, offset);
430  }
431  else // seed & offset come from device
432  {
433  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
434  kargs.init_dropout(p_drop,
435  reinterpret_cast<const uint64_t*>(seed_ptr),
436  reinterpret_cast<const uint64_t*>(offset_ptr));
437  }
438 
439  kargs.rand_val_ptr = rand_val_ptr;
440  kargs.stride_randval = stride_randval;
441  kargs.nhead_stride_randval = nhead_stride_randval;
442  kargs.batch_stride_randval = batch_stride_randval;
443  kargs.is_store_randval = s_randval;
444  }
445  if constexpr(kHasLogitsSoftCap)
446  {
447  kargs.init_logits_soft_cap(logits_soft_cap);
448  }
449 
450  return kargs;
451  }
452 
453  // std::variant<> can't take in a list initializer, overload for backward compatibility
454  template <bool Cond = !kIsGroupMode>
455  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
456  MakeKargs(const void* q_ptr,
457  const void* k_ptr,
458  const void* v_ptr,
459  const void* bias_ptr,
460  void* rand_val_ptr,
461  void* lse_ptr,
462  void* o_ptr,
463  ck_tile::index_t seqlen_q,
464  ck_tile::index_t seqlen_k,
465  ck_tile::index_t hdim_q,
466  ck_tile::index_t hdim_v,
467  ck_tile::index_t num_head_q,
468  ck_tile::index_t nhead_ratio_qk,
469  float scale_s,
470  float scale_p,
471  float scale_o,
472  float logits_soft_cap,
473  ck_tile::index_t stride_q,
474  ck_tile::index_t stride_k,
475  ck_tile::index_t stride_v,
476  ck_tile::index_t stride_bias,
477  ck_tile::index_t stride_randval,
478  ck_tile::index_t stride_o,
479  ck_tile::index_t nhead_stride_q,
480  ck_tile::index_t nhead_stride_k,
481  ck_tile::index_t nhead_stride_v,
482  ck_tile::index_t nhead_stride_bias,
483  ck_tile::index_t nhead_stride_randval,
484  ck_tile::index_t nhead_stride_lse,
485  ck_tile::index_t nhead_stride_o,
486  ck_tile::index_t batch_stride_q,
487  ck_tile::index_t batch_stride_k,
488  ck_tile::index_t batch_stride_v,
489  ck_tile::index_t batch_stride_bias,
490  ck_tile::index_t batch_stride_randval,
491  ck_tile::index_t batch_stride_lse,
492  ck_tile::index_t batch_stride_o,
493  ck_tile::index_t window_size_left,
494  ck_tile::index_t window_size_right,
495  ck_tile::index_t mask_type,
496  float p_drop,
497  bool s_randval,
498  const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
499  {
500  return MakeKargsImpl(
501  q_ptr,
502  k_ptr,
503  v_ptr,
504  bias_ptr,
505  rand_val_ptr,
506  lse_ptr,
507  o_ptr,
508  seqlen_q,
509  seqlen_k,
510  hdim_q,
511  hdim_v,
512  num_head_q,
513  nhead_ratio_qk,
514  scale_s,
515  scale_p,
516  scale_o,
517  logits_soft_cap,
518  stride_q,
519  stride_k,
520  stride_v,
521  stride_bias,
522  stride_randval,
523  stride_o,
524  nhead_stride_q,
525  nhead_stride_k,
526  nhead_stride_v,
527  nhead_stride_bias,
528  nhead_stride_randval,
529  nhead_stride_lse,
530  nhead_stride_o,
531  batch_stride_q,
532  batch_stride_k,
533  batch_stride_v,
534  batch_stride_bias,
535  batch_stride_randval,
536  batch_stride_lse,
537  batch_stride_o,
538  window_size_left,
539  window_size_right,
540  mask_type,
541  p_drop,
542  s_randval,
543  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
544  }
545 
546  // std::variant<> can't take in a list initializer, overload for backward compatibility
547  template <bool Cond = !kIsGroupMode>
548  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
549  MakeKargs(const void* q_ptr,
550  const void* k_ptr,
551  const void* v_ptr,
552  const void* bias_ptr,
553  void* rand_val_ptr,
554  void* lse_ptr,
555  void* o_ptr,
556  ck_tile::index_t seqlen_q,
557  ck_tile::index_t seqlen_k,
558  ck_tile::index_t hdim_q,
559  ck_tile::index_t hdim_v,
560  ck_tile::index_t num_head_q,
561  ck_tile::index_t nhead_ratio_qk,
562  float scale_s,
563  float scale_p,
564  float scale_o,
565  float logits_soft_cap,
566  ck_tile::index_t stride_q,
567  ck_tile::index_t stride_k,
568  ck_tile::index_t stride_v,
569  ck_tile::index_t stride_bias,
570  ck_tile::index_t stride_randval,
571  ck_tile::index_t stride_o,
572  ck_tile::index_t nhead_stride_q,
573  ck_tile::index_t nhead_stride_k,
574  ck_tile::index_t nhead_stride_v,
575  ck_tile::index_t nhead_stride_bias,
576  ck_tile::index_t nhead_stride_randval,
577  ck_tile::index_t nhead_stride_lse,
578  ck_tile::index_t nhead_stride_o,
579  ck_tile::index_t batch_stride_q,
580  ck_tile::index_t batch_stride_k,
581  ck_tile::index_t batch_stride_v,
582  ck_tile::index_t batch_stride_bias,
583  ck_tile::index_t batch_stride_randval,
584  ck_tile::index_t batch_stride_lse,
585  ck_tile::index_t batch_stride_o,
586  ck_tile::index_t window_size_left,
587  ck_tile::index_t window_size_right,
588  ck_tile::index_t mask_type,
589  float p_drop,
590  bool s_randval,
591  const std::tuple<const void*, const void*>& drop_seed_offset)
592  {
593  return MakeKargsImpl(
594  q_ptr,
595  k_ptr,
596  v_ptr,
597  bias_ptr,
598  rand_val_ptr,
599  lse_ptr,
600  o_ptr,
601  seqlen_q,
602  seqlen_k,
603  hdim_q,
604  hdim_v,
605  num_head_q,
606  nhead_ratio_qk,
607  scale_s,
608  scale_p,
609  scale_o,
610  logits_soft_cap,
611  stride_q,
612  stride_k,
613  stride_v,
614  stride_bias,
615  stride_randval,
616  stride_o,
617  nhead_stride_q,
618  nhead_stride_k,
619  nhead_stride_v,
620  nhead_stride_bias,
621  nhead_stride_randval,
622  nhead_stride_lse,
623  nhead_stride_o,
624  batch_stride_q,
625  batch_stride_k,
626  batch_stride_v,
627  batch_stride_bias,
628  batch_stride_randval,
629  batch_stride_lse,
630  batch_stride_o,
631  window_size_left,
632  window_size_right,
633  mask_type,
634  p_drop,
635  s_randval,
636  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
637  }
638 
639  template <bool Cond = kIsGroupMode>
640  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
641  MakeKargsImpl(const void* q_ptr,
642  const void* k_ptr,
643  const void* v_ptr,
644  const void* bias_ptr,
645  void* rand_val_ptr,
646  void* lse_ptr,
647  void* o_ptr,
648  const void* seqstart_q_ptr,
649  const void* seqstart_k_ptr,
650  const void* seqlen_k_ptr,
651  ck_tile::index_t hdim_q,
652  ck_tile::index_t hdim_v,
653  ck_tile::index_t num_head_q,
654  ck_tile::index_t nhead_ratio_qk,
655  float scale_s,
656  float scale_p,
657  float scale_o,
658  float logits_soft_cap,
659  ck_tile::index_t stride_q,
660  ck_tile::index_t stride_k,
661  ck_tile::index_t stride_v,
662  ck_tile::index_t stride_bias,
663  ck_tile::index_t stride_randval,
664  ck_tile::index_t stride_o,
665  ck_tile::index_t nhead_stride_q,
666  ck_tile::index_t nhead_stride_k,
667  ck_tile::index_t nhead_stride_v,
668  ck_tile::index_t nhead_stride_bias,
669  ck_tile::index_t nhead_stride_randval,
670  ck_tile::index_t nhead_stride_lse,
671  ck_tile::index_t nhead_stride_o,
672  ck_tile::index_t window_size_left,
673  ck_tile::index_t window_size_right,
674  ck_tile::index_t mask_type,
675  ck_tile::index_t min_seqlen_q,
676  float p_drop,
677  bool s_randval,
678  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
679  drop_seed_offset)
680  {
681  Kargs kargs{{q_ptr,
682  k_ptr,
683  v_ptr,
684  o_ptr,
685  -1, // seqlen will be updated by another pointer
686  -1, //
687  hdim_q,
688  hdim_v,
689  num_head_q,
690  nhead_ratio_qk,
691 #if CK_TILE_FMHA_FWD_FAST_EXP2
692  static_cast<float>(scale_s * ck_tile::log2e_v<>),
693 #else
694  scale_s,
695 #endif
696  stride_q,
697  stride_k,
698  stride_v,
699  stride_o,
700  nhead_stride_q,
701  nhead_stride_k,
702  nhead_stride_v,
703  nhead_stride_o}, // args for common karg
704  {}, // placeholder for bias
705  {}, // placeholder for mask
706  {}, // placeholder for lse
707  {}, // placeholder for fp8_static_quant args
708  {}, // placeholder for dropout
709  {}, // placeholder for logits_soft_cap
710  {}, // placeholder for min_seqlen_q
711  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
712  reinterpret_cast<const int32_t*>(seqstart_k_ptr),
713  reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
714 
716  {
717  kargs.bias_ptr = bias_ptr;
718  kargs.stride_bias = stride_bias;
719  kargs.nhead_stride_bias = nhead_stride_bias;
720  }
721  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
722  {
723  kargs.alibi_slope_ptr = bias_ptr;
724  kargs.alibi_slope_stride = stride_bias;
725  }
726  if constexpr(kHasMask)
727  {
728  kargs.window_size_left = window_size_left;
729  kargs.window_size_right = window_size_right;
730  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
731  }
732  if constexpr(kStoreLSE)
733  {
734  kargs.lse_ptr = lse_ptr;
735  kargs.nhead_stride_lse = nhead_stride_lse;
736  }
737  if constexpr(kDoFp8StaticQuant)
738  {
739  kargs.scale_p = scale_p;
740  kargs.scale_o = scale_o;
741  }
742  if constexpr(kHasDropout)
743  {
744  if(drop_seed_offset.index() == 0) // seed & offset come from host
745  {
746  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
747  kargs.init_dropout(p_drop, seed, offset);
748  }
749  else // seed & offset come from device
750  {
751  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
752  kargs.init_dropout(p_drop,
753  reinterpret_cast<const uint64_t*>(seed_ptr),
754  reinterpret_cast<const uint64_t*>(offset_ptr));
755  }
756 
757  kargs.rand_val_ptr = rand_val_ptr;
758  kargs.stride_randval = stride_randval;
759  kargs.nhead_stride_randval = nhead_stride_randval;
760  kargs.is_store_randval = s_randval;
761  }
762  if constexpr(kHasLogitsSoftCap)
763  {
764  kargs.init_logits_soft_cap(logits_soft_cap);
765  }
766  if constexpr(kSkipMinSeqlenQ)
767  {
768  kargs.min_seqlen_q = min_seqlen_q;
769  }
770 
771  return kargs;
772  }
773 
774  // std::variant<> can't take in a list initializer, overload for backward compatibility
775  template <bool Cond = kIsGroupMode>
776  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
777  MakeKargs(const void* q_ptr,
778  const void* k_ptr,
779  const void* v_ptr,
780  const void* bias_ptr,
781  void* rand_val_ptr,
782  void* lse_ptr,
783  void* o_ptr,
784  const void* seqstart_q_ptr,
785  const void* seqstart_k_ptr,
786  const void* seqlen_k_ptr,
787  ck_tile::index_t hdim_q,
788  ck_tile::index_t hdim_v,
789  ck_tile::index_t num_head_q,
790  ck_tile::index_t nhead_ratio_qk,
791  float scale_s,
792  float scale_p,
793  float scale_o,
794  float logits_soft_cap,
795  ck_tile::index_t stride_q,
796  ck_tile::index_t stride_k,
797  ck_tile::index_t stride_v,
798  ck_tile::index_t stride_bias,
799  ck_tile::index_t stride_randval,
800  ck_tile::index_t stride_o,
801  ck_tile::index_t nhead_stride_q,
802  ck_tile::index_t nhead_stride_k,
803  ck_tile::index_t nhead_stride_v,
804  ck_tile::index_t nhead_stride_bias,
805  ck_tile::index_t nhead_stride_randval,
806  ck_tile::index_t nhead_stride_lse,
807  ck_tile::index_t nhead_stride_o,
808  ck_tile::index_t window_size_left,
809  ck_tile::index_t window_size_right,
810  ck_tile::index_t mask_type,
811  ck_tile::index_t min_seqlen_q,
812  float p_drop,
813  bool s_randval,
814  const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
815  {
816  return MakeKargsImpl(
817  q_ptr,
818  k_ptr,
819  v_ptr,
820  bias_ptr,
821  rand_val_ptr,
822  lse_ptr,
823  o_ptr,
824  seqstart_q_ptr,
825  seqstart_k_ptr,
826  seqlen_k_ptr,
827  hdim_q,
828  hdim_v,
829  num_head_q,
830  nhead_ratio_qk,
831  scale_s,
832  scale_p,
833  scale_o,
834  logits_soft_cap,
835  stride_q,
836  stride_k,
837  stride_v,
838  stride_bias,
839  stride_randval,
840  stride_o,
841  nhead_stride_q,
842  nhead_stride_k,
843  nhead_stride_v,
844  nhead_stride_bias,
845  nhead_stride_randval,
846  nhead_stride_lse,
847  nhead_stride_o,
848  window_size_left,
849  window_size_right,
850  mask_type,
851  min_seqlen_q,
852  p_drop,
853  s_randval,
854  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
855  }
856 
857  // std::variant<> can't take in a list initializer, overload for backward compatibility
858  template <bool Cond = kIsGroupMode>
859  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
860  MakeKargs(const void* q_ptr,
861  const void* k_ptr,
862  const void* v_ptr,
863  const void* bias_ptr,
864  void* rand_val_ptr,
865  void* lse_ptr,
866  void* o_ptr,
867  const void* seqstart_q_ptr,
868  const void* seqstart_k_ptr,
869  const void* seqlen_k_ptr,
870  ck_tile::index_t hdim_q,
871  ck_tile::index_t hdim_v,
872  ck_tile::index_t num_head_q,
873  ck_tile::index_t nhead_ratio_qk,
874  float scale_s,
875  float scale_p,
876  float scale_o,
877  float logits_soft_cap,
878  ck_tile::index_t stride_q,
879  ck_tile::index_t stride_k,
880  ck_tile::index_t stride_v,
881  ck_tile::index_t stride_bias,
882  ck_tile::index_t stride_randval,
883  ck_tile::index_t stride_o,
884  ck_tile::index_t nhead_stride_q,
885  ck_tile::index_t nhead_stride_k,
886  ck_tile::index_t nhead_stride_v,
887  ck_tile::index_t nhead_stride_bias,
888  ck_tile::index_t nhead_stride_randval,
889  ck_tile::index_t nhead_stride_lse,
890  ck_tile::index_t nhead_stride_o,
891  ck_tile::index_t window_size_left,
892  ck_tile::index_t window_size_right,
893  ck_tile::index_t mask_type,
894  ck_tile::index_t min_seqlen_q,
895  float p_drop,
896  bool s_randval,
897  const std::tuple<const void*, const void*>& drop_seed_offset)
898  {
899  return MakeKargsImpl(
900  q_ptr,
901  k_ptr,
902  v_ptr,
903  bias_ptr,
904  rand_val_ptr,
905  lse_ptr,
906  o_ptr,
907  seqstart_q_ptr,
908  seqstart_k_ptr,
909  seqlen_k_ptr,
910  hdim_q,
911  hdim_v,
912  num_head_q,
913  nhead_ratio_qk,
914  scale_s,
915  scale_p,
916  scale_o,
917  logits_soft_cap,
918  stride_q,
919  stride_k,
920  stride_v,
921  stride_bias,
922  stride_randval,
923  stride_o,
924  nhead_stride_q,
925  nhead_stride_k,
926  nhead_stride_v,
927  nhead_stride_bias,
928  nhead_stride_randval,
929  nhead_stride_lse,
930  nhead_stride_o,
931  window_size_left,
932  window_size_right,
933  mask_type,
934  min_seqlen_q,
935  p_drop,
936  s_randval,
937  std::make_pair(std::get<0>(drop_seed_offset), std::get<1>(drop_seed_offset)));
938  }
939 
940  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
941  ck_tile::index_t nhead_,
942  ck_tile::index_t seqlen_q_,
943  ck_tile::index_t hdim_v_,
944  bool has_padded_seqlen_k = false)
945  {
946  // has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr)
947  if(has_padded_seqlen_k)
948  {
949  // TODO: this may need tuning
950  return dim3(nhead_,
951  batch_size_,
952  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
953  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1));
954  }
955  else
956  {
957  // TODO: this may need tuning
958  return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
959  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
960  nhead_,
961  batch_size_);
962  }
963  }
964 
965  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
966  {
967  bool has_padded_seqlen_k = false;
968 
969  if constexpr(kIsGroupMode)
970  has_padded_seqlen_k = (kargs.seqlen_k_ptr != nullptr);
971 
972  if(has_padded_seqlen_k)
973  {
974  // const index_t num_tile_m0 = seqlen_q / kM0;
975  const index_t num_tile_n1 =
976  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
977 
978  const index_t i_block = blockIdx.z;
979  const index_t i_nhead = blockIdx.x;
980  const index_t i_batch = blockIdx.y;
981 
982  const auto f = [](index_t dividend, index_t divisor) {
983  index_t quotient = dividend / divisor;
984  index_t modulus = dividend - quotient * divisor;
985  return ck_tile::make_tuple(quotient, modulus);
986  };
987 
988  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
989 
990  if constexpr(kHasMask)
991  {
992  // assume that num_tile_n1 is always 1
993  return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
994  }
995  else
996  {
997  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
998  }
999  }
1000  else
1001  {
1002  // const index_t num_tile_m0 = seqlen_q / kM0;
1003  const index_t num_tile_n1 =
1004  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
1005 
1006  const index_t i_block = blockIdx.x;
1007  const index_t i_nhead = blockIdx.y;
1008  const index_t i_batch = blockIdx.z;
1009 
1010  const auto f = [](index_t dividend, index_t divisor) {
1011  index_t quotient = dividend / divisor;
1012  index_t modulus = dividend - quotient * divisor;
1013  return ck_tile::make_tuple(quotient, modulus);
1014  };
1015 
1016  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
1017 
1018  if constexpr(kHasMask)
1019  {
1020  // assume that num_tile_n1 is always 1
1021  return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
1022  }
1023  else
1024  {
1025  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
1026  }
1027  }
1028  }
1029 
1030  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
1031 
1033  {
1034  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
1035  }
1036 
1037  CK_TILE_DEVICE void operator()(Kargs kargs) const
1038  {
1039  // allocate LDS
1040  __shared__ char smem_ptr[GetSmemSize()];
1041 
1042  // divide problem
1043  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
1044 
1045  const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
1046  const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
1047 
1048  long_index_t batch_offset_q = 0;
1049  long_index_t batch_offset_k = 0;
1050  long_index_t batch_offset_v = 0;
1051  long_index_t batch_offset_bias = 0;
1052  long_index_t batch_offset_randval = 0;
1053  long_index_t batch_offset_lse = 0;
1054  long_index_t batch_offset_o = 0;
1055 
1056  if constexpr(kIsGroupMode)
1057  {
1058  // get starting offset for each batch
1059  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
1060  const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
1061 
1062  batch_offset_q = query_start * kargs.stride_q;
1063  batch_offset_k = key_start * kargs.stride_k;
1064  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1065  {
1066  batch_offset_v = key_start * kargs.stride_v;
1067  }
1068  else
1069  {
1070  batch_offset_v = key_start;
1071  }
1073  {
1074  batch_offset_bias = query_start * kargs.stride_bias;
1075  }
1076  if constexpr(kStoreLSE)
1077  {
1078  batch_offset_lse = query_start;
1079  }
1080  if constexpr(kHasDropout)
1081  {
1082  batch_offset_randval = query_start * kargs.stride_randval;
1083  }
1084  batch_offset_o = query_start * kargs.stride_o;
1085 
1086  // get real # queries & # keys under group mode
1087  const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
1088  kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
1089 
1090  if constexpr(kSkipMinSeqlenQ)
1091  {
1092  if(kargs.seqlen_q <= kargs.min_seqlen_q)
1093  {
1094  return;
1095  }
1096  }
1097 
1098  // # of required blocks is different in each groups, terminate unnecessary blocks
1099  // earlier
1100  if(kargs.seqlen_q <= i_m0)
1101  {
1102  return;
1103  }
1104 
1105  if(kargs.seqlen_k_ptr != nullptr)
1106  {
1107  kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
1108  }
1109  else
1110  {
1111  const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
1112  kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
1113  }
1114  }
1115  else
1116  {
1117  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
1118  batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
1119  batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
1121  {
1122  batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
1123  }
1124  if constexpr(kStoreLSE)
1125  {
1126  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
1127  }
1128  if constexpr(kHasDropout)
1129  {
1130  batch_offset_randval =
1131  static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
1132  }
1133  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
1134  }
1135 
1136  // for simplicity, batch stride we just modify the pointer
1137  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
1138  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
1139  batch_offset_q;
1140  const KDataType* k_ptr =
1141  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
1142  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
1143  batch_offset_k;
1144  const VDataType* v_ptr =
1145  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
1146  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
1147  batch_offset_v;
1148  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
1149  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
1150  batch_offset_o;
1151 
1152  // Q/K/V DRAM and DRAM window
1153  const auto q_dram = [&]() {
1154  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1155  q_ptr,
1156  make_tuple(kargs.seqlen_q, kargs.hdim_q),
1157  make_tuple(kargs.stride_q, 1),
1159  number<1>{});
1160  if constexpr(FmhaPipeline::kQLoadOnce)
1161  {
1162  return pad_tensor_view(
1163  q_dram_naive,
1166  }
1167  else
1168  {
1169  return pad_tensor_view(
1170  q_dram_naive,
1173  }
1174  }();
1175  const auto k_dram = [&]() {
1176  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1177  k_ptr,
1178  make_tuple(kargs.seqlen_k, kargs.hdim_q),
1179  make_tuple(kargs.stride_k, 1),
1181  number<1>{});
1182 
1183  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
1184  return pad_tensor_view(
1185  k_dram_naive,
1188  }();
1189  const auto v_dram = [&]() {
1190  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
1191  {
1192  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1193  v_ptr,
1194  make_tuple(kargs.seqlen_k, kargs.hdim_v),
1195  make_tuple(kargs.stride_v, 1),
1197  number<1>{});
1198 
1199  const auto v_dram_transposed =
1200  transform_tensor_view(v_dram_naive,
1202  make_pass_through_transform(kargs.seqlen_k)),
1205 
1206  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : false;
1207  return pad_tensor_view(
1208  v_dram_transposed,
1211  }
1212  else
1213  {
1214  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1215  v_ptr,
1216  make_tuple(kargs.hdim_v, kargs.seqlen_k),
1217  make_tuple(kargs.stride_v, 1),
1219  number<1>{});
1220 
1221  constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false;
1222  return pad_tensor_view(
1223  v_dram_naive,
1226  }
1227  }();
1228 
1229  auto q_dram_window = make_tile_window(
1230  q_dram,
1231  [&]() {
1232  if constexpr(FmhaPipeline::kQLoadOnce)
1235  else
1237  }(),
1238  {i_m0, 0});
1239 
1240  auto k_dram_window = make_tile_window(
1242 
1243  auto v_dram_window =
1244  make_tile_window(v_dram,
1246  {i_n1, 0});
1249  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
1250  constexpr auto bias_dram_window_lengths =
1253  {
1254  const BiasDataType* bias_ptr =
1255  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
1256  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
1257  batch_offset_bias;
1258 
1259  const auto bias_dram = [&]() {
1260  const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1261  bias_ptr,
1262  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1263  make_tuple(kargs.stride_bias, 1),
1265  number<1>{});
1266 
1267  return pad_tensor_view(bias_dram_naive,
1268  bias_dram_window_lengths,
1270  }();
1271 
1272  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
1273  }
1274  else
1275  {
1276  return make_null_tile_window(bias_dram_window_lengths);
1277  }
1278  }();
1279 
1280  // lse
1281  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
1282  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
1283  if constexpr(kStoreLSE)
1284  {
1285  LSEDataType* lse_ptr =
1286  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
1287  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
1288 
1289  const auto lse_dram = [&]() {
1290  const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1291  lse_ptr,
1292  make_tuple(kargs.seqlen_q),
1293  make_tuple(1),
1294  number<1>{},
1295  number<1>{});
1296 
1297  return pad_tensor_view(
1298  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
1299  }();
1300 
1301  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
1302  }
1303  else
1304  {
1305  return make_null_tile_window(lse_dram_window_lengths);
1306  }
1307  }();
1308 
1309  auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
1310  if constexpr(kHasDropout)
1311  {
1312  return BlockDropout{i_batch_,
1313  i_nhead_,
1314  kargs.num_head_q,
1315  kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
1316  : *kargs.drop_seed.ptr,
1317  kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
1318  : *kargs.drop_offset.ptr,
1319  kargs.rp_undrop,
1320  kargs.p_undrop_in_uint8_t,
1321  kargs.is_store_randval};
1322  }
1323  else
1324  {
1325  return NullBlockDropout{};
1326  };
1327  }();
1328 
1329  auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
1330  constexpr auto randval_dram_window_lengths =
1332  if constexpr(kHasDropout)
1333  {
1334  RandValOutputDataType* rand_val_ptr =
1335  reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
1336  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
1337  batch_offset_randval;
1338 
1339  const auto randval_dram = [&]() {
1340  const auto randval_dram_naive =
1341  make_naive_tensor_view<address_space_enum::global>(
1342  rand_val_ptr,
1343  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
1344  make_tuple(kargs.stride_randval, 1),
1345  number<1>{},
1346  number<1>{});
1347 
1348  return pad_tensor_view(randval_dram_naive,
1349  randval_dram_window_lengths,
1351  }();
1352 
1353  return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0});
1354  }
1355  else
1356  {
1357  return make_null_tile_window(randval_dram_window_lengths);
1358  }
1359  }();
1360 
1361  FmhaMask mask = [&]() {
1362  if constexpr(kHasMask)
1363  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
1364  kargs.window_size_left,
1365  kargs.window_size_right,
1366  kargs.seqlen_q,
1367  kargs.seqlen_k,
1369  else
1370  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1371  }();
1372 
1373  // WA i_batch capture structure binding before c++20
1374  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1375  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
1376  {
1377  // data loading, shared by entire wg
1378  // TODO: how to use s_read?
1379  SaccDataType slope =
1380  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1381  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1382 #if CK_TILE_FMHA_FWD_FAST_EXP2
1383  slope *= ck_tile::log2e_v<>;
1384 #endif
1385  if constexpr(kHasMask)
1386  {
1387  return make_alibi_from_lr_mask<SaccDataType, true>(slope,
1388  kargs.window_size_left,
1389  kargs.window_size_right,
1390  kargs.seqlen_q,
1391  kargs.seqlen_k,
1392  kargs.mask_type);
1393  }
1394  else
1395  {
1397  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1398  }
1399  }
1400  else
1401  {
1403  }
1404  }();
1405 
1406  AttentionVariant variant;
1407  const auto variant_params = [&] {
1408  if constexpr(kHasLogitsSoftCap)
1409  {
1411  mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1412  }
1413  else
1414  {
1415  return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
1416  }
1417  }();
1418 
1419  BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
1420 
1421  auto o_acc_tile = [&]() {
1422  if constexpr(kDoFp8StaticQuant)
1423  {
1424  return FmhaPipeline{}(
1425  q_dram_window,
1426  identity{}, // q_element_func
1427  k_dram_window,
1428  identity{}, // k_element_func
1429  v_dram_window,
1430  identity{}, // v_element_func
1431  bias_dram_window,
1432  identity{}, // bias_element_func
1433  randval_dram_window,
1434  lse_dram_window,
1435  identity{}, // lse_element_func
1436  identity{}, // s_acc_element_func
1437  scales{kargs.scale_p}, // p_compute_element_func
1438  composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
1439  mask,
1440  position_encoding,
1441  kargs.scale_s,
1442  variant,
1443  variant_params,
1444  block_indices,
1445  smem_ptr,
1446  dropout);
1447  }
1448  else
1449  {
1450  return FmhaPipeline{}(q_dram_window,
1451  k_dram_window,
1452  v_dram_window,
1453  bias_dram_window,
1454  randval_dram_window,
1455  lse_dram_window,
1456  mask,
1457  position_encoding,
1458  kargs.scale_s,
1459  variant,
1460  variant_params,
1461  block_indices,
1462  smem_ptr,
1463  dropout);
1464  }
1465  }();
1466 
1467  // O DRAM and O DRAM window
1468  auto o_dram = [&]() {
1469  const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1470  o_ptr,
1471  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1472  make_tuple(kargs.stride_o, 1),
1474  number<1>{});
1475 
1476  return pad_tensor_view(
1477  o_dram_naive,
1480  }();
1481 
1482  auto o_dram_window =
1483  make_tile_window(o_dram,
1485  {i_m0, i_n1});
1486 
1487  EpiloguePipeline{}(o_dram_window, o_acc_tile);
1488  }
1489 };
1490 
1491 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
#define _TS_
#define _SS_
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ T floor(T x)
Definition: math_v2.hpp:367
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:63
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
_Float16 fp16_t
Definition: half.hpp:110
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:510
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:529
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1615
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
__host__ __device__ composes(Ts &&...) -> composes< remove_cvref_t< Ts >... >
FIXME: create macro to replace 'host device' and nothing more.
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: block_position_encoding.hpp:48
Definition: block_attention_bias_enum.hpp:19
Definition: block_dropout.hpp:26
const float rp_undrop
Definition: block_dropout.hpp:290
Definition: block_position_encoding.hpp:137
Definition: fmha_fwd_kernel.hpp:308
ck_tile::index_t kv_head_idx
Definition: fmha_fwd_kernel.hpp:311
ck_tile::index_t batch_idx
Definition: fmha_fwd_kernel.hpp:309
ck_tile::index_t qo_head_idx
Definition: fmha_fwd_kernel.hpp:310
Definition: fmha_fwd_kernel.hpp:182
ck_tile::index_t alibi_slope_stride
Definition: fmha_fwd_kernel.hpp:185
const void * alibi_slope_ptr
Definition: fmha_fwd_kernel.hpp:184
Definition: fmha_fwd_kernel.hpp:177
ck_tile::index_t batch_stride_bias
Definition: fmha_fwd_kernel.hpp:178
Definition: fmha_fwd_kernel.hpp:258
ck_tile::index_t batch_stride_randval
Definition: fmha_fwd_kernel.hpp:259
Definition: fmha_fwd_kernel.hpp:279
ck_tile::index_t batch_stride_o
Definition: fmha_fwd_kernel.hpp:283
ck_tile::index_t batch_stride_q
Definition: fmha_fwd_kernel.hpp:280
ck_tile::index_t batch_stride_k
Definition: fmha_fwd_kernel.hpp:281
ck_tile::index_t batch_stride_v
Definition: fmha_fwd_kernel.hpp:282
Definition: fmha_fwd_kernel.hpp:170
const void * bias_ptr
Definition: fmha_fwd_kernel.hpp:171
ck_tile::index_t stride_bias
Definition: fmha_fwd_kernel.hpp:172
ck_tile::index_t nhead_stride_bias
Definition: fmha_fwd_kernel.hpp:173
Definition: fmha_fwd_kernel.hpp:223
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr)
Definition: fmha_fwd_kernel.hpp:236
float rp_undrop
Definition: fmha_fwd_kernel.hpp:248
ck_tile::index_t stride_randval
Definition: fmha_fwd_kernel.hpp:253
ck_tile::index_t nhead_stride_randval
Definition: fmha_fwd_kernel.hpp:254
void * rand_val_ptr
Definition: fmha_fwd_kernel.hpp:251
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
Definition: fmha_fwd_kernel.hpp:224
bool is_store_randval
Definition: fmha_fwd_kernel.hpp:250
uint8_t p_undrop_in_uint8_t
Definition: fmha_fwd_kernel.hpp:249
Definition: fmha_fwd_kernel.hpp:119
ck_tile::index_t nhead_stride_k
Definition: fmha_fwd_kernel.hpp:142
float scale_s
Definition: fmha_fwd_kernel.hpp:134
ck_tile::index_t seqlen_k
Definition: fmha_fwd_kernel.hpp:126
ck_tile::index_t nhead_stride_o
Definition: fmha_fwd_kernel.hpp:144
ck_tile::index_t nhead_ratio_qk
Definition: fmha_fwd_kernel.hpp:133
ck_tile::index_t num_head_q
Definition: fmha_fwd_kernel.hpp:130
ck_tile::index_t hdim_q
Definition: fmha_fwd_kernel.hpp:127
const void * v_ptr
Definition: fmha_fwd_kernel.hpp:122
void * o_ptr
Definition: fmha_fwd_kernel.hpp:123
const void * k_ptr
Definition: fmha_fwd_kernel.hpp:121
ck_tile::index_t nhead_stride_q
Definition: fmha_fwd_kernel.hpp:141
ck_tile::index_t stride_k
Definition: fmha_fwd_kernel.hpp:137
ck_tile::index_t stride_o
Definition: fmha_fwd_kernel.hpp:139
ck_tile::index_t stride_v
Definition: fmha_fwd_kernel.hpp:138
ck_tile::index_t hdim_v
Definition: fmha_fwd_kernel.hpp:128
ck_tile::index_t nhead_stride_v
Definition: fmha_fwd_kernel.hpp:143
const void * q_ptr
Definition: fmha_fwd_kernel.hpp:120
ck_tile::index_t seqlen_q
Definition: fmha_fwd_kernel.hpp:125
ck_tile::index_t stride_q
Definition: fmha_fwd_kernel.hpp:136
Definition: fmha_fwd_kernel.hpp:202
ck_tile::index_t batch_stride_lse
Definition: fmha_fwd_kernel.hpp:205
void * lse_ptr
Definition: fmha_fwd_kernel.hpp:203
ck_tile::index_t nhead_stride_lse
Definition: fmha_fwd_kernel.hpp:204
Definition: fmha_fwd_kernel.hpp:209
bool is_drop_seed_offset_from_host
Definition: fmha_fwd_kernel.hpp:219
ValueOrPointer< uint64_t > drop_seed
Definition: fmha_fwd_kernel.hpp:217
ValueOrPointer< uint64_t > drop_offset
Definition: fmha_fwd_kernel.hpp:218
Definition: fmha_fwd_kernel.hpp:112
Definition: fmha_fwd_kernel.hpp:196
float scale_o
Definition: fmha_fwd_kernel.hpp:198
float scale_p
Definition: fmha_fwd_kernel.hpp:197
Definition: fmha_fwd_kernel.hpp:299
const int32_t * seqstart_q_ptr
Definition: fmha_fwd_kernel.hpp:300
const int32_t * seqlen_k_ptr
Definition: fmha_fwd_kernel.hpp:302
const int32_t * seqstart_k_ptr
Definition: fmha_fwd_kernel.hpp:301
Definition: fmha_fwd_kernel.hpp:148
float logits_soft_cap
Definition: fmha_fwd_kernel.hpp:165
float logits_soft_cap_rcp
Definition: fmha_fwd_kernel.hpp:166
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_fwd_kernel.hpp:151
Definition: fmha_fwd_kernel.hpp:189
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_fwd_kernel.hpp:192
ck_tile::index_t window_size_right
Definition: fmha_fwd_kernel.hpp:191
ck_tile::index_t window_size_left
Definition: fmha_fwd_kernel.hpp:191
Definition: fmha_fwd_kernel.hpp:263
ck_tile::index_t min_seqlen_q
Definition: fmha_fwd_kernel.hpp:264
Definition: fmha_fwd_kernel.hpp:65
Definition: fmha_fwd_kernel.hpp:26
static constexpr bool kHasDropout
Definition: fmha_fwd_kernel.hpp:54
static CK_TILE_HOST std::string GetName()
Definition: fmha_fwd_kernel.hpp:73
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset)
Definition: fmha_fwd_kernel.hpp:456
static constexpr bool kDoFp8StaticQuant
Definition: fmha_fwd_kernel.hpp:55
static constexpr bool kStoreLSE
Definition: fmha_fwd_kernel.hpp:53
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_fwd_kernel.hpp:35
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_fwd_kernel.hpp:305
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_fwd_kernel.hpp:30
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_fwd_kernel.hpp:41
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_fwd_kernel.hpp:44
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_fwd_kernel.hpp:29
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_fwd_kernel.hpp:37
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_, bool has_padded_seqlen_k=false)
Definition: fmha_fwd_kernel.hpp:940
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_fwd_kernel.hpp:1030
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_fwd_kernel.hpp:36
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset)
Definition: fmha_fwd_kernel.hpp:549
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_fwd_kernel.hpp:32
static constexpr auto BiasEnum
Definition: fmha_fwd_kernel.hpp:52
static constexpr bool kPadHeadDimV
Definition: fmha_fwd_kernel.hpp:50
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_fwd_kernel.hpp:965
static constexpr bool kSkipMinSeqlenQ
Definition: fmha_fwd_kernel.hpp:56
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_fwd_kernel.hpp:40
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_fwd_kernel.hpp:34
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< uint64_t, uint64_t > &drop_seed_offset)
Definition: fmha_fwd_kernel.hpp:777
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_fwd_kernel.hpp:58
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_fwd_kernel.hpp:1032
static constexpr bool kHasMask
Definition: fmha_fwd_kernel.hpp:60
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_fwd_kernel.hpp:641
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargsImpl(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_k, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_fwd_kernel.hpp:316
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, const void *seqstart_k_ptr, const void *seqlen_k_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, ck_tile::index_t min_seqlen_q, float p_drop, bool s_randval, const std::tuple< const void *, const void * > &drop_seed_offset)
Definition: fmha_fwd_kernel.hpp:860
static constexpr bool kUseAsyncCopy
Definition: fmha_fwd_kernel.hpp:62
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_fwd_kernel.hpp:27
static constexpr bool kPadHeadDimQ
Definition: fmha_fwd_kernel.hpp:49
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_fwd_kernel.hpp:42
static constexpr bool kPadSeqLenQ
Definition: fmha_fwd_kernel.hpp:47
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_fwd_kernel.hpp:59
static constexpr bool kHasLogitsSoftCap
Definition: fmha_fwd_kernel.hpp:51
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition: fmha_fwd_kernel.hpp:39
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_fwd_kernel.hpp:28
static constexpr bool kPadSeqLenK
Definition: fmha_fwd_kernel.hpp:48
static constexpr bool kIsGroupMode
Definition: fmha_fwd_kernel.hpp:46
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_fwd_kernel.hpp:1037
Definition: variants.hpp:63
float logits_soft_cap
Definition: variants.hpp:128
Definition: block_dropout.hpp:12
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: functional.hpp:86
Definition: coordinate_transform.hpp:1443
Definition: unary_element_function.hpp:56
Definition: math.hpp:28
Definition: sequence.hpp:52
const T * ptr
Definition: fmha_fwd_kernel.hpp:214