35 #ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_RADIX_RANK_HPP_
36 #define HIPCUB_ROCPRIM_BLOCK_BLOCK_RADIX_RANK_HPP_
40 #include "../../../config.hpp"
41 #include "../../../util_type.hpp"
42 #include "../../../util_ptx.hpp"
44 #include "../block/block_scan.hpp"
45 #include "../block/radix_rank_sort_operations.hpp"
46 #include "../thread/thread_reduce.hpp"
47 #include "../thread/thread_scan.hpp"
49 BEGIN_HIPCUB_NAMESPACE
91 bool MEMOIZE_OUTER_SCAN =
false,
92 BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS,
93 hipSharedMemConfig SMEM_CONFIG = hipSharedMemBankSizeFourByte,
96 int ARCH = HIPCUB_ARCH >
106 typedef unsigned short DigitCounter;
109 typedef typename std::conditional<(SMEM_CONFIG == hipSharedMemBankSizeEightByte),
111 unsigned int>::type PackedCounter;
116 BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z,
118 RADIX_DIGITS = 1 << RADIX_BITS,
121 WARP_THREADS = 1 << LOG_WARP_THREADS,
122 WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS,
124 BYTES_PER_COUNTER =
sizeof(DigitCounter),
127 PACKING_RATIO =
sizeof(PackedCounter) /
sizeof(DigitCounter),
130 LOG_COUNTER_LANES = rocprim::maximum<int>()((int(RADIX_BITS) - int(LOG_PACKING_RATIO)), 0),
131 COUNTER_LANES = 1 << LOG_COUNTER_LANES,
134 PADDED_COUNTER_LANES = COUNTER_LANES + 1,
135 RAKING_SEGMENT = PADDED_COUNTER_LANES,
143 BINS_TRACKED_PER_THREAD = rocprim::maximum<int>()(1, (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS),
153 INNER_SCAN_ALGORITHM,
159 #ifndef DOXYGEN_SHOULD_SKIP_THIS
162 struct __align__(16) _TempStorage
166 DigitCounter digit_counters[PADDED_COUNTER_LANES * BLOCK_THREADS * PACKING_RATIO];
167 PackedCounter raking_grid[BLOCK_THREADS * RAKING_SEGMENT];
172 typename BlockScan::TempStorage block_scan;
182 _TempStorage &temp_storage;
185 unsigned int linear_tid;
188 PackedCounter cached_segment[RAKING_SEGMENT];
198 HIPCUB_DEVICE
inline _TempStorage& PrivateStorage()
200 __shared__ _TempStorage private_storage;
201 return private_storage;
208 HIPCUB_DEVICE
inline PackedCounter Upsweep()
210 PackedCounter *smem_raking_ptr = &temp_storage.aliasable.raking_grid[linear_tid * RAKING_SEGMENT];
211 PackedCounter *raking_ptr;
213 if (MEMOIZE_OUTER_SCAN)
217 for (
int i = 0; i < RAKING_SEGMENT; i++)
219 cached_segment[i] = smem_raking_ptr[i];
221 raking_ptr = cached_segment;
225 raking_ptr = smem_raking_ptr;
228 return internal::ThreadReduce<RAKING_SEGMENT>(raking_ptr, Sum());
233 HIPCUB_DEVICE
inline void ExclusiveDownsweep(
234 PackedCounter raking_partial)
236 PackedCounter *smem_raking_ptr = &temp_storage.aliasable.raking_grid[linear_tid * RAKING_SEGMENT];
238 PackedCounter *raking_ptr = (MEMOIZE_OUTER_SCAN) ?
243 internal::ThreadScanExclusive<RAKING_SEGMENT>(raking_ptr, raking_ptr, Sum(), raking_partial);
245 if (MEMOIZE_OUTER_SCAN)
249 for (
int i = 0; i < RAKING_SEGMENT; i++)
251 smem_raking_ptr[i] = cached_segment[i];
260 HIPCUB_DEVICE
inline void ResetCounters()
264 for (
int LANE = 0; LANE < PADDED_COUNTER_LANES; LANE++)
267 for (
int SUB_COUNTER = 0; SUB_COUNTER < PACKING_RATIO; SUB_COUNTER++)
269 temp_storage.aliasable.digit_counters[(LANE * BLOCK_THREADS + linear_tid) * PACKING_RATIO + SUB_COUNTER] = 0;
278 struct PrefixCallBack
280 HIPCUB_DEVICE
inline PackedCounter operator()(PackedCounter block_aggregate)
282 PackedCounter block_prefix = 0;
286 for (
int PACKED = 1; PACKED < PACKING_RATIO; PACKED++)
288 block_prefix += block_aggregate << (
sizeof(DigitCounter) * 8 * PACKED);
299 HIPCUB_DEVICE
inline void ScanCounters()
302 PackedCounter raking_partial = Upsweep();
305 PackedCounter exclusive_partial;
306 PrefixCallBack prefix_call_back;
307 BlockScan(temp_storage.block_scan).ExclusiveSum(raking_partial, exclusive_partial, prefix_call_back);
310 ExclusiveDownsweep(exclusive_partial);
329 temp_storage(PrivateStorage()),
330 linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
340 temp_storage(temp_storage.Alias()),
341 linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
355 typename UnsignedBits,
357 typename DigitExtractorT>
359 UnsignedBits (&keys)[KEYS_PER_THREAD],
360 int (&ranks)[KEYS_PER_THREAD],
361 DigitExtractorT digit_extractor)
363 DigitCounter thread_prefixes[KEYS_PER_THREAD];
364 DigitCounter* digit_counters[KEYS_PER_THREAD];
370 for (
int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
373 unsigned int digit = digit_extractor.Digit(keys[ITEM]);
376 unsigned int sub_counter = digit >> LOG_COUNTER_LANES;
379 unsigned int counter_lane = digit & (COUNTER_LANES - 1);
383 sub_counter = PACKING_RATIO - 1 - sub_counter;
384 counter_lane = COUNTER_LANES - 1 - counter_lane;
388 digit_counters[ITEM] = &temp_storage.aliasable.digit_counters[counter_lane * BLOCK_THREADS * PACKING_RATIO + linear_tid * PACKING_RATIO + sub_counter];
391 thread_prefixes[ITEM] = *digit_counters[ITEM];
394 *digit_counters[ITEM] = thread_prefixes[ITEM] + 1;
397 ::rocprim::syncthreads();
402 ::rocprim::syncthreads();
406 for (
int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
409 ranks[ITEM] = thread_prefixes[ITEM] + *digit_counters[ITEM];
418 typename UnsignedBits,
420 typename DigitExtractorT>
422 UnsignedBits (&keys)[KEYS_PER_THREAD],
423 int (&ranks)[KEYS_PER_THREAD],
424 DigitExtractorT digit_extractor,
425 int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD])
428 RankKeys(keys, ranks, digit_extractor);
432 for (
int track = 0; track < BINS_TRACKED_PER_THREAD; ++track)
434 int bin_idx = (linear_tid * BINS_TRACKED_PER_THREAD) + track;
436 if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS))
439 bin_idx = RADIX_DIGITS - bin_idx - 1;
443 unsigned int counter_lane = (bin_idx & (COUNTER_LANES - 1));
444 unsigned int sub_counter = bin_idx >> (LOG_COUNTER_LANES);
446 exclusive_digit_prefix[track] = temp_storage.aliasable.digit_counter[counter_lane * BLOCK_THREADS * PACKING_RATIO + sub_counter];
463 BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS,
466 int ARCH = HIPCUB_ARCH>
475 typedef int32_t RankT;
476 typedef int32_t DigitCounterT;
481 BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z,
483 RADIX_DIGITS = 1 << RADIX_BITS,
486 WARP_THREADS = 1 << LOG_WARP_THREADS,
487 WARPS = (BLOCK_THREADS + WARP_THREADS - 1) / WARP_THREADS,
489 PADDED_WARPS = ((WARPS & 0x1) == 0) ?
493 COUNTERS = PADDED_WARPS * RADIX_DIGITS,
494 RAKING_SEGMENT = (COUNTERS + BLOCK_THREADS - 1) / BLOCK_THREADS,
495 PADDED_RAKING_SEGMENT = ((RAKING_SEGMENT & 0x1) == 0) ?
505 BINS_TRACKED_PER_THREAD = rocprim::maximum<int>()(1, (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS),
514 INNER_SCAN_ALGORITHM,
520 #ifndef DOXYGEN_SHOULD_SKIP_THIS
522 struct __align__(16) _TempStorage
524 typename BlockScanT::TempStorage block_scan;
526 union __align__(16) Aliasable
528 volatile DigitCounterT warp_digit_counters[RADIX_DIGITS * PADDED_WARPS];
529 DigitCounterT raking_grid[BLOCK_THREADS * PADDED_RAKING_SEGMENT];
540 _TempStorage &temp_storage;
543 unsigned int linear_tid;
565 temp_storage(temp_storage.Alias()),
566 linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
580 typename UnsignedBits,
582 typename DigitExtractorT>
584 UnsignedBits (&keys)[KEYS_PER_THREAD],
585 int (&ranks)[KEYS_PER_THREAD],
586 DigitExtractorT digit_extractor)
591 for (
int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM)
592 temp_storage.aliasable.raking_grid[linear_tid * PADDED_RAKING_SEGMENT + ITEM] = 0;
594 ::rocprim::syncthreads();
598 volatile DigitCounterT *digit_counters[KEYS_PER_THREAD];
599 uint32_t warp_id = linear_tid >> LOG_WARP_THREADS;
600 uint32_t lane_mask_lt = LaneMaskLt();
603 for (
int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
606 uint32_t digit = digit_extractor.Digit(keys[ITEM]);
609 digit = RADIX_DIGITS - digit - 1;
612 uint32_t peer_mask = rocprim::MatchAny<RADIX_BITS>(digit);
615 digit_counters[ITEM] = &temp_storage.aliasable.warp_digit_counters[digit * PADDED_WARPS + warp_id];
618 DigitCounterT warp_digit_prefix = *digit_counters[ITEM];
621 WARP_SYNC(0xFFFFFFFF);
624 int32_t digit_count = __popc(peer_mask);
627 int32_t peer_digit_prefix = __popc(peer_mask & lane_mask_lt);
629 if (peer_digit_prefix == 0)
632 *digit_counters[ITEM] = DigitCounterT(warp_digit_prefix + digit_count);
636 WARP_SYNC(0xFFFFFFFF);
639 ranks[ITEM] = warp_digit_prefix + DigitCounterT(peer_digit_prefix);
642 ::rocprim::syncthreads();
646 DigitCounterT scan_counters[PADDED_RAKING_SEGMENT];
649 for (
int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM)
650 scan_counters[ITEM] = temp_storage.aliasable.raking_grid[linear_tid * PADDED_RAKING_SEGMENT + ITEM];
652 BlockScanT(temp_storage.block_scan).ExclusiveSum(scan_counters, scan_counters);
655 for (
int ITEM = 0; ITEM < PADDED_RAKING_SEGMENT; ++ITEM)
656 temp_storage.aliasable.raking_grid[linear_tid * PADDED_RAKING_SEGMENT + ITEM] = scan_counters[ITEM];
658 ::rocprim::syncthreads();
662 for (
int ITEM = 0; ITEM < KEYS_PER_THREAD; ++ITEM)
663 ranks[ITEM] += *digit_counters[ITEM];
671 typename UnsignedBits,
673 typename DigitExtractorT>
675 UnsignedBits (&keys)[KEYS_PER_THREAD],
676 int (&ranks)[KEYS_PER_THREAD],
677 DigitExtractorT digit_extractor,
678 int (&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD])
680 RankKeys(keys, ranks, digit_extractor);
684 for (
int track = 0; track < BINS_TRACKED_PER_THREAD; ++track)
686 int bin_idx = (linear_tid * BINS_TRACKED_PER_THREAD) + track;
688 if ((BLOCK_THREADS == RADIX_DIGITS) || (bin_idx < RADIX_DIGITS))
691 bin_idx = RADIX_DIGITS - bin_idx - 1;
693 exclusive_digit_prefix[track] = temp_storage.aliasable.warp_digit_counters[bin_idx * PADDED_WARPS];
Definition: block_radix_rank.hpp:468
__device__ BlockRadixRankMatch(TempStorage &temp_storage)
Collective constructor using the specified memory allocation as temporary storage.
Definition: block_radix_rank.hpp:562
__device__ __forceinline__ void RankKeys(UnsignedBits(&keys)[KEYS_PER_THREAD], int(&ranks)[KEYS_PER_THREAD], DigitExtractorT digit_extractor, int(&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD])
Rank keys. For the lower RADIX_DIGITS threads, digit counts for each digit are provided for the corre...
Definition: block_radix_rank.hpp:674
__device__ __forceinline__ void RankKeys(UnsignedBits(&keys)[KEYS_PER_THREAD], int(&ranks)[KEYS_PER_THREAD], DigitExtractorT digit_extractor)
Rank keys.
Definition: block_radix_rank.hpp:583
BlockRadixRank provides operations for ranking unsigned integer types within a CUDA thread block.
Definition: block_radix_rank.hpp:98
__device__ BlockRadixRank(TempStorage &temp_storage)
Collective constructor using the specified memory allocation as temporary storage.
Definition: block_radix_rank.hpp:337
__device__ BlockRadixRank()
Collective constructor using a private static allocation of shared memory as temporary storage.
Definition: block_radix_rank.hpp:327
__device__ void RankKeys(UnsignedBits(&keys)[KEYS_PER_THREAD], int(&ranks)[KEYS_PER_THREAD], DigitExtractorT digit_extractor, int(&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD])
Rank keys. For the lower RADIX_DIGITS threads, digit counts for each digit are provided for the corre...
Definition: block_radix_rank.hpp:421
__device__ void RankKeys(UnsignedBits(&keys)[KEYS_PER_THREAD], int(&ranks)[KEYS_PER_THREAD], DigitExtractorT digit_extractor)
Rank keys.
Definition: block_radix_rank.hpp:358
Definition: block_scan.hpp:80
\smemstorage{BlockScan}
Definition: block_radix_rank.hpp:550
\smemstorage{BlockScan}
Definition: block_radix_rank.hpp:316
Definition: util_type.hpp:101
A storage-backing wrapper that allows types with non-trivial constructors to be aliased in unions.
Definition: util_type.hpp:363