/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/arch.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/arch.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/arch/arch.hpp Source File
arch.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 // Address Space for AMDGCN
7 // https://llvm.org/docs/AMDGPUUsage.html#address-space
8 
15 
16 #define CK_TILE_S_CNT_MAX 0b1100'1111'0111'1111
17 #define CK_TILE_VMCNT(cnt) \
18  ([]() { static_assert(!((cnt) >> 6), "VMCNT only has 6 bits"); }(), \
19  ((cnt) & 0b1111) | (((cnt) & 0b110000) << 10))
20 #define CK_TILE_EXPCNT(cnt) \
21  ([]() { static_assert(!((cnt) >> 3), "EXP only has 3 bits"); }(), ((cnt) << 4))
22 #define CK_TILE_LGKMCNT(cnt) \
23  ([]() { static_assert(!((cnt) >> 4), "LGKM only has 4 bits"); }(), ((cnt) << 8))
24 
25 namespace ck_tile {
26 
27 template <typename, bool>
28 struct safe_underlying_type;
29 
30 template <typename T>
31 struct safe_underlying_type<T, true>
32 {
33  using type = std::underlying_type_t<T>;
34 };
35 
36 template <typename T>
37 struct safe_underlying_type<T, false>
38 {
39  using type = void;
40 };
41 
42 template <typename T>
43 using safe_underlying_type_t = typename safe_underlying_type<T, std::is_enum<T>::value>::type;
44 
45 enum struct address_space_enum : std::uint16_t
46 {
47  generic = 0,
48  global,
49  lds,
50  sgpr,
51  constant,
52  vgpr
53 };
54 
55 enum struct memory_operation_enum : std::uint16_t
56 {
57  set = 0,
58  atomic_add,
59  atomic_max,
60  add
61 };
62 
64 {
65 #if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
66  return 64;
67 #else
68  return 32;
69 #endif
70 }
71 
72 CK_TILE_HOST bool is_wave32()
73 {
74  hipDeviceProp_t props{};
75  int device;
76  auto status = hipGetDevice(&device);
77  if(status != hipSuccess)
78  {
79  return false;
80  }
81  status = hipGetDeviceProperties(&props, device);
82  if(status != hipSuccess)
83  {
84  return false;
85  }
86  return props.major > 9;
87 }
88 
89 CK_TILE_DEVICE index_t get_grid_size() { return gridDim.x; }
90 
91 CK_TILE_DEVICE index_t get_block_size() { return blockDim.x; }
92 
93 // TODO: deprecate these
94 CK_TILE_DEVICE index_t get_thread_local_1d_id() { return threadIdx.x; }
95 
96 CK_TILE_DEVICE index_t get_thread_global_1d_id() { return blockIdx.x * blockDim.x + threadIdx.x; }
97 
98 CK_TILE_DEVICE index_t get_block_1d_id() { return blockIdx.x; }
99 
100 // Use these instead
101 CK_TILE_DEVICE index_t get_lane_id() { return __lane_id(); }
102 
103 template <bool ReturnSgpr = true>
104 CK_TILE_DEVICE index_t get_warp_id(bool_constant<ReturnSgpr> = {})
105 {
106  const index_t warp_id = threadIdx.x / get_warp_size();
107  if constexpr(ReturnSgpr)
108  {
109  return amd_wave_read_first_lane(warp_id);
110  }
111  else
112  {
113  return warp_id;
114  }
115 }
116 
117 CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }
118 
119 CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
120 
121 CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
122 {
123 #ifdef __gfx12__
124  asm volatile("s_wait_loadcnt %0 \n"
125  "s_barrier_signal -1 \n"
126  "s_barrier_wait -1"
127  :
128  : "n"(cnt)
129  : "memory");
130 #else
131  asm volatile("s_waitcnt vmcnt(%0) \n"
132  "s_barrier"
133  :
134  : "n"(cnt)
135  : "memory");
136 #endif
137 }
138 
139 struct WaitcntLayoutGfx12
140 { // s_wait_loadcnt_dscnt: mem[13:8], ds[5:0]
141  CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F; // mem
142  CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x3F; // ds
143  CK_TILE_DEVICE static constexpr bool HAS_EXP = false;
144 
145  CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c) { return ((c & VM_MASK) << 8); }
146  CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 0); }
147  CK_TILE_DEVICE static constexpr index_t pack_exp(index_t) { return 0; }
148 };
149 
150 struct WaitcntLayoutGfx11
151 { // vm[15:10] (6), lgkm[9:4] (6), exp unused
152  CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F;
153  CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x3F;
154  CK_TILE_DEVICE static constexpr bool HAS_EXP = false;
155 
156  CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c) { return ((c & VM_MASK) << 10); }
157  CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 4); }
158  CK_TILE_DEVICE static constexpr index_t pack_exp(index_t) { return 0; }
159 };
160 
161 struct WaitcntLayoutLegacy
162 { // FE'DC'BA98'7'654'3210 => VV'UU'LLLL'U'EEE'VVVV
163  CK_TILE_DEVICE static constexpr index_t VM_MASK = 0x3F; // split: low4 + hi2
164  CK_TILE_DEVICE static constexpr index_t LGKM_MASK = 0x0F; // [11:8]
165  CK_TILE_DEVICE static constexpr index_t EXP_MASK = 0x07; // [6:4]
166  CK_TILE_DEVICE static constexpr bool HAS_EXP = true;
167 
168  CK_TILE_DEVICE static constexpr index_t pack_vm(index_t c)
169  {
170  c &= VM_MASK;
171  return ((c & 0xF) << 0) | ((c & 0x30) << 10);
172  }
173  CK_TILE_DEVICE static constexpr index_t pack_lgkm(index_t c) { return ((c & LGKM_MASK) << 8); }
174  CK_TILE_DEVICE static constexpr index_t pack_exp(index_t c) { return ((c & EXP_MASK) << 4); }
175 };
176 
177 // Select active layout
178 #if defined(__gfx12__)
179 using Waitcnt = WaitcntLayoutGfx12;
180 #elif defined(__gfx11__)
181 using Waitcnt = WaitcntLayoutGfx11;
182 #else
183 using Waitcnt = WaitcntLayoutLegacy;
184 #endif
185 
186 //----------------------------------------------
187 // Public API: only from_* (constexpr templates)
188 //----------------------------------------------
189 struct waitcnt_arg
190 {
191  // kMax* exposed for callers; match field widths per-arch
192 #if defined(__gfx12__) || defined(__gfx11__)
193  CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits
194  CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x3F; // 6 bits
195  CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x0; // none
196 #else
197  CK_TILE_DEVICE static constexpr index_t kMaxVmCnt = 0x3F; // 6 bits (split)
198  CK_TILE_DEVICE static constexpr index_t kMaxLgkmCnt = 0x0F; // 4 bits
199  CK_TILE_DEVICE static constexpr index_t kMaxExpCnt = 0x07; // 3 bits
200 #endif
201 
202  template <index_t cnt>
203  CK_TILE_DEVICE static constexpr index_t from_vmcnt()
204  {
205  static_assert((cnt & ~Waitcnt::VM_MASK) == 0, "vmcnt out of range");
206  return Waitcnt::pack_vm(cnt);
207  }
208 
209  template <index_t cnt>
210  CK_TILE_DEVICE static constexpr index_t from_lgkmcnt()
211  {
212  static_assert((cnt & ~Waitcnt::LGKM_MASK) == 0, "lgkmcnt out of range");
213  return Waitcnt::pack_lgkm(cnt);
214  }
215 
216  template <index_t cnt>
217  CK_TILE_DEVICE static constexpr index_t from_expcnt()
218  {
219  if constexpr(Waitcnt::HAS_EXP)
220  {
221  // EXP_MASK only exists on legacy
222 #if !defined(__gfx12__) && !defined(__gfx11__)
223  static_assert((cnt & ~Waitcnt::EXP_MASK) == 0, "expcnt out of range");
224  return Waitcnt::pack_exp(cnt);
225 #else
226  (void)cnt;
227  return 0;
228 #endif
229  }
230  else
231  {
232  static_assert(cnt == 0, "expcnt unsupported on this arch");
233  return 0;
234  }
235  }
236 };
237 
238 template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
239  index_t expcnt = waitcnt_arg::kMaxExpCnt,
240  index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
241 CK_TILE_DEVICE void s_waitcnt()
242 {
243 #if defined(__gfx12__)
244  // GFX12 do't use __builtin_amdgcn_s_waitcnt
245  constexpr index_t wait_mask = waitcnt_arg::from_vmcnt<vmcnt>() |
246  waitcnt_arg::from_expcnt<expcnt>() |
247  waitcnt_arg::from_lgkmcnt<lgkmcnt>();
248 
249  asm volatile("s_wait_loadcnt_dscnt %0" : : "n"(wait_mask) : "memory");
250 #else
251  __builtin_amdgcn_s_waitcnt(waitcnt_arg::from_vmcnt<vmcnt>() |
252  waitcnt_arg::from_expcnt<expcnt>() |
253  waitcnt_arg::from_lgkmcnt<lgkmcnt>());
254 #endif
255 }
256 
257 template <index_t vmcnt = waitcnt_arg::kMaxVmCnt,
258  index_t expcnt = waitcnt_arg::kMaxExpCnt,
259  index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
260 CK_TILE_DEVICE void s_waitcnt_barrier()
261 {
262 #if defined(__gfx12__)
263  // GFX12 optimization: Manual barrier implementation avoids performance penalty
264  // from __builtin_amdgcn_s_barrier which inserts extra s_wait_loadcnt_dscnt 0x0
265  constexpr index_t wait_mask = waitcnt_arg::from_vmcnt<vmcnt>() |
266  waitcnt_arg::from_expcnt<expcnt>() |
267  waitcnt_arg::from_lgkmcnt<lgkmcnt>();
268 
269  asm volatile("s_wait_loadcnt_dscnt %0\n"
270  "s_barrier_signal -1\n"
271  "s_barrier_wait -1"
272  :
273  : "n"(wait_mask)
274  : "memory");
275 #else
276  s_waitcnt<vmcnt, expcnt, lgkmcnt>();
277  __builtin_amdgcn_s_barrier();
278 #endif
279 }
280 
281 template <index_t lgkmcnt = 0>
283 {
284  s_waitcnt_barrier<waitcnt_arg::kMaxVmCnt, waitcnt_arg::kMaxExpCnt, lgkmcnt>();
285 }
286 
287 template <index_t vmcnt = 0>
289 {
290  s_waitcnt_barrier<vmcnt, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
291 }
292 
293 CK_TILE_DEVICE void s_nop(index_t cnt = 0)
294 {
295 #if 1
296  asm volatile("s_nop %0" : : "n"(cnt) :);
297 #else
298  __builtin_amdgcn_sched_barrier(cnt);
299 #endif
300 }
301 
302 #define CK_CONSTANT_ADDRESS_SPACE \
303  __attribute__((address_space( \
304  static_cast<safe_underlying_type_t<address_space_enum>>(address_space_enum::constant))))
305 
306 template <typename T>
308 {
309  // cast a pointer in "Constant" address space (4) to "Generic" address space (0)
310  // only c-style pointer cast seems be able to be compiled
311 #pragma clang diagnostic push
312 #pragma clang diagnostic ignored "-Wold-style-cast"
313  return (T*)(p); // NOLINT(old-style-cast)
314 #pragma clang diagnostic pop
315 }
316 
317 template <typename T>
319 {
320  // cast a pointer in "Generic" address space (0) to "Constant" address space (4)
321  // only c-style pointer cast seems be able to be compiled;
322 #pragma clang diagnostic push
323 #pragma clang diagnostic ignored "-Wold-style-cast"
324  return (T CK_CONSTANT_ADDRESS_SPACE*)p; // NOLINT(old-style-cast)
325 #pragma clang diagnostic pop
326 }
327 
328 CK_TILE_HOST_DEVICE constexpr index_t get_smem_capacity()
329 {
330 #if defined(__gfx950__)
331  return 163840;
332 #else
333  return 65536;
334 #endif
335 }
336 
338 CK_TILE_HOST_DEVICE constexpr const char* address_space_to_string(address_space_enum addr_space)
339 {
340  switch(addr_space)
341  {
342  case address_space_enum::generic: return "generic";
343  case address_space_enum::global: return "global";
344  case address_space_enum::lds: return "lds";
345  case address_space_enum::sgpr: return "sgpr";
346  case address_space_enum::constant: return "constant";
347  case address_space_enum::vgpr: return "vgpr";
348  default: return "unknown";
349  }
350 }
351 
352 // Architecture tags
353 struct gfx9_t
354 {
355 };
356 struct gfx950_t
357 {
358 };
359 struct gfx103_t
360 {
361 };
362 struct gfx11_t
363 {
364 };
365 struct gfx12_t
366 {
367 };
368 struct gfx_invalid_t
369 {
370 };
371 
372 CK_TILE_DEVICE static constexpr auto get_device_arch()
373 {
374 // FIXME(0): on all devices except gfx11 it returns gfx12_t
375 // FIXME(1): during the host compilation pass it returns gfx12_t
376 #if defined(__gfx11__)
377  return gfx11_t{};
378 #else
379  return gfx12_t{};
380 #endif
381 }
382 
383 CK_TILE_DEVICE static constexpr auto get_n_words_per_128b() { return 4; }
384 
385 namespace detail {
386 CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx9_t) { return 32; }
387 
388 CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx103_t) { return 32; }
389 
390 CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx11_t) { return 32; }
391 
392 CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx12_t) { return 32; }
393 
394 CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx950_t) { return 64; }
395 
396 CK_TILE_DEVICE static constexpr auto get_n_lds_banks(gfx_invalid_t) { return 0; }
397 
398 CK_TILE_DEVICE static constexpr auto arch_tag_dispatch()
399 {
400 #if defined(__gfx103__)
401  return gfx103_t{};
402 #elif defined(__gfx11__)
403  return gfx11_t{};
404 #elif defined(__gfx12__)
405  return gfx12_t{};
406 #elif defined(__gfx950__)
407  return gfx950_t{};
408 #elif defined(__gfx9__)
409  return gfx9_t{};
410 #else
411  return gfx_invalid_t{};
412 #endif
413 }
414 } // namespace detail
415 CK_TILE_DEVICE static constexpr auto get_n_lds_banks()
416 {
417  return detail::get_n_lds_banks(detail::arch_tag_dispatch());
418 }
419 
420 enum LLVMSchedGroupMask : int32_t
421 {
422  NONE = 0,
423  ALU = 1 << 0,
424  VALU = 1 << 1,
425  SALU = 1 << 2,
426  MFMA = 1 << 3,
427  VMEM = 1 << 4,
428  VMEM_READ = 1 << 5,
429  VMEM_WRITE = 1 << 6,
430  DS = 1 << 7,
431  DS_READ = 1 << 8,
432  DS_WRITE = 1 << 9,
433  ALL = (DS_WRITE << 1) - 1,
434 };
435 } // namespace ck_tile
#define CK_CONSTANT_ADDRESS_SPACE
Definition: ck.hpp:23
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE void atomic_add(X *p_dst, const X &x)
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE T add(const T &a, const T &b)
Definition: generic_memory_space_atomic.hpp:16
int32_t index_t
Definition: integer.hpp:9
__device__ index_t get_grid_size()
Definition: get_id.hpp:49
__device__ void s_nop()
Definition: synchronization.hpp:61
__device__ index_t get_block_size()
Definition: get_id.hpp:51
__device__ void block_sync_lds_direct_load()
Definition: synchronization.hpp:43
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
__device__ index_t get_thread_global_1d_id()
Definition: get_id.hpp:43
__device__ X atomic_max(X *p_dst, const X &x)
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition: amd_address_space.hpp:35
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition: amd_address_space.hpp:24
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1697
unsigned short uint16_t
Definition: stdint.h:125
signed int int32_t
Definition: stdint.h:123