30 #ifndef HIPCUB_ROCPRIM_UTIL_PTX_HPP_
31 #define HIPCUB_ROCPRIM_UTIL_PTX_HPP_
34 #include <type_traits>
36 #include "../../config.hpp"
38 #include <rocprim/intrinsics/warp_shuffle.hpp>
40 BEGIN_HIPCUB_NAMESPACE
60 int RowMajorTid(
int block_dim_x,
int block_dim_y,
int block_dim_z)
62 return ((block_dim_z == 1) ? 0 : (hipThreadIdx_z * block_dim_x * block_dim_y))
63 + ((block_dim_y == 1) ? 0 : (hipThreadIdx_y * block_dim_x))
70 return ::rocprim::lane_id();
76 return ::rocprim::warp_id();
79 template <
int LOGICAL_WARP_THREADS,
int = 0>
81 uint64_t WarpMask(
unsigned int warp_id) {
82 constexpr
bool is_pow_of_two = ::rocprim::detail::is_power_of_two(LOGICAL_WARP_THREADS);
83 constexpr
bool is_arch_warp =
84 LOGICAL_WARP_THREADS == ::rocprim::device_warp_size();
86 uint64_t member_mask = uint64_t(-1) >> (64 - LOGICAL_WARP_THREADS);
88 if (is_pow_of_two && !is_arch_warp) {
89 member_mask <<= warp_id * LOGICAL_WARP_THREADS;
99 return (uint64_t(1) << LaneId()) - 1;
104 uint64_t LaneMaskLe()
106 return ((uint64_t(1) << LaneId()) << 1) - 1;
111 uint64_t LaneMaskGt()
113 return uint64_t(-1)^LaneMaskLe();
118 uint64_t LaneMaskGe()
120 return uint64_t(-1)^LaneMaskLt();
126 int LOGICAL_WARP_THREADS,
133 unsigned int member_mask)
140 return ::rocprim::warp_shuffle_up(
141 input, src_offset, LOGICAL_WARP_THREADS
146 int LOGICAL_WARP_THREADS,
150 T ShuffleDown(T input,
153 unsigned int member_mask)
160 return ::rocprim::warp_shuffle_down(
161 input, src_offset, LOGICAL_WARP_THREADS
166 int LOGICAL_WARP_THREADS,
170 T ShuffleIndex(T input,
172 unsigned int member_mask)
177 return ::rocprim::warp_shuffle(
178 input, src_lane, LOGICAL_WARP_THREADS
185 unsigned int SHR_ADD(
unsigned int x,
189 return (x >> shift) + addend;
193 unsigned int SHL_ADD(
unsigned int x,
197 return (x << shift) + addend;
202 template <
typename Un
signedBits>
204 auto unsigned_bit_extract(UnsignedBits source,
205 unsigned int bit_start,
206 unsigned int num_bits)
207 ->
typename std::enable_if<
sizeof(UnsignedBits) == 8,
unsigned int>::type
209 #ifdef __HIP_PLATFORM_AMD__
210 return __bitextract_u64(source, bit_start, num_bits);
212 return (source << (64 - bit_start - num_bits)) >> (64 - num_bits);
216 template <
typename Un
signedBits>
218 auto unsigned_bit_extract(UnsignedBits source,
219 unsigned int bit_start,
220 unsigned int num_bits)
221 ->
typename std::enable_if<
sizeof(UnsignedBits) < 8,
unsigned int>::type
223 #ifdef __HIP_PLATFORM_AMD__
224 return __bitextract_u32(source, bit_start, num_bits);
226 return (
static_cast<unsigned int>(source) << (32 - bit_start - num_bits)) >> (32 - num_bits);
235 template <
typename Un
signedBits>
237 unsigned int BFE(UnsignedBits source,
238 unsigned int bit_start,
239 unsigned int num_bits)
241 static_assert(std::is_unsigned<UnsignedBits>::value,
"UnsignedBits must be unsigned");
242 return detail::unsigned_bit_extract(source, bit_start, num_bits);
248 void BFI(
unsigned int &ret,
251 unsigned int bit_start,
252 unsigned int num_bits)
254 #ifdef __HIP_PLATFORM_AMD__
255 ret = __bitinsert_u32(x, y, bit_start, num_bits);
258 unsigned int MASK_X = ((1 << num_bits) - 1) << bit_start;
259 unsigned int MASK_Y = ~MASK_X;
260 ret = (y & MASK_Y) | (x & MASK_X);
265 unsigned int IADD3(
unsigned int x,
unsigned int y,
unsigned int z)
271 int PRMT(
unsigned int a,
unsigned int b,
unsigned int index)
273 return ::__byte_perm(a, b, index);
290 void WARP_SYNC(
unsigned int member_mask)
293 ::rocprim::wave_barrier();
297 int WARP_ANY(
int predicate, uint64_t member_mask)
300 return ::__any(predicate);
304 int WARP_ALL(
int predicate, uint64_t member_mask)
307 return ::__all(predicate);
311 int64_t WARP_BALLOT(
int predicate, uint64_t member_mask)
314 return __ballot(predicate);