/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hip/checkouts/docs-5.3.3/include/hip/hip_bfloat16.h Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hip/checkouts/docs-5.3.3/include/hip/hip_bfloat16.h Source File#

HIP Runtime API Reference: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hip/checkouts/docs-5.3.3/include/hip/hip_bfloat16.h Source File
hip_bfloat16.h
Go to the documentation of this file.
1
29#ifndef _HIP_BFLOAT16_H_
30#define _HIP_BFLOAT16_H_
31
32#if __cplusplus < 201103L || !defined(__HIPCC__)
33
34// If this is a C compiler, C++ compiler below C++11, or a host-only compiler, we only
35// include a minimal definition of hip_bfloat16
36
37#include <stdint.h>
39typedef struct
40{
41 uint16_t data;
43
44#else // __cplusplus < 201103L || !defined(__HIPCC__)
45
46#include <cmath>
47#include <cstddef>
48#include <cstdint>
49#include <hip/hip_runtime.h>
50#include <ostream>
51#include <type_traits>
52
53#pragma clang diagnostic push
54#pragma clang diagnostic ignored "-Wshadow"
55struct hip_bfloat16
56{
57 uint16_t data;
58
59 enum truncate_t
60 {
61 truncate
62 };
63
64 __host__ __device__ hip_bfloat16() = default;
65
66 // round upper 16 bits of IEEE float to convert to bfloat16
67 explicit __host__ __device__ hip_bfloat16(float f)
68 : data(float_to_bfloat16(f))
69 {
70 }
71
72 explicit __host__ __device__ hip_bfloat16(float f, truncate_t)
73 : data(truncate_float_to_bfloat16(f))
74 {
75 }
76
77 // zero extend lower 16 bits of bfloat16 to convert to IEEE float
78 __host__ __device__ operator float() const
79 {
80 union
81 {
82 uint32_t int32;
83 float fp32;
84 } u = {uint32_t(data) << 16};
85 return u.fp32;
86 }
87
88 static __host__ __device__ hip_bfloat16 round_to_bfloat16(float f)
89 {
90 hip_bfloat16 output;
91 output.data = float_to_bfloat16(f);
92 return output;
93 }
94
95 static __host__ __device__ hip_bfloat16 round_to_bfloat16(float f, truncate_t)
96 {
97 hip_bfloat16 output;
98 output.data = truncate_float_to_bfloat16(f);
99 return output;
100 }
101
102private:
103 static __host__ __device__ uint16_t float_to_bfloat16(float f)
104 {
105 union
106 {
107 float fp32;
108 uint32_t int32;
109 } u = {f};
110 if(~u.int32 & 0x7f800000)
111 {
112 // When the exponent bits are not all 1s, then the value is zero, normal,
113 // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
114 // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
115 // This causes the bfloat16's mantissa to be incremented by 1 if the 16
116 // least significant bits of the float mantissa are greater than 0x8000,
117 // or if they are equal to 0x8000 and the least significant bit of the
118 // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
119 // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
120 // has the value 0x7f, then incrementing it causes it to become 0x00 and
121 // the exponent is incremented by one, which is the next higher FP value
122 // to the unrounded bfloat16 value. When the bfloat16 value is subnormal
123 // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
124 // to a normal value with an exponent of 0x01 and a mantissa of 0x00.
125 // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
126 // incrementing it causes it to become an exponent of 0xFF and a mantissa
127 // of 0x00, which is Inf, the next higher value to the unrounded value.
128 u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
129 }
130 else if(u.int32 & 0xffff)
131 {
132 // When all of the exponent bits are 1, the value is Inf or NaN.
133 // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
134 // mantissa bit. Quiet NaN is indicated by the most significant mantissa
135 // bit being 1. Signaling NaN is indicated by the most significant
136 // mantissa bit being 0 but some other bit(s) being 1. If any of the
137 // lower 16 bits of the mantissa are 1, we set the least significant bit
138 // of the bfloat16 mantissa, in order to preserve signaling NaN in case
139 // the bloat16's mantissa bits are all 0.
140 u.int32 |= 0x10000; // Preserve signaling NaN
141 }
142 return uint16_t(u.int32 >> 16);
143 }
144
145 // Truncate instead of rounding, preserving SNaN
146 static __host__ __device__ uint16_t truncate_float_to_bfloat16(float f)
147 {
148 union
149 {
150 float fp32;
151 uint32_t int32;
152 } u = {f};
153 return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
154 }
155};
156#pragma clang diagnostic pop
157
158typedef struct
159{
160 uint16_t data;
161} hip_bfloat16_public;
162
163static_assert(std::is_standard_layout<hip_bfloat16>{},
164 "hip_bfloat16 is not a standard layout type, and thus is "
165 "incompatible with C.");
166
167static_assert(std::is_trivial<hip_bfloat16>{},
168 "hip_bfloat16 is not a trivial type, and thus is "
169 "incompatible with C.");
170
171static_assert(sizeof(hip_bfloat16) == sizeof(hip_bfloat16_public)
172 && offsetof(hip_bfloat16, data) == offsetof(hip_bfloat16_public, data),
173 "internal hip_bfloat16 does not match public hip_bfloat16");
174
175inline std::ostream& operator<<(std::ostream& os, const hip_bfloat16& bf16)
176{
177 return os << float(bf16);
178}
179inline __host__ __device__ hip_bfloat16 operator+(hip_bfloat16 a)
180{
181 return a;
182}
183inline __host__ __device__ hip_bfloat16 operator-(hip_bfloat16 a)
184{
185 a.data ^= 0x8000;
186 return a;
187}
188inline __host__ __device__ hip_bfloat16 operator+(hip_bfloat16 a, hip_bfloat16 b)
189{
190 return hip_bfloat16(float(a) + float(b));
191}
192inline __host__ __device__ hip_bfloat16 operator-(hip_bfloat16 a, hip_bfloat16 b)
193{
194 return hip_bfloat16(float(a) - float(b));
195}
196inline __host__ __device__ hip_bfloat16 operator*(hip_bfloat16 a, hip_bfloat16 b)
197{
198 return hip_bfloat16(float(a) * float(b));
199}
200inline __host__ __device__ hip_bfloat16 operator/(hip_bfloat16 a, hip_bfloat16 b)
201{
202 return hip_bfloat16(float(a) / float(b));
203}
204inline __host__ __device__ bool operator<(hip_bfloat16 a, hip_bfloat16 b)
205{
206 return float(a) < float(b);
207}
208inline __host__ __device__ bool operator==(hip_bfloat16 a, hip_bfloat16 b)
209{
210 return float(a) == float(b);
211}
212inline __host__ __device__ bool operator>(hip_bfloat16 a, hip_bfloat16 b)
213{
214 return b < a;
215}
216inline __host__ __device__ bool operator<=(hip_bfloat16 a, hip_bfloat16 b)
217{
218 return !(a > b);
219}
220inline __host__ __device__ bool operator!=(hip_bfloat16 a, hip_bfloat16 b)
221{
222 return !(a == b);
223}
224inline __host__ __device__ bool operator>=(hip_bfloat16 a, hip_bfloat16 b)
225{
226 return !(a < b);
227}
228inline __host__ __device__ hip_bfloat16& operator+=(hip_bfloat16& a, hip_bfloat16 b)
229{
230 return a = a + b;
231}
232inline __host__ __device__ hip_bfloat16& operator-=(hip_bfloat16& a, hip_bfloat16 b)
233{
234 return a = a - b;
235}
236inline __host__ __device__ hip_bfloat16& operator*=(hip_bfloat16& a, hip_bfloat16 b)
237{
238 return a = a * b;
239}
240inline __host__ __device__ hip_bfloat16& operator/=(hip_bfloat16& a, hip_bfloat16 b)
241{
242 return a = a / b;
243}
244inline __host__ __device__ hip_bfloat16& operator++(hip_bfloat16& a)
245{
246 return a += hip_bfloat16(1.0f);
247}
248inline __host__ __device__ hip_bfloat16& operator--(hip_bfloat16& a)
249{
250 return a -= hip_bfloat16(1.0f);
251}
252inline __host__ __device__ hip_bfloat16 operator++(hip_bfloat16& a, int)
253{
254 hip_bfloat16 orig = a;
255 ++a;
256 return orig;
257}
258inline __host__ __device__ hip_bfloat16 operator--(hip_bfloat16& a, int)
259{
260 hip_bfloat16 orig = a;
261 --a;
262 return orig;
263}
264
265namespace std
266{
267 constexpr __host__ __device__ bool isinf(hip_bfloat16 a)
268 {
269 return !(~a.data & 0x7f80) && !(a.data & 0x7f);
270 }
271 constexpr __host__ __device__ bool isnan(hip_bfloat16 a)
272 {
273 return !(~a.data & 0x7f80) && +(a.data & 0x7f);
274 }
275 constexpr __host__ __device__ bool iszero(hip_bfloat16 a)
276 {
277 return !(a.data & 0x7fff);
278 }
279}
280
281#endif // __cplusplus < 201103L || !defined(__HIPCC__)
282
283#endif // _HIP_BFLOAT16_H_
Struct to represent a 16 bit brain floating point number.
Definition hip_bfloat16.h:40
uint16_t data
Definition hip_bfloat16.h:41