13 #ifndef CK_USE_FNUZ_FP8 
   14 #define CK_USE_FNUZ_FP8 0 
   17 #ifndef CK_USE_OCP_FP8 
   18 #define CK_USE_OCP_FP8 0 
   21 #if(defined(__gfx942__) || defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && \ 
   22     __HIP_DEVICE_COMPILE__ 
   23 #define CK_FP8_CVT_FAST_PATH 1 
   25 #define CK_FP8_CVT_FAST_PATH 0 
   28 #if(defined(__gfx1200__) || defined(__gfx1201__) || defined(__gfx950__)) && __HIP_DEVICE_COMPILE__ 
   29 #define CK_OCP_FP8_CVT_FAST_PATH 1 
   31 #define CK_OCP_FP8_CVT_FAST_PATH 0 
   64 typedef _Float16 
half2_t __attribute__((ext_vector_type(2)));
 
   65 typedef ushort 
ushortx2_t __attribute__((ext_vector_type(2)));
 
   66 typedef short shortx2_t __attribute__((ext_vector_type(2)));
 
   67 typedef float float2_t __attribute__((ext_vector_type(2)));
 
   69 __host__ __device__ 
static inline constexpr 
bool fnuz_f8_is_nan(
f8_fnuz_t a)
 
   71     return static_cast<unsigned char>(a) == 0x80;
 
   73 __host__ __device__ 
static inline constexpr 
bool fnuz_bf8_is_nan(
bf8_fnuz_t a)
 
   75     return static_cast<unsigned char>(a) == 0x80;
 
   78 __host__ __device__ 
static inline constexpr 
bool ocp_f8_is_nan(
fp8_storage_t a)
 
   80     return (a & 0x7f) == 0x7f;
 
   82 __host__ __device__ 
static inline constexpr 
bool ocp_bf8_is_nan(
fp8_storage_t a)
 
   84     return (a & 0x7f) > 0x7c;
 
   90 template <
typename T, 
int wm, 
int we, 
bool is_fnuz, 
bool clip = false>
 
   91 __host__ __device__ 
static inline T cast_from_f8(
fp8_storage_t x)
 
   93     constexpr 
bool is_half   = __hip_internal::is_same<T, _Float16>::value;
 
   94     constexpr 
bool is_float  = __hip_internal::is_same<T, float>::value;
 
   95     constexpr 
bool is_double = __hip_internal::is_same<T, double>::value;
 
   96     static_assert(is_half || is_float || is_double, 
"only half, float and double are supported");
 
   98     constexpr 
int weo = is_half ? 5 : (is_float ? 8 : 11);
 
   99     constexpr 
