16 #define CK_TILE_S_CNT_MAX 0b1100'1111'0111'1111
17 #define CK_TILE_VMCNT(cnt) \
18 ([]() { static_assert(!((cnt) >> 6), "VMCNT only has 6 bits"); }(), \
19 ((cnt) & 0b1111) | (((cnt) & 0b110000) << 10))
20 #define CK_TILE_EXPCNT(cnt) \
21 ([]() { static_assert(!((cnt) >> 3), "EXP only has 3 bits"); }(), ((cnt) << 4))
22 #define CK_TILE_LGKMCNT(cnt) \
23 ([]() { static_assert(!((cnt) >> 4), "LGKM only has 4 bits"); }(), ((cnt) << 8))
27 template <
typename,
bool>
28 struct safe_underlying_type;
31 struct safe_underlying_type<T, true>
33 using type = std::underlying_type_t<T>;
37 struct safe_underlying_type<T, false>
65 #if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
74 hipDeviceProp_t props{};
76 auto status = hipGetDevice(&device);
77 if(status != hipSuccess)
81 status = hipGetDeviceProperties(&props, device);
82 if(status != hipSuccess)
86 return props.major > 9;
103 template <
bool ReturnSgpr = true>
107 if constexpr(ReturnSgpr)
124 asm volatile(
"s_wait_loadcnt %0 \n"
125 "s_barrier_signal -1 \n"
131 asm volatile(
"s_waitcnt vmcnt(%0) \n"
139 struct WaitcntLayoutGfx12
150 struct WaitcntLayoutGfx11
161 struct WaitcntLayoutLegacy
171 return ((c & 0xF) << 0) | ((c & 0x30) << 10);
178 #if defined(__gfx12__)
179 using Waitcnt = WaitcntLayoutGfx12;
180 #elif defined(__gfx11__)
181 using Waitcnt = WaitcntLayoutGfx11;
183 using Waitcnt = WaitcntLayoutLegacy;
192 #if defined(__gfx12__) || defined(__gfx11__)
202 template <index_t cnt>
205 static_assert((cnt & ~Waitcnt::VM_MASK) == 0,
"vmcnt out of range");
206 return Waitcnt::pack_vm(cnt);
209 template <index_t cnt>
212 static_assert((cnt & ~Waitcnt::LGKM_MASK) == 0,
"lgkmcnt out of range");
213 return Waitcnt::pack_lgkm(cnt);
216 template <index_t cnt>
219 if constexpr(Waitcnt::HAS_EXP)
222 #if !defined(__gfx12__) && !defined(__gfx11__)
223 static_assert((cnt & ~Waitcnt::EXP_MASK) == 0,
"expcnt out of range");
224 return Waitcnt::pack_exp(cnt);
232 static_assert(cnt == 0,
"expcnt unsupported on this arch");
238 template <
index_t vmcnt = waitcnt_arg::kMaxVmCnt,
239 index_t expcnt = waitcnt_arg::kMaxExpCnt,
240 index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
243 #if defined(__gfx12__)
245 constexpr
index_t wait_mask = waitcnt_arg::from_vmcnt<vmcnt>() |
246 waitcnt_arg::from_expcnt<expcnt>() |
247 waitcnt_arg::from_lgkmcnt<lgkmcnt>();
249 asm volatile(
"s_wait_loadcnt_dscnt %0" : :
"n"(wait_mask) :
"memory");
251 __builtin_amdgcn_s_waitcnt(waitcnt_arg::from_vmcnt<vmcnt>() |
252 waitcnt_arg::from_expcnt<expcnt>() |
253 waitcnt_arg::from_lgkmcnt<lgkmcnt>());
257 template <
index_t vmcnt = waitcnt_arg::kMaxVmCnt,
258 index_t expcnt = waitcnt_arg::kMaxExpCnt,
259 index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
262 #if defined(__gfx12__)
265 constexpr
index_t wait_mask = waitcnt_arg::from_vmcnt<vmcnt>() |
266 waitcnt_arg::from_expcnt<expcnt>() |
267 waitcnt_arg::from_lgkmcnt<lgkmcnt>();
269 asm volatile(
"s_wait_loadcnt_dscnt %0\n"
270 "s_barrier_signal -1\n"
276 s_waitcnt<vmcnt, expcnt, lgkmcnt>();
277 __builtin_amdgcn_s_barrier();
281 template <index_t lgkmcnt = 0>
284 s_waitcnt_barrier<waitcnt_arg::kMaxVmCnt, waitcnt_arg::kMaxExpCnt, lgkmcnt>();
287 template <index_t vmcnt = 0>
290 s_waitcnt_barrier<vmcnt, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
296 asm volatile(
"s_nop %0" : :
"n"(cnt) :);
298 __builtin_amdgcn_sched_barrier(cnt);
302 #define CK_CONSTANT_ADDRESS_SPACE \
303 __attribute__((address_space( \
304 static_cast<safe_underlying_type_t<address_space_enum>>(address_space_enum::constant))))
306 template <
typename T>
311 #pragma clang diagnostic push
312 #pragma clang diagnostic ignored "-Wold-style-cast"
314 #pragma clang diagnostic pop
317 template <
typename T>
322 #pragma clang diagnostic push
323 #pragma clang diagnostic ignored "-Wold-style-cast"
325 #pragma clang diagnostic pop
330 #if defined(__gfx950__)
338 CK_TILE_HOST_DEVICE constexpr
const char* address_space_to_string(address_space_enum addr_space)
342 case address_space_enum::generic:
return "generic";
343 case address_space_enum::global:
return "global";
344 case address_space_enum::lds:
return "lds";
345 case address_space_enum::sgpr:
return "sgpr";
346 case address_space_enum::constant:
return "constant";
347 case address_space_enum::vgpr:
return "vgpr";
348 default:
return "unknown";
376 #if defined(__gfx11__)
383 CK_TILE_DEVICE static constexpr
auto get_n_words_per_128b() {
return 4; }
386 CK_TILE_DEVICE static constexpr
auto get_n_lds_banks(gfx9_t) {
return 32; }
388 CK_TILE_DEVICE static constexpr
auto get_n_lds_banks(gfx103_t) {
return 32; }
390 CK_TILE_DEVICE static constexpr
auto get_n_lds_banks(gfx11_t) {
return 32; }
392 CK_TILE_DEVICE static constexpr
auto get_n_lds_banks(gfx12_t) {
return 32; }
394 CK_TILE_DEVICE static constexpr
auto get_n_lds_banks(gfx950_t) {
return 64; }
396 CK_TILE_DEVICE static constexpr
auto get_n_lds_banks(gfx_invalid_t) {
return 0; }
400 #if defined(__gfx103__)
402 #elif defined(__gfx11__)
404 #elif defined(__gfx12__)
406 #elif defined(__gfx950__)
408 #elif defined(__gfx9__)
411 return gfx_invalid_t{};
417 return detail::get_n_lds_banks(detail::arch_tag_dispatch());
420 enum LLVMSchedGroupMask :
int32_t
433 ALL = (DS_WRITE << 1) - 1,
#define CK_CONSTANT_ADDRESS_SPACE
Definition: ck.hpp:23
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE void atomic_add(X *p_dst, const X &x)
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
CK_TILE_HOST_DEVICE T add(const T &a, const T &b)
Definition: generic_memory_space_atomic.hpp:16
int32_t index_t
Definition: integer.hpp:9
__device__ index_t get_grid_size()
Definition: get_id.hpp:49
__device__ void s_nop()
Definition: synchronization.hpp:61
__device__ index_t get_block_size()
Definition: get_id.hpp:51
__device__ void block_sync_lds_direct_load()
Definition: synchronization.hpp:43
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:47
__device__ index_t get_thread_global_1d_id()
Definition: get_id.hpp:43
__device__ X atomic_max(X *p_dst, const X &x)
constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition: amd_address_space.hpp:35
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition: amd_address_space.hpp:24
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:41
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1697
unsigned short uint16_t
Definition: stdint.h:125
signed int int32_t
Definition: stdint.h:123