29#ifndef _HIP_BFLOAT16_H_
30#define _HIP_BFLOAT16_H_
32#if __cplusplus < 201103L || !defined(__HIPCC__)
53#pragma clang diagnostic push
54#pragma clang diagnostic ignored "-Wshadow"
68 :
data(float_to_bfloat16(f))
72 explicit __host__ __device__
hip_bfloat16(
float f, truncate_t)
73 :
data(truncate_float_to_bfloat16(f))
78 __host__ __device__
operator float()
const
84 } u = {uint32_t(
data) << 16};
88 static __host__ __device__
hip_bfloat16 round_to_bfloat16(
float f)
91 output.
data = float_to_bfloat16(f);
95 static __host__ __device__
hip_bfloat16 round_to_bfloat16(
float f, truncate_t)
98 output.
data = truncate_float_to_bfloat16(f);
103 static __host__ __device__ uint16_t float_to_bfloat16(
float f)
110 if(~u.int32 & 0x7f800000)
128 u.int32 += 0x7fff + ((u.int32 >> 16) & 1);
130 else if(u.int32 & 0xffff)
142 return uint16_t(u.int32 >> 16);
146 static __host__ __device__ uint16_t truncate_float_to_bfloat16(
float f)
153 return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
156#pragma clang diagnostic pop
161} hip_bfloat16_public;
163static_assert(std::is_standard_layout<hip_bfloat16>{},
164 "hip_bfloat16 is not a standard layout type, and thus is "
165 "incompatible with C.");
167static_assert(std::is_trivial<hip_bfloat16>{},
168 "hip_bfloat16 is not a trivial type, and thus is "
169 "incompatible with C.");
171static_assert(
sizeof(
hip_bfloat16) ==
sizeof(hip_bfloat16_public)
172 && offsetof(
hip_bfloat16, data) == offsetof(hip_bfloat16_public, data),
173 "internal hip_bfloat16 does not match public hip_bfloat16");
175inline std::ostream& operator<<(std::ostream& os,
const hip_bfloat16& bf16)
177 return os << float(bf16);
206 return float(a) < float(b);
210 return float(a) == float(b);
267 constexpr __host__ __device__
bool isinf(
hip_bfloat16 a)
269 return !(~a.data & 0x7f80) && !(a.
data & 0x7f);
271 constexpr __host__ __device__
bool isnan(
hip_bfloat16 a)
273 return !(~a.data & 0x7f80) && +(a.
data & 0x7f);
275 constexpr __host__ __device__
bool iszero(
hip_bfloat16 a)
277 return !(a.
data & 0x7fff);
Struct to represent a 16 bit brain floating point number.
Definition hip_bfloat16.h:40
uint16_t data
Definition hip_bfloat16.h:41