int wmo = is_half ? 10 : (is_float ? 23 : 52);
 
  101     T fInf, fNegInf, fNaN, fNeg0, fmax, fmin;
 
  102     if constexpr(is_half)
 
  104         const unsigned short int ihInf    = 0x7C00;
 
  105         const unsigned short int ihNegInf = 0xFC00;
 
  106         const unsigned short int ihNaN    = 0x7C01;
 
  107         const unsigned short int ihNeg0   = 0x8000;
 
  109         const unsigned short int ifmax = 0x7B00;
 
  110         const unsigned short int ifmin = 0xFB00;
 
  112         fInf    = bit_cast<_Float16>(ihInf);
 
  113         fNegInf = bit_cast<_Float16>(ihNegInf);
 
  114         fNaN    = bit_cast<_Float16>(ihNaN);
 
  115         fNeg0   = bit_cast<_Float16>(ihNeg0);
 
  116         fmax    = bit_cast<_Float16>(ifmax);
 
  117         fmin    = bit_cast<_Float16>(ifmin);
 
  119     else if constexpr(is_float)
 
  121         const unsigned int ifInf    = 0x7F800000;
 
  122         const unsigned int ifNegInf = 0xFF800000;
 
  123         const unsigned int ifNaN    = 0x7F800001;
 
  124         const unsigned int ifNeg0   = 0x80000000;
 
  126         const unsigned int ifmax = 0x47600000;
 
  127         const unsigned int ifmin = 0xC7600000;
 
  129         fInf    = bit_cast<float>(ifInf);
 
  130         fNegInf = bit_cast<float>(ifNegInf);
 
  131         fNaN    = bit_cast<float>(ifNaN);
 
  132         fNeg0   = bit_cast<float>(ifNeg0);
 
  133         fmax    = bit_cast<float>(ifmax);
 
  134         fmin    = bit_cast<float>(ifmin);
 
  136     else if constexpr(is_double)
 
  138         const unsigned long long ifInf    = 0x7FF0000000000000ull;
 
  139         const unsigned long long ifNegInf = 0xFFF0000000000000ull;
 
  140         const unsigned long long ifNaN    = 0x7FF0000000000001ull;
 
  141         const unsigned long long ifNeg0   = 0x8000000000000000ull;
 
  143         const unsigned long long ifmax = 0x40EC000000000000ull;
 
  144         const unsigned long long ifmin = 0xC0EC000000000000ull;
 
  146         fInf    = bit_cast<double>(ifInf);
 
  147         fNegInf = bit_cast<double>(ifNegInf);
 
  148         fNaN    = bit_cast<double>(ifNaN);
 
  149         fNeg0   = bit_cast<double>(ifNeg0);
 
  150         fmax    = bit_cast<double>(ifmax);
 
  151         fmin    = bit_cast<double>(ifmin);
 
  159     unsigned long long sign     = x >> 7;
 
  160     unsigned long long mantissa = x & ((1 << wm) - 1);
 
  161     int exponent                = (x & 0x7F) >> wm;
 
  162     if constexpr(is_fnuz)
 
  175         if constexpr(we == 4)
 
  177             if((x & 0x7F) == 0x7F)
 
  182         else if((x & 0x7C) == 0x7C)
 
  188                     return sign ? fmin : fmax;
 
  190                 return sign ? fNegInf : fInf;
 
  202     if constexpr(we == 5 && is_half && !is_fnuz)
 
  205         return bit_cast<T>(retval);
 
  208     const int exp_low_cutoff = (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (is_fnuz ? 1 : 0);
 
  213 #if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__ 
  215         int sh = 1 + __clz(mantissa) - (32 - wm);
 
  217         int sh = 1 + __builtin_clz(mantissa) - (32 - wm);
 
  221         mantissa &= ((1ull << wm) - 1);
 
  223     exponent += exp_low_cutoff - 1;
 
  224     mantissa <<= wmo - wm;
 
  229         mantissa |= 1 << wmo;
 
  230         mantissa >>= 1 - exponent;
 
  234     if constexpr(
sizeof(T) == 2)
 
  235         retval = (sign << 15) | (exponent << 10) | mantissa;
 
  236     else if constexpr(sizeof(T) == 4)
 
  237         retval = (sign << 31) | (exponent << 23) | mantissa;
 
  239         retval = (sign << 63) | (static_cast<
unsigned long long>(exponent) << 52) | mantissa;
 
  244 #if CK_FP8_CVT_FAST_PATH 
  245 template <ck_fp8_
interpretation_t 
interpret>
 
  246 static __host__ __device__ 
float cast_to_f32_from_f8(
fp8_storage_t v)
 
  251         unsigned char i8val[4];
 
  259                   "Only FNUZ and OCP interpretations are supported");
 
  264         return __builtin_amdgcn_cvt_f32_fp8(val.i32val, 0);
 
  268         return __builtin_amdgcn_cvt_f32_bf8(val.i32val, 0);
 
  272 template <ck_fp8_
interpretation_t 
interpret>
 
  275     const auto i16val = bit_cast<uint16_t>(v);
 
  281                   "Only FNUZ and OCP interpretations are supported");
 
  286         return __builtin_amdgcn_cvt_pk_f32_fp8(i16val, 
false);
 
  290         return __builtin_amdgcn_cvt_pk_f32_bf8(i16val, 
false);
 
  306     static constexpr 
unsigned int we = 4; 
 
  307     static constexpr 
unsigned int wm = 3; 
 
  311         return (data == other.
data) && (fp8_impl::ocp_f8_is_nan(data) == 
false); 
 
  315     __host__ __device__ 
explicit operator float() const
 
  317     __host__ 
