Implementing traits for custom types in rocPRIM#

Overview#

This interface is designed to enable users to provide additional type trait information to rocPRIM, facilitating better compatibility with custom types.

Accurately describing custom types is important for performance optimization and computational correctness.

Custom types that implement arithmetic operators can behave like built-in arithmetic types but might still be interpreted by rocPRIM algorithms as generic struct or class types.

The rocPRIM type traits interface lets users add custom trait information for their types, improving compatibility between these types and rocPRIM algorithms.

This interface is similar to operator overloading.

Traits should be implemented as required by specific algorithms. Some traits can’t be defined if they can be inferred from others.

Interface#

template<class T>
struct define#

Overview

This template struct provides an interface for downstream libraries to implement type traits for their custom types. Users can utilize this template struct to define traits for these types. Users should only implement traits as required by specific algorithms, and some traits cannot be defined if they can be inferred from others. This API is not static because of ODR.

Example

The example below demonstrates how to implement traits for a custom floating-point type.

// Your type definition
struct custom_float_type
{};
// Implement the traits
template<>
struct rocprim::traits::define<custom_float_type>
{
    using is_arithmetic = rocprim::traits::is_arithmetic::values<true>;
    using number_format = rocprim::traits::number_format::values<traits::number_format::kind::floating_point_type>;
    using float_bit_mask = rocprim::traits::float_bit_mask::values<uint32_t, 10, 10, 10>;
};
The example below demonstrates how to implement traits for a custom integral type.
// Your type definition
struct custom_int_type
{};
// Implement the traits
template<>
struct rocprim::traits::define<custom_int_type>
{
    using is_arithmetic = rocprim::traits::is_arithmetic::values<true>;
    using number_format = rocprim::traits::number_format::values<traits::number_format::kind::integral_type>;
    using integral_sign = rocprim::traits::integral_sign::values<traits::integral_sign::kind::signed_type>;
};

Template Parameters:

T – The type for which you want to define traits.

template<class T>
struct get#

Overview

This template struct is designed to allow rocPRIM algorithms to retrieve trait information from C++ build-in arithmetic types, rocPRIM types, and custom types. This API is not static because of ODR.

  • All member functions are compiled only when invoked.

  • Different algorithms require different traits.

Example

The following code demonstrates how to retrieve the traits of type T.

// Get the trait in a template parameter
template<class T, std::enable_if<rocprim::traits::get<T>().is_integral()>::type* = nullptr>
void get_traits_in_template_parameter(){}
// Get the trait in a function body
template<class T>
void get_traits_in_function_body(){
    constexpr auto input_traits = rocprim::traits::get<InputType>();
    // Then you can use the member functinos
    constexpr bool is_arithmetic = input_traits.is_arithmetic();
}

Template Parameters:

T – The type from which you want to retrieve the traits.

Public Functions

inline constexpr bool is_arithmetic() const#

Get the value of trait is_arithmetic.

Returns:

true if std::is_arithmetic_v<T> is true, or if type T is a rocPRIM arithmetic type, or if the is_arithmetic trait has been defined as true; otherwise, returns false.

inline constexpr bool is_fundamental() const#

Get trait is_fundamental.

Returns:

true if T is a fundamental type (that is, rocPRIM arithmetic type, void, or nullptr_t); otherwise, returns false.

inline constexpr bool is_compound() const#

If T is fundamental type, then returns false.

Returns:

false if T is a fundamental type (that is, rocPRIM arithmetic type, void, or nullptr_t); otherwise, returns true.

inline constexpr bool is_floating_point() const#

To check if T is floating-point type.

Warning

You cannot call this function when is_arithmetic() returns false; doing so will result in a compile-time error.

inline constexpr bool is_integral() const#

To check if T is integral type.

Warning

You cannot call this function when is_arithmetic() returns false; doing so will result in a compile-time error.

inline constexpr bool is_signed() const#

To check if T is signed integral type.

Warning

You cannot call this function when is_integral() returns false; doing so will result in a compile-time error.

inline constexpr bool is_unsigned() const#

To check if T is unsigned integral type.

Warning

You cannot call this function when is_integral() returns false; doing so will result in a compile-time error.

