30 #ifndef HIPCUB_ROCPRIM_DEVICE_DEVICE_REDUCE_HPP_
31 #define HIPCUB_ROCPRIM_DEVICE_DEVICE_REDUCE_HPP_
36 #include <hip/hip_fp16.h>
37 #include <hip/hip_bfloat16.h>
39 #include "../../../config.hpp"
40 #include "../iterator/arg_index_input_iterator.hpp"
41 #include "../thread/thread_operators.hpp"
43 #include <rocprim/device/device_reduce.hpp>
44 #include <rocprim/device/device_reduce_by_key.hpp>
46 BEGIN_HIPCUB_NAMESPACE
54 return std::numeric_limits<T>::lowest();
59 __half get_lowest_value<__half>()
61 unsigned short lowest_half = 0xfbff;
62 __half lowest_value = *
reinterpret_cast<__half*
>(&lowest_half);
68 hip_bfloat16 get_lowest_value<hip_bfloat16>()
70 return hip_bfloat16(-3.38953138925e+38f);
77 return std::numeric_limits<T>::max();
82 __half get_max_value<__half>()
84 unsigned short max_half = 0x7bff;
85 __half max_value = *
reinterpret_cast<__half*
>(&max_half);
91 hip_bfloat16 get_max_value<hip_bfloat16>()
93 return hip_bfloat16(3.38953138925e+38f);
102 typename InputIteratorT,
103 typename OutputIteratorT,
107 HIPCUB_RUNTIME_FUNCTION
static
108 hipError_t Reduce(
void *d_temp_storage,
109 size_t &temp_storage_bytes,
111 OutputIteratorT d_out,
113 ReduceOpT reduction_op,
115 hipStream_t stream = 0,
116 bool debug_synchronous =
false)
118 return ::rocprim::reduce(
119 d_temp_storage, temp_storage_bytes,
120 d_in, d_out, init, num_items,
121 ::hipcub::detail::convert_result_type<InputIteratorT, OutputIteratorT>(reduction_op),
122 stream, debug_synchronous
127 typename InputIteratorT,
128 typename OutputIteratorT
130 HIPCUB_RUNTIME_FUNCTION
static
131 hipError_t Sum(
void *d_temp_storage,
132 size_t &temp_storage_bytes,
134 OutputIteratorT d_out,
136 hipStream_t stream = 0,
137 bool debug_synchronous =
false)
139 using T =
typename std::iterator_traits<InputIteratorT>::value_type;
141 d_temp_storage, temp_storage_bytes,
143 stream, debug_synchronous
148 typename InputIteratorT,
149 typename OutputIteratorT
151 HIPCUB_RUNTIME_FUNCTION
static
152 hipError_t Min(
void *d_temp_storage,
153 size_t &temp_storage_bytes,
155 OutputIteratorT d_out,
157 hipStream_t stream = 0,
158 bool debug_synchronous =
false)
160 using T =
typename std::iterator_traits<InputIteratorT>::value_type;
162 d_temp_storage, temp_storage_bytes,
163 d_in, d_out, num_items, ::
hipcub::Min(), detail::get_max_value<T>(),
164 stream, debug_synchronous
169 typename InputIteratorT,
170 typename OutputIteratorT
172 HIPCUB_RUNTIME_FUNCTION
static
173 hipError_t ArgMin(
void *d_temp_storage,
174 size_t &temp_storage_bytes,
176 OutputIteratorT d_out,
178 hipStream_t stream = 0,
179 bool debug_synchronous =
false)
182 using T =
typename std::iterator_traits<InputIteratorT>::value_type;
183 using O =
typename std::iterator_traits<OutputIteratorT>::value_type;
185 typename std::conditional<
186 std::is_same<O, void>::value,
187 KeyValuePair<OffsetT, T>,
191 using OutputValueT =
typename OutputTupleT::Value;
192 using IteratorT = ArgIndexInputIterator<InputIteratorT, OffsetT, OutputValueT>;
194 IteratorT d_indexed_in(d_in);
195 OutputTupleT init(1, detail::get_max_value<T>());
198 d_temp_storage, temp_storage_bytes,
200 stream, debug_synchronous
205 typename InputIteratorT,
206 typename OutputIteratorT
208 HIPCUB_RUNTIME_FUNCTION
static
209 hipError_t Max(
void *d_temp_storage,
210 size_t &temp_storage_bytes,
212 OutputIteratorT d_out,
214 hipStream_t stream = 0,
215 bool debug_synchronous =
false)
217 using T =
typename std::iterator_traits<InputIteratorT>::value_type;
219 d_temp_storage, temp_storage_bytes,
220 d_in, d_out, num_items, ::
hipcub::Max(), detail::get_lowest_value<T>(),
221 stream, debug_synchronous
226 typename InputIteratorT,
227 typename OutputIteratorT
229 HIPCUB_RUNTIME_FUNCTION
static
230 hipError_t ArgMax(
void *d_temp_storage,
231 size_t &temp_storage_bytes,
233 OutputIteratorT d_out,
235 hipStream_t stream = 0,
236 bool debug_synchronous =
false)
239 using T =
typename std::iterator_traits<InputIteratorT>::value_type;
240 using O =
typename std::iterator_traits<OutputIteratorT>::value_type;
242 typename std::conditional<
243 std::is_same<O, void>::value,
244 KeyValuePair<OffsetT, T>,
248 using OutputValueT =
typename OutputTupleT::Value;
249 using IteratorT = ArgIndexInputIterator<InputIteratorT, OffsetT, OutputValueT>;
251 IteratorT d_indexed_in(d_in);
252 OutputTupleT init(1, detail::get_lowest_value<T>());
255 d_temp_storage, temp_storage_bytes,
257 stream, debug_synchronous
262 typename KeysInputIteratorT,
263 typename UniqueOutputIteratorT,
264 typename ValuesInputIteratorT,
265 typename AggregatesOutputIteratorT,
266 typename NumRunsOutputIteratorT,
267 typename ReductionOpT
269 HIPCUB_RUNTIME_FUNCTION
static
270 hipError_t ReduceByKey(
void * d_temp_storage,
271 size_t& temp_storage_bytes,
272 KeysInputIteratorT d_keys_in,
273 UniqueOutputIteratorT d_unique_out,
274 ValuesInputIteratorT d_values_in,
275 AggregatesOutputIteratorT d_aggregates_out,
276 NumRunsOutputIteratorT d_num_runs_out,
277 ReductionOpT reduction_op,
279 hipStream_t stream = 0,
280 bool debug_synchronous =
false)
282 using key_compare_op =
283 ::rocprim::equal_to<typename std::iterator_traits<KeysInputIteratorT>::value_type>;
284 return ::rocprim::reduce_by_key(
285 d_temp_storage, temp_storage_bytes,
286 d_keys_in, d_values_in, num_items,
287 d_unique_out, d_aggregates_out, d_num_runs_out,
288 ::hipcub::detail::convert_result_type<ValuesInputIteratorT, AggregatesOutputIteratorT>(reduction_op),
290 stream, debug_synchronous
Definition: thread_operators.hpp:126
Definition: thread_operators.hpp:141
Definition: thread_operators.hpp:106
Definition: thread_operators.hpp:116
Definition: thread_operators.hpp:76