explicit operator float() const
 
  320 #if CK_OCP_FP8_CVT_FAST_PATH 
  321         return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
 
  323         return fp8_impl::cast_from_f8<float, wm, we, false>(
 
  329     __host__ __device__ 
explicit operator _Float16() const
 
  331     __host__ 
explicit operator _Float16() const
 
  334 #if CK_OCP_FP8_CVT_FAST_PATH 
  335         return static_cast<_Float16
>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
 
  337         return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
 
  352     static constexpr 
unsigned int we = 5; 
 
  353     static constexpr 
unsigned int wm = 2; 
 
  357         return (data == other.
data) && (fp8_impl::ocp_bf8_is_nan(data) == 
false); 
 
  361     __host__ __device__ 
explicit operator float() const
 
  364     __host__ 
explicit operator float() const
 
  367 #if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) 
  368         return fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data);
 
  370         return fp8_impl::cast_from_f8<float, wm, we, false>(
 
  376     __host__ __device__ 
explicit operator _Float16() const
 
  378     __host__ 
explicit operator _Float16() const
 
  381 #if defined(__gfx950__) || defined(__gfx1200__) || defined(__gfx1201__) 
  382         return static_cast<_Float16
>(fp8_impl::cast_to_f32_from_f8<default_interpret>(this->data));
 
  384         return fp8_impl::cast_from_f8<_Float16, wm, we, false>(
 
  390 template <
typename T>
 
  391 __host__ __device__ 
static inline constexpr 
bool fp8_is_nan(T);
 
  394 __host__ __device__ 
inline constexpr 
bool fp8_is_nan(
f8_ocp_t a)
 
  396     return fp8_impl::ocp_f8_is_nan(a.
data);
 
  399 __host__ __device__ 
inline constexpr 
bool fp8_is_nan(
bf8_ocp_t a)
 
  401     return fp8_impl::ocp_bf8_is_nan(a.
data);
 
  404 __host__ __device__ 
inline constexpr 
bool fp8_is_nan(
f8_fnuz_t a)
 
  406     return fp8_impl::fnuz_f8_is_nan(a);
 
  409 __host__ __device__ 
inline constexpr 
bool fp8_is_nan(
bf8_fnuz_t a)
 
  411     return fp8_impl::fnuz_bf8_is_nan(a);
 
  414 template <
typename T,
 
  416                               is_same_v<T, bf8_fnuz_t> || is_same_v<T, f8_fnuz_t>,
 
  418 __host__ __device__ 
static inline constexpr 
bool fp8_is_inf(T)
 
  423 __host__ __device__ 
inline constexpr 
bool fp8_is_inf(
bf8_ocp_t a)
 
  425     return (a.
data & 0x7f) == 0x7c;
 
  431 #define __fp8_impl_assert_ocp_support(interp)                                      \ 
  433         if(interp != ck_fp8_interpretation_t::CK_E4M3_OCP &&                       \ 
  434            interp != ck_fp8_interpretation_t::CK_E5M2_OCP)                         \ 
  436             __hip_assert(false && "type is unsupported by current target device"); \
 
  439 #define __fp8_impl_assert_fnuz_support(interp)                                     \ 
  441         if(interp != ck_fp8_interpretation_t::CK_E4M3_FNUZ &&                      \ 
  442            interp != ck_fp8_interpretation_t::CK_E5M2_FNUZ)                        \ 
  444             __hip_assert(false && "type is unsupported by current target device"); \
 
  448 __host__ __device__ 
static inline void 
  451 #if defined(__HIP_DEVICE_COMPILE__) && __HIP_DEVICE_COMPILE__ 
  461 #if defined(__gfx950__) 
  464           bool stochastic_rounding                                                 = 
false,
 
  467 static __device__ 
fp8_storage_t cast_to_f8_from_f16(_Float16 v, 
unsigned int rng = 0)
 
  476     constexpr 
unsigned int i32val = 0;
 
  479     if constexpr(saturate)
 
  481         if((val.i32val & 0x7FFF) != 0x7FFF)
 
  483             val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
 
  488         __builtin_amdgcn_cvt_scalef32_sr_fp8_f16(i32val, val.half_vec[0], rng,  1.f, 0);
 
  495           bool stochastic_rounding                                                 = 
false,
 
  502         cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
 
  503         cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
 
  508           bool stochastic_rounding                                                 = 
false,
 
  511 static __device__ 
fp8_storage_t cast_to_f8_from_f16(_Float16 v, 
unsigned int rng = 0)
 
  520     constexpr 
unsigned int i32val = 0;
 
  523     if constexpr(saturate)
 
  525         if((val.i32val & 0x7FFF) != 0x7FFF)
 
  527             val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
 
  532         __builtin_amdgcn_cvt_scalef32_sr_bf8_f16(i32val, val.half_vec[0], rng,  1.f, 0);
 
  539           bool stochastic_rounding                                                 = 
false,
 
  546         cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
 
  547         cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
 
  552           bool stochastic_rounding                                                 = 
false,
 
  555 static __device__ 
fp8_storage_t cast_to_f8_from_f16(_Float16 v, 
unsigned int rng = 0)
 
  570     if constexpr(saturate)
 
  572         if((val.i32val & 0x7FFF) != 0x7FFF)
 
  574             val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
 
  579         __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec,  1.f, 0);
 
  586           bool stochastic_rounding                                                 = 
false,
 
  591 #if CK_WORKAROUND_FP16_TO_FP8_CONVERSION 
  593         cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
 
  594         cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
 
  608     if constexpr(saturate)
 
  610         if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
 
  612             val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 448.0, -448.0);
 
  614         if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
 
  616             val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 448.0, -448.0);
 
  621         __builtin_amdgcn_cvt_scalef32_pk_fp8_f16(i16x2val, val.half_vec,  1.f, 0);
 
  629           bool stochastic_rounding                                                 = 
