/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ref/naive_attention.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ref/naive_attention.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ref/naive_attention.hpp Source File
naive_attention.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 #include <thread>
10 #include <string>
11 
12 namespace ck_tile {
13 
15 {
16  DEFAULT, // maybe this tensor is not used, set some irrelevant value
17  BSHD, // [batch, seqlen, nhead, hdim]
18  BHSD, // [batch, nhead, seqlen, hdim]
19  BS3HD, // [batch, nhead, 3, seqlen, hdim], used when qkv are packed
20  PHSD, // [pages, nhead, page_size, hdim]
21  // PHSDX, // [pages, nhead, page_size/x, hdim, x], where <# used pages>*page_size = seqlen
22  PHDSX, // [pages, nhead, hdim/x, page_size, x], where <# used pages>*page_size = seqlen
23  PHDS, // [pages, nhead, hdim, page_size], where <# used pages>*page_size = seqlen
24 
25  // scale layout used for dynamic dequant
26  SCALE_HS, // [nhead, tokens] or [nhead, tokens-per-group], nhe KVCache quant
27  SCALE_SH, // [tokens, nhead]
28 };
29 
30 // will used to specialize kernel variation
32 {
33  FLASH_BATCHED = 0, // standard flash attention, or xformer/sdpa, used for training
35  DECODE_PAGED, // decode attn, where kv token from another buffer called kvcache
36 };
37 
39 {
40  NO = 0,
41  KV_8BIT_PERHEAD = 1,
42  // FP8/INT8 quant for KVCache, per-token quant
43  // [num_tokens, nhead, hdim] -> [nhead, num_tokens]
44  KV_8BIT_PERTOKEN = 2,
45 };
46 
47 // TODO: for simplicity, this will be used as host/device arg
49 {
50  void* q_ptr;
51  void* k_ptr;
52  void* v_ptr;
53  void* o_ptr;
54  void* context_len_ptr; // [batch] used when seqlen kv come from a pointer(each element is a
55  // number, not cumsum)
56  void* page_table_ptr; // [batch, max_pages_per_seq] seqlen_kv is in different block(paged attn)
57  void* kscale_ptr; // [nhead, max_kv_tokens] used for kvcache dequant
58  void* vscale_ptr; // [nhead, max_kv_tokens] used for kvcache dequant
59  float scale_s;
60  int hdim;
61  int hdim_v; // could be cross-attn, where V and Q/K hdim are different
62  int batch_q;
63  int batch_kv;
64  int batch_ratio_kv; // batch_q / batch_kv
65  int seqlen_q; // in decode case, this should be 1
66  int seqlen_kv; // if context_len_ptr is not nullptr, ignore this field
67  int nhead_q;
68  int nhead_kv;
69  int nhead_ratio_kv; // nhead_q / nhead_kv
70  int page_size; // if paged, the seqlen-kv per each block
72  int max_kv_tokens; // used as stride to access kv scale ptr
73 };
74 
75 // this is trait for host API
77 {
78  std::string q_type;
79  std::string k_type;
80  std::string v_type;
81  std::string o_type;
82  std::string q_layout;
83  std::string k_layout;
84  std::string v_layout;
85  std::string o_layout;
86  int variation; // sync with naive_attention_variation_enum
87  int quant_algo; // sync with naive_attention_quant_algo
88 };
89 
90 // this is trait for kernel template
91 template <naive_attention_variation_enum variation_, naive_attention_quant_algo quant_algo_>
93 {
94  static constexpr naive_attention_variation_enum variation = variation_;
95  static constexpr naive_attention_quant_algo quant_algo = quant_algo_;
96 };
97 
98 // for simplicity, please do not use const-reference type for the template type
99 template <typename QType,
100  typename KType,
101  typename VType,
102  typename OType,
103  typename AccType,
104  typename KVScaleType,
109  naive_attention_layout_enum KScaleLayout,
110  naive_attention_layout_enum VScaleLayout,
111  typename Traits>
113 {
114  static constexpr bool is_kvcache_i8 =
115  std::is_same_v<KType, int8_t> && std::is_same_v<VType, int8_t>;
116  static constexpr bool is_kvcache_fp8 =
117  std::is_same_v<KType, fp8_t> && std::is_same_v<VType, fp8_t>;
118 
119  static constexpr int v_per_token_quant_group_size = 64;
120 
121  // TODO: hardcode
122  using SoftmaxType = float; // always using float to do softmax compute
123  using QuantComputeType = float; // used for quant/dequant scale compute
124  using QCompute = KType; // src A of gemm1, same type as K
125  using PType = VType; // src A of gemm2, same type as V
126  using OAccType = float; // always float, in case int8 FA
127 
128  using p_vec_type = ext_vector_t<PType, 16 / sizeof(PType)>;
130 
131  // clang-format off
132  template <typename T_> struct scale_max { static constexpr float value = 1; /* dummy code */ };
133  template <> struct scale_max<int8_t> { static constexpr float value = 127.0; };
134  template <> struct scale_max<fp8_t> { static constexpr float value = 240.0; };
135  // clang-format on
136 
137  __host__ __device__ naive_attention_fwd_kernel() {}
138 
139  template <typename T, naive_attention_layout_enum Layout>
140  struct addresser
141  {
142  int b, s, h, d; // batch, seqlen, nhead, hdim
144  __device__ addresser(int b_, int s_, int h_, int d_, void* base_ptr_)
145  : b(b_), s(s_), h(h_), d(d_), base_ptr(reinterpret_cast<T*>(base_ptr_))
146  {
147  }
148 
149  // TODO: all the batch/nhead offset will accumulate to the base pointer
150  __device__ T* get_base(int i_b, int i_h)
151  {
153  return base_ptr + i_b * s * h * d + i_h * d;
154  else if constexpr(Layout == naive_attention_layout_enum::BHSD)
155  return base_ptr + i_b * s * h * d + i_h * s * d;
156  }
157 
158  __device__ int get_offset(int i_s, int i_d)
159  {
161  return i_s * h * d + i_d;
162  else if constexpr(Layout == naive_attention_layout_enum::BHSD)
163  return i_s * d + i_d;
164  }
165 
166  // below set of API will directly use pointer inside this struct
167  __device__ void init(int i_b, int i_h) { base_ptr = get_base(i_b, i_h); }
168  __device__ T load(int i_s, int i_d) { return base_ptr[get_offset(i_s, i_d)]; }
169  __device__ void store(T value, int i_s, int i_d) { base_ptr[get_offset(i_s, i_d)] = value; }
170  };
171 
172  template <typename T, naive_attention_layout_enum Layout>
174  {
175  int s, h, d; // page_size, nhead, hdim
176  static constexpr int x = 16 / sizeof(T); // pack 4 dword
178  int* page_table_ptr; // TODO: page table always int
179  int i_h; // store current head
180 
181  __device__ page_addresser(int s_, int h_, int d_, void* base_ptr_, void* pptr_)
182  : s(s_),
183  h(h_),
184  d(d_),
185  base_ptr(reinterpret_cast<T*>(base_ptr_)),
186  page_table_ptr(reinterpret_cast<int*>(pptr_))
187  {
188  }
189 
190  __device__ int64_t get_phy_page_idx(int i_s)
191  {
192  // dynamic compute page idx is simple but slow
193  int page_idx = i_s / s;
194  int phy = page_table_ptr[page_idx];
195  return static_cast<int64_t>(phy);
196  }
197 
198  __device__ int get_phy_page_offset(int i_s)
199  {
200  // dynamic compute page idx is simple but slow
201  return i_s % s;
202  }
203 
204  __device__ int64_t get_offset(int i_s, int i_d)
205  {
206  int page_offset = get_phy_page_offset(i_s);
207  int64_t page_idx = get_phy_page_idx(i_s);
208  int64_t base_ = page_idx * h * s * d;
210  return static_cast<int64_t>(i_h * s * d + page_offset * d + i_d) + base_;
211  else if constexpr(Layout == naive_attention_layout_enum::PHDSX)
212  {
213  int d_r = i_d / x;
214  int d_x = i_d % x;
215  return static_cast<int64_t>(i_h * d * s + d_r * s * x + page_offset * x + d_x) +
216  base_;
217  }
218  else if constexpr(Layout == naive_attention_layout_enum::PHDS)
219  {
220  return static_cast<int64_t>(i_h * d * s + i_d * s + page_offset) + base_;
221  }
222  }
223 
224  // below set of API will directly use pointer inside this struct
225  __device__ void init(int /*i_b*/, int i_h_) { i_h = i_h_; }
226  __device__ T load(int i_s, int i_d) { return base_ptr[get_offset(i_s, i_d)]; }
227  __device__ void store(T /*value*/, int /*i_s*/, int /*i_d*/) {}
228  };
229 
230  template <typename T, naive_attention_layout_enum Layout>
232  {
233  int s, h, d; // seqlen(tokens), nhead, hdim
235  __device__ kvscale_addresser(int s_, int h_, int d_, void* p_)
236  : s(s_), h(h_), d(d_), base_ptr(reinterpret_cast<T*>(p_))
237  {
238  }
239  __device__ int get_offset(int i_s, int i_h, int i_d)
240  {
242  {
243  // [nhead, tokens]
244  (void)i_d;
245  return i_h * s + i_s;
246  }
247  else if constexpr(Layout == naive_attention_layout_enum::DEFAULT)
248  {
249  return 0;
250  }
251  // [h, 2, d]
252  // return i_h * 2 * d + i_kv * d + i_d;
253  }
254  __device__ T load(int i_s, int i_h, int i_d) { return base_ptr[get_offset(i_s, i_h, i_d)]; }
255  };
256 
257  __device__ __host__ static constexpr int get_block_size() { return 256; }
258 
259  // for simpliciy, 1 WG always compute 1 token along q, compute all token along kv
260  // compute all hdim from q, compute WG_SIZE hdim from v
261  // 1) in prefill case, seqlen_q >= 1, seqlen_kv >= 1, batch_q=batch_kv
262  // 2) in decode case, seqlen_q = 1, batch_q is input num-tokens, batch_kv is 1
263  // 3) in paged-attn case, we still use 1 WG compute all the seqlen-kv for simplicity
264  // TODO: could support split-kv to validate intermediate logsum
265  __host__ static dim3 get_grid_size(naive_attention_fwd_args args)
266  {
267  constexpr int wg_size = get_block_size();
268  auto g =
269  dim3((args.hdim_v + wg_size - 1) / wg_size, args.seqlen_q, args.batch_q * args.nhead_q);
270  return g;
271  }
272 
273  // reduce single pixel within a wave
274  template <typename T, typename F>
275  __device__ constexpr T wave_reduce(T local, F reduce_f)
276  {
277  // constexpr int wave_size = 64;
278  constexpr int reduce_stage = 6; // 1<<6=64
279  T v_local = local;
280 #pragma unroll
281  for(int i_stage = 0; i_stage < reduce_stage; i_stage++)
282  {
283  int src_lane = __lane_id() ^ (1 << i_stage);
284  int32_t v_remote_tmp =
285  __builtin_amdgcn_ds_bpermute(src_lane << 2, bit_cast<int32_t>(v_local));
286  T v_remote = bit_cast<T>(v_remote_tmp);
287  v_local = reduce_f(v_local, v_remote);
288  }
289  return v_local;
290  }
291 
292  // Note: this function must be called after wave_reduce
293  // Note: better not use this under if...else... with thread divergence (syncthreads)
294  template <typename T, typename F>
295  __device__ constexpr T cross_wave_reduce(T local, F reduce_f, T* smem)
296  {
297  constexpr int waves = 4;
298  constexpr int wave_size = 64;
299  int lane_id = threadIdx.x % wave_size;
300 
301  __syncthreads();
302  smem[threadIdx.x] = local;
303  __syncthreads();
304 
305  // the data within single wave is the same
306  // but for simplicity, we still use data from each lane.
307  T v_local = smem[lane_id];
308 #pragma unroll
309  for(int i_stage = 1; i_stage < waves; i_stage++)
310  {
311  T v_remote = smem[i_stage * wave_size + lane_id];
312  v_local = reduce_f(v_local, v_remote);
313  }
314  return v_local;
315  }
316 
317  // kernel entry point
318  __device__ void operator()(naive_attention_fwd_args args)
319  {
320  constexpr int wg_size = get_block_size();
321  __shared__ char smem[wg_size * 4 * sizeof(float)]; // should enough
322  char* smem_quant_q = smem + wg_size * 2 * sizeof(float); // second half, should enough
323  int i_dv = blockIdx.x * wg_size + threadIdx.x; // index of hdim_v
324  int i_sq = blockIdx.y; // index of seqlen_q
325  int i_batch = blockIdx.z; // index of batch_q * nhead_q
326  int i_bq = i_batch / args.nhead_q; // index of batch_q
327  int i_hq = i_batch % args.nhead_q; // index of nhead_q
328 
329  int i_bk = i_bq / args.batch_ratio_kv;
330  int i_hk = i_hq / args.nhead_ratio_kv;
331 
332  void* page_table_ptr = [&]() {
333  if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
334  {
335  return reinterpret_cast<int*>(args.page_table_ptr) + i_bq * args.max_pages_per_seq;
336  }
337  else
338  {
339  return nullptr;
340  }
341  }();
342 
343  auto q_addr = [&]() {
344  if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
345  {
347  args.batch_q, args.seqlen_q, args.nhead_q, args.hdim, args.q_ptr};
348  }
349  else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
350  {
352  args.batch_q, args.seqlen_q, args.nhead_q, args.hdim, args.q_ptr};
353  }
354  }();
355  auto k_addr = [&]() {
356  if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
357  {
359  args.batch_kv, args.seqlen_kv, args.nhead_kv, args.hdim, args.k_ptr};
360  }
361  else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
362  {
364  args.page_size, args.nhead_kv, args.hdim, args.k_ptr, page_table_ptr};
365  }
366  }();
367  auto v_addr = [&]() {
368  if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
369  {
371  args.batch_kv, args.seqlen_kv, args.nhead_kv, args.hdim_v, args.v_ptr};
372  }
373  else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
374  {
376  args.page_size, args.nhead_kv, args.hdim_v, args.v_ptr, page_table_ptr};
377  }
378  }();
379  auto o_addr = [&]() {
380  if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
381  {
383  args.batch_q, args.seqlen_q, args.nhead_q, args.hdim_v, args.o_ptr};
384  }
385  else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
386  {
388  args.batch_q, args.seqlen_q, args.nhead_q, args.hdim_v, args.o_ptr};
389  }
390  }();
391 
392  q_addr.init(i_bq, i_hq);
393  k_addr.init(i_bk, i_hk);
394  v_addr.init(i_bk, i_hk);
395  o_addr.init(i_bq, i_hq);
396 
397  auto f_max = [](auto x_, auto y_) { return max(x_, y_); };
398  auto f_sum = [](auto x_, auto y_) { return x_ + y_; };
399  auto f_absmax_f32 = [](float v_0_, float v_1_) {
400  // float rtn;
401  // asm volatile("v_max_f32 %0, abs(%1), abs(%2)" : "=v"(rtn) : "v"(v_0_), "v"(v_1_));
402  // return rtn;
403  return max(abs(v_0_), abs(v_1_));
404  };
405 
406  int seqlen_kv = [&]() {
407  if constexpr(Traits::variation == naive_attention_variation_enum::FLASH_BATCHED)
408  {
409  return args.seqlen_kv;
410  }
411  else if constexpr(Traits::variation == naive_attention_variation_enum::DECODE_PAGED)
412  {
413  return reinterpret_cast<int*>(args.context_len_ptr)[i_bq];
414  }
415  }();
416 
418  SoftmaxType l{0};
419  // AccType o_acc = {0};
420  OAccType o_acc = {0};
421 
422  int sk_loops = (seqlen_kv + wg_size - 1) / wg_size;
423  QuantComputeType q_dequant_scale = .0f;
425  args.max_kv_tokens, args.nhead_kv, args.hdim, args.kscale_ptr};
427  args.max_kv_tokens, args.nhead_kv, args.hdim_v, args.vscale_ptr};
428 
429  if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
430  {
431  // AccType is i32 now, seqlen_q = 1, hdim up to 256
432  AccType q = 0;
433  AccType k_s = 0;
434  if(static_cast<int>(threadIdx.x) < args.hdim)
435  {
436  q = type_convert<AccType>(q_addr.load(0, threadIdx.x));
437  k_s = type_convert<AccType>(kscale_addr.load(i_hk, threadIdx.x, 0));
438  }
439  // 1) we apply the k scale to q
440  AccType q_forwarded = q * k_s;
441 
442  // 2) apply smooth-quant
443  // find absmax
444  AccType qf_max = wave_reduce(q_forwarded, f_absmax_f32);
445  qf_max = cross_wave_reduce(qf_max, f_absmax_f32, reinterpret_cast<AccType*>(smem));
446 
447  // per-token scale
448  q_dequant_scale = type_convert<QuantComputeType>(qf_max) / scale_max<QCompute>::value;
449 
450  // devide by scale
451  q = q / q_dequant_scale;
452 
453  // fp32->i8
454  QCompute quantized_q = static_cast<QCompute>(q);
455  __syncthreads();
456  reinterpret_cast<QCompute*>(smem)[threadIdx.x] = quantized_q;
457  __syncthreads();
458 
459  // after above process, we have 2 data
460  // 1) int8 q data stored in smem(no need to reload)
461  // 2) per-token scale q_dequant_scale, to be mul after 1st gemm
462  }
463  else if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERTOKEN)
464  {
465  if(std::is_same_v<QType, fp16_t> || std::is_same_v<QType, bf16_t>)
466  {
467  // dyanmic quant q here
468  float q = 0;
469  if(static_cast<int>(threadIdx.x) < args.hdim)
470  {
471  q = type_convert<float>(q_addr.load(i_sq, threadIdx.x));
472  }
473 
474  // apply smooth-quant
475  // find absmax
476  float q_max = wave_reduce(q, f_absmax_f32);
477  q_max = cross_wave_reduce(q_max, f_absmax_f32, reinterpret_cast<float*>(smem));
478 
479  // per-token scale
480  q_dequant_scale =
481  type_convert<QuantComputeType>(q_max) / scale_max<QCompute>::value;
482 
483  // devide by scale
484  q = q / q_dequant_scale;
485 
486  QCompute quantized_q = type_convert<QCompute>(q);
487  __syncthreads();
488  reinterpret_cast<QCompute*>(smem_quant_q)[threadIdx.x] = quantized_q;
489  __syncthreads();
490 
491  // after above process, we have 2 data
492  // 1) fp8 q data stored in smem(no need to reload from global)
493  // 2) per-token scale q_dequant_scale, to be mul after 1st gemm
494  }
495  }
496 
497  for(int i_loop1 = 0; i_loop1 < sk_loops; i_loop1++)
498  {
499  int i_sk = i_loop1 * wg_size + threadIdx.x;
500  // gemm-1
502  if(i_sk < seqlen_kv)
503  {
504  AccType s_acc{0}; // clear for every loop
505  for(auto i_dq = 0; i_dq < args.hdim; i_dq++)
506  {
507  auto q = [&]() {
508  if constexpr(Traits::quant_algo ==
510  Traits::quant_algo ==
512  {
513  return reinterpret_cast<QCompute*>(smem_quant_q)[i_dq];
514  }
515  else
516  return q_addr.load(i_sq, i_dq); // q will have duplicate load
517  }();
518  auto k = [&]() { return k_addr.load(i_sk, i_dq); }();
519 
520  s_acc += type_convert<AccType>(q) * type_convert<AccType>(k);
521  }
522  // scale
523  s_softmax = type_convert<SoftmaxType>(s_acc);
524  s_softmax *=
525  type_convert<SoftmaxType>(args.scale_s * ck_tile::log2e_v<SoftmaxType>);
526  if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
527  {
528  s_softmax *= q_dequant_scale; // post scale the per-token factor
529  }
530  else if constexpr(Traits::quant_algo ==
532  {
533  SoftmaxType k_per_token_scale =
534  type_convert<SoftmaxType>(kscale_addr.load(i_sk, i_hk, 0));
535  s_softmax *= q_dequant_scale;
536  s_softmax *= k_per_token_scale;
537  }
538  }
539 
540  // s->p
541  QuantComputeType p_dequant_scale = 1.;
542  {
543  // softmax, find max
544  SoftmaxType old_max = row_max;
545  SoftmaxType cur_max = wave_reduce(s_softmax, f_max);
546 
547  cur_max = cross_wave_reduce(cur_max, f_max, reinterpret_cast<SoftmaxType*>(smem));
548  row_max = max(old_max, cur_max); // update row_max
549  // softmax, exp(i_elem - max)
550  SoftmaxType p_compute = __builtin_amdgcn_exp2f(s_softmax - row_max);
551 
552  // compute exp_sum
553  SoftmaxType row_sum = wave_reduce(p_compute, f_sum);
554  row_sum = cross_wave_reduce(row_sum, f_sum, reinterpret_cast<SoftmaxType*>(smem));
555 
556  // l, pre-scall o_acc
557  SoftmaxType tmp = __builtin_amdgcn_exp2f(old_max - row_max);
558  l = tmp * l + row_sum;
559  o_acc = type_convert<OAccType>(type_convert<SoftmaxType>(o_acc) * tmp);
560 
561  // prepare the p_compute into smem, to let every thread read same p_compute and do
562  // 2nd gemm
563  if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
564  {
565  QuantComputeType v_s = 0;
566  if(static_cast<int>(threadIdx.x) < args.hdim_v)
567  {
568  v_s =
569  type_convert<QuantComputeType>(vscale_addr.load(i_hk, threadIdx.x, 1));
570  }
571 
572  // 1) we apply the v scale to p
573  QuantComputeType p_forwarded = p_compute * v_s;
574 
575  // 2) apply smooth-quant
576  // find absmax
577  QuantComputeType pf_max = wave_reduce(p_forwarded, f_absmax_f32);
578  pf_max = cross_wave_reduce(
579  pf_max, f_absmax_f32, reinterpret_cast<QuantComputeType*>(smem));
580 
581  // per-token scale
582  p_dequant_scale = pf_max / scale_max<PType>::value; // 127.0;
583 
584  // devide by scale
585  p_compute = p_compute / p_dequant_scale;
586 
587  // fp32->i8
588  PType quantized_p = static_cast<PType>(p_compute);
589  __syncthreads();
590  reinterpret_cast<PType*>(smem)[threadIdx.x] = quantized_p;
591  __syncthreads();
592  // after above process, we have 2 data
593  // 1) int8 p data stored in smem(no need to reload)
594  // 2) per-token scale p_dequant_scale, to be mul after 2nd gemm
595  }
596  else if constexpr(Traits::quant_algo ==
598  {
599  // forward apply the v scale to p_compute, this is compute friendly
600  auto v_scale = type_convert<QuantComputeType>(vscale_addr.load(i_sk, i_hk, 0));
601  p_compute *= v_scale;
602  // smooth-quant
603  // find absmax
604  QuantComputeType p_max = wave_reduce(p_compute, f_absmax_f32);
605  p_max = cross_wave_reduce(
606  p_max, f_absmax_f32, reinterpret_cast<QuantComputeType*>(smem));
607 
608  // per-token scale
609  p_dequant_scale = p_max / scale_max<PType>::value; // 240.0;
610 
611  // devide by scale
612  p_compute = p_compute / p_dequant_scale;
613 
614  // fp32->i8
615  PType quantized_p = type_convert<PType>(p_compute);
616  __syncthreads();
617  reinterpret_cast<PType*>(smem)[threadIdx.x] = quantized_p;
618  __syncthreads();
619  // after above process, we have 2 data
620  // 1) fp8_t p data stored in smem(no need to reload)
621  // 2) per-token scale p_dequant_scale, to be mul after 2nd gemm
622  }
623  else
624  {
625  __syncthreads();
626  reinterpret_cast<PType*>(smem)[threadIdx.x] = type_convert<PType>(p_compute);
627  __syncthreads();
628  }
629  }
630 
631  // gemm-2, simple loop over vector by vector
632  constexpr int gemm_2_loop = wg_size / p_vec_elem;
633  {
634  AccType o_acc_local = {0};
635  int sk_start = i_loop1 * wg_size; // we start from the first seqlen_kv element
636  for(int i_loop2 = 0; i_loop2 < gemm_2_loop; i_loop2++)
637  {
638  p_vec_type p_vec = reinterpret_cast<p_vec_type*>(smem)[i_loop2];
639 #pragma unroll
640  for(int i_j = 0; i_j < p_vec_elem; i_j++)
641  {
642  int sv_offset = i_loop2 * p_vec_elem + i_j;
643  int i_sv = sk_start + sv_offset;
644 
645  VType v = 0;
646  if(i_dv < args.hdim_v && i_sv < seqlen_kv)
647  {
648  v = v_addr.load(i_sv, i_dv);
649  }
650 
651  AccType v_compute = [&]() { return type_convert<AccType>(v); }();
652 
653  o_acc_local += type_convert<AccType>(p_vec[i_j]) * v_compute;
654  }
655  }
656 
657  OAccType post_scale_o_acc_local = [&]() {
658  if constexpr(Traits::quant_algo == naive_attention_quant_algo::KV_8BIT_PERHEAD)
659  {
660  // apply pr scale to local acc
661  return type_convert<OAccType>(type_convert<QuantComputeType>(o_acc_local) *
662  p_dequant_scale);
663  }
664  else if constexpr(Traits::quant_algo ==
666  {
667  // apply pr scale to local acc
668  return type_convert<OAccType>(type_convert<QuantComputeType>(o_acc_local) *
669  p_dequant_scale);
670  }
671  else
672  {
673  return type_convert<OAccType>(o_acc_local);
674  }
675  }();
676  o_acc += post_scale_o_acc_local;
677  }
678  }
679 
680  // post scale o_acc
681  {
682  SoftmaxType tmp = l == 0.f ? 0.f : 1.f / l; // in case masking
683  o_acc = type_convert<OAccType>(type_convert<SoftmaxType>(o_acc) * tmp);
684  }
685 
686  // store O
687  if(i_dv < args.hdim_v)
688  o_addr.store(type_convert<OType>(o_acc), i_sq, i_dv);
689  }
690 };
691 
692 #define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_() \
693  { \
694  using ktraits_ = naive_attention_fwd_kernel_traits< \
695  static_cast<naive_attention_variation_enum>(variation_), \
696  static_cast<naive_attention_quant_algo>(quant_algo_)>; \
697  using k_ = naive_attention_fwd_kernel<q_type_, \
698  k_type_, \
699  v_type_, \
700  o_type_, \
701  acc_type_, \
702  kvscale_type_, \
703  q_layout_, \
704  k_layout_, \
705  v_layout_, \
706  o_layout_, \
707  k_scale_layout_, \
708  v_scale_layout_, \
709  ktraits_>; \
710  dim3 grids = k_::get_grid_size(a); \
711  r = ck_tile::launch_kernel(s, \
712  ck_tile::make_kernel(k_{}, grids, k_::get_block_size(), 0, a)); \
713  }
714 
715 #define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_() \
716  if(t.variation == 0 && t.q_layout == "bshd" && t.k_layout == "bshd" && t.v_layout == "bshd" && \
717  t.o_layout == "bshd") \
718  { \
719  constexpr auto q_layout_ = naive_attention_layout_enum::BSHD; \
720  constexpr auto k_layout_ = naive_attention_layout_enum::BSHD; \
721  constexpr auto v_layout_ = naive_attention_layout_enum::BSHD; \
722  constexpr auto o_layout_ = naive_attention_layout_enum::BSHD; \
723  constexpr auto k_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
724  constexpr auto v_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
725  constexpr int variation_ = 0; \
726  CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
727  } \
728  else if(t.variation == 0 && t.q_layout == "bhsd" && t.k_layout == "bhsd" && \
729  t.v_layout == "bhsd" && t.o_layout == "bhsd") \
730  { \
731  constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \
732  constexpr auto k_layout_ = naive_attention_layout_enum::BHSD; \
733  constexpr auto v_layout_ = naive_attention_layout_enum::BHSD; \
734  constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \
735  constexpr auto k_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
736  constexpr auto v_scale_layout_ = naive_attention_layout_enum::DEFAULT; \
737  constexpr int variation_ = 0; \
738  CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
739  } \
740  else if(t.variation == 2 && t.q_layout == "bhsd" && t.k_layout == "phdsx" && \
741  t.v_layout == "phds" && t.o_layout == "bhsd") \
742  { \
743  constexpr auto q_layout_ = naive_attention_layout_enum::BHSD; \
744  constexpr auto k_layout_ = naive_attention_layout_enum::PHDSX; \
745  constexpr auto v_layout_ = naive_attention_layout_enum::PHDS; \
746  constexpr auto o_layout_ = naive_attention_layout_enum::BHSD; \
747  constexpr auto k_scale_layout_ = naive_attention_layout_enum::SCALE_HS; \
748  constexpr auto v_scale_layout_ = naive_attention_layout_enum::SCALE_HS; \
749  constexpr int variation_ = 2; \
750  CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_(); \
751  }
752 
753 //
757 {
758  float r = -1;
759  // TODO: do not explicitly create too much instance!
760  if(t.q_type == "fp16" && t.k_type == "fp16" && t.v_type == "fp16" && t.o_type == "fp16" &&
761  t.quant_algo == 0)
762  {
763  using q_type_ = fp16_t;
764  using k_type_ = fp16_t;
765  using v_type_ = fp16_t;
766  using o_type_ = fp16_t;
767  using acc_type_ = float;
768  using kvscale_type_ = float;
769  constexpr int quant_algo_ = 0;
771  }
772  else if(t.q_type == "bf16" && t.k_type == "bf16" && t.v_type == "bf16" && t.o_type == "bf16" &&
773  t.quant_algo == 0)
774  {
775  using q_type_ = bf16_t;
776  using k_type_ = bf16_t;
777  using v_type_ = bf16_t;
778  using o_type_ = bf16_t;
779  using acc_type_ = float;
780  using kvscale_type_ = float;
781  constexpr int quant_algo_ = 0;
783  }
784  else if(t.q_type == "bf16" && t.k_type == "fp8" && t.v_type == "fp8" && t.o_type == "bf16" &&
785  t.quant_algo == 2)
786  {
787  using q_type_ = bf16_t;
788  using k_type_ = fp8_t;
789  using v_type_ = fp8_t;
790  using o_type_ = bf16_t;
791  using acc_type_ = float; // NOTE!
792  using kvscale_type_ = float;
793  constexpr int quant_algo_ = 2;
795  }
796  else if(t.q_type == "fp16" && t.k_type == "fp8" && t.v_type == "fp8" && t.o_type == "fp16" &&
797  t.quant_algo == 2)
798  {
799  using q_type_ = fp16_t;
800  using k_type_ = fp8_t;
801  using v_type_ = fp8_t;
802  using o_type_ = fp16_t;
803  using acc_type_ = float; // NOTE!
804  using kvscale_type_ = float;
805  constexpr int quant_algo_ = 2;
807  }
808  else if(t.q_type == "bf16" && t.k_type == "int8" && t.v_type == "int8" && t.o_type == "bf16" &&
809  t.quant_algo == 2)
810  {
811  using q_type_ = bf16_t;
812  using k_type_ = int8_t;
813  using v_type_ = int8_t;
814  using o_type_ = bf16_t;
815  using acc_type_ = int32_t; // NOTE!
816  using kvscale_type_ = float;
817  constexpr int quant_algo_ = 2;
819  }
820  return r;
821 }
822 
823 #undef CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_
824 #undef CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_INTERNAL_
825 
826 } // namespace ck_tile
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_DISPATCH_NAIVE_ATTEN_FWD_LAOYUT_()
Definition: naive_attention.hpp:715
Definition: cluster_descriptor.hpp:13
naive_attention_variation_enum
Definition: naive_attention.hpp:32
_BitInt(8) fp8_t
Definition: float8.hpp:204
CK_TILE_HOST float naive_attention_fwd(naive_attention_fwd_traits t, naive_attention_fwd_args a, ck_tile::stream_config s)
Definition: naive_attention.hpp:754
_Float16 fp16_t
Definition: half.hpp:110
int8_t int8_t
Definition: int8.hpp:20
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
typename impl::ext_vector< T, N >::type ext_vector_t
Definition: vector_type.hpp:54
naive_attention_layout_enum
Definition: naive_attention.hpp:15
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition: bfloat16.hpp:395
naive_attention_quant_algo
Definition: naive_attention.hpp:39
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
long int64_t
Definition: data_type.hpp:2474
Layout wrapper that performs the tensor descriptor logic.
Definition: layout.hpp:24
Definition: naive_attention.hpp:49
int page_size
Definition: naive_attention.hpp:70
int max_kv_tokens
Definition: naive_attention.hpp:72
void * page_table_ptr
Definition: naive_attention.hpp:56
void * o_ptr
Definition: naive_attention.hpp:53
int seqlen_kv
Definition: naive_attention.hpp:66
int hdim_v
Definition: naive_attention.hpp:61
int hdim
Definition: naive_attention.hpp:60
void * k_ptr
Definition: naive_attention.hpp:51
int batch_kv
Definition: naive_attention.hpp:63
int nhead_kv
Definition: naive_attention.hpp:68
int nhead_ratio_kv
Definition: naive_attention.hpp:69
void * kscale_ptr
Definition: naive_attention.hpp:57
int max_pages_per_seq
Definition: naive_attention.hpp:71
void * v_ptr
Definition: naive_attention.hpp:52
int batch_q
Definition: naive_attention.hpp:62
int nhead_q
Definition: naive_attention.hpp:67
void * q_ptr
Definition: naive_attention.hpp:50
int seqlen_q
Definition: naive_attention.hpp:65
void * context_len_ptr
Definition: naive_attention.hpp:54
float scale_s
Definition: naive_attention.hpp:59
void * vscale_ptr
Definition: naive_attention.hpp:58
int batch_ratio_kv
Definition: naive_attention.hpp:64
Definition: naive_attention.hpp:141
__device__ void init(int i_b, int i_h)
Definition: naive_attention.hpp:167
int b
Definition: naive_attention.hpp:142
int h
Definition: naive_attention.hpp:142
int d
Definition: naive_attention.hpp:142
int s
Definition: naive_attention.hpp:142
T * base_ptr
Definition: naive_attention.hpp:143
__device__ T load(int i_s, int i_d)
Definition: naive_attention.hpp:168
__device__ addresser(int b_, int s_, int h_, int d_, void *base_ptr_)
Definition: naive_attention.hpp:144
__device__ void store(T value, int i_s, int i_d)
Definition: naive_attention.hpp:169
__device__ T * get_base(int i_b, int i_h)
Definition: naive_attention.hpp:150
__device__ int get_offset(int i_s, int i_d)
Definition: naive_attention.hpp:158
T * base_ptr
Definition: naive_attention.hpp:234
__device__ T load(int i_s, int i_h, int i_d)
Definition: naive_attention.hpp:254
__device__ int get_offset(int i_s, int i_h, int i_d)
Definition: naive_attention.hpp:239
int d
Definition: naive_attention.hpp:233
int h
Definition: naive_attention.hpp:233
int s
Definition: naive_attention.hpp:233
__device__ kvscale_addresser(int s_, int h_, int d_, void *p_)
Definition: naive_attention.hpp:235
Definition: naive_attention.hpp:174
int h
Definition: naive_attention.hpp:175
__device__ int get_phy_page_offset(int i_s)
Definition: naive_attention.hpp:198
T * base_ptr
Definition: naive_attention.hpp:177
int s
Definition: naive_attention.hpp:175
int i_h
Definition: naive_attention.hpp:179
int d
Definition: naive_attention.hpp:175
static constexpr int x
Definition: naive_attention.hpp:176
__device__ T load(int i_s, int i_d)
Definition: naive_attention.hpp:226
__device__ page_addresser(int s_, int h_, int d_, void *base_ptr_, void *pptr_)
Definition: naive_attention.hpp:181
__device__ int64_t get_phy_page_idx(int i_s)
Definition: naive_attention.hpp:190
__device__ void init(int, int i_h_)
Definition: naive_attention.hpp:225
__device__ int64_t get_offset(int i_s, int i_d)
Definition: naive_attention.hpp:204
__device__ void store(T, int, int)
Definition: naive_attention.hpp:227
int * page_table_ptr
Definition: naive_attention.hpp:178
Definition: naive_attention.hpp:132
static constexpr float value
Definition: naive_attention.hpp:132
Definition: naive_attention.hpp:93
static constexpr naive_attention_variation_enum variation
Definition: naive_attention.hpp:94
static constexpr naive_attention_quant_algo quant_algo
Definition: naive_attention.hpp:95
Definition: naive_attention.hpp:113
float QuantComputeType
Definition: naive_attention.hpp:123
VType PType
Definition: naive_attention.hpp:125
ext_vector_t< PType, 16/sizeof(PType)> p_vec_type
Definition: naive_attention.hpp:128
static constexpr bool is_kvcache_i8
Definition: naive_attention.hpp:114
constexpr __device__ T wave_reduce(T local, F reduce_f)
Definition: naive_attention.hpp:275
static constexpr int v_per_token_quant_group_size
Definition: naive_attention.hpp:119
static constexpr int p_vec_elem
Definition: naive_attention.hpp:129
static __host__ dim3 get_grid_size(naive_attention_fwd_args args)
Definition: naive_attention.hpp:265
static constexpr bool is_kvcache_fp8
Definition: naive_attention.hpp:116
__device__ static constexpr __host__ int get_block_size()
Definition: naive_attention.hpp:257
float OAccType
Definition: naive_attention.hpp:126
constexpr __device__ T cross_wave_reduce(T local, F reduce_f, T *smem)
Definition: naive_attention.hpp:295
KType QCompute
Definition: naive_attention.hpp:124
__device__ void operator()(naive_attention_fwd_args args)
Definition: naive_attention.hpp:318
__host__ __device__ naive_attention_fwd_kernel()
Definition: naive_attention.hpp:137
float SoftmaxType
Definition: naive_attention.hpp:122
Definition: naive_attention.hpp:77
std::string q_layout
Definition: naive_attention.hpp:82
std::string v_layout
Definition: naive_attention.hpp:84
std::string o_layout
Definition: naive_attention.hpp:85
std::string k_type
Definition: naive_attention.hpp:79
std::string k_layout
Definition: naive_attention.hpp:83
int variation
Definition: naive_attention.hpp:86
std::string v_type
Definition: naive_attention.hpp:80
std::string q_type
Definition: naive_attention.hpp:78
int quant_algo
Definition: naive_attention.hpp:87
std::string o_type
Definition: naive_attention.hpp:81
static constexpr CK_TILE_HOST_DEVICE T infinity()
Definition: numeric.hpp:38
Definition: stream_config.hpp:26
Definition: vector_type.hpp:60