30 #ifndef HIPCUB_ROCPRIM_UTIL_TYPE_HPP_
31 #define HIPCUB_ROCPRIM_UTIL_TYPE_HPP_
34 #include <type_traits>
36 #include "../../config.hpp"
38 #include <rocprim/detail/various.hpp>
39 #include <rocprim/types/future_value.hpp>
41 #include <hip/hip_fp16.h>
42 #include <hip/hip_bfloat16.h>
44 BEGIN_HIPCUB_NAMESPACE
46 #ifndef DOXYGEN_SHOULD_SKIP_THIS
48 using NullType = ::rocprim::empty_type;
52 template<
bool B,
typename T,
typename F>
55 using Type =
typename std::conditional<B, T, F>::type;
61 static constexpr
bool VALUE = std::is_pointer<T>::value;
67 static constexpr
bool VALUE = std::is_volatile<T>::value;
73 using Type =
typename std::remove_cv<T>::type;
79 static constexpr
bool VALUE = ::rocprim::detail::is_power_of_two(N);
85 template<
int N,
int CURRENT_VAL = N,
int COUNT = 0>
88 static constexpr
int VALUE = Log2Impl<N, (CURRENT_VAL >> 1), COUNT + 1>::VALUE;
91 template<
int N,
int COUNT>
92 struct Log2Impl<N, 0, COUNT>
94 static constexpr
int VALUE = (1 << (COUNT - 1) < N) ? COUNT : COUNT - 1;
102 static_assert(N != 0,
"The logarithm of zero is undefined");
103 static constexpr
int VALUE = detail::Log2Impl<N>::VALUE;
113 HIPCUB_HOST_DEVICE
inline
117 d_buffers[0] =
nullptr;
118 d_buffers[1] =
nullptr;
121 HIPCUB_HOST_DEVICE
inline
125 d_buffers[0] = d_current;
126 d_buffers[1] = d_alternate;
129 HIPCUB_HOST_DEVICE
inline
132 return d_buffers[selector];
135 HIPCUB_HOST_DEVICE
inline
138 return d_buffers[selector ^ 1];
148 #ifndef DOXYGEN_SHOULD_SKIP_THIS
154 using KeyValuePair = ::rocprim::key_value_pair<Key, Value>;
158 template <
typename T,
typename Iter = T*>
159 using FutureValue = ::rocprim::future_value<T, Iter>;
168 return ::rocprim::double_buffer<T>(source.Current(), source.Alternate());
173 void update_double_buffer(DoubleBuffer<T>& target, ::rocprim::double_buffer<T>& source)
175 if(target.Current() != source.current())
177 target.selector ^= 1;
181 #ifndef DOXYGEN_SHOULD_SKIP_THIS
183 template <
typename T>
184 using is_integral_or_enum =
185 std::integral_constant<bool, std::is_integral<T>::value || std::is_enum<T>::value>;
191 template <
typename NumeratorT,
typename DenominatorT>
192 __host__ __device__ __forceinline__ constexpr NumeratorT
193 DivideAndRoundUp(NumeratorT n, DenominatorT d)
195 static_assert(hipcub::detail::is_integral_or_enum<NumeratorT>::value &&
196 hipcub::detail::is_integral_or_enum<DenominatorT>::value,
197 "DivideAndRoundUp is only intended for integral types.");
200 return static_cast<NumeratorT
>(n / d + (n % d != 0 ? 1 : 0));
203 #ifndef DOXYGEN_SHOULD_SKIP_THIS
210 template <
typename T>
222 ALIGN_BYTES =
sizeof(Pad) -
sizeof(T)
233 #define __HIPCUB_ALIGN_BYTES(t, b) \
234 template <> struct AlignBytes<t> \
235 { enum { ALIGN_BYTES = b }; typedef __align__(b) t Type; };
237 __HIPCUB_ALIGN_BYTES(short4, 8)
238 __HIPCUB_ALIGN_BYTES(ushort4, 8)
239 __HIPCUB_ALIGN_BYTES(int2, 8)
240 __HIPCUB_ALIGN_BYTES(uint2, 8)
241 __HIPCUB_ALIGN_BYTES(
long long, 8)
242 __HIPCUB_ALIGN_BYTES(
unsigned long long, 8)
243 __HIPCUB_ALIGN_BYTES(float2, 8)
244 __HIPCUB_ALIGN_BYTES(
double, 8)
246 __HIPCUB_ALIGN_BYTES(long2, 8)
247 __HIPCUB_ALIGN_BYTES(ulong2, 8)
249 __HIPCUB_ALIGN_BYTES(long2, 16)
250 __HIPCUB_ALIGN_BYTES(ulong2, 16)
252 __HIPCUB_ALIGN_BYTES(int4, 16)
253 __HIPCUB_ALIGN_BYTES(uint4, 16)
254 __HIPCUB_ALIGN_BYTES(float4, 16)
255 __HIPCUB_ALIGN_BYTES(long4, 16)
256 __HIPCUB_ALIGN_BYTES(ulong4, 16)
257 __HIPCUB_ALIGN_BYTES(longlong2, 16)
258 __HIPCUB_ALIGN_BYTES(ulonglong2, 16)
259 __HIPCUB_ALIGN_BYTES(double2, 16)
260 __HIPCUB_ALIGN_BYTES(longlong4, 16)
261 __HIPCUB_ALIGN_BYTES(ulonglong4, 16)
262 __HIPCUB_ALIGN_BYTES(double4, 16)
264 template <typename T> struct AlignBytes<volatile T> : AlignBytes<T> {};
265 template <
typename T>
struct AlignBytes<const T> : AlignBytes<T> {};
266 template <
typename T>
struct AlignBytes<const volatile T> : AlignBytes<T> {};
270 template <
typename T>
274 ALIGN_BYTES = AlignBytes<T>::ALIGN_BYTES
277 template <
typename Unit>
281 UNIT_ALIGN_BYTES = AlignBytes<Unit>::ALIGN_BYTES,
282 IS_MULTIPLE = (
sizeof(T) %
sizeof(Unit) == 0) && (
int(ALIGN_BYTES) % int(UNIT_ALIGN_BYTES) == 0)
287 typedef typename If<IsMultiple<int>::IS_MULTIPLE,
289 typename If<IsMultiple<short>::IS_MULTIPLE,
291 unsigned char>::Type>::Type ShuffleWord;
294 typedef typename If<IsMultiple<long long>::IS_MULTIPLE,
296 ShuffleWord>::Type VolatileWord;
299 typedef typename If<IsMultiple<longlong2>::IS_MULTIPLE,
301 VolatileWord>::Type DeviceWord;
304 typedef typename If<IsMultiple<int4>::IS_MULTIPLE,
306 typename If<IsMultiple<int2>::IS_MULTIPLE,
308 ShuffleWord>::Type>::Type TextureWord;
314 struct UnitWord <float2>
316 typedef int ShuffleWord;
317 typedef unsigned long long VolatileWord;
318 typedef unsigned long long DeviceWord;
319 typedef float2 TextureWord;
324 struct UnitWord <float4>
326 typedef int ShuffleWord;
327 typedef unsigned long long VolatileWord;
328 typedef ulonglong2 DeviceWord;
329 typedef float4 TextureWord;
335 struct UnitWord <char2>
337 typedef unsigned short ShuffleWord;
338 typedef unsigned short VolatileWord;
339 typedef unsigned short DeviceWord;
340 typedef unsigned short TextureWord;
344 template <
typename T>
struct UnitWord<volatile T> : UnitWord<T> {};
345 template <
typename T>
struct UnitWord<const T> : UnitWord<T> {};
346 template <
typename T>
struct UnitWord<const volatile T> : UnitWord<T> {};
361 template <
typename T>
376 __host__ __device__ __forceinline__ T&
Alias()
378 return reinterpret_cast<T&
>(*this);
394 #ifndef DOXYGEN_SHOULD_SKIP_THIS
411 template <Category _CATEGORY,
bool _PRIMITIVE,
bool _NULL_TYPE,
typename _Un
signedBits,
typename T>
415 static const Category CATEGORY = _CATEGORY;
418 PRIMITIVE = _PRIMITIVE,
419 NULL_TYPE = _NULL_TYPE,
427 template <
typename _Un
signedBits,
typename T>
428 struct BaseTraits<UNSIGNED_INTEGER, true, false, _UnsignedBits, T>
430 typedef _UnsignedBits UnsignedBits;
432 static const Category CATEGORY = UNSIGNED_INTEGER;
433 static const UnsignedBits LOWEST_KEY = UnsignedBits(0);
434 static const UnsignedBits MAX_KEY = UnsignedBits(-1);
443 static __device__ __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key)
448 static __device__ __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key)
453 static __host__ __device__ __forceinline__ T Max()
455 UnsignedBits retval = MAX_KEY;
456 return reinterpret_cast<T&
>(retval);
459 static __host__ __device__ __forceinline__ T Lowest()
461 UnsignedBits retval = LOWEST_KEY;
462 return reinterpret_cast<T&
>(retval);
470 template <
typename _Un
signedBits,
typename T>
471 struct BaseTraits<SIGNED_INTEGER, true, false, _UnsignedBits, T>
473 typedef _UnsignedBits UnsignedBits;
475 static const Category CATEGORY = SIGNED_INTEGER;
476 static const UnsignedBits HIGH_BIT = UnsignedBits(1) << ((
sizeof(UnsignedBits) * 8) - 1);
477 static const UnsignedBits LOWEST_KEY = HIGH_BIT;
478 static const UnsignedBits MAX_KEY = UnsignedBits(-1) ^ HIGH_BIT;
486 static __device__ __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key)
488 return key ^ HIGH_BIT;
491 static __device__ __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key)
493 return key ^ HIGH_BIT;
496 static __host__ __device__ __forceinline__ T Max()
498 UnsignedBits retval = MAX_KEY;
499 return reinterpret_cast<T&
>(retval);
502 static __host__ __device__ __forceinline__ T Lowest()
504 UnsignedBits retval = LOWEST_KEY;
505 return reinterpret_cast<T&
>(retval);
509 template <
typename _T>
513 struct FpLimits<float>
515 static __host__ __device__ __forceinline__
float Max() {
516 return std::numeric_limits<float>::max();
519 static __host__ __device__ __forceinline__
float Lowest() {
520 return std::numeric_limits<float>::max() * float(-1);
525 struct FpLimits<double>
527 static __host__ __device__ __forceinline__
double Max() {
528 return std::numeric_limits<double>::max();
531 static __host__ __device__ __forceinline__
double Lowest() {
532 return std::numeric_limits<double>::max() * double(-1);
537 struct FpLimits<__half>
539 static __host__ __device__ __forceinline__ __half Max() {
540 unsigned short max_word = 0x7BFF;
541 return reinterpret_cast<__half&
>(max_word);
544 static __host__ __device__ __forceinline__ __half Lowest() {
545 unsigned short lowest_word = 0xFBFF;
546 return reinterpret_cast<__half&
>(lowest_word);
551 struct FpLimits<hip_bfloat16>
553 static __host__ __device__ __forceinline__ hip_bfloat16 Max() {
554 unsigned short max_word = 0x7F7F;
555 return reinterpret_cast<hip_bfloat16 &
>(max_word);
558 static __host__ __device__ __forceinline__ hip_bfloat16 Lowest() {
559 unsigned short lowest_word = 0xFF7F;
560 return reinterpret_cast<hip_bfloat16 &
>(lowest_word);
567 template <
typename _Un
signedBits,
typename T>
568 struct BaseTraits<FLOATING_POINT, true, false, _UnsignedBits, T>
570 typedef _UnsignedBits UnsignedBits;
572 static const Category CATEGORY = FLOATING_POINT;
573 static const UnsignedBits HIGH_BIT = UnsignedBits(1) << ((
sizeof(UnsignedBits) * 8) - 1);
574 static const UnsignedBits LOWEST_KEY = UnsignedBits(-1);
575 static const UnsignedBits MAX_KEY = UnsignedBits(-1) ^ HIGH_BIT;
583 static __device__ __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key)
585 UnsignedBits mask = (key & HIGH_BIT) ? UnsignedBits(-1) : HIGH_BIT;
589 static __device__ __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key)
591 UnsignedBits mask = (key & HIGH_BIT) ? HIGH_BIT : UnsignedBits(-1);
595 static __host__ __device__ __forceinline__ T Max() {
596 return FpLimits<T>::Max();
599 static __host__ __device__ __forceinline__ T Lowest() {
600 return FpLimits<T>::Lowest();
608 template <
typename T>
struct NumericTraits : BaseTraits<NOT_A_NUMBER, false, false, T, T> {};
610 template <>
struct NumericTraits<NullType> : BaseTraits<NOT_A_NUMBER, false, true, NullType, NullType> {};
612 template <>
struct NumericTraits<char> : BaseTraits<(std::numeric_limits<char>::is_signed) ? SIGNED_INTEGER : UNSIGNED_INTEGER, true, false, unsigned char, char> {};
613 template <>
struct NumericTraits<signed char> : BaseTraits<SIGNED_INTEGER, true, false, unsigned char, signed char> {};
614 template <>
struct NumericTraits<short> : BaseTraits<SIGNED_INTEGER, true, false, unsigned short, short> {};
615 template <>
struct NumericTraits<int> : BaseTraits<SIGNED_INTEGER, true, false, unsigned int, int> {};
616 template <>
struct NumericTraits<long> : BaseTraits<SIGNED_INTEGER, true, false, unsigned long, long> {};
617 template <>
struct NumericTraits<long long> : BaseTraits<SIGNED_INTEGER, true, false, unsigned long long, long long> {};
619 template <>
struct NumericTraits<unsigned char> : BaseTraits<UNSIGNED_INTEGER, true, false, unsigned char, unsigned char> {};
620 template <>
struct NumericTraits<unsigned short> : BaseTraits<UNSIGNED_INTEGER, true, false, unsigned short, unsigned short> {};
621 template <>
struct NumericTraits<unsigned int> : BaseTraits<UNSIGNED_INTEGER, true, false, unsigned int, unsigned int> {};
622 template <>
struct NumericTraits<unsigned long> : BaseTraits<UNSIGNED_INTEGER, true, false, unsigned long, unsigned long> {};
623 template <>
struct NumericTraits<unsigned long long> : BaseTraits<UNSIGNED_INTEGER, true, false, unsigned long long, unsigned long long> {};
625 template <>
struct NumericTraits<float> : BaseTraits<FLOATING_POINT, true, false, unsigned int, float> {};
626 template <>
struct NumericTraits<double> : BaseTraits<FLOATING_POINT, true, false, unsigned long long, double> {};
627 template <>
struct NumericTraits<__half> : BaseTraits<FLOATING_POINT, true, false, unsigned short, __half> {};
628 template <>
struct NumericTraits<hip_bfloat16 > : BaseTraits<FLOATING_POINT, true, false, unsigned short, hip_bfloat16 > {};
630 template <>
struct NumericTraits<bool> : BaseTraits<UNSIGNED_INTEGER, true, false, typename UnitWord<bool>::VolatileWord, bool> {};
635 template <
typename T>
636 struct Traits : NumericTraits<typename RemoveQualifiers<T>::Type> {};
Definition: util_type.hpp:108
Definition: util_type.hpp:54
Definition: util_type.hpp:144
Definition: util_type.hpp:60
Definition: util_type.hpp:66
Definition: util_type.hpp:101
Definition: util_type.hpp:78
Definition: util_type.hpp:72
A storage-backing wrapper that allows types with non-trivial constructors to be aliased in unions.
Definition: util_type.hpp:363
__host__ __device__ __forceinline__ T & Alias()
Alias.
Definition: util_type.hpp:376
UnitWord< T >::DeviceWord DeviceWord
Biggest memory-access word that T is a whole multiple of and is not larger than the alignment of T.
Definition: util_type.hpp:365