false,
 
  632 static __device__ 
fp8_storage_t cast_to_f8_from_f16(_Float16 v, 
unsigned int rng = 0)
 
  647     if constexpr(saturate)
 
  649         if((val.i32val & 0x7FFF) != 0x7FFF)
 
  651             val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
 
  656         __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec,  1.f, 0);
 
  663           bool stochastic_rounding                                                 = 
false,
 
  668 #if CK_WORKAROUND_FP16_TO_FP8_CONVERSION 
  670         cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[0], rng),
 
  671         cast_to_f8_from_f16<interpret, saturate, stochastic_rounding>(v[1], rng)};
 
  685     if constexpr(saturate)
 
  687         if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
 
  689             val.half_vec[0] = __builtin_amdgcn_fmed3h(val.half_vec[0], 57344.0, -57344.0);
 
  691         if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
 
  693             val.half_vec[1] = __builtin_amdgcn_fmed3h(val.half_vec[1], 57344.0, -57344.0);
 
  698         __builtin_amdgcn_cvt_scalef32_pk_bf8_f16(i16x2val, val.half_vec,  1.f, 0);
 
  706           bool stochastic_rounding                                                 = 
false,
 
  709 static __device__ 
fp8_storage_t cast_to_f8_from_bf16(ushort v, 
unsigned int rng = 0)
 
  718     constexpr 
unsigned int i32val = 0;
 
  719     val.bhalf_vec[0]              = v;
 
  721     if constexpr(saturate)
 
  723         if((val.i32val & 0x7FFF) != 0x7FFF)
 
  726                 ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
 
  727                             bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >>
 
  732     val.i32val = __builtin_amdgcn_cvt_scalef32_sr_fp8_bf16(
 
  733         i32val, val.bhalf_vec[0], rng,  1.f, 0);
 
  740           bool stochastic_rounding                                                 = 
false,
 
  747         cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
 
  748         cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
 
  753           bool stochastic_rounding                                                 = 
false,
 
  756 static __device__ 
fp8_storage_t cast_to_f8_from_bf16(ushort v, 
unsigned int rng = 0)
 
  765     constexpr 
unsigned int i32val = 0;
 
  766     val.bhalf_vec[0]              = v;
 
  768     if constexpr(saturate)
 
  770         if((val.i32val & 0x7FFF) != 0x7FFF)
 
  772             val.bhalf_vec[0] = ushort(
 
  773                 (bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
 
  774                      bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >>
 
  779     val.i32val = __builtin_amdgcn_cvt_scalef32_sr_bf8_bf16(
 
  780         i32val, val.bhalf_vec[0], rng,  1.f, 0);
 
  787           bool stochastic_rounding                                                 = 
false,
 
  794         cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
 
  795         cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
 
  800           bool stochastic_rounding                                                 = 
false,
 
  803 static __device__ 
fp8_storage_t cast_to_f8_from_bf16(ushort v, 
unsigned int rng = 0)
 
  816     val.bhalf_vec[0]             = v;
 
  818     if constexpr(saturate)
 
  820         if((val.i32val & 0x7FFF) != 0x7FFF)
 
  823                 ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
 
  824                             bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >>
 
  830         __builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec,  1.f, 0);
 
  837           bool stochastic_rounding                                                 = 
false,
 
  842 #if CK_WORKAROUND_BF16_TO_FP8_CONVERSION 
  844         cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[0], rng),
 
  845         cast_to_f8_from_bf16<interpret, saturate, stochastic_rounding>(v[1], rng)};
 
  859     if constexpr(saturate)
 
  861         if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
 
  864                 ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
 
  865                             bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 448.0, -448.0)) >>
 
  868         if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
 
  871                 ushort((bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
 
  872                             bit_cast<float>(uint32_t{val.bhalf_vec[1]} << 16), 448.0, -448.0)) >>
 
  878         __builtin_amdgcn_cvt_scalef32_pk_fp8_bf16(i16x2val, val.bhalf_vec,  1.f, 0);
 
  886           bool stochastic_rounding                                                 = 
false,
 
  889 static __device__ 
