/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/core/numeric/pk_int4.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/core/numeric/pk_int4.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/core/numeric/pk_int4.hpp Source File
pk_int4.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
11 #include <stdint.h>
12 #include <type_traits>
14 
15 #pragma once
16 
17 namespace ck_tile {
18 
19 // Packed 2xint4
20 struct pk_int4_t
21 {
22  using type = int8_t;
24  CK_TILE_HOST_DEVICE constexpr pk_int4_t() : data{type{}} {}
25  CK_TILE_HOST_DEVICE constexpr pk_int4_t(type init) : data{init} {}
26 };
27 
28 // limits
29 template <class T>
30 struct numeric;
31 
32 template <>
34 {
35  // minimum finite value, or minimum positive normalized value for float
36  CK_TILE_HOST_DEVICE static constexpr pk_int4_t min()
37  {
38  constexpr uint8_t val = 0b10001000;
39  return pk_int4_t(bit_cast<int8_t>(val));
40  }
41 
42  // minumum finite value
44  {
45  constexpr uint8_t val = 0b10001000;
46  return pk_int4_t(bit_cast<int8_t>(val));
47  }
48 
49  // maximum finite value
50  CK_TILE_HOST_DEVICE static constexpr pk_int4_t max()
51  {
52  constexpr uint8_t val = 0b01110111;
53  return pk_int4_t(bit_cast<int8_t>(val));
54  }
55 
56  // difference between 1.0 and next value representable by float
58  {
59  return 1; // not used
60  }
61 
63  {
64  return 1; // not used
65  }
66 
67  // positive infinity value
69  {
70  return 1; // not used
71  }
72 
73  // quiet NaN
75  {
76  return 1; // not used
77  }
78 
79  // signaling NaN
81  {
82  return 1; // not used
83  }
84 
85  // smallest positive subnormal value
87  {
88  return 1; // not used
89  }
90 
91  CK_TILE_HOST_DEVICE static constexpr pk_int4_t zero() { return 0; }
92 };
93 
94 template <>
96 {
97  static constexpr int PackedSize = 2;
98 };
99 
100 using fp32x2_t = float __attribute__((ext_vector_type(2)));
101 using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
102 using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
103 
105 {
106  uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
107 
108  float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
109  float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
110 
111 #ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
112  fp32x2_t res = {x_h, x_l};
113 #elif
114  fp32x2_t res = {x_l, x_h};
115 #endif
116  return res;
117 }
118 
120 {
121  uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
122 #ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
123  uint32_t i4s = ((x_u8 & 0x0f) << 16) | ((x_u8 & 0xf0) >> 4);
124 #elif
125  uint32_t i4s = ((x_u8 & 0xf0) << 12) | (x_u8 & 0xf);
126 #endif
127  const int EX = 0x64006400;
128  const int SUB = 0xE408E408; //-8
129 
130  int lo = i4s | EX;
131 
132  return pk_add_f16(bit_cast<fp16x2_t>(lo), bit_cast<fp16x2_t>(SUB));
133 }
134 
136 {
137  uint8_t x_u8 = ck_tile::bit_cast<uint8_t>(x);
138 
139  float x_l = ((x_u8 & 0x0f) >> 0) - 8.f;
140  float x_h = ((x_u8 & 0xf0) >> 4) - 8.f;
141 
142 #ifdef CK_TILE_USE_PK4_LAYOUT_SHUFFLE
143  bf16x2_t res = {type_convert<bf16_t>(x_h), type_convert<bf16_t>(x_l)};
144 #elif
145  bf16x2_t res = {type_convert<bf16_t>(x_l), type_convert<bf16_t>(x_h)};
146 #endif
147  return res;
148 }
149 
150 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t &x)
Definition: pk_int4.hpp:104
int8_t int8_t
Definition: int8.hpp:20
float fp32x2_t
Definition: pk_int4.hpp:100
_Float16 fp16x2_t
Definition: half.hpp:385
uint16_t bf16_raw_t
Definition: bfloat16.hpp:107
CK_TILE_HOST fp16x2_t pk_add_f16(const fp16x2_t &x, const fp16x2_t &y)
Definition: half.hpp:387
CK_TILE_HOST_DEVICE fp16x2_t pk_int4_t_to_halfx2_t(const pk_int4_t &x)
Definition: pk_int4.hpp:119
CK_TILE_HOST_DEVICE bf16x2_t pk_int4_t_to_bfloat16x2_t(const pk_int4_t &x)
Definition: pk_int4.hpp:135
bf16_raw_t bf16x2_t
Definition: pk_int4.hpp:102
static constexpr CK_TILE_HOST_DEVICE pk_int4_t denorm_min()
Definition: pk_int4.hpp:86
static constexpr CK_TILE_HOST_DEVICE pk_int4_t zero()
Definition: pk_int4.hpp:91
static constexpr CK_TILE_HOST_DEVICE pk_int4_t quiet_NaN()
Definition: pk_int4.hpp:74
static constexpr CK_TILE_HOST_DEVICE pk_int4_t min()
Definition: pk_int4.hpp:36
static constexpr CK_TILE_HOST_DEVICE pk_int4_t infinity()
Definition: pk_int4.hpp:68
static constexpr CK_TILE_HOST_DEVICE pk_int4_t max()
Definition: pk_int4.hpp:50
static constexpr CK_TILE_HOST_DEVICE pk_int4_t epsilon()
Definition: pk_int4.hpp:57
static constexpr CK_TILE_HOST_DEVICE pk_int4_t lowest()
Definition: pk_int4.hpp:43
static constexpr CK_TILE_HOST_DEVICE pk_int4_t round_error()
Definition: pk_int4.hpp:62
static constexpr CK_TILE_HOST_DEVICE pk_int4_t signaling_NaN()
Definition: pk_int4.hpp:80
Definition: numeric.hpp:81
static constexpr int PackedSize
Definition: numeric.hpp:82
Definition: numeric.hpp:18
Definition: pk_int4.hpp:21
type data
Definition: pk_int4.hpp:23
constexpr CK_TILE_HOST_DEVICE pk_int4_t()
Definition: pk_int4.hpp:24
constexpr CK_TILE_HOST_DEVICE pk_int4_t(type init)
Definition: pk_int4.hpp:25
int8_t type
Definition: pk_int4.hpp:22