include/ck_tile/core/numeric/pk_fp4.hpp Source File

include/ck_tile/core/numeric/pk_fp4.hpp Source File#

Composable Kernel: include/ck_tile/core/numeric/pk_fp4.hpp Source File
pk_fp4.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <cmath>
10 
11 #if defined(__gfx950__)
12 #define CK_TILE_FP4_CVT_DEVICE 1
13 #else
14 #define CK_TILE_FP4_CVT_DEVICE 0
15 #endif
16 
17 #define TEST_convert_with_table 0
18 
19 namespace ck_tile {
20 
21 using fp32_t = float;
22 using fp32x2_t = float __attribute__((ext_vector_type(2)));
23 using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
24 using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
25 
26 CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float);
27 
28 // TODO: Add stochastic method
30 {
31  static constexpr int exponent = 2;
32  static constexpr int mantissa = 1;
33  static constexpr int bias = 1;
34  // TODO: Can we merge raw_type and type?
35  using raw_type = uint8_t;
36  using type = raw_type;
38 
40  template <typename T, typename = std::enable_if_t<std::is_integral_v<T>>>
41  CK_TILE_HOST_DEVICE constexpr pk_float4_e2m1_t(T init) : data{static_cast<type>(init)}
42  {
43  }
44  CK_TILE_HOST_DEVICE explicit constexpr pk_float4_e2m1_t(float init) : data{float_to_e2m1(init)}
45  {
46  }
47  CK_TILE_HOST_DEVICE constexpr operator type() const { return data; }
48  CK_TILE_HOST_DEVICE constexpr raw_type& get() { return data; }
49  CK_TILE_HOST_DEVICE constexpr raw_type get() const { return data; }
50  CK_TILE_HOST_DEVICE constexpr operator float() const;
51  CK_TILE_HOST_DEVICE constexpr operator fp32x2_t() const;
52  CK_TILE_HOST_DEVICE constexpr operator fp16_t() const;
53  CK_TILE_HOST_DEVICE constexpr operator fp16x2_t() const;
54  CK_TILE_HOST_DEVICE constexpr operator bf16_t() const;
55  CK_TILE_HOST_DEVICE constexpr operator bf16x2_t() const;
56 
57  template <index_t I>
59  CK_TILE_HOST_DEVICE constexpr static pk_float4_e2m1_t pack(const type x0, const type x1)
60  {
61  return (x1 << 4) | (x0 & 0b00001111);
62  }
63 
64 #if TEST_convert_with_table
65  static constexpr float e2m1_to_fp32_table[16] = {
66  0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6};
67  static constexpr fp16_t e2m1_to_fp16_table[16] = {
68  bit_cast<fp16_t>(static_cast<uint16_t>(0x0000)), // 0
69  bit_cast<fp16_t>(static_cast<uint16_t>(0x3800)), // 0.5
70  bit_cast<fp16_t>(static_cast<uint16_t>(0x3C00)), // 1
71  bit_cast<fp16_t>(static_cast<uint16_t>(0x3E00)), // 1.5
72  bit_cast<fp16_t>(static_cast<uint16_t>(0x4000)), // 2
73  bit_cast<fp16_t>(static_cast<uint16_t>(0x4200)), // 3
74  bit_cast<fp16_t>(static_cast<uint16_t>(0x4400)), // 4
75  bit_cast<fp16_t>(static_cast<uint16_t>(0x4600)), // 6
76  bit_cast<fp16_t>(static_cast<uint16_t>(0x8000)), // -0
77  bit_cast<fp16_t>(static_cast<uint16_t>(0xB800)), // -0.5
78  bit_cast<fp16_t>(static_cast<uint16_t>(0xBC00)), // -1
79  bit_cast<fp16_t>(static_cast<uint16_t>(0xBE00)), // -1.5
80  bit_cast<fp16_t>(static_cast<uint16_t>(0xC000)), // -2
81  bit_cast<fp16_t>(static_cast<uint16_t>(0xC200)), // -3
82  bit_cast<fp16_t>(static_cast<uint16_t>(0xC400)), // -4
83  bit_cast<fp16_t>(static_cast<uint16_t>(0xC600)) // -6
84  };
85 #endif
86 };
87 
90 
91 template <>
93 {
95 
96  static constexpr int exp = 2;
97  static constexpr int mant = 1;
98  static constexpr int bias = 1;
99  static constexpr int PackedSize = 2;
100 };
101 
102 // limits
103 template <class T>
104 struct numeric;
105 
106 template <>
108 {
109  static constexpr pk_fp4_raw_t binary_min_normal = 0b00100010; // 1
110  static constexpr pk_fp4_raw_t binary_max_normal = 0b01110111; // 6
111  static constexpr pk_fp4_raw_t binary_lowest_normal = 0b11111111; // -6
112  static constexpr pk_fp4_raw_t binary_min_subnorm = 0b00010001; // 0.5
113  static constexpr pk_fp4_raw_t binary_max_subnorm = 0b00010001; // 0.5
114  static constexpr pk_fp4_raw_t binary_zero = 0b00000000; // 0
115  CK_TILE_HOST_DEVICE static constexpr pk_fp4_t min() { return binary_min_normal; }
116  CK_TILE_HOST_DEVICE static constexpr pk_fp4_t max() { return binary_max_normal; }
117  CK_TILE_HOST_DEVICE static constexpr pk_fp4_t lowest() { return binary_lowest_normal; }
118  CK_TILE_HOST_DEVICE static constexpr pk_fp4_t epsilon() { return binary_min_subnorm; }
119  CK_TILE_HOST_DEVICE static constexpr pk_fp4_t round_error() { return binary_min_subnorm; }
120  CK_TILE_HOST_DEVICE static constexpr pk_fp4_t zero() { return binary_zero; }
121  CK_TILE_HOST_DEVICE static constexpr fp8_t denorm_min() { return binary_min_subnorm; }
122 
123  CK_TILE_HOST_DEVICE static constexpr bool has_inf() { return false; }
124  // N/A
125  CK_TILE_HOST_DEVICE static constexpr pk_fp4_t infinity() { return max(); }
126  // N/A
127  CK_TILE_HOST_DEVICE static constexpr pk_fp4_t quiet_NaN() { return max(); }
128  // N/A
129  CK_TILE_HOST_DEVICE static constexpr pk_fp4_t signaling_NaN() { return max(); }
130 };
131 
132 template <index_t I>
133 CK_TILE_HOST_DEVICE constexpr pk_fp4_raw_t pk_fp4_t::unpack(number<I>) const
134 {
135  static_assert(I < 2, "Index is out of range.");
136  if constexpr(I == 1)
137  return (data >> 4);
138  else
139  return data & 0b00001111;
140 }
142 // TODO: consider replace this macro to improve performance
143 
144 #if CK_TILE_FP4_CVT_DEVICE
145 namespace impl {
146 
147 template <typename T>
148 CK_TILE_DEVICE T _from_f4(pk_fp4_raw_t src, float scale = 1.0f)
149 {
150  if constexpr(std::is_same_v<T, fp32_t>)
151  return fp32x2_t(__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0))[0];
152  else if constexpr(std::is_same_v<T, fp32x2_t>)
153  return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0);
154  else if constexpr(std::is_same_v<T, fp16_t>)
155  return fp16x2_t(__builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0))[0];
156  else if constexpr(std::is_same_v<T, fp16x2_t>)
157  return __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0);
158  else if constexpr(std::is_same_v<T, bf16_t>)
159  return bf16x2_t(__builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0))[0];
160  else if constexpr(std::is_same_v<T, bf16x2_t>)
161  return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0);
162  else
163  static_assert(std::false_type::value, "Unsupported type.");
164  return T{};
165 }
166 template <typename T>
167 CK_TILE_DEVICE pk_fp4_raw_t _to_f4(T src, float scale = 1.0f)
168 {
169  union
170  {
171  uint32_t u32;
172  pk_fp4_raw_t pf4[4];
173  } cvt{0};
174  if constexpr(std::is_same_v<T, fp32_t>)
175  cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src, src, scale, 0);
176  else if constexpr(std::is_same_v<T, fp32x2_t>)
177  cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src[0], src[1], scale, 0);
178  else if constexpr(std::is_same_v<T, fp16_t>)
179  cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32, fp16x2_t{src, src}, scale, 0);
180  else if constexpr(std::is_same_v<T, fp16x2_t>)
181  cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32, src, scale, 0);
182  else if constexpr(std::is_same_v<T, bf16_t>)
183  cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_bf16(cvt.u32, bf16x2_t{src, src}, scale, 0);
184  else if constexpr(std::is_same_v<T, bf16x2_t>)
185  cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_bf16(cvt.u32, src, scale, 0);
186  else
187  static_assert(std::false_type::value, "Unsupported type.");
188  return cvt.pf4[0];
189 }
190 
191 } // namespace impl
192 #endif
193 
194 CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator bf16_t() const
195 {
196 #if CK_TILE_FP4_CVT_DEVICE
197  return impl::_from_f4<bf16_t>(data);
198 #else
199  return bf16_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{})))};
200 #endif
201 }
202 CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator bf16x2_t() const
203 {
204 #if CK_TILE_FP4_CVT_DEVICE
205  return impl::_from_f4<bf16x2_t>(data);
206 #else
207  return bf16x2_t{type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}))),
208  type_convert<bf16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{})))};
209 #endif
210 }
211 
212 // TODO: make float_to_e2m1 generic so that we can convert from directrly.
214 {
215 #if CK_TILE_FP4_CVT_DEVICE
216  return impl::_to_f4(x);
217 #else
218  return convert_to_type<pk_fp4_t>(x);
219 #endif
220 }
224 CK_TILE_HOST_DEVICE constexpr pk_fp4_t float_to_pk_fp4(const float& x) { return float_to_e2m1(x); }
226 {
227 #if CK_TILE_FP4_CVT_DEVICE
228  return impl::_to_f4(x);
229 #else
230  return float_to_e2m1(type_convert<float>(x));
231 #endif
232 }
234 {
235 #if CK_TILE_FP4_CVT_DEVICE
236  return impl::_to_f4(x);
237 #else
238  return float_to_e2m1(type_convert<float>(x));
239 #endif
240 }
242 {
243 #if CK_TILE_FP4_CVT_DEVICE
244  return impl::_to_f4(x);
245 #else
246  return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0])),
247  float_to_e2m1(type_convert<float>(x[1])));
248 #endif
249 }
251 {
252 #if CK_TILE_FP4_CVT_DEVICE
253  return impl::_to_f4(x);
254 #else
255  return pk_fp4_t::pack(float_to_e2m1(type_convert<float>(x[0])),
256  float_to_e2m1(type_convert<float>(x[1])));
257 #endif
258 }
260 {
261 #if CK_TILE_FP4_CVT_DEVICE
262  return impl::_to_f4(x);
263 #else
264  return pk_fp4_t::pack(float_to_e2m1(x[0]), float_to_e2m1(x[1]));
265 #endif
266 }
267 
268 #if TEST_convert_with_table == 0
269 CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator float() const
270 {
271 #if CK_TILE_FP4_CVT_DEVICE
272  return impl::_from_f4<fp32_t>(data);
273 #else
274  return convert_to_float<pk_fp4_t>(unpack(number<0>{}));
275 #endif
276 }
277 CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp32x2_t() const
278 {
279 #if CK_TILE_FP4_CVT_DEVICE
280  return impl::_from_f4<fp32x2_t>(data);
281 #else
282  return fp32x2_t{convert_to_float<pk_fp4_t>(unpack(number<0>{})),
283  convert_to_float<pk_fp4_t>(unpack(number<1>{}))};
284 #endif
285 }
286 CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16_t() const
287 {
288 #if CK_TILE_FP4_CVT_DEVICE
289  return impl::_from_f4<fp16_t>(data);
290 #else
291  return fp16_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{})))};
292 #endif
293 }
294 CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16x2_t() const
295 {
296 #if CK_TILE_FP4_CVT_DEVICE
297  return impl::_from_f4<fp16x2_t>(data);
298 #else
299  return fp16x2_t{type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<0>{}))),
300  type_convert<fp16_t>(convert_to_float<pk_fp4_t>(unpack(number<1>{})))};
301 #endif
302 }
303 #else
304 CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator float() const
305 {
306  return e2m1_to_fp32_table[data & 0xf];
307 }
308 CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp32x2_t() const
309 {
310  return fp32x2_t{e2m1_to_fp32_table[data & 0xf], e2m1_to_fp32_table[data >> 4]};
311 }
312 CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16_t() const
313 {
314  return e2m1_to_fp16_table[data & 0xf];
315 }
316 CK_TILE_HOST_DEVICE constexpr pk_fp4_t::operator fp16x2_t() const
317 {
318  return fp16x2_t{e2m1_to_fp16_table[data & 0xf], e2m1_to_fp16_table[data >> 4]};
319 }
320 #endif
321 
322 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t &x)
Definition: pk_fp4.hpp:250
_Float16 fp16_t
Definition: half.hpp:110
float fp32x2_t
Definition: pk_fp4.hpp:22
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
pk_float4_e2m1_t pk_fp4_t
Definition: pk_fp4.hpp:88
float fp32_t
Definition: pk_fp4.hpp:21
constexpr CK_TILE_HOST_DEVICE pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t &x)
Definition: pk_fp4.hpp:241
constexpr CK_TILE_HOST_DEVICE fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t &x)
Definition: pk_fp4.hpp:221
_Float16 fp16x2_t
Definition: half.hpp:385
uint16_t bf16_raw_t
Definition: bfloat16.hpp:107
constexpr CK_TILE_HOST_DEVICE uint8_t float_to_e2m1(float)
Definition: pk_fp4.hpp:213
constexpr CK_TILE_HOST_DEVICE bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t &x)
Definition: pk_fp4.hpp:223
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:412
constexpr CK_TILE_HOST_DEVICE pk_fp4_t fp16_to_pk_fp4(const fp16_t &x)
Definition: pk_fp4.hpp:225
constexpr CK_TILE_HOST_DEVICE fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t &x)
Definition: pk_fp4.hpp:222
constexpr CK_TILE_HOST_DEVICE pk_fp4_t float_to_pk_fp4(const float &x)
Definition: pk_fp4.hpp:224
typename pk_fp4_t::raw_type pk_fp4_raw_t
Definition: pk_fp4.hpp:89
constexpr CK_TILE_HOST_DEVICE pk_fp4_t bf16_to_pk_fp4(const bf16_t &x)
Definition: pk_fp4.hpp:233
constexpr CK_TILE_HOST_DEVICE pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t &x)
Definition: pk_fp4.hpp:259
bf16_raw_t bf16x2_t
Definition: pk_fp4.hpp:24
Definition: integral_constant.hpp:13
static constexpr CK_TILE_HOST_DEVICE bool has_inf()
Definition: pk_fp4.hpp:123
static constexpr CK_TILE_HOST_DEVICE fp8_t denorm_min()
Definition: pk_fp4.hpp:121
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t min()
Definition: pk_fp4.hpp:115
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t infinity()
Definition: pk_fp4.hpp:125
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t round_error()
Definition: pk_fp4.hpp:119
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t epsilon()
Definition: pk_fp4.hpp:118
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t zero()
Definition: pk_fp4.hpp:120
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t quiet_NaN()
Definition: pk_fp4.hpp:127
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t signaling_NaN()
Definition: pk_fp4.hpp:129
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t lowest()
Definition: pk_fp4.hpp:117
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t max()
Definition: pk_fp4.hpp:116
pk_fp4_raw_t bitwise_type
Definition: pk_fp4.hpp:94
Definition: numeric.hpp:81
static constexpr int PackedSize
Definition: numeric.hpp:82
Definition: numeric.hpp:18
static constexpr CK_TILE_HOST_DEVICE T max()
Definition: numeric.hpp:26
Definition: pk_fp4.hpp:30
static constexpr int bias
Definition: pk_fp4.hpp:33
static constexpr int mantissa
Definition: pk_fp4.hpp:32
constexpr CK_TILE_HOST_DEVICE raw_type & get()
Definition: pk_fp4.hpp:48
raw_type data
Definition: pk_fp4.hpp:37
static constexpr int exponent
Definition: pk_fp4.hpp:31
constexpr CK_TILE_HOST_DEVICE pk_float4_e2m1_t()
Definition: pk_fp4.hpp:39
uint8_t raw_type
Definition: pk_fp4.hpp:35
constexpr CK_TILE_HOST_DEVICE pk_float4_e2m1_t(float init)
Definition: pk_fp4.hpp:44
constexpr CK_TILE_HOST_DEVICE raw_type unpack(number< I >) const
constexpr CK_TILE_HOST_DEVICE raw_type get() const
Definition: pk_fp4.hpp:49
constexpr CK_TILE_HOST_DEVICE pk_float4_e2m1_t(T init)
Definition: pk_fp4.hpp:41
constexpr static CK_TILE_HOST_DEVICE pk_float4_e2m1_t pack(const type x0, const type x1)
Definition: pk_fp4.hpp:59
raw_type type
Definition: pk_fp4.hpp:36
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_)
Definition: numeric.hpp:106