19 __host__ 
inline int clz(uint32_t x) { 
return __builtin_clz(x); }
 
   20 __device__ 
inline int clz(uint32_t x) { 
return __clz(x); }
 
   28 template <
typename X, 
typename Y, 
bool negative_zero_nan, 
bool clip, 
bool stoch>
 
   33     constexpr 
int out_mant = NumericUtils<Y>::mant;
 
   37     constexpr 
int in_mant = NumericUtils<X>::mant;
 
   40     uint32_t head, mantissa, sign;
 
   42     constexpr Y nan_code        = 0x80;
 
   43     constexpr uint32_t nan_mask = NumericUtils<X>::nan_mask;
 
   46     using T_bitwise     = 
typename NumericUtils<X>::bitwise_type;
 
   47     T_bitwise x_bitwise = bit_cast<T_bitwise>(x);
 
   50     head     = x_bitwise & NumericUtils<X>::head_mask;
 
   51     mantissa = x_bitwise & NumericUtils<X>::mant_mask;
 
   52     exponent = (head >> in_mant) & NumericUtils<X>::exp_mask;
 
   53     sign     = head >> (in_exp + in_mant);
 
   54     bias     = NumericUtils<X>::bias;
 
   56     uint32_t signed_inf   = (sign << (in_exp + in_mant)) + (((1 << in_exp) - 1) << in_mant);
 
   57     uint32_t drop_mask    = (1 << (in_mant - out_mant)) - 1;
 
   58     constexpr 
