30 #ifndef HIPCUB_ROCPRIM_DEVICE_DEVICE_REDUCE_HPP_
31 #define HIPCUB_ROCPRIM_DEVICE_DEVICE_REDUCE_HPP_
36 #include <hip/hip_fp16.h>
37 #include <hip/hip_bfloat16.h>
39 #include "../../../config.hpp"
40 #include "../iterator/arg_index_input_iterator.hpp"
41 #include "../thread/thread_operators.hpp"
43 #include <rocprim/device/device_reduce.hpp>
44 #include <rocprim/device/device_reduce_by_key.hpp>
46 BEGIN_HIPCUB_NAMESPACE
51 HIPCUB_HOST_DEVICE T set_half_bits(uint16_t value)
54 unsigned char* char_representation =
reinterpret_cast<unsigned char*
>(&half_value);
55 char_representation[0] = value;
56 char_representation[1] = value >> 8;
61 HIPCUB_HOST_DEVICE
inline T get_lowest_value()
63 return std::numeric_limits<T>::lowest();
67 HIPCUB_HOST_DEVICE
inline __half get_lowest_value<__half>()
70 return set_half_bits<__half>(0xfbff);
74 HIPCUB_HOST_DEVICE
inline hip_bfloat16 get_lowest_value<hip_bfloat16>()
77 return set_half_bits<hip_bfloat16>(0xff7f);
81 HIPCUB_HOST_DEVICE
inline T get_max_value()
83 return std::numeric_limits<T>::max();
87 HIPCUB_HOST_DEVICE
inline __half get_max_value<__half>()
90 return set_half_bits<__half>(0x7bff);
94 HIPCUB_HOST_DEVICE
inline hip_bfloat16 get_max_value<hip_bfloat16>()
97 return set_half_bits<hip_bfloat16>(0x7f7f);
102 inline auto get_lowest_special_value() ->
103 typename std::enable_if_t<!rocprim::is_floating_point<T>::value, T>
105 return get_lowest_value<T>();
110 inline auto get_lowest_special_value() ->
111 typename std::enable_if_t<rocprim::is_floating_point<T>::value, T>
113 return -std::numeric_limits<T>::infinity();
117 inline __half get_lowest_special_value<__half>()
120 return set_half_bits<__half>(0xfc00);
124 inline hip_bfloat16 get_lowest_special_value<hip_bfloat16>()
127 return set_half_bits<hip_bfloat16>(0xff80);
132 inline auto get_max_special_value() ->
133 typename std::enable_if_t<!rocprim::is_floating_point<T>::value, T>
135 return get_max_value<T>();
140 inline auto get_max_special_value() ->
141 typename std::enable_if_t<rocprim::is_floating_point<T>::value, T>
143 return std::numeric_limits<T>::infinity();
147 inline __half get_max_special_value<__half>()
150 return set_half_bits<__half>(0x7c00);
154 inline hip_bfloat16 get_max_special_value<hip_bfloat16>()
157 return set_half_bits<hip_bfloat16>(0x7f80);
166 typename InputIteratorT,
167 typename OutputIteratorT,
171 HIPCUB_RUNTIME_FUNCTION
static
172 hipError_t Reduce(
void *d_temp_storage,
173 size_t &temp_storage_bytes,
175 OutputIteratorT d_out,
177 ReduceOpT reduction_op,
179 hipStream_t stream = 0,
180 bool debug_synchronous =
false)
182 return ::rocprim::reduce(
183 d_temp_storage, temp_storage_bytes,
184 d_in, d_out, init, num_items,
185 ::hipcub::detail::convert_result_type<InputIteratorT, OutputIteratorT>(reduction_op),
186 stream, debug_synchronous
191 typename InputIteratorT,
192 typename OutputIteratorT
194 HIPCUB_RUNTIME_FUNCTION
static
195 hipError_t Sum(
void *d_temp_storage,
196 size_t &temp_storage_bytes,
198 OutputIteratorT d_out,
200 hipStream_t stream = 0,
201 bool debug_synchronous =
false)
203 using T =
typename std::iterator_traits<InputIteratorT>::value_type;
205 d_temp_storage, temp_storage_bytes,
207 stream, debug_synchronous
212 typename InputIteratorT,
213 typename OutputIteratorT
215 HIPCUB_RUNTIME_FUNCTION
static
216 hipError_t Min(
void *d_temp_storage,
217 size_t &temp_storage_bytes,
219 OutputIteratorT d_out,
221 hipStream_t stream = 0,
222 bool debug_synchronous =
false)
224 using T =
typename std::iterator_traits<InputIteratorT>::value_type;
226 d_temp_storage, temp_storage_bytes,
227 d_in, d_out, num_items, ::
hipcub::Min(), detail::get_max_value<T>(),
228 stream, debug_synchronous
233 typename InputIteratorT,
234 typename OutputIteratorT
236 HIPCUB_RUNTIME_FUNCTION
static
237 hipError_t ArgMin(
void *d_temp_storage,
238 size_t &temp_storage_bytes,
240 OutputIteratorT d_out,
242 hipStream_t stream = 0,
243 bool debug_synchronous =
false)
246 using T =
typename std::iterator_traits<InputIteratorT>::value_type;
247 using O =
typename std::iterator_traits<OutputIteratorT>::value_type;
249 typename std::conditional<
250 std::is_same<O, void>::value,
251 KeyValuePair<OffsetT, T>,
255 using OutputValueT =
typename OutputTupleT::Value;
256 using IteratorT = ArgIndexInputIterator<InputIteratorT, OffsetT, OutputValueT>;
258 IteratorT d_indexed_in(d_in);
262 num_items > 0 ? detail::get_max_special_value<T>()
263 : detail::get_max_value<T>());
266 d_temp_storage, temp_storage_bytes,
268 stream, debug_synchronous
273 typename InputIteratorT,
274 typename OutputIteratorT
276 HIPCUB_RUNTIME_FUNCTION
static
277 hipError_t Max(
void *d_temp_storage,
278 size_t &temp_storage_bytes,
280 OutputIteratorT d_out,
282 hipStream_t stream = 0,
283 bool debug_synchronous =
false)
285 using T =
typename std::iterator_traits<InputIteratorT>::value_type;
287 d_temp_storage, temp_storage_bytes,
288 d_in, d_out, num_items, ::
hipcub::Max(), detail::get_lowest_value<T>(),
289 stream, debug_synchronous
294 typename InputIteratorT,
295 typename OutputIteratorT
297 HIPCUB_RUNTIME_FUNCTION
static
298 hipError_t ArgMax(
void *d_temp_storage,
299 size_t &temp_storage_bytes,
301 OutputIteratorT d_out,
303 hipStream_t stream = 0,
304 bool debug_synchronous =
false)
307 using T =
typename std::iterator_traits<InputIteratorT>::value_type;
308 using O =
typename std::iterator_traits<OutputIteratorT>::value_type;
310 typename std::conditional<
311 std::is_same<O, void>::value,
312 KeyValuePair<OffsetT, T>,
316 using OutputValueT =
typename OutputTupleT::Value;
317 using IteratorT = ArgIndexInputIterator<InputIteratorT, OffsetT, OutputValueT>;
319 IteratorT d_indexed_in(d_in);
322 const OutputTupleT init(1,
323 num_items > 0 ? detail::get_lowest_special_value<T>()
324 : detail::get_lowest_value<T>());
327 d_temp_storage, temp_storage_bytes,
329 stream, debug_synchronous
334 typename KeysInputIteratorT,
335 typename UniqueOutputIteratorT,
336 typename ValuesInputIteratorT,
337 typename AggregatesOutputIteratorT,
338 typename NumRunsOutputIteratorT,
339 typename ReductionOpT
341 HIPCUB_RUNTIME_FUNCTION
static
342 hipError_t ReduceByKey(
void * d_temp_storage,
343 size_t& temp_storage_bytes,
344 KeysInputIteratorT d_keys_in,
345 UniqueOutputIteratorT d_unique_out,
346 ValuesInputIteratorT d_values_in,
347 AggregatesOutputIteratorT d_aggregates_out,
348 NumRunsOutputIteratorT d_num_runs_out,
349 ReductionOpT reduction_op,
351 hipStream_t stream = 0,
352 bool debug_synchronous =
false)
354 using key_compare_op =
355 ::rocprim::equal_to<typename std::iterator_traits<KeysInputIteratorT>::value_type>;
356 return ::rocprim::reduce_by_key(
357 d_temp_storage, temp_storage_bytes,
358 d_keys_in, d_values_in, num_items,
359 d_unique_out, d_aggregates_out, d_num_runs_out,
360 ::hipcub::detail::convert_result_type<ValuesInputIteratorT, AggregatesOutputIteratorT>(reduction_op),
362 stream, debug_synchronous
Definition: thread_operators.hpp:126
Definition: thread_operators.hpp:141
Definition: thread_operators.hpp:106
Definition: thread_operators.hpp:116
Definition: thread_operators.hpp:76