Implementing traits for custom types in rocPRIM#
Overview#
This interface is designed to enable users to provide additional type trait information to rocPRIM, facilitating better compatibility with custom types.
Accurately describing custom types is important for performance optimization and computational correctness.
Custom types that implement arithmetic operators can behave like built-in arithmetic types but might still be interpreted by rocPRIM algorithms as generic struct or class types.
The rocPRIM type traits interface lets users add custom trait information for their types, improving compatibility between these types and rocPRIM algorithms.
This interface is similar to operator overloading.
Traits should be implemented as required by specific algorithms. Some traits can’t be defined if they can be inferred from others.
Interface#
- 
template<class T>
 struct define#
- #include <type_traits_interface.hpp>- Overview
- This template struct provides an interface for downstream libraries to implement type traits for their custom types. Users can utilize this template struct to define traits for these types. Users should only implement traits as required by specific algorithms, and some traits cannot be defined if they can be inferred from others. This API is not static because of ODR. 
- Example
- The example below demonstrates how to implement traits for a custom floating-point type. The example below demonstrates how to implement traits for a custom integral type.- // Your type definition struct custom_float_type {}; // Implement the traits template<> struct rocprim::traits::define<custom_float_type> { using is_arithmetic = rocprim::traits::is_arithmetic::values<true>; using number_format = rocprim::traits::number_format::values<traits::number_format::kind::floating_point_type>; using float_bit_mask = rocprim::traits::float_bit_mask::values<uint32_t, 10, 10, 10>; }; - // Your type definition struct custom_int_type {}; // Implement the traits template<> struct rocprim::traits::define<custom_int_type> { using is_arithmetic = rocprim::traits::is_arithmetic::values<true>; using number_format = rocprim::traits::number_format::values<traits::number_format::kind::integral_type>; using integral_sign = rocprim::traits::integral_sign::values<traits::integral_sign::kind::signed_type>; }; 
 - Template Parameters:
- T – The type for which you want to define traits. 
 
- 
template<class T>
 struct get#
- #include <type_traits_interface.hpp>predef - Overview
- This template struct is designed to allow rocPRIM algorithms to retrieve trait information from C++ build-in arithmetic types, rocPRIM types, and custom types. This API is not static because of ODR. - All member functions are - compiled only when invoked.
- Different algorithms require different traits. 
 
- Example
- The following code demonstrates how to retrieve the traits of type - T.- // Get the trait in a template parameter template<class T, std::enable_if<rocprim::traits::get<T>().is_integral()>::type* = nullptr> void get_traits_in_template_parameter(){} // Get the trait in a function body template<class T> void get_traits_in_function_body(){ constexpr auto input_traits = rocprim::traits::get<InputType>(); // Then you can use the member functinos constexpr bool is_arithmetic = input_traits.is_arithmetic(); } 
 - Template Parameters:
- T – The type from which you want to retrieve the traits. 
 Public Functions - 
inline constexpr bool is_arithmetic() const#
- Get the value of trait - is_arithmetic.- Returns:
- trueif- std::is_arithmetic_v<T>is- true, or if type- Tis a rocPRIM arithmetic type, or if the- is_arithmetictrait has been defined as- true; otherwise, returns- false.
 
 - 
inline constexpr bool is_fundamental() const#
- Get trait - is_fundamental.- Returns:
- trueif- Tis a fundamental type (that is, rocPRIM arithmetic type, void, or nullptr_t); otherwise, returns- false.
 
 - 
inline constexpr bool is_compound() const#
- If - Tis fundamental type, then returns- false.- Returns:
- falseif- Tis a fundamental type (that is, rocPRIM arithmetic type, void, or nullptr_t); otherwise, returns- true.
 
 - 
inline constexpr bool is_floating_point() const#
- To check if - Tis floating-point type.- Warning - You cannot call this function when - is_arithmetic()returns- false; doing so will result in a compile-time error.
 - 
inline constexpr bool is_integral() const#
- To check if - Tis integral type.- Warning - You cannot call this function when - is_arithmetic()returns- false; doing so will result in a compile-time error.
 - 
inline constexpr bool is_signed() const#
- To check if - Tis signed integral type.- Warning - You cannot call this function when - is_integral()returns- false; doing so will result in a compile-time error.
 - 
