91 template <naive_attention_variation_enum variation_, naive_attention_quant_algo quant_algo_>
99 template <
typename QType,
104 typename KVScaleType,
115 std::is_same_v<KType, int8_t> && std::is_same_v<VType, int8_t>;
117 std::is_same_v<KType, fp8_t> && std::is_same_v<VType, fp8_t>;
139 template <
typename T, naive_attention_layout_enum Layout>
144 __device__
addresser(
int b_,
int s_,
int h_,
int d_,
void* base_ptr_)
145 :
b(b_),
s(s_),
h(h_),
d(d_),
base_ptr(reinterpret_cast<T*>(base_ptr_))
161 return i_s *
h *
d + i_d;
163 return i_s *
d + i_d;
172 template <
typename T, naive_attention_layout_enum Layout>
176 static constexpr
int x = 16 /
sizeof(T);
185 base_ptr(reinterpret_cast<T*>(base_ptr_)),
193 int page_idx = i_s /
s;
195 return static_cast<int64_t>(phy);
210 return static_cast<int64_t>(
i_h *
s *
d + page_offset *
d + i_d) + base_;
215 return static_cast<int64_t>(
i_h *
d *
s + d_r *
s *
x + page_offset *
x + d_x) +
220 return static_cast<int64_t>(
i_h *
d *
s + i_d *
s + page_offset) + base_;
225 __device__
void init(
int ,
int i_h_) {
i_h = i_h_; }
227 __device__
void store(T ,
int ,
int ) {}
230 template <
typename T, naive_attention_layout_enum Layout>
236 :
s(s_),
h(h_),
d(d_),
base_ptr(reinterpret_cast<T*>(p_))
245 return i_h *
s + i_s;
274 template <
typename T,
typename F>
278 constexpr
int reduce_stage = 6;
281 for(
int i_stage = 0; i_stage < reduce_stage; i_stage++)
283 int src_lane = __lane_id() ^ (1 << i_stage);
284 int32_t v_remote_tmp =
285 __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
286 T v_remote = bit_cast<T>(v_remote_tmp);
287 v_local = reduce_f(v_local, v_remote);
294 template <
typename T,
typename F>
297 constexpr
int waves = 4;
298 constexpr
int wave_size = 64;
299 int lane_id = threadIdx.x % wave_size;
302 smem[threadIdx.x] = local;
307 T v_local = smem[lane_id];
309 for(
int i_stage = 1; i_stage < waves; i_stage++)
311 T v_remote = smem[i_stage * wave_size + lane_id];
312 v_local = reduce_f(v_local, v_remote);
321 __shared__
char smem[wg_size * 4 *
sizeof(float)];
322 char* smem_quant_q = smem + wg_size * 2 *
sizeof(float);
323 int i_dv = blockIdx.x * wg_size + threadIdx.x;
324 int i_sq = blockIdx.y;
325 int i_batch = blockIdx.z;
326 int i_bq = i_batch / args.
nhead_q;
327 int i_hq = i_batch % args.
nhead_q;
332 void* page_table_ptr = [&]() {
343 auto q_addr = [&]() {
355 auto k_addr = [&]() {
367 auto v_addr = [&]() {
379 auto o_addr = [&]() {
392 q_addr.init(i_bq, i_hq);
393 k_addr.init(i_bk, i_hk);
394 v_addr.init(i_bk, i_hk);
395 o_addr.init(i_bq, i_hq);
397 auto f_max = [](
auto x_,
auto y_) {
return max(x_, y_); };
398 auto f_sum = [](
auto x_,
auto y_) {
return x_ + y_; };
399 auto f_absmax_f32 = [](
float v_0_,
float v_1_) {
406 int seqlen_kv = [&]() {
422 int sk_loops = (seqlen_kv + wg_size - 1) / wg_size;
434 if(
static_cast<int>(threadIdx.x) < args.
hdim)
436 q = type_convert<AccType>(q_addr.load(0, threadIdx.x));
437 k_s = type_convert<AccType>(kscale_addr.load(i_hk, threadIdx.x, 0));
440 AccType q_forwarded = q * k_s;
444 AccType qf_max =
wave_reduce(q_forwarded, f_absmax_f32);
445 qf_max =
cross_wave_reduce(qf_max, f_absmax_f32,
reinterpret_cast<AccType*
>(smem));
451 q = q / q_dequant_scale;
456 reinterpret_cast<QCompute*
>(smem)[threadIdx.x] = quantized_q;
465 if(std::is_same_v<QType, fp16_t> || std::is_same_v<QType, bf16_t>)
469 if(
static_cast<int>(threadIdx.x) < args.
hdim)
471 q = type_convert<float>(q_addr.load(i_sq, threadIdx.x));
484 q = q / q_dequant_scale;
486 QCompute quantized_q = type_convert<QCompute>(q);
488 reinterpret_cast<QCompute*
>(smem_quant_q)[threadIdx.x] = quantized_q;
497 for(
int i_loop1 = 0; i_loop1 < sk_loops; i_loop1++)
499 int i_sk = i_loop1 * wg_size + threadIdx.x;
505 for(
auto i_dq = 0; i_dq < args.
hdim; i_dq++)
508 if constexpr(Traits::quant_algo ==
510 Traits::quant_algo ==
513 return reinterpret_cast<QCompute*
>(smem_quant_q)[i_dq];
516 return q_addr.load(i_sq, i_dq);
518 auto k = [&]() {
return k_addr.load(i_sk, i_dq); }();
520 s_acc += type_convert<AccType>(q) * type_convert<AccType>(k);
523 s_softmax = type_convert<SoftmaxType>(s_acc);
525 type_convert<SoftmaxType>(args.
scale_s * ck_tile::log2e_v<SoftmaxType>);
528 s_softmax *= q_dequant_scale;
530 else if constexpr(Traits::quant_algo ==
534 type_convert<SoftmaxType>(kscale_addr.load(i_sk, i_hk, 0));
535 s_softmax *= q_dequant_scale;
536 s_softmax *= k_per_token_scale;
548 row_max =
max(old_max, cur_max);
550 SoftmaxType p_compute = __builtin_amdgcn_exp2f(s_softmax - row_max);
557 SoftmaxType tmp = __builtin_amdgcn_exp2f(old_max - row_max);
558 l = tmp * l + row_sum;
559 o_acc = type_convert<OAccType>(type_convert<SoftmaxType>(o_acc) * tmp);
566 if(
static_cast<int>(threadIdx.x) < args.
hdim_v)
569 type_convert<QuantComputeType>(vscale_addr.load(i_hk, threadIdx.x, 1));
585 p_compute = p_compute / p_dequant_scale;
588 PType quantized_p =
static_cast<PType>(p_compute);
590 reinterpret_cast<PType*
>(smem)[threadIdx.x] = quantized_p;
596 else if constexpr(Traits::quant_algo ==
600 auto v_scale = type_convert<QuantComputeType>(vscale_addr.load(i_sk, i_hk, 0));
601 p_compute *= v_scale;
612 p_compute = p_compute / p_dequant_scale;
615 PType quantized_p = type_convert<PType>(p_compute);
617 reinterpret_cast<PType*
>(smem)[threadIdx.x] = quantized_p;
626 reinterpret_cast<PType*
>(smem)[threadIdx.x] = type_convert<PType>(p_compute);
632 constexpr
int gemm_2_loop = wg_size /
p_vec_elem;
634 AccType o_acc_local = {0};
635 int sk_start = i_loop1 * wg_size;
636 for(
int i_loop2 = 0; i_loop2 < gemm_2_loop; i_loop2++)
643 int i_sv = sk_start + sv_offset;
646 if(i_dv < args.
hdim_v && i_sv < seqlen_kv)
648 v = v_addr.load(i_sv, i_dv);
651 AccType v_compute = [&]() {
return type_convert<AccType>(v); }();
653 o_acc_local += type_convert<AccType>(p_vec[i_j]) * v_compute;
657 OAccType post_scale_o_acc_local = [&]() {
661 return type_convert<OAccType>(type_convert<QuantComputeType>(o_acc_local) *
664 else if constexpr(Traits::quant_algo ==
668 return type_convert<OAccType>(type_convert<QuantComputeType>(o_acc_local) *
673 return type_convert<OAccType>(o_acc_local);
676 o_acc += post_scale_o_acc_local;
683 o_acc = type_convert<OAccType>(type_convert<SoftmaxType>(o_acc) * tmp);
688 o_addr.store(type_convert<OType>(o_acc), i_sq, i_dv);
692 #define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_() \
694 using ktraits_ = naive_attention_fwd_kernel_traits< \
695 static_cast<naive_attention_variation_enum>(variation_), \
696 static_cast<naive_attention_quant_algo>(quant_algo_)>; \
697 using k_ = naive_attention_fwd_kernel<q_type_, \
710 dim3 grids = k_::get_grid_size(a); \
711 r = ck_tile::launch_kernel(s, \
712 ck_tile::make_kernel(k_{}, grids, k_::get_block_size(), 0, a)); \
715 #define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_() \
716 if(t.variation == 0 && t.q_layout == "bshd" && t.k_layout == "bshd" && t.v_layout == "bshd" && \
717 t.o_layout == "bshd") \
719 constexpr auto q_layout_ = naive_attention_layout_enum::BSHD; \
720 constexpr auto k_layout_ = naive_attention_layout_enum::BSHD; \
721 constexpr auto v_layout_ = naive_attention_layout_enum::BSHD; \
722 constexpr auto o_layout_ = naive_attention_layout_enum::BSHD; \
723 constexpr auto k_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
724 constexpr auto v_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
725 constexpr int variation_ = 0; \
726 CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
728 else if(t.variation == 0 && t.q_layout == "bhsd" && t.k_layout == "bhsd" && \
729 t.v_layout == "bhsd" && t.o_layout == "bhsd") \
731 constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \
732 constexpr auto k_layout_ = naive_attention_layout_enum::BHSD; \
733 constexpr auto v_layout_ = naive_attention_layout_enum::BHSD; \
734 constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \
735 constexpr auto k_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
736 constexpr auto v_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
737 constexpr int variation_ = 0; \
738 CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
740 else if(t.variation == 2 && t.q_layout == "bhsd" && t.k_layout == "phdsx" && \
741 t.v_layout == "phds" && t.o_layout == "bhsd") \
743 constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \
744 constexpr auto k_layout_ = naive_attention_layout_enum::PHDSX; \
745 constexpr auto v_layout_ = naive_attention_layout_enum::PHDS; \
746 constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \
747 constexpr auto k_scale_layout_ = naive_attention_layout_enum::SCALE_HS; \
748 constexpr auto v_scale_layout_ = naive_attention_layout_enum::SCALE_HS; \
749 constexpr int variation_ = 2; \
750 CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
767 using acc_type_ = float;
768 using kvscale_type_ = float;
769 constexpr
int quant_algo_ = 0;
779 using acc_type_ = float;
780 using kvscale_type_ = float;
781 constexpr
int quant_algo_ = 0;
788 using k_type_ =
fp8_t;
789 using v_type_ =
fp8_t;
791 using acc_type_ = float;
792 using kvscale_type_ = float;
793 constexpr
int quant_algo_ = 2;
800 using k_type_ =
fp8_t;
801 using v_type_ =
fp8_t;
803 using acc_type_ = float;
804 using kvscale_type_ = float;
805 constexpr
int quant_algo_ = 2;
815 using acc_type_ = int32_t;
816 using kvscale_type_ = float;
817 constexpr
int quant_algo_ = 2;
823 #undef CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
824 #undef CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_()
Definition: naive_attention.hpp:715
Definition: cluster_descriptor.hpp:13
naive_attention_variation_enum
Definition: naive_attention.hpp:32
_BitInt(8) fp8_t
Definition: float8.hpp:204
CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t, naive_attention_fwd_args a, ck_tile::stream_config s)
Definition: naive_attention.hpp:754
_Float16 fp16_t
Definition: half.hpp:110
int8_t int8_t
Definition: int8.hpp:20
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
typename impl::ext_vector< T, N >::type ext_vector_t
Definition: vector_type.hpp:54
naive_attention_layout_enum
Definition: naive_attention.hpp:15
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition: bfloat16.hpp:395
naive_attention_quant_algo
Definition: naive_attention.hpp:39
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
long int64_t
Definition: data_type.hpp:2474
Layout wrapper that performs the tensor descriptor logic.
Definition: layout.hpp:24
Definition: naive_attention.hpp:49
int page_size
Definition: naive_attention.hpp:70
int max_kv_tokens
Definition: naive_attention.hpp:72
void * page_table_ptr
Definition: naive_attention.hpp:56
void * o_ptr
Definition: naive_attention.hpp:53
int seqlen_kv
Definition: naive_attention.hpp:66
int hdim_v
Definition: naive_attention.hpp:61
int hdim
Definition: naive_attention.hpp:60
void * k_ptr
Definition: naive_attention.hpp:51
int batch_kv
Definition: naive_attention.hpp:63
int nhead_kv
Definition: naive_attention.hpp:68
int nhead_ratio_kv
Definition: naive_attention.hpp:69
void * kscale_ptr
Definition: naive_attention.hpp:57
int max_pages_per_seq
Definition: naive_attention.hpp:71
void * v_ptr
Definition: naive_attention.hpp:52
int batch_q
Definition: naive_attention.hpp:62
int nhead_q
Definition: naive_attention.hpp:67
void * q_ptr
Definition: naive_attention.hpp:50
int seqlen_q
Definition: naive_attention.hpp:65
void * context_len_ptr
Definition: naive_attention.hpp:54
float scale_s
Definition: naive_attention.hpp:59
void * vscale_ptr
Definition: naive_attention.hpp:58
int batch_ratio_kv
Definition: naive_attention.hpp:64
Definition: naive_attention.hpp:141
__device__ void init(int i_b, int i_h)
Definition: naive_attention.hpp:167
int b
Definition: naive_attention.hpp:142
int h
Definition: naive_attention.hpp:142
int d
Definition: naive_attention.hpp:142
int s
Definition: naive_attention.hpp:142
T * base_ptr
Definition: naive_attention.hpp:143
__device__ T load(int i_s, int i_d)
Definition: naive_attention.hpp:168
__device__ addresser(int b_, int s_, int h_, int d_, void *base_ptr_)
Definition: naive_attention.hpp:144
__device__ void store(T value, int i_s, int i_d)
Definition: naive_attention.hpp:169
__device__ T * get_base(int i_b, int i_h)
Definition: naive_attention.hpp:150
__device__ int get_offset(int i_s, int i_d)
Definition: naive_attention.hpp:158
Definition: naive_attention.hpp:232
T * base_ptr
Definition: naive_attention.hpp:234
__device__ T load(int i_s, int i_h, int i_d)
Definition: naive_attention.hpp:254
__device__ int get_offset(int i_s, int i_h, int i_d)
Definition: naive_attention.hpp:239
int d
Definition: naive_attention.hpp:233
int h
Definition: naive_attention.hpp:233
int s
Definition: naive_attention.hpp:233
__device__ kvscale_addresser(int s_, int h_, int d_, void *p_)
Definition: naive_attention.hpp:235
Definition: naive_attention.hpp:174
int h
Definition: naive_attention.hpp:175
__device__ int get_phy_page_offset(int i_s)
Definition: naive_attention.hpp:198
T * base_ptr
Definition: naive_attention.hpp:177
int s
Definition: naive_attention.hpp:175
int i_h
Definition: naive_attention.hpp:179
int d
Definition: naive_attention.hpp:175
static constexpr int x
Definition: naive_attention.hpp:176
__device__ T load(int i_s, int i_d)
Definition: naive_attention.hpp:226
__device__ page_addresser(int s_, int h_, int d_, void *base_ptr_, void *pptr_)
Definition: naive_attention.hpp:181
__device__ int64_t get_phy_page_idx(int i_s)
Definition: naive_attention.hpp:190
__device__ void init(int, int i_h_)
Definition: naive_attention.hpp:225
__device__ int64_t get_offset(int i_s, int i_d)
Definition: naive_attention.hpp:204
__device__ void store(T, int, int)
Definition: naive_attention.hpp:227
int * page_table_ptr
Definition: naive_attention.hpp:178
Definition: naive_attention.hpp:132
static constexpr float value
Definition: naive_attention.hpp:132
Definition: naive_attention.hpp:93
static constexpr naive_attention_variation_enum variation
Definition: naive_attention.hpp:94
static constexpr naive_attention_quant_algo quant_algo
Definition: naive_attention.hpp:95
Definition: naive_attention.hpp:113
float QuantComputeType
Definition: naive_attention.hpp:123
VType PType
Definition: naive_attention.hpp:125
ext_vector_t< PType, 16/sizeof(PType)> p_vec_type
Definition: naive_attention.hpp:128
static constexpr bool is_kvcache_i8
Definition: naive_attention.hpp:114
constexpr __device__ T wave_reduce(T local, F reduce_f)
Definition: naive_attention.hpp:275
static constexpr int v_per_token_quant_group_size
Definition: naive_attention.hpp:119
static constexpr int p_vec_elem
Definition: naive_attention.hpp:129
static __host__ dim3 get_grid_size(naive_attention_fwd_args args)
Definition: naive_attention.hpp:265
static constexpr bool is_kvcache_fp8
Definition: naive_attention.hpp:116
__device__ static constexpr __host__ int get_block_size()
Definition: naive_attention.hpp:257
float OAccType
Definition: naive_attention.hpp:126
constexpr __device__ T cross_wave_reduce(T local, F reduce_f, T *smem)
Definition: naive_attention.hpp:295
KType QCompute
Definition: naive_attention.hpp:124
__device__ void operator()(naive_attention_fwd_args args)
Definition: naive_attention.hpp:318
__host__ __device__ naive_attention_fwd_kernel()
Definition: naive_attention.hpp:137
float SoftmaxType
Definition: naive_attention.hpp:122
Definition: naive_attention.hpp:77
std::string q_layout
Definition: naive_attention.hpp:82
std::string v_layout
Definition: naive_attention.hpp:84
std::string o_layout
Definition: naive_attention.hpp:85
std::string k_type
Definition: naive_attention.hpp:79
std::string k_layout
Definition: naive_attention.hpp:83
int variation
Definition: naive_attention.hpp:86
std::string v_type
Definition: naive_attention.hpp:80
std::string q_type
Definition: naive_attention.hpp:78
int quant_algo
Definition: naive_attention.hpp:87
std::string o_type
Definition: naive_attention.hpp:81
static constexpr CK_TILE_HOST_DEVICE T infinity()
Definition: numeric.hpp:38
Definition: stream_config.hpp:26
Definition: vector_type.hpp:60