91 template <naive_attention_variation_enum variation_, naive_attention_quant_algo quant_algo_>
 
   99 template <
typename QType,
 
  104           typename KVScaleType,
 
  115         std::is_same_v<KType, int8_t> && std::is_same_v<VType, int8_t>;
 
  117         std::is_same_v<KType, fp8_t> && std::is_same_v<VType, fp8_t>;
 
  139     template <
typename T, naive_attention_layout_enum Layout>
 
  144         __device__ 
addresser(
int b_, 
int s_, 
int h_, 
int d_, 
void* base_ptr_)
 
  145             : 
b(b_), 
s(s_), 
h(h_), 
d(d_), 
base_ptr(reinterpret_cast<T*>(base_ptr_))
 
  161                 return i_s * 
h * 
d + i_d;
 
  163                 return i_s * 
d + i_d;
 
  172     template <
typename T, naive_attention_layout_enum Layout>
 
  176         static constexpr 
int x = 16 / 
sizeof(T); 
 
  185               base_ptr(reinterpret_cast<T*>(base_ptr_)),
 
  193             int page_idx = i_s / 
s;
 
  195             return static_cast<int64_t>(phy);
 
  210                 return static_cast<int64_t>(
i_h * 
s * 
d + page_offset * 
d + i_d) + base_;
 
  215                 return static_cast<int64_t>(
i_h * 
d * 
s + d_r * 
s * 
x + page_offset * 
x + d_x) +
 
  220                 return static_cast<int64_t>(
i_h * 
d * 
s + i_d * 
s + page_offset) + base_;
 
  225         __device__ 
void init(
int , 
int i_h_) { 
i_h = i_h_; }
 
  227         __device__ 
void store(T , 
int , 
int ) {}
 
  230     template <
typename T, naive_attention_layout_enum Layout>
 
  236             : 
s(s_), 
h(h_), 
d(d_), 
base_ptr(reinterpret_cast<T*>(p_))
 
  245                 return i_h * 
s + i_s;
 
  274     template <
typename T, 
typename F>
 
  278         constexpr 
int reduce_stage = 6; 
 
  281         for(
int i_stage = 0; i_stage < reduce_stage; i_stage++)
 
  283             int src_lane = __lane_id() ^ (1 << i_stage);
 
  285                 __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
 
  286             T v_remote = bit_cast<T>(v_remote_tmp);
 
  287             v_local    = reduce_f(v_local, v_remote);
 
  294     template <
typename T, 
typename F>
 
  297         constexpr 
int waves     = 4;
 
  298         constexpr 
int wave_size = 64;
 
  299         int lane_id             = threadIdx.x % wave_size;
 
  302         smem[threadIdx.x] = local;
 
  307         T v_local = smem[lane_id];
 
  309         for(
int i_stage = 1; i_stage < waves; i_stage++)
 
  311             T v_remote = smem[i_stage * wave_size + lane_id];
 
  312             v_local    = reduce_f(v_local, v_remote);
 
  321         __shared__ 
char smem[wg_size * 4 * 
sizeof(float)];       
 
  322         char* smem_quant_q = smem + wg_size * 2 * 
sizeof(float); 
 
  323         int i_dv           = blockIdx.x * wg_size + threadIdx.x; 
 
  324         int i_sq           = blockIdx.y;                         
 
  325         int i_batch        = blockIdx.z;                         
 
  326         int i_bq           = i_batch / args.
nhead_q;             
 
  327         int i_hq           = i_batch % args.
