/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.2/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.2/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.2/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp Source File
block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
13 
14 namespace ck_tile {
15 
16 // a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
17 template <typename Problem_,
20 {
36 
39  static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
40  static_assert(kQLoadOnce == Policy::QLoadOnce);
41 
42  static constexpr index_t kBlockSize = Problem::kBlockSize;
43 
44  static constexpr index_t kM0 = BlockFmhaShape::kM0;
45  static constexpr index_t kN0 = BlockFmhaShape::kN0;
46  static constexpr index_t kK0 = BlockFmhaShape::kK0;
47  static constexpr index_t kN1 = BlockFmhaShape::kN1;
48  static constexpr index_t kK1 = BlockFmhaShape::kK1;
49  static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
50  static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
51  static constexpr auto I0 = number<0>{};
52  static constexpr auto I1 = number<1>{};
53  static constexpr auto I2 = number<2>{};
54  static constexpr auto I3 = number<3>{};
55 
56  static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
57 
58  static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
59  // TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
60  // only need special care about seq_k padding (oob need set -INF of p instead of zero)
61  static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true &&
62  Problem::kPadHeadDimV == true);
63  static constexpr bool kPadSeqLenQ = true;
64  static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
65  static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
66  static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
67  static constexpr bool kHasLogitsSoftCap = Problem::kHasLogitsSoftCap;
68  static constexpr auto BiasEnum = Problem::BiasEnum;
69  static constexpr bool kStoreLSE = Problem::kStoreLSE;
70  static constexpr bool kHasDropout = Problem::kHasDropout;
71 
72  static_assert((CK_TILE_FMHA_FWD_FAST_EXP2 &&
73  (kHasLogitsSoftCap && Problem::BiasEnum == BlockAttentionBiasEnum::NO_BIAS ||
74  !kHasLogitsSoftCap)) ||
76 
77  // last dimension vector length used to create tensor view(and decide buffer_load vector length)
78  // ... together with tensor distribution. tensor dist should able to overwrite this
79  static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>();
80  static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>();
81  static constexpr index_t kAlignmentV = []() {
82  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
83  return Policy::template GetAlignmentV<Problem>();
84  else
85  return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
86  }();
87  static constexpr index_t kAlignmentO = Policy::template GetAlignmentO<Problem>();
88  static constexpr index_t kAlignmentBias =
89  kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
90 
91 #if CK_TILE_FMHA_FWD_FAST_EXP2
92  static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
93 #endif
94 
95  static constexpr index_t kBlockPerCu = []() {
96  if constexpr(Problem::kBlockPerCu != -1)
97  return Problem::kBlockPerCu;
98  else
99  {
100  // minimize occupancy
102  {
103  return 1;
104  }
105 
106  if constexpr(kQKHeaddim <= 32)
107  {
109  FmhaMask::IsMasking)
110  return 1;
111  else
112  return 2;
113  }
114  else if constexpr(kQKHeaddim <= 64)
115  {
117  return 2;
118  else
119  return 3;
120  }
121  else if constexpr(kQKHeaddim <= 128)
122  {
124  return 1;
125  // use larger K/V LDS buffer size will lower the occupancy
126  else if constexpr(64 <= kK0 || 64 <= kK1)
127  return 1;
128  else
129  return 2;
130  }
131  else if constexpr(kQKHeaddim <= 192)
132  {
134  return 1;
135  else
136  return 2;
137  }
138  else if constexpr(kQKHeaddim <= 256)
139  {
140  return 1;
141  }
142  else
143  {
144  return 1;
145  };
146  }
147  }();
148 
149  static constexpr const char* name = "qr_async";
150 
151  using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
152 
154  {
155  return Policy::template GetSmemSize<Problem>();
156  }
157 
158  template <typename QDramBlockWindowTmp,
159  typename KDramBlockWindowTmp,
160  typename VDramBlockWindowTmp,
161  typename BiasDramBlockWindowTmp,
162  typename RandValDramBlockWindowTmp,
163  typename LSEDramBlockWindowTmp,
164  typename QElementFunction,
165  typename KElementFunction,
166  typename VElementFunction,
167  typename BiasElementFunction,
168  typename LSEElementFunction,
169  typename SAccElementFunction,
170  typename PComputeElementFunction,
171  typename OAccElementFunction,
172  typename PositionEncoding,
173  typename AttentionVariantParams,
174  typename BlockIndices>
176  operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
177  const QElementFunction& q_element_func,
178  const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
179  const KElementFunction& /*k_element_func*/,
180  const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
181  const VElementFunction& v_element_func,
182  const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
183  const BiasElementFunction& bias_element_func,
184  RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
185  LSEDramBlockWindowTmp& lse_dram_window_tmp, // M0*1 tile
186  const LSEElementFunction& lse_element_func,
187  const SAccElementFunction& s_acc_element_func,
188  const PComputeElementFunction& p_compute_element_func,
189  const OAccElementFunction& o_acc_element_func,
190  FmhaMask mask,
191  PositionEncoding position_encoding,
192  float scale_s,
193  const AttentionVariant& variant,
194  const AttentionVariantParams& variant_params,
195  const BlockIndices& block_indices,
196  void* smem_ptr,
197  const index_t* page_idx,
198  const index_t stride_k,
199  const index_t stride_v,
200  DropoutType& dropout) const
201  {
202  static_assert(
206  "wrong!");
207 
208  static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
209  kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
210  kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
211  kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
212  kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
213  kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
214  kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
215  "wrong!");
216 
217  constexpr auto LdsSeq = Policy::template GetLdsBufferSequence<Problem>();
218 
219  // K tile in LDS
220  auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
221  auto k_lds_store = generate_tuple(
222  [&](auto i_buf) {
223  return make_tile_window(
224  make_tensor_view<address_space_enum::lds>(
225  k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf)),
226  Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf).get_lengths(),
227  {0, 0, 0});
228  },
230 
231  auto k_lds_Load_view = make_tensor_view<address_space_enum::lds>(
232  k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor<Problem>());
233 
234  auto k_lds_load =
235  make_tile_window(k_lds_Load_view,
236  Policy::template MakeKLdsLoadBlockDescriptor<Problem>().get_lengths(),
237  {0, 0});
238 
239  // V tile in LDS
240  auto v_lds = make_tensor_view<address_space_enum::lds>(
241  reinterpret_cast<VDataType*>(smem_ptr),
242  Policy::template MakeVLdsBlockDescriptor<Problem>());
243  auto v_lds_window = make_tile_window(
244  v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
245 
246  // Block GEMM
247  constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
248  constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
249 
250  auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
251  q_dram_block_window_tmp.get_window_lengths(),
252  q_dram_block_window_tmp.get_window_origin(),
253  Policy::template MakeQRegTileDistribution<Problem>());
254  q_dram_window.init_raw();
255 
256  // TODO: we use async Copy for K, which is inline asm
257  // a side effect is we have to use inline asm for q as well
258  auto q = decltype(load_tile(q_dram_window)){};
259  // TODO: start from rocm-6.2, compiler will have problem if manually set clear of q.
260  // however, q would be cleared in the constructor of static distributed tensor
261  // set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
262  load_tile_raw(q, q_dram_window);
263  __builtin_amdgcn_sched_barrier(0);
264 
265  using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
266  auto s_acc = SaccBlockTileType{};
267 
268  // reduction function for softmax
269  const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
270  const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
271 
272  // infer Sacc, S, P, M, L, Oacc type
273  using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_acc));
274 
275  using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
276  SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
277 
278  using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
279 
280  // init Oacc, M, L
281  auto o_acc = OaccBlockTileType{};
282  auto m = MLBlockTileType{};
283  auto l = MLBlockTileType{};
284 
285  clear_tile(o_acc);
287  clear_tile(l);
288 
289  __builtin_amdgcn_sched_barrier(0);
290  const auto q_origin = q_dram_window.get_window_origin();
291  const auto [seqlen_k_start, seqlen_k_end] =
292  mask.GetTileRangeAlongX(q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
293 
294  const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
295 
296  // check early exit if no work to do
297  if constexpr(FmhaMask::IsMasking || kPadSeqLenK)
298  {
299  if(num_total_loop <= 0)
300  {
301  if constexpr(kStoreLSE)
302  {
303  auto lse =
304  make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
305 
307 
308  store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
309  }
310  buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
311  // otherwise will have compute error(maybe compiler bug?)
312 
313  // Note: here occ are all cleard, return it
314  return o_acc;
315  }
316  __builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check
317  }
318 
319  auto k_dram_block_window =
320  make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
321  k_dram_block_window_tmp.get_window_lengths(),
322  {seqlen_k_start, 0});
323 
324  auto k_dist = Policy::template MakeKDramTileDistribution<Problem>();
325  auto k_coord = k_dist.calculate_index();
326  using KDstrEncode = typename decltype(k_dist)::DstrEncode;
327  constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0];
329  static_for<0, NRepeat, 1>{}([&](auto n0) {
330  k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k;
331  });
332  auto k_dram_window = make_tile_scatter_gather(k_dram_block_window.get_bottom_tensor_view(),
333  k_dram_block_window.get_window_lengths(),
334  k_dram_block_window.get_window_origin(),
335  k_dist,
336  k_offsets); // K DRAM tile window for
337  k_dram_window.init_raw();
338  constexpr auto k_oob_ck = bool_constant<true>{};
339  constexpr auto k_pre_np = [&]() {
340  if constexpr(kPadSeqLenK &&
343  return bool_constant<true>{};
344  else
345  return bool_constant<false>{};
346  }();
347 
348  const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
349  auto bias_dram_window =
350  make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
351  bias_dram_block_window_tmp.get_window_lengths(),
352  {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
353  Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
354 
355  auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
356  randval_dram_block_window_tmp, seqlen_k_start);
357 
358  auto v_dist = Policy::template MakeVDramTileDistribution<Problem>();
359  auto v_coord = v_dist.calculate_index();
360  const auto VPageIndexDim = I1;
361  using VDstrEncode = typename decltype(v_dist)::DstrEncode;
362  constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3];
364  (void)stride_k;
365  static_for<0, V_KRepeat, 1>{}([&](auto k0) {
366  v_offsets[k0] = page_idx[v_coord[VPageIndexDim] + k0.value] * stride_v;
367  });
368 
369  auto v_dram_window =
370  make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(),
371  v_dram_block_window_tmp.get_window_lengths(),
372  {0, seqlen_k_start}, // TODO: hdim split?
373  v_dist,
374  v_offsets,
375  VPageIndexDim);
376 
377  // prefetch K tile
379  k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np);
380  move_tile_window(k_dram_window, {0, kK0});
381  __builtin_amdgcn_sched_barrier(0);
382 
383  buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer());
384  (void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
385  // auto q_tile = q; // tile_elementwise_in(q_element_func, q);
386 
387  index_t i_total_loops = 0;
388  constexpr index_t k0_loops = kQKHeaddim / kK0;
389  constexpr index_t k1_loops = kN0 / kK1;
390 
391  static_assert(1 <= k0_loops);
392  static_assert(1 <= k1_loops);
393  // main loop
394  do
395  {
396  // STAGE 1, QK gemm
397  clear_tile(s_acc); // initialize C
398  if constexpr(k0_loops > 1)
399  {
400  static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
401  async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}),
402  k_dram_window,
403  number<-1>{},
404  k_oob_ck,
405  k_pre_np);
406  if constexpr(i_k0 < k0_loops - 1)
407  move_tile_window(k_dram_window, {0, kK0});
408 
409  async_load_fence(k_dram_window.get_num_of_access());
410  __builtin_amdgcn_s_barrier();
411  __builtin_amdgcn_sched_barrier(0);
412  gemm_0(s_acc,
414  q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
415  get_slice_tile(k_lds_load,
416  sequence<(LdsSeq.at(number<i_k0>{})) * kN0, 0>{},
417  sequence<(LdsSeq.at(number<i_k0>{}) + 1) * kN0, kK0>{}));
418  });
419  }
420 
421  // TODO: this to fix a bug when loop smaller than 2,
422  // the following fence/barrier will be scheduled inside 1st loop
423  if constexpr(k0_loops <= 2)
424  __builtin_amdgcn_sched_barrier(0);
425 
427  __builtin_amdgcn_s_barrier();
428 
429  const auto bias_tile = load_tile(bias_dram_window); // load bias tile
430  auto v_buf = load_tile(v_dram_window, number<-1>{}, bool_constant<false>{});
431  static_for<0, V_KRepeat, 1>{}([&](auto k0) {
432  v_offsets[k0] = page_idx[kK1 + v_coord[VPageIndexDim] + k0.value] * stride_v;
433  });
434  v_dram_window.update_page_idx(v_offsets);
435 
436  __builtin_amdgcn_sched_barrier(0);
437  { // tail
438  gemm_0(
439  s_acc,
441  q, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kM0, k0_loops * kK0>{}),
442  get_slice_tile(k_lds_load,
443  sequence<(LdsSeq.at(number<k0_loops - 1>{})) * kN0, 0>{},
444  sequence<(LdsSeq.at(number<k0_loops - 1>{}) + 1) * kN0, kK0>{}));
445  }
446  __builtin_amdgcn_sched_barrier(1);
447 
448  // STAGE 2, scale_s, add bias, mask, softmax
450  {
451  s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
452  tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
454  [&](auto& x, const auto& y) {
455 #if !CK_TILE_FMHA_FWD_FAST_EXP2
456  x += type_convert<SaccDataType>(bias_element_func(y));
457 #else
458  x += log2e_v<SaccDataType> *
459  type_convert<SaccDataType>(bias_element_func(y));
460 #endif
461  },
462  s_acc,
463  bias_tile);
464  }
465  else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
466  {
467  const auto k_origin = k_dram_block_window.get_window_origin();
468  constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
469  s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
470  sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
471  sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
472  const auto tile_idx = get_x_indices_from_distributed_indices(
473  s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
474 
475  const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
476  const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
477  constexpr auto i_j_idx = make_tuple(idx0, idx1);
478 
479  s_acc(i_j_idx) *= scale_s;
480  position_encoding.update(s_acc(i_j_idx), row, col);
481  });
482  });
483  }
484  else
485  {
486  s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
487  if constexpr(kHasLogitsSoftCap)
488  {
489  auto apply_logits_transform =
490  [&variant, &variant_params, &block_indices](auto& x) {
491  x = variant.LogitsTransform(variant_params,
492  variant.QueryTransform(variant_params, x),
493  block_indices.batch_idx,
494  block_indices.qo_head_idx,
495  block_indices.kv_head_idx);
496  };
497 #if !CK_TILE_FMHA_FWD_FAST_EXP2
498  for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
499  {
500  apply_logits_transform(s_acc.thread_buf_[i]);
501  }
502 #else
503  for(index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
504  {
505 #if(defined(__gfx90a__) || defined(__gfx94__)) && \
506  (CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \
507  CK_TILE_ATTENTION_USE_SOFTSIGN_ASM)
508  // Avoid data hazard if v_mfma is followed by inline asm consumer
509  // instructions. In this case, compiler won't add s_nop for us
510  if(i == s_acc.thread_buf_.size() / 2)
511  {
512  __builtin_amdgcn_sched_barrier(0);
513  }
514 #endif
515  apply_logits_transform(s_acc.thread_buf_[i]);
516  }
517 #endif
518  }
519  else
520  {
521 #if !CK_TILE_FMHA_FWD_FAST_EXP2
522  tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
523 #endif
524  }
525  }
526  move_tile_window(bias_dram_window, {0, kN0});
527  if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
528  {
529  const auto k_origin = k_dram_block_window.get_window_origin();
530  bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
531  k_origin.at(number<0>{}),
532  number<kM0>{},
533  number<kN0>{});
534 
535  if(need_perpixel_check)
536  {
537  set_tile_if(
538  s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
539  const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
540  const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
541  return !variant.LogitsMask(variant_params,
542  block_indices.batch_idx,
543  row,
544  col,
545  block_indices.qo_head_idx,
546  block_indices.kv_head_idx);
547  });
548  }
549  }
550 
551  const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
552  auto m_local = block_tile_reduce<SMPLComputeDataType>(
553  s,
554  sequence<1>{},
555  f_max,
556  -numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
557  block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
558 
559  const auto m_old = m; // m{j-1}
561  [](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
562 
563  auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
564  s.get_tile_distribution()); // Pcompute{j}
565 
566  __builtin_amdgcn_sched_barrier(0x7F);
567  // store & prefetch next v, after the max reduction
568  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
569  {
570  auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
571  Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
572  shuffle_tile(v_shuffle_tmp, v_buf);
573 
574  auto v_lds_window_tmp =
575  get_slice_tile(v_lds_window,
576  sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
577  sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
578 
579  store_tile(
580  v_lds_window_tmp,
581  tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
582  }
583  else
584  {
585  auto v_lds_window_tmp =
586  get_slice_tile(v_lds_window,
587  sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
588  sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
589  store_tile(v_lds_window_tmp,
590  tile_elementwise_in(v_element_func, v_buf)); // store the prefetch
591  }
592 
593  if constexpr(k1_loops > 1)
594  {
596  v_dram_window,
597  {0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
598  v_buf = load_tile(
599  v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
600  static_for<0, V_KRepeat, 1>{}([&](auto k0) {
601  v_offsets[k0] =
602  page_idx[kK1 * 2 + v_coord[VPageIndexDim] + k0.value] * stride_v;
603  });
604  v_dram_window.update_page_idx(v_offsets);
605  }
606  __builtin_amdgcn_sched_barrier(0);
607 
608  static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
612  FmhaMask::IsMasking)
613  {
614  return raw_m == -numeric<SMPLComputeDataType>::infinity()
615  ? type_convert<SMPLComputeDataType>(0.f)
616  : raw_m;
617  }
618  else
619  {
620  return raw_m;
621  }
622  };
623 
624  constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
625  sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
626  constexpr auto i_idx = make_tuple(idx0);
627 #if CK_TILE_FMHA_FWD_FAST_EXP2
628  auto row_max = scale_s * get_validated_m(m[i_idx]);
629 #endif
630  sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
631  constexpr auto i_j_idx = make_tuple(idx0, idx1);
632 #if CK_TILE_FMHA_FWD_FAST_EXP2
635  {
636  p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
637  }
638  else
639  {
640  if constexpr(kHasLogitsSoftCap)
641  {
642  p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
643  }
644  else
645  {
646  p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
647  }
648  }
649 #else
650  p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
651 #endif
652  });
653  });
654 
655  auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
656  p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
657 
658  block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
659  // l{j}, Oacc{j}
660  constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
661  sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
662  constexpr auto i_idx = make_tuple(idx0);
663 #if CK_TILE_FMHA_FWD_FAST_EXP2
664  const auto tmp = [&]() {
667  {
668  return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
669  }
670  else
671  {
672  if constexpr(kHasLogitsSoftCap)
673  {
674  return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
675  }
676  else
677  {
678  auto row_max = scale_s * get_validated_m(m[i_idx]);
679  return exp2(scale_s * m_old[i_idx] - row_max);
680  }
681  }
682  }();
683 #else
684  const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
685 #endif
686  l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
687  sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
688  constexpr auto i_j_idx = make_tuple(idx0, idx1);
689  // FIXME: this use different equation from FA v2 paper,
690  // but produce correc result.
691  // Is the equation wrong?
692  o_acc(i_j_idx) *= tmp;
693  });
694  });
695 
696  if constexpr(kHasDropout)
697  {
698  auto randval_ptr =
699  reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
700  dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
701  randval_ptr,
702  seqlen_k_start + i_total_loops * kN0,
703  p_compute,
704  randval_dram_window);
705  }
706 
707  const auto p = [&]() {
708 #if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
709  // For fp32 to fp16,
710  // impl::cast_tile_pk_fp16_fp32 would cause precision issue,
711  // since it uses __builtin_amdgcn_cvt_pkrtz, which is round to zero.
712  return cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
713 #else
714  if constexpr(std::is_same_v<PDataType, fp16_t>)
715  return impl::cast_tile_pk_fp16_fp32<PDataType>(
716  tile_elementwise_in(p_compute_element_func, p_compute));
717  else
718  return cast_tile<PDataType>(
719  tile_elementwise_in(p_compute_element_func, p_compute));
720 #endif
721  }();
722 
723  // STAGE 3, KV gemm
724  if constexpr(k1_loops > 1)
725  {
726  static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
727  if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
728  {
729  v_buf = load_tile(
730  v_dram_window, number<-1>{}, bool_constant<false>{}); // load next v_buf
731  static_for<0, V_KRepeat, 1>{}([&](auto k0) {
732  v_offsets[k0] = page_idx[kK1 * 2 + i_k1.value * kK1 +
733  v_coord[VPageIndexDim] + k0.value] *
734  stride_v;
735  });
736  v_dram_window.update_page_idx(v_offsets);
737  }
738  block_sync_lds();
739  gemm_1(o_acc,
741  p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
743  v_lds_window,
744  sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
745  sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1, kK1>{}));
746 
747  if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
748  {
749  auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
750  Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
751  shuffle_tile(v_shuffle_tmp, v_buf);
752  auto v_lds_window_tmp = get_slice_tile(
753  v_lds_window,
754  sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
755  sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
756  store_tile(v_lds_window_tmp,
757  tile_elementwise_in(v_element_func,
758  v_shuffle_tmp)); // store the prefetch
759  }
760  else
761  {
762  auto v_lds_window_tmp = get_slice_tile(
763  v_lds_window,
764  sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
765  sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
766  store_tile(v_lds_window_tmp,
767  tile_elementwise_in(v_element_func, v_buf)); // store next v_buf
768  }
769  if constexpr(i_k1 < k1_loops - 1)
770  move_tile_window(v_dram_window, {0, kK1});
771  });
772  }
773  i_total_loops++;
774  if(i_total_loops < num_total_loop)
775  {
776  page_idx += kN0;
777  // move K tile windows
778  move_tile_window(k_dram_block_window, {kN0, 0});
779  k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
780 
781  static_for<0, NRepeat, 1>{}([&](auto n0) {
782  k_offsets[n0] = page_idx[k_coord[0] + kN0 / NRepeat * n0.value] * stride_k;
783  });
784  k_dram_window.update_page_idx(k_offsets);
785  if constexpr(k1_loops >= 2 &&
786  LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
787  __builtin_amdgcn_s_barrier();
788  async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})),
789  k_dram_window,
790  number<-1>{},
791  k_oob_ck,
792  k_pre_np);
793  move_tile_window(k_dram_window, {0, kK0});
794  }
795  // tail
796  {
797  block_sync_lds();
798  gemm_1(
799  o_acc,
800  get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
802  v_lds_window,
803  sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
804  sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{}));
805  }
806  } while(i_total_loops < num_total_loop);
807 
808  // store lse
809  if constexpr(kStoreLSE)
810  {
811  auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
812 
813  constexpr auto lse_spans = decltype(lse)::get_distributed_spans();
814  sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
815  constexpr auto i_idx = make_tuple(idx0);
816 #if CK_TILE_FMHA_FWD_FAST_EXP2
819  {
820  lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]);
821  }
822  else
823  {
824  if constexpr(kHasLogitsSoftCap)
825  {
826  lse(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]);
827  }
828  else
829  {
830  lse(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]);
831  }
832  }
833 #else
834  lse(i_idx) = m_[i_idx] + log(l_[i_idx]);
835 #endif
836  });
837 
838  store_tile(lse_dram_window_tmp, tile_elementwise_in(lse_element_func, lse));
839  }
840 
841  // finally, O
842  constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
843 
844  sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
845  constexpr auto i_idx = make_tuple(idx0);
846  const auto tmp = [&]() {
847  if constexpr(FmhaMask::IsMasking)
848  {
849  return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
850  }
851  else
852  return 1 / l[i_idx];
853  }();
854  sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
855  constexpr auto i_j_idx = make_tuple(idx0, idx1);
856  o_acc(i_j_idx) *= tmp;
857  });
858  });
859 
860  o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
861 
862  return o_acc;
863  }
864 
865  template <typename QDramBlockWindowTmp,
866  typename KDramBlockWindowTmp,
867  typename VDramBlockWindowTmp,
868  typename BiasDramBlockWindowTmp,
869  typename RandValDramBlockWindowTmp,
870  typename LSEDramBlockWindowTmp,
871  typename PositionEncoding,
872  typename AttentionVariantParams,
873  typename BlockIndices>
875  operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
876  const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
877  const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
878  const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
879  RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
880  LSEDramBlockWindowTmp& lse_dram_block_window_tmp, // M0*1 tile
881  FmhaMask mask,
882  PositionEncoding position_encoding,
883  float scale_s,
884  const AttentionVariant& variant,
885  const AttentionVariantParams& variant_params,
886  const BlockIndices& block_indices,
887  void* smem_ptr,
888  const index_t* page_idx,
889  const index_t stride_k,
890  const index_t stride_v,
891  DropoutType& dropout) const
892  {
893  return operator()(q_dram_block_window_tmp,
894  identity{},
895  k_dram_block_window_tmp,
896  identity{},
897  v_dram_block_window_tmp,
898  identity{},
899  bias_dram_block_window_tmp,
900  identity{},
901  randval_dram_block_window_tmp,
902  lse_dram_block_window_tmp,
903  identity{},
904  identity{},
905  identity{},
906  identity{},
907  mask,
908  position_encoding,
909  scale_s,
910  variant,
911  variant_params,
912  block_indices,
913  smem_ptr,
914  page_idx,
915  stride_k,
916  stride_v,
917  dropout);
918  }
919 };
920 
921 } // namespace ck_tile
#define CK_TILE_FMHA_FWD_FAST_EXP2
Definition: config.hpp:223
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:421
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
constexpr CK_TILE_DEVICE auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition: slice_tile.hpp:23
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
CK_TILE_DEVICE auto async_load_fence(index_t cnt=0)
Definition: load_tile.hpp:122
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition: tile_elementwise.hpp:23
tuple_array< T, N > statically_indexed_array
Definition: statically_indexed_array.hpp:16
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:83
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={})
Definition: block_reduce.hpp:18
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constant< v > number
Definition: integral_constant.hpp:33
CK_TILE_DEVICE void shuffle_tile(OutTensor &out, const InTensor &in)
Definition: shuffle_tile.hpp:154
BlockFmhaPipelineQXKSVSCustomPolicy< true, true, 3, 3 > BlockFmhaBatchPrefillPipelineQRKSVSAsyncDefaultPolicy
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp:16
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:412
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
CK_TILE_DEVICE auto load_tile_raw(T &tile, const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Loads a tile of data using inline assembly.
Definition: load_tile.hpp:58
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:92
CK_TILE_DEVICE void buffer_load_fence(index_t cnt=0)
Definition: amd_buffer_addressing.hpp:756
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition: static_distributed_tensor.hpp:175
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:406
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
constexpr CK_TILE_DEVICE auto make_tile_scatter_gather(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, const StaticPageIndexArray_ &page_idx, number< HsGatherDim >={}, number< NumCoord >={})
Definition: tile_scatter_gather.hpp:729
constexpr CK_TILE_HOST_DEVICE auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition: static_distributed_tensor.hpp:159
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: load_tile.hpp:110
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:177
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition: bfloat16.hpp:418
constexpr bool is_same_v
Definition: type.hpp:283
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:20
static constexpr bool kPadSeqLenK
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:64
remove_cvref_t< typename Problem::BiasDataType > BiasDataType
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:28
remove_cvref_t< typename Problem::ODataType > ODataType
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:33
static constexpr index_t kK1
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:48
static constexpr index_t kAlignmentV
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:81
remove_cvref_t< typename Problem::KDataType > KDataType
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:24
static constexpr index_t kN1
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:47
static constexpr auto I1
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:52
remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:37
remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:35
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:153
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, const KElementFunction &, const VDramBlockWindowTmp &v_dram_block_window_tmp, const VElementFunction &v_element_func, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, const BiasElementFunction &bias_element_func, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr, const index_t *page_idx, const index_t stride_k, const index_t stride_v, DropoutType &dropout) const
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:176
static constexpr index_t kM0
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:44
remove_cvref_t< typename Problem::AttentionVariant > AttentionVariant
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:34
static constexpr index_t kAlignmentO
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:87
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:67
static constexpr auto BiasEnum
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:68
remove_cvref_t< Policy_ > Policy
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:22
static constexpr auto I0
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:51
static constexpr index_t kSubQKHeaddim
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:50
static constexpr bool kHasDropout
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:70
static constexpr index_t kBlockSize
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:42
static constexpr bool kPadHeadDimQ
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:65
static constexpr index_t kQKHeaddim
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:49
static constexpr index_t kAlignmentK
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:80
static constexpr index_t kN0
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:45
remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:32
static constexpr index_t kBlockPerCu
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:95
static constexpr bool kStoreLSE
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:69
static constexpr auto I3
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:54
remove_cvref_t< typename Problem::QDataType > QDataType
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:23
static constexpr bool kPadHeadDimV
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:66
static constexpr bool kPadSeqLenQ
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:63
static constexpr index_t kAlignmentBias
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:88
remove_cvref_t< typename BlockFmhaShape::VLayout > VLayout
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:38
std::conditional_t< kHasDropout, BlockDropout, NullBlockDropout > DropoutType
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:151
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr, const index_t *page_idx, const index_t stride_k, const index_t stride_v, DropoutType &dropout) const
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:875
remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:26
remove_cvref_t< typename Problem::RandValOutputDataType > RandValOutputDataType
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:29
static constexpr const char * name
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:149
static constexpr index_t kK0
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:46
remove_cvref_t< typename Problem::VDataType > VDataType
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:25
remove_cvref_t< typename Problem::PDataType > PDataType
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:31
remove_cvref_t< Problem_ > Problem
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:21
static constexpr index_t kAlignmentQ
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:79
static constexpr auto I2
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:53
remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:27
static constexpr bool kQLoadOnce
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:39
static constexpr bool kIsGroupMode
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:58
remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:30
Definition: integral_constant.hpp:13
Definition: functional.hpp:86
Definition: numeric.hpp:18
static constexpr CK_TILE_HOST_DEVICE T infinity()
Definition: numeric.hpp:38
Definition: sequence.hpp:52
Definition: functional.hpp:43