11 #define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH 0 
   12 #define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN 1 
   14 #ifndef CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT 
   15 #define CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH 
   18 #ifndef CK_TILE_ATTENTION_USE_SOFTSIGN_ASM 
   19 #define CK_TILE_ATTENTION_USE_SOFTSIGN_ASM 0 
   24 __device__ 
inline float 
   27 #if(defined(__gfx90a__) || defined(__gfx94__)) &&                                               \ 
   28     (CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN && \ 
   29      CK_TILE_ATTENTION_USE_SOFTSIGN_ASM) 
   31     float result, numerator, denominator;
 
   33         "v_mul_f32_e32 %[denominator], %[logits], %[logits_soft_cap_rcp]\n" 
   34         "v_add_f32_e64 %[denominator], |%[denominator]|, 1.0\n" 
   35         "v_rcp_f32_e32 %[denominator], %[denominator]\n" 
   36         "v_mul_f32_e32 %[numerator], %[softmax_scale], %[logits]\n" 
   37         "v_mul_f32_e32 %[result], %[numerator], %[denominator]" 
   38         : [numerator] 
"=&v"(numerator), [denominator] 
"=&v"(denominator), [result] 
"=v"(result)
 
   39         : [softmax_scale] 
"s"(softmax_scale),
 
   41           [logits_soft_cap_rcp] 
"v"(logits_soft_cap_rcp));
 
   44     return softmax_scale * logits * rcp<float>(1.f + 
abs(logits * logits_soft_cap_rcp));
 
   49 template <
typename ImplMask>
 
   61 template <
typename ImplMask, 
bool UseExp2 = false>
 
  101         if constexpr(UseExp2)
 
  110                                             float logits_soft_cap_,
 
  111                                             float logits_soft_cap_rcp_)
 
  119         if constexpr(UseExp2)
 
  136     template <
typename Params, 
typename T>
 
  139         return type_convert<float>(q) * params.sm_scale;
 
  144     template <
typename Params, 
typename T>
 
  147                                                  [[maybe_unused]] uint32_t batch_idx,
 
  149                                                  [[maybe_unused]] uint32_t qo_head_idx,
 
  150                                                  [[maybe_unused]] uint32_t kv_head_idx)
 const 
  155     template <
typename Params>
 
  156     __device__ __forceinline__ 
bool LogitsMask(
const Params& params,
 
  157                                                [[maybe_unused]] uint32_t batch_idx,
 
  160                                                [[maybe_unused]] uint32_t qo_head_idx,
 
  161                                                [[maybe_unused]] uint32_t kv_head_idx)
 const 
  163         return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
 
  167 template <
bool UseExp2 = false>
 
  172     template <
typename Params, 
typename T>
 
  175         if constexpr(UseExp2)
 
  181             return type_convert<float>(q) * params.sm_scale;
 
  187     template <
typename Params, 
typename T>
 
  190                                                  [[maybe_unused]] uint32_t batch_idx,
 
  192                                                  [[maybe_unused]] uint32_t qo_head_idx,
 
  193                                                  [[maybe_unused]] uint32_t kv_head_idx)
 const 
  195         if constexpr(UseExp2)
 
  197 #if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH 
  198             return params.logits_soft_cap *
 
  200 #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN 
  202                 params.sm_scale, type_convert<float>(logits), params.logits_soft_cap_rcp);
 
  207 #if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH 
  208             return params.logits_soft_cap *
 
  209                    tanhf(type_convert<float>(logits) * params.logits_soft_cap_rcp);
 
  210 #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN 
  211             return type_convert<float>(logits) *
 
  212                    rcp<float>(1.f + 
abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
 
  217     template <
typename Params>
 
  218     __device__ __forceinline__ 
bool LogitsMask(
const Params& params,
 
  219                                                [[maybe_unused]] uint32_t batch_idx,
 
  222                                                [[maybe_unused]] uint32_t qo_head_idx,
 
  223                                                [[maybe_unused]] uint32_t kv_head_idx)
 const 
  225         return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
 
  234 template <u
int32_t VARIANT_CODE, 
bool UseExp2 = false>
 
  243     template <
typename Params, 
typename T>
 
  250         return type_convert<float>(q) * params.sm_scale;
 
  255     template <
typename Params, 
typename T>
 
  258                                                  [[maybe_unused]] uint32_t batch_idx,
 
  260                                                  [[maybe_unused]] uint32_t qo_head_idx,
 
  261                                                  [[maybe_unused]] uint32_t kv_head_idx)
 const 
  265             if constexpr(UseExp2)
 
  267 #if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH 
  268                 return params.logits_soft_cap *
 
  270 #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN 
  272                     params.sm_scale, type_convert<float>(logits), params.logits_soft_cap_rcp);
 
  277 #if CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_TANH 
  278                 return params.logits_soft_cap *
 
  279                        tanhf(type_convert<float>(logits) * params.logits_soft_cap_rcp);
 
  280 #elif CK_TILE_ATTENTION_LOGITS_SOFT_CAP_DEFAULT == CK_TILE_ATTENTION_LOGITS_SOFT_CAP_SOFTSIGN 
  281                 return type_convert<float>(logits) *
 
  283                                   abs(type_convert<float>(logits) * params.logits_soft_cap_rcp));
 
  290     template <
