22         return (x >> traits::mant) & 
exp_mask;
 
   26         return (x >> (
traits::exp + traits::mant)) == _numeric::binary_zero;
 
   36         for(uint32_t i = 0; i < traits::mant; ++i)
 
   38             mantissa += std::ldexp(
static_cast<float>(x & 0b1), -(traits::mant - i));
 
   49     static constexpr 
int e8m0_bias = 127; 
 
   50     float sign                     = utils::is_positive(data) ? 1.0 : -1.0;
 
   52     float mant = utils::get_mantissa(data);
 
   54     return std::ldexp(sign * mant, 
exp + scale_exp - e8m0_bias);
 
   67         uint32_t max_bitwise = bit_cast<uint32_t>(max_value);
 
   75         bitwise_type mantissa =
 
   87         float prev_val = bit_cast<float>(prev_bit);
 
   88         float diff     = max_value - prev_val;
 
   90         float actual_max = max_value + (diff / 2);
 
   92         if(std::abs(value) < actual_max)
 
   95                    (exp << numeric_traits<T>::mant) | mantissa;
 
  108                        (exp << numeric_traits<T>::mant);
 
  114     x = bit_cast<uint32_t>(value);
 
  116     uint32_t head, mantissa;
 
  132     const int mini_denormal_act_exponent = 1 - mini_bias;
 
  134     int act_exponent, out_exponent, exponent_diff;
 
  136     bool is_subnorm = 
false;
 
  140         act_exponent  = exponent - bias + 1;
 
  141         exponent_diff = mini_denormal_act_exponent - act_exponent;
 
  146         act_exponent = exponent - bias;
 
  147         if(act_exponent <= mini_denormal_act_exponent)
 
  149             exponent_diff = mini_denormal_act_exponent - act_exponent;
 
  156         mantissa += (1UL << mfmt);
 
  160     shift_amount      = (shift_amount >= 64) ? 63 : shift_amount;
 
  161     bool midpoint     = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1));
 
  165     if(is_subnorm && std::abs(value) < std::abs(min_subnorm))
 
  168         if(std::abs(value) <= std::abs(min_subnorm - value))
 
  174     if(exponent_diff > 0)
 
  175         mantissa >>= exponent_diff;
 
  176     else if(exponent_diff == -1)
 
  177         mantissa <<= -exponent_diff;
 
  178     bool implicit_one = mantissa & (1 << mfmt);
 
  179     out_exponent      = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1);
 
  183     mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask;
 
  185     if(out_exponent == 0)
 
  187         if((1UL << mfmt) & mantissa)
 
  194         if((1UL << (mfmt + 1)) & mantissa)
 
  203     if(out_exponent == 0 && mantissa == 0)
 
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
 
__host__ T exp(T x)
Definition: math_v2.hpp:391
 
__host__ __device__ bool is_subnormal(T x)
Definition: mxfp_utils.hpp:45
 
Definition: cluster_descriptor.hpp:13
 
CK_TILE_HOST_DEVICE T::raw_type convert_to_type(float value)
Definition: mxfp_convert.hpp:58
 
CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, int scale_exp=127)
Definition: mxfp_convert.hpp:46
 
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:412
 
int32_t int32_t
Definition: integer.hpp:10
 
Definition: numeric.hpp:81
 
Definition: mxfp_convert.hpp:11
 
typename T::raw_type raw_type
Definition: mxfp_convert.hpp:15
 
static constexpr bool is_positive(raw_type x)
Definition: mxfp_convert.hpp:24
 
static constexpr double get_mantissa(raw_type x)
Definition: mxfp_convert.hpp:33
 
static constexpr int exp_mask
Definition: mxfp_convert.hpp:17
 
static constexpr int get_exponent(raw_type x)
Definition: mxfp_convert.hpp:19
 
static constexpr bool is_subnormal(raw_type x)
Definition: mxfp_convert.hpp:28
 
Definition: numeric.hpp:18
 
static constexpr CK_TILE_HOST_DEVICE T max()
Definition: numeric.hpp:26