inline constexpr bool is_unsigned() const#
- To check if - Tis unsigned integral type.- Warning - You cannot call this function when - is_integral()returns- false; doing so will result in a compile-time error.
 - 
inline constexpr bool is_scalar() const#
- Get trait - is_scalar.- Returns:
- trueif- std::is_scalar_v<T>is- true, or if type- Tis a rocPRIM arithmetic type, or if the- is_scalartrait has been defined as- true; otherwise, returns- false.
 
 - 
inline constexpr auto float_bit_mask() const#
- Get trait - float_bit_mask.- Warning - You cannot call this function when - is_floating_point()returns- false; doing so will result in a compile-time error.- Returns:
- A constexpr instance of the specialization of - rocprim::traits::float_bit_mask::valuesas provided in the traits definition of type T. If the- float_bit_mask traitis not defined, it returns the rocprim::detail::float_bit_mask values, provided a specialization of- rocprim::detail::float_bit_mask<T>exists.
 
 
Available traits#
- 
struct is_arithmetic#
- #include <type_traits_interface.hpp>- Definability
- Undefinable: For types with - predefined traits.
- Optional: For other types. 
 
- How to define
- using is_arithmetic = rocprim::traits::is_arithmetic::values<true>; 
 - How to use
- rocprim::traits::get<InputType>().is_arithmetic(); 
 - 
template<bool Val>
 struct values#
- #include <type_traits_interface.hpp>Value of this trait. 
 
- 
struct is_scalar#
- #include <type_traits_interface.hpp>Arithmetic types, pointers, member pointers, and null pointers are considered scalar types. - Definability
- Undefinable: For types with - predefined traits.
- Optional: For other types. If both - is_arithmeticand- is_scalarare defined, their values must be consistent; otherwise, a compile-time error will occur.
 
- How to define
- using is_scalar = rocprim::traits::is_scalar::values<true>; 
 - How to use
- rocprim::traits::get<InputType>().is_scalar(); 
 - 
template<bool Val>
 struct values#
- #include <type_traits_interface.hpp>Value of this trait. 
 
- 
struct number_format#
- #include <type_traits_interface.hpp>- Definability
- Undefinable: For types with - predefined traitsand non-arithmetic types.
- Required: If you define - is_arithmeticas- true, you must also define this trait; otherwise, a compile-time error will occur.
 
- How to define
- using number_format = rocprim::traits::number_format::values<number_format::kind::integral_type>; 
 - How to use
- rocprim::traits::get<InputType>().is_integral(); rocprim::traits::get<InputType>().is_floating_point(); 
 Public Types 
- 
struct integral_sign#
- #include <type_traits_interface.hpp>- Definability
- Undefinable: For types with - predefined traits, non-arithmetic types and floating-point types.
- Required: If you define - number_formatas- number_format::kind::floating_point_type, you must also define this trait; otherwise, a compile-time error will occur.
 
- How to define
- using integral_sign = rocprim::traits::integral_sign::values<traits::integral_sign::kind::signed_type>; 
 - How to use
- rocprim::traits::get<InputType>().is_signed(); rocprim::traits::get<InputType>().is_unsigned(); 
 Public Types 
- 
struct float_bit_mask#
- #include <type_traits_interface.hpp>- Definability
- Undefinable: For types with - predefined traits, non-arithmetic types and integral types.
- Required: If you define - number_formatas- number_format::kind::unknown_type, you must also define this trait; otherwise, a compile-time error will occur.
 
- How to define
- using float_bit_mask = rocprim::traits::float_bit_mask::values<int,1,1,1>; 
 - How to use
- rocprim::traits::get<InputType>().float_bit_mask(); 
 Warning For some types, if this trait is not implemented in their traits definition, it will link to rocprim::detail::float_bit_maskto maintain compatibility with downstream libraries. However, this linkage will be removed in the next major release. Please ensure that these types are updated to the latest interface.
- 
struct is_fundamental#
- #include <type_traits_interface.hpp>The trait is_fundamentalis undefinable, as it is the union ofstd::is_fundamentalandrocprim::traits::is_arithmetic.- Definability
- Undefinable: If you attempt to define this trait in any form, a compile-time error will occur. 
 