inline constexpr bool is_scalar() const#

Get trait is_scalar.

Returns:

true if std::is_scalar_v<T> is true, or if type T is a rocPRIM arithmetic type, or if the is_scalar trait has been defined as true; otherwise, returns false.

inline constexpr auto float_bit_mask() const#

Get trait float_bit_mask.

Warning

You cannot call this function when is_floating_point() returns false; doing so will result in a compile-time error.

Returns:

A constexpr instance of the specialization of rocprim::traits::float_bit_mask::values as provided in the traits definition of type T. If the float_bit_mask trait is not defined, it returns the rocprim::detail::float_bit_mask values, provided a specialization of rocprim::detail::float_bit_mask<T> exists.

Available traits#

struct is_arithmetic#

Definability

  • Undefinable: For types with predefined traits.

  • Optional: For other types.

How to define
using is_arithmetic = rocprim::traits::is_arithmetic::values<true>;

How to use
rocprim::traits::get<InputType>().is_arithmetic();

template<bool Val>
struct values#

Value of this trait.

Public Static Attributes

static constexpr auto value = Val#

This indicates if the InputType is arithmetic.

struct is_scalar#

Arithmetic types, pointers, member pointers, and null pointers are considered scalar types.

Definability

  • Undefinable: For types with predefined traits.

  • Optional: For other types. If both is_arithmetic and is_scalar are defined, their values must be consistent; otherwise, a compile-time error will occur.

How to define
using is_scalar = rocprim::traits::is_scalar::values<true>;

How to use
rocprim::traits::get<InputType>().is_scalar();

template<bool Val>
struct values#

Value of this trait.

Public Static Attributes

static constexpr auto value = Val#

This indicates if the InputType is scalar.

struct number_format#

Definability

  • Undefinable: For types with predefined traits and non-arithmetic types.

  • Required: If you define is_arithmetic as true, you must also define this trait; otherwise, a compile-time error will occur.

How to define
using number_format = rocprim::traits::number_format::values<number_format::kind::integral_type>;

How to use
rocprim::traits::get<InputType>().is_integral();
rocprim::traits::get<InputType>().is_floating_point();

Public Types

enum class kind#

The kind enum that indecates the values avaliable for this trait.

Values:

enumerator unknown_type#
enumerator floating_point_type#
enumerator integral_type#
template<kind Val>
struct values#

Value of this trait.

Public Static Attributes

static constexpr auto value = Val#

This indicates if the InputType is floating_point_type or integral_type or unknown_type.

struct integral_sign#

Definability

  • Undefinable: For types with predefined traits, non-arithmetic types and floating-point types.

  • Required: If you define number_format as number_format::kind::floating_point_type, you must also define this trait; otherwise, a compile-time error will occur.

How to define
using integral_sign = rocprim::traits::integral_sign::values<traits::integral_sign::kind::signed_type>;

How to use
rocprim::traits::get<InputType>().is_signed();
rocprim::traits::get<InputType>().is_unsigned();

Public Types

enum class kind#

The kind enum that indecates the values avaliable for this trait.

Values:

enumerator unknown_type#
enumerator signed_type#
enumerator unsigned_type#
template<kind Val>
struct values#

Value of this trait.

Public Static Attributes

static constexpr auto value = Val#

This indicates if the InputType is signed_type or unsigned_type or unknown_type.

struct float_bit_mask#

Definability

  • Undefinable: For types with predefined traits, non-arithmetic types and integral types.

  • Required: If you define number_format as number_format::kind::unknown_type, you must also define this trait; otherwise, a compile-time error will occur.

How to define
using float_bit_mask = rocprim::traits::float_bit_mask::values<int,1,1,1>;

How to use
rocprim::traits::get<InputType>().float_bit_mask();

Warning

For some types, if this trait is not implemented in their traits definition, it will link to rocprim::detail::float_bit_mask to maintain compatibility with downstream libraries. However, this linkage will be removed in the next major release. Please ensure that these types are updated to the latest interface.

template<class BitType, BitType SignBit, BitType Exponent, BitType Mantissa>
struct values#

Value of this trait.

Public Static Attributes

static constexpr BitType sign_bit = SignBit#

