/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hip/checkouts/docs-5.4.3/include/hip/hip_bfloat16.h Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hip/checkouts/docs-5.4.3/include/hip/hip_bfloat16.h Source File#

HIP Runtime API Reference: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hip/checkouts/docs-5.4.3/include/hip/hip_bfloat16.h Source File
hip_bfloat16.h
Go to the documentation of this file.
1 
29 #ifndef _HIP_BFLOAT16_H_
30 #define _HIP_BFLOAT16_H_
31 
32 #if __cplusplus < 201103L || !defined(__HIPCC__)
33 
34 // If this is a C compiler, C++ compiler below C++11, or a host-only compiler, we only
35 // include a minimal definition of hip_bfloat16
36 
37 #include <stdint.h>
39 typedef struct
40 {
41  uint16_t data;
42 } hip_bfloat16;
43 
44 #else // __cplusplus < 201103L || !defined(__HIPCC__)
45 
46 #include <cmath>
47 #include <cstddef>
48 #include <cstdint>
49 #include <hip/hip_runtime.h>
50 #include <ostream>
51 #include <type_traits>
52 
53 #pragma clang diagnostic push
54 #pragma clang diagnostic ignored "-Wshadow"
55 struct hip_bfloat16
56 {
57  uint16_t data;
58 
59  enum truncate_t
60  {
61  truncate
62  };
63 
64  __host__ __device__ hip_bfloat16() = default;
65 
66  // round upper 16 bits of IEEE float to convert to bfloat16
67  explicit __host__ __device__ hip_bfloat16(float f)
68  : data(float_to_bfloat16(f))
69  {
70  }
71 
72  explicit __host__ __device__ hip_bfloat16(float f, truncate_t)
73  : data(truncate_float_to_bfloat16(f))
74  {
75  }
76 
77  // zero extend lower 16 bits of bfloat16 to convert to IEEE float
78  __host__ __device__ operator float() const
79  {
80  union
81  {
82  uint32_t int32;
83  float fp32;
84  } u = {uint32_t(data) << 16};
85  return u.fp32;
86  }
87 
88  static __host__ __device__ hip_bfloat16 round_to_bfloat16(float f)
89  {
90  hip_bfloat16 output;
91  output.data = float_to_bfloat16(f);
92  return output;
93  }
94 
95  static __host__ __device__ hip_bfloat16 round_to_bfloat16(float f, truncate_t)
96  {
97  hip_bfloat16 output;
98  output.data = truncate_float_to_bfloat16(f);
99  return output;
100  }
101 
102 private:
103  static __host__ __device__ uint16_t float_to_bfloat16(float f)
104  {
105  union
106  {
107  float fp32;
108  uint32_t int32;
109  } u = {f};
110  if(~u.int32 & 0x7f800000)
111  {
112  // When the exponent bits are not all 1s, then the value is zero, normal,
113  // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
114  // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
115  // This causes the bfloat16's mantissa to be incremented by 1 if the 16
116  // least significant bits of the float mantissa are greater than 0x8000,
117  // or if they are equal to 0x8000 and the least significant bit of the
118  // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
119  // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
120  // has the value 0x7f, then incrementing it causes it to become 0x00 and
121  // the exponent is incremented by one, which is the next higher FP value
122  // to the unrounded bfloat16 value. When the bfloat16 value is subnormal
123  // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
124  // to a normal value with an exponent of 0x01 and a mantissa of 0x00.
125  // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
126  // incrementing it causes it to become an exponent of 0xFF and a mantissa
127  // of 0x00, which is Inf, the next higher value to the unrounded value.
128  u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
129  }
130  else if(u.int32 & 0xffff)
131  {
132  // When all of the exponent bits are 1, the value is Inf or NaN.
133  // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
134  // mantissa bit. Quiet NaN is indicated by the most significant mantissa
135  // bit being 1. Signaling NaN is indicated by the most significant
136  // mantissa bit being 0 but some other bit(s) being 1. If any of the
137  // lower 16 bits of the mantissa are 1, we set the least significant bit
138  // of the bfloat16 mantissa, in order to preserve signaling NaN in case
139  // the bloat16's mantissa bits are all 0.
140  u.int32 |= 0x10000; // Preserve signaling NaN
141  }
142  return uint16_t(u.int32 >> 16);
143  }
144 
145  // Truncate instead of rounding, preserving SNaN
146  static __host__ __device__ uint16_t truncate_float_to_bfloat16(float f)
147  {
148  union
149  {
150  float fp32;
151  uint32_t int32;
152  } u = {f};
153  return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
154  }
155 };
156 #pragma clang diagnostic pop
157 
158 typedef struct
159 {
160  uint16_t data;
161 } hip_bfloat16_public;
162 
163 static_assert(std::is_standard_layout<hip_bfloat16>{},
164  "hip_bfloat16 is not a standard layout type, and thus is "
165  "incompatible with C.");
166 
167 static_assert(std::is_trivial<hip_bfloat16>{},
168  "hip_bfloat16 is not a trivial type, and thus is "
169  "incompatible with C.");
170 
171 static_assert(sizeof(hip_bfloat16) == sizeof(hip_bfloat16_public)
172  && offsetof(hip_bfloat16, data) == offsetof(hip_bfloat16_public, data),
173  "internal hip_bfloat16 does not match public hip_bfloat16");
174 
175 inline std::ostream& operator<<(std::ostream& os, const hip_bfloat16& bf16)
176 {
177  return os << float(bf16);
178 }
179 inline __host__ __device__ hip_bfloat16 operator+(hip_bfloat16 a)
180 {
181  return a;
182 }
183 inline __host__ __device__ hip_bfloat16 operator-(hip_bfloat16 a)
184 {
185  a.data ^= 0x8000;
186  return a;
187 }
188 inline __host__ __device__ hip_bfloat16 operator+(hip_bfloat16 a, hip_bfloat16 b)
189 {
190  return hip_bfloat16(float(a) + float(b));
191 }
192 inline __host__ __device__ hip_bfloat16 operator-(hip_bfloat16 a, hip_bfloat16 b)
193 {
194  return hip_bfloat16(float(a) - float(b));
195 }
196 inline __host__ __device__ hip_bfloat16 operator*(hip_bfloat16 a, hip_bfloat16 b)
197 {
198  return hip_bfloat16(float(a) * float(b));
199 }
200 inline __host__ __device__ hip_bfloat16 operator/(hip_bfloat16 a, hip_bfloat16 b)
201 {
202  return hip_bfloat16(float(a) / float(b));
203 }
204 inline __host__ __device__ bool operator<(hip_bfloat16 a, hip_bfloat16 b)
205 {
206  return float(a) < float(b);
207 }
208 inline __host__ __device__ bool operator==(hip_bfloat16 a, hip_bfloat16 b)
209 {
210  return float(a) == float(b);
211 }
212 inline __host__ __device__ bool operator>(hip_bfloat16 a, hip_bfloat16 b)
213 {
214  return b < a;
215 }
216 inline __host__ __device__ bool operator<=(hip_bfloat16 a, hip_bfloat16 b)
217 {
218  return !(a > b);
219 }
220 inline __host__ __device__ bool operator!=(hip_bfloat16 a, hip_bfloat16 b)
221 {
222  return !(a == b);
223 }
224 inline __host__ __device__ bool operator>=(hip_bfloat16 a, hip_bfloat16 b)
225 {
226  return !(a < b);
227 }
228 inline __host__ __device__ hip_bfloat16& operator+=(hip_bfloat16& a, hip_bfloat16 b)
229 {
230  return a = a + b;
231 }
232 inline __host__ __device__ hip_bfloat16& operator-=(hip_bfloat16& a, hip_bfloat16 b)
233 {
234  return a = a - b;
235 }
236 inline __host__ __device__ hip_bfloat16& operator*=(hip_bfloat16& a, hip_bfloat16 b)
237 {
238  return a = a * b;
239 }
240 inline __host__ __device__ hip_bfloat16& operator/=(hip_bfloat16& a, hip_bfloat16 b)
241 {
242  return a = a / b;
243 }
244 inline __host__ __device__ hip_bfloat16& operator++(hip_bfloat16& a)
245 {
246  return a += hip_bfloat16(1.0f);
247 }
248 inline __host__ __device__ hip_bfloat16& operator--(hip_bfloat16& a)
249 {
250  return a -= hip_bfloat16(1.0f);
251 }
252 inline __host__ __device__ hip_bfloat16 operator++(hip_bfloat16& a, int)
253 {
254  hip_bfloat16 orig = a;
255  ++a;
256  return orig;
257 }
258 inline __host__ __device__ hip_bfloat16 operator--(hip_bfloat16& a, int)
259 {
260  hip_bfloat16 orig = a;
261  --a;
262  return orig;
263 }
264 
265 namespace std
266 {
267  constexpr __host__ __device__ bool isinf(hip_bfloat16 a)
268  {
269  return !(~a.data & 0x7f80) && !(a.data & 0x7f);
270  }
271  constexpr __host__ __device__ bool isnan(hip_bfloat16 a)
272  {
273  return !(~a.data & 0x7f80) && +(a.data & 0x7f);
274  }
275  constexpr __host__ __device__ bool iszero(hip_bfloat16 a)
276  {
277  return !(a.data & 0x7fff);
278  }
279 }
280 
281 #endif // __cplusplus < 201103L || !defined(__HIPCC__)
282 
283 #endif // _HIP_BFLOAT16_H_
Struct to represent a 16 bit brain floating point number.
Definition: hip_bfloat16.h:40
uint16_t data
Definition: hip_bfloat16.h:41