/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/core/numeric/mxfp_convert.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/core/numeric/mxfp_convert.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/core/numeric/mxfp_convert.hpp Source File
mxfp_convert.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 namespace ck_tile {
7 // modify from include/ck/utility/mxfp_utils.hpp
8 
9 template <typename T>
11 {
12 
15  using raw_type = typename T::raw_type;
16 
17  static constexpr int exp_mask = (1 << traits::exp) - 1;
18 
19  static constexpr int get_exponent(raw_type x)
20  {
21  // TODO: check if repeated calls are optimized.
22  return (x >> traits::mant) & exp_mask;
23  }
24  static constexpr bool is_positive(raw_type x)
25  {
26  return (x >> (traits::exp + traits::mant)) == _numeric::binary_zero;
27  }
28  static constexpr bool is_subnormal(raw_type x)
29  {
30  return get_exponent(x) == _numeric::binary_zero;
31  }
32  // TODO: replace double with template arg?
33  static constexpr double get_mantissa(raw_type x)
34  {
35  double mantissa = is_subnormal(x) ? 0.0f : 1.0f;
36  for(uint32_t i = 0; i < traits::mant; ++i)
37  {
38  mantissa += std::ldexp(static_cast<float>(x & 0b1), -(traits::mant - i));
39  x >>= 1;
40  }
41  return mantissa;
42  }
43 };
44 
45 template <typename T>
46 CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, int scale_exp = 127)
47 {
48  using utils = numeric_utils<T>;
49  static constexpr int e8m0_bias = 127; // TODO: make it generic.
50  float sign = utils::is_positive(data) ? 1.0 : -1.0;
51  int exp = (utils::is_subnormal(data) ? 1 : utils::get_exponent(data)) - utils::bias;
52  float mant = utils::get_mantissa(data);
53 
54  return std::ldexp(sign * mant, exp + scale_exp - e8m0_bias);
55 }
56 
57 template <typename T>
58 CK_TILE_HOST_DEVICE typename T::raw_type convert_to_type(float value)
59 {
60  using bitwise_type = typename numeric_traits<T>::bitwise_type;
61 
62  if(std::abs(value) > float(numeric<T>::max()))
63  {
64  float max_value = numeric<T>::max();
65 
66  // cppcheck-suppress redundantAssignment
67  uint32_t max_bitwise = bit_cast<uint32_t>(max_value);
68 
69  // cppcheck-suppress redundantAssignment
70  bitwise_type sign =
71  bit_cast<uint32_t>(value) >> (numeric_traits<float>::exp + numeric_traits<float>::mant);
72  bitwise_type exp =
75  bitwise_type mantissa =
77 
78  uint32_t mant_prev = max_bitwise >> (numeric_traits<float>::mant - numeric_traits<T>::mant);
79  mant_prev &= ((1 << numeric_traits<T>::mant) - 1);
80  mant_prev--;
81 
83  uint32_t prev_bit =
85  mant_prev;
86 
87  float prev_val = bit_cast<float>(prev_bit);
88  float diff = max_value - prev_val;
89 
90  float actual_max = max_value + (diff / 2);
91 
92  if(std::abs(value) < actual_max)
93  {
94  return sign << ((numeric_traits<T>::exp + numeric_traits<T>::mant)) |
95  (exp << numeric_traits<T>::mant) | mantissa;
96  }
97  else
98  {
99  if constexpr(!numeric<T>::has_inf())
100  {
101 
102  return (1 << (numeric_traits<T>::mant + numeric_traits<T>::exp)) - 1;
103  }
104  else
105  {
106  exp++;
107  return sign << ((numeric_traits<T>::exp + numeric_traits<T>::mant)) |
108  (exp << numeric_traits<T>::mant);
109  }
110  }
111  }
112  const int mfmt = numeric_traits<float>::mant;
113  uint32_t x;
114  x = bit_cast<uint32_t>(value);
115 
116  uint32_t head, mantissa;
117  int32_t exponent, bias;
118  uint32_t sign;
119 
121  mantissa = x & numeric_traits<float>::mant_mask;
125 
126  if(x == 0)
127  {
128  return 0b0;
129  }
130 
131  const int mini_bias = numeric_traits<T>::bias;
132  const int mini_denormal_act_exponent = 1 - mini_bias;
133 
134  int act_exponent, out_exponent, exponent_diff;
135 
136  bool is_subnorm = false;
137 
138  if(exponent == 0)
139  {
140  act_exponent = exponent - bias + 1;
141  exponent_diff = mini_denormal_act_exponent - act_exponent;
142  is_subnorm = true;
143  }
144  else
145  {
146  act_exponent = exponent - bias;
147  if(act_exponent <= mini_denormal_act_exponent)
148  {
149  exponent_diff = mini_denormal_act_exponent - act_exponent;
150  is_subnorm = true;
151  }
152  else
153  {
154  exponent_diff = 0;
155  }
156  mantissa += (1UL << mfmt);
157  }
158 
159  auto shift_amount = (mfmt - numeric_traits<T>::mant + exponent_diff);
160  shift_amount = (shift_amount >= 64) ? 63 : shift_amount;
161  bool midpoint = (mantissa & ((1UL << shift_amount) - 1)) == (1UL << (shift_amount - 1));
162 
163  float min_subnorm = float(numeric<T>::epsilon()) * (sign ? -1 : 1);
164 
165  if(is_subnorm && std::abs(value) < std::abs(min_subnorm))
166  {
167  // closer to 0
168  if(std::abs(value) <= std::abs(min_subnorm - value))
170  else
171  return 1 | (sign << (numeric_traits<T>::exp + numeric_traits<T>::mant));
172  }
173 
174  if(exponent_diff > 0)
175  mantissa >>= exponent_diff;
176  else if(exponent_diff == -1)
177  mantissa <<= -exponent_diff;
178  bool implicit_one = mantissa & (1 << mfmt);
179  out_exponent = (act_exponent + exponent_diff) + mini_bias - (implicit_one ? 0 : 1);
180 
181  uint32_t drop_mask = (1UL << (mfmt - numeric_traits<T>::mant)) - 1;
182  bool odd = mantissa & (1UL << (mfmt - numeric_traits<T>::mant));
183  mantissa += (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa) & drop_mask;
184 
185  if(out_exponent == 0)
186  {
187  if((1UL << mfmt) & mantissa)
188  {
189  out_exponent = 1;
190  }
191  }
192  else
193  {
194  if((1UL << (mfmt + 1)) & mantissa)
195  {
196  mantissa >>= 1;
197  out_exponent++;
198  }
199  }
200 
201  mantissa >>= (mfmt - numeric_traits<T>::mant);
202 
203  if(out_exponent == 0 && mantissa == 0)
204  {
206  }
207 
208  mantissa &= (1UL << numeric_traits<T>::mant) - 1;
209  return (sign << (numeric_traits<T>::exp + numeric_traits<T>::mant)) |
210  (out_exponent << numeric_traits<T>::mant) | mantissa;
211 }
212 
213 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
__host__ T exp(T x)
Definition: math_v2.hpp:391
__host__ __device__ bool is_subnormal(T x)
Definition: mxfp_utils.hpp:45
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE T::raw_type convert_to_type(float value)
Definition: mxfp_convert.hpp:58
CK_TILE_HOST_DEVICE float convert_to_float(typename T::raw_type data, int scale_exp=127)
Definition: mxfp_convert.hpp:46
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:412
int32_t int32_t
Definition: integer.hpp:10
Definition: numeric.hpp:81
Definition: mxfp_convert.hpp:11
typename T::raw_type raw_type
Definition: mxfp_convert.hpp:15
static constexpr bool is_positive(raw_type x)
Definition: mxfp_convert.hpp:24
static constexpr double get_mantissa(raw_type x)
Definition: mxfp_convert.hpp:33
static constexpr int exp_mask
Definition: mxfp_convert.hpp:17
static constexpr int get_exponent(raw_type x)
Definition: mxfp_convert.hpp:19
static constexpr bool is_subnormal(raw_type x)
Definition: mxfp_convert.hpp:28
Definition: numeric.hpp:18
static constexpr CK_TILE_HOST_DEVICE T max()
Definition: numeric.hpp:26