13 #include <type_traits>
17 #if(defined(__gfx94__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
18 #define CK_TILE_FP8_CVT_DEVICE 1
20 #define CK_TILE_FP8_CVT_DEVICE 0
70 #if CK_TILE_USE_CUSTOM_DATA_TYPE
71 struct alignas(1) float8_e4m3_t
73 static constexpr
int exponent = 4;
74 static constexpr
int mantissa = 3;
75 #if CK_TILE_USE_OCP_FP8
76 static constexpr
int bias = 7;
78 static constexpr
int bias = 8;
80 using raw_type = uint8_t;
84 static constexpr float8_e4m3_t
bit_cast(raw_type x)
92 constexpr float8_e4m3_t() : data() {}
96 explicit constexpr float8_e4m3_t(
const float& x) : data(
float_to_fp8_raw(x)) {}
100 explicit constexpr float8_e4m3_t(
const int& x) : data(
float_to_fp8_raw(static_cast<float>(x)))
106 explicit constexpr float8_e4m3_t(
const unsigned int& x)
113 explicit constexpr
operator float()
const {
return fp8_to_float_raw(data); }
117 explicit constexpr
operator int()
const {
return static_cast<int>(
fp8_to_float_raw(data)); }
121 constexpr raw_type& get() {
return data; }
124 constexpr raw_type get()
const {
return data; }
126 using fp8_t = float8_e4m3_t;
127 using fp8_raw_t =
typename fp8_t::raw_type;
129 struct alignas(1) float8_e5m2_t
131 static constexpr
int exponent = 5;
132 static constexpr
int mantissa = 2;
133 #if CK_TILE_USE_OCP_FP8
134 static constexpr
int bias = 15;
136 static constexpr
int bias = 16;
138 using raw_type = uint8_t;
142 static constexpr float8_e5m2_t
bit_cast(raw_type x)
150 constexpr float8_e5m2_t() : data() {}
154 explicit constexpr float8_e5m2_t(
const float& x) : data(
float_to_bf8_raw(x)) {}
158 explicit constexpr float8_e5m2_t(
const int& x) : data(
float_to_bf8_raw(static_cast<float>(x)))
164 explicit constexpr float8_e5m2_t(
const unsigned int& x)
171 explicit constexpr
operator float()
const {
return bf8_to_float_raw(data); }
175 explicit constexpr
operator int()
const {
return static_cast<int>(
bf8_to_float_raw(data)); }
179 constexpr raw_type& get() {
return data; }
182 constexpr raw_type get()
const {
return data; }
184 using bf8_t = float8_e5m2_t;
185 using bf8_raw_t =
typename bf8_t::raw_type;
191 struct native_t<
fp8_t>
193 using type = _BitInt(8);
197 struct native_t<
bf8_t>
199 using type =
unsigned _BitInt(8);
215 static constexpr
int exp = 4;
216 static constexpr
int mant = 3;
217 #if CK_TILE_USE_OCP_FP8
218 static constexpr
int bias = 7;
221 static constexpr
int bias = 8;
224 static constexpr uint8_t abs_mask = 0x7F;
233 static constexpr
int exp = 5;
234 static constexpr
int mant = 2;
235 #if CK_TILE_USE_OCP_FP8
236 static constexpr
int bias = 15;
239 static constexpr
int bias = 16;
242 static constexpr uint8_t abs_mask = 0x7F;
249 template <
typename SrcT,
typename DstT,
bool clip = true,
bool stoch = false>
252 static_assert(std::is_same<DstT, fp8_t>::value || std::is_same<DstT, bf8_t>::value,
253 "DstT type must be fp8 or bf8.");
255 constexpr
bool is_half = std::is_same<SrcT, half_t>::value;
256 constexpr
bool is_float = std::is_same<SrcT, float>::value;
257 static_assert(is_half || is_float,
"Only half and float can be cast to f8");
262 constexpr
bool is_fnuz =
270 SrcT_bitwise src_bitwise = bit_cast<SrcT_bitwise>(src);
272 unsigned long long head, mantissa;
275 unsigned long long fInf, abs_mask;
280 sign = head >> (SrcT_exp + SrcT_mant);
285 unsigned int signed_inf = 0;
286 unsigned int nan = 0;
287 if constexpr(is_fnuz)
289 signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
294 if constexpr(DstT_exp == 4)
296 signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
300 signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
302 nan = (sign << 7) + 0x7f;
305 unsigned long long ifmax = 0;
306 if constexpr(is_float)
308 if constexpr(DstT_exp == 5)
314 if constexpr(is_fnuz)
324 else if constexpr(is_half)
326 if constexpr(DstT_exp == 5)
332 if constexpr(is_fnuz)
344 if((src_bitwise & fInf) == fInf)
346 if constexpr(is_fnuz)
349 return mantissa != 0 ? nan : signed_inf;
352 if((src_bitwise & abs_mask) > ifmax)
370 const int f8_bias = (1 << (DstT_exp - 1)) - 1 + (is_fnuz ? 1 : 0);
371 const int f8_denormal_act_exponent = 1 - f8_bias;
376 int act_exponent, f8_exponent, exponent_diff;
387 act_exponent = exponent - bias + 1;
388 exponent_diff = f8_denormal_act_exponent -
393 act_exponent = exponent - bias;
394 if(act_exponent <= f8_denormal_act_exponent)
401 exponent_diff = f8_denormal_act_exponent - act_exponent;
409 mantissa += (1ull << SrcT_mant);
412 bool midpoint = (mantissa & ((1ull << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) ==
413 (1ull << (SrcT_mant - DstT_mant + exponent_diff - 1));
421 if(exponent_diff > 0)
422 mantissa >>= exponent_diff;
423 else if(exponent_diff == -1)
424 mantissa <<= -exponent_diff;
425 bool implicit_one = mantissa & (1ull << SrcT_mant);
429 (act_exponent + exponent_diff) + f8_bias - (implicit_one ? 0 : 1);
432 unsigned long long drop_mask = (1ull << (SrcT_mant - DstT_mant)) - 1;
434 mantissa & (1ull << (SrcT_mant -
437 (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
442 if((1ull << SrcT_mant) & mantissa)
449 if((1ull << (SrcT_mant + 1)) & mantissa)
456 mantissa >>= (SrcT_mant - DstT_mant);
459 const int max_exp = (1 << DstT_exp) - 1;
460 if(f8_exponent > max_exp)
464 mantissa = (1 << DstT_mant) - 1;
465 f8_exponent = max_exp;
473 if(f8_exponent == 0 && mantissa == 0)
474 return is_fnuz ? 0 : (sign << 7);
475 mantissa &= (1 << DstT_mant) - 1;
476 return (sign << 7) | (f8_exponent << DstT_mant) | mantissa;
479 template <
typename SrcT,
typename DstT,
bool clip = true>
482 static_assert(std::is_same<SrcT, fp8_t>::value || std::is_same<SrcT, bf8_t>::value,
483 "SrcT type must be fp8 or bf8.");
486 constexpr
bool is_fnuz =
490 constexpr
bool is_half = std::is_same<DstT, half_t>::value;
491 constexpr
bool is_float = std::is_same<DstT, float>::value;
492 static_assert(is_half || is_float,
"DstT type must be half_t or float.");
503 DstT fmax{0}, fmin{0};
505 if constexpr(is_half)
510 else if constexpr(is_float)
521 unsigned long long sign = x >> 7;
522 unsigned long long mantissa = x & ((1 << SrcT_mant) - 1);
523 int exponent = (x & 0x7F) >> SrcT_mant;
524 if constexpr(is_fnuz)
526 if((x & 0xff) == 0x80)
537 if constexpr(SrcT_exp == 4)
539 if((x & 0x7F) == 0x7F)
544 else if((x & 0x7C) == 0x7C)
550 return sign ? fmin : fmax;
552 return sign ? fNegInf : fInf;
560 if constexpr(SrcT_exp == 5 && is_half && !is_fnuz)
563 return bit_cast<DstT>(retval);
566 const int exp_low_cutoff =
567 (1 << (DstT_exp - 1)) - (1 << (SrcT_exp - 1)) + 1 - (is_fnuz ? 1 : 0);
572 int sh = 1 +
clz(mantissa) - (32 - SrcT_mant);
575 mantissa &= ((1ull << SrcT_mant) - 1);
577 exponent += exp_low_cutoff - 1;
578 mantissa <<= DstT_mant - SrcT_mant;
583 mantissa |= 1 << DstT_mant;
584 mantissa >>= 1 - exponent;
588 retval = (sign << (DstT_exp + DstT_mant)) | (exponent << DstT_mant) | mantissa;
590 return bit_cast<DstT>(retval);
593 template <
typename X,
typename Y,
bool clip,
bool stoch>
596 return bit_cast<Y>(run_cast_to_f8<X, Y, clip, stoch>(x, rng));
599 #if CK_TILE_FP8_CVT_DEVICE
603 template <fp8_
interpretation
interpret,
bool saturate,
bool stochastic_rounding = false>
604 CK_TILE_DEVICE uint8_t cast_to_f8_from_f32(
float v,
unsigned int rng = 0)
611 unsigned char i8val[4];
614 unsigned int ival = 0;
617 if constexpr(saturate)
621 if((val.i32val & 0x7F800000) != 0x7F800000)
623 val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
628 if((val.i32val & 0x7F800000) != 0x7F800000)
630 val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0);
635 if((val.i32val & 0x7F800000) != 0x7F800000)
637 val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
642 if constexpr(stochastic_rounding)
646 ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
647 : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0);
649 i8data = val.i8val[0];
655 ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival,
false)
656 : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval,
661 i8data = val.i8val[0];
682 template <
typename SrcT,
typename DstT>
685 constexpr
bool clip =
true;
686 constexpr
int seed = 42;
688 #if CK_TILE_FP8_CVT_DEVICE
689 return impl::cast_to_f8_from_f32<numeric_traits<DstT>::f8_interpret, clip,
true>(x, rng);
691 return bit_cast<typename numeric_traits<DstT>::bitwise_type>(
692 impl::cast_to_f8<SrcT, DstT, clip, true>(x, rng));
708 template <
typename SrcT,
typename DstT>
711 constexpr
bool clip =
true;
712 #if CK_TILE_FP8_CVT_DEVICE
713 return impl::cast_to_f8_from_f32<numeric_traits<DstT>::f8_interpret, clip,
false>(x, 0);
715 return bit_cast<typename numeric_traits<DstT>::bitwise_type>(
716 impl::cast_to_f8<SrcT, DstT, clip, false>(x, 0));
720 template <fp8_rounding_mode rounding>
725 return float_to_fp8_rtn_raw<float, fp8_t>(x);
729 return float_to_fp8_sr_raw<float, fp8_t>(x);
737 template <fp8_rounding_mode rounding>
742 return float_to_fp8_rtn_raw<float, bf8_t>(x);
746 return float_to_fp8_sr_raw<float, bf8_t>(x);
756 #if CK_TILE_FP8_CVT_DEVICE
758 uint32_t i32val =
static_cast<uint32_t
>(x);
759 fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
763 return impl::run_cast_from_f8<fp8_t, float>(bit_cast<fp8_t>(x));
769 #if CK_TILE_FP8_CVT_DEVICE
771 uint32_t i32val =
static_cast<uint32_t
>(x);
772 fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
776 return impl::run_cast_from_f8<bf8_t, float>(bit_cast<bf8_t>(x));
799 #if CK_TILE_USE_OCP_FP8
801 struct numeric<
fp8_t>
806 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x08));
812 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0xfe));
818 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x7e));
825 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x20));
832 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x18));
838 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x7F));
844 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0xFF));
850 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x01));
855 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0));
860 struct numeric<
bf8_t>
865 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x04));
871 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0xfb));
877 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x7b));
883 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x34));
890 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x30));
896 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x7c));
902 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x7F));
908 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0xFF));
914 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x01));
919 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0));
929 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x08));
935 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0xff));
941 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x7f));
947 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x20));
957 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x30));
963 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x80));
969 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x80));
975 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x80));
981 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0x01));
986 return bit_cast<fp8_t>(
static_cast<fp8_raw_t>(0));
996 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x04));
1002 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0xff));
1008 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x7f));
1014 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x34));
1024 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x38));
1030 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x80));
1036 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x80));
1042 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x80));
1048 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0x01));
1053 return bit_cast<bf8_t>(
static_cast<bf8_raw_t>(0));
1058 #if CK_TILE_USE_CUSTOM_DATA_TYPE
1064 template <
typename T>
1067 static_assert(std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>,
1068 "Only fp8_t and bf8_t are supported");
1075 uint8_t xx = bit_cast<fp8_raw_t>(x);
1077 #if CK_TILE_USE_OCP_FP8
1078 return (xx & 0x7f) == 0x7f;
1083 #if CK_TILE_USE_CUSTOM_DATA_TYPE
1085 fp8_t sqrt(
fp8_t x) {
return static_cast<fp8_t>(__builtin_amdgcn_sqrtf(
static_cast<float>(x))); };
1088 fp8_t exp(
fp8_t x) {
return static_cast<fp8_t>(__ocml_exp_f32(
static_cast<float>(x))); };
1100 uint8_t xx = bit_cast<bf8_raw_t>(x);
1102 #if CK_TILE_USE_OCP_FP8
1103 return (xx & 0x7f) > 0x7c;
1109 #if CK_TILE_USE_CUSTOM_DATA_TYPE
1111 bf8_t sqrt(
bf8_t x) {
return static_cast<bf8_t>(__builtin_amdgcn_sqrtf(
static_cast<float>(x))); };
1114 bf8_t exp(
bf8_t x) {
return static_cast<bf8_t>(__ocml_exp_f32(
static_cast<float>(x))); };
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_FLOAT_TO_FP8_DEFAULT
Definition: config.hpp:78
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng=0)
Definition: float8.hpp:250
CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
Definition: float8.hpp:480
CK_TILE_HOST_DEVICE Y cast_to_f8(X x, uint32_t rng)
Definition: float8.hpp:594
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:421
fp8_interpretation
FP8 interpretation used in conversion algorithms.
Definition: float8.hpp:38
_BitInt(8) fp8_t
Definition: float8.hpp:204
CK_TILE_HOST_DEVICE fp8_t float_to_fp8(float x, constant< rounding >={})
Definition: float8.hpp:781
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t)
Definition: float8.hpp:754
constexpr CK_TILE_HOST_DEVICE Y bit_cast(const X &x)
Definition: bit_cast.hpp:11
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t)
Definition: float8.hpp:767
fp8_rounding_mode
Definition: float8.hpp:29
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition: bfloat16.hpp:406
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float, constant< rounding >={})
Definition: float8.hpp:721
uint8_t fp8_raw_t
Definition: float8.hpp:205
CK_TILE_HOST_DEVICE float bf8_to_float(bf8_t x)
Definition: float8.hpp:794
CK_TILE_HOST_DEVICE numeric_traits< DstT >::bitwise_type float_to_fp8_sr_raw(SrcT x)
Converts a floating-point value to an 8-bit floating-point representation with stochastic rounding.
Definition: float8.hpp:683
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:412
CK_TILE_HOST int clz(uint32_t x)
Definition: math.hpp:264
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition: bfloat16.hpp:393
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
uint8_t bf8_raw_t
Definition: float8.hpp:207
CK_TILE_HOST_DEVICE bf8_t float_to_bf8(float x, constant< rounding >={})
Definition: float8.hpp:787
CK_TILE_HOST_DEVICE bool isnan(const bfloat16_t &x)
Definition: bfloat16.hpp:399
CK_TILE_HOST_DEVICE numeric_traits< DstT >::bitwise_type float_to_fp8_rtn_raw(SrcT x)
Converts a floating-point value to an 8-bit floating-point representation with rounding to nearest ev...
Definition: float8.hpp:709
CK_TILE_HOST_DEVICE float fp8_to_float(fp8_t x)
Definition: float8.hpp:792
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float, constant< rounding >={})
Definition: float8.hpp:738
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition: bfloat16.hpp:418
Definition: integral_constant.hpp:13
remove_cvref_t< T > type
Definition: vector_type.hpp:26
static constexpr CK_TILE_HOST_DEVICE bf8_t min()
Definition: float8.hpp:994
static constexpr CK_TILE_HOST_DEVICE bf8_t quiet_NaN()
Definition: float8.hpp:1034
static constexpr CK_TILE_HOST_DEVICE bf8_t lowest()
Definition: float8.hpp:1000
static constexpr CK_TILE_HOST_DEVICE bf8_t round_error()
Definition: float8.hpp:1022
static constexpr CK_TILE_HOST_DEVICE bf8_t signaling_NaN()
Definition: float8.hpp:1040
static constexpr CK_TILE_HOST_DEVICE bf8_t denorm_min()
Definition: float8.hpp:1046
static constexpr CK_TILE_HOST_DEVICE bf8_t epsilon()
Definition: float8.hpp:1012
static constexpr CK_TILE_HOST_DEVICE bf8_t infinity()
Definition: float8.hpp:1028
static constexpr CK_TILE_HOST_DEVICE bf8_t max()
Definition: float8.hpp:1006
static constexpr CK_TILE_HOST_DEVICE bf8_t zero()
Definition: float8.hpp:1051
static constexpr CK_TILE_HOST_DEVICE fp8_t signaling_NaN()
Definition: float8.hpp:973
static constexpr CK_TILE_HOST_DEVICE fp8_t zero()
Definition: float8.hpp:984
static constexpr CK_TILE_HOST_DEVICE fp8_t min()
Definition: float8.hpp:927
static constexpr CK_TILE_HOST_DEVICE fp8_t lowest()
Definition: float8.hpp:933
static constexpr CK_TILE_HOST_DEVICE fp8_t epsilon()
Definition: float8.hpp:945
static constexpr CK_TILE_HOST_DEVICE fp8_t quiet_NaN()
Definition: float8.hpp:967
static constexpr CK_TILE_HOST_DEVICE fp8_t max()
Definition: float8.hpp:939
static constexpr CK_TILE_HOST_DEVICE fp8_t denorm_min()
Definition: float8.hpp:979
static constexpr CK_TILE_HOST_DEVICE fp8_t round_error()
Definition: float8.hpp:955
static constexpr CK_TILE_HOST_DEVICE fp8_t infinity()
Definition: float8.hpp:961
bf8_raw_t bitwise_type
Definition: float8.hpp:231
fp8_raw_t bitwise_type
Definition: float8.hpp:213
Definition: numeric.hpp:81
static constexpr int PackedSize
Definition: numeric.hpp:82
Definition: numeric.hpp:18
static constexpr CK_TILE_HOST_DEVICE T lowest()
Definition: numeric.hpp:23
static constexpr CK_TILE_HOST_DEVICE T min()
Definition: numeric.hpp:20
static constexpr CK_TILE_HOST_DEVICE T quiet_NaN()
Definition: numeric.hpp:41
static constexpr CK_TILE_HOST_DEVICE T signaling_NaN()
Definition: numeric.hpp:47
static constexpr CK_TILE_HOST_DEVICE T max()
Definition: numeric.hpp:26
static constexpr CK_TILE_HOST_DEVICE T round_error()
Definition: numeric.hpp:32
static constexpr CK_TILE_HOST_DEVICE T zero()
Definition: numeric.hpp:58
static constexpr CK_TILE_HOST_DEVICE T denorm_min()
Definition: numeric.hpp:53
static constexpr CK_TILE_HOST_DEVICE T epsilon()
Definition: numeric.hpp:29
static constexpr CK_TILE_HOST_DEVICE T infinity()
Definition: numeric.hpp:38
Definition: random.hpp:17
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_)
Definition: numeric.hpp:106