35 #ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_RADIX_RANK_HPP_
36 #define HIPCUB_ROCPRIM_BLOCK_BLOCK_RADIX_RANK_HPP_
40 #include "../../../config.hpp"
41 #include "../../../util_type.hpp"
42 #include "../../../util_ptx.hpp"
44 #include "../thread/thread_reduce.hpp"
45 #include "../block/block_scan.hpp"
46 #include "../block/radix_rank_sort_operations.hpp"
48 BEGIN_HIPCUB_NAMESPACE
90 bool MEMOIZE_OUTER_SCAN =
false,
91 BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS,
92 hipSharedMemConfig SMEM_CONFIG = hipSharedMemBankSizeFourByte,
95 int ARCH = HIPCUB_ARCH >
105 typedef unsigned short DigitCounter;
108 typedef typename If<(SMEM_CONFIG == hipSharedMemBankSizeEightByte),
110 unsigned int>::Type PackedCounter;
115 BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z,
117 RADIX_DIGITS = 1 << RADIX_BITS,
120 WARP_THREADS = 1 << LOG_WARP_THREADS,
121 WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS,
123 BYTES_PER_COUNTER =
sizeof(DigitCounter),
126 PACKING_RATIO =
sizeof(PackedCounter) /
sizeof(DigitCounter),
129 LOG_COUNTER_LANES = rocprim::maximum<int>()((int(RADIX_BITS) - int(LOG_PACKING_RATIO)), 0),
130 COUNTER_LANES = 1 << LOG_COUNTER_LANES,
133 PADDED_COUNTER_LANES = COUNTER_LANES + 1,
134 RAKING_SEGMENT = PADDED_COUNTER_LANES,
142 BINS_TRACKED_PER_THREAD = rocprim::maximum<int>()(1, (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS),
152 INNER_SCAN_ALGORITHM,
158 #ifndef DOXYGEN_SHOULD_SKIP_THIS
161 struct __align__(16) _TempStorage
165 DigitCounter digit_counters[PADDED_COUNTER_LANES * BLOCK_THREADS * PACKING_RATIO];
166 PackedCounter raking_grid[BLOCK_THREADS * RAKING_SEGMENT];
171 typename BlockScan::TempStorage block_scan;
181 _TempStorage &temp_storage;
184 unsigned int linear_tid;
187 PackedCounter cached_segment[RAKING_SEGMENT];
197 HIPCUB_DEVICE
inline _TempStorage& PrivateStorage()
199 __shared__ _TempStorage private_storage;
200 return private_storage;
207 HIPCUB_DEVICE
inline PackedCounter Upsweep()
209 PackedCounter *smem_raking_ptr = &temp_storage.aliasable.raking_grid[linear_tid * RAKING_SEGMENT];
210 PackedCounter *raking_ptr;
212 if (MEMOIZE_OUTER_SCAN)
216 for (
int i = 0; i < RAKING_SEGMENT; i++)
218 cached_segment[i] = smem_raking_ptr[i];
220 raking_ptr = cached_segment;
224 raking_ptr = smem_raking_ptr;
227 return internal::ThreadReduce<RAKING_SEGMENT>(raking_ptr, Sum());
232 HIPCUB_DEVICE
inline void ExclusiveDownsweep(
233 PackedCounter raking_partial)
235 PackedCounter *smem_raking_ptr = &temp_storage.aliasable.raking_grid[linear_tid * RAKING_SEGMENT];
237 PackedCounter *raking_ptr = (MEMOIZE_OUTER_SCAN) ?
242 internal::ThreadScanExclusive<RAKING_SEGMENT>(raking_ptr, raking_ptr, Sum(), raking_partial);
244 if (MEMOIZE_OUTER_SCAN)
248 for (
int i = 0; i < RAKING_SEGMENT; i++)
250 smem_raking_ptr[i] = cached_segment[i];
259 HIPCUB_DEVICE
inline void ResetCounters()
263 for (
int LANE = 0; LANE < PADDED_COUNTER_LANES; LANE++)
266 for (
int SUB_COUNTER = 0; SUB_COUNTER < PACKING_RATIO; SUB_COUNTER++)
268 temp_storage.aliasable.digit_counters[(LANE * BLOCK_THREADS + linear_tid) * PACKING_RATIO + SUB_COUNTER] = 0;
277 struct PrefixCallBack
279 HIPCUB_DEVICE
inline PackedCounter operator()(PackedCounter block_aggregate)
281 PackedCounter block_prefix = 0;
285 for (
int PACKED = 1; PACKED < PACKING_RATIO; PACKED++)
287 block_prefix += block_aggregate << (
sizeof(DigitCounter) * 8 * PACKED);
298 HIPCUB_DEVICE
inline void ScanCounters()
301 PackedCounter raking_partial = Upsweep();
304 PackedCounter exclusive_partial;
305 PrefixCallBack prefix_call_back;
306 BlockScan(temp_storage.block_scan).ExclusiveSum(raking_partial, exclusive_partial, prefix_call_back);
309 ExclusiveDownsweep(exclusive_partial);
328 temp_storage(PrivateStorage()),
329 linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
339 temp_storage(temp_storage.Alias()),
340 linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
354 typename UnsignedBits,
356 typename DigitExtractorT>
358 UnsignedBits (&keys)[KEYS_PER_THREAD],
359 int (&ranks)[KEYS_PER_THREAD],
360 DigitExtractorT digit_extractor)
362 DigitCounter thread_prefixes[KEYS_PER_THREAD];
363 DigitCounter* digit_counters[KEYS_PER_THREAD];
369 for (
int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
372 unsigned int digit = digit_extractor.Digit(keys[ITEM]);
375 unsigned int sub_counter = digit >> LOG_COUNTER_LANES;
378 unsigned int counter_lane = digit & (COUNTER_LANES - 1);
382 sub_counter = PACKING_RATIO - 1 - sub_counter;
383 counter_lane = COUNTER_LANES - 1 - counter_lane;
387 digit_counters[ITEM] = &temp_storage.aliasable.digit_counters[counter_lane * BLOCK_THREADS * PACKING_RATIO + linear_tid * PACKING_RATIO + sub_counter];
390 thread_prefixes[ITEM] = *digit_counters[ITEM];
393 *digit_counters[ITEM] = thread_prefixes[ITEM] + 1;
396 ::rocprim::syncthreads();
401 ::rocprim::syncthreads();
405 for (
int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
408 ranks[ITEM] = thread_prefixes[ITEM] + *digit_counters[ITEM];
417 typename UnsignedBits,
419 typename DigitExtractorT>
421 UnsignedBits (&keys)[KEYS_PER_THREAD],
422 int (&ranks)[KEYS_PER_THREAD],
423 DigitExtractorT digit_extractor,
424 int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD])
427 RankKeys(keys, ranks, digit_extractor);
431 for (
int track = 0; track < BINS_TRACKED_PER_THREAD; ++track)
433 int bin_idx = (linear_tid * BINS_TRACKED_PER_THREAD) + track;
435 if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS))
438 bin_idx = RADIX_DIGITS - bin_idx - 1;
442 unsigned int counter_lane = (bin_idx & (COUNTER_LANES - 1));
443 unsigned int sub_counter = bin_idx >> (LOG_COUNTER_LANES);
445 exclusive_digit_prefix[track] = temp_storage.aliasable.digit_counter[counter_lane * BLOCK_THREADS * PACKING_RATIO + sub_counter];
462 BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS,
465 int ARCH = HIPCUB_ARCH>
474 typedef int32_t RankT;
475 typedef int32_t DigitCounterT;
480 BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z,
482 RADIX_DIGITS = 1 << RADIX_BITS,
485 WARP_THREADS = 1 << LOG_WARP_THREADS,
486 WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS,
488 PADDED_WARPS = ((WARPS & 0x1) == 0) ?
492 COUNTERS = PADDED_WARPS * RADIX_DIGITS,
493 RAKING_SEGMENT = (COUNTERS + BLOCK_THREADS - 1) / BLOCK_THREADS,
494 PADDED_RAKING_SEGMENT = ((RAKING_SEGMENT & 0x1) == 0) ?
504 BINS_TRACKED_PER_THREAD = rocprim::maximum<int>()(1, (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS),
513 INNER_SCAN_ALGORITHM,
519 #ifndef DOXYGEN_SHOULD_SKIP_THIS
521 struct __align__(16) _TempStorage
523 typename BlockScanT::TempStorage block_scan;
525 union __align__(16) Aliasable
527 volatile DigitCounterT warp_digit_counters[RADIX_DIGITS * PADDED_WARPS];
528 DigitCounterT raking_grid[BLOCK_THREADS * PADDED_RAKING_SEGMENT];
539 _TempStorage &temp_storage;
542 unsigned int linear_tid;
564 temp_storage(temp_storage.Alias()),
565 linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
579 typename UnsignedBits,
581 typename DigitExtractorT>
583 UnsignedBits (&keys)[KEYS_PER_THREAD],
584 int (&ranks)[KEYS_PER_THREAD],
585 DigitExtractorT digit_extractor)
590 for (
int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM)
591 temp_storage.aliasable.raking_grid[linear_tid * PADDED_RAKING_SEGMENT + ITEM] = 0;
593 ::rocprim::syncthreads();
597 volatile DigitCounterT *digit_counters[KEYS_PER_THREAD];
598 uint32_t warp_id = linear_tid >> LOG_WARP_THREADS;
599 uint32_t lane_mask_lt = LaneMaskLt();
602 for (
int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
605 uint32_t digit = digit_extractor.Digit(keys[ITEM]);
608 digit = RADIX_DIGITS - digit - 1;
611 uint32_t peer_mask = rocprim::MatchAny<RADIX_BITS>(digit);
614 digit_counters[ITEM] = &temp_storage.aliasable.warp_digit_counters[digit * PADDED_WARPS + warp_id];
617 DigitCounterT warp_digit_prefix = *digit_counters[ITEM];
620 WARP_SYNC(0xFFFFFFFF);
623 int32_t digit_count = __popc(peer_mask);
626 int32_t peer_digit_prefix = __popc(peer_mask & lane_mask_lt);
628 if (peer_digit_prefix == 0)
631 *digit_counters[ITEM] = DigitCounterT(warp_digit_prefix + digit_count);
635 WARP_SYNC(0xFFFFFFFF);
638 ranks[ITEM] = warp_digit_prefix + DigitCounterT(peer_digit_prefix);
641 ::rocprim::syncthreads();
645 DigitCounterT scan_counters[PADDED_RAKING_SEGMENT];
648 for (
int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM)
649 scan_counters[ITEM] = temp_storage.aliasable.raking_grid[linear_tid * PADDED_RAKING_SEGMENT + ITEM];
651 BlockScanT(temp_storage.block_scan).ExclusiveSum(scan_counters, scan_counters);
654 for (
int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM)
655 temp_storage.aliasable.raking_grid[linear_tid * PADDED_RAKING_SEGMENT + ITEM] = scan_counters[ITEM];
657 ::rocprim::syncthreads();
661 for (
int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
662 ranks[ITEM] += *digit_counters[ITEM];
670 typename UnsignedBits,
672 typename DigitExtractorT>
674 UnsignedBits (&keys)[KEYS_PER_THREAD],
675 int (&ranks)[KEYS_PER_THREAD],
676 DigitExtractorT digit_extractor,
677 int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD])
679 RankKeys(keys, ranks, digit_extractor);
683 for (
int track = 0; track < BINS_TRACKED_PER_THREAD; ++track)
685 int bin_idx = (linear_tid * BINS_TRACKED_PER_THREAD) + track;
687 if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS))
690 bin_idx = RADIX_DIGITS - bin_idx - 1;
692 exclusive_digit_prefix[track] = temp_storage.aliasable.warp_digit_counters[bin_idx * PADDED_WARPS];
Definition: block_radix_rank.hpp:467
__device__ BlockRadixRankMatch(TempStorage &temp_storage)
Collective constructor using the specified memory allocation as temporary storage.
Definition: block_radix_rank.hpp:561
__device__ __forceinline__ void RankKeys(UnsignedBits(&keys)[KEYS_PER_THREAD], int(&ranks)[KEYS_PER_THREAD], DigitExtractorT digit_extractor, int(&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD])
Rank keys. For the lower RADIX_DIGITS threads, digit counts for each digit are provided for the corre...
Definition: block_radix_rank.hpp:673
__device__ __forceinline__ void RankKeys(UnsignedBits(&keys)[KEYS_PER_THREAD], int(&ranks)[KEYS_PER_THREAD], DigitExtractorT digit_extractor)
Rank keys.
Definition: block_radix_rank.hpp:582
BlockRadixRank provides operations for ranking unsigned integer types within a CUDA thread block.
Definition: block_radix_rank.hpp:97
__device__ BlockRadixRank(TempStorage &temp_storage)
Collective constructor using the specified memory allocation as temporary storage.
Definition: block_radix_rank.hpp:336
__device__ BlockRadixRank()
Collective constructor using a private static allocation of shared memory as temporary storage.
Definition: block_radix_rank.hpp:326
__device__ void RankKeys(UnsignedBits(&keys)[KEYS_PER_THREAD], int(&ranks)[KEYS_PER_THREAD], DigitExtractorT digit_extractor, int(&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD])
Rank keys. For the lower RADIX_DIGITS threads, digit counts for each digit are provided for the corre...
Definition: block_radix_rank.hpp:420
__device__ void RankKeys(UnsignedBits(&keys)[KEYS_PER_THREAD], int(&ranks)[KEYS_PER_THREAD], DigitExtractorT digit_extractor)
Rank keys.
Definition: block_radix_rank.hpp:357
Definition: block_scan.hpp:80
\smemstorage{BlockScan}
Definition: block_radix_rank.hpp:549
\smemstorage{BlockScan}
Definition: block_radix_rank.hpp:315
Definition: util_type.hpp:54
Definition: util_type.hpp:101
A storage-backing wrapper that allows types with non-trivial constructors to be aliased in unions.
Definition: util_type.hpp:363