30 #ifndef HIPCUB_ROCPRIM_DEVICE_DEVICE_REDUCE_HPP_
31 #define HIPCUB_ROCPRIM_DEVICE_DEVICE_REDUCE_HPP_
36 #include <hip/hip_fp16.h>
38 #include "../../../config.hpp"
39 #include "../iterator/arg_index_input_iterator.hpp"
40 #include "../thread/thread_operators.hpp"
42 #include <rocprim/device/device_reduce.hpp>
43 #include <rocprim/device/device_reduce_by_key.hpp>
45 BEGIN_HIPCUB_NAMESPACE
53 return std::numeric_limits<T>::lowest();
58 __half get_lowest_value<__half>()
60 unsigned short lowest_half = 0xfbff;
61 __half lowest_value = *
reinterpret_cast<__half*
>(&lowest_half);
69 return std::numeric_limits<T>::max();
74 __half get_max_value<__half>()
76 unsigned short max_half = 0x7bff;
77 __half max_value = *
reinterpret_cast<__half*
>(&max_half);
87 typename InputIteratorT,
88 typename OutputIteratorT,
92 HIPCUB_RUNTIME_FUNCTION
static
93 hipError_t Reduce(
void *d_temp_storage,
94 size_t &temp_storage_bytes,
96 OutputIteratorT d_out,
98 ReduceOpT reduction_op,
100 hipStream_t stream = 0,
101 bool debug_synchronous =
false)
103 return ::rocprim::reduce(
104 d_temp_storage, temp_storage_bytes,
105 d_in, d_out, init, num_items,
106 ::hipcub::detail::convert_result_type<InputIteratorT, OutputIteratorT>(reduction_op),
107 stream, debug_synchronous
112 typename InputIteratorT,
113 typename OutputIteratorT
115 HIPCUB_RUNTIME_FUNCTION
static
116 hipError_t Sum(
void *d_temp_storage,
117 size_t &temp_storage_bytes,
119 OutputIteratorT d_out,
121 hipStream_t stream = 0,
122 bool debug_synchronous =
false)
124 using T =
typename std::iterator_traits<InputIteratorT>::value_type;
126 d_temp_storage, temp_storage_bytes,
128 stream, debug_synchronous
133 typename InputIteratorT,
134 typename OutputIteratorT
136 HIPCUB_RUNTIME_FUNCTION
static
137 hipError_t Min(
void *d_temp_storage,
138 size_t &temp_storage_bytes,
140 OutputIteratorT d_out,
142 hipStream_t stream = 0,
143 bool debug_synchronous =
false)
145 using T =
typename std::iterator_traits<InputIteratorT>::value_type;
147 d_temp_storage, temp_storage_bytes,
148 d_in, d_out, num_items, ::
hipcub::Min(), detail::get_max_value<T>(),
149 stream, debug_synchronous
154 typename InputIteratorT,
155 typename OutputIteratorT
157 HIPCUB_RUNTIME_FUNCTION
static
158 hipError_t ArgMin(
void *d_temp_storage,
159 size_t &temp_storage_bytes,
161 OutputIteratorT d_out,
163 hipStream_t stream = 0,
164 bool debug_synchronous =
false)
167 using T =
typename std::iterator_traits<InputIteratorT>::value_type;
168 using O =
typename std::iterator_traits<OutputIteratorT>::value_type;
170 typename std::conditional<
171 std::is_same<O, void>::value,
172 KeyValuePair<OffsetT, T>,
176 using OutputValueT =
typename OutputTupleT::Value;
177 using IteratorT = ArgIndexInputIterator<InputIteratorT, OffsetT, OutputValueT>;
179 IteratorT d_indexed_in(d_in);
180 OutputTupleT init(1, detail::get_max_value<T>());
183 d_temp_storage, temp_storage_bytes,
185 stream, debug_synchronous
190 typename InputIteratorT,
191 typename OutputIteratorT
193 HIPCUB_RUNTIME_FUNCTION
static
194 hipError_t Max(
void *d_temp_storage,
195 size_t &temp_storage_bytes,
197 OutputIteratorT d_out,
199 hipStream_t stream = 0,
200 bool debug_synchronous =
false)
202 using T =
typename std::iterator_traits<InputIteratorT>::value_type;
204 d_temp_storage, temp_storage_bytes,
205 d_in, d_out, num_items, ::
hipcub::Max(), detail::get_lowest_value<T>(),
206 stream, debug_synchronous
211 typename InputIteratorT,
212 typename OutputIteratorT
214 HIPCUB_RUNTIME_FUNCTION
static
215 hipError_t ArgMax(
void *d_temp_storage,
216 size_t &temp_storage_bytes,
218 OutputIteratorT d_out,
220 hipStream_t stream = 0,
221 bool debug_synchronous =
false)
224 using T =
typename std::iterator_traits<InputIteratorT>::value_type;
225 using O =
typename std::iterator_traits<OutputIteratorT>::value_type;
227 typename std::conditional<
228 std::is_same<O, void>::value,
229 KeyValuePair<OffsetT, T>,
233 using OutputValueT =
typename OutputTupleT::Value;
234 using IteratorT = ArgIndexInputIterator<InputIteratorT, OffsetT, OutputValueT>;
236 IteratorT d_indexed_in(d_in);
237 OutputTupleT init(1, detail::get_lowest_value<T>());
240 d_temp_storage, temp_storage_bytes,
242 stream, debug_synchronous
247 typename KeysInputIteratorT,
248 typename UniqueOutputIteratorT,
249 typename ValuesInputIteratorT,
250 typename AggregatesOutputIteratorT,
251 typename NumRunsOutputIteratorT,
252 typename ReductionOpT
254 HIPCUB_RUNTIME_FUNCTION
static
255 hipError_t ReduceByKey(
void * d_temp_storage,
256 size_t& temp_storage_bytes,
257 KeysInputIteratorT d_keys_in,
258 UniqueOutputIteratorT d_unique_out,
259 ValuesInputIteratorT d_values_in,
260 AggregatesOutputIteratorT d_aggregates_out,
261 NumRunsOutputIteratorT d_num_runs_out,
262 ReductionOpT reduction_op,
264 hipStream_t stream = 0,
265 bool debug_synchronous =
false)
267 using key_compare_op =
268 ::rocprim::equal_to<typename std::iterator_traits<KeysInputIteratorT>::value_type>;
269 return ::rocprim::reduce_by_key(
270 d_temp_storage, temp_storage_bytes,
271 d_keys_in, d_values_in, num_items,
272 d_unique_out, d_aggregates_out, d_num_runs_out,
273 ::hipcub::detail::convert_result_type<ValuesInputIteratorT, AggregatesOutputIteratorT>(reduction_op),
275 stream, debug_synchronous
Definition: thread_operators.hpp:106
Definition: thread_operators.hpp:121
Definition: thread_operators.hpp:86
Definition: thread_operators.hpp:96
Definition: thread_operators.hpp:76