30 #ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_RADIX_SORT_HPP_
31 #define HIPCUB_ROCPRIM_BLOCK_BLOCK_RADIX_SORT_HPP_
33 #include "../../../config.hpp"
35 #include "../util_type.hpp"
37 #include <rocprim/functional.hpp>
38 #include <rocprim/block/block_radix_sort.hpp>
40 #include "block_scan.hpp"
42 BEGIN_HIPCUB_NAMESPACE
48 typename ValueT = NullType,
50 bool MEMOIZE_OUTER_SCAN =
true,
51 BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS,
52 hipSharedMemConfig SMEM_CONFIG = hipSharedMemBankSizeFourByte,
55 int PTX_ARCH = HIPCUB_ARCH
58 :
private ::rocprim::block_radix_sort<
68 BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z > 0,
69 "BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z must be greater than 0"
73 typename ::rocprim::block_radix_sort<
83 typename base_type::storage_type& temp_storage_;
86 using TempStorage =
typename base_type::storage_type;
94 BlockRadixSort(TempStorage& temp_storage) : temp_storage_(temp_storage)
99 void Sort(KeyT (&keys)[ITEMS_PER_THREAD],
101 int end_bit =
sizeof(KeyT) * 8)
103 base_type::sort(keys, temp_storage_, begin_bit, end_bit);
107 void Sort(KeyT (&keys)[ITEMS_PER_THREAD],
108 ValueT (&values)[ITEMS_PER_THREAD],
110 int end_bit =
sizeof(KeyT) * 8)
112 base_type::sort(keys, values, temp_storage_, begin_bit, end_bit);
116 void SortDescending(KeyT (&keys)[ITEMS_PER_THREAD],
118 int end_bit =
sizeof(KeyT) * 8)
120 base_type::sort_desc(keys, temp_storage_, begin_bit, end_bit);
124 void SortDescending(KeyT (&keys)[ITEMS_PER_THREAD],
125 ValueT (&values)[ITEMS_PER_THREAD],
127 int end_bit =
sizeof(KeyT) * 8)
129 base_type::sort_desc(keys, values, temp_storage_, begin_bit, end_bit);
133 void SortBlockedToStriped(KeyT (&keys)[ITEMS_PER_THREAD],
135 int end_bit =
sizeof(KeyT) * 8)
137 base_type::sort_to_striped(keys, temp_storage_, begin_bit, end_bit);
141 void SortBlockedToStriped(KeyT (&keys)[ITEMS_PER_THREAD],
142 ValueT (&values)[ITEMS_PER_THREAD],
144 int end_bit =
sizeof(KeyT) * 8)
146 base_type::sort_to_striped(keys, values, temp_storage_, begin_bit, end_bit);
150 void SortDescendingBlockedToStriped(KeyT (&keys)[ITEMS_PER_THREAD],
152 int end_bit =
sizeof(KeyT) * 8)
154 base_type::sort_desc_to_striped(keys, temp_storage_, begin_bit, end_bit);
158 void SortDescendingBlockedToStriped(KeyT (&keys)[ITEMS_PER_THREAD],
159 ValueT (&values)[ITEMS_PER_THREAD],
161 int end_bit =
sizeof(KeyT) * 8)
163 base_type::sort_desc_to_striped(keys, values, temp_storage_, begin_bit, end_bit);
168 TempStorage& private_storage()
170 HIPCUB_SHARED_MEMORY TempStorage private_storage;
171 return private_storage;
Definition: block_radix_sort.hpp:66