/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hip/checkouts/docs-5.0.2/include/hip/hip_bfloat16.h Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hip/checkouts/docs-5.0.2/include/hip/hip_bfloat16.h Source File#

HIP Runtime API Reference: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hip/checkouts/docs-5.0.2/include/hip/hip_bfloat16.h Source File
hip_bfloat16.h
Go to the documentation of this file.
1
29#ifndef _HIP_BFLOAT16_H_
30#define _HIP_BFLOAT16_H_
31
32#if __cplusplus < 201103L || !defined(__HIPCC__)
33
34// If this is a C compiler, C++ compiler below C++11, or a host-only compiler, we only
35// include a minimal definition of hip_bfloat16
36
37#include <stdint.h>
39typedef struct
40{
41 uint16_t data;
43
44#else // __cplusplus < 201103L || !defined(__HIPCC__)
45
46#include <cmath>
47#include <cstddef>
48#include <cstdint>
49#include <hip/hip_runtime.h>
50#include <ostream>
51#include <type_traits>
52
53struct hip_bfloat16
54{
55 uint16_t data;
56
57 enum truncate_t
58 {
59 truncate
60 };
61
62 __host__ __device__ hip_bfloat16() = default;
63
64 // round upper 16 bits of IEEE float to convert to bfloat16
65 explicit __host__ __device__ hip_bfloat16(float f)
66 : data(float_to_bfloat16(f))
67 {
68 }
69
70 explicit __host__ __device__ hip_bfloat16(float f, truncate_t)
71 : data(truncate_float_to_bfloat16(f))
72 {
73 }
74
75 // zero extend lower 16 bits of bfloat16 to convert to IEEE float
76 __host__ __device__ operator float() const
77 {
78 union
79 {
80 uint32_t int32;
81 float fp32;
82 } u = {uint32_t(data) << 16};
83 return u.fp32;
84 }
85
86 static __host__ __device__ hip_bfloat16 round_to_bfloat16(float f)
87 {
88 hip_bfloat16 output;
89 output.data = float_to_bfloat16(f);
90 return output;
91 }
92
93 static __host__ __device__ hip_bfloat16 round_to_bfloat16(float f, truncate_t)
94 {
95 hip_bfloat16 output;
96 output.data = truncate_float_to_bfloat16(f);
97 return output;
98 }
99
100private:
101 static __host__ __device__ uint16_t float_to_bfloat16(float f)
102 {
103 union
104 {
105 float fp32;
106 uint32_t int32;
107 } u = {f};
108 if(~u.int32 & 0x7f800000)
109 {
110 // When the exponent bits are not all 1s, then the value is zero, normal,
111 // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
112 // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
113 // This causes the bfloat16's mantissa to be incremented by 1 if the 16
114 // least significant bits of the float mantissa are greater than 0x8000,
115 // or if they are equal to 0x8000 and the least significant bit of the
116 // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
117 // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
118 // has the value 0x7f, then incrementing it causes it to become 0x00 and
119 // the exponent is incremented by one, which is the next higher FP value
120 // to the unrounded bfloat16 value. When the bfloat16 value is subnormal
121 // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
122 // to a normal value with an exponent of 0x01 and a mantissa of 0x00.
123 // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
124 // incrementing it causes it to become an exponent of 0xFF and a mantissa
125 // of 0x00, which is Inf, the next higher value to the unrounded value.
126 u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
127 }
128 else if(u.int32 & 0xffff)
129 {
130 // When all of the exponent bits are 1, the value is Inf or NaN.
131 // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
132 // mantissa bit. Quiet NaN is indicated by the most significant mantissa
133 // bit being 1. Signaling NaN is indicated by the most significant
134 // mantissa bit being 0 but some other bit(s) being 1. If any of the
135 // lower 16 bits of the mantissa are 1, we set the least significant bit
136 // of the bfloat16 mantissa, in order to preserve signaling NaN in case
137 // the bloat16's mantissa bits are all 0.
138 u.int32 |= 0x10000; // Preserve signaling NaN
139 }
140 return uint16_t(u.int32 >> 16);
141 }
142
143 // Truncate instead of rounding, preserving SNaN
144 static __host__ __device__ uint16_t truncate_float_to_bfloat16(float f)
145 {
146 union
147 {
148 float fp32;
149 uint32_t int32;
150 } u = {f};
151 return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
152 }
153};
154
155typedef struct
156{
157 uint16_t data;
158} hip_bfloat16_public;
159
160static_assert(std::is_standard_layout<hip_bfloat16>{},
161 "hip_bfloat16 is not a standard layout type, and thus is "
162 "incompatible with C.");
163
164static_assert(std::is_trivial<hip_bfloat16>{},
165 "hip_bfloat16 is not a trivial type, and thus is "
166 "incompatible with C.");
167
168static_assert(sizeof(hip_bfloat16) == sizeof(hip_bfloat16_public)
169 && offsetof(hip_bfloat16, data) == offsetof(hip_bfloat16_public, data),
170 "internal hip_bfloat16 does not match public hip_bfloat16");
171
172inline std::ostream& operator<<(std::ostream& os, const hip_bfloat16& bf16)
173{
174 return os << float(bf16);
175}
176inline __host__ __device__ hip_bfloat16 operator+(hip_bfloat16 a)
177{
178 return a;
179}
180inline __host__ __device__ hip_bfloat16 operator-(hip_bfloat16 a)
181{
182 a.data ^= 0x8000;
183 return a;
184}
185inline __host__ __device__ hip_bfloat16 operator+(hip_bfloat16 a, hip_bfloat16 b)
186{
187 return hip_bfloat16(float(a) + float(b));
188}
189inline __host__ __device__ hip_bfloat16 operator-(hip_bfloat16 a, hip_bfloat16 b)
190{
191 return hip_bfloat16(float(a) - float(b));
192}
193inline __host__ __device__ hip_bfloat16 operator*(hip_bfloat16 a, hip_bfloat16 b)
194{
195 return hip_bfloat16(float(a) * float(b));
196}
197inline __host__ __device__ hip_bfloat16 operator/(hip_bfloat16 a, hip_bfloat16 b)
198{
199 return hip_bfloat16(float(a) / float(b));
200}
201inline __host__ __device__ bool operator<(hip_bfloat16 a, hip_bfloat16 b)
202{
203 return float(a) < float(b);
204}
205inline __host__ __device__ bool operator==(hip_bfloat16 a, hip_bfloat16 b)
206{
207 return float(a) == float(b);
208}
209inline __host__ __device__ bool operator>(hip_bfloat16 a, hip_bfloat16 b)
210{
211 return b < a;
212}
213inline __host__ __device__ bool operator<=(hip_bfloat16 a, hip_bfloat16 b)
214{
215 return !(a > b);
216}
217inline __host__ __device__ bool operator!=(hip_bfloat16 a, hip_bfloat16 b)
218{
219 return !(a == b);
220}
221inline __host__ __device__ bool operator>=(hip_bfloat16 a, hip_bfloat16 b)
222{
223 return !(a < b);
224}
225inline __host__ __device__ hip_bfloat16& operator+=(hip_bfloat16& a, hip_bfloat16 b)
226{
227 return a = a + b;
228}
229inline __host__ __device__ hip_bfloat16& operator-=(hip_bfloat16& a, hip_bfloat16 b)
230{
231 return a = a - b;
232}
233inline __host__ __device__ hip_bfloat16& operator*=(hip_bfloat16& a, hip_bfloat16 b)
234{
235 return a = a * b;
236}
237inline __host__ __device__ hip_bfloat16& operator/=(hip_bfloat16& a, hip_bfloat16 b)
238{
239 return a = a / b;
240}
241inline __host__ __device__ hip_bfloat16& operator++(hip_bfloat16& a)
242{
243 return a += hip_bfloat16(1.0f);
244}
245inline __host__ __device__ hip_bfloat16& operator--(hip_bfloat16& a)
246{
247 return a -= hip_bfloat16(1.0f);
248}
249inline __host__ __device__ hip_bfloat16 operator++(hip_bfloat16& a, int)
250{
251 hip_bfloat16 orig = a;
252 ++a;
253 return orig;
254}
255inline __host__ __device__ hip_bfloat16 operator--(hip_bfloat16& a, int)
256{
257 hip_bfloat16 orig = a;
258 --a;
259 return orig;
260}
261
262namespace std
263{
264 constexpr __host__ __device__ bool isinf(hip_bfloat16 a)
265 {
266 return !(~a.data & 0x7f80) && !(a.data & 0x7f);
267 }
268 constexpr __host__ __device__ bool isnan(hip_bfloat16 a)
269 {
270 return !(~a.data & 0x7f80) && +(a.data & 0x7f);
271 }
272 constexpr __host__ __device__ bool iszero(hip_bfloat16 a)
273 {
274 return !(a.data & 0x7fff);
275 }
276}
277
278#endif // __cplusplus < 201103L || !defined(__HIPCC__)
279
280#endif // _HIP_BFLOAT16_H_
Struct to represent a 16 bit brain floating point number.
Definition hip_bfloat16.h:40
uint16_t data
Definition hip_bfloat16.h:41