/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/core/numeric/bfloat16.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/core/numeric/bfloat16.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/core/numeric/bfloat16.hpp Source File
bfloat16.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
9 #include <stdint.h>
10 
11 #pragma once
12 
13 namespace ck_tile {
14 
16 {
17  standard = 0, // rtn
19  truncate,
21  rta_asm, // round to nearest away
22 };
23 
24 template <bf16_rounding_mode rounding =
26 CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<rounding> = {});
27 
28 template <bf16_rounding_mode rounding =
30 CK_TILE_HOST_DEVICE constexpr uint16_t double_to_bf16_raw(double f, constant<rounding> = {});
31 
33 constexpr float bf16_to_float_raw(uint16_t x);
34 
36 constexpr double bf16_to_double_raw(uint16_t x);
37 
38 #if CK_TILE_USE_CUSTOM_DATA_TYPE
39 // HIP use __hip_bfloat16 as struct
40 struct alignas(2) bfloat16_t
41 {
42  using raw_type = uint16_t;
43  raw_type data;
44 
46  static constexpr bfloat16_t bit_cast(raw_type x)
47  {
48  bfloat16_t y;
49  y.data = x;
50  return y;
51  }
52 
53  // constructor
54  constexpr bfloat16_t() : data() {}
55 
56  // construct from float
58  explicit constexpr bfloat16_t(const float& x) : data(float_to_bf16_raw(x)) {}
59 
60  // construct from double
62  explicit constexpr bfloat16_t(const double& x) : data(double_to_bf16_raw(x)) {}
63 
64  // construct from int
66  explicit constexpr bfloat16_t(const int& x) : data(float_to_bf16_raw(static_cast<float>(x))) {}
67 
68  // construct from unsigned int
70  explicit constexpr bfloat16_t(const unsigned int& x)
71  : data(float_to_bf16_raw(static_cast<float>(x)))
72  {
73  }
74 
75  // cast to float
77  explicit constexpr operator float() const { return bf16_to_float_raw(data); }
78 
79  // cast to float
81  explicit constexpr operator double() const { return bf16_to_double_raw(data); }
82 
83  // cast to int
85  explicit constexpr operator int() const { return static_cast<int>(bf16_to_float_raw(data)); }
86 
87  // internal access
89  constexpr raw_type& get() { return data; }
90 
92  constexpr raw_type get() const { return data; }
93 };
94 template <typename>
95 struct native_t;
96 
97 template <>
98 struct native_t<bfloat16_t>
99 {
100  using type = ushort;
101 };
102 using bf16_t = bfloat16_t;
103 using bf16_raw_t = typename bf16_t::raw_type;
104 #else
105 using bfloat16_t = ushort;
107 using bf16_raw_t = uint16_t;
108 #endif
109 // round to nearest
111 constexpr uint16_t float_to_bf16_rtn_raw(float f)
112 {
113  union
114  {
115  float fp32;
116  uint32_t int32;
117  } u = {f};
118  if(~u.int32 & 0x7f800000)
119  {
120  // When the exponent bits are not all 1s, then the value is zero, normal,
121  // or subnormal. We round the bfloat16 mantissa up by adding 0x7FFF, plus
122  // 1 if the least significant bit of the bfloat16 mantissa is 1 (odd).
123  // This causes the bfloat16's mantissa to be incremented by 1 if the 16
124  // least significant bits of the float mantissa are greater than 0x8000,
125  // or if they are equal to 0x8000 and the least significant bit of the
126  // bfloat16 mantissa is 1 (odd). This causes it to be rounded to even when
127  // the lower 16 bits are exactly 0x8000. If the bfloat16 mantissa already
128  // has the value 0x7f, then incrementing it causes it to become 0x00 and
129  // the exponent is incremented by one, which is the next higher FP value
130  // to the unrounded bfloat16 value. When the bfloat16 value is subnormal
131  // with an exponent of 0x00 and a mantissa of 0x7F, it may be rounded up
132  // to a normal value with an exponent of 0x01 and a mantissa of 0x00.
133  // When the bfloat16 value has an exponent of 0xFE and a mantissa of 0x7F,
134  // incrementing it causes it to become an exponent of 0xFF and a mantissa
135  // of 0x00, which is Inf, the next higher value to the unrounded value.
136  u.int32 += 0x7fff + ((u.int32 >> 16) & 1); // Round to nearest, round to even
137  }
138  else if(u.int32 & 0xffff)
139  {
140  // When all of the exponent bits are 1, the value is Inf or NaN.
141  // Inf is indicated by a zero mantissa. NaN is indicated by any nonzero
142  // mantissa bit. Quiet NaN is indicated by the most significant mantissa
143  // bit being 1. Signaling NaN is indicated by the most significant
144  // mantissa bit being 0 but some other bit(s) being 1. If any of the
145  // lower 16 bits of the mantissa are 1, we set the least significant bit
146  // of the bfloat16 mantissa, in order to preserve signaling NaN in case
147  // the bloat16's mantissa bits are all 0.
148  u.int32 |= 0x10000; // Preserve signaling NaN
149  }
150  return uint16_t(u.int32 >> 16);
151 }
152 
154 constexpr uint16_t float_to_bf16_rtn_asm(float f) { return float_to_bf16_rtn_raw(f); }
155 
157 uint16_t float_to_bf16_rtn_asm(float f)
158 {
159  union
160  {
161  float fp32;
162  uint32_t int32;
163  } u = {f};
164 
165  static constexpr uint32_t FP32_NAN = 0x7fff0000;
166  static constexpr uint32_t ROUND_BIAS_FOR_BF16 = 0x7fff;
167 
168  using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
169  uint32x2_t check_nan;
170  uint32_t tmp;
171  asm volatile("\n \
172  v_cmp_u_f32 %0, %2, %2 \n \
173  v_bfe_u32 %1, %2, 16, 1 \n \
174  v_add3_u32 %1, %2, %1, %3 \n \
175  v_cndmask_b32 %2, %1, %4, %0 \n \
176  v_lshrrev_b32 %2, 16, %2 \n \
177  "
178  : "=s"(check_nan), "+v"(tmp), "+v"(u.fp32)
179  : "v"(ROUND_BIAS_FOR_BF16), "v"(FP32_NAN));
180 
181  return uint16_t(u.int32);
182 }
183 
184 // TODO: do we need this on host?
186 uint16_t float_to_bf16_rta_asm(float f) { return float_to_bf16_rtn_raw(f); }
187 
189 uint16_t float_to_bf16_rta_asm(float f)
190 {
191  union
192  {
193  float fp32;
194  struct
195  {
196  uint16_t lo;
197  uint16_t hi;
198  };
199  } u = {f};
200 
201  const uint32_t low_nan = 0x7fff;
202  const uint32_t hi_nan = 0x7fff0000;
203 
204  using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
205  uint32x2_t check_nan;
206 
207  asm volatile("v_cmp_u_f32 %[s_cnan], %[v_x], %[v_x] \n"
208  "v_add3_u32 %[v_x], %[v_x], %[v_blo], 1 \n"
209  "v_cndmask_b32 %[v_x], %[v_x], %[v_bhi], %[s_cnan]"
210  : [s_cnan] "+s"(check_nan), [v_x] "+v"(u.fp32)
211  : [v_blo] "v"(low_nan), [v_bhi] "v"(hi_nan));
212 
213  // Note: in above code snipet, we use hi 16 bit
214  return u.hi;
215 }
216 
217 // Truncate instead of rounding, preserving SNaN
219 constexpr uint16_t float_to_bf16_truc_nan_raw(float f)
220 {
221  union
222  {
223  float fp32;
224  uint32_t int32;
225  } u = {f};
226  return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
227 }
228 
229 // Fast truncate instead of rounding, RTZ
231 constexpr uint16_t float_to_bf16_truc_raw(float f)
232 {
233  union
234  {
235  float fp32;
236  uint32_t int32;
237  } u = {f};
238  return uint16_t(u.int32 >> 16);
239 }
240 
241 template <bf16_rounding_mode rounding>
243 {
244  if constexpr(rounding == bf16_rounding_mode::standard)
245  return float_to_bf16_rtn_raw(f);
246  else if constexpr(rounding == bf16_rounding_mode::standard_asm)
247  return float_to_bf16_rtn_asm(f);
248  else if constexpr(rounding == bf16_rounding_mode::truncate_with_nan)
249  return float_to_bf16_truc_nan_raw(f);
250  else if constexpr(rounding == bf16_rounding_mode::rta_asm)
251  return float_to_bf16_rta_asm(f);
252  else
253  return float_to_bf16_truc_raw(f);
254 }
255 
256 template <bf16_rounding_mode rounding>
258 {
259  return float_to_bf16_raw(static_cast<float>(f), constant<rounding>{});
260 }
261 
263 constexpr float bf16_to_float_raw(uint16_t x)
264 {
265  union
266  {
267  uint32_t int32;
268  float fp32;
269  } u = {uint32_t(x) << 16};
270  return u.fp32;
271 }
272 
274 constexpr double bf16_to_double_raw(uint16_t x)
275 {
276  return static_cast<double>(bf16_to_float_raw(x));
277 }
278 
279 template <bf16_rounding_mode rounding =
282 {
283  return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));
284 }
285 
286 template <bf16_rounding_mode rounding =
289 {
290  return bit_cast<bfloat16_t>(double_to_bf16_raw(f, constant<rounding>{}));
291 }
292 
294 constexpr float bf16_to_float(bfloat16_t x) { return bf16_to_float_raw(bit_cast<uint16_t>(x)); }
295 
297 constexpr double bf16_to_double(bfloat16_t x) { return static_cast<double>(bf16_to_float_raw(x)); }
298 
299 template <bf16_rounding_mode rounding =
302 {
303  return bit_cast<bfloat16_t>(float_to_bf16_raw(static_cast<float>(f), constant<rounding>{}));
304 }
305 
307 constexpr half_t bf16_to_fp16(bfloat16_t x) { return static_cast<fp16_t>(static_cast<float>(x)); }
308 
309 template <class T>
310 struct numeric;
311 
312 template <>
314 {
315  // minimum finite value, or minimum positive normalized value for float
316  CK_TILE_HOST_DEVICE static constexpr bfloat16_t min()
317  {
318  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0080));
319  }
320 
321  // minumum finite value
323  {
324  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0xff7f));
325  }
326 
327  // maximum finite value
328  CK_TILE_HOST_DEVICE static constexpr bfloat16_t max()
329  {
330  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f7f));
331  }
332 
333  // difference between 1.0 and next value representable by float
335  {
336  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x1000));
337  }
338 
339  // maximum rounding error
340  // maximum rounding error
341  // bin : f edcba 9876543210
342  // bits: s eeeeeeee mmmmmmm
343  // 0 01111110 0000000 (0.5)
344  //
346  {
347  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x3f00));
348  }
349 
350  // positive infinity value
352  {
353  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7f80));
354  }
355 
356  // quiet NaN
358  {
359  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
360  }
361 
362  // signaling NaN
364  {
365  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x7FFF));
366  }
367 
368  // smallest positive subnormal value
370  {
371  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0x0001));
372  }
374  {
375  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(0));
376  }
377 };
378 
379 template <typename T>
381 
382 template <>
384 {
385  static constexpr int exp = 8;
386  static constexpr int mant = 7;
387 };
388 
389 #if CK_TILE_USE_CUSTOM_DATA_TYPE
391 #endif
392 
393 // math
396 {
397  return bit_cast<bfloat16_t>(static_cast<bf16_raw_t>(bit_cast<bf16_raw_t>(x) & 0x7fff));
398 }
399 
401 bool isnan(const bfloat16_t& x)
402 {
403  uint16_t xx = bit_cast<bf16_raw_t>(x);
404  return (xx & 0x7FFF) > 0x7C00;
405 }
406 
409 {
410  return static_cast<bfloat16_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x)));
411 };
412 
415 {
416  return static_cast<bfloat16_t>(__ocml_exp_f32(static_cast<float>(x)));
417 };
418 
420 bfloat16_t exp2(bfloat16_t x) { return static_cast<bfloat16_t>(exp2f(static_cast<float>(x))); };
421 
423 bfloat16_t log(bfloat16_t x) { return static_cast<bfloat16_t>(__logf(static_cast<float>(x))); };
424 
425 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
Definition: config.hpp:70
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE uint16_t float_to_bf16_truc_nan_raw(float f)
Definition: bfloat16.hpp:219
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:423
ushort bfloat16_t
Definition: bfloat16.hpp:105
uint32_t uint32x2_t
Definition: vector_type.hpp:122
constexpr CK_TILE_HOST_DEVICE Y bit_cast(const X &x)
Definition: bit_cast.hpp:11
_Float16 fp16_t
Definition: half.hpp:110
constexpr CK_TILE_HOST_DEVICE uint16_t double_to_bf16_raw(double f, constant< rounding >={})
Definition: bfloat16.hpp:257
constexpr CK_TILE_HOST_DEVICE float bf16_to_float_raw(uint16_t x)
Definition: bfloat16.hpp:263
constexpr CK_TILE_HOST_DEVICE uint16_t float_to_bf16_truc_raw(float f)
Definition: bfloat16.hpp:231
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
CK_TILE_HOST uint16_t float_to_bf16_rta_asm(float f)
Definition: bfloat16.hpp:186
constexpr CK_TILE_HOST_DEVICE double bf16_to_double_raw(uint16_t x)
Definition: bfloat16.hpp:274
constexpr CK_TILE_HOST_DEVICE bfloat16_t float_to_bf16(float f, constant< rounding >={})
Definition: bfloat16.hpp:281
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition: bfloat16.hpp:408
constexpr CK_TILE_HOST_DEVICE float bf16_to_float(bfloat16_t x)
Definition: bfloat16.hpp:294
uint16_t bf16_raw_t
Definition: bfloat16.hpp:107
constexpr CK_TILE_HOST_DEVICE half_t bf16_to_fp16(bfloat16_t x)
Definition: bfloat16.hpp:307
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:414
bf16_rounding_mode
Definition: bfloat16.hpp:16
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition: bfloat16.hpp:395
constexpr CK_TILE_HOST_DEVICE double bf16_to_double(bfloat16_t x)
Definition: bfloat16.hpp:297
constexpr CK_TILE_HOST_DEVICE uint16_t float_to_bf16_rtn_raw(float f)
Definition: bfloat16.hpp:111
constexpr CK_TILE_HOST_DEVICE uint16_t float_to_bf16_raw(float f, constant< rounding >={})
Definition: bfloat16.hpp:242
CK_TILE_HOST_DEVICE bool isnan(const bfloat16_t &x)
Definition: bfloat16.hpp:401
constexpr CK_TILE_HOST uint16_t float_to_bf16_rtn_asm(float f)
Definition: bfloat16.hpp:154
CK_TILE_HOST_DEVICE constexpr bfloat16_t fp16_to_bf16(half_t f, constant< rounding >={})
Definition: bfloat16.hpp:301
constexpr CK_TILE_HOST_DEVICE bfloat16_t double_to_bf16(double f, constant< rounding >={})
Definition: bfloat16.hpp:288
_Float16 half_t
Definition: half.hpp:111
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition: bfloat16.hpp:420
Definition: integral_constant.hpp:13
Definition: bfloat16.hpp:314
static constexpr CK_TILE_HOST_DEVICE bfloat16_t round_error()
Definition: bfloat16.hpp:345
static constexpr CK_TILE_HOST_DEVICE bfloat16_t infinity()
Definition: bfloat16.hpp:351
static constexpr CK_TILE_HOST_DEVICE bfloat16_t max()
Definition: bfloat16.hpp:328
static constexpr CK_TILE_HOST_DEVICE bfloat16_t denorm_min()
Definition: bfloat16.hpp:369
static constexpr CK_TILE_HOST_DEVICE bfloat16_t min()
Definition: bfloat16.hpp:316
static constexpr CK_TILE_HOST_DEVICE bfloat16_t lowest()
Definition: bfloat16.hpp:322
static constexpr CK_TILE_HOST_DEVICE bfloat16_t epsilon()
Definition: bfloat16.hpp:334
static constexpr CK_TILE_HOST_DEVICE bfloat16_t quiet_NaN()
Definition: bfloat16.hpp:357
static constexpr CK_TILE_HOST_DEVICE bfloat16_t zero()
Definition: bfloat16.hpp:373
static constexpr CK_TILE_HOST_DEVICE bfloat16_t signaling_NaN()
Definition: bfloat16.hpp:363
Definition: bfloat16.hpp:384
Definition: bfloat16.hpp:380
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_)
Definition: numeric.hpp:102