/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/core/numeric/float8.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/core/numeric/float8.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/core/numeric/float8.hpp Source File
float8.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
12 #include <stdint.h>
13 #include <type_traits>
14 
15 #pragma once
16 
17 #if(defined(__gfx94__) || defined(__gfx12__)) && __HIP_DEVICE_COMPILE__
18 #define CK_TILE_FP8_CVT_DEVICE 1
19 #else
20 #define CK_TILE_FP8_CVT_DEVICE 0
21 #endif
22 
23 namespace ck_tile {
24 
25 // fp8 rounding modes
26 // use standard for rounding to nearest, the faster one
27 // use stochastic for stochastic rounding, helps to avoid error accumulation
29 {
30  standard = 0,
32 };
33 
38 {
39  E4M3_OCP = 0, // OCP FP8 E4M3
40  E5M2_OCP = 1, // OCP BF8 E5M2
41  E4M3_FNUZ = 2, // FNUZ FP8 E4M3
42  E5M2_FNUZ = 3, // FNUZ BF8 E5M2
43 };
44 
45 /*
46  * ______________FNUZ_________________ | ______________OCP________________
47  * e4m3 e5m2 | e4m3 e5m2
48  * bias : 8 16 | 7 15
49  * inf : 1.0000.000 1.00000.00 | N/A s.11111.00
50  * Nan : 1.0000.000 1.00000.00 | s.1111.111 s.11111.{01, 10, 11}
51  * zero : 0.0000.000 0.00000.00 | s.0000.000 s.00000.00
52  * Max(norm) : s.1111.111 (240) s.11111.11(57344) | s.1111.110(448) s.11110.11(57344)
53  * Max(snorm): s.0000.111 s.00000.11 | s.0000.111 s.00000.11
54  * 0.0068359375 2.288818e-05 | 0.013671875 4.57763671875e-05
55  * Min(norm) : s.0001.000 s.00001.00 | s.0001.000 s.00001.00
56  * 2^-7(0.00078125) 2^-15(3.05176e-05) | 2^-6(0.015625) 2^-14(6.10352e-05)
57  * Min(snorm): s.0000.001 s.00000.01 | s.0000.001 s.00000.01
58  * 2^-10(0.00097656) 2^-17(7.629395e-06)| 2^-9(0.001953125) 2^-16(1.52588e-05)
59  */
60 
61 template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
62 CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float, constant<rounding> = {});
63 
64 template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
65 CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float, constant<rounding> = {});
66 
69 
70 #if CK_TILE_USE_CUSTOM_DATA_TYPE
71 struct alignas(1) float8_e4m3_t
72 {
73  static constexpr int exponent = 4;
74  static constexpr int mantissa = 3;
75 #if CK_TILE_USE_OCP_FP8
76  static constexpr int bias = 7; // OCP
77 #else
78  static constexpr int bias = 8; // FNUZ
79 #endif
80  using raw_type = uint8_t;
81  raw_type data;
82 
84  static constexpr float8_e4m3_t bit_cast(raw_type x)
85  {
86  float8_e4m3_t y;
87  y.data = x;
88  return y;
89  }
90 
91  // constructor
92  constexpr float8_e4m3_t() : data() {}
93 
94  // construct from float
96  explicit constexpr float8_e4m3_t(const float& x) : data(float_to_fp8_raw(x)) {}
97 
98  // construct from int
100  explicit constexpr float8_e4m3_t(const int& x) : data(float_to_fp8_raw(static_cast<float>(x)))
101  {
102  }
103 
104  // construct from unsigned int
106  explicit constexpr float8_e4m3_t(const unsigned int& x)
107  : data(float_to_fp8_raw(static_cast<float>(x)))
108  {
109  }
110 
111  // cast to float
113  explicit constexpr operator float() const { return fp8_to_float_raw(data); }
114 
115  // cast to int
117  explicit constexpr operator int() const { return static_cast<int>(fp8_to_float_raw(data)); }
118 
119  // internal access
121  constexpr raw_type& get() { return data; }
122 
124  constexpr raw_type get() const { return data; }
125 };
126 using fp8_t = float8_e4m3_t;
127 using fp8_raw_t = typename fp8_t::raw_type;
128 
129 struct alignas(1) float8_e5m2_t
130 {
131  static constexpr int exponent = 5;
132  static constexpr int mantissa = 2;
133 #if CK_TILE_USE_OCP_FP8
134  static constexpr int bias = 15; // OCP
135 #else
136  static constexpr int bias = 16; // FNUZ
137 #endif
138  using raw_type = uint8_t;
139  raw_type data;
140 
142  static constexpr float8_e5m2_t bit_cast(raw_type x)
143  {
144  float8_e5m2_t y;
145  y.data = x;
146  return y;
147  }
148 
149  // constructor
150  constexpr float8_e5m2_t() : data() {}
151 
152  // construct from float
154  explicit constexpr float8_e5m2_t(const float& x) : data(float_to_bf8_raw(x)) {}
155 
156  // construct from int
158  explicit constexpr float8_e5m2_t(const int& x) : data(float_to_bf8_raw(static_cast<float>(x)))
159  {
160  }
161 
162  // construct from unsigned int
164  explicit constexpr float8_e5m2_t(const unsigned int& x)
165  : data(float_to_bf8_raw(static_cast<float>(x)))
166  {
167  }
168 
169  // cast to float
171  explicit constexpr operator float() const { return bf8_to_float_raw(data); }
172 
173  // cast to int
175  explicit constexpr operator int() const { return static_cast<int>(bf8_to_float_raw(data)); }
176 
177  // internal access
179  constexpr raw_type& get() { return data; }
180 
182  constexpr raw_type get() const { return data; }
183 };
184 using bf8_t = float8_e5m2_t;
185 using bf8_raw_t = typename bf8_t::raw_type;
186 
187 template <typename>
188 struct native_t;
189 
190 template <>
191 struct native_t<fp8_t>
192 {
193  using type = _BitInt(8);
194 };
195 
196 template <>
197 struct native_t<bf8_t>
198 {
199  using type = unsigned _BitInt(8);
200 };
201 
202 #else
203 
204 using fp8_t = _BitInt(8);
205 using fp8_raw_t = uint8_t;
206 using bf8_t = unsigned _BitInt(8);
207 using bf8_raw_t = uint8_t;
208 #endif
209 
210 template <>
212 {
214 
215  static constexpr int exp = 4;
216  static constexpr int mant = 3;
217 #if CK_TILE_USE_OCP_FP8
218  static constexpr int bias = 7;
219  static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E4M3_OCP;
220 #else
221  static constexpr int bias = 8;
222  static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E4M3_FNUZ;
223 #endif
224  static constexpr uint8_t abs_mask = 0x7F;
225  static constexpr int PackedSize = 1;
226 };
227 
228 template <>
230 {
232 
233  static constexpr int exp = 5;
234  static constexpr int mant = 2;
235 #if CK_TILE_USE_OCP_FP8
236  static constexpr int bias = 15;
237  static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E5M2_OCP;
238 #else
239  static constexpr int bias = 16;
240  static constexpr fp8_interpretation f8_interpret = fp8_interpretation::E5M2_FNUZ;
241 #endif
242  static constexpr uint8_t abs_mask = 0x7F;
243  static constexpr int PackedSize = 1;
244 };
245 
246 // below is sw fp8 conversion, not utilizing hw instruction
247 namespace impl {
248 
249 template <typename SrcT, typename DstT, bool clip = true, bool stoch = false>
250 CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng = 0)
251 {
252  static_assert(std::is_same<DstT, fp8_t>::value || std::is_same<DstT, bf8_t>::value,
253  "DstT type must be fp8 or bf8.");
254 
255  constexpr bool is_half = std::is_same<SrcT, half_t>::value;
256  constexpr bool is_float = std::is_same<SrcT, float>::value;
257  static_assert(is_half || is_float, "Only half and float can be cast to f8");
258 
259  // fp8/bf8 type exponent/mantissa layout
260  constexpr int DstT_exp = numeric_traits<DstT>::exp; // exponent width of the destination type
261  constexpr int DstT_mant = numeric_traits<DstT>::mant; // mantissa width of the destination type
262  constexpr bool is_fnuz =
265 
266  constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
267  constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
268 
269  using SrcT_bitwise = typename numeric_traits<SrcT>::bitwise_type;
270  SrcT_bitwise src_bitwise = bit_cast<SrcT_bitwise>(src);
271 
272  unsigned long long head, mantissa;
273  int exponent, bias;
274  unsigned int sign;
275  unsigned long long fInf, abs_mask;
276 
277  head = src_bitwise & numeric_traits<SrcT>::head_mask;
278  mantissa = src_bitwise & numeric_traits<SrcT>::mant_mask;
279  exponent = (head >> SrcT_mant) & numeric_traits<SrcT>::exp_mask;
280  sign = head >> (SrcT_exp + SrcT_mant);
284 
285  unsigned int signed_inf = 0;
286  unsigned int nan = 0;
287  if constexpr(is_fnuz)
288  {
289  signed_inf = clip ? ((sign << 7) + 0x7f) : 0x80;
290  nan = 0x80;
291  }
292  else
293  {
294  if constexpr(DstT_exp == 4)
295  { // e4m3
296  signed_inf = (sign << 7) + (clip ? 0x7e : 0x7f);
297  }
298  else
299  { // e5m2
300  signed_inf = (sign << 7) + (clip ? 0x7b : 0x7c);
301  }
302  nan = (sign << 7) + 0x7f;
303  }
304  // Max values
305  unsigned long long ifmax = 0;
306  if constexpr(is_float)
307  {
308  if constexpr(DstT_exp == 5)
309  {
310  ifmax = 0x47600000;
311  }
312  else
313  {
314  if constexpr(is_fnuz)
315  {
316  ifmax = 0x43700000;
317  }
318  else
319  {
320  ifmax = 0x43E00000;
321  }
322  }
323  }
324  else if constexpr(is_half)
325  {
326  if constexpr(DstT_exp == 5)
327  {
328  ifmax = 0x7B00;
329  }
330  else
331  {
332  if constexpr(is_fnuz)
333  {
334  ifmax = 0x5B80;
335  }
336  else
337  {
338  ifmax = 0x5F00;
339  }
340  }
341  }
342 
343  // Deal with inf and NaNs
344  if((src_bitwise & fInf) == fInf)
345  {
346  if constexpr(is_fnuz)
347  return signed_inf;
348 
349  return mantissa != 0 ? nan : signed_inf;
350  }
351 
352  if((src_bitwise & abs_mask) > ifmax)
353  {
354  return signed_inf;
355  }
356 
357  if(src_bitwise == 0)
358  {
359  return 0;
360  }
361 
362  // First need to check if it is normal or denorm as there is a difference of
363  // implicit 1 Then need to adjust the exponent to align with the F8 exponent,
364  // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
365  // to mantissa and truncate. And for RNE, no need to add rng. Then probably
366  // need to check whether there is carry and adjust exponent and mantissa again
367 
368  // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
369  // bits
370  const int f8_bias = (1 << (DstT_exp - 1)) - 1 + (is_fnuz ? 1 : 0);
371  const int f8_denormal_act_exponent = 1 - f8_bias; // actual exponent of f8 denormal
372  // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
373  // f8_exponent is the converted f8 exponent with bias encoding
374  // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
375  // the difference needs to be adjusted and mantissa shifted
376  int act_exponent, f8_exponent, exponent_diff;
377 
378  if(exponent == 0)
379  { // fp32/fp16 is in denormal.
380  /* fp32 denormal is below 2^-127 so it is usually not a concern here, we
381  mostly concern fp16 here. In this case, f8 is usually in denormal. But there
382  could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
383  exponent bias 16. It means that there are some numbers in fp16 denormal but they
384  are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
385  where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
386  (NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
387  act_exponent = exponent - bias + 1;
388  exponent_diff = f8_denormal_act_exponent -
389  act_exponent; // actual exponent is exponent-bias+1 as it is denormal
390  }
391  else
392  { // fp32/fp16 is normal with implicit 1
393  act_exponent = exponent - bias;
394  if(act_exponent <= f8_denormal_act_exponent)
395  {
396  /* This is the case where fp32/fp16 is normal but it is in f8 denormal
397  range. For example fp8 nanoo mode, denormal exponent is -7, but if the fp32/fp16
398  actual exponent is -7, it is actually larger due to the implicit 1,
399  Therefore it needs to be adjust to -6 and mantissa shift right by 1.
400  So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
401  exponent_diff = f8_denormal_act_exponent - act_exponent;
402  }
403  else
404  { // both fp32/fp16 and f8 are in normal range
405  exponent_diff = 0; // exponent_diff=0 does not mean there is no difference
406  // for this case, act_exponent could be larger. Just
407  // that it does not need shift mantissa
408  }
409  mantissa += (1ull << SrcT_mant); // Add the implicit 1 into mantissa
410  }
411 
412  bool midpoint = (mantissa & ((1ull << (SrcT_mant - DstT_mant + exponent_diff)) - 1)) ==
413  (1ull << (SrcT_mant - DstT_mant + exponent_diff - 1));
414  /* This part is a bit tricky. The judgment of whether it is a tie needs to be
415  done before we shift right as shift right could rip off some residual part and
416  make something not midpoint look like midpoint. For example, the fp16 number
417  0x1002 (0 00100 0000000010), it is larger than midpoint, but after shift right
418  by 4 bits, it would look like midpoint.
419  */
420 
421  if(exponent_diff > 0)
422  mantissa >>= exponent_diff;
423  else if(exponent_diff == -1)
424  mantissa <<= -exponent_diff;
425  bool implicit_one = mantissa & (1ull << SrcT_mant);
426  // if there is no implicit 1, it means the f8 is denormal and need to adjust
427  // to denorm exponent
428  f8_exponent =
429  (act_exponent + exponent_diff) /*actual f8 exponent*/ + f8_bias - (implicit_one ? 0 : 1);
430 
431  // Now we have the exponent and mantissa adjusted
432  unsigned long long drop_mask = (1ull << (SrcT_mant - DstT_mant)) - 1;
433  bool odd =
434  mantissa & (1ull << (SrcT_mant -
435  DstT_mant)); // if the least significant bit that is not truncated is 1
436  mantissa +=
437  (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1ull) : mantissa)) & drop_mask;
438 
439  // Now we deal with overflow
440  if(f8_exponent == 0)
441  {
442  if((1ull << SrcT_mant) & mantissa)
443  {
444  f8_exponent = 1; // denormal overflow to become normal, promote exponent
445  }
446  }
447  else
448  {
449  if((1ull << (SrcT_mant + 1)) & mantissa)
450  {
451  mantissa >>= 1;
452  f8_exponent++;
453  }
454  }
455 
456  mantissa >>= (SrcT_mant - DstT_mant);
457 
458  // above range: quantize to maximum possible float of the same sign
459  const int max_exp = (1 << DstT_exp) - 1;
460  if(f8_exponent > max_exp)
461  {
462  if constexpr(clip)
463  {
464  mantissa = (1 << DstT_mant) - 1;
465  f8_exponent = max_exp;
466  }
467  else
468  {
469  return signed_inf;
470  }
471  }
472 
473  if(f8_exponent == 0 && mantissa == 0)
474  return is_fnuz ? 0 : (sign << 7);
475  mantissa &= (1 << DstT_mant) - 1;
476  return (sign << 7) | (f8_exponent << DstT_mant) | mantissa;
477 }
478 
479 template <typename SrcT, typename DstT, bool clip = true>
481 {
482  static_assert(std::is_same<SrcT, fp8_t>::value || std::is_same<SrcT, bf8_t>::value,
483  "SrcT type must be fp8 or bf8.");
484  constexpr int SrcT_exp = numeric_traits<SrcT>::exp;
485  constexpr int SrcT_mant = numeric_traits<SrcT>::mant;
486  constexpr bool is_fnuz =
489 
490  constexpr bool is_half = std::is_same<DstT, half_t>::value;
491  constexpr bool is_float = std::is_same<DstT, float>::value;
492  static_assert(is_half || is_float, "DstT type must be half_t or float.");
493 
494  // destination type exponent/mantissa layout
495  constexpr int DstT_exp = numeric_traits<DstT>::exp; // exponent width of the destination type
496  constexpr int DstT_mant = numeric_traits<DstT>::mant; // mantissa width of the destination type
497 
498  constexpr DstT fInf = bit_cast<DstT>(numeric_traits<DstT>::Inf);
499  constexpr DstT fNegInf = bit_cast<DstT>(numeric_traits<DstT>::NegInf);
500  constexpr DstT fNaN = bit_cast<DstT>(numeric_traits<DstT>::NaN);
501  constexpr DstT fNeg0 = bit_cast<DstT>(numeric_traits<DstT>::Neg0);
502 
503  DstT fmax{0}, fmin{0};
504  // Max number in e5m2 57344
505  if constexpr(is_half)
506  {
507  fmax = bit_cast<DstT>(static_cast<typename numeric_traits<DstT>::bitwise_type>(0x7B00));
508  fmin = bit_cast<DstT>(static_cast<typename numeric_traits<DstT>::bitwise_type>(0xFB00));
509  }
510  else if constexpr(is_float)
511  {
512  fmax = bit_cast<DstT>(static_cast<typename numeric_traits<DstT>::bitwise_type>(0x47600000));
513  fmin = bit_cast<DstT>(static_cast<typename numeric_traits<DstT>::bitwise_type>(0xC7600000));
514  }
515 
516  if(x == 0)
517  {
518  return 0;
519  }
520 
521  unsigned long long sign = x >> 7;
522  unsigned long long mantissa = x & ((1 << SrcT_mant) - 1);
523  int exponent = (x & 0x7F) >> SrcT_mant;
524  if constexpr(is_fnuz)
525  {
526  if((x & 0xff) == 0x80)
527  {
528  return fNaN;
529  }
530  }
531  else
532  {
533  if(x == SrcT(0x80))
534  {
535  return fNeg0;
536  }
537  if constexpr(SrcT_exp == 4)
538  { // e4m3
539  if((x & 0x7F) == 0x7F)
540  {
541  return fNaN;
542  }
543  }
544  else if((x & 0x7C) == 0x7C)
545  { // e5m2
546  if((x & 0x3) == 0)
547  {
548  if constexpr(clip)
549  {
550  return sign ? fmin : fmax;
551  }
552  return sign ? fNegInf : fInf;
553  }
554  return fNaN;
555  }
556  }
557 
558  typename numeric_traits<DstT>::bitwise_type retval;
559 
560  if constexpr(SrcT_exp == 5 && is_half && !is_fnuz)
561  {
562  retval = x << 8;
563  return bit_cast<DstT>(retval);
564  }
565 
566  const int exp_low_cutoff =
567  (1 << (DstT_exp - 1)) - (1 << (SrcT_exp - 1)) + 1 - (is_fnuz ? 1 : 0);
568 
569  // subnormal input
570  if(exponent == 0)
571  {
572  int sh = 1 + clz(mantissa) - (32 - SrcT_mant);
573  mantissa <<= sh;
574  exponent += 1 - sh;
575  mantissa &= ((1ull << SrcT_mant) - 1);
576  }
577  exponent += exp_low_cutoff - 1;
578  mantissa <<= DstT_mant - SrcT_mant;
579 
580  // subnormal output (occurs when DstT is half_t, we=5, is_fnuz=true)
581  if(exponent <= 0)
582  {
583  mantissa |= 1 << DstT_mant;
584  mantissa >>= 1 - exponent;
585  exponent = 0;
586  }
587 
588  retval = (sign << (DstT_exp + DstT_mant)) | (exponent << DstT_mant) | mantissa;
589 
590  return bit_cast<DstT>(retval);
591 }
592 
593 template <typename X, typename Y, bool clip, bool stoch>
594 CK_TILE_HOST_DEVICE Y cast_to_f8(X x, uint32_t rng)
595 {
596  return bit_cast<Y>(run_cast_to_f8<X, Y, clip, stoch>(x, rng));
597 }
598 
599 #if CK_TILE_FP8_CVT_DEVICE
603 template <fp8_interpretation interpret, bool saturate, bool stochastic_rounding = false>
604 CK_TILE_DEVICE uint8_t cast_to_f8_from_f32(float v, unsigned int rng = 0)
605 {
606  uint8_t i8data;
607  union
608  {
609  float fval;
610  unsigned int i32val;
611  unsigned char i8val[4]; // NOTE: not endian independent
612  } val;
613 
614  unsigned int ival = 0;
615  val.fval = v;
616 
617  if constexpr(saturate)
618  {
619  if constexpr(interpret == fp8_interpretation::E4M3_FNUZ)
620  {
621  if((val.i32val & 0x7F800000) != 0x7F800000)
622  {
623  val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0);
624  }
625  }
626  else if constexpr(interpret == fp8_interpretation::E4M3_OCP)
627  { // OCP type
628  if((val.i32val & 0x7F800000) != 0x7F800000)
629  {
630  val.fval = __builtin_amdgcn_fmed3f(val.fval, 448.0, -448.0);
631  }
632  }
633  else
634  {
635  if((val.i32val & 0x7F800000) != 0x7F800000)
636  {
637  val.fval = __builtin_amdgcn_fmed3f(val.fval, 57344.0, -57344.0);
638  }
639  }
640  }
641 
642  if constexpr(stochastic_rounding)
643  {
644  ival = (interpret == fp8_interpretation::E4M3_FNUZ) ||
645  (interpret == fp8_interpretation::E4M3_OCP)
646  ? __builtin_amdgcn_cvt_sr_fp8_f32(val.fval, rng, ival, 0)
647  : __builtin_amdgcn_cvt_sr_bf8_f32(val.fval, rng, ival, 0); // 0 pos
648  val.i32val = ival;
649  i8data = val.i8val[0]; // little endian
650  }
651  else
652  { // RNE CVT
653  ival = (interpret == fp8_interpretation::E4M3_FNUZ) ||
654  (interpret == fp8_interpretation::E4M3_OCP)
655  ? __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, false)
656  : __builtin_amdgcn_cvt_pk_bf8_f32(val.fval,
657  val.fval,
658  ival,
659  false); // false -> WORD0
660  val.i32val = ival;
661  i8data = val.i8val[0];
662  }
663  return i8data;
664 }
665 #endif // CK_TILE_FP8_CVT_DEVICE
666 
667 } // namespace impl
668 
682 template <typename SrcT, typename DstT>
684 {
685  constexpr bool clip = true;
686  constexpr int seed = 42;
687  uint32_t rng = prand_generator_t<SrcT, seed>{}(reinterpret_cast<uintptr_t>(&x), x);
688 #if CK_TILE_FP8_CVT_DEVICE
689  return impl::cast_to_f8_from_f32<numeric_traits<DstT>::f8_interpret, clip, true>(x, rng);
690 #else
691  return bit_cast<typename numeric_traits<DstT>::bitwise_type>(
692  impl::cast_to_f8<SrcT, DstT, clip, true>(x, rng));
693 #endif
694 }
695 
708 template <typename SrcT, typename DstT>
710 {
711  constexpr bool clip = true;
712 #if CK_TILE_FP8_CVT_DEVICE
713  return impl::cast_to_f8_from_f32<numeric_traits<DstT>::f8_interpret, clip, false>(x, 0);
714 #else
715  return bit_cast<typename numeric_traits<DstT>::bitwise_type>(
716  impl::cast_to_f8<SrcT, DstT, clip, false>(x, 0));
717 #endif
718 }
719 
720 template <fp8_rounding_mode rounding>
722 {
723  if constexpr(rounding == fp8_rounding_mode::standard)
724  {
725  return float_to_fp8_rtn_raw<float, fp8_t>(x);
726  }
727  else if constexpr(rounding == fp8_rounding_mode::stochastic)
728  {
729  return float_to_fp8_sr_raw<float, fp8_t>(x);
730  }
731  else
732  {
733  return fp8_raw_t{0};
734  }
735 }
736 
737 template <fp8_rounding_mode rounding>
739 {
740  if constexpr(rounding == fp8_rounding_mode::standard)
741  {
742  return float_to_fp8_rtn_raw<float, bf8_t>(x);
743  }
744  else if constexpr(rounding == fp8_rounding_mode::stochastic)
745  {
746  return float_to_fp8_sr_raw<float, bf8_t>(x);
747  }
748  else
749  {
750  return bf8_raw_t{0};
751  }
752 }
753 
755 {
756 #if CK_TILE_FP8_CVT_DEVICE
757  float fval;
758  uint32_t i32val = static_cast<uint32_t>(x);
759  fval = __builtin_amdgcn_cvt_f32_fp8(i32val, 0);
760  // asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
761  return fval;
762 #else
763  return impl::run_cast_from_f8<fp8_t, float>(bit_cast<fp8_t>(x));
764 #endif
765 }
766 
768 {
769 #if CK_TILE_FP8_CVT_DEVICE
770  float fval;
771  uint32_t i32val = static_cast<uint32_t>(x);
772  fval = __builtin_amdgcn_cvt_f32_bf8(i32val, 0);
773  // asm volatile("v_cvt_f32_bf8 %0, %1 src0_sel:BYTE_0" : "=v"(fval) : "v"(i32val));
774  return fval;
775 #else
776  return impl::run_cast_from_f8<bf8_t, float>(bit_cast<bf8_t>(x));
777 #endif
778 }
779 
780 template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
782 {
783  return bit_cast<fp8_t>(float_to_fp8_raw(x, constant<rounding>{}));
784 }
785 
786 template <fp8_rounding_mode rounding = static_cast<fp8_rounding_mode>(CK_TILE_FLOAT_TO_FP8_DEFAULT)>
788 {
789  return bit_cast<bf8_t>(float_to_bf8_raw(x, constant<rounding>{}));
790 }
791 
792 CK_TILE_HOST_DEVICE float fp8_to_float(fp8_t x) { return fp8_to_float_raw(bit_cast<fp8_raw_t>(x)); }
793 
794 CK_TILE_HOST_DEVICE float bf8_to_float(bf8_t x) { return bf8_to_float_raw(bit_cast<bf8_raw_t>(x)); }
795 
796 template <class T>
797 struct numeric;
798 
799 #if CK_TILE_USE_OCP_FP8
800 template <>
801 struct numeric<fp8_t>
802 {
803  // minimum finite value, or minimum positive normal value
804  CK_TILE_HOST_DEVICE static constexpr fp8_t min()
805  {
806  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x08)); // 0b00001000 = 2^-6
807  }
808 
809  // minumum finite value
810  CK_TILE_HOST_DEVICE static constexpr fp8_t lowest()
811  {
812  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0xfe)); // 0b11111110 = -448
813  }
814 
815  // maximum finite value
816  CK_TILE_HOST_DEVICE static constexpr fp8_t max()
817  {
818  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x7e)); // 0b01111110 = 448
819  }
820 
821  // difference between 1.0 and next representable f8 value (1.125)
822  // returns fp8_t(0.125)
823  CK_TILE_HOST_DEVICE static constexpr fp8_t epsilon()
824  {
825  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x20)); // 0.125
826  }
827 
828  // rounding error (0.0625)
829  // half of epsilon
830  CK_TILE_HOST_DEVICE static constexpr fp8_t round_error()
831  {
832  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x18)); // 0.0625
833  }
834 
835  // quiet NaN
836  CK_TILE_HOST_DEVICE static constexpr fp8_t quiet_NaN()
837  {
838  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x7F)); // 0b01111111
839  }
840 
841  // signaling NaN
842  CK_TILE_HOST_DEVICE static constexpr fp8_t signaling_NaN()
843  {
844  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0xFF)); // 0b11111111
845  }
846 
847  // smallest positive subnormal value
848  CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min()
849  {
850  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x01));
851  }
852 
853  CK_TILE_HOST_DEVICE static constexpr fp8_t zero()
854  {
855  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0));
856  }
857 };
858 
859 template <>
860 struct numeric<bf8_t>
861 {
862  // minimum finite value, or minimum positive normalized value for float
863  CK_TILE_HOST_DEVICE static constexpr bf8_t min()
864  {
865  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x04)); // 0b00000100 = 2^-14
866  }
867 
868  // minumum finite value
869  CK_TILE_HOST_DEVICE static constexpr bf8_t lowest()
870  {
871  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0xfb)); // 0b11111011 = -57344
872  }
873 
874  // maximum finite value
875  CK_TILE_HOST_DEVICE static constexpr bf8_t max()
876  {
877  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7b)); // 0b01111011 = 57344
878  }
879 
880  // difference between 1.0 and next representable bf8 value (1.25)
881  CK_TILE_HOST_DEVICE static constexpr bf8_t epsilon()
882  {
883  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x34)); // 0.25
884  }
885 
886  // rounding error (0.125)
887  // half of epsilon
888  CK_TILE_HOST_DEVICE static constexpr bf8_t round_error()
889  {
890  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x30)); // 0.125
891  }
892 
893  // positive infinity value
894  CK_TILE_HOST_DEVICE static constexpr bf8_t infinity()
895  {
896  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7c)); // 0b01111100
897  }
898 
899  // quiet NaN
900  CK_TILE_HOST_DEVICE static constexpr bf8_t quiet_NaN()
901  {
902  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7F)); // 0b01111111
903  }
904 
905  // signaling NaN
906  CK_TILE_HOST_DEVICE static constexpr bf8_t signaling_NaN()
907  {
908  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0xFF));
909  }
910 
911  // smallest positive subnormal value
912  CK_TILE_HOST_DEVICE static constexpr bf8_t denorm_min()
913  {
914  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x01));
915  }
916 
917  CK_TILE_HOST_DEVICE static constexpr bf8_t zero()
918  {
919  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0));
920  }
921 };
922 #else
923 template <>
924 struct numeric<fp8_t>
925 {
926  // minimum finite value, or minimum positive normalized value for float
927  CK_TILE_HOST_DEVICE static constexpr fp8_t min()
928  {
929  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x08));
930  }
931 
932  // minumum finite value
933  CK_TILE_HOST_DEVICE static constexpr fp8_t lowest()
934  {
935  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0xff));
936  }
937 
938  // maximum finite value
939  CK_TILE_HOST_DEVICE static constexpr fp8_t max()
940  {
941  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x7f));
942  }
943 
944  // difference between 1.0 and next value representable by float
945  CK_TILE_HOST_DEVICE static constexpr fp8_t epsilon()
946  {
947  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x20));
948  }
949 
950  // maximum rounding error
951  // bin : 7 6543 210
952  // bits: s eeee mmm
953  // 0 0110 000 (0.5)
954  //
956  {
957  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x30));
958  }
959 
960  // positive infinity value
961  CK_TILE_HOST_DEVICE static constexpr fp8_t infinity()
962  {
963  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x80));
964  }
965 
966  // quiet NaN
968  {
969  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x80));
970  }
971 
972  // signaling NaN
974  {
975  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x80));
976  }
977 
978  // smallest positive subnormal value
980  {
981  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0x01));
982  }
983 
984  CK_TILE_HOST_DEVICE static constexpr fp8_t zero()
985  {
986  return bit_cast<fp8_t>(static_cast<fp8_raw_t>(0));
987  }
988 };
989 
990 template <>
991 struct numeric<bf8_t>
992 {
993  // minimum finite value, or minimum positive normalized value for float
994  CK_TILE_HOST_DEVICE static constexpr bf8_t min()
995  {
996  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x04));
997  }
998 
999  // minumum finite value
1000  CK_TILE_HOST_DEVICE static constexpr bf8_t lowest()
1001  {
1002  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0xff));
1003  }
1004 
1005  // maximum finite value
1006  CK_TILE_HOST_DEVICE static constexpr bf8_t max()
1007  {
1008  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x7f));
1009  }
1010 
1011  // difference between 1.0 and next value representable by float
1012  CK_TILE_HOST_DEVICE static constexpr bf8_t epsilon()
1013  {
1014  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x34));
1015  }
1016 
1017  // maximum rounding error
1018  // bin : 7 65432 10
1019  // bits: s eeeee mm
1020  // 0 01110 00 (0.5)
1021  //
1023  {
1024  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x38));
1025  }
1026 
1027  // positive infinity value
1029  {
1030  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x80));
1031  }
1032 
1033  // quiet NaN
1035  {
1036  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x80));
1037  }
1038 
1039  // signaling NaN
1041  {
1042  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x80));
1043  }
1044 
1045  // smallest positive subnormal value
1047  {
1048  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0x01));
1049  }
1050 
1051  CK_TILE_HOST_DEVICE static constexpr bf8_t zero()
1052  {
1053  return bit_cast<bf8_t>(static_cast<bf8_raw_t>(0));
1054  }
1055 };
1056 #endif
1057 
1058 #if CK_TILE_USE_CUSTOM_DATA_TYPE
1061 #endif
1062 
1063 // math
1064 template <typename T>
1066 {
1067  static_assert(std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>,
1068  "Only fp8_t and bf8_t are supported");
1069  return bit_cast<T>(static_cast<uint8_t>(bit_cast<uint8_t>(x) & numeric_traits<T>::abs_mask));
1070 }
1071 
1073 bool isnan(const fp8_t& x)
1074 {
1075  uint8_t xx = bit_cast<fp8_raw_t>(x);
1076 
1077 #if CK_TILE_USE_OCP_FP8
1078  return (xx & 0x7f) == 0x7f;
1079 #else
1080  return xx == 0x80;
1081 #endif
1082 }
1083 #if CK_TILE_USE_CUSTOM_DATA_TYPE
1085 fp8_t sqrt(fp8_t x) { return static_cast<fp8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
1086 
1088 fp8_t exp(fp8_t x) { return static_cast<fp8_t>(__ocml_exp_f32(static_cast<float>(x))); };
1089 
1091 fp8_t exp2(fp8_t x) { return static_cast<fp8_t>(exp2f(static_cast<float>(x))); };
1092 
1094 fp8_t log(fp8_t x) { return static_cast<fp8_t>(__logf(static_cast<float>(x))); };
1095 #endif
1096 
1098 bool isnan(const bf8_t& x)
1099 {
1100  uint8_t xx = bit_cast<bf8_raw_t>(x);
1101 
1102 #if CK_TILE_USE_OCP_FP8
1103  return (xx & 0x7f) > 0x7c;
1104 #else
1105  return xx == 0x80;
1106 #endif
1107 }
1108 
1109 #if CK_TILE_USE_CUSTOM_DATA_TYPE
1111 bf8_t sqrt(bf8_t x) { return static_cast<bf8_t>(__builtin_amdgcn_sqrtf(static_cast<float>(x))); };
1112 
1114 bf8_t exp(bf8_t x) { return static_cast<bf8_t>(__ocml_exp_f32(static_cast<float>(x))); };
1115 
1117 bf8_t exp2(bf8_t x) { return static_cast<bf8_t>(exp2f(static_cast<float>(x))); };
1118 
1120 bf8_t log(bf8_t x) { return static_cast<bf8_t>(__logf(static_cast<float>(x))); };
1121 #endif
1122 
1123 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_FLOAT_TO_FP8_DEFAULT
Definition: config.hpp:78
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
CK_TILE_HOST_DEVICE DstT run_cast_to_f8(SrcT src, unsigned int rng=0)
Definition: float8.hpp:250
CK_TILE_HOST_DEVICE DstT run_cast_from_f8(SrcT x)
Definition: float8.hpp:480
CK_TILE_HOST_DEVICE Y cast_to_f8(X x, uint32_t rng)
Definition: float8.hpp:594
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE bfloat16_t log(bfloat16_t x)
Definition: bfloat16.hpp:421
fp8_interpretation
FP8 interpretation used in conversion algorithms.
Definition: float8.hpp:38
_BitInt(8) fp8_t
Definition: float8.hpp:204
CK_TILE_HOST_DEVICE fp8_t float_to_fp8(float x, constant< rounding >={})
Definition: float8.hpp:781
CK_TILE_HOST_DEVICE float fp8_to_float_raw(uint8_t)
Definition: float8.hpp:754
constexpr CK_TILE_HOST_DEVICE Y bit_cast(const X &x)
Definition: bit_cast.hpp:11
CK_TILE_HOST_DEVICE float bf8_to_float_raw(uint8_t)
Definition: float8.hpp:767
fp8_rounding_mode
Definition: float8.hpp:29
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition: bfloat16.hpp:406
CK_TILE_HOST_DEVICE uint8_t float_to_fp8_raw(float, constant< rounding >={})
Definition: float8.hpp:721
uint8_t fp8_raw_t
Definition: float8.hpp:205
CK_TILE_HOST_DEVICE float bf8_to_float(bf8_t x)
Definition: float8.hpp:794
CK_TILE_HOST_DEVICE numeric_traits< DstT >::bitwise_type float_to_fp8_sr_raw(SrcT x)
Converts a floating-point value to an 8-bit floating-point representation with stochastic rounding.
Definition: float8.hpp:683
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:412
CK_TILE_HOST int clz(uint32_t x)
Definition: math.hpp:264
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition: bfloat16.hpp:393
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
uint8_t bf8_raw_t
Definition: float8.hpp:207
CK_TILE_HOST_DEVICE bf8_t float_to_bf8(float x, constant< rounding >={})
Definition: float8.hpp:787
CK_TILE_HOST_DEVICE bool isnan(const bfloat16_t &x)
Definition: bfloat16.hpp:399
CK_TILE_HOST_DEVICE numeric_traits< DstT >::bitwise_type float_to_fp8_rtn_raw(SrcT x)
Converts a floating-point value to an 8-bit floating-point representation with rounding to nearest ev...
Definition: float8.hpp:709
CK_TILE_HOST_DEVICE float fp8_to_float(fp8_t x)
Definition: float8.hpp:792
CK_TILE_HOST_DEVICE uint8_t float_to_bf8_raw(float, constant< rounding >={})
Definition: float8.hpp:738
CK_TILE_DEVICE bfloat16_t exp2(bfloat16_t x)
Definition: bfloat16.hpp:418
Definition: integral_constant.hpp:13
remove_cvref_t< T > type
Definition: vector_type.hpp:26
static constexpr CK_TILE_HOST_DEVICE bf8_t min()
Definition: float8.hpp:994
static constexpr CK_TILE_HOST_DEVICE bf8_t quiet_NaN()
Definition: float8.hpp:1034
static constexpr CK_TILE_HOST_DEVICE bf8_t lowest()
Definition: float8.hpp:1000
static constexpr CK_TILE_HOST_DEVICE bf8_t round_error()
Definition: float8.hpp:1022
static constexpr CK_TILE_HOST_DEVICE bf8_t signaling_NaN()
Definition: float8.hpp:1040
static constexpr CK_TILE_HOST_DEVICE bf8_t denorm_min()
Definition: float8.hpp:1046
static constexpr CK_TILE_HOST_DEVICE bf8_t epsilon()
Definition: float8.hpp:1012
static constexpr CK_TILE_HOST_DEVICE bf8_t infinity()
Definition: float8.hpp:1028
static constexpr CK_TILE_HOST_DEVICE bf8_t max()
Definition: float8.hpp:1006
static constexpr CK_TILE_HOST_DEVICE bf8_t zero()
Definition: float8.hpp:1051
static constexpr CK_TILE_HOST_DEVICE fp8_t signaling_NaN()
Definition: float8.hpp:973
static constexpr CK_TILE_HOST_DEVICE fp8_t zero()
Definition: float8.hpp:984
static constexpr CK_TILE_HOST_DEVICE fp8_t min()
Definition: float8.hpp:927
static constexpr CK_TILE_HOST_DEVICE fp8_t lowest()
Definition: float8.hpp:933
static constexpr CK_TILE_HOST_DEVICE fp8_t epsilon()
Definition: float8.hpp:945
static constexpr CK_TILE_HOST_DEVICE fp8_t quiet_NaN()
Definition: float8.hpp:967
static constexpr CK_TILE_HOST_DEVICE fp8_t max()
Definition: float8.hpp:939
static constexpr CK_TILE_HOST_DEVICE fp8_t denorm_min()
Definition: float8.hpp:979
static constexpr CK_TILE_HOST_DEVICE fp8_t round_error()
Definition: float8.hpp:955
static constexpr CK_TILE_HOST_DEVICE fp8_t infinity()
Definition: float8.hpp:961
bf8_raw_t bitwise_type
Definition: float8.hpp:231
fp8_raw_t bitwise_type
Definition: float8.hpp:213
Definition: numeric.hpp:81
static constexpr int PackedSize
Definition: numeric.hpp:82
Definition: numeric.hpp:18
static constexpr CK_TILE_HOST_DEVICE T lowest()
Definition: numeric.hpp:23
static constexpr CK_TILE_HOST_DEVICE T min()
Definition: numeric.hpp:20
static constexpr CK_TILE_HOST_DEVICE T quiet_NaN()
Definition: numeric.hpp:41
static constexpr CK_TILE_HOST_DEVICE T signaling_NaN()
Definition: numeric.hpp:47
static constexpr CK_TILE_HOST_DEVICE T max()
Definition: numeric.hpp:26
static constexpr CK_TILE_HOST_DEVICE T round_error()
Definition: numeric.hpp:32
static constexpr CK_TILE_HOST_DEVICE T zero()
Definition: numeric.hpp:58
static constexpr CK_TILE_HOST_DEVICE T denorm_min()
Definition: numeric.hpp:53
static constexpr CK_TILE_HOST_DEVICE T epsilon()
Definition: numeric.hpp:29
static constexpr CK_TILE_HOST_DEVICE T infinity()
Definition: numeric.hpp:38
Definition: random.hpp:17
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_)
Definition: numeric.hpp:106