30 #ifndef HIPCUB_ROCPRIM_UTIL_TYPE_HPP_
31 #define HIPCUB_ROCPRIM_UTIL_TYPE_HPP_
34 #include <type_traits>
36 #include "../../config.hpp"
38 #include <rocprim/detail/various.hpp>
40 #include <hip/hip_fp16.h>
41 #include <hip/hip_bfloat16.h>
43 BEGIN_HIPCUB_NAMESPACE
45 #ifndef DOXYGEN_SHOULD_SKIP_THIS
47 using NullType = ::rocprim::empty_type;
51 template<
bool B,
typename T,
typename F>
54 using Type =
typename std::conditional<B, T, F>::type;
60 static constexpr
bool VALUE = std::is_pointer<T>::value;
66 static constexpr
bool VALUE = std::is_volatile<T>::value;
72 using Type =
typename std::remove_cv<T>::type;
78 static constexpr
bool VALUE = ::rocprim::detail::is_power_of_two<N>();
84 template<
int N,
int CURRENT_VAL = N,
int COUNT = 0>
87 static constexpr
int VALUE = Log2Impl<N, (CURRENT_VAL >> 1), COUNT + 1>::VALUE;
90 template<
int N,
int COUNT>
91 struct Log2Impl<N, 0, COUNT>
93 static constexpr
int VALUE = (1 << (COUNT - 1) < N) ? COUNT : COUNT - 1;
101 static_assert(N != 0,
"The logarithm of zero is undefined");
102 static constexpr
int VALUE = detail::Log2Impl<N>::VALUE;
112 HIPCUB_HOST_DEVICE
inline
116 d_buffers[0] =
nullptr;
117 d_buffers[1] =
nullptr;
120 HIPCUB_HOST_DEVICE
inline
124 d_buffers[0] = d_current;
125 d_buffers[1] = d_alternate;
128 HIPCUB_HOST_DEVICE
inline
131 return d_buffers[selector];
134 HIPCUB_HOST_DEVICE
inline
137 return d_buffers[selector ^ 1];
147 #ifndef DOXYGEN_SHOULD_SKIP_THIS
153 using KeyValuePair = ::rocprim::key_value_pair<Key, Value>;
164 return ::rocprim::double_buffer<T>(source.Current(), source.Alternate());
169 void update_double_buffer(DoubleBuffer<T>& target, ::rocprim::double_buffer<T>& source)
171 if(target.Current() != source.current())
173 target.selector ^= 1;
177 #ifndef DOXYGEN_SHOULD_SKIP_THIS
179 template <
typename T>
180 using is_integral_or_enum =
181 std::integral_constant<bool, std::is_integral<T>::value || std::is_enum<T>::value>;
187 template <
typename NumeratorT,
typename DenominatorT>
188 __host__ __device__ __forceinline__ constexpr NumeratorT
189 DivideAndRoundUp(NumeratorT n, DenominatorT d)
191 static_assert(hipcub::detail::is_integral_or_enum<NumeratorT>::value &&
192 hipcub::detail::is_integral_or_enum<DenominatorT>::value,
193 "DivideAndRoundUp is only intended for integral types.");
196 return static_cast<NumeratorT
>(n / d + (n % d != 0 ? 1 : 0));
199 #ifndef DOXYGEN_SHOULD_SKIP_THIS
206 template <
typename T>
218 ALIGN_BYTES =
sizeof(Pad) -
sizeof(T)
229 #define __HIPCUB_ALIGN_BYTES(t, b) \
230 template <> struct AlignBytes<t> \
231 { enum { ALIGN_BYTES = b }; typedef __align__(b) t Type; };
233 __HIPCUB_ALIGN_BYTES(short4, 8)
234 __HIPCUB_ALIGN_BYTES(ushort4, 8)
235 __HIPCUB_ALIGN_BYTES(int2, 8)
236 __HIPCUB_ALIGN_BYTES(uint2, 8)
237 __HIPCUB_ALIGN_BYTES(
long long, 8)
238 __HIPCUB_ALIGN_BYTES(
unsigned long long, 8)
239 __HIPCUB_ALIGN_BYTES(float2, 8)
240 __HIPCUB_ALIGN_BYTES(
double, 8)
242 __HIPCUB_ALIGN_BYTES(long2, 8)
243 __HIPCUB_ALIGN_BYTES(ulong2, 8)
245 __HIPCUB_ALIGN_BYTES(long2, 16)
246 __HIPCUB_ALIGN_BYTES(ulong2, 16)
248 __HIPCUB_ALIGN_BYTES(int4, 16)
249 __HIPCUB_ALIGN_BYTES(uint4, 16)
250 __HIPCUB_ALIGN_BYTES(float4, 16)
251 __HIPCUB_ALIGN_BYTES(long4, 16)
252 __HIPCUB_ALIGN_BYTES(ulong4, 16)
253 __HIPCUB_ALIGN_BYTES(longlong2, 16)
254 __HIPCUB_ALIGN_BYTES(ulonglong2, 16)
255 __HIPCUB_ALIGN_BYTES(double2, 16)
256 __HIPCUB_ALIGN_BYTES(longlong4, 16)
257 __HIPCUB_ALIGN_BYTES(ulonglong4, 16)
258 __HIPCUB_ALIGN_BYTES(double4, 16)
260 template <typename T> struct AlignBytes<volatile T> : AlignBytes<T> {};
261 template <
typename T>
struct AlignBytes<const T> : AlignBytes<T> {};
262 template <
typename T>
struct AlignBytes<const volatile T> : AlignBytes<T> {};
266 template <
typename T>
270 ALIGN_BYTES = AlignBytes<T>::ALIGN_BYTES
273 template <
typename Unit>
277 UNIT_ALIGN_BYTES = AlignBytes<Unit>::ALIGN_BYTES,
278 IS_MULTIPLE = (
sizeof(T) %
sizeof(Unit) == 0) && (
int(ALIGN_BYTES) % int(UNIT_ALIGN_BYTES) == 0)
283 typedef typename If<IsMultiple<int>::IS_MULTIPLE,
285 typename If<IsMultiple<short>::IS_MULTIPLE,
287 unsigned char>::Type>::Type ShuffleWord;
290 typedef typename If<IsMultiple<long long>::IS_MULTIPLE,
292 ShuffleWord>::Type VolatileWord;
295 typedef typename If<IsMultiple<longlong2>::IS_MULTIPLE,
297 VolatileWord>::Type DeviceWord;
300 typedef typename If<IsMultiple<int4>::IS_MULTIPLE,
302 typename If<IsMultiple<int2>::IS_MULTIPLE,
304 ShuffleWord>::Type>::Type TextureWord;
310 struct UnitWord <float2>
312 typedef int ShuffleWord;
313 typedef unsigned long long VolatileWord;
314 typedef unsigned long long DeviceWord;
315 typedef float2 TextureWord;
320 struct UnitWord <float4>
322 typedef int ShuffleWord;
323 typedef unsigned long long VolatileWord;
324 typedef ulonglong2 DeviceWord;
325 typedef float4 TextureWord;
331 struct UnitWord <char2>
333 typedef unsigned short ShuffleWord;
334 typedef unsigned short VolatileWord;
335 typedef unsigned short DeviceWord;
336 typedef unsigned short TextureWord;
340 template <
typename T>
struct UnitWord<volatile T> : UnitWord<T> {};
341 template <
typename T>
struct UnitWord<const T> : UnitWord<T> {};
342 template <
typename T>
struct UnitWord<const volatile T> : UnitWord<T> {};
357 template <
typename T>
372 __host__ __device__ __forceinline__ T&
Alias()
374 return reinterpret_cast<T&
>(*this);
390 #ifndef DOXYGEN_SHOULD_SKIP_THIS
407 template <Category _CATEGORY,
bool _PRIMITIVE,
bool _NULL_TYPE,
typename _Un
signedBits,
typename T>
411 static const Category CATEGORY = _CATEGORY;
414 PRIMITIVE = _PRIMITIVE,
415 NULL_TYPE = _NULL_TYPE,
423 template <
typename _Un
signedBits,
typename T>
424 struct BaseTraits<UNSIGNED_INTEGER, true, false, _UnsignedBits, T>
426 typedef _UnsignedBits UnsignedBits;
428 static const Category CATEGORY = UNSIGNED_INTEGER;
429 static const UnsignedBits LOWEST_KEY = UnsignedBits(0);
430 static const UnsignedBits MAX_KEY = UnsignedBits(-1);
439 static __device__ __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key)
444 static __device__ __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key)
449 static __host__ __device__ __forceinline__ T Max()
451 UnsignedBits retval = MAX_KEY;
452 return reinterpret_cast<T&
>(retval);
455 static __host__ __device__ __forceinline__ T Lowest()
457 UnsignedBits retval = LOWEST_KEY;
458 return reinterpret_cast<T&
>(retval);
466 template <
typename _Un
signedBits,
typename T>
467 struct BaseTraits<SIGNED_INTEGER, true, false, _UnsignedBits, T>
469 typedef _UnsignedBits UnsignedBits;
471 static const Category CATEGORY = SIGNED_INTEGER;
472 static const UnsignedBits HIGH_BIT = UnsignedBits(1) << ((
sizeof(UnsignedBits) * 8) - 1);
473 static const UnsignedBits LOWEST_KEY = HIGH_BIT;
474 static const UnsignedBits MAX_KEY = UnsignedBits(-1) ^ HIGH_BIT;
482 static __device__ __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key)
484 return key ^ HIGH_BIT;
487 static __device__ __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key)
489 return key ^ HIGH_BIT;
492 static __host__ __device__ __forceinline__ T Max()
494 UnsignedBits retval = MAX_KEY;
495 return reinterpret_cast<T&
>(retval);
498 static __host__ __device__ __forceinline__ T Lowest()
500 UnsignedBits retval = LOWEST_KEY;
501 return reinterpret_cast<T&
>(retval);
505 template <
typename _T>
509 struct FpLimits<float>
511 static __host__ __device__ __forceinline__
float Max() {
512 return std::numeric_limits<float>::max();
515 static __host__ __device__ __forceinline__
float Lowest() {
516 return std::numeric_limits<float>::max() * float(-1);
521 struct FpLimits<double>
523 static __host__ __device__ __forceinline__
double Max() {
524 return std::numeric_limits<double>::max();
527 static __host__ __device__ __forceinline__
double Lowest() {
528 return std::numeric_limits<double>::max() * double(-1);
533 struct FpLimits<__half>
535 static __host__ __device__ __forceinline__ __half Max() {
536 unsigned short max_word = 0x7BFF;
537 return reinterpret_cast<__half&
>(max_word);
540 static __host__ __device__ __forceinline__ __half Lowest() {
541 unsigned short lowest_word = 0xFBFF;
542 return reinterpret_cast<__half&
>(lowest_word);
547 struct FpLimits<hip_bfloat16>
549 static __host__ __device__ __forceinline__ hip_bfloat16 Max() {
550 unsigned short max_word = 0x7F7F;
551 return reinterpret_cast<hip_bfloat16 &
>(max_word);
554 static __host__ __device__ __forceinline__ hip_bfloat16 Lowest() {
555 unsigned short lowest_word = 0xFF7F;
556 return reinterpret_cast<hip_bfloat16 &
>(lowest_word);
563 template <
typename _Un
signedBits,
typename T>
564 struct BaseTraits<FLOATING_POINT, true, false, _UnsignedBits, T>
566 typedef _UnsignedBits UnsignedBits;
568 static const Category CATEGORY = FLOATING_POINT;
569 static const UnsignedBits HIGH_BIT = UnsignedBits(1) << ((
sizeof(UnsignedBits) * 8) - 1);
570 static const UnsignedBits LOWEST_KEY = UnsignedBits(-1);
571 static const UnsignedBits MAX_KEY = UnsignedBits(-1) ^ HIGH_BIT;
579 static __device__ __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key)
581 UnsignedBits mask = (key & HIGH_BIT) ? UnsignedBits(-1) : HIGH_BIT;
585 static __device__ __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key)
587 UnsignedBits mask = (key & HIGH_BIT) ? HIGH_BIT : UnsignedBits(-1);
591 static __host__ __device__ __forceinline__ T Max() {
592 return FpLimits<T>::Max();
595 static __host__ __device__ __forceinline__ T Lowest() {
596 return FpLimits<T>::Lowest();
604 template <
typename T>
struct NumericTraits : BaseTraits<NOT_A_NUMBER, false, false, T, T> {};
606 template <>
struct NumericTraits<NullType> : BaseTraits<NOT_A_NUMBER, false, true, NullType, NullType> {};
608 template <>
struct NumericTraits<char> : BaseTraits<(std::numeric_limits<char>::is_signed) ? SIGNED_INTEGER : UNSIGNED_INTEGER, true, false, unsigned char, char> {};
609 template <>
struct NumericTraits<signed char> : BaseTraits<SIGNED_INTEGER, true, false, unsigned char, signed char> {};
610 template <>
struct NumericTraits<short> : BaseTraits<SIGNED_INTEGER, true, false, unsigned short, short> {};
611 template <>
struct NumericTraits<int> : BaseTraits<SIGNED_INTEGER, true, false, unsigned int, int> {};
612 template <>
struct NumericTraits<long> : BaseTraits<SIGNED_INTEGER, true, false, unsigned long, long> {};
613 template <>
struct NumericTraits<long long> : BaseTraits<SIGNED_INTEGER, true, false, unsigned long long, long long> {};
615 template <>
struct NumericTraits<unsigned char> : BaseTraits<UNSIGNED_INTEGER, true, false, unsigned char, unsigned char> {};
616 template <>
struct NumericTraits<unsigned short> : BaseTraits<UNSIGNED_INTEGER, true, false, unsigned short, unsigned short> {};
617 template <>
struct NumericTraits<unsigned int> : BaseTraits<UNSIGNED_INTEGER, true, false, unsigned int, unsigned int> {};
618 template <>
struct NumericTraits<unsigned long> : BaseTraits<UNSIGNED_INTEGER, true, false, unsigned long, unsigned long> {};
619 template <>
struct NumericTraits<unsigned long long> : BaseTraits<UNSIGNED_INTEGER, true, false, unsigned long long, unsigned long long> {};
621 template <>
struct NumericTraits<float> : BaseTraits<FLOATING_POINT, true, false, unsigned int, float> {};
622 template <>
struct NumericTraits<double> : BaseTraits<FLOATING_POINT, true, false, unsigned long long, double> {};
623 template <>
struct NumericTraits<__half> : BaseTraits<FLOATING_POINT, true, false, unsigned short, __half> {};
624 template <>
struct NumericTraits<hip_bfloat16 > : BaseTraits<FLOATING_POINT, true, false, unsigned short, hip_bfloat16 > {};
626 template <>
struct NumericTraits<bool> : BaseTraits<UNSIGNED_INTEGER, true, false, typename UnitWord<bool>::VolatileWord, bool> {};
631 template <
typename T>
632 struct Traits : NumericTraits<typename RemoveQualifiers<T>::Type> {};
Definition: util_type.hpp:107
Definition: util_type.hpp:53
Definition: util_type.hpp:143
Definition: util_type.hpp:59
Definition: util_type.hpp:65
Definition: util_type.hpp:100
Definition: util_type.hpp:77
Definition: util_type.hpp:71
A storage-backing wrapper that allows types with non-trivial constructors to be aliased in unions.
Definition: util_type.hpp:359
__host__ __device__ __forceinline__ T & Alias()
Alias.
Definition: util_type.hpp:372
UnitWord< T >::DeviceWord DeviceWord
Biggest memory-access word that T is a whole multiple of and is not larger than the alignment of T.
Definition: util_type.hpp:361