13 #define CK_TILE_S_CNT_MAX 0b1100'1111'0111'1111
14 #define CK_TILE_VMCNT(cnt) \
15 ([]() { static_assert(!((cnt) >> 6), "VMCNT only has 6 bits"); }(), \
16 ((cnt) & 0b1111) | (((cnt) & 0b110000) << 10))
17 #define CK_TILE_EXPCNT(cnt) \
18 ([]() { static_assert(!((cnt) >> 3), "EXP only has 3 bits"); }(), ((cnt) << 4))
19 #define CK_TILE_LGKMCNT(cnt) \
20 ([]() { static_assert(!((cnt) >> 4), "LGKM only has 4 bits"); }(), ((cnt) << 8))
24 template <
typename,
bool>
25 struct safe_underlying_type;
28 struct safe_underlying_type<T, true>
30 using type = std::underlying_type_t<T>;
34 struct safe_underlying_type<T, false>
40 using safe_underlying_type_t =
typename safe_underlying_type<T, std::is_enum<T>::value>::type;
42 enum struct address_space_enum : std::uint16_t
52 enum struct memory_operation_enum : std::uint16_t
62 #if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
85 return __builtin_amdgcn_readfirstlane(threadIdx.x /
get_warp_size());
94 #if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
100 __builtin_amdgcn_s_waitcnt(0xc07f);
101 __builtin_amdgcn_s_barrier();
110 asm volatile(
"s_wait_loadcnt %0 \n"
111 "s_barrier_signal -1 \n"
117 asm volatile(
"s_waitcnt vmcnt(%0) \n"
136 template <index_t cnt>
139 static_assert(cnt >= 0 && !(cnt >> 6),
"valid range is [0..63]");
140 return MAX & ((cnt & 0b1111) | ((cnt & 0b110000) << 10));
143 template <index_t cnt>
146 static_assert(cnt >= 0 && !(cnt >> 3),
"valid range is [0..7]");
147 return MAX & (cnt << 4);
150 template <index_t cnt>
153 static_assert(cnt >= 0 && !(cnt >> 4),
"valid range is [0..15]");
154 return MAX & (cnt << 8);
158 template <
index_t vmcnt = waitcnt_arg::kMaxVmCnt,
159 index_t expcnt = waitcnt_arg::kMaxExpCnt,
160 index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
163 __builtin_amdgcn_s_waitcnt(waitcnt_arg::from_vmcnt<vmcnt>() |
164 waitcnt_arg::from_expcnt<expcnt>() |
165 waitcnt_arg::from_lgkmcnt<lgkmcnt>());
168 template <
index_t vmcnt = waitcnt_arg::kMaxVmCnt,
169 index_t expcnt = waitcnt_arg::kMaxExpCnt,
170 index_t lgkmcnt = waitcnt_arg::kMaxLgkmCnt>
173 s_waitcnt<vmcnt, expcnt, lgkmcnt>();
174 __builtin_amdgcn_s_barrier();
182 s_waitcnt_barrier<0, waitcnt_arg::kMaxExpCnt, 0>();
186 s_waitcnt vmcnt(0) \n \
187 s_waitcnt lgkmcnt(0) \n \
196 asm volatile(
"s_nop %0" : :
"n"(cnt) :);
198 __builtin_amdgcn_sched_barrier(cnt);
202 #define CK_CONSTANT_ADDRESS_SPACE \
203 __attribute__((address_space( \
204 static_cast<safe_underlying_type_t<address_space_enum>>(address_space_enum::constant))))
206 template <
typename T>
211 #pragma clang diagnostic push
212 #pragma clang diagnostic ignored "-Wold-style-cast"
214 #pragma clang diagnostic pop
217 template <
typename T>
222 #pragma clang diagnostic push
223 #pragma clang diagnostic ignored "-Wold-style-cast"
225 #pragma clang diagnostic pop
230 #if defined(__gfx950__)
#define CK_CONSTANT_ADDRESS_SPACE
Definition: arch.hpp:202
constexpr CK_TILE_HOST_DEVICE index_t get_smem_capacity()
Definition: arch.hpp:228
CK_TILE_DEVICE void s_nop(index_t cnt=0)
Definition: arch.hpp:193
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition: arch.hpp:207
CK_TILE_DEVICE void s_waitcnt_barrier()
Definition: arch.hpp:171
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition: arch.hpp:218
CK_TILE_DEVICE void block_sync_lds_direct_load()
Definition: arch.hpp:177
CK_TILE_DEVICE void s_waitcnt()
Definition: arch.hpp:161
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE void atomic_add(X *p_dst, const X &x)
CK_TILE_HOST_DEVICE T add(const T &a, const T &b)
Definition: generic_memory_space_atomic.hpp:12
int32_t index_t
Definition: integer.hpp:9
__device__ index_t get_grid_size()
Definition: get_id.hpp:27
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__device__ index_t get_block_size()
Definition: get_id.hpp:29
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
__device__ index_t get_thread_global_1d_id()
Definition: get_id.hpp:21
__device__ X atomic_max(X *p_dst, const X &x)
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
__device__ void block_sync_lds()
Definition: synchronization.hpp:10