int max_exp = (1 << out_exp) - (negative_zero_nan ? 1 : 2);
 
   60     if constexpr(negative_zero_nan)
 
   62         if((x_bitwise & nan_mask) == nan_mask)
 
   67         if((x_bitwise & nan_mask) == nan_mask)
 
   68             return signed_inf + (mantissa != 0 ? 1 : 0);
 
   82     const int out_bias                  = (1 << (out_exp - 1)) - 1 + (negative_zero_nan ? 1 : 0);
 
   83     const int out_denormal_act_exponent = 1 - out_bias; 
 
   88     int act_exponent, out_exponent, exponent_diff;
 
   98         act_exponent  = exponent - bias + 1;
 
   99         exponent_diff = out_denormal_act_exponent -
 
  104         act_exponent = exponent - bias;
 
  105         if(act_exponent <= out_denormal_act_exponent)
 
  112             exponent_diff = out_denormal_act_exponent - act_exponent;
 
  120         mantissa += (1 << in_mant); 
 
  123     bool midpoint = (mantissa & ((1 << (in_mant - out_mant + exponent_diff)) - 1)) ==
 
  124                     (1 << (in_mant - out_mant + exponent_diff - 1));
 
  130     if(exponent_diff > 0)
 
  131         mantissa >>= exponent_diff;
 
  132     else if(exponent_diff == -1)
 
  133         mantissa <<= -exponent_diff;
 
  134     bool implicit_one = mantissa & (1 << in_mant);
 
  137         (act_exponent + exponent_diff)  + out_bias - (implicit_one ? 0 : 1);
 
  142         (1 << (in_mant - out_mant)); 
 
  143     mantissa += (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & drop_mask;
 
  146     if(out_exponent == 0)
 
  148         if((1 << in_mant) & mantissa)
 
  156         if((1 << (in_mant + 1)) & mantissa)
 
  164     mantissa >>= (in_mant - out_mant);
 
  166     if(out_exponent > max_exp)
 
  170             mantissa     = (1 << out_mant) - 1;
 
  171             out_exponent = max_exp;
 
  180     if(out_exponent == 0 && mantissa == 0)
 
  181         return negative_zero_nan ? 0 : (sign << (out_exp + out_mant));
 
  182     mantissa &= (1 << out_mant) - 1;
 
  183     return (sign << (out_exp + out_mant)) | (out_exponent << out_mant) | mantissa;
 
  186 template <
typename X, 
typename Y, 
bool negative_zero_nan>
 
  191     constexpr 
int in_mant = NumericUtils<X>::mant;
 
  195     constexpr 
int out_mant = NumericUtils<Y>::mant;
 
  198     constexpr X nan_code = 0x80;
 
  199     using T_bitwise      = 
typename NumericUtils<Y>::bitwise_type;
 
  201     constexpr T_bitwise Inf_bitwise    = NumericUtils<Y>::Inf;
 
  202     constexpr T_bitwise NegInf_bitwise = NumericUtils<Y>::NegInf;
 
  203     constexpr T_bitwise NaN_bitwise    = NumericUtils<Y>::NaN;
 
  204     constexpr T_bitwise Neg0_bitwise   = NumericUtils<Y>::Neg0;
 
  206     constexpr Y Inf    = bit_cast<Y>(Inf_bitwise);
 
  207     constexpr Y NegInf = bit_cast<Y>(NegInf_bitwise);
 
  208     constexpr Y NaN    = bit_cast<Y>(NaN_bitwise);
 
  209     constexpr Y Neg0   = bit_cast<Y>(Neg0_bitwise);
 
  213         return static_cast<Y
>(0);
 
  216     uint32_t sign     = x >> (in_exp + in_mant);
 
  217     uint32_t mantissa = x & ((1 << in_mant) - 1);
 
  218     int exponent      = (x & 0x7F) >> in_mant;
 
  220     constexpr 
int exp_low_cutoff =
 
  221         (1 << (out_exp - 1)) - (1 << (in_exp - 1)) + 1 - (negative_zero_nan ? 1 : 0);
 
  224     if constexpr(negative_zero_nan)
 
  233         if(exponent == ((1 << in_exp) - 1))
 
  234             return (mantissa == 0) ? (sign ? NegInf : Inf) : NaN;
 
  237     if constexpr((NumericUtils<Y>::mant == 10) && (NumericUtils<X>::mant == 2) &&
 
  242         return bit_cast<Y>(retval);
 
  249         int sh = 1 + 
clz(mantissa) - (32 - in_mant);
 
  252         mantissa &= ((1 << in_mant) - 1);
 
  254     exponent += exp_low_cutoff - 1;
 
  255     mantissa <<= out_mant - in_mant;
 
  260         mantissa |= 1 << out_mant;
 
  261         mantissa >>= 1 - exponent;
 
  265     retval = (sign << (out_exp + out_mant)) | (exponent << out_mant) | mantissa;
 
  266     return bit_cast<Y>(retval);
 
  271 template <
typename X, 
typename Y, 
bool negative_zero_nan, 
bool clip, 
bool stoch>
 
  275     constexpr 
bool is_half  = std::is_same<X, half_t>::value;
 
  276     constexpr 
bool is_float = std::is_same<X, float>::value;
 
  277     static_assert(is_half || is_float, 
"Only half and float can be casted.");
 
  279     return run_cast_to_f8<X, Y, negative_zero_nan, clip, stoch>(x, rng);
 
  282 template <
typename X, 
typename Y, 
bool negative_zero_nan>
 
  286     constexpr 
bool is_half  = std::is_same<Y, half_t>::value;
 
  287     constexpr 
bool is_float = std::is_same<Y, float>::value;
 
  288     static_assert(is_half || is_float, 
"only half and float are supported.");
 
  290     return run_cast_from_f8<X, Y, negative_zero_nan>(x);
 
__host__ T exp(T x)
Definition: math_v2.hpp:391
 
Definition: check_err.hpp:24
 
__host__ __device__ Y cast_from_f8(X x)
Definition: f8_utils.hpp:283
 
__host__ __device__ Y cast_to_f8(X x, uint32_t rng)
Definition: f8_utils.hpp:272
 
CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng=0)
Definition: float8.hpp:250
 
CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
Definition: float8.hpp:480
 
f8_rounding_mode
Definition: f8_utils.hpp:14
 
__host__ int clz(uint32_t x)
Definition: f8_utils.hpp:19