fp8_storage_t cast_to_f8_from_bf16(ushort v, 
unsigned int rng = 0)
 
  902     val.bhalf_vec[0]             = v;
 
  904     if constexpr(saturate)
 
  906         if((val.i32val & 0x7FFF) != 0x7FFF)
 
  908             val.bhalf_vec[0] = ushort(
 
  909                 (bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
 
  910                      bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >>
 
  916         __builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec,  1.f, 0);
 
  923           bool stochastic_rounding                                                 = 
false,
 
  940     if constexpr(saturate)
 
  942         if((val.i16_vec[0] & 0x7FFF) != 0x7FFF)
 
  944             val.bhalf_vec[0] = ushort(
 
  945                 (bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
 
  946                      bit_cast<float>(uint32_t{val.bhalf_vec[0]} << 16), 57344.0, -57344.0)) >>
 
  949         if((val.i16_vec[1] & 0x7FFF) != 0x7FFF)
 
  951             val.bhalf_vec[1] = ushort(
 
  952                 (bit_cast<uint32_t>(__builtin_amdgcn_fmed3f(
 
  953                      bit_cast<float>(uint32_t{val.bhalf_vec[1]} << 16), 57344.0, -57344.0)) >>
 
  959         __builtin_amdgcn_cvt_scalef32_pk_bf8_bf16(i16x2val, val.bhalf_vec,  1.f, 0);
 
  965 #if CK_FP8_CVT_FAST_PATH 
  968 template <ck_fp8_
interpretation_t 
interpret, 
bool saturate, 
bool stochastic_rounding = false>
 
  969 static __device__ 
fp8_storage_t cast_to_f8_from_f32(
float v, 
unsigned int rng = 0)
 
  976         unsigned char i8val[4]; 
 
  979     unsigned int ival = 0;
 
  982     if constexpr(saturate)
 
  986             if((val.i32val & 0x7F800000) != 0x7F800000)
 
  988                 val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
 
  993             if((val.i32val & 0x7F800000) != 0x7F800000)
 
  995                 val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0);
 
 1000             if((val.i32val & 0x7F800000) != 0x7F800000)
 
 1002                 val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
 
 1007     if constexpr(stochastic_rounding)
 
 1011                          ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
 
 1012                          : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); 
 
 1014         i8data     = val.i8val[0]; 
 
 1020                          ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, 
false)
 
 1021                          : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval,
 
 1026         i8data     = val.i8val[0];
 
 1031 template <ck_fp8_
interpretation_t 
interpret, 
bool saturate, 
bool stochastic_rounding = false>
 
 1034     if constexpr(stochastic_rounding)
 
 1038             cast_to_f8_from_f32<interpret, saturate, stochastic_rounding>(v[0], rng),
 
 1039             cast_to_f8_from_f32<interpret, saturate, stochastic_rounding>(v[1], rng)};
 
 1046             unsigned int i32val;
 
 1047             unsigned char i8val[4];
 
 1053         unsigned int ival = 0;
 
 1055         if constexpr(saturate)
 
 1059                 if((val0.i32val & 0x7F800000) != 0x7F800000)
 
 1061                     val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 240.0, -240.0);
 
 1063                 if((val1.i32val & 0x7F800000) != 0x7F800000)
 
 1065                     val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 240.0, -240.0);
 
 1070                 if((val0.i32val & 0x7F800000) != 0x7F800000)
 
 1072                     val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 448.0, -448.0);
 
 1074                 if((val1.i32val & 0x7F800000) != 0x7F800000)
 
 1076                     val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 448.0, -448.0);
 
 1081                 if((val0.i32val & 0x7F800000) != 0x7F800000)
 
 1083                     val0.fval = __builtin_amdgcn_fmed3f(val0.fval, 57344.0, -57344.0);
 
 1085                 if((val1.i32val & 0x7F800000) != 0x7F800000)
 
 1087                     val1.fval = __builtin_amdgcn_fmed3f(val1.fval, 57344.0, -57344.0);
 
 1096             ival = __builtin_amdgcn_cvt_pk_fp8_f32(val0.fval, val1.fval, ival, 
false);
 
 1100             ival = __builtin_amdgcn_cvt_pk_bf8_f32(val0.fval, val1.fval, ival, 
false);
 
 1113 template <
typename T, 
int wm, 
int we, 
bool is_fnuz, 
bool clip = false, 
bool stoch = false>
 
 1114 __host__ __device__ 
static inline fp8_storage_t cast_to_f8(T _x, 
unsigned int rng = 0)
 
 1116     constexpr 
bool is_half   = __hip_internal::is_same<T, _Float16>::value;
 
 1117     constexpr 
bool is_float  = __hip_internal::is_same<T, float>::value;
 
 1118     constexpr 
bool is_double = __hip_internal::is_same<T, double>::value;
 
 1119     static_assert(is_half || is_float || is_double,
 
 1120                   "Only half, float and double can be cast to f8");
 
 1122     constexpr 