Trait sign_bit for the InputType.

static constexpr BitType exponent = Exponent#

Trait exponent for the InputType.

static constexpr BitType mantissa = Mantissa#

Trait mantissa for the InputType.

struct is_fundamental#

The trait is_fundamental is undefinable, as it is the union of std::is_fundamental and rocprim::traits::is_arithmetic.

Definability

  • Undefinable: If you attempt to define this trait in any form, a compile-time error will occur.

How to use
rocprim::traits::get<InputType>().is_fundamental();
rocprim::traits::get<InputType>().is_compound();

template<bool Val>
struct values#

Value of this trait.

Public Static Attributes

static constexpr auto value = Val#

This indicates if the InputType is fundamental.

Type traits wrappers#

Warning

doxygengroup: Cannot find group “rocprim_type_traits_wrapper” in doxygen xml output for project “rocPRIM” from directory: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-rocprim/checkouts/develop/docs/doxygen/xml

template<class T>
struct is_floating_point : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_floating_point()>

An extension of std::is_floating_point that supports additional arithmetic types, including rocprim::half, rocprim::bfloat16, and any types with trait rocprim::traits::number_format::values<number_format::kind::floating_point_type> implemented.

template<class T>
struct is_integral : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_integral()>

An extension of std::is_integral that supports additional arithmetic types, including rocprim::int128_t, rocprim::uint128_t, and any types with trait rocprim::traits::number_format::values<number_format::kind::integral_type> implemented.

template<class T>
struct is_arithmetic : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_arithmetic()>

An extension of std::is_arithmetic that supports additional arithmetic types, including any types with trait rocprim::traits::is_arithmetic::values<true> implemented.

template<class T>
struct is_fundamental : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_fundamental()>

An extension of std::is_fundamental that supports additional arithmetic types, including any types with trait rocprim::traits::is_arithmetic::values<true> implemented.

template<class T>
struct is_unsigned : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_unsigned()>

An extension of std::is_unsigned that supports additional arithmetic types, including rocprim::uint128_t, and any types with trait rocprim::traits::integral_sign::values<integral_sign::kind::unsigned_type> implemented.

template<class T>
struct is_signed : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_signed()>

An extension of std::is_signed that supports additional arithmetic types, including rocprim::int128_t, and any types with trait rocprim::traits::integral_sign::values<integral_sign::kind::signed_type> implemented.

template<class T>
struct is_scalar : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_scalar()>

An extension of std::is_scalar that supports additional arithmetic types, including any types with trait rocprim::traits::is_scalar::values<true> implemented.

template<class T>
struct is_compound : public std::integral_constant<bool, ::rocprim::traits::get<T>().is_compound()>

An extension of std::is_scalar that supports additional non-arithmetic types.

Types with predefined traits#

template<>
struct define<float>

Public Types

using float_bit_mask = traits::float_bit_mask::values<uint32_t, 0x80000000, 0x7F800000, 0x007FFFFF>
template<>
struct define<double>

Public Types

using float_bit_mask = traits::float_bit_mask::values<uint64_t, 0x8000000000000000, 0x7FF0000000000000, 0x000FFFFFFFFFFFFF>
template<>
struct define<rocprim::bfloat16>

Public Types

using is_arithmetic = traits::is_arithmetic::values<true>
using number_format = traits::number_format::values<traits::number_format::kind::floating_point_type>
using float_bit_mask = traits::float_bit_mask::values<uint16_t, 0x8000, 0x7F80, 0x007F>
template<>
struct define<rocprim::half>

Public Types

using is_arithmetic = traits::is_arithmetic::values<true>
using number_format = traits::number_format::values<traits::number_format::kind::floating_point_type>
using float_bit_mask = traits::float_bit_mask::values<uint16_t, 0x8000, 0x7F80, 0x007F>
template<>
struct define<rocprim::int128_t> : public std::conditional_t<std::is_arithmetic<rocprim::int128_t>::value, traits::define<void>, detail::define_int128_t>
template<>
struct define<rocprim::uint128_t> : public std::conditional_t<std::is_arithmetic<rocprim::uint128_t>::value, traits::define<void>, detail::define_uint128_t>