- How to use
- rocprim::traits::get<InputType>().is_fundamental(); rocprim::traits::get<InputType>().is_compound(); 
 - 
template<bool Val>
 struct values#
- #include <type_traits_interface.hpp>Value of this trait. 
 
Type traits wrappers#
- 
template<class T>
 struct is_floating_point : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_floating_point()>
- #include <type_traits_interface.hpp>An extension of std::is_floating_pointthat supports additional arithmetic types, includingrocprim::half,rocprim::bfloat16, and any types with traitrocprim::traits::number_format::values<number_format::kind::floating_point_type>implemented.
- 
template<class T>
 struct is_integral : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_integral()>
- #include <type_traits_interface.hpp>An extension of std::is_integralthat supports additional arithmetic types, includingrocprim::int128_t,rocprim::uint128_t, and any types with traitrocprim::traits::number_format::values<number_format::kind::integral_type>implemented.
- 
template<class T>
 struct is_arithmetic : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_arithmetic()>
- #include <type_traits_interface.hpp>An extension of std::is_arithmeticthat supports additional arithmetic types, including any types with traitrocprim::traits::is_arithmetic::values<true>implemented.
- 
template<class T>
 struct is_fundamental : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_fundamental()>
- #include <type_traits_interface.hpp>An extension of std::is_fundamentalthat supports additional arithmetic types, including any types with traitrocprim::traits::is_arithmetic::values<true>implemented.
- 
template<class T>
 struct is_unsigned : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_unsigned()>
- #include <type_traits_interface.hpp>An extension of std::is_unsignedthat supports additional arithmetic types, includingrocprim::uint128_t, and any types with traitrocprim::traits::integral_sign::values<integral_sign::kind::unsigned_type>implemented.
- 
template<class T>
 struct is_signed : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_signed()>
- #include <type_traits_interface.hpp>An extension of std::is_signedthat supports additional arithmetic types, includingrocprim::int128_t, and any types with traitrocprim::traits::integral_sign::values<integral_sign::kind::signed_type>implemented.
- 
template<class T>
 struct is_scalar : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_scalar()>
- #include <type_traits_interface.hpp>An extension of std::is_scalarthat supports additional arithmetic types, including any types with traitrocprim::traits::is_scalar::values<true>implemented.
Types with predefined traits#
- 
template<>
 struct define<float>
- Public Types - 
using float_bit_mask = traits::float_bit_mask::values<uint32_t, 0x80000000, 0x7F800000, 0x007FFFFF>
 
- 
using float_bit_mask = traits::float_bit_mask::values<uint32_t, 0x80000000, 0x7F800000, 0x007FFFFF>
- 
template<>
 struct define<double>
- Public Types - 
using float_bit_mask = traits::float_bit_mask::values<uint64_t, 0x8000000000000000, 0x7FF0000000000000, 0x000FFFFFFFFFFFFF>
 
- 
using float_bit_mask = traits::float_bit_mask::values<uint64_t, 0x8000000000000000, 0x7FF0000000000000, 0x000FFFFFFFFFFFFF>
- 
template<>
 struct define<rocprim::bfloat16>
- Public Types - 
using is_arithmetic = traits::is_arithmetic::values<true>
 - 
using number_format = traits::number_format::values<traits::number_format::kind::floating_point_type>
 - 
using float_bit_mask = traits::float_bit_mask::values<uint16_t, 0x8000, 0x7F80, 0x007F>
 
- 
using is_arithmetic = traits::is_arithmetic::values<true>
- 
template<>
 struct define<rocprim::half>
- Public Types - 
using is_arithmetic = traits::is_arithmetic::values<true>
 - 
using number_format = traits::number_format::values<traits::number_format::kind::floating_point_type>
 - 
using float_bit_mask = traits::float_bit_mask::values<uint16_t, 0x8000, 0x7F80, 0x007F>
 
- 
using is_arithmetic = traits::is_arithmetic::values<true>
- 
template<>
 struct define<rocprim::int128_t> : public std::conditional_t<std::is_arithmetic<rocprim::int128_t>::value, traits::define<void>, detail::define_int128_t>
- 
template<>
 struct define<rocprim::uint128_t> : public std::conditional_t<std::is_arithmetic<rocprim::uint128_t>::value, traits::define<void>, detail::define_uint128_t>