int mfmt = (
sizeof(T) == 8) ? 52 : ((
sizeof(T) == 4) ? 23 : 10);
 
 1128     T_bitwise x_bitwise = bit_cast<T_bitwise>(_x);
 
 1130     unsigned long long x{x_bitwise};
 
 1132     unsigned long long head, mantissa;
 
 1135     unsigned long long fInf, mask;
 
 1137     if constexpr(
sizeof(T) == 8)
 
 1139         head     = x & 0xFFF0000000000000ull;
 
 1140         mantissa = x & 0xFFFFFFFFFFFFFull;
 
 1141         exponent = (head >> 52) & 0x7FF;
 
 1144         fInf     = 0x7FF0000000000000ull;
 
 1145         mask     = 0x7FFFFFFFFFFFFFFFull;
 
 1147     else if constexpr(
sizeof(T) == 4)
 
 1149         head     = x & 0xFF800000;
 
 1150         mantissa = x & 0x7FFFFF;
 
 1151         exponent = (head >> 23) & 0xFF;
 
 1160         mantissa = x & 0x3FF;
 
 1161         exponent = (head >> 10) & 0x1F;
 
 1167     unsigned int signed_inf = 0;
 
 1168     unsigned int nan        = 0;
 
 1169     if constexpr(is_fnuz)
 
 1171         signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
 
 1176         if constexpr(we == 4)
 
 1178             signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
 
 1182             signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
 
 1184         nan = (sign << 7) + 0x7f;
 
 1187     unsigned long long ifmax = 0;
 
 1188     if constexpr(
sizeof(T) == 8)
 
 1190         if constexpr(we == 5)
 
 1192             ifmax = 0x40EC000000000000ull;
 
 1196             if constexpr(is_fnuz)
 
 1198                 ifmax = 0x406E000000000000ull;
 
 1202                 ifmax = 0x407C000000000000ull;
 
 1206     else if(
sizeof(T) == 4)
 
 1208         if constexpr(we == 5)
 
 1214             if constexpr(is_fnuz)
 
 1226         if constexpr(we == 5)
 
 1232             if constexpr(is_fnuz)
 
 1243     if((x & fInf) == fInf)
 
 1245         if constexpr(is_fnuz)
 
 1248         return mantissa != 0 ? nan : signed_inf;
 
 1251     if((x & mask) > ifmax)
 
 1269     const int f8_bias                  = (1 << (we - 1)) - 1 + (is_fnuz ? 1 : 0);
 
 1270     const int f8_denormal_act_exponent = 1 - f8_bias; 
 
 1275     int act_exponent, f8_exponent, exponent_diff;
 
 1286         act_exponent  = exponent - bias + 1;
 
 1287         exponent_diff = f8_denormal_act_exponent -
 
 1292         act_exponent = exponent - bias;
 
 1293         if(act_exponent <= f8_denormal_act_exponent)
 
 1300             exponent_diff = f8_denormal_act_exponent - act_exponent;
 
 1308         mantissa += (1ull << mfmt); 
 
 1311     bool midpoint = (mantissa & ((1ull << (mfmt - wm + exponent_diff)) - 1)) ==
 
 1312                     (1ull << (mfmt - wm + exponent_diff - 1));
 
 1320     if(exponent_diff > 0)
 
 1321         mantissa >>= exponent_diff;
 
 1322     else if(exponent_diff == -1)
 
 1323         mantissa <<= -exponent_diff;
 
 1324     bool implicit_one = mantissa & (1ull << mfmt);
 
 1328         (act_exponent + exponent_diff)  + f8_bias - (implicit_one ? 0 : 1);
 
 1331     unsigned long long drop_mask = (1ull << (mfmt - wm)) - 1;
 
 1333         mantissa & (1ull << (mfmt - wm)); 
 
 1335         (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
 
 1338     if(f8_exponent == 0)
 
 1340         if((1ull << mfmt) & mantissa)
 
 1347         if((1ull << (mfmt + 1)) & mantissa)
 
 1354     mantissa >>= (mfmt - wm);
 
 1357     const int max_exp = (1 << we) - 1;
 
 1358     if(f8_exponent > max_exp)
 
 1362             mantissa    = (1 << wm) - 1;
 
 1363             f8_exponent = max_exp;
 
 1371     if(f8_exponent == 0 && mantissa == 0)
 
 1372         return is_fnuz ? 0 : (sign << 7);
 
 1373     mantissa &= (1 << wm) - 1;
 
 1374     return (sign << 7) | (f8_exponent << wm) | mantissa;
 
 1388           bool stochastic_rounding = 
false>
 
 1389 #if CK_FP8_CVT_FAST_PATH 
 1390 __host__ __device__ 
static inline fp8_storage_t cvt_float_to_fp8(
const float f)
 
 1392     __is_interpret_supported(interp);
 
 1394     if constexpr(stochastic_rounding)
 
 1396 #if defined(__gfx950__) 
 1398         rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
 
 1401         constexpr 
int seed = 1254739;
 
 1402 #ifndef CK_CODE_GEN_RTC 
 1403         rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t
>(&f), f);
 
 1405         rng = prand_generator<float, seed>(
reinterpret_cast<size_t>(&f), f);
 
 1409     return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
 
 1413 __host__ __device__ 
