/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/fmha/kernel/fmha_batch_prefill_kernel.hpp Source File
fmha_batch_prefill_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
11 
12 #include <string>
13 #include <type_traits>
14 #include <utility>
15 #include <variant>
16 
17 // S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
18 // S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
19 // S''[seqlen_q, seqlen_k] = S'[seqlen_q, seqlen_k] + Bias[seqlen_q, seqlen_k]
20 // P[seqlen_q, seqlen_k] = Softmax(S''[seqlen_q, seqlen_k])
21 // O[seqlen_q, hdim_v] = P[seqlen_q, seqlen_k] @ V^T[hdim_v, seqlen_k]
22 
23 namespace ck_tile {
24 
25 template <typename FmhaPipeline_, typename EpiloguePipeline_>
27 {
30  static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
31 
32  static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
33  static_assert(kBlockPerCu > 0);
34  static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
35 
45 
47 
48  static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
49  static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
50  static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
51  static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
52  static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
53  static constexpr bool kHasLogitsSoftCap = FmhaPipeline::kHasLogitsSoftCap;
54  static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
55  static constexpr bool kStoreLSE = FmhaPipeline::kStoreLSE;
56  static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
57  static constexpr auto QScaleEnum = FmhaPipeline::Problem::QScaleEnum;
60  static constexpr bool kHasMask = FmhaMask::IsMasking;
61 
62  static constexpr bool kUseAsyncCopy = FmhaPipeline::Policy::AsyncCopy;
63 
64  // clang-format off
65  template <typename T> struct t2s;
66  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
67  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
68  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
69  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
70  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
71  // clang-format on
72 
73  CK_TILE_HOST static std::string GetName()
74  {
75  // sync with generate.py
76  // clang-format off
77  using bfs = typename FmhaPipeline::BlockFmhaShape;
78  using g0br = typename bfs::Gemm0BlockWarps;
79  using g1br = typename bfs::Gemm1BlockWarps;
80  using g0wt = typename bfs::Gemm0WarpTile;
81  using g1wt = typename bfs::Gemm1WarpTile;
82  #define _SS_ std::string
83  #define _TS_ std::to_string
84  auto pn = [&] () {
85  std::string n;
86  if (kPadSeqLenQ) n += "s";
87  if (kPadSeqLenK) n += "sk";
88  if (kPadHeadDimQ) n += "d";
89  if (kPadHeadDimV) n += "dv";
90  return n.empty() ? n : std::string("p") + n; }();
91  return
92  _SS_("fmha_batch_prefill_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
93  "_" + (kIsGroupMode ? "group" : "batch") + "_"
94  "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" +
95  _TS_(bfs::kN1) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kQKHeaddim) + "_" +
96  "r" + _TS_(g0br::at(ck_tile::number<0>{})) + "x" + _TS_(g0br::at(ck_tile::number<1>{})) + "x" + _TS_(g0br::at(ck_tile::number<2>{})) + "_" +
97  "r" + _TS_(g1br::at(ck_tile::number<0>{})) + "x" + _TS_(g1br::at(ck_tile::number<1>{})) + "x" + _TS_(g1br::at(ck_tile::number<2>{})) + "_" +
98  "w" + _TS_(g0wt::at(ck_tile::number<0>{})) + "x" + _TS_(g0wt::at(ck_tile::number<1>{})) + "x" + _TS_(g0wt::at(ck_tile::number<2>{})) + "_" +
99  "w" + _TS_(g1wt::at(ck_tile::number<0>{})) + "x" + _TS_(g1wt::at(ck_tile::number<1>{})) + "x" + _TS_(g1wt::at(ck_tile::number<2>{})) + "_" +
100  (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
101  "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "_npad" : "_" + pn) +
102  (kHasLogitsSoftCap ? "_logits" : "_nlogits" ) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("_nbias") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
103  (kHasMask ? "_" + _SS_(FmhaMask::name) : "_nmask") + (kStoreLSE ? "_lse" : "_nlse" ) + (kHasDropout ? "_dropout" : "_ndropout" ) +
105  #undef _SS_
106  #undef _TS_
107  // clang-format on
108  }
109 
110  template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
111  // arg
113  {
114  };
115 
116  // kargs use aggregate initializer, so no constructor will provided
117  // use inheritance to minimize karg size
118  // user need to use MakeKargs() function to create kargs.
120  {
121  const void* q_ptr;
122  const void* k_ptr;
123  const void* v_ptr;
124  void* o_ptr;
125 
130 
132  // for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
133  // if this param is larger than 1, indicate MQA/GQA case
135 
139 #if 0 // we assume page_block_size=1 for now
140  const int32_t* kv_last_page_lens;
142 #else
143  static constexpr ck_tile::index_t page_block_size = 1;
144 #endif
145 
146  float scale_s;
147 
152 
157  };
158 
160  {
162 
163  void init_logits_soft_cap(float logits_soft_cap_)
164  {
165  if(0 < logits_soft_cap_)
166  {
167  logits_soft_cap = logits_soft_cap_;
169  }
170  else
171  {
172  logits_soft_cap = 0.f;
173  logits_soft_cap_rcp = 0.f;
174  }
175  }
176 
179  };
180 
182  {
183  const void* bias_ptr = nullptr;
186  };
187 
189  {
191  };
192 
194  {
195  // alibi is batch*nhead*1, no matter in batch/group mode, they are the same
196  const void* alibi_slope_ptr;
197  ck_tile::index_t alibi_slope_stride; // stride in batch, or 0 for all batch share same slope
198  };
199 
201  {
202  // ck_tile::index_t window_size_left, window_size_right;
205  };
206 
208  {
209  void* lse_ptr = nullptr;
212  };
213 
215  {
216  template <typename T>
218  {
219  T val;
220  const T* ptr;
221  };
222 
226  };
227 
229  {
230  void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
231  {
232  float p_undrop = 1.0 - p_drop;
235  rp_undrop = 1.0 / p_undrop;
236 
237  this->drop_seed.val = seed;
238  this->drop_offset.val = offset;
239  this->is_drop_seed_offset_from_host = true;
240  }
241 
242  void init_dropout(float p_drop, const uint64_t* seed_ptr, const uint64_t* offset_ptr)
243  {
244  float p_undrop = 1.0 - p_drop;
247  rp_undrop = 1.0 / p_undrop;
248 
249  this->drop_seed.ptr = seed_ptr;
250  this->drop_offset.ptr = offset_ptr;
251  this->is_drop_seed_offset_from_host = false;
252  }
253 
254  float rp_undrop = 1;
256  bool is_store_randval = false;
257  void* rand_val_ptr = nullptr;
258 
261  };
262 
264  {
266  };
267 
270  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
271  FmhaFwdBatchModeBiasKargs,
272  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
273  FmhaFwdAlibiKargs,
274  FmhaFwdEmptyKargs<0>>>,
275  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
276  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
277  std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<3>>,
278  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<4>>
279  {
284  };
285 
288  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS,
289  FmhaFwdCommonBiasKargs,
290  std::conditional_t<BiasEnum == BlockAttentionBiasEnum::ALIBI,
291  FmhaFwdAlibiKargs,
292  FmhaFwdEmptyKargs<0>>>,
293  std::conditional_t<kHasMask, FmhaFwdMaskKargs, FmhaFwdEmptyKargs<1>>,
294  std::conditional_t<kStoreLSE, FmhaFwdCommonLSEKargs, FmhaFwdEmptyKargs<2>>,
295  std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<3>>,
296  std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<4>>
297  {
301  };
302 
303  using Kargs = std::conditional_t<kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs>;
304 
306  {
310  };
311 
312  template <bool Cond = !kIsGroupMode>
313  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
314  MakeKargs(const void* q_ptr,
315  const void* k_ptr,
316  const void* v_ptr,
317  const void* bias_ptr,
318  void* rand_val_ptr,
319  void* lse_ptr,
320  void* o_ptr,
321  ck_tile::index_t seqlen_q,
322  ck_tile::index_t hdim_q,
323  ck_tile::index_t hdim_v,
324  ck_tile::index_t num_head_q,
325  ck_tile::index_t nhead_ratio_qk,
326  int32_t num_total_pages,
327  const void* kv_indptr,
328  const void* kv_page_indices,
329 #if 0 // we assume page_block_size=1 for now
330  const void* kv_last_page_lens,
331  ck_tile::index_t page_block_size,
332 #endif
333  float scale_s,
334  [[maybe_unused]] float scale_p,
335  [[maybe_unused]] float scale_o,
336  float logits_soft_cap,
337  ck_tile::index_t stride_q,
338  ck_tile::index_t stride_k,
339  ck_tile::index_t stride_v,
340  ck_tile::index_t stride_bias,
341  ck_tile::index_t stride_randval,
342  ck_tile::index_t stride_o,
343  ck_tile::index_t nhead_stride_q,
344  ck_tile::index_t nhead_stride_k,
345  ck_tile::index_t nhead_stride_v,
346  ck_tile::index_t nhead_stride_bias,
347  ck_tile::index_t nhead_stride_randval,
348  ck_tile::index_t nhead_stride_lse,
349  ck_tile::index_t nhead_stride_o,
350  ck_tile::index_t batch_stride_q,
351  ck_tile::index_t batch_stride_k,
352  ck_tile::index_t batch_stride_v,
353  ck_tile::index_t batch_stride_bias,
354  ck_tile::index_t batch_stride_randval,
355  ck_tile::index_t batch_stride_lse,
356  ck_tile::index_t batch_stride_o,
357  ck_tile::index_t window_size_left,
358  ck_tile::index_t window_size_right,
359  ck_tile::index_t mask_type,
360  float p_drop,
361  bool s_randval,
362  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
363  drop_seed_offset)
364  {
365  Kargs kargs{{q_ptr,
366  k_ptr,
367  v_ptr,
368  o_ptr,
369  seqlen_q,
370  -1,
371  hdim_q,
372  hdim_v,
373  num_head_q,
374  nhead_ratio_qk,
375  num_total_pages,
376  reinterpret_cast<const int32_t*>(kv_indptr),
377  reinterpret_cast<const int32_t*>(kv_page_indices),
378 #if 0 // we assume page_block_size=1 for now
379  reinterpret_cast<const int32_t*>(kv_last_page_lens),
380  page_block_size,
381 #endif
383  static_cast<float>(scale_s * ck_tile::log2e_v<>),
384 #else
385  scale_s,
386 #endif
387  stride_q,
388  stride_k,
389  stride_v,
390  stride_o,
391  nhead_stride_q,
392  nhead_stride_k,
393  nhead_stride_v,
394  nhead_stride_o}, // args for common karg
395  {}, // placeholder for bias
396  {}, // placeholder for mask
397  {}, // placeholder for lse
398  {}, // placeholder for dropout
399  {}, // placeholder for logits_soft_cap
400  batch_stride_q,
401  batch_stride_k,
402  batch_stride_v,
403  batch_stride_o};
404 
406  {
407  kargs.bias_ptr = bias_ptr;
408  kargs.stride_bias = stride_bias;
409  kargs.nhead_stride_bias = nhead_stride_bias;
410  kargs.batch_stride_bias = batch_stride_bias;
411  }
412  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
413  {
414  kargs.alibi_slope_ptr = bias_ptr;
415  kargs.alibi_slope_stride = stride_bias;
416  }
417  if constexpr(kHasMask)
418  {
419  kargs.window_size_left = window_size_left;
420  kargs.window_size_right = window_size_right;
421  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
422  }
423  if constexpr(kStoreLSE)
424  {
425  kargs.lse_ptr = lse_ptr;
426  kargs.nhead_stride_lse = nhead_stride_lse;
427  kargs.batch_stride_lse = batch_stride_lse;
428  }
429  if constexpr(kHasDropout)
430  {
431  if(drop_seed_offset.index() == 0) // seed & offset come from host
432  {
433  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
434  kargs.init_dropout(p_drop, seed, offset);
435  }
436  else // seed & offset come from device
437  {
438  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
439  kargs.init_dropout(p_drop,
440  reinterpret_cast<const uint64_t*>(seed_ptr),
441  reinterpret_cast<const uint64_t*>(offset_ptr));
442  }
443 
444  kargs.rand_val_ptr = rand_val_ptr;
445  kargs.stride_randval = stride_randval;
446  kargs.nhead_stride_randval = nhead_stride_randval;
447  kargs.batch_stride_randval = batch_stride_randval;
448  kargs.is_store_randval = s_randval;
449  }
450  if constexpr(kHasLogitsSoftCap)
451  {
452  kargs.init_logits_soft_cap(logits_soft_cap);
453  }
454 
455  return kargs;
456  }
457 
458  template <bool Cond = kIsGroupMode>
459  CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
460  MakeKargs(const void* q_ptr,
461  const void* k_ptr,
462  const void* v_ptr,
463  const void* bias_ptr,
464  void* rand_val_ptr,
465  void* lse_ptr,
466  void* o_ptr,
467  const void* seqstart_q_ptr,
468  ck_tile::index_t hdim_q,
469  ck_tile::index_t hdim_v,
470  ck_tile::index_t num_head_q,
471  ck_tile::index_t nhead_ratio_qk,
472  int32_t num_total_pages,
473  const void* kv_indptr,
474  const void* kv_page_indices,
475 #if 0 // we assume page_block_size=1 for now
476  const void* kv_last_page_lens,
477  ck_tile::index_t page_block_size,
478 #endif
479  float scale_s,
480  [[maybe_unused]] float scale_p,
481  [[maybe_unused]] float scale_o,
482  float logits_soft_cap,
483  ck_tile::index_t stride_q,
484  ck_tile::index_t stride_k,
485  ck_tile::index_t stride_v,
486  ck_tile::index_t stride_bias,
487  ck_tile::index_t stride_randval,
488  ck_tile::index_t stride_o,
489  ck_tile::index_t nhead_stride_q,
490  ck_tile::index_t nhead_stride_k,
491  ck_tile::index_t nhead_stride_v,
492  ck_tile::index_t nhead_stride_bias,
493  ck_tile::index_t nhead_stride_randval,
494  ck_tile::index_t nhead_stride_lse,
495  ck_tile::index_t nhead_stride_o,
496  ck_tile::index_t batch_stride_k,
497  ck_tile::index_t batch_stride_v,
498  ck_tile::index_t window_size_left,
499  ck_tile::index_t window_size_right,
500  ck_tile::index_t mask_type,
501  float p_drop,
502  bool s_randval,
503  std::variant<std::pair<uint64_t, uint64_t>, std::pair<const void*, const void*>>
504  drop_seed_offset)
505  {
506  Kargs kargs{{q_ptr,
507  k_ptr,
508  v_ptr,
509  o_ptr,
510  -1, // seqlen will be updated by another pointer
511  -1, //
512  hdim_q,
513  hdim_v,
514  num_head_q,
515  nhead_ratio_qk,
516  num_total_pages,
517  reinterpret_cast<const int32_t*>(kv_indptr),
518  reinterpret_cast<const int32_t*>(kv_page_indices),
519 #if 0 // we assume page_block_size=1 for now
520  reinterpret_cast<const int32_t*>(kv_last_page_lens),
521  page_block_size,
522 #endif
524  static_cast<float>(scale_s * ck_tile::log2e_v<>),
525 #else
526  scale_s,
527 #endif
528  stride_q,
529  stride_k,
530  stride_v,
531  stride_o,
532  nhead_stride_q,
533  nhead_stride_k,
534  nhead_stride_v,
535  nhead_stride_o}, // args for common karg
536  {}, // placeholder for bias
537  {}, // placeholder for mask
538  {}, // placeholder for lse
539  {}, // placeholder for dropout
540  {}, // placeholder for logits_soft_cap
541  reinterpret_cast<const int32_t*>(seqstart_q_ptr),
542  batch_stride_k,
543  batch_stride_v};
544 
546  {
547  kargs.bias_ptr = bias_ptr;
548  kargs.stride_bias = stride_bias;
549  kargs.nhead_stride_bias = nhead_stride_bias;
550  }
551  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
552  {
553  kargs.alibi_slope_ptr = bias_ptr;
554  kargs.alibi_slope_stride = stride_bias;
555  }
556  if constexpr(kHasMask)
557  {
558  kargs.window_size_left = window_size_left;
559  kargs.window_size_right = window_size_right;
560  kargs.mask_type = static_cast<ck_tile::GenericAttentionMaskEnum>(mask_type);
561  }
562  if constexpr(kStoreLSE)
563  {
564  kargs.lse_ptr = lse_ptr;
565  kargs.nhead_stride_lse = nhead_stride_lse;
566  }
567  if constexpr(kHasDropout)
568  {
569  if(drop_seed_offset.index() == 0) // seed & offset come from host
570  {
571  const auto& [seed, offset] = std::get<0>(drop_seed_offset);
572  kargs.init_dropout(p_drop, seed, offset);
573  }
574  else // seed & offset come from device
575  {
576  const auto& [seed_ptr, offset_ptr] = std::get<1>(drop_seed_offset);
577  kargs.init_dropout(p_drop,
578  reinterpret_cast<const uint64_t*>(seed_ptr),
579  reinterpret_cast<const uint64_t*>(offset_ptr));
580  }
581 
582  kargs.rand_val_ptr = rand_val_ptr;
583  kargs.stride_randval = stride_randval;
584  kargs.nhead_stride_randval = nhead_stride_randval;
585  kargs.is_store_randval = s_randval;
586  }
587  if constexpr(kHasLogitsSoftCap)
588  {
589  kargs.init_logits_soft_cap(logits_soft_cap);
590  }
591 
592  return kargs;
593  }
594 
595  CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
596  ck_tile::index_t nhead_,
597  ck_tile::index_t seqlen_q_,
598  ck_tile::index_t hdim_v_)
599  {
600  if constexpr(kIsGroupMode)
601  {
602  // TODO: this may need tuning
603  return dim3(nhead_,
604  batch_size_,
605  ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
606  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1));
607  }
608  else
609  {
610  // TODO: this may need tuning
611  return dim3(ck_tile::integer_divide_ceil(seqlen_q_, FmhaPipeline::kM0) *
612  ck_tile::integer_divide_ceil(hdim_v_, FmhaPipeline::kN1),
613  nhead_,
614  batch_size_);
615  }
616  }
617 
618  CK_TILE_DEVICE static constexpr auto GetTileIndex(const Kargs& kargs)
619  {
620  if constexpr(kIsGroupMode)
621  {
622  // const index_t num_tile_m0 = seqlen_q / kM0;
623  const index_t num_tile_n1 =
624  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
625 
626  const index_t i_block = blockIdx.z;
627  const index_t i_nhead = blockIdx.x;
628  const index_t i_batch = blockIdx.y;
629 
630  const auto f = [](index_t dividend, index_t divisor) {
631  index_t quotient = dividend / divisor;
632  index_t modulus = dividend - quotient * divisor;
633  return ck_tile::make_tuple(quotient, modulus);
634  };
635 
636  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
637  if constexpr(kHasMask)
638  {
639  // assume that num_tile_n1 is always 1
640  return ck_tile::make_tuple(gridDim.z - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
641  }
642  else
643  {
644  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
645  }
646  }
647  else
648  {
649  // const index_t num_tile_m0 = seqlen_q / kM0;
650  const index_t num_tile_n1 =
651  ck_tile::integer_divide_ceil(kargs.hdim_v, FmhaPipeline::kN1);
652 
653  const index_t i_block = blockIdx.x;
654  const index_t i_nhead = blockIdx.y;
655  const index_t i_batch = blockIdx.z;
656 
657  const auto f = [](index_t dividend, index_t divisor) {
658  index_t quotient = dividend / divisor;
659  index_t modulus = dividend - quotient * divisor;
660  return ck_tile::make_tuple(quotient, modulus);
661  };
662 
663  const auto [i_tile_m, i_tile_n] = f(i_block, num_tile_n1);
664 
665  if constexpr(kHasMask)
666  {
667  // assume that num_tile_n1 is always 1
668  return ck_tile::make_tuple(gridDim.x - 1 - i_tile_m, i_tile_n, i_nhead, i_batch);
669  }
670  else
671  {
672  return ck_tile::make_tuple(i_tile_m, i_tile_n, i_nhead, i_batch);
673  }
674  }
675  }
676 
677  CK_TILE_HOST static dim3 BlockSize()
678  {
679  if(is_wave32())
680  {
681  return dim3(kBlockSize / 2);
682  }
683  else
684  {
685  return dim3(kBlockSize);
686  }
687  }
688 
690  {
691  return ck_tile::max(FmhaPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
692  }
693 
694  CK_TILE_DEVICE void operator()(Kargs kargs) const
695  {
696  // allocate LDS
697  __shared__ char smem_ptr[GetSmemSize()];
698 
699  // divide problem
700  const auto [i_tile_m, i_tile_n, i_nhead, i_batch] = GetTileIndex(kargs);
701 
702  const index_t i_m0 = amd_wave_read_first_lane(i_tile_m * FmhaPipeline::kM0);
703  const index_t i_n1 = amd_wave_read_first_lane(i_tile_n * FmhaPipeline::kN1);
704 
705  long_index_t batch_offset_q = 0;
706  long_index_t batch_offset_bias = 0;
707  long_index_t batch_offset_randval = 0;
708  long_index_t batch_offset_lse = 0;
709  long_index_t batch_offset_o = 0;
710 
711  const int32_t num_page_blocks = kargs.kv_indptr[i_batch + 1] - kargs.kv_indptr[i_batch];
712 #if 0 // we assume page_block_size=1 for now
713  const int32_t last_page_len = kargs.kv_last_page_lens[i_batch];
714 #endif
715  if constexpr(kIsGroupMode)
716  {
717  // get starting offset for each batch
718  const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
719 
720  batch_offset_q = query_start * kargs.stride_q;
721 
722  kargs.kv_page_indices += kargs.kv_indptr[i_batch];
723 
725  {
726  batch_offset_bias = query_start * kargs.stride_bias;
727  }
728  if constexpr(kStoreLSE)
729  {
730  batch_offset_lse = query_start;
731  }
732  if constexpr(kHasDropout)
733  {
734  batch_offset_randval = query_start * kargs.stride_randval;
735  }
736  batch_offset_o = query_start * kargs.stride_o;
737 
738  // get real # queries & # keys under group mode
739  kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - query_start;
740 
741  // # of required blocks is different in each groups, terminate unnecessary blocks
742  // earlier
743  if(kargs.seqlen_q <= i_m0)
744  {
745  return;
746  }
747 
748 #if 0 // we assume page_block_size=1 for now
749  kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
750 #else
751  kargs.seqlen_k = num_page_blocks;
752 #endif
753  }
754  else
755  {
756  batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
757 
758  kargs.kv_page_indices += kargs.kv_indptr[i_batch];
759 
761  {
762  batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
763  }
764  if constexpr(kStoreLSE)
765  {
766  batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
767  }
768  if constexpr(kHasDropout)
769  {
770  batch_offset_randval =
771  static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
772  }
773  batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
774 
775 #if 0 // we assume page_block_size=1 for now
776  kargs.seqlen_k = (num_page_blocks - 1) * kargs.page_block_size + last_page_len;
777 #else
778  kargs.seqlen_k = num_page_blocks;
779 #endif
780  }
781 
782  // for simplicity, batch stride we just modify the pointer
783  const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
784  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
785  batch_offset_q;
786  const KDataType* k_ptr =
787  reinterpret_cast<const KDataType*>(kargs.k_ptr) +
788  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k;
789  const VDataType* v_ptr =
790  reinterpret_cast<const VDataType*>(kargs.v_ptr) +
791  static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v;
792  ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
793  static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
794  batch_offset_o;
795 
796  // Q/K/V DRAM and DRAM window
797  const auto q_dram = [&]() {
798  const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
799  q_ptr,
800  make_tuple(kargs.seqlen_q, kargs.hdim_q),
801  make_tuple(kargs.stride_q, 1),
803  number<1>{});
804  if constexpr(FmhaPipeline::kQLoadOnce)
805  {
806  return pad_tensor_view(
807  q_dram_naive,
810  }
811  else
812  {
813  return pad_tensor_view(
814  q_dram_naive,
817  }
818  }();
819  const auto k_dram = [&]() {
820  const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
821  k_ptr,
822  make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_q),
823  make_tuple(kargs.stride_k, 1),
825  number<1>{});
826 
827  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
828  return pad_tensor_view(
829  k_dram_naive,
832  }();
833  const auto v_dram = [&]() {
834  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
835  {
836  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
837  v_ptr,
838  make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_v),
839  make_tuple(kargs.stride_v, 1),
841  number<1>{});
842 
843  const auto v_dram_transposed = transform_tensor_view(
844  v_dram_naive,
845  make_tuple(
846  make_pass_through_transform(kargs.hdim_v),
847  make_pass_through_transform(kargs.num_total_pages * kargs.page_block_size)),
850 
851  constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true;
852  return pad_tensor_view(
853  v_dram_transposed,
856  }
857  else
858  {
859  const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
860  v_ptr,
861  make_tuple(kargs.hdim_v, kargs.num_total_pages * kargs.page_block_size),
862  make_tuple(kargs.stride_v, 1),
864  number<1>{});
865 
866  constexpr bool kPadHeadDimV_ = kUseAsyncCopy ? kPadHeadDimV : false;
867  return pad_tensor_view(
868  v_dram_naive,
871  }
872  }();
873 
874  auto q_dram_window = make_tile_window(
875  q_dram,
876  [&]() {
877  if constexpr(FmhaPipeline::kQLoadOnce)
880  else
882  }(),
883  {i_m0, 0});
884 
885  auto k_dram_window = make_tile_window(
887 
888  auto v_dram_window =
889  make_tile_window(v_dram,
891  {i_n1, 0});
894  const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
895  constexpr auto bias_dram_window_lengths =
898  {
899  const BiasDataType* bias_ptr =
900  reinterpret_cast<const BiasDataType*>(kargs.bias_ptr) +
901  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_bias +
902  batch_offset_bias;
903 
904  const auto bias_dram = [&]() {
905  const auto bias_dram_naive = make_naive_tensor_view<address_space_enum::global>(
906  bias_ptr,
907  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
908  make_tuple(kargs.stride_bias, 1),
910  number<1>{});
911 
912  return pad_tensor_view(bias_dram_naive,
913  bias_dram_window_lengths,
915  }();
916 
917  return make_tile_window(bias_dram, bias_dram_window_lengths, {i_m0, 0});
918  }
919  else
920  {
921  return make_null_tile_window(bias_dram_window_lengths);
922  }
923  }();
924 
925  // lse
926  auto lse_dram_window = [&, i_nhead_ = i_nhead]() {
927  constexpr auto lse_dram_window_lengths = make_tuple(number<FmhaPipeline::kM0>{});
928  if constexpr(kStoreLSE)
929  {
930  LSEDataType* lse_ptr =
931  reinterpret_cast<LSEDataType*>(kargs.lse_ptr) +
932  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_lse + batch_offset_lse;
933 
934  const auto lse_dram = [&]() {
935  const auto lse_dram_naive = make_naive_tensor_view<address_space_enum::global>(
936  lse_ptr,
937  make_tuple(kargs.seqlen_q),
938  make_tuple(1),
939  number<1>{},
940  number<1>{});
941 
942  return pad_tensor_view(
943  lse_dram_naive, lse_dram_window_lengths, sequence<kPadSeqLenQ>{});
944  }();
945 
946  return make_tile_window(lse_dram, lse_dram_window_lengths, {i_m0});
947  }
948  else
949  {
950  return make_null_tile_window(lse_dram_window_lengths);
951  }
952  }();
953 
954  auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
955  if constexpr(kHasDropout)
956  {
957  return BlockDropout{i_batch_,
958  i_nhead_,
959  kargs.num_head_q,
960  kargs.is_drop_seed_offset_from_host ? kargs.drop_seed.val
961  : *kargs.drop_seed.ptr,
962  kargs.is_drop_seed_offset_from_host ? kargs.drop_offset.val
963  : *kargs.drop_offset.ptr,
964  kargs.rp_undrop,
965  kargs.p_undrop_in_uint8_t,
966  kargs.is_store_randval};
967  }
968  else
969  {
970  return NullBlockDropout{};
971  };
972  }();
973 
974  auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
975  constexpr auto randval_dram_window_lengths =
977  if constexpr(kHasDropout)
978  {
979  RandValOutputDataType* rand_val_ptr =
980  reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
981  static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
982  batch_offset_randval;
983 
984  const auto randval_dram = [&]() {
985  const auto randval_dram_naive =
986  make_naive_tensor_view<address_space_enum::global>(
987  rand_val_ptr,
988  make_tuple(kargs.seqlen_q, kargs.seqlen_k),
989  make_tuple(kargs.stride_randval, 1),
991  number<1>{});
992 
993  return pad_tensor_view(randval_dram_naive,
994  randval_dram_window_lengths,
996  }();
997 
998  return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0});
999  }
1000  else
1001  {
1002  return make_null_tile_window(randval_dram_window_lengths);
1003  }
1004  }();
1005 
1006  FmhaMask mask = [&]() {
1007  if constexpr(kHasMask)
1008  return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
1009  kargs.window_size_left,
1010  kargs.window_size_right,
1011  kargs.seqlen_q,
1012  kargs.seqlen_k,
1014  else
1015  return FmhaMask{kargs.seqlen_q, kargs.seqlen_k};
1016  }();
1017 
1018  // WA i_batch capture structure binding before c++20
1019  auto position_encoding = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
1020  if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
1021  {
1022  // data loading, shared by entire wg
1023  // TODO: how to use s_read?
1024  SaccDataType slope =
1025  *(reinterpret_cast<const SaccDataType*>(kargs.alibi_slope_ptr) +
1026  i_batch_ * kargs.alibi_slope_stride + i_nhead_);
1027 #if CK_TILE_FMHA_FWD_FAST_EXP2
1028  slope *= ck_tile::log2e_v<>;
1029 #endif
1030  if constexpr(kHasMask)
1031  {
1032  return make_alibi_from_lr_mask<SaccDataType, true>(slope,
1033  kargs.window_size_left,
1034  kargs.window_size_right,
1035  kargs.seqlen_q,
1036  kargs.seqlen_k,
1037  kargs.mask_type);
1038  }
1039  else
1040  {
1042  slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
1043  }
1044  }
1045  else
1046  {
1048  }
1049  }();
1050 
1051  AttentionVariant variant;
1052  const auto variant_params = [&] {
1053  if constexpr(kHasLogitsSoftCap)
1054  {
1056  mask, kargs.scale_s, kargs.logits_soft_cap, kargs.logits_soft_cap_rcp};
1057  }
1058  else
1059  {
1060  return ck_tile::StandardAttentionParams<FmhaMask>{mask, kargs.scale_s};
1061  }
1062  }();
1063 
1064  BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
1065 
1066  auto o_acc_tile = [&]() {
1067  return FmhaPipeline{}(q_dram_window,
1068  k_dram_window,
1069  v_dram_window,
1070  bias_dram_window,
1071  randval_dram_window,
1072  lse_dram_window,
1073  mask,
1074  position_encoding,
1075  kargs.scale_s,
1076  variant,
1077  variant_params,
1078  block_indices,
1079  smem_ptr,
1080  kargs.kv_page_indices,
1081  kargs.stride_k,
1082  kargs.stride_v,
1083  dropout);
1084  }();
1085 
1086  // O DRAM and O DRAM window
1087  auto o_dram = [&]() {
1088  const auto o_dram_naive = make_naive_tensor_view<address_space_enum::global>(
1089  o_ptr,
1090  make_tuple(kargs.seqlen_q, kargs.hdim_v),
1091  make_tuple(kargs.stride_o, 1),
1093  number<1>{});
1094 
1095  return pad_tensor_view(
1096  o_dram_naive,
1099  }();
1100 
1101  auto o_dram_window =
1102  make_tile_window(o_dram,
1104  {i_m0, i_n1});
1105 
1106  EpiloguePipeline{}(o_dram_window, o_acc_tile, nullptr);
1107  }
1108 };
1109 
1110 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_FMHA_FWD_FAST_EXP2
Definition: config.hpp:238
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
#define _TS_
#define _SS_
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ T floor(T x)
Definition: math_v2.hpp:367
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
_Float16 fp16_t
Definition: half.hpp:110
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_view(const OldTensorView &old_tensor_view, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_view.hpp:526
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1558
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
int32_t int32_t
Definition: integer.hpp:10
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
GenericAttentionMaskEnum
Definition: block_masking.hpp:11
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
unsigned char uint8_t
Definition: stdint.h:124
unsigned __int64 uint64_t
Definition: stdint.h:136
Definition: block_position_encoding.hpp:48
Definition: block_attention_bias_enum.hpp:19
Definition: block_attention_quant_scale_enum.hpp:18
Definition: block_dropout.hpp:53
const float rp_undrop
Definition: block_dropout.hpp:371
Definition: block_position_encoding.hpp:137
Definition: fmha_batch_prefill_kernel.hpp:306
ck_tile::index_t kv_head_idx
Definition: fmha_batch_prefill_kernel.hpp:309
ck_tile::index_t qo_head_idx
Definition: fmha_batch_prefill_kernel.hpp:308
ck_tile::index_t batch_idx
Definition: fmha_batch_prefill_kernel.hpp:307
Definition: fmha_batch_prefill_kernel.hpp:194
ck_tile::index_t alibi_slope_stride
Definition: fmha_batch_prefill_kernel.hpp:197
const void * alibi_slope_ptr
Definition: fmha_batch_prefill_kernel.hpp:196
ck_tile::index_t batch_stride_bias
Definition: fmha_batch_prefill_kernel.hpp:190
ck_tile::index_t batch_stride_randval
Definition: fmha_batch_prefill_kernel.hpp:265
ck_tile::index_t batch_stride_o
Definition: fmha_batch_prefill_kernel.hpp:283
ck_tile::index_t batch_stride_v
Definition: fmha_batch_prefill_kernel.hpp:282
ck_tile::index_t batch_stride_q
Definition: fmha_batch_prefill_kernel.hpp:280
ck_tile::index_t batch_stride_k
Definition: fmha_batch_prefill_kernel.hpp:281
ck_tile::index_t nhead_stride_bias
Definition: fmha_batch_prefill_kernel.hpp:185
ck_tile::index_t stride_bias
Definition: fmha_batch_prefill_kernel.hpp:184
const void * bias_ptr
Definition: fmha_batch_prefill_kernel.hpp:183
ck_tile::index_t stride_randval
Definition: fmha_batch_prefill_kernel.hpp:259
void init_dropout(float p_drop, const uint64_t *seed_ptr, const uint64_t *offset_ptr)
Definition: fmha_batch_prefill_kernel.hpp:242
ck_tile::index_t nhead_stride_randval
Definition: fmha_batch_prefill_kernel.hpp:260
void init_dropout(float p_drop, uint64_t seed, uint64_t offset)
Definition: fmha_batch_prefill_kernel.hpp:230
void * rand_val_ptr
Definition: fmha_batch_prefill_kernel.hpp:257
float rp_undrop
Definition: fmha_batch_prefill_kernel.hpp:254
bool is_store_randval
Definition: fmha_batch_prefill_kernel.hpp:256
uint8_t p_undrop_in_uint8_t
Definition: fmha_batch_prefill_kernel.hpp:255
Definition: fmha_batch_prefill_kernel.hpp:120
ck_tile::index_t stride_q
Definition: fmha_batch_prefill_kernel.hpp:148
ck_tile::index_t stride_v
Definition: fmha_batch_prefill_kernel.hpp:150
int32_t num_total_pages
Definition: fmha_batch_prefill_kernel.hpp:136
float scale_s
Definition: fmha_batch_prefill_kernel.hpp:146
ck_tile::index_t seqlen_q
Definition: fmha_batch_prefill_kernel.hpp:126
ck_tile::index_t stride_k
Definition: fmha_batch_prefill_kernel.hpp:149
ck_tile::index_t nhead_stride_o
Definition: fmha_batch_prefill_kernel.hpp:156
ck_tile::index_t nhead_stride_k
Definition: fmha_batch_prefill_kernel.hpp:154
ck_tile::index_t nhead_ratio_qk
Definition: fmha_batch_prefill_kernel.hpp:134
ck_tile::index_t nhead_stride_v
Definition: fmha_batch_prefill_kernel.hpp:155
ck_tile::index_t nhead_stride_q
Definition: fmha_batch_prefill_kernel.hpp:153
const int32_t * kv_page_indices
Definition: fmha_batch_prefill_kernel.hpp:138
const void * v_ptr
Definition: fmha_batch_prefill_kernel.hpp:123
const int32_t * kv_indptr
Definition: fmha_batch_prefill_kernel.hpp:137
void * o_ptr
Definition: fmha_batch_prefill_kernel.hpp:124
ck_tile::index_t seqlen_k
Definition: fmha_batch_prefill_kernel.hpp:127
ck_tile::index_t stride_o
Definition: fmha_batch_prefill_kernel.hpp:151
ck_tile::index_t hdim_v
Definition: fmha_batch_prefill_kernel.hpp:129
ck_tile::index_t num_head_q
Definition: fmha_batch_prefill_kernel.hpp:131
static constexpr ck_tile::index_t page_block_size
Definition: fmha_batch_prefill_kernel.hpp:143
const void * k_ptr
Definition: fmha_batch_prefill_kernel.hpp:122
ck_tile::index_t hdim_q
Definition: fmha_batch_prefill_kernel.hpp:128
const void * q_ptr
Definition: fmha_batch_prefill_kernel.hpp:121
ck_tile::index_t batch_stride_lse
Definition: fmha_batch_prefill_kernel.hpp:211
ck_tile::index_t nhead_stride_lse
Definition: fmha_batch_prefill_kernel.hpp:210
void * lse_ptr
Definition: fmha_batch_prefill_kernel.hpp:209
bool is_drop_seed_offset_from_host
Definition: fmha_batch_prefill_kernel.hpp:225
ValueOrPointer< uint64_t > drop_seed
Definition: fmha_batch_prefill_kernel.hpp:223
ValueOrPointer< uint64_t > drop_offset
Definition: fmha_batch_prefill_kernel.hpp:224
Definition: fmha_batch_prefill_kernel.hpp:113
ck_tile::index_t batch_stride_v
Definition: fmha_batch_prefill_kernel.hpp:300
ck_tile::index_t batch_stride_k
Definition: fmha_batch_prefill_kernel.hpp:299
const int32_t * seqstart_q_ptr
Definition: fmha_batch_prefill_kernel.hpp:298
float logits_soft_cap_rcp
Definition: fmha_batch_prefill_kernel.hpp:178
void init_logits_soft_cap(float logits_soft_cap_)
Definition: fmha_batch_prefill_kernel.hpp:163
float logits_soft_cap
Definition: fmha_batch_prefill_kernel.hpp:177
Definition: fmha_batch_prefill_kernel.hpp:201
ck_tile::index_t window_size_right
Definition: fmha_batch_prefill_kernel.hpp:203
ck_tile::index_t window_size_left
Definition: fmha_batch_prefill_kernel.hpp:203
ck_tile::GenericAttentionMaskEnum mask_type
Definition: fmha_batch_prefill_kernel.hpp:204
Definition: fmha_batch_prefill_kernel.hpp:65
Definition: fmha_batch_prefill_kernel.hpp:27
static constexpr CK_TILE_DEVICE auto GetTileIndex(const Kargs &kargs)
Definition: fmha_batch_prefill_kernel.hpp:618
static constexpr bool kIsGroupMode
Definition: fmha_batch_prefill_kernel.hpp:48
static constexpr ck_tile::index_t kBlockPerCu
Definition: fmha_batch_prefill_kernel.hpp:32
ck_tile::remove_cvref_t< typename FmhaPipeline::VDataType > VDataType
Definition: fmha_batch_prefill_kernel.hpp:38
ck_tile::remove_cvref_t< FmhaPipeline_ > FmhaPipeline
Definition: fmha_batch_prefill_kernel.hpp:28
ck_tile::remove_cvref_t< typename FmhaPipeline::KDataType > KDataType
Definition: fmha_batch_prefill_kernel.hpp:37
static constexpr bool kPadSeqLenQ
Definition: fmha_batch_prefill_kernel.hpp:49
ck_tile::remove_cvref_t< typename FmhaPipeline::RandValOutputDataType > RandValOutputDataType
Definition: fmha_batch_prefill_kernel.hpp:41
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, ck_tile::index_t seqlen_q, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, int32_t num_total_pages, const void *kv_indptr, const void *kv_page_indices, float scale_s, [[maybe_unused]] float scale_p, [[maybe_unused]] float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_lse, ck_tile::index_t batch_stride_o, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_batch_prefill_kernel.hpp:314
static constexpr bool kPadHeadDimV
Definition: fmha_batch_prefill_kernel.hpp:52
ck_tile::remove_cvref_t< typename FmhaPipeline::LSEDataType > LSEDataType
Definition: fmha_batch_prefill_kernel.hpp:42
ck_tile::remove_cvref_t< typename FmhaPipeline::QDataType > QDataType
Definition: fmha_batch_prefill_kernel.hpp:36
static constexpr bool kHasMask
Definition: fmha_batch_prefill_kernel.hpp:60
static CK_TILE_HOST std::string GetName()
Definition: fmha_batch_prefill_kernel.hpp:73
ck_tile::remove_cvref_t< typename FmhaPipeline::BiasDataType > BiasDataType
Definition: fmha_batch_prefill_kernel.hpp:39
ck_tile::remove_cvref_t< typename FmhaPipeline::FmhaMask > FmhaMask
Definition: fmha_batch_prefill_kernel.hpp:59
static CK_TILE_HOST dim3 BlockSize()
Definition: fmha_batch_prefill_kernel.hpp:677
static constexpr CK_TILE_HOST std::enable_if_t< Cond, Kargs > MakeKargs(const void *q_ptr, const void *k_ptr, const void *v_ptr, const void *bias_ptr, void *rand_val_ptr, void *lse_ptr, void *o_ptr, const void *seqstart_q_ptr, ck_tile::index_t hdim_q, ck_tile::index_t hdim_v, ck_tile::index_t num_head_q, ck_tile::index_t nhead_ratio_qk, int32_t num_total_pages, const void *kv_indptr, const void *kv_page_indices, float scale_s, [[maybe_unused]] float scale_p, [[maybe_unused]] float scale_o, float logits_soft_cap, ck_tile::index_t stride_q, ck_tile::index_t stride_k, ck_tile::index_t stride_v, ck_tile::index_t stride_bias, ck_tile::index_t stride_randval, ck_tile::index_t stride_o, ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_o, ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_v, ck_tile::index_t window_size_left, ck_tile::index_t window_size_right, ck_tile::index_t mask_type, float p_drop, bool s_randval, std::variant< std::pair< uint64_t, uint64_t >, std::pair< const void *, const void * >> drop_seed_offset)
Definition: fmha_batch_prefill_kernel.hpp:460
static constexpr bool kPadSeqLenK
Definition: fmha_batch_prefill_kernel.hpp:50
static constexpr auto QScaleEnum
Definition: fmha_batch_prefill_kernel.hpp:57
static constexpr bool kHasLogitsSoftCap
Definition: fmha_batch_prefill_kernel.hpp:53
static constexpr bool kHasDropout
Definition: fmha_batch_prefill_kernel.hpp:56
static constexpr bool kStoreLSE
Definition: fmha_batch_prefill_kernel.hpp:55
ck_tile::remove_cvref_t< typename FmhaPipeline::SaccDataType > SaccDataType
Definition: fmha_batch_prefill_kernel.hpp:44
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fmha_batch_prefill_kernel.hpp:689
static constexpr ck_tile::index_t kBlockSize
Definition: fmha_batch_prefill_kernel.hpp:30
static constexpr auto BiasEnum
Definition: fmha_batch_prefill_kernel.hpp:54
ck_tile::remove_cvref_t< typename FmhaPipeline::VLayout > VLayout
Definition: fmha_batch_prefill_kernel.hpp:46
ck_tile::remove_cvref_t< typename FmhaPipeline::ODataType > ODataType
Definition: fmha_batch_prefill_kernel.hpp:43
ck_tile::remove_cvref_t< typename FmhaPipeline::AttentionVariant > AttentionVariant
Definition: fmha_batch_prefill_kernel.hpp:58
static constexpr bool kUseAsyncCopy
Definition: fmha_batch_prefill_kernel.hpp:62
static constexpr ck_tile::index_t kBlockPerCuInput
Definition: fmha_batch_prefill_kernel.hpp:34
static constexpr CK_TILE_HOST auto GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_, ck_tile::index_t hdim_v_)
Definition: fmha_batch_prefill_kernel.hpp:595
static constexpr bool kPadHeadDimQ
Definition: fmha_batch_prefill_kernel.hpp:51
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: fmha_batch_prefill_kernel.hpp:694
std::conditional_t< kIsGroupMode, FmhaFwdGroupModeKargs, FmhaFwdBatchModeKargs > Kargs
Definition: fmha_batch_prefill_kernel.hpp:303
ck_tile::remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: fmha_batch_prefill_kernel.hpp:29
Definition: variants.hpp:63
float logits_soft_cap
Definition: variants.hpp:128
Definition: block_dropout.hpp:39
Definition: variants.hpp:51
Definition: integral_constant.hpp:13
Definition: coordinate_transform.hpp:1392
Definition: sequence.hpp:49