30 #ifndef HIPCUB_ROCPRIM_UTIL_PTX_HPP_
31 #define HIPCUB_ROCPRIM_UTIL_PTX_HPP_
34 #include <type_traits>
36 #include "../../config.hpp"
38 #include <rocprim/intrinsics/warp_shuffle.hpp>
40 #define HIPCUB_WARP_THREADS ::rocprim::warp_size()
41 #define HIPCUB_DEVICE_WARP_THREADS ::rocprim::device_warp_size()
42 #define HIPCUB_HOST_WARP_THREADS ::rocprim::host_warp_size()
46 BEGIN_HIPCUB_NAMESPACE
66 int RowMajorTid(
int block_dim_x,
int block_dim_y,
int block_dim_z)
68 return ((block_dim_z == 1) ? 0 : (hipThreadIdx_z * block_dim_x * block_dim_y))
69 + ((block_dim_y == 1) ? 0 : (hipThreadIdx_y * block_dim_x))
76 return ::rocprim::lane_id();
82 return ::rocprim::warp_id();
85 template <
int LOGICAL_WARP_THREADS,
int = 0>
87 uint64_t WarpMask(
unsigned int warp_id) {
88 constexpr
bool is_pow_of_two = ::rocprim::detail::is_power_of_two(LOGICAL_WARP_THREADS);
89 constexpr
bool is_arch_warp =
90 LOGICAL_WARP_THREADS == ::rocprim::device_warp_size();
92 uint64_t member_mask = uint64_t(-1) >> (64 - LOGICAL_WARP_THREADS);
94 if (is_pow_of_two && !is_arch_warp) {
95 member_mask <<= warp_id * LOGICAL_WARP_THREADS;
103 uint64_t LaneMaskLt()
105 return (uint64_t(1) << LaneId()) - 1;
110 uint64_t LaneMaskLe()
112 return ((uint64_t(1) << LaneId()) << 1) - 1;
117 uint64_t LaneMaskGt()
119 return uint64_t(-1)^LaneMaskLe();
124 uint64_t LaneMaskGe()
126 return uint64_t(-1)^LaneMaskLt();
132 int LOGICAL_WARP_THREADS,
139 unsigned int member_mask)
146 return ::rocprim::warp_shuffle_up(
147 input, src_offset, LOGICAL_WARP_THREADS
152 int LOGICAL_WARP_THREADS,
156 T ShuffleDown(T input,
159 unsigned int member_mask)
166 return ::rocprim::warp_shuffle_down(
167 input, src_offset, LOGICAL_WARP_THREADS
172 int LOGICAL_WARP_THREADS,
176 T ShuffleIndex(T input,
178 unsigned int member_mask)
183 return ::rocprim::warp_shuffle(
184 input, src_lane, LOGICAL_WARP_THREADS
191 unsigned int SHR_ADD(
unsigned int x,
195 return (x >> shift) + addend;
199 unsigned int SHL_ADD(
unsigned int x,
203 return (x << shift) + addend;
208 template <
typename Un
signedBits>
210 auto unsigned_bit_extract(UnsignedBits source,
211 unsigned int bit_start,
212 unsigned int num_bits)
213 ->
typename std::enable_if<
sizeof(UnsignedBits) == 8,
unsigned int>::type
215 #ifdef __HIP_PLATFORM_AMD__
216 return __bitextract_u64(source, bit_start, num_bits);
218 return (source << (64 - bit_start - num_bits)) >> (64 - num_bits);
222 template <
typename Un
signedBits>
224 auto unsigned_bit_extract(UnsignedBits source,
225 unsigned int bit_start,
226 unsigned int num_bits)
227 ->
typename std::enable_if<
sizeof(UnsignedBits) < 8,
unsigned int>::type
229 #ifdef __HIP_PLATFORM_AMD__
230 return __bitextract_u32(source, bit_start, num_bits);
232 return (
static_cast<unsigned int>(source) << (32 - bit_start - num_bits)) >> (32 - num_bits);
241 template <
typename Un
signedBits>
243 unsigned int BFE(UnsignedBits source,
244 unsigned int bit_start,
245 unsigned int num_bits)
247 static_assert(std::is_unsigned<UnsignedBits>::value,
"UnsignedBits must be unsigned");
248 return detail::unsigned_bit_extract(source, bit_start, num_bits);
254 void BFI(
unsigned int &ret,
257 unsigned int bit_start,
258 unsigned int num_bits)
260 #ifdef __HIP_PLATFORM_AMD__
261 ret = __bitinsert_u32(x, y, bit_start, num_bits);
264 unsigned int MASK_X = ((1 << num_bits) - 1) << bit_start;
265 unsigned int MASK_Y = ~MASK_X;
266 ret = (y & MASK_Y) | (x & MASK_X);
271 unsigned int IADD3(
unsigned int x,
unsigned int y,
unsigned int z)
277 int PRMT(
unsigned int a,
unsigned int b,
unsigned int index)
279 return ::__byte_perm(a, b, index);
296 void WARP_SYNC(
unsigned int member_mask)
299 ::rocprim::wave_barrier();
303 int WARP_ANY(
int predicate, uint64_t member_mask)
306 return ::__any(predicate);
310 int WARP_ALL(
int predicate, uint64_t member_mask)
313 return ::__all(predicate);
317 int64_t WARP_BALLOT(
int predicate, uint64_t member_mask)
320 return __ballot(predicate);