/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.5.1/hipcub/include/hipcub/backend/rocprim/thread/thread_operators.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.5.1/hipcub/include/hipcub/backend/rocprim/thread/thread_operators.hpp Source File#

hipCUB: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.5.1/hipcub/include/hipcub/backend/rocprim/thread/thread_operators.hpp Source File
thread_operators.hpp
1 /******************************************************************************
2  * Copyright (c) 2010-2011, Duane Merrill. All rights reserved.
3  * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
4  * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  * * Redistributions of source code must retain the above copyright
9  * notice, this list of conditions and the following disclaimer.
10  * * Redistributions in binary form must reproduce the above copyright
11  * notice, this list of conditions and the following disclaimer in the
12  * documentation and/or other materials provided with the distribution.
13  * * Neither the name of the NVIDIA CORPORATION nor the
14  * names of its contributors may be used to endorse or promote products
15  * derived from this software without specific prior written permission.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
18  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
19  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
21  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
22  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
23  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
24  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27  *
28  ******************************************************************************/
29 
30 #ifndef HIBCUB_ROCPRIM_THREAD_THREAD_OPERATORS_HPP_
31 #define HIBCUB_ROCPRIM_THREAD_THREAD_OPERATORS_HPP_
32 
33 #include "../../../config.hpp"
34 
35 #include "../util_type.hpp"
36 
37 BEGIN_HIPCUB_NAMESPACE
38 
39 struct Equality
40 {
41  template<class T>
42  HIPCUB_HOST_DEVICE inline
43  constexpr bool operator()(const T& a, const T& b) const
44  {
45  return a == b;
46  }
47 };
48 
49 struct Inequality
50 {
51  template<class T>
52  HIPCUB_HOST_DEVICE inline
53  constexpr bool operator()(const T& a, const T& b) const
54  {
55  return a != b;
56  }
57 };
58 
59 template <class EqualityOp>
61 {
62  EqualityOp op;
63 
64  HIPCUB_HOST_DEVICE inline
65  InequalityWrapper(EqualityOp op) : op(op) {}
66 
67  template<class T>
68  HIPCUB_HOST_DEVICE inline
69  bool operator()(const T &a, const T &b)
70  {
71  return !op(a, b);
72  }
73 };
74 
75 struct Sum
76 {
77  template<class T>
78  HIPCUB_HOST_DEVICE inline
79  constexpr T operator()(const T &a, const T &b) const
80  {
81  return a + b;
82  }
83 };
84 
85 struct Difference
86 {
87  template <class T>
88  HIPCUB_HOST_DEVICE inline
89  constexpr T operator()(const T &a, const T &b) const
90  {
91  return a - b;
92  }
93 };
94 
95 struct Division
96 {
97  template <class T>
98  HIPCUB_HOST_DEVICE inline
99  constexpr T operator()(const T &a, const T &b) const
100  {
101  return a / b;
102  }
103 };
104 
105 struct Max
106 {
107  template<class T>
108  HIPCUB_HOST_DEVICE inline
109  constexpr T operator()(const T &a, const T &b) const
110  {
111  return a < b ? b : a;
112  }
113 };
114 
115 struct Min
116 {
117  template<class T>
118  HIPCUB_HOST_DEVICE inline
119  constexpr T operator()(const T &a, const T &b) const
120  {
121  return a < b ? a : b;
122  }
123 };
124 
125 struct ArgMax
126 {
127  template<
128  class Key,
129  class Value
130  >
131  HIPCUB_HOST_DEVICE inline
132  constexpr KeyValuePair<Key, Value>
133  operator()(const KeyValuePair<Key, Value>& a,
134  const KeyValuePair<Key, Value>& b) const
135  {
136  return ((b.value > a.value) || ((a.value == b.value) && (b.key < a.key))) ? b : a;
137  }
138 };
139 
140 struct ArgMin
141 {
142  template<
143  class Key,
144  class Value
145  >
146  HIPCUB_HOST_DEVICE inline
147  constexpr KeyValuePair<Key, Value>
148  operator()(const KeyValuePair<Key, Value>& a,
149  const KeyValuePair<Key, Value>& b) const
150  {
151  return ((b.value < a.value) || ((a.value == b.value) && (b.key < a.key))) ? b : a;
152  }
153 };
154 
155 template <typename B>
156 struct CastOp
157 {
158  template <typename A>
159  HIPCUB_HOST_DEVICE inline
160  B operator()(const A &a) const
161  {
162  return (B)a;
163  }
164 };
165 
166 template <typename ScanOp>
168 {
169 private:
170  ScanOp scan_op;
171 
172 public:
173  HIPCUB_HOST_DEVICE inline
174  SwizzleScanOp(ScanOp scan_op) : scan_op(scan_op)
175  {
176  }
177 
178  template <typename T>
179  HIPCUB_HOST_DEVICE inline
180  T operator()(const T &a, const T &b)
181  {
182  T _a(a);
183  T _b(b);
184 
185  return scan_op(_b, _a);
186  }
187 };
188 
189 template <typename ReductionOpT>
191 {
192  ReductionOpT op;
193 
194  HIPCUB_HOST_DEVICE inline
196  {
197  }
198 
199  HIPCUB_HOST_DEVICE inline
200  ReduceBySegmentOp(ReductionOpT op) : op(op)
201  {
202  }
203 
204  template <typename KeyValuePairT>
205  HIPCUB_HOST_DEVICE inline
206  KeyValuePairT operator()(
207  const KeyValuePairT &first,
208  const KeyValuePairT &second)
209  {
210  KeyValuePairT retval;
211  retval.key = first.key + second.key;
212  retval.value = (second.key) ?
213  second.value :
214  op(first.value, second.value);
215  return retval;
216  }
217 };
218 
219 template <typename ReductionOpT>
221 {
222  ReductionOpT op;
223 
224  HIPCUB_HOST_DEVICE inline
225  ReduceByKeyOp()
226  {
227  }
228 
229  HIPCUB_HOST_DEVICE inline
230  ReduceByKeyOp(ReductionOpT op) : op(op)
231  {
232  }
233 
234  template <typename KeyValuePairT>
235  HIPCUB_HOST_DEVICE inline
236  KeyValuePairT operator()(
237  const KeyValuePairT &first,
238  const KeyValuePairT &second)
239  {
240  KeyValuePairT retval = second;
241 
242  if (first.key == second.key)
243  {
244  retval.value = op(first.value, retval.value);
245  }
246  return retval;
247  }
248 };
249 
250 template <typename BinaryOpT>
252 {
253  BinaryOpT binary_op;
254 
255  HIPCUB_HOST_DEVICE
256  explicit BinaryFlip(BinaryOpT binary_op) : binary_op(binary_op)
257  {
258  }
259 
260  template <typename T, typename U>
261  HIPCUB_DEVICE auto
262  operator()(T &&t, U &&u) -> decltype(binary_op(std::forward<U>(u),
263  std::forward<T>(t)))
264  {
265  return binary_op(std::forward<U>(u), std::forward<T>(t));
266  }
267 };
268 
269 template <typename BinaryOpT>
270 HIPCUB_HOST_DEVICE
271 BinaryFlip<BinaryOpT> MakeBinaryFlip(BinaryOpT binary_op)
272 {
273  return BinaryFlip<BinaryOpT>(binary_op);
274 }
275 
276 namespace detail
277 {
278 
279 // CUB uses value_type of OutputIteratorT (if not void) as a type of intermediate results in reduce,
280 // for example:
281 //
282 // /// The output value type
283 // typedef typename If<(Equals<typename std::iterator_traits<OutputIteratorT>::value_type, void>::VALUE), // OutputT = (if output iterator's value type is void) ?
284 // typename std::iterator_traits<InputIteratorT>::value_type, // ... then the input iterator's value type,
285 // typename std::iterator_traits<OutputIteratorT>::value_type>::Type OutputT; // ... else the output iterator's value type
286 //
287 // rocPRIM (as well as Thrust) uses result type of BinaryFunction instead (if not void):
288 //
289 // using input_type = typename std::iterator_traits<InputIterator>::value_type;
290 // using result_type = typename ::rocprim::detail::match_result_type<
291 // input_type, BinaryFunction
292 // >::type;
293 //
294 // For short -> float using Sum()
295 // CUB: float Sum(float, float)
296 // rocPRIM: short Sum(short, short)
297 //
298 // This wrapper allows to have compatibility with CUB in hipCUB.
299 template<
300  class InputIteratorT,
301  class OutputIteratorT,
302  class BinaryFunction
303 >
304 struct convert_result_type_wrapper
305 {
306  using input_type = typename std::iterator_traits<InputIteratorT>::value_type;
307  using output_type = typename std::iterator_traits<OutputIteratorT>::value_type;
308  using result_type =
309  typename std::conditional<
310  std::is_void<output_type>::value, input_type, output_type
311  >::type;
312 
313  convert_result_type_wrapper(BinaryFunction op) : op(op) {}
314 
315  template<class T>
316  HIPCUB_HOST_DEVICE inline
317  constexpr result_type operator()(const T &a, const T &b) const
318  {
319  return static_cast<result_type>(op(a, b));
320  }
321 
322  BinaryFunction op;
323 };
324 
325 template<
326  class InputIteratorT,
327  class OutputIteratorT,
328  class BinaryFunction
329 >
330 inline
331 convert_result_type_wrapper<InputIteratorT, OutputIteratorT, BinaryFunction>
332 convert_result_type(BinaryFunction op)
333 {
334  return convert_result_type_wrapper<InputIteratorT, OutputIteratorT, BinaryFunction>(op);
335 }
336 
337 } // end detail namespace
338 
339 END_HIPCUB_NAMESPACE
340 
341 #endif // HIBCUB_ROCPRIM_THREAD_THREAD_OPERATORS_HPP_
Definition: thread_operators.hpp:168
Definition: thread_operators.hpp:126
Definition: thread_operators.hpp:141
Definition: thread_operators.hpp:252
Definition: thread_operators.hpp:157
Definition: thread_operators.hpp:86
Definition: thread_operators.hpp:96
Definition: thread_operators.hpp:40
Definition: thread_operators.hpp:61
Definition: thread_operators.hpp:50
Definition: thread_operators.hpp:106
Definition: thread_operators.hpp:116
Definition: thread_operators.hpp:221
Definition: thread_operators.hpp:191
Definition: thread_operators.hpp:76