typename Params>
 
  291     __device__ __forceinline__ 
bool LogitsMask(
const Params& params,
 
  292                                                [[maybe_unused]] uint32_t batch_idx,
 
  295                                                [[maybe_unused]] uint32_t qo_head_idx,
 
  296                                                [[maybe_unused]] uint32_t kv_head_idx)
 const 
  298         return !params.impl_mask.IsOutOfBound(qo_idx, kv_idx);
 
__device__ float exp2_soft_sign_impl(float softmax_scale, float logits, float logits_soft_cap_rcp)
Definition: variants.hpp:25
 
Definition: cluster_descriptor.hpp:13
 
constexpr uint32_t ALIBI
Definition: variants.hpp:232
 
CK_TILE_DEVICE float tanh_fast< float >(float x)
Definition: math.hpp:1394
 
constexpr uint32_t LOGITS_SOFT_CAP
Definition: variants.hpp:231
 
constexpr uint32_t CUSTOM_MASK
Definition: variants.hpp:229
 
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition: bfloat16.hpp:393
 
constexpr uint32_t SLIDING_WINDOW
Definition: variants.hpp:230
 
Definition: variants.hpp:236
 
__device__ __forceinline__ bool LogitsMask(const Params ¶ms, [[maybe_unused]] uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:291
 
__device__ __host__ ComposedAttention()=default
 
static constexpr bool use_exp2
Definition: variants.hpp:237
 
__device__ __forceinline__ T QueryTransform(const Params ¶ms, T q) const
Definition: variants.hpp:244
 
__device__ __forceinline__ T LogitsTransform(const Params ¶ms, T logits, [[maybe_unused]] uint32_t batch_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:256
 
static constexpr bool use_logits_soft_cap
Definition: variants.hpp:239
 
Definition: variants.hpp:169
 
__device__ __host__ LogitsSoftCap()=default
 
__device__ __forceinline__ T LogitsTransform(const Params ¶ms, T logits, [[maybe_unused]] uint32_t batch_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:188
 
__device__ __forceinline__ bool LogitsMask(const Params ¶ms, [[maybe_unused]] uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:218
 
__device__ __forceinline__ T QueryTransform(const Params ¶ms, T q) const
Definition: variants.hpp:173
 
Definition: variants.hpp:63
 
float logits_soft_cap_rcp
Definition: variants.hpp:129
 
__host__ LogitsSoftCapParams(const ImplMask &impl_mask_, float sm_scale_, float logits_soft_cap_)
Definition: variants.hpp:87
 
const ImplMask & impl_mask
Definition: variants.hpp:126
 
__device__ __host__ LogitsSoftCapParams(const ImplMask &impl_mask_, float sm_scale_, float logits_soft_cap_, float logits_soft_cap_rcp_)
Definition: variants.hpp:108
 
__device__ LogitsSoftCapParams(const ImplMask &impl_mask_, float sm_scale_, float logits_soft_cap_)
Definition: variants.hpp:65
 
float sm_scale
Definition: variants.hpp:127
 
float logits_soft_cap
Definition: variants.hpp:128
 
Definition: variants.hpp:133
 
__device__ __forceinline__ bool LogitsMask(const Params ¶ms, [[maybe_unused]] uint32_t batch_idx, uint32_t qo_idx, uint32_t kv_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:156
 
__device__ __forceinline__ T QueryTransform(const Params ¶ms, T q) const
Definition: variants.hpp:137
 
__device__ __host__ StandardAttention()=default
 
__device__ __forceinline__ T LogitsTransform([[maybe_unused]] const Params ¶ms, T logits, [[maybe_unused]] uint32_t batch_idx, [[maybe_unused]] uint32_t qo_head_idx, [[maybe_unused]] uint32_t kv_head_idx) const
Definition: variants.hpp:145
 
Definition: variants.hpp:51
 
const ImplMask & impl_mask
Definition: variants.hpp:57
 
__device__ __host__ StandardAttentionParams(const ImplMask &impl_mask_, float sm_scale_)
Definition: variants.hpp:52
 
float sm_scale
Definition: variants.hpp:58