/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/arch.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/arch.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/arch.hpp Source File
arch.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 // Address Space for AMDGCN
7 // https://llvm.org/docs/AMDGPUUsage.html#address-space
8 
15 
16 #define CK_TILE_S_CNT_MAX 0b1100'1111'0111'1111
17 #define CK_TILE_VMCNT(cnt) \
18  ([]() { static_assert(!((cnt) >> 6), "VMCNT only has 6 bits"); }(), \
19  ((cnt) & 0b1111) | (((cnt) & 0b110000) << 10))
20 #define CK_TILE_EXPCNT(cnt) \
21  ([]() { static_assert(!((cnt) >> 3), "EXP only has 3 bits"); }(), ((cnt) << 4))
22 #define CK_TILE_LGKMCNT(cnt) \
23  ([]() { static_assert(!((cnt) >> 4), "LGKM only has 4 bits"); }(), ((cnt) << 8))
24 
25 namespace ck_tile {
26 
27 template <typename, bool>
28 struct safe_underlying_type;
29 
30 template <typename T>
31 struct safe_underlying_type<T, true>
32 {
33  using type = std::underlying_type_t<T>;
34 };
35 
36 template <typename T>
37 struct safe_underlying_type<T, false>
38 {
39  using type = void;
40 };
41 
42 template <typename T>
43 using safe_underlying_type_t = typename safe_underlying_type<T, std::is_enum<T>::value>::type;
44 
45 enum struct address_space_enum : std::uint16_t
46 {
47  generic = 0,
48  global,
49  lds,
50  sgpr,
51  constant,
52  vgpr
53 };
54 
55 enum struct memory_operation_enum : std::uint16_t
56 {
57  set = 0,
58  atomic_add,
59  atomic_max,
60  add
61 };
62 
64 {
65 #if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
66  return 64;
67 #else
68  return 32;
69 #endif
70 }
71 
72 CK_TILE_HOST bool is_wave32()
73 {
74  hipDeviceProp_t props{};
75  int device;
76  auto status = hipGetDevice(&device);
77  if(status != hipSuccess)
78  {
79  return false;
80  }
81  status = hipGetDeviceProperties(&props, device);
82  if(status != hipSuccess)
83  {
84  return false;
85  }
86  return props.major > 9;
87 }
88 
89 CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; }
90 
91 CK_TILE_DEVICE index_t get_block_size() { return blockDim.x; }
92 
93 // TODO: deprecate these
94 CK_TILE_DEVICE index_t get_thread_local_1d_id() { return threadIdx.x; }
95 
96 CK_TILE_DEVICE index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + threadIdx.x; }
97 
98 CK_TILE_DEVICE index_t get_block_1d_id() { return blockIdx.x; }
99 
100 // Use these instead
101 CK_TILE_DEVICE index_t get_lane_id() { return __lane_id(); }
102 
103 template <bool ReturnSgpr = true>
104 CK_TILE_DEVICE index_t get_warp_id(bool_constant<ReturnSgpr> = {})
105 {
106  const index_t warp_id = threadIdx.x / get_warp_size();
107  if constexpr(ReturnSgpr)
108  {
109  return amd_wave_read_first_lane(warp_id);
110  }
111  else
112  {
113  return warp_id;
114  }
115 }
116 
117 CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }
118 
119 CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
120 
121 CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
122 {
123 #ifdef __gfx12__
124  asm volatile("s_wait_loadcnt %0 \n"
125  "s_barrier_signal -1 \n"
126  "s_barrier_wait -1"
127  :
128  : "n"(cnt)
129  : "memory");
130 #else
131  asm volatile("s_waitcnt vmcnt(%0) \n"
132  "s_barrier"
133  :
134  : "n"(cnt)
135  : "memory");
136 #endif
137 }
138 
139 // https://llvm.org/docs/AMDGPU/gfx9_waitcnt.html
140 struct waitcnt_arg
141 {
142 #if defined(__gfx12__)
143  // use s_wait_loadcnt_dscnt in this instruction; in this instruction, ds [5:0]; mem [13:8]
144  CK_TILE_DEVICE static constexpr index_t MAX = 0b00'111111'00'111111;
145 
146  CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0b111111;
147  CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0b111;
148  CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0b111111;
149 
150  template <index_t cnt>
151  CK_TILE_DEVICE static constexpr index_t from_vmcnt()
152  {
153  static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]");
154  return MAX & (cnt << 8);
155  }
156 
157  template <index_t cnt>
158  CK_TILE_DEVICE static constexpr index_t from_expcnt()
159  {
160  return 0; // no export in MI series
161  }
162 
163  template <index_t cnt>
164  CK_TILE_DEVICE static constexpr index_t from_lgkmcnt()
165  {
166  static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]");
167  return MAX & cnt;
168  }
169 #else
170  // bit numbers (hex) -------------------------> FE'DC'BA98'7'654'3210
171  // [V]M [E]XP [L]GKM counters and [U]NUSED ---> VV'UU'LLLL'U'EEE'VVVV
172  CK_TILE_DEVICE static constexpr index_t MAX = 0b11'00'1111'0'111'1111;
173 
174  CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0b111111;
175  CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0b111;
176  CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0b1111;
177 
178  template <index_t cnt>
179  CK_TILE_DEVICE static constexpr index_t from_vmcnt()
180  {
181  static_assert(cnt >= 0 && !(cnt >> 6), "valid range is [0..63]");
182  return MAX & ((cnt & 0b1111) | ((cnt & 0b110000) << 10));
183  }
184 
185  template <index_t cnt>
186  CK_TILE_DEVICE static constexpr index_t from_expcnt()
187  {
188  static_assert(cnt >= 0 && !(cnt >> 3), "valid range is [0..7]");
189  return MAX & (cnt << 4);
190  }
191 
192  template <index_t cnt>
193  CK_TILE_DEVICE static constexpr index_t from_lgkmcnt()
194  {
195  static_assert(cnt >= 0 && !(cnt >> 4), "valid range is [0..15]");
196  return MAX & (cnt << 8);
197  }
198 #endif
199 };
200 
201 template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
202  index_t expcnt = waitcnt_arg::kMaxExpCnt,
203  index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
205 {
206 #if defined(__gfx12__)
207  // GFX12 do't use __builtin_amdgcn_s_waitcnt
208  constexpr index_t wait_mask = waitcnt_arg::from_vmcnt<vmcnt>() |
209  waitcnt_arg::from_expcnt<expcnt>() |
210  waitcnt_arg::from_lgkmcnt<lgkmcnt>();
211 
212  asm volatile("s_wait_loadcnt_dscnt %0" : : "n"(wait_mask) : "memory");
213 #else
214  __builtin_amdgcn_s_waitcnt(waitcnt_arg::from_vmcnt<vmcnt>() |
215  waitcnt_arg::from_expcnt<expcnt>() |
216  waitcnt_arg::from_lgkmcnt<lgkmcnt>());
217 #endif
218 }
219 
220 template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
221  index_t expcnt = waitcnt_arg::kMaxExpCnt,
222  index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
224 {
225 #if defined(__gfx12__)
226  // GFX12 optimization: Manual barrier implementation avoids performance penalty
227  // from __builtin_amdgcn_s_barrier which inserts extra s_wait_loadcnt_dscnt 0x0
228  constexpr index_t wait_mask = waitcnt_arg::from_vmcnt<vmcnt>() |
229  waitcnt_arg::from_expcnt<expcnt>() |
230  waitcnt_arg::from_lgkmcnt<lgkmcnt>();
231 
232  asm volatile("s_wait_loadcnt_dscnt %0\n"
233  "s_barrier_signal -1\n"
234  "s_barrier_wait -1"
235  :
236  : "n"(wait_mask)
237  : "memory");
238 #else
239  s_waitcnt<vmcnt, expcnt, lgkmcnt>();
240  __builtin_amdgcn_s_barrier();
241 #endif
242 }
243 
244 template <index_t lgkmcnt = 0>
246 {
247  s_waitcnt_barrier<waitcnt_arg::kMaxVmCnt, waitcnt_arg::kMaxExpCnt, lgkmcnt>();
248 }
249 
250 template <index_t vmcnt = 0>
252 {
253  s_waitcnt_barrier<vmcnt, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
254 }
255 
257 {
258 #if 1
259  asm volatile("s_nop %0" : : "n"(cnt) :);
260 #else
261  __builtin_amdgcn_sched_barrier(cnt);
262 #endif
263 }
264 
265 #define CK_CONSTANT_ADDRESS_SPACE \
266  __attribute__((address_space( \
267  static_cast<safe_underlying_type_t<address_space_enum>>(address_space_enum::constant))))
268 
269 template <typename T>
271 {
272  // cast a pointer in "Constant" address space (4) to "Generic" address space (0)
273  // only c-style pointer cast seems be able to be compiled
274 #pragma clang diagnostic push
275 #pragma clang diagnostic ignored "-Wold-style-cast"
276  return (T*)(p); // NOLINT(old-style-cast)
277 #pragma clang diagnostic pop
278 }
279 
280 template <typename T>
282 {
283  // cast a pointer in "Generic" address space (0) to "Constant" address space (4)
284  // only c-style pointer cast seems be able to be compiled;
285 #pragma clang diagnostic push
286 #pragma clang diagnostic ignored "-Wold-style-cast"
287  return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast)
288 #pragma clang diagnostic pop
289 }
290 
292 {
293 #if defined(__gfx950__)
294  return 163840;
295 #else
296  return 65536;
297 #endif
298 }
299 
301 CK_TILE_HOST_DEVICE constexpr const char* address_space_to_string(address_space_enum addr_space)
302 {
303  switch(addr_space)
304  {
305  case address_space_enum::generic: return "generic";
306  case address_space_enum::global: return "global";
307  case address_space_enum::lds: return "lds";
308  case address_space_enum::sgpr: return "sgpr";
309  case address_space_enum::constant: return "constant";
310  case address_space_enum::vgpr: return "vgpr";
311  default: return "unknown";
312  }
313 }
314 
315 // Architecture tags
316 struct gfx11_t
317 {
318 };
319 struct gfx12_t
320 {
321 };
322 
323 CK_TILE_DEVICE static constexpr auto get_device_arch()
324 {
325 #if defined(__gfx11__)
326  return gfx11_t{};
327 #else // if defined(__gfx12__)
328  return gfx12_t{};
329 #endif
330 }
331 
333 {
334  NONE = 0,
335  ALU = 1 << 0,
336  VALU = 1 << 1,
337  SALU = 1 << 2,
338  MFMA = 1 << 3,
339  VMEM = 1 << 4,
340  VMEM_READ = 1 << 5,
341  VMEM_WRITE = 1 << 6,
342  DS = 1 << 7,
343  DS_READ = 1 << 8,
344  DS_WRITE = 1 << 9,
345  ALL = (DS_WRITE << 1) - 1,
346 };
347 } // namespace ck_tile
#define CK_CONSTANT_ADDRESS_SPACE
Definition: arch.hpp:265
constexpr CK_TILE_HOST_DEVICE const char * address_space_to_string(address_space_enum addr_space)
Helper function to convert address space enum to string.
Definition: arch.hpp:301
constexpr CK_TILE_HOST_DEVICE index_t get_smem_capacity()
Definition: arch.hpp:291
CK_TILE_DEVICE void s_nop(index_t cnt=0)
Definition: arch.hpp:256
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition: arch.hpp:270
CK_TILE_DEVICE void block_sync_lds_direct_load()
Definition: arch.hpp:251
CK_TILE_DEVICE void s_waitcnt_barrier()
Definition: arch.hpp:223
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition: arch.hpp:281
LLVMSchedGroupMask
Definition: arch.hpp:333
@ VMEM_READ
Definition: arch.hpp:340
@ DS_READ
Definition: arch.hpp:343
@ SALU
Definition: arch.hpp:337
@ VMEM
Definition: arch.hpp:339
@ ALL
Definition: arch.hpp:345
@ DS
Definition: arch.hpp:342
@ VALU
Definition: arch.hpp:336
@ NONE
Definition: arch.hpp:334
@ MFMA
Definition: arch.hpp:338
@ DS_WRITE
Definition: arch.hpp:344
@ VMEM_WRITE
Definition: arch.hpp:341
@ ALU
Definition: arch.hpp:335
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:245
CK_TILE_DEVICE void s_waitcnt()
Definition: arch.hpp:204
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE void atomic_add(X *p_dst, const X &x)
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:33
CK_TILE_HOST_DEVICE T add(const T &a, const T &b)
Definition: generic_memory_space_atomic.hpp:16
int32_t index_t
Definition: integer.hpp:9
__device__ index_t get_grid_size()
Definition: get_id.hpp:49
__device__ index_t get_block_size()
Definition: get_id.hpp:51
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
__device__ index_t get_thread_global_1d_id()
Definition: get_id.hpp:43
__device__ X atomic_max(X *p_dst, const X &x)
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
unsigned short uint16_t
Definition: stdint.h:125
signed int int32_t
Definition: stdint.h:123
Definition: arch.hpp:317
Definition: arch.hpp:320