nhead_q;             
 
  332         void* page_table_ptr = [&]() {
 
  343         auto q_addr = [&]() {
 
  355         auto k_addr = [&]() {
 
  367         auto v_addr = [&]() {
 
  379         auto o_addr = [&]() {
 
  392         q_addr.init(i_bq, i_hq);
 
  393         k_addr.init(i_bk, i_hk);
 
  394         v_addr.init(i_bk, i_hk);
 
  395         o_addr.init(i_bq, i_hq);
 
  397         auto f_max        = [](
auto x_, 
auto y_) { 
return max(x_, y_); };
 
  398         auto f_sum        = [](
auto x_, 
auto y_) { 
return x_ + y_; };
 
  399         auto f_absmax_f32 = [](
float v_0_, 
float v_1_) {
 
  406         int seqlen_kv = [&]() {
 
  422         int sk_loops                     = (seqlen_kv + wg_size - 1) / wg_size;
 
  434             if(
static_cast<int>(threadIdx.x) < args.
hdim)
 
  436                 q   = type_convert<AccType>(q_addr.load(0, threadIdx.x));
 
  437                 k_s = type_convert<AccType>(kscale_addr.load(i_hk, threadIdx.x, 0));
 
  440             AccType q_forwarded = q * k_s;
 
  444             AccType qf_max = 
wave_reduce(q_forwarded, f_absmax_f32);
 
  445             qf_max = 
cross_wave_reduce(qf_max, f_absmax_f32, 
reinterpret_cast<AccType*
>(smem));
 
  451             q = q / q_dequant_scale;
 
  456             reinterpret_cast<QCompute*
>(smem)[threadIdx.x] = quantized_q;
 
  465             if(std::is_same_v<QType, fp16_t> || std::is_same_v<QType, bf16_t>)
 
  469                 if(
static_cast<int>(threadIdx.x) < args.
hdim)
 
  471                     q = type_convert<float>(q_addr.load(i_sq, threadIdx.x));
 
  484                 q = q / q_dequant_scale;
 
  486                 QCompute quantized_q = type_convert<QCompute>(q);
 
  488                 reinterpret_cast<QCompute*
>(smem_quant_q)[threadIdx.x] = quantized_q;
 
  497         for(
int i_loop1 = 0; i_loop1 < sk_loops; i_loop1++)
 
  499             int i_sk = i_loop1 * wg_size + threadIdx.x;
 
  505                 for(
auto i_dq = 0; i_dq < args.
hdim; i_dq++)
 
  508                         if constexpr(Traits::quant_algo ==
 
  510                                      Traits::quant_algo ==
 
  513                             return reinterpret_cast<QCompute*
>(smem_quant_q)[i_dq];
 
  516                             return q_addr.load(i_sq, i_dq); 
 
  518                     auto k = [&]() { 
return k_addr.load(i_sk, i_dq); }();
 
  520                     s_acc += type_convert<AccType>(q) * type_convert<AccType>(k);
 
  523                 s_softmax = type_convert<SoftmaxType>(s_acc);
 
  525                     type_convert<SoftmaxType>(args.
scale_s * ck_tile::log2e_v<SoftmaxType>);
 
  528                     s_softmax *= q_dequant_scale; 
 
  530                 else if constexpr(Traits::quant_algo ==
 
  534                         type_convert<SoftmaxType>(kscale_addr.load(i_sk, i_hk, 0));
 
  535                     s_softmax *= q_dequant_scale;
 
  536                     s_softmax *= k_per_token_scale;
 
  548                 row_max = 
max(old_max, cur_max); 
 
  550                 SoftmaxType p_compute = __builtin_amdgcn_exp2f(s_softmax - row_max);
 
  557                 SoftmaxType tmp = __builtin_amdgcn_exp2f(old_max - row_max);
 
  558                 l               = tmp * l + row_sum;
 
  559                 o_acc           = type_convert<OAccType>(type_convert<SoftmaxType>(o_acc) * tmp);
 
  566                     if(
static_cast<int>(threadIdx.x) < args.
hdim_v)
 
  569                             type_convert<QuantComputeType>(vscale_addr.load(i_hk, threadIdx.x, 1));
 
  585                     p_compute = p_compute / p_dequant_scale;
 
  588                     PType quantized_p = 
static_cast<PType>(p_compute);
 
  590                     reinterpret_cast<PType*
>(smem)[threadIdx.x] = quantized_p;
 
  596                 else if constexpr(Traits::quant_algo ==
 
  600                     auto v_scale = type_convert<QuantComputeType>(vscale_addr.load(i_sk, i_hk, 0));
 
  601                     p_compute *= v_scale;
 
  612                     p_compute = p_compute / p_dequant_scale;
 
  615                     PType quantized_p = type_convert<PType>(p_compute);
 
  617                     reinterpret_cast<PType*
>(smem)[threadIdx.x] = quantized_p;
 
  626                     reinterpret_cast<PType*
>(smem)[threadIdx.x] = type_convert<PType>(p_compute);
 
  632             constexpr 
int gemm_2_loop = wg_size / 
p_vec_elem;
 
  634                 AccType o_acc_local = {0};
 
  635                 int sk_start = i_loop1 * wg_size; 
 
  636                 for(
int i_loop2 = 0; i_loop2 < gemm_2_loop; i_loop2++)
 
  643                         int i_sv      = sk_start + sv_offset;
 
  646                         if(i_dv < args.
hdim_v && i_sv < seqlen_kv)
 
  648                             v = v_addr.load(i_sv, i_dv);
 
  651                         AccType v_compute = [&]() { 
return type_convert<AccType>(v); }();
 
  653                         o_acc_local += type_convert<AccType>(p_vec[i_j]) * v_compute;
 
  657                 OAccType post_scale_o_acc_local = [&]() {
 
  661                         return type_convert<OAccType>(type_convert<QuantComputeType>(o_acc_local) *
 
  664                     else if constexpr(Traits::quant_algo ==
 
  668                         return type_convert<OAccType>(type_convert<QuantComputeType>(o_acc_local) *
 
  673                         return type_convert<OAccType>(o_acc_local);
 
  676                 o_acc += post_scale_o_acc_local;
 
  683             o_acc           = type_convert<OAccType>(type_convert<SoftmaxType>(o_acc) * tmp);
 
  688             o_addr.store(type_convert<OType>(o_acc), i_sq, i_dv);
 
  692 #define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_()                                                        \ 
  694         using ktraits_ = naive_attention_fwd_kernel_traits<                                                 \ 
  695             static_cast<naive_attention_variation_enum>(variation_),                                        \ 
  696             static_cast<naive_attention_quant_algo>(quant_algo_)>;                                          \ 
  697         using k_   = naive_attention_fwd_kernel<q_type_,                                                    \ 
  710         dim3 grids = k_::get_grid_size(a);                                                                  \ 
  711         r          = ck_tile::launch_kernel(s,                                                              \ 
  712                                    ck_tile::make_kernel(k_{}, grids, k_::get_block_size(), 0, a)); \ 
  715 #define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_()                                                 \ 
  716     if(t.variation == 0 && t.q_layout == "bshd" && t.k_layout == "bshd" && t.v_layout == "bshd" && \
 
  717        t.o_layout == "bshd")                                                                       \
 
  719         constexpr auto q_layout_       = naive_attention_layout_enum::BSHD;                        \
 
  720         constexpr auto k_layout_       = naive_attention_layout_enum::BSHD;                        \
 
  721         constexpr auto v_layout_       = naive_attention_layout_enum::BSHD;                        \
 
  722         constexpr auto o_layout_       = naive_attention_layout_enum::BSHD;                        \
 
  723         constexpr auto k_scale_layout_ = naive_attention_layout_enum::DEFAULT;                     \
 
  724         constexpr auto v_scale_layout_ = naive_attention_layout_enum::DEFAULT;                     \
 
  725         constexpr int variation_       = 0;                                                        \
 
  726         CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_();                                              \
 
  728     else if(t.variation == 0 && t.q_layout == "bhsd" && t.k_layout == "bhsd" &&                    \
 
  729             t.v_layout == "bhsd" && t.o_layout == "bhsd")                                          \
 
  731         constexpr auto q_layout_       = naive_attention_layout_enum::BHSD;                        \
 
  732         constexpr auto k_layout_       = naive_attention_layout_enum::BHSD;                        \
 
  733         constexpr auto v_layout_       = naive_attention_layout_enum::BHSD;                        \
 
  734         constexpr auto o_layout_       = naive_attention_layout_enum::BHSD;                        \
 
  735         constexpr auto k_scale_layout_ = naive_attention_layout_enum::DEFAULT;                     \
 
  736         constexpr auto v_scale_layout_ = naive_attention_layout_enum::DEFAULT;                     \
 
  737         constexpr int variation_       = 0;                                                        \
 
  738         CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_();                                              \
 
  740     else if(t.variation == 2 && t.q_layout == "bhsd" && t.k_layout == "phdsx" &&                   \
 
  741             t.v_layout == "phds" && t.o_layout == "bhsd")                                          \
 
  743         constexpr auto q_layout_       = naive_attention_layout_enum::BHSD;                        \
 
  744         constexpr auto k_layout_       = naive_attention_layout_enum::PHDSX;                       \
 
  745         constexpr auto v_layout_       = naive_attention_layout_enum::PHDS;                        \
 
  746         constexpr auto o_layout_       = naive_attention_layout_enum::BHSD;                        \
 
  747         constexpr auto k_scale_layout_ = naive_attention_layout_enum::SCALE_HS;                    \
 
  748         constexpr auto v_scale_layout_ = naive_attention_layout_enum::SCALE_HS;                    \
 
  749         constexpr int variation_       = 2;                                                        \
 
  750         CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_();                                              \
 
  767         using acc_type_           = float;
 
  768         using kvscale_type_       = float;
 
  769         constexpr 
int quant_algo_ = 0;
 
  779         using acc_type_           = float;
 
  780         using kvscale_type_       = float;
 
  781         constexpr 
int quant_algo_ = 0;
 
  788         using k_type_             = 
fp8_t;
 
  789         using v_type_             = 
fp8_t;
 
  791         using acc_type_           = float; 
 
  792         using kvscale_type_       = float;
 
  793         constexpr 
int quant_algo_ = 2;
 
  800         using k_type_             = 
fp8_t;
 
  801         using v_type_             = 
fp8_t;
 
  803         using acc_type_           = float; 
 
  804         using kvscale_type_       = float;
 
  805         constexpr 
int quant_algo_ = 2;
 
  816         using kvscale_type_       = float;
 
  817         constexpr 
int quant_algo_ = 2;
 
  823 #undef CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_ 
  824 #undef CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_ 
#define CK_TILE_HOST
Definition: config.hpp:39
 
#define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_()
Definition: naive_attention.hpp:715
 
Definition: cluster_descriptor.hpp:13
 
naive_attention_variation_enum
Definition: naive_attention.hpp:32
 
_BitInt(8) fp8_t
Definition: float8.hpp:204
 
CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t, naive_attention_fwd_args a, ck_tile::stream_config s)
Definition: naive_attention.hpp:754
 
_Float16 fp16_t
Definition: half.hpp:110
 
int8_t int8_t
Definition: int8.hpp:20
 
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
 
typename impl::ext_vector< T, N >::type ext_vector_t
Definition: vector_type.hpp:83
 
int32_t int32_t
Definition: integer.hpp:10
 
naive_attention_layout_enum
Definition: naive_attention.hpp:15
 
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition: bfloat16.hpp:393
 
naive_attention_quant_algo
Definition: naive_attention.hpp:39
 
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
 
long int64_t
Definition: data_type.hpp:472
 
Layout wrapper that performs the tensor descriptor logic.
Definition: layout.hpp:24
 
Definition: naive_attention.hpp:49
 
int page_size
Definition: naive_attention.hpp:70
 
int max_kv_tokens
Definition: naive_attention.hpp:72
 
void * page_table_ptr
Definition: naive_attention.hpp:56
 
void * o_ptr
Definition: naive_attention.hpp:53
 
int seqlen_kv
Definition: naive_attention.hpp:66
 
int hdim_v
Definition: naive_attention.hpp:61
 
int hdim
Definition: naive_attention.hpp:60
 
void * k_ptr
Definition: naive_attention.hpp:51
 
int batch_kv
Definition: naive_attention.hpp:63
 
int nhead_kv
Definition: naive_attention.hpp:68
 
int nhead_ratio_kv
Definition: naive_attention.hpp:69
 
void * kscale_ptr
Definition: naive_attention.hpp:57
 
int max_pages_per_seq
Definition: naive_attention.hpp:71
 
void * v_ptr
Definition: naive_attention.hpp:52
 
int batch_q
Definition: naive_attention.hpp:62
 
int nhead_q
Definition: naive_attention.hpp:67
 
void * q_ptr
Definition: naive_attention.hpp:50
 
int seqlen_q
Definition: naive_attention.hpp:65
 
void * context_len_ptr
Definition: naive_attention.hpp:54
 
float scale_s
Definition: naive_attention.hpp:59
 
void * vscale_ptr
Definition: naive_attention.hpp:58
 
int batch_ratio_kv
Definition: naive_attention.hpp:64
 
Definition: naive_attention.hpp:141
 
__device__ void init(int i_b, int i_h)
Definition: naive_attention.hpp:167
 
int b
Definition: naive_attention.hpp:142
 
int h
Definition: naive_attention.hpp:142
 
int d
Definition: naive_attention.hpp:142
 
int s
Definition: naive_attention.hpp:142
 
T * base_ptr
Definition: naive_attention.hpp:143
 
__device__ T load(int i_s, int i_d)
Definition: naive_attention.hpp:168
 
__device__ addresser(int b_, int s_, int h_, int d_, void *base_ptr_)
Definition: naive_attention.hpp:144
 
__device__ void store(T value, int i_s, int i_d)
Definition: naive_attention.hpp:169
 
__device__ T * get_base(int i_b, int i_h)
Definition: naive_attention.hpp:150
 
__device__ int get_offset(int i_s, int i_d)
Definition: naive_attention.hpp:158
 
Definition: naive_attention.hpp:232
 
T * base_ptr
Definition: naive_attention.hpp:234
 
__device__ T load(int i_s, int i_h, int i_d)
Definition: naive_attention.hpp:254
 
__device__ int get_offset(int i_s, int i_h, int i_d)
Definition: naive_attention.hpp:239
 
int d
Definition: naive_attention.hpp:233
 
int h
Definition: naive_attention.hpp:233
 
int s
Definition: naive_attention.hpp:233
 
__device__ kvscale_addresser(int s_, int h_, int d_, void *p_)
Definition: naive_attention.hpp:235
 
Definition: naive_attention.hpp:174
 
int h
Definition: naive_attention.hpp:175
 
__device__ int get_phy_page_offset(int i_s)
Definition: naive_attention.hpp:198
 
T * base_ptr
Definition: naive_attention.hpp:177
 
int s
Definition: naive_attention.hpp:175
 
int i_h
Definition: naive_attention.hpp:179
 
int d
Definition: naive_attention.hpp:175
 
static constexpr int x
Definition: naive_attention.hpp:176
 
__device__ T load(int i_s, int i_d)
Definition: naive_attention.hpp:226
 
__device__ page_addresser(int s_, int h_, int d_, void *base_ptr_, void *pptr_)
Definition: naive_attention.hpp:181
 
__device__ int64_t get_phy_page_idx(int i_s)
Definition: naive_attention.hpp:190
 
__device__ void init(int, int i_h_)
Definition: naive_attention.hpp:225
 
__device__ int64_t get_offset(int i_s, int i_d)
Definition: naive_attention.hpp:204
 
__device__ void store(T, int, int)
Definition: naive_attention.hpp:227
 
int * page_table_ptr
Definition: naive_attention.hpp:178
 
Definition: naive_attention.hpp:132
 
static constexpr float value
Definition: naive_attention.hpp:132
 
Definition: naive_attention.hpp:93
 
static constexpr naive_attention_variation_enum variation
Definition: naive_attention.hpp:94
 
static constexpr naive_attention_quant_algo quant_algo
Definition: naive_attention.hpp:95
 
Definition: naive_attention.hpp:113
 
float QuantComputeType
Definition: naive_attention.hpp:123
 
VType PType
Definition: naive_attention.hpp:125
 
ext_vector_t< PType, 16/sizeof(PType)> p_vec_type
Definition: naive_attention.hpp:128
 
static constexpr bool is_kvcache_i8
Definition: naive_attention.hpp:114
 
constexpr __device__ T wave_reduce(T local, F reduce_f)
Definition: naive_attention.hpp:275
 
static constexpr int v_per_token_quant_group_size
Definition: naive_attention.hpp:119
 
static constexpr int p_vec_elem
Definition: naive_attention.hpp:129
 
static __host__ dim3 get_grid_size(naive_attention_fwd_args args)
Definition: naive_attention.hpp:265
 
static constexpr bool is_kvcache_fp8
Definition: naive_attention.hpp:116
 
__device__ static constexpr __host__ int get_block_size()
Definition: naive_attention.hpp:257
 
float OAccType
Definition: naive_attention.hpp:126
 
constexpr __device__ T cross_wave_reduce(T local, F reduce_f, T *smem)
Definition: naive_attention.hpp:295
 
KType QCompute
Definition: naive_attention.hpp:124
 
__device__ void operator()(naive_attention_fwd_args args)
Definition: naive_attention.hpp:318
 
__host__ __device__ naive_attention_fwd_kernel()
Definition: naive_attention.hpp:137
 
float SoftmaxType
Definition: naive_attention.hpp:122
 
Definition: naive_attention.hpp:77
 
std::string q_layout
Definition: naive_attention.hpp:82
 
std::string v_layout
Definition: naive_attention.hpp:84
 
std::string o_layout
Definition: naive_attention.hpp:85
 
std::string k_type
Definition: naive_attention.hpp:79
 
std::string k_layout
Definition: naive_attention.hpp:83
 
int variation
Definition: naive_attention.hpp:86
 
std::string v_type
Definition: naive_attention.hpp:80
 
std::string q_type
Definition: naive_attention.hpp:78
 
int quant_algo
Definition: naive_attention.hpp:87
 
std::string o_type
Definition: naive_attention.hpp:81
 
static constexpr CK_TILE_HOST_DEVICE T infinity()
Definition: numeric.hpp:38
 
Definition: stream_config.hpp:30
 
Definition: vector_type.hpp:89