/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp Source File
fmha_batch_prefill_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
10 
11 #include <string>
12 #include <type_traits>
13 #include <utility>
14 #include <variant>
15 
16 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
17 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
18 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
19 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
20 // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
21 
22 namespace ck_tile {
23 
24 template <typename FmhaPipeline_, typename EpiloguePipeline_>
26 {
29  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
30  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
31  static_assert(kBlockPerCu > 0);
32  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
33 
43 
45 
46  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
47  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
48  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
49  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
50  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
51  static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
52  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
53  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
54  static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
55  static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
58  static constexpr bool kHasMask = FmhaMask::IsMasking;
59 
60  static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
61 
62  // clang-format off
63  template <typename T> struct t2s;
64  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
65  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
66  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
67  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
68  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
69  // clang-format on
70 
71  CK_TILE_HOST static std::string GetName()
72  {
73  // sync with generate.py
74  // clang-format off
75  using bfs = typename FmhaPipeline::BlockFmhaShape;
76  using g0br = typename bfs::Gemm0BlockWarps;
77  using g1br = typename bfs::Gemm1BlockWarps;
78  using g0wt = typename bfs::Gemm0WarpTile;
79  using g1wt = typename bfs::Gemm1WarpTile;
80  #define _SS_ std::string
81  #define _TS_ std::to_string
82  auto pn = [&] () {
83  std::string n;
84  if (kPadSeqLenQ) n += "s";
85  if (kPadSeqLenK) n += "sk";
86  if (kPadHeadDimQ) n += "d";
87  if (kPadHeadDimV) n += "dv";
88  return n.empty() ? n : std::string("p") + n; }();
89  return
90  _SS_("fmha_batch_prefill_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
91  "_" + (kIsGroupMode ? "group" : "batch") + "_"
92  "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
93  _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
94  "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
95  "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
96  "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
97  "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
98  (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
99  "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
100  (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
101  (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) + (kDoFp8StaticQuant ? "_squant" : "_nsquant" );
102  #undef _SS_
103  #undef _TS_
104  // clang-format on
105  }
106 
107  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
108  // arg
110  {
111  };
112 
113  // kargs use aggregate initializer, so no constructor will provided
114  // use inheritance to minimize karg size
115  // user need to use MakeKargs() function to create kargs.
117  {
118  const void* q_ptr;
119  const void* k_ptr;
120  const void* v_ptr;
121  void* o_ptr;
122 
127 
129  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
130  // if this param is larger than 1, indicate MQA/GQA case
132 
136 #if 0 // we assume page_block_size=1 for now
137  const int32_t* kv_last_page_lens;
139 #else
140  static constexpr ck_tile::index_t page_block_size = 1;
141 #endif
142 
143  float scale_s;
144 
149 
154  };
155 
157  {
159 
160  void init_logits_soft_cap(float logits_soft_cap_)
161  {
162  if(0 < logits_soft_cap_)
163  {
164  logits_soft_cap = logits_soft_cap_;
166  }
167  else
168  {
169  logits_soft_cap = 0.f;
170  logits_soft_cap_rcp = 0.f;
171  }
172  }
173 
176  };
177 
179  {
180  const void* bias_ptr = nullptr;
183  };
184 
186  {
188  };
189 
191  {
192  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
193  const void* alibi_slope_ptr;
194  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
195  };
196 
198  {
199  // ck_tile::index_t window_size_left, window_size_right;
202  };
203 
205  {
206  float scale_p;
207  float scale_o;
208  };
209 
211  {
212  void* lse_ptr = nullptr;
215  };
216 
218  {
219  template <typename T>
221  {
222  T val;
223  const T* ptr;
224  };
225 
229  };
230 
232  {
233  void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
234  {
235  float p_undrop = 1.0 - p_drop;
237  uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
238  rp_undrop = 1.0 / p_undrop;
239 
240  this->drop_seed.val = seed;
241  this->drop_offset.val = offset;
242  this->is_drop_seed_offset_from_host = true;
243  }
244 
245  void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
246  {
247  float p_undrop = 1.0 - p_drop;
249  uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
250  rp_undrop = 1.0 / p_undrop;
251 
252  this->drop_seed.ptr = seed_ptr;
253  this->drop_offset.ptr = offset_ptr;
254  this->is_drop_seed_offset_from_host = false;
255  }
256 
257  float rp_undrop = 1;
259  bool is_store_randval = false;
260  void* rand_val_ptr = nullptr;
261 
264  };
265 
267  {
269  };
270 
273  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
274  FmhaFwdBatchModeBiasKargs,
275  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
276  FmhaFwdAlibiKargs,
277  FmhaFwdEmptyKargs<0>>>,
278  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
279  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
280  std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
281  std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
282  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
283  {
288  };
289 
292  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
293  FmhaFwdCommonBiasKargs,
294  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
295  FmhaFwdAlibiKargs,
296  FmhaFwdEmptyKargs<0>>>,
297  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
298  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
299  std::conditional_t<kDoFp8StaticQuant, FmhaFwdFp8StaticQuantKargs, FmhaFwdEmptyKargs<3>>,
300  std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
301  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
302  {
306  };
307 
308  using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
309 
311  {
315  };
316 
317  template <bool Cond = !kIsGroupMode>
318  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
319  MakeKargs(const void* q_ptr,
320  const void* k_ptr,
321  const void* v_ptr,
322  const void* bias_ptr,
323  void* rand_val_ptr,
324  void* lse_ptr,
325  void* o_ptr,
326  ck_tile::index_t seqlen_q,
327  ck_tile::index_t hdim_q,
328  ck_tile::index_t hdim_v,
329  ck_tile::index_t num_head_q,
330  ck_tile::index_t nhead_ratio_qk,
331  int32_t num_total_pages,
332  const void* kv_indptr,
333  const void* kv_page_indices,
334 #if 0 // we assume page_block_size=1 for now
335  const void* kv_last_page_lens,
336  ck_tile::index_t page_block_size,
337 #endif
338  float scale_s,
339  float scale_p,
340  float scale_o,
341  float logits_soft_cap,
342  ck_tile::index_t stride_q,
343  ck_tile::index_t stride_k,
344  ck_tile::index_t stride_v,
345  ck_tile::index_t stride_bias,
346  ck_tile::index_t stride_randval,
347  ck_tile::index_t stride_o,
348  ck_tile::index_t nhead_stride_q,
349  ck_tile::index_t nhead_stride_k,
350  ck_tile::index_t nhead_stride_v,
351  ck_tile::index_t nhead_stride_bias,
352  ck_tile::index_t nhead_stride_randval,
353  ck_tile::index_t nhead_stride_lse,
354  ck_tile::index_t nhead_stride_o,
355  ck_tile::index_t batch_stride_q,
356  ck_tile::index_t batch_stride_k,
357  ck_tile::index_t batch_stride_v,
358  ck_tile::index_t batch_stride_bias,
359  ck_tile::index_t batch_stride_randval,
360  ck_tile::index_t batch_stride_lse,
361  ck_tile::index_t batch_stride_o,
362  ck_tile::index_t window_size_left,
363  ck_tile::index_t window_size_right,
364  ck_tile::index_t mask_type,
365  float p_drop,
366  bool s_randval,
367  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
368  drop_seed_offset)
369  {
370  Kargs kargs{{q_ptr,
371  k_ptr,
372  v_ptr,
373  o_ptr,
374  seqlen_q,
375  -1,
376  hdim_q,
377  hdim_v,
378  num_head_q,
379  nhead_ratio_qk,
380  num_total_pages,
381  reinterpret_cast<const int32_t*>(kv_indptr),
382  reinterpret_cast<const int32_t*>(kv_page_indices),
383 #if 0 // we assume page_block_size=1 for now
384  reinterpret_cast<const int32_t*>(kv_last_page_lens),
385  page_block_size,
386 #endif
388  static_cast<float>(scale_s * ck_tile::log2e_v<>),
389 #else
390  scale_s,
391 #endif
392  stride_q,
393  stride_k,
394  stride_v,
395  stride_o,
396  nhead_stride_q,
397  nhead_stride_k,
398  nhead_stride_v,
399  nhead_stride_o}, // args for common karg
400  {}, // placeholder for bias
401  {}, // placeholder for mask
402  {}, // placeholder for lse
403  {}, // placeholder for fp8_static_quant args
404  {}, // placeholder for dropout
405  {}, // placeholder for logits_soft_cap
406  batch_stride_q,
407  batch_stride_k,
408  batch_stride_v,
409  batch_stride_o};
410 
412  {
413  kargs.bias_ptr = bias_ptr;
414  kargs.stride_bias = stride_bias;
415  kargs.nhead_stride_bias = nhead_stride_bias;
416  kargs.batch_stride_bias = batch_stride_bias;
417  }
418  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
419  {
420  kargs.alibi_slope_ptr = bias_ptr;
421  kargs.alibi_slope_stride = stride_bias;
422  }
423  if constexpr(kHasMask)
424  {
425  kargs.window_size_left = window_size_left;
426  kargs.window_size_right = window_size_right;
427  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
428  }
429  if constexpr(kStoreLSE)
430  {
431  kargs.lse_ptr = lse_ptr;
432  kargs.nhead_stride_lse = nhead_stride_lse;
433  kargs.batch_stride_lse = batch_stride_lse;
434  }
435  if constexpr(kDoFp8StaticQuant)
436  {
437  kargs.scale_p = scale_p;
438  kargs.scale_o = scale_o;
439  }
440  if constexpr(kHasDropout)
441  {
442  if(drop_seed_offset.index() == 0) // seed & offset come from host
443  {
444  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
445  kargs.init_dropout(p_drop, seed, offset);
446  }
447  else // seed & offset come from device
448  {
449  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
450  kargs.init_dropout(p_drop,
451  reinterpret_cast<const uint64_t*>(seed_ptr),
452  reinterpret_cast<const uint64_t*>(offset_ptr));
453  }
454 
455  kargs.rand_val_ptr = rand_val_ptr;
456  kargs.stride_randval = stride_randval;
457  kargs.nhead_stride_randval = nhead_stride_randval;
458  kargs.batch_stride_randval = batch_stride_randval;
459  kargs.is_store_randval = s_randval;
460  }
461  if constexpr(kHasLogitsSoftCap)
462  {
463  kargs.init_logits_soft_cap(logits_soft_cap);
464  }
465 
466  return kargs;
467  }
468 
469  template <bool Cond = kIsGroupMode>
470  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
471  MakeKargs(const void* q_ptr,
472  const void* k_ptr,
473  const void* v_ptr,
474  const void* bias_ptr,
475  void* rand_val_ptr,
476  void* lse_ptr,
477  void* o_ptr,
478  const void* seqstart_q_ptr,
479  ck_tile::index_t hdim_q,
480  ck_tile::index_t hdim_v,
481  ck_tile::index_t num_head_q,
482  ck_tile::index_t nhead_ratio_qk,
483  int32_t num_total_pages,
484  const void* kv_indptr,
485  const void* kv_page_indices,
486 #if 0 // we assume page_block_size=1 for now
487  const void* kv_last_page_lens,
488  ck_tile::index_t page_block_size,
489 #endif
490  float scale_s,
491  float scale_p,
492  float scale_o,
493  float logits_soft_cap,
494  ck_tile::index_t stride_q,
495  ck_tile::index_t stride_k,
496  ck_tile::index_t stride_v,
497  ck_tile::index_t stride_bias,
498  ck_tile::index_t stride_randval,
499  ck_tile::index_t stride_o,
500  ck_tile::index_t nhead_stride_q,
501  ck_tile::index_t nhead_stride_k,
502  ck_tile::index_t nhead_stride_v,
503  ck_tile::index_t nhead_stride_bias,
504  ck_tile::index_t nhead_stride_randval,
505  ck_tile::index_t nhead_stride_lse,
506  ck_tile::index_t nhead_stride_o,
507  ck_tile::index_t batch_stride_k,
508  ck_tile::index_t batch_stride_v,
509  ck_tile::index_t window_size_left,
510  ck_tile::index_t window_size_right,
511  ck_tile::index_t mask_type,
512  float p_drop,
513  bool s_randval,
514  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
515  drop_seed_offset)
516  {
517  Kargs kargs{{q_ptr,
518  k_ptr,
519  v_ptr,
520  o_ptr,
521  -1, // seqlen will be updated by another pointer
522  -1, //
523  hdim_q,
524  hdim_v,
525  num_head_q,
526  nhead_ratio_qk,
527  num_total_pages,
528  reinterpret_cast<const int32_t*>(kv_indptr),
529  reinterpret_cast<const int32_t*>(kv_page_indices),
530 #if 0 // we assume page_block_size=1 for now
531  reinterpret_cast<const int32_t*>(kv_last_page_lens),
532  page_block_size,
533 #endif
535  static_cast<float>(scale_s * ck_tile::log2e_v<>),
536 #else
537  scale_s,
538 #endif
539  stride_q,
540  stride_k,
541  stride_v,
542  stride_o,
543  nhead_stride_q,
544  nhead_stride_k,
545  nhead_stride_v,
546  nhead_stride_o}, // args for common karg
547  {}, // placeholder for bias
548  {}, // placeholder for mask
549  {}, // placeholder for lse
550  {}, // placeholder for fp8_static_quant args
551  {}, // placeholder for dropout
552  {}, // placeholder for logits_soft_cap
553  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
554  batch_stride_k,
555  batch_stride_v};
556 
558  {
559  kargs.bias_ptr = bias_ptr;
560  kargs.stride_bias = stride_bias;
561  kargs.nhead_stride_bias = nhead_stride_bias;
562  }
563  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
564  {
565  kargs.alibi_slope_ptr = bias_ptr;
566  kargs.alibi_slope_stride = stride_bias;
567  }
568  if constexpr(kHasMask)
569  {
570  kargs.window_size_left = window_size_left;
571  kargs.window_size_right = window_size_right;
572  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
573  }
574  if constexpr(kStoreLSE)
575  {
576  kargs.lse_ptr = lse_ptr;
577  kargs.nhead_stride_lse = nhead_stride_lse;
578  }
579  if constexpr(kDoFp8StaticQuant)
580  {
581  kargs.scale_p = scale_p;
582  kargs.scale_o = scale_o;
583  }
584  if constexpr(kHasDropout)
585  {
586  if(drop_seed_offset.index() == 0) // seed & offset come from host
587  {
588  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
589  kargs.init_dropout(p_drop, seed, offset);
590  }
591  else // seed & offset come from device
592  {
593  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
594  kargs.init_dropout(p_drop,
595  reinterpret_cast<const uint64_t*>(seed_ptr),
596  reinterpret_cast<const uint64_t*>(offset_ptr));
597  }
598 
599  kargs.rand_val_ptr = rand_val_ptr;
600  kargs.stride_randval = stride_randval;
601  kargs.nhead_stride_randval = nhead_stride_randval;
602  kargs.is_store_randval = s_randval;
603  }
604  if constexpr(kHasLogitsSoftCap)
605  {
606  kargs.init_logits_soft_cap(logits_soft_cap);
607  }
608 
609  return kargs;
610  }
611 
612  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
613  ck_tile::index_t nhead_,
614  ck_tile::index_t seqlen_q_,
615  ck_tile::index_t hdim_v_)
616  {
617  if constexpr(kIsGroupMode)
618  {
619  // TODO: this may need tuning
620  return dim3(nhead_,
621  batch_size_,
622  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
623  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1));
624  }
625  else
626  {
627  // TODO: this may need tuning
628  return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
629  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
630  nhead_,
631  batch_size_);
632  }
633  }
634 
635  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
636  {
637  if constexpr(kIsGroupMode)
638  {
639  // const index_t num_tile_m0 = seqlen_q / kM0;
640  const index_t num_tile_n1 =
641  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
642 
643  const index_t i_block = blockIdx.z;
644  const index_t i_nhead = blockIdx.x;
645  const index_t i_batch = blockIdx.y;
646 
647  const auto f = [](index_t dividend, index_t divisor) {
648  index_t quotient = dividend / divisor;
649  index_t modulus = dividend - quotient * divisor;
650  return ck_tile::make_tuple(quotient, modulus);
651  };
652 
653  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
654  if constexpr(kHasMask)
655  {
656  // assume that num_tile_n1 is always 1
657  return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
658  }
659  else
660  {
661  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
662  }
663  }
664  else
665  {
666  // const index_t num_tile_m0 = seqlen_q / kM0;
667  const index_t num_tile_n1 =
668  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
669 
670  const index_t i_block = blockIdx.x;
671  const index_t i_nhead = blockIdx.y;
672  const index_t i_batch = blockIdx.z;
673 
674  const auto f = [](index_t dividend, index_t divisor) {
675  index_t quotient = dividend / divisor;
676  index_t modulus = dividend - quotient * divisor;
677  return ck_tile::make_tuple(quotient, modulus);
678  };
679 
680  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
681 
682  if constexpr(kHasMask)
683  {
684  // assume that num_tile_n1 is always 1
685  return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
686  }
687  else
688  {
689  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
690  }
691  }
692  }
693 
694  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
695 
697  {
698  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
699  }
700 
701  CK_TILE_DEVICE void operator()(Kargs kargs) const
702  {
703  // allocate LDS
704  __shared__ char smem_ptr[GetSmemSize()];
705 
706  // divide problem
707  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
708 
709  const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
710  const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
711 
712  long_index_t batch_offset_q = 0;
713  long_index_t batch_offset_bias = 0;
714  long_index_t batch_offset_randval = 0;
715  long_index_t batch_offset_lse = 0;
716  long_index_t batch_offset_o = 0;
717 
718  const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch];
719 #if 0 // we assume page_block_size=1 for now
720  const int32_t last_page_len = kargs.kv_last_page_lens[i_batch];
721 #endif
722  if constexpr(kIsGroupMode)
723  {
724  // get starting offset for each batch
725  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
726 
727  batch_offset_q = query_start * kargs.stride_q;
728 
729  kargs.kv_page_indices += kargs.kv_indptr[i_batch];
730 
732  {
733  batch_offset_bias = query_start * kargs.stride_bias;
734  }
735  if constexpr(kStoreLSE)
736  {
737  batch_offset_lse = query_start;
738  }
739  if constexpr(kHasDropout)
740  {
741  batch_offset_randval = query_start * kargs.stride_randval;
742  }
743  batch_offset_o = query_start * kargs.stride_o;
744 
745  // get real # queries & # keys under group mode
746  kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - query_start;
747 
748  // # of required blocks is different in each groups, terminate unnecessary blocks
749  // earlier
750  if(kargs.seqlen_q <= i_m0)
751  {
752  return;
753  }
754 
755 #if 0 // we assume page_block_size=1 for now
756  kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
757 #else
758  kargs.seqlen_k = num_page_blocks;
759 #endif
760  }
761  else
762  {
763  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
764 
765  kargs.kv_page_indices += kargs.kv_indptr[i_batch];
766 
768  {
769  batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
770  }
771  if constexpr(kStoreLSE)
772  {
773  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
774  }
775  if constexpr(kHasDropout)
776  {
777  batch_offset_randval =
778  static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
779  }
780  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
781 
782 #if 0 // we assume page_block_size=1 for now
783  kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
784 #else
785  kargs.seqlen_k = num_page_blocks;
786 #endif
787  }
788 
789  // for simplicity, batch stride we just modify the pointer
790  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
791  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
792  batch_offset_q;
793  const KDataType* k_ptr =
794  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
795  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k;
796  const VDataType* v_ptr =
797  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
798  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v;
799  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
800  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
801  batch_offset_o;
802 
803  // Q/K/V DRAM and DRAM window
804  const auto q_dram = [&]() {
805  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
806  q_ptr,
807  make_tuple(kargs.seqlen_q, kargs.hdim_q),
808  make_tuple(kargs.stride_q, 1),
810  number<1>{});
811  if constexpr(FmhaPipeline::kQLoadOnce)
812  {
813  return pad_tensor_view(
814  q_dram_naive,
817  }
818  else
819  {
820  return pad_tensor_view(
821  q_dram_naive,
824  }
825  }();
826  const auto k_dram = [&]() {
827  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
828  k_ptr,
829  make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_q),
830  make_tuple(kargs.stride_k, 1),
832  number<1>{});
833 
834  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
835  return pad_tensor_view(
836  k_dram_naive,
839  }();
840  const auto v_dram = [&]() {
841  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
842  {
843  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
844  v_ptr,
845  make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_v),
846  make_tuple(kargs.stride_v, 1),
848  number<1>{});
849 
850  const auto v_dram_transposed = transform_tensor_view(
851  v_dram_naive,
852  make_tuple(
853  make_pass_through_transform(kargs.hdim_v),
854  make_pass_through_transform(kargs.num_total_pages * kargs.page_block_size)),
857 
858  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
859  return pad_tensor_view(
860  v_dram_transposed,
863  }
864  else
865  {
866  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
867  v_ptr,
868  make_tuple(kargs.hdim_v, kargs.num_total_pages * kargs.page_block_size),
869  make_tuple(kargs.stride_v, 1),
871  number<1>{});
872 
873  constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false;
874  return pad_tensor_view(
875  v_dram_naive,
878  }
879  }();
880 
881  auto q_dram_window = make_tile_window(
882  q_dram,
883  [&]() {
884  if constexpr(FmhaPipeline::kQLoadOnce)
887  else
889  }(),
890  {i_m0, 0});
891 
892  auto k_dram_window = make_tile_window(
894 
895  auto v_dram_window =
896  make_tile_window(v_dram,
898  {i_n1, 0});
901  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
902  constexpr auto bias_dram_window_lengths =
905  {
906  const BiasDataType* bias_ptr =
907  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
908  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
909  batch_offset_bias;
910 
911  const auto bias_dram = [&]() {
912  const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
913  bias_ptr,
914  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
915  make_tuple(kargs.stride_bias, 1),
917  number<1>{});
918 
919  return pad_tensor_view(bias_dram_naive,
920  bias_dram_window_lengths,
922  }();
923 
924  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
925  }
926  else
927  {
928  return make_null_tile_window(bias_dram_window_lengths);
929  }
930  }();
931 
932  // lse
933  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
934  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
935  if constexpr(kStoreLSE)
936  {
937  LSEDataType* lse_ptr =
938  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
939  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
940 
941  const auto lse_dram = [&]() {
942  const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
943  lse_ptr,
944  make_tuple(kargs.seqlen_q),
945  make_tuple(1),
946  number<1>{},
947  number<1>{});
948 
949  return pad_tensor_view(
950  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
951  }();
952 
953  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
954  }
955  else
956  {
957  return make_null_tile_window(lse_dram_window_lengths);
958  }
959  }();
960 
961  auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
962  if constexpr(kHasDropout)
963  {
964  return BlockDropout{i_batch_,
965  i_nhead_,
966  kargs.num_head_q,
967  kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
968  : *kargs.drop_seed.ptr,
969  kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
970  : *kargs.drop_offset.ptr,
971  kargs.rp_undrop,
972  kargs.p_undrop_in_uint8_t,
973  kargs.is_store_randval};
974  }
975  else
976  {
977  return NullBlockDropout{};
978  };
979  }();
980 
981  auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
982  constexpr auto randval_dram_window_lengths =
984  if constexpr(kHasDropout)
985  {
986  RandValOutputDataType* rand_val_ptr =
987  reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
988  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
989  batch_offset_randval;
990 
991  const auto randval_dram = [&]() {
992  const auto randval_dram_naive =
993  make_naive_tensor_view<address_space_enum::global>(
994  rand_val_ptr,
995  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
996  make_tuple(kargs.stride_randval, 1),
997  number<1>{},
998  number<1>{});
999 
1000  return pad_tensor_view(randval_dram_naive,
1001  randval_dram_window_lengths,
1003  }();
1004 
1005  return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0});
1006  }
1007  else
1008  {
1009  return make_null_tile_window(randval_dram_window_lengths);
1010  }
1011  }();
1012 
1013  FmhaMask mask = [&]() {
1014  if constexpr(kHasMask)
1015  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
1016  kargs.window_size_left,
1017  kargs.window_size_right,
1018  kargs.seqlen_q,
1019  kargs.seqlen_k,
1021  else
1022  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1023  }();
1024 
1025  // WA i_batch capture structure binding before c++20
1026  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1027  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
1028  {
1029  // data loading, shared by entire wg
1030  // TODO: how to use s_read?
1031  SaccDataType slope =
1032  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1033  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1034 #if CK_TILE_FMHA_FWD_FAST_EXP2
1035  slope *= ck_tile::log2e_v<>;
1036 #endif
1037  if constexpr(kHasMask)
1038  {
1039  return make_alibi_from_lr_mask<SaccDataType, true>(slope,
1040  kargs.window_size_left,
1041  kargs.window_size_right,
1042  kargs.seqlen_q,
1043  kargs.seqlen_k,
1044  kargs.mask_type);
1045  }
1046  else
1047  {
1049  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1050  }
1051  }
1052  else
1053  {
1055  }
1056  }();
1057 
1058  AttentionVariant variant;
1059  const auto variant_params = [&] {
1060  if constexpr(kHasLogitsSoftCap)
1061  {
1063  mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1064  }
1065  else
1066  {
1067  return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
1068  }
1069  }();
1070 
1071  BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
1072 
1073  auto o_acc_tile = [&]() {
1074  if constexpr(kDoFp8StaticQuant)
1075  {
1076  return FmhaPipeline{}(
1077  q_dram_window,
1078  identity{}, // q_element_func
1079  k_dram_window,
1080  identity{}, // k_element_func
1081  v_dram_window,
1082  identity{}, // v_element_func
1083  bias_dram_window,
1084  identity{}, // bias_element_func
1085  randval_dram_window,
1086  lse_dram_window,
1087  identity{}, // lse_element_func
1088  identity{}, // s_acc_element_func
1089  scales{kargs.scale_p}, // p_compute_element_func
1090  composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
1091  mask,
1092  position_encoding,
1093  kargs.scale_s,
1094  variant,
1095  variant_params,
1096  block_indices,
1097  smem_ptr,
1098  kargs.kv_page_indices,
1099  kargs.stride_k,
1100  kargs.stride_v,
1101  dropout);
1102  }
1103  else
1104  {
1105  return FmhaPipeline{}(q_dram_window,
1106  k_dram_window,
1107  v_dram_window,
1108  bias_dram_window,
1109  randval_dram_window,
1110  lse_dram_window,
1111  mask,
1112  position_encoding,
1113  kargs.scale_s,
1114  variant,
1115  variant_params,
1116  block_indices,
1117  smem_ptr,
1118  kargs.kv_page_indices,
1119  kargs.stride_k,
1120  kargs.stride_v,
1121  dropout);
1122  }
1123  }();
1124 
1125  // O DRAM and O DRAM window
1126  auto o_dram = [&]() {
1127  const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1128  o_ptr,
1129  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1130  make_tuple(kargs.stride_o, 1),
1132  number<1>{});
1133 
1134  return pad_tensor_view(
1135  o_dram_naive,
1138  }();
1139 
1140  auto o_dram_window =
1141  make_tile_window(o_dram,
1143  {i_m0, i_n1});
1144 
1145  EpiloguePipeline{}(o_dram_window, o_acc_tile);
1146  }
1147 };
1148 
1149 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_FMHA_FWD_FAST_EXP2
Definition: config.hpp:223
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
#define _TS_
#define _SS_
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ T floor(T x)
Definition: math_v2.hpp:367
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:63
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
_Float16 fp16_t
Definition: half.hpp:110
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:510
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:529
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1615
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
__host__ __device__ composes(Ts &&...) -> composes< remove_cvref_t< Ts >... >
FIXME: create macro to replace 'host device' and nothing more.
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: block_position_encoding.hpp:48
Definition: block_attention_bias_enum.hpp:19
Definition: block_dropout.hpp:26
const float rp_undrop
Definition: block_dropout.hpp:290
Definition: block_position_encoding.hpp:137
Definition: fmha_batch_prefill_kernel.hpp:311
ck_tile::index_t kv_head_idx
Definition: fmha_batch_prefill_kernel.hpp:314
ck_tile::index_t qo_head_idx
Definition: fmha_batch_prefill_kernel.hpp:313
ck_tile::index_t batch_idx
Definition: fmha_batch_prefill_kernel.hpp:312
Definition: fmha_batch_prefill_kernel.hpp:191
ck_tile::index_t alibi_slope_stride
Definition: fmha_batch_prefill_kernel.hpp:194
const void * alibi_slope_ptr
Definition: fmha_batch_prefill_kernel.hpp:193
ck_tile::index_t batch_stride_bias
Definition: fmha_batch_prefill_kernel.hpp:187
ck_tile::index_t batch_stride_randval
Definition: fmha_batch_prefill_kernel.hpp:268
ck_tile::index_t batch_stride_o
Definition: fmha_batch_prefill_kernel.hpp:287
ck_tile::index_t batch_stride_v
Definition: fmha_batch_prefill_kernel.hpp:286
ck_tile::index_t batch_stride_q
Definition: fmha_batch_prefill_kernel.hpp:284
ck_tile::index_t batch_stride_k
Definition: fmha_batch_prefill_kernel.hpp:285
ck_tile::index_t nhead_stride_bias
Definition: fmha_batch_prefill_kernel.hpp:182
ck_tile::index_t stride_bias
Definition: fmha_batch_prefill_kernel.hpp:181
const void * bias_ptr
Definition: fmha_batch_prefill_kernel.hpp:180
ck_tile::index_t stride_randval
Definition: fmha_batch_prefill_kernel.hpp:262
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr)
Definition: fmha_batch_prefill_kernel.hpp:245
ck_tile::index_t nhead_stride_randval
Definition: fmha_batch_prefill_kernel.hpp:263
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
Definition: fmha_batch_prefill_kernel.hpp:233
void * rand_val_ptr
Definition: fmha_batch_prefill_kernel.hpp:260
float rp_undrop
Definition: fmha_batch_prefill_kernel.hpp:257
bool is_store_randval
Definition: fmha_batch_prefill_kernel.hpp:259
uint8_t p_undrop_in_uint8_t
Definition: fmha_batch_prefill_kernel.hpp:258
Definition: fmha_batch_prefill_kernel.hpp:117
ck_tile::index_t stride_q
Definition: fmha_batch_prefill_kernel.hpp:145
ck_tile::index_t stride_v
Definition: fmha_batch_prefill_kernel.hpp:147
int32_t num_total_pages
Definition: fmha_batch_prefill_kernel.hpp:133
float scale_s
Definition: fmha_batch_prefill_kernel.hpp:143
ck_tile::index_t seqlen_q
Definition: fmha_batch_prefill_kernel.hpp:123
ck_tile::index_t stride_k
Definition: fmha_batch_prefill_kernel.hpp:146
ck_tile::index_t nhead_stride_o
Definition: fmha_batch_prefill_kernel.hpp:153
ck_tile::index_t nhead_stride_k
Definition: fmha_batch_prefill_kernel.hpp:151
ck_tile::index_t nhead_ratio_qk
Definition: fmha_batch_prefill_kernel.hpp:131
ck_tile::index_t nhead_stride_v
Definition: fmha_batch_prefill_kernel.hpp:152
ck_tile::index_t nhead_stride_q
Definition: fmha_batch_prefill_kernel.hpp:150
const int32_t * kv_page_indices
Definition: fmha_batch_prefill_kernel.hpp:135
const void * v_ptr
Definition: fmha_batch_prefill_kernel.hpp:120
const int32_t * kv_indptr
Definition: fmha_batch_prefill_kernel.hpp:134
void * o_ptr
Definition: fmha_batch_prefill_kernel.hpp:121
ck_tile::index_t seqlen_k
Definition: fmha_batch_prefill_kernel.hpp:124
ck_tile::index_t stride_o
Definition: fmha_batch_prefill_kernel.hpp:148
ck_tile::index_t hdim_v
Definition: fmha_batch_prefill_kernel.hpp:126
ck_tile::index_t num_head_q
Definition: fmha_batch_prefill_kernel.hpp:128
static constexpr ck_tile::index_t page_block_size
Definition: fmha_batch_prefill_kernel.hpp:140
const void * k_ptr
Definition: fmha_batch_prefill_kernel.hpp:119
ck_tile::index_t hdim_q
Definition: fmha_batch_prefill_kernel.hpp:125
const void * q_ptr
Definition: fmha_batch_prefill_kernel.hpp:118
ck_tile::index_t batch_stride_lse
Definition: fmha_batch_prefill_kernel.hpp:214
ck_tile::index_t nhead_stride_lse
Definition: fmha_batch_prefill_kernel.hpp:213
void * lse_ptr
Definition: fmha_batch_prefill_kernel.hpp:212
bool is_drop_seed_offset_from_host
Definition: fmha_batch_prefill_kernel.hpp:228
ValueOrPointer< uint64_t > drop_seed
Definition: fmha_batch_prefill_kernel.hpp:226
ValueOrPointer< uint64_t > drop_offset
Definition: fmha_batch_prefill_kernel.hpp:227
Definition: fmha_batch_prefill_kernel.hpp:110
float scale_p
Definition: fmha_batch_prefill_kernel.hpp:206
float scale_o
Definition: fmha_batch_prefill_kernel.hpp:207
ck_tile::index_t batch_stride_v
Definition: fmha_batch_prefill_kernel.hpp:305
ck_tile::index_t batch_stride_k
Definition: fmha_batch_prefill_kernel.hpp:304
const int32_t * seqstart_q_ptr
Definition: fmha_batch_prefill_kernel.hpp:303
float logits_soft_cap_rcp
Definition: fmha_batch_prefill_kernel.hpp:175
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_batch_prefill_kernel.hpp:160
float logits_soft_cap
Definition: fmha_batch_prefill_kernel.hpp:174
Definition: fmha_batch_prefill_kernel.hpp:198
ck_tile::index_t window_size_right
Definition: fmha_batch_prefill_kernel.hpp:200
ck_tile::index_t window_size_left
Definition: fmha_batch_prefill_kernel.hpp:200
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_batch_prefill_kernel.hpp:201
Definition: fmha_batch_prefill_kernel.hpp:63
Definition: fmha_batch_prefill_kernel.hpp:26
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_batch_prefill_kernel.hpp:635
static constexpr bool kIsGroupMode
Definition: fmha_batch_prefill_kernel.hpp:46
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_batch_prefill_kernel.hpp:30
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_batch_prefill_kernel.hpp:36
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_batch_prefill_kernel.hpp:27
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_batch_prefill_kernel.hpp:35
static constexpr bool kPadSeqLenQ
Definition: fmha_batch_prefill_kernel.hpp:47
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition: fmha_batch_prefill_kernel.hpp:39
static constexpr bool kDoFp8StaticQuant
Definition: fmha_batch_prefill_kernel.hpp:55
static constexpr bool kPadHeadDimV
Definition: fmha_batch_prefill_kernel.hpp:50
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_batch_prefill_kernel.hpp:40
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_batch_prefill_kernel.hpp:34
static constexpr bool kHasMask
Definition: fmha_batch_prefill_kernel.hpp:58
static CK_TILE_HOST std::string GetName()
Definition: fmha_batch_prefill_kernel.hpp:71
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_batch_prefill_kernel.hpp:37
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_batch_prefill_kernel.hpp:57
static constexpr bool kPadSeqLenK
Definition: fmha_batch_prefill_kernel.hpp:48
static constexpr bool kHasLogitsSoftCap
Definition: fmha_batch_prefill_kernel.hpp:51
static constexpr bool kHasDropout
Definition: fmha_batch_prefill_kernel.hpp:54
static constexpr bool kStoreLSE
Definition: fmha_batch_prefill_kernel.hpp:53
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_batch_prefill_kernel.hpp:42
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_batch_prefill_kernel.hpp:696
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_batch_prefill_kernel.hpp:29
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, int32_t num_total_pages, const void *kv_indptr, const void *kv_page_indices, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_batch_prefill_kernel.hpp:471
static constexpr auto BiasEnum
Definition: fmha_batch_prefill_kernel.hpp:52
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_batch_prefill_kernel.hpp:44
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_batch_prefill_kernel.hpp:41
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_batch_prefill_kernel.hpp:56
static constexpr bool kUseAsyncCopy
Definition: fmha_batch_prefill_kernel.hpp:60
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, int32_t num_total_pages, const void *kv_indptr, const void *kv_page_indices, float scale_s, float scale_p, float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_batch_prefill_kernel.hpp:319
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_batch_prefill_kernel.hpp:32
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_)
Definition: fmha_batch_prefill_kernel.hpp:612
static constexpr bool kPadHeadDimQ
Definition: fmha_batch_prefill_kernel.hpp:49
static constexpr CK_TILE_HOST auto BlockSize()
Definition: fmha_batch_prefill_kernel.hpp:694
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_batch_prefill_kernel.hpp:701
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_batch_prefill_kernel.hpp:308
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_batch_prefill_kernel.hpp:28
Definition: variants.hpp:63
float logits_soft_cap
Definition: variants.hpp:128
Definition: block_dropout.hpp:12
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: functional.hpp:86
Definition: coordinate_transform.hpp:1443
Definition: unary_element_function.hpp:56
Definition: math.hpp:28
Definition: sequence.hpp:52