30 #ifndef HIPCUB_ROCPRIM_UTIL_PTX_HPP_
31 #define HIPCUB_ROCPRIM_UTIL_PTX_HPP_
34 #include <type_traits>
36 #include "../../config.hpp"
38 #include <rocprim/intrinsics/warp_shuffle.hpp>
40 #define HIPCUB_WARP_THREADS ::rocprim::warp_size()
41 #define HIPCUB_DEVICE_WARP_THREADS ::rocprim::device_warp_size()
42 #define HIPCUB_HOST_WARP_THREADS ::rocprim::host_warp_size()
46 BEGIN_HIPCUB_NAMESPACE
66 int RowMajorTid(
int block_dim_x,
int block_dim_y,
int block_dim_z)
68 return ((block_dim_z == 1) ? 0 : (hipThreadIdx_z * block_dim_x * block_dim_y))
69 + ((block_dim_y == 1) ? 0 : (hipThreadIdx_y * block_dim_x))
76 return ::rocprim::lane_id();
82 return ::rocprim::warp_id();
89 return (uint64_t(1) << LaneId()) - 1;
96 return ((uint64_t(1) << LaneId()) << 1) - 1;
101 uint64_t LaneMaskGt()
103 return uint64_t(-1)^LaneMaskLe();
108 uint64_t LaneMaskGe()
110 return uint64_t(-1)^LaneMaskLt();
116 int LOGICAL_WARP_THREADS,
123 unsigned int member_mask)
130 return ::rocprim::warp_shuffle_up(
131 input, src_offset, LOGICAL_WARP_THREADS
136 int LOGICAL_WARP_THREADS,
140 T ShuffleDown(T input,
143 unsigned int member_mask)
150 return ::rocprim::warp_shuffle_down(
151 input, src_offset, LOGICAL_WARP_THREADS
156 int LOGICAL_WARP_THREADS,
160 T ShuffleIndex(T input,
162 unsigned int member_mask)
167 return ::rocprim::warp_shuffle(
168 input, src_lane, LOGICAL_WARP_THREADS
175 unsigned int SHR_ADD(
unsigned int x,
179 return (x >> shift) + addend;
183 unsigned int SHL_ADD(
unsigned int x,
187 return (x << shift) + addend;
192 template <
typename Un
signedBits>
194 auto unsigned_bit_extract(UnsignedBits source,
195 unsigned int bit_start,
196 unsigned int num_bits)
197 ->
typename std::enable_if<
sizeof(UnsignedBits) == 8,
unsigned int>::type
199 #ifdef __HIP_PLATFORM_AMD__
200 return __bitextract_u64(source, bit_start, num_bits);
202 return (source << (64 - bit_start - num_bits)) >> (64 - num_bits);
206 template <
typename Un
signedBits>
208 auto unsigned_bit_extract(UnsignedBits source,
209 unsigned int bit_start,
210 unsigned int num_bits)
211 ->
typename std::enable_if<
sizeof(UnsignedBits) < 8,
unsigned int>::type
213 #ifdef __HIP_PLATFORM_AMD__
214 return __bitextract_u32(source, bit_start, num_bits);
216 return (
static_cast<unsigned int>(source) << (32 - bit_start - num_bits)) >> (32 - num_bits);
225 template <
typename Un
signedBits>
227 unsigned int BFE(UnsignedBits source,
228 unsigned int bit_start,
229 unsigned int num_bits)
231 static_assert(std::is_unsigned<UnsignedBits>::value,
"UnsignedBits must be unsigned");
232 return detail::unsigned_bit_extract(source, bit_start, num_bits);
238 void BFI(
unsigned int &ret,
241 unsigned int bit_start,
242 unsigned int num_bits)
244 #ifdef __HIP_PLATFORM_AMD__
245 ret = __bitinsert_u32(x, y, bit_start, num_bits);
248 unsigned int MASK_X = ((1 << num_bits) - 1) << bit_start;
249 unsigned int MASK_Y = ~MASK_X;
250 ret = (y & MASK_Y) | (x & MASK_X);
255 unsigned int IADD3(
unsigned int x,
unsigned int y,
unsigned int z)
261 int PRMT(
unsigned int a,
unsigned int b,
unsigned int index)
263 return ::__byte_perm(a, b, index);
280 void WARP_SYNC(
unsigned int member_mask)
287 int WARP_ANY(
int predicate, uint64_t member_mask)
290 return ::__any(predicate);
294 int WARP_ALL(
int predicate, uint64_t member_mask)
297 return ::__all(predicate);
301 int64_t WARP_BALLOT(
int predicate, uint64_t member_mask)
304 return __ballot(predicate);