30 #ifndef HIPCUB_ROCPRIM_WARP_WARP_REDUCE_HPP_
31 #define HIPCUB_ROCPRIM_WARP_WARP_REDUCE_HPP_
33 #include "../../../config.hpp"
35 #include "../util_ptx.hpp"
36 #include "../thread/thread_operators.hpp"
38 #include <rocprim/warp/warp_reduce.hpp>
40 BEGIN_HIPCUB_NAMESPACE
44 int LOGICAL_WARP_THREADS = HIPCUB_DEVICE_WARP_THREADS,
45 int ARCH = HIPCUB_ARCH>
46 class WarpReduce :
private ::rocprim::warp_reduce<T, LOGICAL_WARP_THREADS>
48 static_assert(LOGICAL_WARP_THREADS > 0,
"LOGICAL_WARP_THREADS must be greater than 0");
49 using base_type = typename ::rocprim::warp_reduce<T, LOGICAL_WARP_THREADS>;
51 typename base_type::storage_type &temp_storage_;
54 using TempStorage =
typename base_type::storage_type;
57 WarpReduce(TempStorage& temp_storage) : temp_storage_(temp_storage)
64 base_type::reduce(input, input, temp_storage_);
69 T
Sum(T input,
int valid_items)
71 base_type::reduce(input, input, valid_items, temp_storage_);
75 template<
typename FlagT>
77 T HeadSegmentedSum(T input, FlagT head_flag)
79 base_type::head_segmented_reduce(input, input, head_flag, temp_storage_);
83 template<
typename FlagT>
85 T TailSegmentedSum(T input, FlagT tail_flag)
87 base_type::tail_segmented_reduce(input, input, tail_flag, temp_storage_);
91 template<
typename ReduceOp>
93 T Reduce(T input, ReduceOp reduce_op)
95 base_type::reduce(input, input, temp_storage_, reduce_op);
99 template<
typename ReduceOp>
101 T Reduce(T input, ReduceOp reduce_op,
int valid_items)
103 base_type::reduce(input, input, valid_items, temp_storage_, reduce_op);
107 template<
typename ReduceOp,
typename FlagT>
109 T HeadSegmentedReduce(T input, FlagT head_flag, ReduceOp reduce_op)
111 base_type::head_segmented_reduce(
112 input, input, head_flag, temp_storage_, reduce_op
117 template<
typename ReduceOp,
typename FlagT>
119 T TailSegmentedReduce(T input, FlagT tail_flag, ReduceOp reduce_op)
121 base_type::tail_segmented_reduce(
122 input, input, tail_flag, temp_storage_, reduce_op
Definition: warp_reduce.hpp:47
Definition: thread_operators.hpp:76