static inline fp8_storage_t cvt_float_to_fp8(
const float f)
 
 1416 __host__ 
static inline fp8_storage_t cvt_float_to_fp8(
const float f)
 
 1420     if constexpr(stochastic_rounding)
 
 1422 #if defined(__gfx950__) 
 1424         rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
 
 1427         constexpr 
int seed = 1254739;
 
 1428 #ifndef CK_CODE_GEN_RTC 
 1429         rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t
>(&f), f);
 
 1431         rng = prand_generator<float, seed>(
reinterpret_cast<size_t>(&f), f);
 
 1438         return cast_to_f8<float,
 
 1443                           stochastic_rounding>(f, rng);
 
 1447         return cast_to_f8<float,
 
 1452                           stochastic_rounding>(f, rng);
 
 1456         return cast_to_f8<float,
 
 1461                           stochastic_rounding>(f, rng);
 
 1465         return cast_to_f8<float,
 
 1470                           stochastic_rounding>(f, rng);
 
 1474         __hip_assert(
false && 
"FP8 type is not supported by current target device");
 
 1491           bool stochastic_rounding = 
false>
 
 1492 #if CK_FP8_CVT_FAST_PATH 
 1495     __is_interpret_supported(interp);
 
 1497     if constexpr(stochastic_rounding)
 
 1499 #if defined(__gfx950__) 
 1501         rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
 
 1504         constexpr 
int seed = 1254739;
 
 1505 #ifndef CK_CODE_GEN_RTC 
 1506         rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t
>(&f), f[0]);
 
 1508         rng = prand_generator<float, seed>(
reinterpret_cast<size_t>(&f), f[0]);
 
 1512     return cast_to_f8_from_f32<interp, sat == ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
 
 1522     return fp8x2_storage_t{cvt_float_to_fp8<interp, sat, stochastic_rounding>(f[0]),
 
 1523                            cvt_float_to_fp8<interp, sat, stochastic_rounding>(f[1])};
 
 1538           bool stochastic_rounding = 
false>
 
 1539 #if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8 
 1540 __host__ __device__ 
static inline fp8_storage_t cvt_half_t_to_fp8(
const _Float16 x)
 
 1542 __host__ 
static inline fp8_storage_t cvt_half_t_to_fp8(
const _Float16 x)
 
 1546         __is_interpret_supported(interp);
 
 1548         if constexpr(stochastic_rounding)
 
 1550 #if defined(__gfx950__) 
 1552             rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
 
 1555             constexpr 
int seed = 1254739;
 
 1556 #ifndef CK_CODE_GEN_RTC 
 1557             rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t
>(&x), x);
 
 1559             rng = prand_generator<float, seed>(
reinterpret_cast<size_t>(&x), x);
 
 1563 #if defined(__gfx950__) 
 1564         return cast_to_f8_from_f16<interp,
 
 1566                                    stochastic_rounding>(x, rng);
 
 1569         return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
 
 1570             static_cast<float>(x));
 
 1586           bool stochastic_rounding = 
false>
 
 1587 #if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8 
 1594         __is_interpret_supported(interp);
 
 1596         if constexpr(stochastic_rounding)
 
 1598 #if defined(__gfx950__) 
 1600             rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
 
 1603             constexpr 
int seed = 1254739;
 
 1604 #ifndef CK_CODE_GEN_RTC 
 1605             rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t
>(&x), x[0]);
 
 1607             rng = prand_generator<float, seed>(
reinterpret_cast<size_t>(&x), x[0]);
 
 1611 #if defined(__gfx950__) 
 1612         return cast_to_f8_from_f16<interp,
 
 1614                                    stochastic_rounding>(x, rng);
 
 1617         return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
 
 1618             float2_t{
static_cast<float>(x[0]), 
static_cast<float>(x[1])});
 
 1634           bool stochastic_rounding = 
false>
 
 1635 #if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8 
 1636 __host__ __device__ 
static inline fp8_storage_t cvt_bhalf_t_to_fp8(
const ushort x)
 
 1638 __host__ 
static inline fp8_storage_t cvt_bhalf_t_to_fp8(
const ushort x)
 
 1642         __is_interpret_supported(interp);
 
 1644         if constexpr(stochastic_rounding)
 
 1646 #if defined(__gfx950__) 
 1648             rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
 
 1651             constexpr 
int seed = 1254739;
 
 1652 #ifndef CK_CODE_GEN_RTC 
 1653             rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t
