22 return (x >> traits::mant) &
exp_mask;
26 return (x >> (
traits::exp + traits::mant)) == _numeric::binary_zero;
36 for(uint32_t i = 0; i < traits::mant; ++i)
38 mantissa += std::ldexp(
static_cast<float>(x & 0b1), -(traits::mant - i));
49 static constexpr
int e8m0_bias = 127;
50 float sign = utils::is_positive(data) ? 1.0 : -1.0;
52 float mant = utils::get_mantissa(data);
54 return std::ldexp(sign * mant,
exp + scale_exp - e8m0_bias);
67 uint32_t max_bitwise = bit_cast<uint32_t>(max_value);
75 bitwise_type mantissa =
87 float prev_val = bit_cast<float>(prev_bit);
88 float diff = max_value - prev_val;
90 float actual_max = max_value + (diff / 2);
92 if(std::abs(value) < actual_max)
95 (exp << numeric_traits<T>::mant) | mantissa;
108 (exp << numeric_traits<T>::mant);
114 x = bit_cast<uint32_t>(value);
116 uint32_t head, mantissa;
132 const int mini_denormal_act_exponent = 1 - mini_bias;
134 int act_exponent, out_exponent, exponent_diff;
136 bool is_subnorm =
false;
140 act_exponent = exponent - bias + 1;
141 exponent_diff = mini_denormal_act_exponent - act_exponent;
146 act_exponent = exponent - bias;
147 if(act_exponent <= mini_denormal_act_exponent)
149 exponent_diff = mini_denormal_act_exponent - act_exponent;
156 mantissa += (1UL << mfmt);
160 shift_amount = (shift_amount >= 64) ? 63 : shift_amount;
161 bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1));
165 if(is_subnorm && std::abs(value) < std::abs(min_subnorm))
168 if(std::abs(value) <= std::abs(min_subnorm - value))
174 if(exponent_diff > 0)
175 mantissa >>= exponent_diff;
176 else if(exponent_diff == -1)
177 mantissa <<= -exponent_diff;
178 bool implicit_one = mantissa & (1 << mfmt);
179 out_exponent = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1);
183 mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask;
185 if(out_exponent == 0)
187 if((1UL << mfmt) & mantissa)
194 if((1UL << (mfmt + 1)) & mantissa)
203 if(out_exponent == 0 && mantissa == 0)
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
__host__ T exp(T x)
Definition: math_v2.hpp:391
__host__ __device__ bool is_subnormal(T x)
Definition: mxfp_utils.hpp:45
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE T::raw_type convert_to_type(float value)
Definition: mxfp_convert.hpp:58
CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, int scale_exp=127)
Definition: mxfp_convert.hpp:46
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:412
int32_t int32_t
Definition: integer.hpp:10
Definition: numeric.hpp:81
Definition: mxfp_convert.hpp:11
typename T::raw_type raw_type
Definition: mxfp_convert.hpp:15
static constexpr bool is_positive(raw_type x)
Definition: mxfp_convert.hpp:24
static constexpr double get_mantissa(raw_type x)
Definition: mxfp_convert.hpp:33
static constexpr int exp_mask
Definition: mxfp_convert.hpp:17
static constexpr int get_exponent(raw_type x)
Definition: mxfp_convert.hpp:19
static constexpr bool is_subnormal(raw_type x)
Definition: mxfp_convert.hpp:28
Definition: numeric.hpp:18
static constexpr CK_TILE_HOST_DEVICE T max()
Definition: numeric.hpp:26