Implementing traits for custom types in rocPRIM#
Overview#
This interface is designed to enable users to provide additional type trait information to rocPRIM, facilitating better compatibility with custom types.
Accurately describing custom types is important for performance optimization and computational correctness.
Custom types that implement arithmetic operators can behave like built-in arithmetic types but might still be interpreted by rocPRIM algorithms as generic struct
or class
types.
The rocPRIM type traits interface lets users add custom trait information for their types, improving compatibility between these types and rocPRIM algorithms.
This interface is similar to operator overloading.
Traits should be implemented as required by specific algorithms. Some traits can’t be defined if they can be inferred from others.
Interface#
-
template<class T>
struct define# - Overview
This template struct provides an interface for downstream libraries to implement type traits for their custom types. Users can utilize this template struct to define traits for these types. Users should only implement traits as required by specific algorithms, and some traits cannot be defined if they can be inferred from others. This API is not static because of ODR.
- Example
The example below demonstrates how to implement traits for a custom floating-point type.
The example below demonstrates how to implement traits for a custom integral type.// Your type definition struct custom_float_type {}; // Implement the traits template<> struct rocprim::traits::define<custom_float_type> { using is_arithmetic = rocprim::traits::is_arithmetic::values<true>; using number_format = rocprim::traits::number_format::values<traits::number_format::kind::floating_point_type>; using float_bit_mask = rocprim::traits::float_bit_mask::values<uint32_t, 10, 10, 10>; };
// Your type definition struct custom_int_type {}; // Implement the traits template<> struct rocprim::traits::define<custom_int_type> { using is_arithmetic = rocprim::traits::is_arithmetic::values<true>; using number_format = rocprim::traits::number_format::values<traits::number_format::kind::integral_type>; using integral_sign = rocprim::traits::integral_sign::values<traits::integral_sign::kind::signed_type>; };
- Template Parameters:
T – The type for which you want to define traits.
-
template<class T>
struct get# - Overview
This template struct is designed to allow rocPRIM algorithms to retrieve trait information from C++ build-in arithmetic types, rocPRIM types, and custom types. This API is not static because of ODR.
All member functions are
compiled only when invoked
.Different algorithms require different traits.
- Example
The following code demonstrates how to retrieve the traits of type
T
.// Get the trait in a template parameter template<class T, std::enable_if<rocprim::traits::get<T>().is_integral()>::type* = nullptr> void get_traits_in_template_parameter(){} // Get the trait in a function body template<class T> void get_traits_in_function_body(){ constexpr auto input_traits = rocprim::traits::get<InputType>(); // Then you can use the member functinos constexpr bool is_arithmetic = input_traits.is_arithmetic(); }
- Template Parameters:
T – The type from which you want to retrieve the traits.
Public Functions
-
inline constexpr bool is_arithmetic() const#
Get the value of trait
is_arithmetic
.- Returns:
true
ifstd::is_arithmetic_v<T>
istrue
, or if typeT
is a rocPRIM arithmetic type, or if theis_arithmetic
trait has been defined astrue
; otherwise, returnsfalse
.
-
inline constexpr bool is_fundamental() const#
Get trait
is_fundamental
.- Returns:
true
ifT
is a fundamental type (that is, rocPRIM arithmetic type, void, or nullptr_t); otherwise, returnsfalse
.
-
inline constexpr bool is_compound() const#
If
T
is fundamental type, then returnsfalse
.- Returns:
false
ifT
is a fundamental type (that is, rocPRIM arithmetic type, void, or nullptr_t); otherwise, returnstrue
.
-
inline constexpr bool is_floating_point() const#
To check if
T
is floating-point type.Warning
You cannot call this function when
is_arithmetic()
returnsfalse
; doing so will result in a compile-time error.
-
inline constexpr bool is_integral() const#
To check if
T
is integral type.Warning
You cannot call this function when
is_arithmetic()
returnsfalse
; doing so will result in a compile-time error.
-
inline constexpr bool is_signed() const#
To check if
T
is signed integral type.Warning
You cannot call this function when
is_integral()
returnsfalse
; doing so will result in a compile-time error.
-
inline constexpr bool is_unsigned() const#
To check if
T
is unsigned integral type.Warning
You cannot call this function when
is_integral()
returnsfalse
; doing so will result in a compile-time error.
-
inline constexpr bool is_scalar() const#
Get trait
is_scalar
.- Returns:
true
ifstd::is_scalar_v<T>
istrue
, or if typeT
is a rocPRIM arithmetic type, or if theis_scalar
trait has been defined astrue
; otherwise, returnsfalse
.
-
inline constexpr auto float_bit_mask() const#
Get trait
float_bit_mask
.Warning
You cannot call this function when
is_floating_point()
returnsfalse
; doing so will result in a compile-time error.- Returns:
A constexpr instance of the specialization of
rocprim::traits::float_bit_mask::values
as provided in the traits definition of type T. If thefloat_bit_mask trait
is not defined, it returns the rocprim::detail::float_bit_mask values, provided a specialization ofrocprim::detail::float_bit_mask<T>
exists.
Available traits#
-
struct is_arithmetic#
- Definability
Undefinable: For types with
predefined traits
.Optional: For other types.
- How to define
using is_arithmetic = rocprim::traits::is_arithmetic::values<true>;
- How to use
rocprim::traits::get<InputType>().is_arithmetic();
-
template<bool Val>
struct values# Value of this trait.
-
struct is_scalar#
Arithmetic types, pointers, member pointers, and null pointers are considered scalar types.
- Definability
Undefinable: For types with
predefined traits
.Optional: For other types. If both
is_arithmetic
andis_scalar
are defined, their values must be consistent; otherwise, a compile-time error will occur.
- How to define
using is_scalar = rocprim::traits::is_scalar::values<true>;
- How to use
rocprim::traits::get<InputType>().is_scalar();
-
template<bool Val>
struct values# Value of this trait.
-
struct number_format#
- Definability
Undefinable: For types with
predefined traits
and non-arithmetic types.Required: If you define
is_arithmetic
astrue
, you must also define this trait; otherwise, a compile-time error will occur.
- How to define
using number_format = rocprim::traits::number_format::values<number_format::kind::integral_type>;
- How to use
rocprim::traits::get<InputType>().is_integral(); rocprim::traits::get<InputType>().is_floating_point();
Public Types
-
struct integral_sign#
- Definability
Undefinable: For types with
predefined traits
, non-arithmetic types and floating-point types.Required: If you define
number_format
asnumber_format::kind::floating_point_type
, you must also define this trait; otherwise, a compile-time error will occur.
- How to define
using integral_sign = rocprim::traits::integral_sign::values<traits::integral_sign::kind::signed_type>;
- How to use
rocprim::traits::get<InputType>().is_signed(); rocprim::traits::get<InputType>().is_unsigned();
Public Types
-
struct float_bit_mask#
- Definability
Undefinable: For types with
predefined traits
, non-arithmetic types and integral types.Required: If you define
number_format
asnumber_format::kind::unknown_type
, you must also define this trait; otherwise, a compile-time error will occur.
- How to define
using float_bit_mask = rocprim::traits::float_bit_mask::values<int,1,1,1>;
- How to use
rocprim::traits::get<InputType>().float_bit_mask();
Warning
For some types, if this trait is not implemented in their traits definition, it will link to
rocprim::detail::float_bit_mask
to maintain compatibility with downstream libraries. However, this linkage will be removed in the next major release. Please ensure that these types are updated to the latest interface.
-
struct is_fundamental#
The trait
is_fundamental
is undefinable, as it is the union ofstd::is_fundamental
androcprim::traits::is_arithmetic
.- Definability
Undefinable: If you attempt to define this trait in any form, a compile-time error will occur.
- How to use
rocprim::traits::get<InputType>().is_fundamental(); rocprim::traits::get<InputType>().is_compound();
-
template<bool Val>
struct values# Value of this trait.
Type traits wrappers#
Warning
doxygengroup: Cannot find group “rocprim_type_traits_wrapper” in doxygen xml output for project “rocPRIM” from directory: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-rocprim/checkouts/develop/docs/doxygen/xml
-
template<class T>
struct is_floating_point : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_floating_point()> An extension of
std::is_floating_point
that supports additional arithmetic types, includingrocprim::half
,rocprim::bfloat16
, and any types with traitrocprim::traits::number_format::values<number_format::kind::floating_point_type>
implemented.
-
template<class T>
struct is_integral : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_integral()> An extension of
std::is_integral
that supports additional arithmetic types, includingrocprim::int128_t
,rocprim::uint128_t
, and any types with traitrocprim::traits::number_format::values<number_format::kind::integral_type>
implemented.
-
template<class T>
struct is_arithmetic : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_arithmetic()> An extension of
std::is_arithmetic
that supports additional arithmetic types, including any types with traitrocprim::traits::is_arithmetic::values<true>
implemented.
-
template<class T>
struct is_fundamental : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_fundamental()> An extension of
std::is_fundamental
that supports additional arithmetic types, including any types with traitrocprim::traits::is_arithmetic::values<true>
implemented.
-
template<class T>
struct is_unsigned : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_unsigned()> An extension of
std::is_unsigned
that supports additional arithmetic types, includingrocprim::uint128_t
, and any types with traitrocprim::traits::integral_sign::values<integral_sign::kind::unsigned_type>
implemented.
-
template<class T>
struct is_signed : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_signed()> An extension of
std::is_signed
that supports additional arithmetic types, includingrocprim::int128_t
, and any types with traitrocprim::traits::integral_sign::values<integral_sign::kind::signed_type>
implemented.
-
template<class T>
struct is_scalar : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_scalar()> An extension of
std::is_scalar
that supports additional arithmetic types, including any types with traitrocprim::traits::is_scalar::values<true>
implemented.
Types with predefined traits#
-
template<>
struct define<float> Public Types
-
using float_bit_mask = traits::float_bit_mask::values<uint32_t, 0x80000000, 0x7F800000, 0x007FFFFF>
-
using float_bit_mask = traits::float_bit_mask::values<uint32_t, 0x80000000, 0x7F800000, 0x007FFFFF>
-
template<>
struct define<double> Public Types
-
using float_bit_mask = traits::float_bit_mask::values<uint64_t, 0x8000000000000000, 0x7FF0000000000000, 0x000FFFFFFFFFFFFF>
-
using float_bit_mask = traits::float_bit_mask::values<uint64_t, 0x8000000000000000, 0x7FF0000000000000, 0x000FFFFFFFFFFFFF>
-
template<>
struct define<rocprim::bfloat16> Public Types
-
using is_arithmetic = traits::is_arithmetic::values<true>
-
using number_format = traits::number_format::values<traits::number_format::kind::floating_point_type>
-
using float_bit_mask = traits::float_bit_mask::values<uint16_t, 0x8000, 0x7F80, 0x007F>
-
using is_arithmetic = traits::is_arithmetic::values<true>
-
template<>
struct define<rocprim::half> Public Types
-
using is_arithmetic = traits::is_arithmetic::values<true>
-
using number_format = traits::number_format::values<traits::number_format::kind::floating_point_type>
-
using float_bit_mask = traits::float_bit_mask::values<uint16_t, 0x8000, 0x7F80, 0x007F>
-
using is_arithmetic = traits::is_arithmetic::values<true>
-
template<>
struct define<rocprim::int128_t> : public std::conditional_t<std::is_arithmetic<rocprim::int128_t>::value, traits::define<void>, detail::define_int128_t>
-
template<>
struct define<rocprim::uint128_t> : public std::conditional_t<std::is_arithmetic<rocprim::uint128_t>::value, traits::define<void>, detail::define_uint128_t>