14 template <
typename DTYPE>
17 return DTYPE::dataInfo.hasInf;
36 return static_cast<int>(x);
42 return get_exponent_value<T>(x) == 0;
48 double mantissa = is_subnormal<T>(x) ? 0.0f : 1.0f;
50 for(uint i = 0; i < NumericUtils<T>::mant; i++)
74 if(is_subnormal<T>(data))
78 float d_mant = get_mantissa_value<T>(data);
80 float data_value = d_sign * d_exp * d_mant;
84 return data_value * scale_value;
130 float diff = max_value - prev_val;
132 float actual_max = max_value + (diff / 2);
134 if(std::abs(value) < actual_max)
137 (exp << NumericUtils<T>::mant) | mantissa;
141 if(!get_data_has_inf<T>())
150 (exp << NumericUtils<T>::mant);
156 x = bit_cast<uint32_t>(value);
158 uint32_t head, mantissa;
159 int32_t exponent, bias;
174 const int mini_denormal_act_exponent = 1 - mini_bias;
176 int act_exponent, out_exponent, exponent_diff;
178 bool is_subnorm =
false;
182 act_exponent = exponent - bias + 1;
183 exponent_diff = mini_denormal_act_exponent - act_exponent;
188 act_exponent = exponent - bias;
189 if(act_exponent <= mini_denormal_act_exponent)
191 exponent_diff = mini_denormal_act_exponent - act_exponent;
198 mantissa += (1UL << mfmt);
202 shift_amount = (shift_amount >= 64) ? 63 : shift_amount;
203 bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1));
207 if(is_subnorm && std::abs(value) < std::abs(min_subnorm))
210 if(std::abs(value) <= std::abs(min_subnorm - value))
216 if(exponent_diff > 0)
217 mantissa >>= exponent_diff;
218 else if(exponent_diff == -1)
219 mantissa <<= -exponent_diff;
220 bool implicit_one = mantissa & (1 << mfmt);
221 out_exponent = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1);
225 mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask;
227 if(out_exponent == 0)
229 if((1UL << mfmt) & mantissa)
236 if((1UL << (mfmt + 1)) & mantissa)
245 if(out_exponent == 0 && mantissa == 0)
255 template <
typename T>
284 float diff = max_value - prev_val;
286 float actual_max = max_value + (diff / 2);
288 if(std::abs(value) < actual_max)
290 double d_max_value =
static_cast<double>(max_value);
291 double d_actual_max =
static_cast<double>(actual_max);
292 double d_value =
static_cast<double>(value);
293 double d_is = std::abs(d_max_value - d_actual_max);
294 double d_seed =
static_cast<double>(seed);
295 double d_prob = 1.0f - (std::abs(d_value - d_max_value) / d_is);
297 double thresh = UINT_MAX * d_prob;
299 if(!get_data_has_inf<T>() || d_seed <= thresh)
307 | (exp << NumericUtils<T>::mant);
312 if(!get_data_has_inf<T>())
318 | (exp << NumericUtils<T>::mant);
323 uint32_t f32 = bit_cast<uint32_t>(value);
333 int32_t
exp = f32_exp;
334 auto mant = f32_mant;
335 bool subnorm =
false;
368 mant += seed >> sr_shift;
376 auto biased_exp =
static_cast<uint32_t
>(
exp);
380 auto val = sign | biased_exp << NumericUtils<T>::mant | mant;
__host__ T exp(T x)
Definition: math_v2.hpp:391
__host__ T pow(T x, T gamma)
Definition: math_v2.hpp:427
Definition: check_err.hpp:24
T convert_to_type_sr(float value, uint32_t seed)
Definition: mxfp_utils.hpp:256
__host__ __device__ T sat_convert_to_type(float value)
__host__ __device__ bool is_subnormal(T x)
Definition: mxfp_utils.hpp:40
__host__ __device__ bool get_data_has_inf()
Definition: mxfp_utils.hpp:62
__host__ __device__ T sat_convert_to_type_sr(float value, uint32_t seed)
__host__ __device__ float convert_to_float(T data, int scale_exp)
Definition: mxfp_utils.hpp:68
T convert_to_type(float value)
Definition: mxfp_utils.hpp:97
__host__ __device__ int get_exponent_value(T x)
Definition: mxfp_utils.hpp:30
__host__ __device__ bool is_zero(e8m0_bexp_t const scale, T const data)
__host__ __device__ bool is_inf(e8m0_bexp_t const scale, T const data)
__host__ __device__ double get_mantissa_value(T x)
Definition: mxfp_utils.hpp:46
__host__ __device__ bool is_nan(e8m0_bexp_t const scale, T const data)
bool getDataHasInf()
Definition: mxfp_utils.hpp:15
__host__ __device__ float to_float(e8m0_bexp_t const scale, T const data)
Definition: data_type.hpp:2831
__host__ static constexpr __device__ T Max()
Definition: data_type.hpp:2833
Definition: data_type.hpp:3078
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:25
Definition: mxfp_utils.hpp:9
float value_float
Definition: mxfp_utils.hpp:10
uint32_t value_bitwise
Definition: mxfp_utils.hpp:11