11 #if defined(__gfx950__)
12 #define CK_TILE_FP4_CVT_DEVICE 1
14 #define CK_TILE_FP4_CVT_DEVICE 0
17 #define TEST_convert_with_table 0
22 using fp32x2_t =
float __attribute__((ext_vector_type(2)));
23 using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
33 static constexpr
int bias = 1;
40 template <
typename T,
typename = std::enable_if_t<std::is_
integral_v<T>>>
61 return (x1 << 4) | (x0 & 0b00001111);
64 #if TEST_convert_with_table
65 static constexpr
float e2m1_to_fp32_table[16] = {
66 0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6};
67 static constexpr
fp16_t e2m1_to_fp16_table[16] = {
68 bit_cast<fp16_t>(
static_cast<uint16_t
>(0x0000)),
69 bit_cast<fp16_t>(
static_cast<uint16_t
>(0x3800)),
70 bit_cast<fp16_t>(
static_cast<uint16_t
>(0x3C00)),
71 bit_cast<fp16_t>(
static_cast<uint16_t
>(0x3E00)),
72 bit_cast<fp16_t>(
static_cast<uint16_t
>(0x4000)),
73 bit_cast<fp16_t>(
static_cast<uint16_t
>(0x4200)),
74 bit_cast<fp16_t>(
static_cast<uint16_t
>(0x4400)),
75 bit_cast<fp16_t>(
static_cast<uint16_t
>(0x4600)),
76 bit_cast<fp16_t>(
static_cast<uint16_t
>(0x8000)),
77 bit_cast<fp16_t>(
static_cast<uint16_t
>(0xB800)),
78 bit_cast<fp16_t>(
static_cast<uint16_t
>(0xBC00)),
79 bit_cast<fp16_t>(
static_cast<uint16_t
>(0xBE00)),
80 bit_cast<fp16_t>(
static_cast<uint16_t
>(0xC000)),
81 bit_cast<fp16_t>(
static_cast<uint16_t
>(0xC200)),
82 bit_cast<fp16_t>(
static_cast<uint16_t
>(0xC400)),
83 bit_cast<fp16_t>(
static_cast<uint16_t
>(0xC600))
96 static constexpr
int exp = 2;
97 static constexpr
int mant = 1;
98 static constexpr
int bias = 1;
135 static_assert(I < 2,
"Index is out of range.");
139 return
data & 0b00001111;
144 #if CK_TILE_FP4_CVT_DEVICE
147 template <
typename T>
150 if constexpr(std::is_same_v<T, fp32_t>)
151 return fp32x2_t(__builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0))[0];
152 else if constexpr(std::is_same_v<T, fp32x2_t>)
153 return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0);
154 else if constexpr(std::is_same_v<T, fp16_t>)
155 return fp16x2_t(__builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0))[0];
156 else if constexpr(std::is_same_v<T, fp16x2_t>)
157 return __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0);
158 else if constexpr(std::is_same_v<T, bf16_t>)
159 return bf16x2_t(__builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0))[0];
160 else if constexpr(std::is_same_v<T, bf16x2_t>)
161 return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0);
163 static_assert(std::false_type::value,
"Unsupported type.");
166 template <
typename T>
174 if constexpr(std::is_same_v<T, fp32_t>)
175 cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src, src, scale, 0);
176 else if constexpr(std::is_same_v<T, fp32x2_t>)
177 cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src[0], src[1], scale, 0);
178 else if constexpr(std::is_same_v<T, fp16_t>)
179 cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32,
fp16x2_t{src, src}, scale, 0);
180 else if constexpr(std::is_same_v<T, fp16x2_t>)
181 cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32, src, scale, 0);
182 else if constexpr(std::is_same_v<T, bf16_t>)
183 cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_bf16(cvt.u32,
bf16x2_t{src, src}, scale, 0);
184 else if constexpr(std::is_same_v<T, bf16x2_t>)
185 cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_bf16(cvt.u32, src, scale, 0);
187 static_assert(std::false_type::value,
"Unsupported type.");
196 #if CK_TILE_FP4_CVT_DEVICE
197 return impl::_from_f4<bf16_t>(
data);
204 #if CK_TILE_FP4_CVT_DEVICE
205 return impl::_from_f4<bf16x2_t>(
data);
208 type_convert<bf16_t>(convert_to_float<pk_fp4_t>(
unpack(
number<1>{})))};
215 #if CK_TILE_FP4_CVT_DEVICE
216 return impl::_to_f4(x);
218 return convert_to_type<pk_fp4_t>(x);
227 #if CK_TILE_FP4_CVT_DEVICE
228 return impl::_to_f4(x);
235 #if CK_TILE_FP4_CVT_DEVICE
236 return impl::_to_f4(x);
243 #if CK_TILE_FP4_CVT_DEVICE
244 return impl::_to_f4(x);
252 #if CK_TILE_FP4_CVT_DEVICE
253 return impl::_to_f4(x);
261 #if CK_TILE_FP4_CVT_DEVICE
262 return impl::_to_f4(x);
268 #if TEST_convert_with_table == 0
271 #if CK_TILE_FP4_CVT_DEVICE
272 return impl::_from_f4<fp32_t>(
data);
279 #if CK_TILE_FP4_CVT_DEVICE
280 return impl::_from_f4<fp32x2_t>(
data);
288 #if CK_TILE_FP4_CVT_DEVICE
289 return impl::_from_f4<fp16_t>(
data);
296 #if CK_TILE_FP4_CVT_DEVICE
297 return impl::_from_f4<fp16x2_t>(
data);
300 type_convert<fp16_t>(convert_to_float<pk_fp4_t>(
unpack(
number<1>{})))};
306 return e2m1_to_fp32_table[
data & 0xf];
310 return fp32x2_t{e2m1_to_fp32_table[
data & 0xf], e2m1_to_fp32_table[
data >> 4]};
314 return e2m1_to_fp16_table[
data & 0xf];
318 return fp16x2_t{e2m1_to_fp16_table[
data & 0xf], e2m1_to_fp16_table[
data >> 4]};
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t &x)
Definition: pk_fp4.hpp:250
_Float16 fp16_t
Definition: half.hpp:110
float fp32x2_t
Definition: pk_fp4.hpp:22
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
pk_float4_e2m1_t pk_fp4_t
Definition: pk_fp4.hpp:88
float fp32_t
Definition: pk_fp4.hpp:21
constexpr CK_TILE_HOST_DEVICE pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t &x)
Definition: pk_fp4.hpp:241
constexpr CK_TILE_HOST_DEVICE fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t &x)
Definition: pk_fp4.hpp:221
_Float16 fp16x2_t
Definition: half.hpp:385
uint16_t bf16_raw_t
Definition: bfloat16.hpp:107
constexpr CK_TILE_HOST_DEVICE uint8_t float_to_e2m1(float)
Definition: pk_fp4.hpp:213
constexpr CK_TILE_HOST_DEVICE bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t &x)
Definition: pk_fp4.hpp:223
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:412
constexpr CK_TILE_HOST_DEVICE pk_fp4_t fp16_to_pk_fp4(const fp16_t &x)
Definition: pk_fp4.hpp:225
constexpr CK_TILE_HOST_DEVICE fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t &x)
Definition: pk_fp4.hpp:222
constexpr CK_TILE_HOST_DEVICE pk_fp4_t float_to_pk_fp4(const float &x)
Definition: pk_fp4.hpp:224
typename pk_fp4_t::raw_type pk_fp4_raw_t
Definition: pk_fp4.hpp:89
constexpr CK_TILE_HOST_DEVICE pk_fp4_t bf16_to_pk_fp4(const bf16_t &x)
Definition: pk_fp4.hpp:233
constexpr CK_TILE_HOST_DEVICE pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t &x)
Definition: pk_fp4.hpp:259
bf16_raw_t bf16x2_t
Definition: pk_fp4.hpp:24
Definition: integral_constant.hpp:13
static constexpr CK_TILE_HOST_DEVICE bool has_inf()
Definition: pk_fp4.hpp:123
static constexpr CK_TILE_HOST_DEVICE fp8_t denorm_min()
Definition: pk_fp4.hpp:121
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t min()
Definition: pk_fp4.hpp:115
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t infinity()
Definition: pk_fp4.hpp:125
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t round_error()
Definition: pk_fp4.hpp:119
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t epsilon()
Definition: pk_fp4.hpp:118
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t zero()
Definition: pk_fp4.hpp:120
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t quiet_NaN()
Definition: pk_fp4.hpp:127
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t signaling_NaN()
Definition: pk_fp4.hpp:129
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t lowest()
Definition: pk_fp4.hpp:117
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t max()
Definition: pk_fp4.hpp:116
pk_fp4_raw_t bitwise_type
Definition: pk_fp4.hpp:94
Definition: numeric.hpp:81
static constexpr int PackedSize
Definition: numeric.hpp:82
Definition: numeric.hpp:18
static constexpr CK_TILE_HOST_DEVICE T max()
Definition: numeric.hpp:26
Definition: pk_fp4.hpp:30
static constexpr int bias
Definition: pk_fp4.hpp:33
static constexpr int mantissa
Definition: pk_fp4.hpp:32
constexpr CK_TILE_HOST_DEVICE raw_type & get()
Definition: pk_fp4.hpp:48
raw_type data
Definition: pk_fp4.hpp:37
static constexpr int exponent
Definition: pk_fp4.hpp:31
constexpr CK_TILE_HOST_DEVICE pk_float4_e2m1_t()
Definition: pk_fp4.hpp:39
uint8_t raw_type
Definition: pk_fp4.hpp:35
constexpr CK_TILE_HOST_DEVICE pk_float4_e2m1_t(float init)
Definition: pk_fp4.hpp:44
constexpr CK_TILE_HOST_DEVICE raw_type unpack(number< I >) const
constexpr CK_TILE_HOST_DEVICE raw_type get() const
Definition: pk_fp4.hpp:49
constexpr CK_TILE_HOST_DEVICE pk_float4_e2m1_t(T init)
Definition: pk_fp4.hpp:41
constexpr static CK_TILE_HOST_DEVICE pk_float4_e2m1_t pack(const type x0, const type x1)
Definition: pk_fp4.hpp:59
raw_type type
Definition: pk_fp4.hpp:36
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_)
Definition: numeric.hpp:106