Implementing traits for custom types in rocPRIM#
This interface is designed to enable users to provide additional type trait information to rocPRIM, facilitating better compatibility with custom types.
Accurately describing custom types is important for performance optimization and computational correctness.
Custom types that implement arithmetic operators can behave like built-in arithmetic types but might still be interpreted by rocPRIM algorithms as generic struct or class types.
The rocPRIM type traits interface lets users add custom trait information for their types, improving compatibility between these types and rocPRIM algorithms.
This interface is similar to operator overloading.
Traits should be implemented as required by specific algorithms. Some traits can’t be defined if they can be inferred from others.
Interface#
-
template<class T>
struct define# - Overview
This template struct provides an interface for downstream libraries to implement type traits for their custom types. Users can utilize this template struct to define traits for these types. Users should only implement traits as required by specific algorithms, and some traits cannot be defined if they can be inferred from others. This API is not static because of ODR.
- Example
The example below demonstrates how to implement traits for a custom floating-point type.
The example below demonstrates how to implement traits for a custom integral type.// Your type definition struct custom_float_type {}; // Implement the traits template<> struct rocprim::traits::define<custom_float_type> { using is_arithmetic = rocprim::traits::is_arithmetic::values<true>; using number_format = rocprim::traits::number_format::values<traits::number_format::kind::floating_point_type>; using float_bit_mask = rocprim::traits::float_bit_mask::values<uint32_t, 10, 10, 10>; };
// Your type definition struct custom_int_type {}; // Implement the traits template<> struct rocprim::traits::define<custom_int_type> { using is_arithmetic = rocprim::traits::is_arithmetic::values<true>; using number_format = rocprim::traits::number_format::values<traits::number_format::kind::integral_type>; using integral_sign = rocprim::traits::integral_sign::values<traits::integral_sign::kind::signed_type>; };
- Template Parameters:
T – The type for which you want to define traits.
-
template<class T>
struct get# - Overview
This template struct is designed to allow rocPRIM algorithms to retrieve trait information from C++ build-in arithmetic types, rocPRIM types, and custom types. This API is not static because of ODR.
All member functions are
compiled only when invoked.Different algorithms require different traits.
- Example
The following code demonstrates how to retrieve the traits of type
T.// Get the trait in a template parameter template<class T, std::enable_if<rocprim::traits::get<T>().is_integral()>::type* = nullptr> void get_traits_in_template_parameter(){} // Get the trait in a function body template<class T> void get_traits_in_function_body(){ constexpr auto input_traits = rocprim::traits::get<InputType>(); // Then you can use the member functinos constexpr bool is_arithmetic = input_traits.is_arithmetic(); }
- Template Parameters:
T – The type from which you want to retrieve the traits.
Public Functions
-
inline constexpr bool is_arithmetic() const#
Get the value of trait
is_arithmetic.- Returns:
trueifstd::is_arithmetic_v<T>istrue, or if typeTis a rocPRIM arithmetic type, or if theis_arithmetictrait has been defined astrue; otherwise, returnsfalse.
-
inline constexpr bool is_fundamental() const#
Get trait
is_fundamental.- Returns:
trueifTis a fundamental type (that is, rocPRIM arithmetic type, void, or nullptr_t); otherwise, returnsfalse.
-
inline constexpr bool is_build_in() const#
Check if the type is a
build_intype, this function is different fromis_fundamental, because, by implementing traits, downstream code can “hack” into rocprim to let a type bearithmetic, and by following the rules ofstd::is_fundamental,rocprim::is_fundamentalreturns a union set ofstd::is_fundamentalandrocprim::is_arithmetic. So, to check wether a type is a build-in type, please use this function.- Returns:
trueifTis abuild_intype (that is, char, unsigned char, short, unsigned short, int unsigned int, long long, unsigned long long, rocprim::int128_t, rocprim::uint128_t, rocprim::half, float, double);
-
inline constexpr bool is_compound() const#
If
Tis fundamental type, then returnsfalse.- Returns:
falseifTis a fundamental type (that is, rocPRIM arithmetic type, void, or nullptr_t); otherwise, returnstrue.
-
inline constexpr bool is_floating_point() const#
To check if
Tis floating-point type.Warning
You cannot call this function when
is_arithmetic()returnsfalse; doing so will result in a compile-time error.
-
inline constexpr bool is_integral() const#
To check if
Tis integral type.Warning
You cannot call this function when
is_arithmetic()returnsfalse; doing so will result in a compile-time error.
-
inline constexpr bool is_signed() const#
To check if
Tis signed integral type.Warning
You cannot call this function when
is_integral()returnsfalse; doing so will result in a compile-time error.
-
inline constexpr bool is_unsigned() const#
To check if
Tis unsigned integral type.Warning
You cannot call this function when
is_integral()returnsfalse; doing so will result in a compile-time error.
-
inline constexpr bool is_scalar() const#
Get trait
is_scalar.- Returns:
trueifstd::is_scalar_v<T>istrue, or if typeTis a rocPRIM arithmetic type, or if theis_scalartrait has been defined astrue; otherwise, returnsfalse.
-
inline constexpr auto float_bit_mask() const#
Get trait
float_bit_mask.Warning
You cannot call this function when
is_floating_point()returnsfalse; doing so will result in a compile-time error.- Returns:
A constexpr instance of the specialization of
rocprim::traits::float_bit_mask::valuesas provided in the traits definition of type T. If thefloat_bit_mask traitis not defined, it returns the rocprim::detail::float_bit_mask values, provided a specialization ofrocprim::detail::float_bit_mask<T>exists.
-
template<bool Descending = false>
inline constexpr auto radix_key_codec() const# Get trait
radix_key_codec.- Returns:
A constexpr instance of the specialization of
rocprim::traits::radix_key_codec::codecas provided in the traits definition of type T.
Available traits#
-
struct is_arithmetic#
- Definability
Undefinable: For types with
predefined traits.Optional: For other types.
- How to define
using is_arithmetic = rocprim::traits::is_arithmetic::values<true>;
- How to use
rocprim::traits::get<InputType>().is_arithmetic();
-
template<bool Val>
struct values# Value of this trait.
-
struct is_scalar#
Arithmetic types, pointers, member pointers, and null pointers are considered scalar types.
- Definability
Undefinable: For types with
predefined traits.Optional: For other types. If both
is_arithmeticandis_scalarare defined, their values must be consistent; otherwise, a compile-time error will occur.
- How to define
using is_scalar = rocprim::traits::is_scalar::values<true>;
- How to use
rocprim::traits::get<InputType>().is_scalar();
-
template<bool Val>
struct values# Value of this trait.
-
struct number_format#
- Definability
Undefinable: For types with
predefined traitsand non-arithmetic types.Required: If you define
is_arithmeticastrue, you must also define this trait; otherwise, a compile-time error will occur.
- How to define
using number_format = rocprim::traits::number_format::values<number_format::kind::integral_type>;
- How to use
rocprim::traits::get<InputType>().is_integral(); rocprim::traits::get<InputType>().is_floating_point();
Public Types
-
struct integral_sign#
- Definability
Undefinable: For types with
predefined traits, non-arithmetic types and floating-point types.Required: If you define
number_formatasnumber_format::kind::floating_point_type, you must also define this trait; otherwise, a compile-time error will occur.
- How to define
using integral_sign = rocprim::traits::integral_sign::values<traits::integral_sign::kind::signed_type>;
- How to use
rocprim::traits::get<InputType>().is_signed(); rocprim::traits::get<InputType>().is_unsigned();
Public Types
-
struct float_bit_mask#
- Definability
Undefinable: For types with
predefined traits, non-arithmetic types and integral types.Required: If you define
number_formatasnumber_format::kind::unknown_type, you must also define this trait; otherwise, a compile-time error will occur.
- How to define
using float_bit_mask = rocprim::traits::float_bit_mask::values<int,1,1,1>;
- How to use
rocprim::traits::get<InputType>().float_bit_mask();
Warning
doxygenstruct: Cannot find class “rocprim::traits::is_fundamental” in doxygen xml output for project “rocPRIM” from directory: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-rocprim/checkouts/docs-7.0.0/projects/rocprim/docs/doxygen/xml
Type traits wrappers#
Warning
doxygengroup: Cannot find group “rocprim_type_traits_wrapper” in doxygen xml output for project “rocPRIM” from directory: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-rocprim/checkouts/docs-7.0.0/projects/rocprim/docs/doxygen/xml
-
template<class T>
struct is_floating_point : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_floating_point()> An extension of
std::is_floating_pointthat supports additional arithmetic types, includingrocprim::half,rocprim::bfloat16, and any types with traitrocprim::traits::number_format::values<number_format::kind::floating_point_type>implemented.
-
template<class T>
struct is_integral : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_integral()> An extension of
std::is_integralthat supports additional arithmetic types, includingrocprim::int128_t,rocprim::uint128_t, and any types with traitrocprim::traits::number_format::values<number_format::kind::integral_type>implemented.
-
template<class T>
struct is_arithmetic : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_arithmetic()> An extension of
std::is_arithmeticthat supports additional arithmetic types, including any types with traitrocprim::traits::is_arithmetic::values<true>implemented.
-
template<class T>
struct is_fundamental : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_fundamental()> An extension of
std::is_fundamentalthat supports additional arithmetic types, including any types with traitrocprim::traits::is_arithmetic::values<true>implemented.
-
template<class T>
struct is_unsigned : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_unsigned()> An extension of
std::is_unsignedthat supports additional arithmetic types, includingrocprim::uint128_t, and any types with traitrocprim::traits::integral_sign::values<integral_sign::kind::unsigned_type>implemented.
-
template<class T>
struct is_signed : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_signed()> An extension of
std::is_signedthat supports additional arithmetic types, includingrocprim::int128_t, and any types with traitrocprim::traits::integral_sign::values<integral_sign::kind::signed_type>implemented.
-
template<class T>
struct is_scalar : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_scalar()> An extension of
std::is_scalarthat supports additional arithmetic types, including any types with traitrocprim::traits::is_scalar::values<true>implemented.
Types with predefined traits#
-
template<>
struct define<float> Public Types
-
using float_bit_mask = traits::float_bit_mask::values<uint32_t, 0x80000000, 0x7F800000, 0x007FFFFF>
-
using float_bit_mask = traits::float_bit_mask::values<uint32_t, 0x80000000, 0x7F800000, 0x007FFFFF>
-
template<>
struct define<double> Public Types
-
using float_bit_mask = traits::float_bit_mask::values<uint64_t, 0x8000000000000000, 0x7FF0000000000000, 0x000FFFFFFFFFFFFF>
-
using float_bit_mask = traits::float_bit_mask::values<uint64_t, 0x8000000000000000, 0x7FF0000000000000, 0x000FFFFFFFFFFFFF>
-
template<>
struct define<rocprim::bfloat16> Public Types
-
using is_arithmetic = traits::is_arithmetic::values<true>
-
using number_format = traits::number_format::values<traits::number_format::kind::floating_point_type>
-
using float_bit_mask = traits::float_bit_mask::values<uint16_t, 0x8000, 0x7F80, 0x007F>
-
using is_arithmetic = traits::is_arithmetic::values<true>
-
template<>
struct define<rocprim::half> Public Types
-
using is_arithmetic = traits::is_arithmetic::values<true>
-
using number_format = traits::number_format::values<traits::number_format::kind::floating_point_type>
-
using float_bit_mask = traits::float_bit_mask::values<uint16_t, 0x8000, 0x7F80, 0x007F>
-
using is_arithmetic = traits::is_arithmetic::values<true>
-
template<>
struct define<rocprim::int128_t> : public std::conditional_t<std::is_arithmetic<rocprim::int128_t>::value, traits::define<void>, detail::define_int128_t>
-
template<>
struct define<rocprim::uint128_t> : public std::conditional_t<std::is_arithmetic<rocprim::uint128_t>::value, traits::define<void>, detail::define_uint128_t>