30 #ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_SCAN_HPP_
31 #define HIPCUB_ROCPRIM_BLOCK_BLOCK_SCAN_HPP_
33 #include <type_traits>
35 #include "../../../config.hpp"
37 #include "../thread/thread_operators.hpp"
39 #include <rocprim/block/block_scan.hpp>
41 BEGIN_HIPCUB_NAMESPACE
46 typename std::underlying_type<::rocprim::block_scan_algorithm>::type
47 to_BlockScanAlgorithm_enum(::rocprim::block_scan_algorithm v)
49 using utype = std::underlying_type<::rocprim::block_scan_algorithm>::type;
50 return static_cast<utype
>(v);
54 enum BlockScanAlgorithm
57 = detail::to_BlockScanAlgorithm_enum(::rocprim::block_scan_algorithm::reduce_then_scan),
58 BLOCK_SCAN_RAKING_MEMOIZE
59 = detail::to_BlockScanAlgorithm_enum(::rocprim::block_scan_algorithm::reduce_then_scan),
61 = detail::to_BlockScanAlgorithm_enum(::rocprim::block_scan_algorithm::using_warp_scan)
67 BlockScanAlgorithm ALGORITHM = BLOCK_SCAN_RAKING,
70 int ARCH = HIPCUB_ARCH
73 :
private ::rocprim::block_scan<
76 static_cast<::rocprim::block_scan_algorithm>(ALGORITHM),
82 BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z > 0,
83 "BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z must be greater than 0"
87 typename ::rocprim::block_scan<
90 static_cast<::rocprim::block_scan_algorithm
>(ALGORITHM),
96 typename base_type::storage_type& temp_storage_;
99 using TempStorage =
typename base_type::storage_type;
102 BlockScan() : temp_storage_(private_storage())
107 BlockScan(TempStorage& temp_storage) : temp_storage_(temp_storage)
112 void InclusiveSum(T input, T& output)
114 base_type::inclusive_scan(input, output, temp_storage_);
118 void InclusiveSum(T input, T& output, T& block_aggregate)
120 base_type::inclusive_scan(input, output, block_aggregate, temp_storage_);
123 template<
typename BlockPrefixCallbackOp>
125 void InclusiveSum(T input, T& output, BlockPrefixCallbackOp& block_prefix_callback_op)
127 base_type::inclusive_scan(
128 input, output, temp_storage_, block_prefix_callback_op, ::
hipcub::Sum()
132 template<
int ITEMS_PER_THREAD>
134 void InclusiveSum(T(&input)[ITEMS_PER_THREAD], T(&output)[ITEMS_PER_THREAD])
136 base_type::inclusive_scan(input, output, temp_storage_);
139 template<
int ITEMS_PER_THREAD>
141 void InclusiveSum(T(&input)[ITEMS_PER_THREAD], T(&output)[ITEMS_PER_THREAD],
144 base_type::inclusive_scan(input, output, block_aggregate, temp_storage_);
147 template<
int ITEMS_PER_THREAD,
typename BlockPrefixCallbackOp>
149 void InclusiveSum(T(&input)[ITEMS_PER_THREAD], T(&output)[ITEMS_PER_THREAD],
150 BlockPrefixCallbackOp& block_prefix_callback_op)
152 base_type::inclusive_scan(
153 input, output, temp_storage_, block_prefix_callback_op, ::
hipcub::Sum()
157 template<
typename ScanOp>
159 void InclusiveScan(T input, T& output, ScanOp scan_op)
161 base_type::inclusive_scan(input, output, temp_storage_, scan_op);
164 template<
typename ScanOp>
166 void InclusiveScan(T input, T& output, ScanOp scan_op, T& block_aggregate)
168 base_type::inclusive_scan(input, output, block_aggregate, temp_storage_, scan_op);
171 template<
typename ScanOp,
typename BlockPrefixCallbackOp>
173 void InclusiveScan(T input, T& output, ScanOp scan_op, BlockPrefixCallbackOp& block_prefix_callback_op)
175 base_type::inclusive_scan(
176 input, output, temp_storage_, block_prefix_callback_op, scan_op
180 template<
int ITEMS_PER_THREAD,
typename ScanOp>
182 void InclusiveScan(T(&input)[ITEMS_PER_THREAD], T(&output)[ITEMS_PER_THREAD], ScanOp scan_op)
184 base_type::inclusive_scan(input, output, temp_storage_, scan_op);
187 template<
int ITEMS_PER_THREAD,
typename ScanOp>
189 void InclusiveScan(T(&input)[ITEMS_PER_THREAD], T(&output)[ITEMS_PER_THREAD],
190 ScanOp scan_op, T& block_aggregate)
192 base_type::inclusive_scan(input, output, block_aggregate, temp_storage_, scan_op);
195 template<
int ITEMS_PER_THREAD,
typename ScanOp,
typename BlockPrefixCallbackOp>
197 void InclusiveScan(T(&input)[ITEMS_PER_THREAD], T(&output)[ITEMS_PER_THREAD],
198 ScanOp scan_op, BlockPrefixCallbackOp& block_prefix_callback_op)
200 base_type::inclusive_scan(
201 input, output, temp_storage_, block_prefix_callback_op, scan_op
206 void ExclusiveSum(T input, T& output)
208 base_type::exclusive_scan(input, output, T(0), temp_storage_);
212 void ExclusiveSum(T input, T& output, T& block_aggregate)
214 base_type::exclusive_scan(input, output, T(0), block_aggregate, temp_storage_);
217 template<
typename BlockPrefixCallbackOp>
219 void ExclusiveSum(T input, T& output, BlockPrefixCallbackOp& block_prefix_callback_op)
221 base_type::exclusive_scan(
222 input, output, temp_storage_, block_prefix_callback_op, ::
hipcub::Sum()
226 template<
int ITEMS_PER_THREAD>
228 void ExclusiveSum(T(&input)[ITEMS_PER_THREAD], T(&output)[ITEMS_PER_THREAD])
230 base_type::exclusive_scan(input, output, T(0), temp_storage_);
233 template<
int ITEMS_PER_THREAD>
235 void ExclusiveSum(T(&input)[ITEMS_PER_THREAD], T(&output)[ITEMS_PER_THREAD],
238 base_type::exclusive_scan(input, output, T(0), block_aggregate, temp_storage_);
241 template<
int ITEMS_PER_THREAD,
typename BlockPrefixCallbackOp>
243 void ExclusiveSum(T(&input)[ITEMS_PER_THREAD], T(&output)[ITEMS_PER_THREAD],
244 BlockPrefixCallbackOp& block_prefix_callback_op)
246 base_type::exclusive_scan(
247 input, output, temp_storage_, block_prefix_callback_op, ::
hipcub::Sum()
251 template<
typename ScanOp>
253 void ExclusiveScan(T input, T& output, T initial_value, ScanOp scan_op)
255 base_type::exclusive_scan(input, output, initial_value, temp_storage_, scan_op);
258 template<
typename ScanOp>
260 void ExclusiveScan(T input, T& output, T initial_value,
261 ScanOp scan_op, T& block_aggregate)
263 base_type::exclusive_scan(
264 input, output, initial_value, block_aggregate, temp_storage_, scan_op
268 template<
typename ScanOp,
typename BlockPrefixCallbackOp>
270 void ExclusiveScan(T input, T& output, ScanOp scan_op,
271 BlockPrefixCallbackOp& block_prefix_callback_op)
273 base_type::exclusive_scan(
274 input, output, temp_storage_, block_prefix_callback_op, scan_op
278 template<
int ITEMS_PER_THREAD,
typename ScanOp>
280 void ExclusiveScan(T(&input)[ITEMS_PER_THREAD], T(&output)[ITEMS_PER_THREAD],
281 T initial_value, ScanOp scan_op)
283 base_type::exclusive_scan(input, output, initial_value, temp_storage_, scan_op);
286 template<
int ITEMS_PER_THREAD,
typename ScanOp>
288 void ExclusiveScan(T(&input)[ITEMS_PER_THREAD], T(&output)[ITEMS_PER_THREAD],
289 T initial_value, ScanOp scan_op, T& block_aggregate)
291 base_type::exclusive_scan(
292 input, output, initial_value, block_aggregate, temp_storage_, scan_op
296 template<
int ITEMS_PER_THREAD,
typename ScanOp,
typename BlockPrefixCallbackOp>
298 void ExclusiveScan(T(&input)[ITEMS_PER_THREAD], T(&output)[ITEMS_PER_THREAD],
299 ScanOp scan_op, BlockPrefixCallbackOp& block_prefix_callback_op)
301 base_type::exclusive_scan(
302 input, output, temp_storage_, block_prefix_callback_op, scan_op
308 TempStorage& private_storage()
310 HIPCUB_SHARED_MEMORY TempStorage private_storage;
311 return private_storage;
Definition: block_scan.hpp:80
Definition: thread_operators.hpp:76