17 template <
typename Problem_,
40 static_assert(
kQLoadOnce == Policy::QLoadOnce);
56 static_assert(
kSubQKHeaddim <= 256,
"hdim bigger than 256 is not suitable for this pipeline!");
61 static_assert(Problem::kPadSeqLenQ ==
true && Problem::kPadHeadDimQ ==
true &&
62 Problem::kPadHeadDimV ==
true);
68 static constexpr
auto BiasEnum = Problem::BiasEnum;
69 static constexpr
bool kStoreLSE = Problem::kStoreLSE;
82 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
83 return Policy::template GetAlignmentV<Problem>();
85 return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
89 kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
91 #if CK_TILE_FMHA_FWD_FAST_EXP2
92 static constexpr
auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
96 if constexpr(Problem::kBlockPerCu != -1)
97 return Problem::kBlockPerCu;
126 else if constexpr(64 <=
kK0 || 64 <=
kK1)
149 static constexpr
const char*
name =
"qr_async";
151 using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
155 return Policy::template GetSmemSize<Problem>();
158 template <
typename QDramBlockWindowTmp,
159 typename KDramBlockWindowTmp,
160 typename VDramBlockWindowTmp,
161 typename BiasDramBlockWindowTmp,
162 typename RandValDramBlockWindowTmp,
163 typename LSEDramBlockWindowTmp,
164 typename QElementFunction,
165 typename KElementFunction,
166 typename VElementFunction,
167 typename BiasElementFunction,
168 typename LSEElementFunction,
169 typename SAccElementFunction,
170 typename PComputeElementFunction,
171 typename OAccElementFunction,
172 typename PositionEncoding,
173 typename AttentionVariantParams,
174 typename BlockIndices>
176 operator()(
const QDramBlockWindowTmp& q_dram_block_window_tmp,
177 const QElementFunction& q_element_func,
178 const KDramBlockWindowTmp& k_dram_block_window_tmp,
179 const KElementFunction& ,
180 const VDramBlockWindowTmp& v_dram_block_window_tmp,
181 const VElementFunction& v_element_func,
182 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
183 const BiasElementFunction& bias_element_func,
184 RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
185 LSEDramBlockWindowTmp& lse_dram_window_tmp,
186 const LSEElementFunction& lse_element_func,
187 const SAccElementFunction& s_acc_element_func,
188 const PComputeElementFunction& p_compute_element_func,
189 const OAccElementFunction& o_acc_element_func,
191 PositionEncoding position_encoding,
194 const AttentionVariantParams& variant_params,
195 const BlockIndices& block_indices,
208 static_assert(
kM0 == QDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
209 kN0 == KDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
210 kK0 == KDramBlockWindowTmp{}.get_window_lengths()[
number<1>{}] &&
211 kN1 == VDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
212 kK1 == VDramBlockWindowTmp{}.get_window_lengths()[
number<1>{}] &&
213 kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[
number<0>{}] &&
214 kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[
number<1>{}],
217 constexpr
auto LdsSeq = Policy::template GetLdsBufferSequence<Problem>();
220 auto k_lds_ptr =
reinterpret_cast<KDataType*
>(smem_ptr);
224 make_tensor_view<address_space_enum::lds>(
225 k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf)),
226 Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf).get_lengths(),
231 auto k_lds_Load_view = make_tensor_view<address_space_enum::lds>(
232 k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor<Problem>());
236 Policy::template MakeKLdsLoadBlockDescriptor<Problem>().get_lengths(),
240 auto v_lds = make_tensor_view<address_space_enum::lds>(
242 Policy::template MakeVLdsBlockDescriptor<Problem>());
244 v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
247 constexpr
auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
248 constexpr
auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
250 auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
251 q_dram_block_window_tmp.get_window_lengths(),
252 q_dram_block_window_tmp.get_window_origin(),
253 Policy::template MakeQRegTileDistribution<Problem>());
254 q_dram_window.init_raw();
258 auto q = decltype(
load_tile(q_dram_window)){};
263 __builtin_amdgcn_sched_barrier(0);
265 using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
266 auto s_acc = SaccBlockTileType{};
269 const auto f_max = [](
auto e0,
auto e1) {
return max(e0, e1); };
270 const auto f_sum = [](
auto e0,
auto e1) {
return e0 + e1; };
273 using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_acc));
275 using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
278 using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
281 auto o_acc = OaccBlockTileType{};
282 auto m = MLBlockTileType{};
283 auto l = MLBlockTileType{};
289 __builtin_amdgcn_sched_barrier(0);
290 const auto q_origin = q_dram_window.get_window_origin();
291 const auto [seqlen_k_start, seqlen_k_end] =
299 if(num_total_loop <= 0)
304 make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
316 __builtin_amdgcn_sched_barrier(0);
319 auto k_dram_block_window =
321 k_dram_block_window_tmp.get_window_lengths(),
322 {seqlen_k_start, 0});
324 auto k_dist = Policy::template MakeKDramTileDistribution<Problem>();
325 auto k_coord = k_dist.calculate_index();
326 using KDstrEncode =
typename decltype(k_dist)::DstrEncode;
327 constexpr
index_t NRepeat = KDstrEncode::hs_lengthss_[
I0][
I0];
330 k_offsets[n0] = page_idx[k_coord[0] +
kN0 / NRepeat * n0.value] * stride_k;
333 k_dram_block_window.get_window_lengths(),
334 k_dram_block_window.get_window_origin(),
337 k_dram_window.init_raw();
339 constexpr
auto k_pre_np = [&]() {
348 const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
349 auto bias_dram_window =
351 bias_dram_block_window_tmp.get_window_lengths(),
352 {bias_origin.at(number<0>{}), seqlen_k_start},
353 Policy::template MakeBiasDramTileDistribution<decltype(gemm_0)>());
355 auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
356 randval_dram_block_window_tmp, seqlen_k_start);
358 auto v_dist = Policy::template MakeVDramTileDistribution<Problem>();
359 auto v_coord = v_dist.calculate_index();
360 const auto VPageIndexDim =
I1;
361 using VDstrEncode =
typename decltype(v_dist)::DstrEncode;
362 constexpr
index_t V_KRepeat = VDstrEncode::hs_lengthss_[
I1][
I3];
366 v_offsets[k0] = page_idx[v_coord[VPageIndexDim] + k0.value] * stride_v;
371 v_dram_block_window_tmp.get_window_lengths(),
379 k_lds_store(LdsSeq.at(number<0>{})), k_dram_window, number<-1>{}, k_oob_ck, k_pre_np);
381 __builtin_amdgcn_sched_barrier(0);
384 (void)q_element_func;
391 static_assert(1 <= k0_loops);
392 static_assert(1 <= k1_loops);
398 if constexpr(k0_loops > 1)
400 static_for<0, k0_loops - 1, 1>{}([&](
auto i_k0) {
406 if constexpr(i_k0 < k0_loops - 1)
410 __builtin_amdgcn_s_barrier();
411 __builtin_amdgcn_sched_barrier(0);
414 q, sequence<0, i_k0 * kK0>{}, sequence<
kM0, (i_k0 + 1) *
kK0>{}),
416 sequence<(LdsSeq.at(number<i_k0>{})) *
kN0, 0>{},
417 sequence<(LdsSeq.at(number<i_k0>{}) + 1) *
kN0,
kK0>{}));
423 if constexpr(k0_loops <= 2)
424 __builtin_amdgcn_sched_barrier(0);
427 __builtin_amdgcn_s_barrier();
429 const auto bias_tile =
load_tile(bias_dram_window);
430 auto v_buf =
load_tile(v_dram_window,
number<-1>{}, bool_constant<false>{});
431 static_for<0, V_KRepeat, 1>{}([&](
auto k0) {
432 v_offsets[k0] = page_idx[
kK1 + v_coord[VPageIndexDim] + k0.value] * stride_v;
434 v_dram_window.update_page_idx(v_offsets);
436 __builtin_amdgcn_sched_barrier(0);
441 q, sequence<0, (k0_loops - 1) *
kK0>{}, sequence<kM0, k0_loops * kK0>{}),
443 sequence<(LdsSeq.at(number<k0_loops - 1>{})) *
kN0, 0>{},
444 sequence<(LdsSeq.at(number<k0_loops - 1>{}) + 1) *
kN0,
kK0>{}));
446 __builtin_amdgcn_sched_barrier(1);
454 [&](
auto& x,
const auto& y) {
455 #if !CK_TILE_FMHA_FWD_FAST_EXP2
456 x += type_convert<SaccDataType>(bias_element_func(y));
458 x += log2e_v<SaccDataType> *
459 type_convert<SaccDataType>(bias_element_func(y));
467 const auto k_origin = k_dram_block_window.get_window_origin();
468 constexpr
auto s_spans = decltype(s_acc)::get_distributed_spans();
473 s_acc.get_tile_distribution(),
make_tuple(idx0, idx1));
475 const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
476 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
477 constexpr
auto i_j_idx =
make_tuple(idx0, idx1);
479 s_acc(i_j_idx) *= scale_s;
480 position_encoding.update(s_acc(i_j_idx), row, col);
489 auto apply_logits_transform =
490 [&variant, &variant_params, &block_indices](
auto& x) {
491 x = variant.LogitsTransform(variant_params,
492 variant.QueryTransform(variant_params, x),
493 block_indices.batch_idx,
494 block_indices.qo_head_idx,
495 block_indices.kv_head_idx);
497 #if !CK_TILE_FMHA_FWD_FAST_EXP2
498 for(
index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
500 apply_logits_transform(s_acc.thread_buf_[i]);
503 for(
index_t i = 0; i < s_acc.thread_buf_.size(); ++i)
505 #if(defined(__gfx90a__) || defined(__gfx94__)) && \
506 (CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \
507 CK_TILE_ATTENTION_USE_SOFTSIGN_ASM)
510 if(i == s_acc.thread_buf_.size() / 2)
512 __builtin_amdgcn_sched_barrier(0);
515 apply_logits_transform(s_acc.thread_buf_[i]);
521 #if !CK_TILE_FMHA_FWD_FAST_EXP2
529 const auto k_origin = k_dram_block_window.get_window_origin();
530 bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
531 k_origin.at(number<0>{}),
535 if(need_perpixel_check)
539 const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
540 const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
541 return !variant.LogitsMask(variant_params,
542 block_indices.batch_idx,
545 block_indices.qo_head_idx,
546 block_indices.kv_head_idx);
551 const auto s = cast_tile<SMPLComputeDataType>(s_acc);
552 auto m_local = block_tile_reduce<SMPLComputeDataType>(
559 const auto m_old = m;
561 [](
auto& e0,
auto e1,
auto e2) { e0 =
max(e1, e2); }, m, m_old, m_local);
563 auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
564 s.get_tile_distribution());
566 __builtin_amdgcn_sched_barrier(0x7F);
568 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
570 auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
571 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
574 auto v_lds_window_tmp =
576 sequence<(LdsSeq.at(number<k0_loops>{})) *
kN1, 0>{},
577 sequence<(LdsSeq.at(number<k0_loops>{}) + 1) *
kN1,
kK1>{});
585 auto v_lds_window_tmp =
587 sequence<(LdsSeq.at(number<k0_loops>{})) *
kN1, 0>{},
588 sequence<(LdsSeq.at(number<k0_loops>{}) + 1) *
kN1,
kK1>{});
593 if constexpr(k1_loops > 1)
599 v_dram_window, number<-1>{}, bool_constant<false>{});
600 static_for<0, V_KRepeat, 1>{}([&](
auto k0) {
602 page_idx[
kK1 * 2 + v_coord[VPageIndexDim] + k0.value] * stride_v;
604 v_dram_window.update_page_idx(v_offsets);
606 __builtin_amdgcn_sched_barrier(0);
615 ? type_convert<SMPLComputeDataType>(0.f)
624 constexpr
auto p_spans = decltype(p_compute)::get_distributed_spans();
627 #if CK_TILE_FMHA_FWD_FAST_EXP2
628 auto row_max = scale_s * get_validated_m(m[i_idx]);
631 constexpr
auto i_j_idx =
make_tuple(idx0, idx1);
632 #if CK_TILE_FMHA_FWD_FAST_EXP2
636 p_compute(i_j_idx) =
exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
642 p_compute(i_j_idx) =
exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
646 p_compute(i_j_idx) =
exp2(scale_s * s[i_j_idx] - row_max);
650 p_compute(i_j_idx) =
exp(s[i_j_idx] - get_validated_m(m[i_idx]));
655 auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
660 constexpr
auto o_spans = decltype(o_acc)::get_distributed_spans();
663 #if CK_TILE_FMHA_FWD_FAST_EXP2
664 const auto tmp = [&]() {
668 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
674 return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
678 auto row_max = scale_s * get_validated_m(m[i_idx]);
679 return exp2(scale_s * m_old[i_idx] - row_max);
684 const auto tmp =
exp(m_old[i_idx] - get_validated_m(m[i_idx]));
686 l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
688 constexpr
auto i_j_idx =
make_tuple(idx0, idx1);
692 o_acc(i_j_idx) *= tmp;
699 reinterpret_cast<char*
>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
700 dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
702 seqlen_k_start + i_total_loops *
kN0,
704 randval_dram_window);
707 const auto p = [&]() {
708 #if CK_TILE_FMHA_FLOAT_TO_FLOAT16_RTN
714 if constexpr(std::is_same_v<PDataType, fp16_t>)
715 return impl::cast_tile_pk_fp16_fp32<PDataType>(
718 return cast_tile<PDataType>(
724 if constexpr(k1_loops > 1)
726 static_for<0, k1_loops - 1, 1>{}([&](
auto i_k1) {
727 if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
730 v_dram_window, number<-1>{}, bool_constant<false>{});
731 static_for<0, V_KRepeat, 1>{}([&](
auto k0) {
732 v_offsets[k0] = page_idx[
kK1 * 2 + i_k1.value *
kK1 +
733 v_coord[VPageIndexDim] + k0.value] *
736 v_dram_window.update_page_idx(v_offsets);
741 p, sequence<0, i_k1 * kK1>{}, sequence<
kM0, (i_k1 + 1) *
kK1>{}),
744 sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) *
kN1, 0>{},
745 sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) *
kN1,
kK1>{}));
747 if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
749 auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
750 Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
754 sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) *
kN1, 0>{},
755 sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) *
kN1,
kK1>{});
764 sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) *
kN1, 0>{},
765 sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) *
kN1,
kK1>{});
769 if constexpr(i_k1 < k1_loops - 1)
774 if(i_total_loops < num_total_loop)
779 k_dram_window.set_window_origin(k_dram_block_window.get_window_origin());
781 static_for<0, NRepeat, 1>{}([&](
auto n0) {
782 k_offsets[n0] = page_idx[k_coord[0] +
kN0 / NRepeat * n0.value] * stride_k;
784 k_dram_window.update_page_idx(k_offsets);
785 if constexpr(k1_loops >= 2 &&
786 LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
787 __builtin_amdgcn_s_barrier();
803 sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) *
kN1, 0>{},
804 sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) *
kN1,
kK1>{}));
806 }
while(i_total_loops < num_total_loop);
811 auto lse = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
813 constexpr
auto lse_spans = decltype(lse)::get_distributed_spans();
814 sweep_tile_span(lse_spans[number<0>{}], [&, m_ = m, l_ = l](
auto idx0) {
816 #if CK_TILE_FMHA_FWD_FAST_EXP2
820 lse(i_idx) = m_[i_idx] * R_LOG2E +
log(l_[i_idx]);
826 lse(i_idx) = m_[i_idx] * R_LOG2E +
log(l_[i_idx]);
830 lse(i_idx) = m_[i_idx] * scale_s * R_LOG2E +
log(l_[i_idx]);
834 lse(i_idx) = m_[i_idx] +
log(l_[i_idx]);
842 constexpr
auto o_spans = decltype(o_acc)::get_distributed_spans();
846 const auto tmp = [&]() {
847 if constexpr(FmhaMask::IsMasking)
849 return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
855 constexpr
auto i_j_idx =
make_tuple(idx0, idx1);
856 o_acc(i_j_idx) *= tmp;
865 template <
typename QDramBlockWindowTmp,
866 typename KDramBlockWindowTmp,
867 typename VDramBlockWindowTmp,
868 typename BiasDramBlockWindowTmp,
869 typename RandValDramBlockWindowTmp,
870 typename LSEDramBlockWindowTmp,
871 typename PositionEncoding,
872 typename AttentionVariantParams,
873 typename BlockIndices>
875 operator()(
const QDramBlockWindowTmp& q_dram_block_window_tmp,
876 const KDramBlockWindowTmp& k_dram_block_window_tmp,
877 const VDramBlockWindowTmp& v_dram_block_window_tmp,
878 const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
879 RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
880 LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
882 PositionEncoding position_encoding,
885 const AttentionVariantParams& variant_params,
886 const BlockIndices& block_indices,
893 return operator()(q_dram_block_window_tmp,
895 k_dram_block_window_tmp,
897 v_dram_block_window_tmp,
899 bias_dram_block_window_tmp,
901 randval_dram_block_window_tmp,
902 lse_dram_block_window_tmp,
#define CK_TILE_FMHA_FWD_FAST_EXP2
Definition: config.hpp:223
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:421
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
constexpr CK_TILE_DEVICE auto get_slice_tile(const tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile, sequence< SliceBegins... > slice_begins, sequence< SliceEnds... > slice_ends)
Definition: slice_tile.hpp:23
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
CK_TILE_DEVICE auto async_load_fence(index_t cnt=0)
Definition: load_tile.hpp:122
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition: tile_elementwise.hpp:23
tuple_array< T, N > statically_indexed_array
Definition: statically_indexed_array.hpp:16
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:83
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={})
Definition: block_reduce.hpp:18
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constant< v > number
Definition: integral_constant.hpp:33
CK_TILE_DEVICE void shuffle_tile(OutTensor &out, const InTensor &in)
Definition: shuffle_tile.hpp:154
BlockFmhaPipelineQXKSVSCustomPolicy< true, true, 3, 3 > BlockFmhaBatchPrefillPipelineQRKSVSAsyncDefaultPolicy
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async_default_policy.hpp:16
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:412
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
CK_TILE_DEVICE auto load_tile_raw(T &tile, const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Loads a tile of data using inline assembly.
Definition: load_tile.hpp:58
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:92
CK_TILE_DEVICE void buffer_load_fence(index_t cnt=0)
Definition: amd_buffer_addressing.hpp:756
CK_TILE_HOST_DEVICE void set_tile_if(static_distributed_tensor< DataType, StaticTileDistribution > &out_tensor, DataType value, XIndicesPredicate predicate)
Definition: static_distributed_tensor.hpp:175
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:406
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
constexpr CK_TILE_DEVICE auto make_tile_scatter_gather(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, const StaticPageIndexArray_ &page_idx, number< HsGatherDim >={}, number< NumCoord >={})
Definition: tile_scatter_gather.hpp:729
constexpr CK_TILE_HOST_DEVICE auto get_x_indices_from_distributed_indices(StaticTileDistribution tile_distribution, DistributedIndices distributed_indices)
Definition: static_distributed_tensor.hpp:159
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_ &&lds_tile, const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: load_tile.hpp:110
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:177
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition: bfloat16.hpp:418
constexpr bool is_same_v
Definition: type.hpp:283
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:20
static constexpr bool kPadSeqLenK
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:64
remove_cvref_t< typename Problem::BiasDataType > BiasDataType
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:28
remove_cvref_t< typename Problem::ODataType > ODataType
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:33
static constexpr index_t kK1
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:48
static constexpr index_t kAlignmentV
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:81
remove_cvref_t< typename Problem::KDataType > KDataType
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:24
static constexpr index_t kN1
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:47
static constexpr auto I1
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:52
remove_cvref_t< typename Problem::BlockFmhaShape > BlockFmhaShape
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:37
remove_cvref_t< typename Problem::FmhaMask > FmhaMask
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:35
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:153
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const QElementFunction &q_element_func, const KDramBlockWindowTmp &k_dram_block_window_tmp, const KElementFunction &, const VDramBlockWindowTmp &v_dram_block_window_tmp, const VElementFunction &v_element_func, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, const BiasElementFunction &bias_element_func, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_window_tmp, const LSEElementFunction &lse_element_func, const SAccElementFunction &s_acc_element_func, const PComputeElementFunction &p_compute_element_func, const OAccElementFunction &o_acc_element_func, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr, const index_t *page_idx, const index_t stride_k, const index_t stride_v, DropoutType &dropout) const
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:176
static constexpr index_t kM0
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:44
remove_cvref_t< typename Problem::AttentionVariant > AttentionVariant
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:34
static constexpr index_t kAlignmentO
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:87
static constexpr bool kHasLogitsSoftCap
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:67
static constexpr auto BiasEnum
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:68
remove_cvref_t< Policy_ > Policy
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:22
static constexpr auto I0
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:51
static constexpr index_t kSubQKHeaddim
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:50
static constexpr bool kHasDropout
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:70
static constexpr index_t kBlockSize
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:42
static constexpr bool kPadHeadDimQ
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:65
static constexpr index_t kQKHeaddim
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:49
static constexpr index_t kAlignmentK
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:80
static constexpr index_t kN0
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:45
remove_cvref_t< typename Problem::OaccDataType > OaccDataType
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:32
static constexpr index_t kBlockPerCu
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:95
static constexpr bool kStoreLSE
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:69
static constexpr auto I3
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:54
remove_cvref_t< typename Problem::QDataType > QDataType
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:23
static constexpr bool kPadHeadDimV
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:66
static constexpr bool kPadSeqLenQ
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:63
static constexpr index_t kAlignmentBias
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:88
remove_cvref_t< typename BlockFmhaShape::VLayout > VLayout
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:38
std::conditional_t< kHasDropout, BlockDropout, NullBlockDropout > DropoutType
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:151
CK_TILE_HOST_DEVICE auto operator()(const QDramBlockWindowTmp &q_dram_block_window_tmp, const KDramBlockWindowTmp &k_dram_block_window_tmp, const VDramBlockWindowTmp &v_dram_block_window_tmp, const BiasDramBlockWindowTmp &bias_dram_block_window_tmp, RandValDramBlockWindowTmp &randval_dram_block_window_tmp, LSEDramBlockWindowTmp &lse_dram_block_window_tmp, FmhaMask mask, PositionEncoding position_encoding, float scale_s, const AttentionVariant &variant, const AttentionVariantParams &variant_params, const BlockIndices &block_indices, void *smem_ptr, const index_t *page_idx, const index_t stride_k, const index_t stride_v, DropoutType &dropout) const
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:875
remove_cvref_t< typename Problem::SaccDataType > SaccDataType
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:26
remove_cvref_t< typename Problem::RandValOutputDataType > RandValOutputDataType
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:29
static constexpr const char * name
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:149
static constexpr index_t kK0
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:46
remove_cvref_t< typename Problem::VDataType > VDataType
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:25
remove_cvref_t< typename Problem::PDataType > PDataType
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:31
remove_cvref_t< Problem_ > Problem
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:21
static constexpr index_t kAlignmentQ
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:79
static constexpr auto I2
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:53
remove_cvref_t< typename Problem::SMPLComputeDataType > SMPLComputeDataType
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:27
static constexpr bool kQLoadOnce
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:39
static constexpr bool kIsGroupMode
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:58
remove_cvref_t< typename Problem::LSEDataType > LSEDataType
Definition: block_fmha_batch_prefill_pipeline_qr_ks_vs_async.hpp:30
Definition: integral_constant.hpp:13
Definition: functional.hpp:86
Definition: numeric.hpp:18
static constexpr CK_TILE_HOST_DEVICE T infinity()
Definition: numeric.hpp:38
Definition: sequence.hpp:52
Definition: functional.hpp:43