11 #if defined(__gfx950__)
12 #define CK_TILE_FP4_CVT_DEVICE 1
14 #define CK_TILE_FP4_CVT_DEVICE 0
17 #define TEST_convert_with_table 0
22 using fp32x2_t =
float __attribute__((ext_vector_type(2)));
23 using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
45 return ck_tile::bit_cast<fp16x2_repr>(v).e[0];
49 return ck_tile::bit_cast<fp16x2_repr>(v).e[1];
54 return ck_tile::bit_cast<bf16x2_repr>(v).e[0];
58 return ck_tile::bit_cast<bf16x2_repr>(v).e[1];
63 return ck_tile::bit_cast<fp32x2_repr>(v).e[0];
67 return ck_tile::bit_cast<fp32x2_repr>(v).e[1];
71 struct pk_float4_e2m1_t;
83 template <
typename T,
typename = std::enable_if_t<std::is_
integral_v<T>>>
124 return (x1 << 4) | (x0 & 0b00001111);
127 #if TEST_convert_with_table
128 static constexpr
float e2m1_to_fp32_table[16] = {
129 0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6};
130 static constexpr
fp16_t e2m1_to_fp16_table[16] = {
131 bit_cast<fp16_t>(
static_cast<uint16_t>(0x0000)),
132 bit_cast<fp16_t>(
static_cast<uint16_t>(0x3800)),
133 bit_cast<fp16_t>(
static_cast<uint16_t>(0x3C00)),
134 bit_cast<fp16_t>(
static_cast<uint16_t>(0x3E00)),
135 bit_cast<fp16_t>(
static_cast<uint16_t>(0x4000)),
136 bit_cast<fp16_t>(
static_cast<uint16_t>(0x4200)),
137 bit_cast<fp16_t>(
static_cast<uint16_t>(0x4400)),
138 bit_cast<fp16_t>(
static_cast<uint16_t>(0x4600)),
139 bit_cast<fp16_t>(
static_cast<uint16_t>(0x8000)),
140 bit_cast<fp16_t>(
static_cast<uint16_t>(0xB800)),
141 bit_cast<fp16_t>(
static_cast<uint16_t>(0xBC00)),
142 bit_cast<fp16_t>(
static_cast<uint16_t>(0xBE00)),
143 bit_cast<fp16_t>(
static_cast<uint16_t>(0xC000)),
144 bit_cast<fp16_t>(
static_cast<uint16_t>(0xC200)),
145 bit_cast<fp16_t>(
static_cast<uint16_t>(0xC400)),
146 bit_cast<fp16_t>(
static_cast<uint16_t>(0xC600))
159 static constexpr
int exp = 2;
160 static constexpr
int mant = 1;
161 static constexpr
int bias = 1;
198 static_assert(I < 2,
"Index is out of range.");
202 return
data & 0b00001111;
207 #if CK_TILE_FP4_CVT_DEVICE
210 template <
typename T>
213 if constexpr(std::is_same_v<T, fp32_t>)
215 fp32x2_t tmp = __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0);
218 else if constexpr(std::is_same_v<T, fp32x2_t>)
219 return __builtin_amdgcn_cvt_scalef32_pk_f32_fp4(src, scale, 0);
220 else if constexpr(std::is_same_v<T, fp16_t>)
222 fp16x2_t tmp = __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0);
225 else if constexpr(std::is_same_v<T, fp16x2_t>)
226 return __builtin_amdgcn_cvt_scalef32_pk_f16_fp4(src, scale, 0);
227 else if constexpr(std::is_same_v<T, bf16_t>)
229 bf16x2_t tmp = __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0);
232 else if constexpr(std::is_same_v<T, bf16x2_t>)
233 return __builtin_amdgcn_cvt_scalef32_pk_bf16_fp4(src, scale, 0);
238 template <
typename T>
246 if constexpr(std::is_same_v<T, fp32_t>)
247 cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(cvt.u32, src, src, scale, 0);
248 else if constexpr(std::is_same_v<T, fp32x2_t>)
249 cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f32(
251 else if constexpr(std::is_same_v<T, fp16_t>)
252 cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32,
fp16x2_t{src, src}, scale, 0);
253 else if constexpr(std::is_same_v<T, fp16x2_t>)
254 cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_f16(cvt.u32, src, scale, 0);
255 else if constexpr(std::is_same_v<T, bf16_t>)
256 cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_bf16(cvt.u32,
bf16x2_t{src, src}, scale, 0);
257 else if constexpr(std::is_same_v<T, bf16x2_t>)
258 cvt.u32 = __builtin_amdgcn_cvt_scalef32_pk_fp4_bf16(cvt.u32, src, scale, 0);
269 #if CK_TILE_FP4_CVT_DEVICE
270 return impl::_from_f4<bf16_t>(
data, scale);
278 #if CK_TILE_FP4_CVT_DEVICE
279 return impl::_from_f4<bf16x2_t>(
data, scale);
282 type_convert<bf16_t>(convert_to_float<pk_fp4_t>(
_unpack(
number<1>{}), scale))};
289 #if CK_TILE_FP4_CVT_DEVICE
290 return impl::_to_f4(x, scale);
292 return convert_to_type<pk_fp4_t>(x, scale);
297 #if CK_TILE_FP4_CVT_DEVICE
298 return impl::_to_f4(x, scale);
300 auto res = convert_to_type<pk_fp4_t>(x, scale);
306 #if CK_TILE_FP4_CVT_DEVICE
307 return impl::_to_f4(x, scale);
315 #if CK_TILE_FP4_CVT_DEVICE
316 return impl::_to_f4(x, scale);
324 #if CK_TILE_FP4_CVT_DEVICE
325 return impl::_to_f4(x, scale);
333 #if CK_TILE_FP4_CVT_DEVICE
334 return impl::_to_f4(x, scale);
342 #if CK_TILE_FP4_CVT_DEVICE
343 return impl::_to_f4(x, scale);
375 #if TEST_convert_with_table == 0
378 #if CK_TILE_FP4_CVT_DEVICE
379 return impl::_from_f4<fp32_t>(
data, scale);
386 #if CK_TILE_FP4_CVT_DEVICE
387 return impl::_from_f4<fp32x2_t>(
data, scale);
396 #if CK_TILE_FP4_CVT_DEVICE
397 return impl::_from_f4<fp16_t>(
data, scale);
404 #if CK_TILE_FP4_CVT_DEVICE
405 return impl::_from_f4<fp16x2_t>(
data, scale);
408 type_convert<fp16_t>(convert_to_float<pk_fp4_t>(
_unpack(
number<1>{}), scale))};
418 return fp32x2_t{e2m1_to_fp32_table[
_unpack(number<0>{})] * scale, e2m1_to_fp32_table[
_unpack(number<1>{}] * scale};
422 return type_convert<float>(e2m1_to_fp16_table[
_unpack(number<0>{})]) * scale;
427 type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[
_unpack(number<0>{})]) * scale),
428 type_convert<fp16_t>(type_convert<float>(e2m1_to_fp16_table[
_unpack(number<1>{})]) *
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
constexpr CK_TILE_HOST_DEVICE _Float16 lane0(const fp16x2_t &v)
Definition: pk_fp4.hpp:43
constexpr CK_TILE_HOST_DEVICE _Float16 lane1(const fp16x2_t &v)
Definition: pk_fp4.hpp:47
Definition: cluster_descriptor.hpp:13
typename pk_fp4_t::type pk_fp4_raw_t
Definition: pk_fp4.hpp:152
ushort bfloat16_t
Definition: bfloat16.hpp:111
constexpr CK_TILE_HOST_DEVICE pk_fp4_t fp16_to_pk_fp4(const fp16_t &x, float scale)
Definition: pk_fp4.hpp:304
bfloat16_t bf16x2_t
Definition: pk_fp4.hpp:24
_Float16 fp16_t
Definition: half.hpp:110
float fp32x2_t
Definition: pk_fp4.hpp:22
bfloat16_t bf16_t
Definition: bfloat16.hpp:113
constexpr CK_TILE_HOST_DEVICE pk_fp4_t fp32x2_to_pk_fp4(const fp32x2_t &x, float scale)
Definition: pk_fp4.hpp:340
pk_float4_e2m1_t pk_fp4_t
Definition: pk_fp4.hpp:151
float fp32_t
Definition: pk_fp4.hpp:21
_Float16 fp16x2_t
Definition: half.hpp:385
constexpr CK_TILE_HOST_DEVICE pk_float4_e2m1_t float_to_pk_fp4(const float &x, float scale=1.f)
Definition: pk_fp4.hpp:295
constexpr CK_TILE_HOST_DEVICE float pk_fp4_to_float(const pk_fp4_t &x, float scale)
Definition: pk_fp4.hpp:362
constexpr CK_TILE_HOST_DEVICE fp16_t pk_fp4_to_fp16(const pk_fp4_t &x, float scale)
Definition: pk_fp4.hpp:366
CK_TILE_DEVICE bfloat16_t exp(bfloat16_t x)
Definition: bfloat16.hpp:411
constexpr CK_TILE_HOST_DEVICE pk_fp4_t bf16x2_to_pk_fp4(const bf16x2_t &x, float scale)
Definition: pk_fp4.hpp:331
constexpr CK_TILE_HOST_DEVICE pk_fp4_t fp16x2_to_pk_fp4(const fp16x2_t &x, float scale)
Definition: pk_fp4.hpp:322
constexpr CK_TILE_HOST_DEVICE fp32x2_t pk_fp4_to_fp32x2(const pk_fp4_t &x, float scale)
Definition: pk_fp4.hpp:350
constexpr CK_TILE_HOST_DEVICE pk_fp4_raw_t float_to_mxfp4(float x, float scale)
Definition: pk_fp4.hpp:287
constexpr CK_TILE_HOST_DEVICE pk_fp4_t bf16_to_pk_fp4(const bf16_t &x, float scale)
Definition: pk_fp4.hpp:313
constexpr CK_TILE_HOST_DEVICE fp16x2_t pk_fp4_to_fp16x2(const pk_fp4_t &x, float scale)
Definition: pk_fp4.hpp:354
constexpr CK_TILE_HOST_DEVICE bf16_t pk_fp4_to_bf16(const pk_fp4_t &x, float scale)
Definition: pk_fp4.hpp:370
constexpr CK_TILE_HOST_DEVICE bf16x2_t pk_fp4_to_bf16x2(const pk_fp4_t &x, float scale)
Definition: pk_fp4.hpp:358
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
unsigned short uint16_t
Definition: stdint.h:125
unsigned int uint32_t
Definition: stdint.h:126
unsigned char uint8_t
Definition: stdint.h:124
Definition: integral_constant.hpp:13
Definition: pk_fp4.hpp:35
bfloat16_t e[2]
Definition: pk_fp4.hpp:36
Definition: pk_fp4.hpp:31
_Float16 e[2]
Definition: pk_fp4.hpp:32
Definition: pk_fp4.hpp:39
float e[2]
Definition: pk_fp4.hpp:40
static constexpr CK_TILE_HOST_DEVICE bool has_inf()
Definition: pk_fp4.hpp:186
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t min()
Definition: pk_fp4.hpp:178
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t denorm_min()
Definition: pk_fp4.hpp:184
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t infinity()
Definition: pk_fp4.hpp:188
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t round_error()
Definition: pk_fp4.hpp:182
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t epsilon()
Definition: pk_fp4.hpp:181
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t zero()
Definition: pk_fp4.hpp:183
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t quiet_NaN()
Definition: pk_fp4.hpp:190
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t signaling_NaN()
Definition: pk_fp4.hpp:192
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t lowest()
Definition: pk_fp4.hpp:180
static constexpr CK_TILE_HOST_DEVICE pk_fp4_t max()
Definition: pk_fp4.hpp:179
pk_fp4_raw_t bitwise_type
Definition: pk_fp4.hpp:157
Definition: numeric.hpp:81
static constexpr int PackedSize
Definition: numeric.hpp:82
Definition: numeric.hpp:18
static constexpr CK_TILE_HOST_DEVICE T max()
Definition: numeric.hpp:26
Definition: pk_fp4.hpp:76
constexpr CK_TILE_HOST_DEVICE bf16x2_t to_bf16x2(float scale=1.f) const
Definition: pk_fp4.hpp:276
constexpr CK_TILE_HOST_DEVICE fp16x2_t to_fp16x2(float scale=1.f) const
Definition: pk_fp4.hpp:402
constexpr CK_TILE_HOST_DEVICE fp16_t to_fp16(float scale=1.f) const
Definition: pk_fp4.hpp:394
constexpr CK_TILE_HOST_DEVICE float to_float(float scale=1.f) const
Definition: pk_fp4.hpp:376
constexpr CK_TILE_HOST_DEVICE pk_float4_e2m1_t()
Definition: pk_fp4.hpp:82
uint8_t raw_type
Definition: pk_fp4.hpp:78
constexpr CK_TILE_HOST_DEVICE pk_float4_e2m1_t(float init, float scale=1.f)
Definition: pk_fp4.hpp:87
constexpr CK_TILE_HOST_DEVICE pk_float4_e2m1_t unpack(number< I >) const
Definition: pk_fp4.hpp:110
constexpr CK_TILE_HOST_DEVICE type & get()
Definition: pk_fp4.hpp:92
constexpr CK_TILE_HOST_DEVICE type _unpack(number< I >) const
constexpr CK_TILE_HOST_DEVICE type get() const
Definition: pk_fp4.hpp:93
constexpr CK_TILE_HOST_DEVICE fp32x2_t to_fp32x2(float scale=1.f) const
Definition: pk_fp4.hpp:384
constexpr CK_TILE_HOST_DEVICE bf16_t to_bf16(float scale=1.f) const
Definition: pk_fp4.hpp:267
constexpr CK_TILE_HOST_DEVICE pk_float4_e2m1_t(T init)
Definition: pk_fp4.hpp:84
type data
Definition: pk_fp4.hpp:80
raw_type type
Definition: pk_fp4.hpp:79
constexpr static CK_TILE_HOST_DEVICE pk_float4_e2m1_t pack(const pk_float4_e2m1_t &x0, const pk_float4_e2m1_t &x1)
Definition: pk_fp4.hpp:114
constexpr static CK_TILE_HOST_DEVICE type _pack(const type x0, const type x1)
Definition: pk_fp4.hpp:122
#define CK_TILE_ARITHMETIC_USING_FLOAT(attr_, type_)
Definition: numeric.hpp:106