30 #ifndef HIPCUB_ROCPRIM_WARP_WARP_SCAN_HPP_
31 #define HIPCUB_ROCPRIM_WARP_WARP_SCAN_HPP_
33 #include "../../../config.hpp"
35 #include "../util_ptx.hpp"
36 #include "../thread/thread_operators.hpp"
38 #include <rocprim/warp/warp_scan.hpp>
40 BEGIN_HIPCUB_NAMESPACE
44 int LOGICAL_WARP_THREADS = HIPCUB_DEVICE_WARP_THREADS,
45 int ARCH = HIPCUB_ARCH>
46 class WarpScan :
private ::rocprim::warp_scan<T, LOGICAL_WARP_THREADS>
48 static_assert(LOGICAL_WARP_THREADS > 0,
"LOGICAL_WARP_THREADS must be greater than 0");
49 using base_type = typename ::rocprim::warp_scan<T, LOGICAL_WARP_THREADS>;
51 typename base_type::storage_type &temp_storage_;
54 using TempStorage =
typename base_type::storage_type;
57 WarpScan(TempStorage& temp_storage) : temp_storage_(temp_storage)
62 void InclusiveSum(T input, T& inclusive_output)
64 base_type::inclusive_scan(input, inclusive_output, temp_storage_);
68 void InclusiveSum(T input, T& inclusive_output, T& warp_aggregate)
70 base_type::inclusive_scan(input, inclusive_output, warp_aggregate, temp_storage_);
74 void ExclusiveSum(T input, T& exclusive_output)
76 base_type::exclusive_scan(input, exclusive_output, T(0), temp_storage_);
80 void ExclusiveSum(T input, T& exclusive_output, T& warp_aggregate)
82 base_type::exclusive_scan(input, exclusive_output, T(0), warp_aggregate, temp_storage_);
85 template<
typename ScanOp>
87 void InclusiveScan(T input, T& inclusive_output, ScanOp scan_op)
89 base_type::inclusive_scan(input, inclusive_output, temp_storage_, scan_op);
92 template<
typename ScanOp>
94 void InclusiveScan(T input, T& inclusive_output, ScanOp scan_op, T& warp_aggregate)
96 base_type::inclusive_scan(
97 input, inclusive_output, warp_aggregate,
98 temp_storage_, scan_op
102 template<
typename ScanOp>
104 void ExclusiveScan(T input, T& exclusive_output, ScanOp scan_op)
106 base_type::inclusive_scan(input, exclusive_output, temp_storage_, scan_op);
107 base_type::to_exclusive(exclusive_output, exclusive_output, temp_storage_);
110 template<
typename ScanOp>
112 void ExclusiveScan(T input, T& exclusive_output, T initial_value, ScanOp scan_op)
114 base_type::exclusive_scan(
115 input, exclusive_output, initial_value,
116 temp_storage_, scan_op
120 template<
typename ScanOp>
122 void ExclusiveScan(T input, T& exclusive_output, ScanOp scan_op, T& warp_aggregate)
124 base_type::inclusive_scan(
125 input, exclusive_output, warp_aggregate, temp_storage_, scan_op
127 base_type::to_exclusive(exclusive_output, exclusive_output, temp_storage_);
130 template<
typename ScanOp>
132 void ExclusiveScan(T input, T& exclusive_output, T initial_value, ScanOp scan_op, T& warp_aggregate)
134 base_type::exclusive_scan(
135 input, exclusive_output, initial_value, warp_aggregate,
136 temp_storage_, scan_op
140 template<
typename ScanOp>
142 void Scan(T input, T& inclusive_output, T& exclusive_output, ScanOp scan_op)
144 base_type::inclusive_scan(input, inclusive_output, temp_storage_, scan_op);
145 base_type::to_exclusive(inclusive_output, exclusive_output, temp_storage_);
148 template<
typename ScanOp>
150 void Scan(T input, T& inclusive_output, T& exclusive_output, T initial_value, ScanOp scan_op)
153 input, inclusive_output, exclusive_output, initial_value,
154 temp_storage_, scan_op
160 inclusive_output = scan_op(initial_value, inclusive_output);
164 T Broadcast(T input,
unsigned int src_lane)
166 return base_type::broadcast(input, src_lane, temp_storage_);
Definition: warp_scan.hpp:47