35 #ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_RADIX_RANK_HPP_
36 #define HIPCUB_ROCPRIM_BLOCK_BLOCK_RADIX_RANK_HPP_
40 #include "../../../config.hpp"
41 #include "../../../util_type.hpp"
42 #include "../../../util_ptx.hpp"
44 #include "../block/block_scan.hpp"
45 #include "../block/radix_rank_sort_operations.hpp"
46 #include "../thread/thread_reduce.hpp"
47 #include "../thread/thread_scan.hpp"
51 BEGIN_HIPCUB_NAMESPACE
55 template<
typename DigitExtractorT,
typename Un
signedBits,
int RADIX_BITS,
bool IS_DESCENDING>
56 struct DigitExtractorAdopter
58 DigitExtractorT& digit_extractor_;
60 HIPCUB_DEVICE DigitExtractorAdopter(DigitExtractorT& digit_extractor)
61 : digit_extractor_(digit_extractor)
64 HIPCUB_DEVICE
inline UnsignedBits operator()(
const UnsignedBits key)
66 UnsignedBits digit = digit_extractor_.Digit(key);
70 digit ^= (1 << RADIX_BITS) - 1;
111 template<
int BLOCK_DIM_X,
114 bool MEMOIZE_OUTER_SCAN =
false,
115 BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS,
116 hipSharedMemConfig SMEM_CONFIG = hipSharedMemBankSizeFourByte,
119 int ARCH = HIPCUB_ARCH >
121 :
private ::rocprim::block_radix_rank<BLOCK_DIM_X,
124 ? ::rocprim::block_radix_rank_algorithm::basic_memoize
125 : ::rocprim::block_radix_rank_algorithm::basic,
129 static_assert(BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z > 0,
130 "BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z must be greater than 0");
133 = ::rocprim::block_radix_rank<BLOCK_DIM_X,
136 ? ::rocprim::block_radix_rank_algorithm::basic_memoize
137 : ::rocprim::block_radix_rank_algorithm::basic,
142 using TempStorage =
typename base_type::storage_type;
146 TempStorage& temp_storage_;
148 HIPCUB_DEVICE
inline TempStorage& PrivateStorage()
150 HIPCUB_SHARED_MEMORY TempStorage private_storage;
151 return private_storage;
158 BINS_TRACKED_PER_THREAD = base_type::digits_per_thread,
177 : temp_storage_(temp_storage)
189 template<
typename UnsignedBits,
191 typename DigitExtractorT>
193 UnsignedBits (&keys)[KEYS_PER_THREAD],
194 int (&ranks)[KEYS_PER_THREAD],
195 DigitExtractorT digit_extractor)
197 detail::DigitExtractorAdopter<DigitExtractorT, UnsignedBits, RADIX_BITS, IS_DESCENDING>
198 digit_extractor_adopter(digit_extractor);
199 base_type::rank_keys(keys,
200 reinterpret_cast<unsigned int(&)[KEYS_PER_THREAD]
>(ranks),
202 digit_extractor_adopter);
208 template<
typename UnsignedBits,
210 typename DigitExtractorT>
212 UnsignedBits (&keys)[KEYS_PER_THREAD],
215 DigitExtractorT digit_extractor,
216 int (&exclusive_digit_prefix)
217 [BINS_TRACKED_PER_THREAD])
219 unsigned int counts[BINS_TRACKED_PER_THREAD];
220 detail::DigitExtractorAdopter<DigitExtractorT, UnsignedBits, RADIX_BITS, IS_DESCENDING>
221 digit_extractor_adopter(digit_extractor);
222 base_type::rank_keys(
224 reinterpret_cast<unsigned int(&)[KEYS_PER_THREAD]
>(ranks),
226 digit_extractor_adopter,
227 reinterpret_cast<unsigned int(&)[BINS_TRACKED_PER_THREAD]
>(exclusive_digit_prefix),
235 template<
int BLOCK_DIM_X,
238 BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS,
241 int ARCH = HIPCUB_ARCH>
243 :
private ::rocprim::block_radix_rank<BLOCK_DIM_X,
245 ::rocprim::block_radix_rank_algorithm::match,
249 static_assert(BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z > 0,
250 "BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z must be greater than 0");
252 using base_type = ::rocprim::block_radix_rank<BLOCK_DIM_X,
254 ::rocprim::block_radix_rank_algorithm::match,
259 using TempStorage =
typename base_type::storage_type;
263 TempStorage& temp_storage_;
265 HIPCUB_DEVICE
inline TempStorage& PrivateStorage()
267 HIPCUB_SHARED_MEMORY TempStorage private_storage;
268 return private_storage;
275 BINS_TRACKED_PER_THREAD = base_type::digits_per_thread,
294 : temp_storage_(temp_storage)
306 template<
typename UnsignedBits,
308 typename DigitExtractorT>
310 UnsignedBits (&keys)[KEYS_PER_THREAD],
311 int (&ranks)[KEYS_PER_THREAD],
312 DigitExtractorT digit_extractor)
314 detail::DigitExtractorAdopter<DigitExtractorT, UnsignedBits, RADIX_BITS, IS_DESCENDING>
315 digit_extractor_adopter(digit_extractor);
316 base_type::rank_keys(keys,
317 reinterpret_cast<unsigned int(&)[KEYS_PER_THREAD]
>(ranks),
319 digit_extractor_adopter);
325 template<
typename UnsignedBits,
327 typename DigitExtractorT>
329 UnsignedBits (&keys)[KEYS_PER_THREAD],
332 DigitExtractorT digit_extractor,
333 int (&exclusive_digit_prefix)
334 [BINS_TRACKED_PER_THREAD])
336 unsigned int counts[BINS_TRACKED_PER_THREAD];
337 detail::DigitExtractorAdopter<DigitExtractorT, UnsignedBits, RADIX_BITS, IS_DESCENDING>
338 digit_extractor_adopter(digit_extractor);
339 base_type::rank_keys(
341 reinterpret_cast<unsigned int(&)[KEYS_PER_THREAD]
>(ranks),
343 digit_extractor_adopter,
344 reinterpret_cast<unsigned int(&)[BINS_TRACKED_PER_THREAD]
>(exclusive_digit_prefix),
Definition: block_radix_rank.hpp:248
__device__ BlockRadixRankMatch(TempStorage &temp_storage)
Collective constructor using the specified memory allocation as temporary storage.
Definition: block_radix_rank.hpp:291
__device__ BlockRadixRankMatch()
Collective constructor using a private static allocation of shared memory as temporary storage.
Definition: block_radix_rank.hpp:286
__device__ __forceinline__ void RankKeys(UnsignedBits(&keys)[KEYS_PER_THREAD], int(&ranks)[KEYS_PER_THREAD], DigitExtractorT digit_extractor, int(&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD])
Rank keys. For the lower RADIX_DIGITS threads, digit counts for each digit are provided for the corre...
Definition: block_radix_rank.hpp:328
__device__ __forceinline__ void RankKeys(UnsignedBits(&keys)[KEYS_PER_THREAD], int(&ranks)[KEYS_PER_THREAD], DigitExtractorT digit_extractor)
Rank keys.
Definition: block_radix_rank.hpp:309
BlockRadixRank provides operations for ranking unsigned integer types within a CUDA thread block.
Definition: block_radix_rank.hpp:128
__device__ BlockRadixRank(TempStorage &temp_storage)
Collective constructor using the specified memory allocation as temporary storage.
Definition: block_radix_rank.hpp:174
__device__ BlockRadixRank()
Collective constructor using a private static allocation of shared memory as temporary storage.
Definition: block_radix_rank.hpp:169
__device__ void RankKeys(UnsignedBits(&keys)[KEYS_PER_THREAD], int(&ranks)[KEYS_PER_THREAD], DigitExtractorT digit_extractor, int(&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD])
Rank keys. For the lower RADIX_DIGITS threads, digit counts for each digit are provided for the corre...
Definition: block_radix_rank.hpp:211
__device__ void RankKeys(UnsignedBits(&keys)[KEYS_PER_THREAD], int(&ranks)[KEYS_PER_THREAD], DigitExtractorT digit_extractor)
Rank keys.
Definition: block_radix_rank.hpp:192