/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/numeric/bfloat16.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/numeric/bfloat16.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/numeric/bfloat16.hpp Source File
bfloat16.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
9 #if CK_TILE_USE_LLVM_BUILTIN_BF16
10 #include <hip/hip_bfloat16.h>
11 #endif
12 #include <stdint.h>
13 
14 #pragma once
15 
16 namespace ck_tile {
17 
19 {
20  standard = 0, // rtn
22  truncate,
24  rta_asm, // round to nearest away
25 };
26 
27 template <bf16_rounding_mode rounding =
29 CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
30 
31 template <bf16_rounding_mode rounding =
33 CK_TILE_HOST_DEVICE constexpr uint16_t double_to_bf16_raw(double f, constant<rounding> = {});
34 
36 constexpr float bf16_to_float_raw(uint16_t x);
37 
39 constexpr double bf16_to_double_raw(uint16_t x);
40 
41 #if CK_TILE_USE_CUSTOM_DATA_TYPE
42 // HIP use __hip_bfloat16 as struct
43 struct alignas(2) bfloat16_t
44 {
45  using raw_type = uint16_t;
46  raw_type data;
47 
49  static constexpr bfloat16_t bit_cast(raw_type x)
50  {
51  bfloat16_t y;
52  y.data = x;
53  return y;
54  }
55 
56  // constructor
57  constexpr bfloat16_t() : data() {}
58 
59  // construct from float
61  explicit constexpr bfloat16_t(const float& x) : data(float_to_bf16_raw(x)) {}
62 
63  // construct from double
65  explicit constexpr bfloat16_t(const double& x) : data(double_to_bf16_raw(x)) {}
66 
67  // construct from int
69  explicit constexpr bfloat16_t(const int& x) : data(float_to_bf16_raw(static_cast<float>(x))) {}
70 
71  // construct from unsigned int
73  explicit constexpr bfloat16_t(const unsigned int& x)
74  : data(float_to_bf16_raw(static_cast<float>(x)))
75  {
76  }
77 
78  // cast to float
80  explicit constexpr operator float() const { return bf16_to_float_raw(data); }
81 
82  // cast to float
84  explicit constexpr operator double() const { return bf16_to_double_raw(data); }
85 
86  // cast to int
88  explicit constexpr operator int() const { return static_cast<int>(bf16_to_float_raw(data)); }
89 
90  // internal access
92  constexpr raw_type& get() { return data; }
93 
95  constexpr raw_type get() const { return data; }
96 };
97 template <typename>
98 struct native_t;
99 
100 template <>
101 struct native_t<bfloat16_t>
102 {
103  using type = ushort;
104 };
105 using bf16_t = bfloat16_t;
106 using bf16_raw_t = typename bf16_t::raw_type;
107 #else
108 #if CK_TILE_USE_LLVM_BUILTIN_BF16
109 using bfloat16_t = __bf16;
110 #else
111 using bfloat16_t = ushort;
112 #endif
115 #endif
116 // round to nearest
118 constexpr uint16_t float_to_bf16_rtn_raw(float f)
119 {
120  uint32_t bits = bit_cast<uint32_t>(f);
121  if(~bits & 0x7f800000)
122  {
123  // When the exponent bits are not all 1s, then the value is zero, normal,
124  // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
125  // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
126  // This causes the bfloat16's mantissa to be incremented by 1 if the 16
127  // least significant bits of the float mantissa are greater than 0x8000,
128  // or if they are equal to 0x8000 and the least significant bit of the
129  // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
130  // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
131  // has the value 0x7f, then incrementing it causes it to become 0x00 and
132  // the exponent is incremented by one, which is the next higher FP value
133  // to the unrounded bfloat16 value. When the bfloat16 value is subnormal
134  // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
135  // to a normal value with an exponent of 0x01 and a mantissa of 0x00.
136  // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
137  // incrementing it causes it to become an exponent of 0xFF and a mantissa
138  // of 0x00, which is Inf, the next higher value to the unrounded value.
139  bits += 0x7fff + ((bits >> 16) & 1); // Round to nearest, round to even
140  }
141  else if(bits & 0xffff)
142  {
143  // When all of the exponent bits are 1, the value is Inf or NaN.
144  // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
145  // mantissa bit. Quiet NaN is indicated by the most significant mantissa
146  // bit being 1. Signaling NaN is indicated by the most significant
147  // mantissa bit being 0 but some other bit(s) being 1. If any of the
148  // lower 16 bits of the mantissa are 1, we set the least significant bit
149  // of the bfloat16 mantissa, in order to preserve signaling NaN in case
150  // the bloat16's mantissa bits are all 0.
151  bits |= 0x10000; // Preserve signaling NaN
152  }
153  return uint16_t(bits >> 16);
154 }
155 
157 constexpr uint16_t float_to_bf16_rtn_asm(float f) { return float_to_bf16_rtn_raw(f); }
158 
161 {
162  union
163  {
164  float fp32;
165  uint32_t int32;
166  } u = {f};
167 
168  static constexpr uint32_t FP32_NAN = 0x7fff0000;
169  static constexpr uint32_t ROUND_BIAS_FOR_BF16 = 0x7fff;
170 
171 #if defined(__GFX9__)
172  using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
173  uint32x2_t check_nan;
174 #else
175  uint32_t check_nan;
176 #endif
177  uint32_t tmp;
178  asm volatile("\n \
179  v_cmp_u_f32 %0, %2, %2 \n \
180  v_bfe_u32 %1, %2, 16, 1 \n \
181  v_add3_u32 %1, %2, %1, %3 \n \
182  v_cndmask_b32 %2, %1, %4, %0 \n \
183  v_lshrrev_b32 %2, 16, %2 \n \
184  "
185  : "=s"(check_nan), "+v"(tmp), "+v"(u.fp32)
186  : "v"(ROUND_BIAS_FOR_BF16), "v"(FP32_NAN));
187 
188  return uint16_t(u.int32);
189 }
190 
191 // TODO: do we need this on host?
194 
197 {
198  union
199  {
200  float fp32;
201  struct
202  {
203  uint16_t lo;
204  uint16_t hi;
205  };
206  } u = {f};
207 
208  const uint32_t low_nan = 0x7fff;
209  const uint32_t hi_nan = 0x7fff0000;
210 
211 #if defined(__GFX9__)
212  using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
213  uint32x2_t check_nan;
214 #else
215  uint32_t check_nan;
216 #endif
217 
218  asm volatile("v_cmp_u_f32 %[s_cnan], %[v_x], %[v_x] \n"
219  "v_add3_u32 %[v_x], %[v_x], %[v_blo], 1 \n"
220  "v_cndmask_b32 %[v_x], %[v_x], %[v_bhi], %[s_cnan]"
221  : [s_cnan] "+s"(check_nan), [v_x] "+v"(u.fp32)
222  : [v_blo] "v"(low_nan), [v_bhi] "v"(hi_nan));
223 
224  // Note: in above code snipet, we use hi 16 bit
225  return u.hi;
226 }
227 
228 // Truncate instead of rounding, preserving SNaN
231 {
232  uint32_t bits = bit_cast<uint32_t>(f);
233  return static_cast<uint16_t>(bits >> 16) | (!(~bits & 0x7f800000) && (bits & 0xffff));
234 }
235 
236 // Fast truncate instead of rounding, RTZ
239 {
240  uint32_t bits = bit_cast<uint32_t>(f);
241  return static_cast<uint16_t>(bits >> 16);
242 }
243 
244 template <bf16_rounding_mode rounding>
246 {
247  if constexpr(rounding == bf16_rounding_mode::standard)
248  return float_to_bf16_rtn_raw(f);
249  else if constexpr(rounding == bf16_rounding_mode::standard_asm)
250  return float_to_bf16_rtn_asm(f);
251  else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
252  return float_to_bf16_truc_nan_raw(f);
253  else if constexpr(rounding == bf16_rounding_mode::rta_asm)
254  return float_to_bf16_rta_asm(f);
255  else
256  return float_to_bf16_truc_raw(f);
257 }
258 
259 template <bf16_rounding_mode rounding>
261 {
262  return float_to_bf16_raw(static_cast<float>(f), constant<rounding>{});
263 }
264 
266 constexpr float bf16_to_float_raw(uint16_t x)
267 {
268  union
269  {
270  uint32_t int32;
271  float fp32;
272  } u = {uint32_t(x) << 16};
273  return u.fp32;
274 }
275 
277 constexpr double bf16_to_double_raw(uint16_t x)
278 {
279  return static_cast<double>(bf16_to_float_raw(x));
280 }
281 
282 template <bf16_rounding_mode rounding =
285 {
286 // Use builtin bfloat16 conversion only on gfx950 as its predecessors do not support bf16 cvt
287 // instructions, resulting in suboptimal performance; Add host side marcro check for consistency
288 // during accuracy tests.
289 #if CK_TILE_USE_LLVM_BUILTIN_BF16 && (defined(__gfx950__) || defined(CK_GFX950_SUPPORT))
290  return static_cast<bfloat16_t>(f);
291 #else
292  return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));
293 #endif
294 }
295 
296 template <bf16_rounding_mode rounding =
299 {
300  return bit_cast<bfloat16_t>(double_to_bf16_raw(f, constant<rounding>{}));
301 }
302 
304 constexpr float bf16_to_float(bfloat16_t x) { return bf16_to_float_raw(bit_cast<uint16_t>(x)); }
305 
307 constexpr double bf16_to_double(bfloat16_t x) { return static_cast<double>(bf16_to_float_raw(x)); }
308 
309 template <bf16_rounding_mode rounding =
312 {
313  return bit_cast<bfloat16_t>(float_to_bf16_raw(static_cast<float>(f), constant<rounding>{}));
314 }
315 
317 constexpr half_t bf16_to_fp16(bfloat16_t x) { return static_cast<fp16_t>(static_cast<float>(x)); }
318 
319 template <class T>
320 struct numeric;
321 
322 template <>
324 {
325  // minimum finite value, or minimum positive normalized value for float
326  CK_TILE_HOST_DEVICE static constexpr bfloat16_t min()
327  {
328  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0080));
329  }
330 
331  // minumum finite value
333  {
334  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0xff7f));
335  }
336 
337  // maximum finite value
338  CK_TILE_HOST_DEVICE static constexpr bfloat16_t max()
339  {
340  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f7f));
341  }
342 
343  // difference between 1.0 and next value representable by float
345  {
346  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x1000));
347  }
348 
349  // maximum rounding error
350  // maximum rounding error
351  // bin : f edcba 9876543210
352  // bits: s eeeeeeee mmmmmmm
353  // 0 01111110 0000000 (0.5)
354  //
356  {
357  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x3f00));
358  }
359 
360  // positive infinity value
362  {
363  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f80));
364  }
365 
366  // quiet NaN
368  {
369  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
370  }
371 
372  // signaling NaN
374  {
375  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
376  }
377 
378  // smallest positive subnormal value
380  {
381  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0001));
382  }
384  {
385  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0));
386  }
387 };
388 
389 template <>
391 {
392  static constexpr int exp = 8;
393  static constexpr int mant = 7;
394  static constexpr int PackedSize = 1;
395 };
396 
397 #if CK_TILE_USE_CUSTOM_DATA_TYPE
399 #endif
400 
401 // math
404 {
405  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(bit_cast<bf16_raw_t>(x) & 0x7fff));
406 }
407 
409 bool isnan(const bfloat16_t& x)
410 {
411  uint16_t xx = bit_cast<bf16_raw_t>(x);
412  return (xx & 0x7FFF) > 0x7C00;
413 }
414 
417 {
418  return static_cast<bfloat16_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
419 };
420 
423 {
424  return static_cast<bfloat16_t>(__ocml_exp_f32(static_cast<float>(x)));
425 };
426 
428 bfloat16_t exp2(bfloat16_t x) { return static_cast<bfloat16_t>(exp2f(static_cast<float>(x))); };
429 
431 bfloat16_t log(bfloat16_t x) { return static_cast<bfloat16_t>(__logf(static_cast<float>(x))); };
432 
433 using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
434 using fp32x2_t = float __attribute__((ext_vector_type(2)));
435 
436 template <bf16_rounding_mode rounding =
439 {
440  return bf16x2_t{float_to_bf16<rounding>(x.x), float_to_bf16<rounding>(x.y)};
441 }
442 
443 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
Definition: config.hpp:76
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE uint16_t float_to_bf16_truc_nan_raw(float f)
Definition: bfloat16.hpp:230
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:431
ushort bfloat16_t
Definition: bfloat16.hpp:111
uint32_t uint32x2_t
Definition: vector_type.hpp:163
constexpr CK_TILE_HOST_DEVICE Y bit_cast(const X &x)
Definition: bit_cast.hpp:11
bfloat16_t bf16x2_t
Definition: bfloat16.hpp:433
_Float16 fp16_t
Definition: half.hpp:110
constexpr CK_TILE_HOST_DEVICE uint16_t double_to_bf16_raw(double f, constant< rounding >={})
Definition: bfloat16.hpp:260
constexpr CK_TILE_HOST_DEVICE float bf16_to_float_raw(uint16_t x)
Definition: bfloat16.hpp:266
constexpr CK_TILE_HOST_DEVICE uint16_t float_to_bf16_truc_raw(float f)
Definition: bfloat16.hpp:238
float fp32x2_t
Definition: bfloat16.hpp:434
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
CK_TILE_HOST uint16_t float_to_bf16_rta_asm(float f)
Definition: bfloat16.hpp:193
constexpr CK_TILE_HOST_DEVICE double bf16_to_double_raw(uint16_t x)
Definition: bfloat16.hpp:277
constexpr CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant< rounding >={})
Definition: bfloat16.hpp:284
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition: bfloat16.hpp:416
constexpr CK_TILE_HOST_DEVICE float bf16_to_float(bfloat16_t x)
Definition: bfloat16.hpp:304
constexpr CK_TILE_HOST_DEVICE bf16x2_t fp32x2_to_bf16x2(const fp32x2_t &x)
Definition: bfloat16.hpp:438
uint16_t bf16_raw_t
Definition: bfloat16.hpp:114
constexpr CK_TILE_HOST_DEVICE half_t bf16_to_fp16(bfloat16_t x)
Definition: bfloat16.hpp:317
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:422
bf16_rounding_mode
Definition: bfloat16.hpp:19
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition: bfloat16.hpp:403
constexpr CK_TILE_HOST_DEVICE double bf16_to_double(bfloat16_t x)
Definition: bfloat16.hpp:307
constexpr CK_TILE_HOST_DEVICE uint16_t float_to_bf16_rtn_raw(float f)
Definition: bfloat16.hpp:118
constexpr CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant< rounding >={})
Definition: bfloat16.hpp:245
CK_TILE_HOST_DEVICE bool isnan(const bfloat16_t &x)
Definition: bfloat16.hpp:409
constexpr CK_TILE_HOST uint16_t float_to_bf16_rtn_asm(float f)
Definition: bfloat16.hpp:157
CK_TILE_HOST_DEVICE constexpr bfloat16_t fp16_to_bf16(half_t f, constant< rounding >={})
Definition: bfloat16.hpp:311
constexpr CK_TILE_HOST_DEVICE bfloat16_t double_to_bf16(double f, constant< rounding >={})
Definition: bfloat16.hpp:298
_Float16 half_t
Definition: half.hpp:111
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition: bfloat16.hpp:428
unsigned short uint16_t
Definition: stdint.h:125
unsigned int uint32_t
Definition: stdint.h:126
Definition: integral_constant.hpp:13
Definition: bfloat16.hpp:324
static constexpr CK_TILE_HOST_DEVICE bfloat16_t round_error()
Definition: bfloat16.hpp:355
static constexpr CK_TILE_HOST_DEVICE bfloat16_t infinity()
Definition: bfloat16.hpp:361
static constexpr CK_TILE_HOST_DEVICE bfloat16_t max()
Definition: bfloat16.hpp:338
static constexpr CK_TILE_HOST_DEVICE bfloat16_t denorm_min()
Definition: bfloat16.hpp:379
static constexpr CK_TILE_HOST_DEVICE bfloat16_t min()
Definition: bfloat16.hpp:326
static constexpr CK_TILE_HOST_DEVICE bfloat16_t lowest()
Definition: bfloat16.hpp:332
static constexpr CK_TILE_HOST_DEVICE bfloat16_t epsilon()
Definition: bfloat16.hpp:344
static constexpr CK_TILE_HOST_DEVICE bfloat16_t quiet_NaN()
Definition: bfloat16.hpp:367
static constexpr CK_TILE_HOST_DEVICE bfloat16_t zero()
Definition: bfloat16.hpp:383
static constexpr CK_TILE_HOST_DEVICE bfloat16_t signaling_NaN()
Definition: bfloat16.hpp:373
Definition: bfloat16.hpp:391
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_)
Definition: numeric.hpp:106