>(&x),
 
 1654                                                static_cast<float>(x));
 
 1656             rng = prand_generator<float, seed>(
reinterpret_cast<size_t>(&x), 
static_cast<float>(x));
 
 1660 #if defined(__gfx950__) 
 1661         return cast_to_f8_from_bf16<interp,
 
 1663                                     stochastic_rounding>(x, rng);
 
 1666         return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
 
 1667             bit_cast<float>(uint32_t{x} << 16)); 
 
 1683           bool stochastic_rounding = 
false>
 
 1684 #if CK_FP8_CVT_FAST_PATH || CK_USE_OCP_FP8 
 1690 #if CK_WORKAROUND_BF16_TO_FP8_CONVERSION 
 1691     return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
 
 1692         float2_t{bit_cast<float>(uint32_t{x[0]} << 16),
 
 1693                  bit_cast<float>(uint32_t{x[1]} << 16)}); 
 
 1696         __is_interpret_supported(interp);
 
 1698         if constexpr(stochastic_rounding)
 
 1700 #if defined(__gfx950__) 
 1702             rng = __builtin_amdgcn_prng_b32(__builtin_amdgcn_s_memrealtime() *
 
 1705             constexpr 
int seed = 1254739;
 
 1706 #ifndef CK_CODE_GEN_RTC 
 1707             rng = prand_generator<float, seed>(
reinterpret_cast<uintptr_t
>(&x),
 
 1708                                                static_cast<float>(x[0]));
 
 1710             rng = prand_generator<float, seed>(
reinterpret_cast<size_t>(&x),
 
 1711                                                static_cast<float>(x[0]));
 
 1715 #if defined(__gfx950__) 
 1716         return cast_to_f8_from_bf16<interp,
 
 1718                                     stochastic_rounding>(x, rng);
 
 1721         return cvt_float_to_fp8<interp, ck_saturation_t::CK_SATFINITE, stochastic_rounding>(
 
 1722             float2_t{bit_cast<float>(uint32_t{x[0]} << 16),
 
 1723                      bit_cast<float>(uint32_t{x[1]} << 16)}); 
 
 1732 using f8_t  = f8_ocp_t;
 
 1733 using bf8_t = bf8_ocp_t;
 
 1734 #define CK_FP8_TYPE_FNUZ 0 
 1735 #define CK_FP8_TYPE_OCP 1 
 1739 #define CK_FP8_TYPE_FNUZ 1 
 1740 #define CK_FP8_TYPE_OCP 0 
#define __fp8_impl_assert_fnuz_support(interp)
Definition: amd_ck_fp8.hpp:439
 
#define __fp8_impl_assert_ocp_support(interp)
Definition: amd_ck_fp8.hpp:431
 
ushort ushortx2_t
Definition: amd_ck_fp8.hpp:65
 
short shortx2_t
Definition: amd_ck_fp8.hpp:66
 
float float2_t
Definition: amd_ck_fp8.hpp:67
 
fp8_storage_t fp8x2_storage_t
Definition: amd_ck_fp8.hpp:63
 
_Float16 half2_t
Definition: amd_ck_fp8.hpp:64
 
__host__ constexpr __device__ Y bit_cast(const X &x)
Definition: type.hpp:306
 
bf8_fnuz_t bf8_t
Definition: amd_ck_fp8.hpp:1738
 
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
 
ck_fp8_interpretation_t
Describes FP8 interpretation.
Definition: amd_ck_fp8.hpp:45
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
 
__device__ index_t get_thread_global_1d_id()
Definition: get_id.hpp:21
 
unsigned _BitInt(8) bf8_fnuz_t
Definition: amd_ck_fp8.hpp:37
 
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
 
_BitInt(8) f8_fnuz_t
Definition: amd_ck_fp8.hpp:36
 
ck_saturation_t
Describes saturation behavior.
Definition: amd_ck_fp8.hpp:56
 
unsigned char fp8_storage_t
Definition: amd_ck_fp8.hpp:39
 
Definition: amd_ck_fp8.hpp:344
 
__host__ constexpr __device__ bool operator==(const bf8_ocp_t &other) const
Definition: amd_ck_fp8.hpp:355
 
fp8_storage_t data_type
Definition: amd_ck_fp8.hpp:345
 
data_type data
Definition: amd_ck_fp8.hpp:346
 
Definition: amd_ck_fp8.hpp:298
 
fp8_storage_t data_type
Definition: amd_ck_fp8.hpp:299
 
data_type data
Definition: amd_ck_fp8.hpp:300
 
__host__ constexpr __device__ bool operator==(const f8_ocp_t &other) const
Definition: amd_ck_fp8.hpp:309