/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.7.0/hipcub/include/hipcub/backend/rocprim/block/block_radix_rank.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.7.0/hipcub/include/hipcub/backend/rocprim/block/block_radix_rank.hpp Source File#

hipCUB: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.7.0/hipcub/include/hipcub/backend/rocprim/block/block_radix_rank.hpp Source File
block_radix_rank.hpp
Go to the documentation of this file.
1 /******************************************************************************
2  * Copyright (c) 2011, Duane Merrill. All rights reserved.
3  * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
4  * Modifications Copyright (c) 2021-2022, Advanced Micro Devices, Inc. All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  * * Redistributions of source code must retain the above copyright
9  * notice, this list of conditions and the following disclaimer.
10  * * Redistributions in binary form must reproduce the above copyright
11  * notice, this list of conditions and the following disclaimer in the
12  * documentation and/or other materials provided with the distribution.
13  * * Neither the name of the NVIDIA CORPORATION nor the
14  * names of its contributors may be used to endorse or promote products
15  * derived from this software without specific prior written permission.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
18  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
19  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
21  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
22  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
23  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
24  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27  *
28  ******************************************************************************/
29 
35  #ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_RADIX_RANK_HPP_
36  #define HIPCUB_ROCPRIM_BLOCK_BLOCK_RADIX_RANK_HPP_
37 
38 #include <stdint.h>
39 
40 #include "../../../config.hpp"
41 #include "../../../util_type.hpp"
42 #include "../../../util_ptx.hpp"
43 
44 #include "../block/block_scan.hpp"
45 #include "../block/radix_rank_sort_operations.hpp"
46 #include "../thread/thread_reduce.hpp"
47 #include "../thread/thread_scan.hpp"
48 
50 
51 BEGIN_HIPCUB_NAMESPACE
52 
53 namespace detail
54 {
55 template<typename DigitExtractorT, typename UnsignedBits, int RADIX_BITS, bool IS_DESCENDING>
56 struct DigitExtractorAdopter
57 {
58  DigitExtractorT& digit_extractor_;
59 
60  HIPCUB_DEVICE DigitExtractorAdopter(DigitExtractorT& digit_extractor)
61  : digit_extractor_(digit_extractor)
62  {}
63 
64  HIPCUB_DEVICE inline UnsignedBits operator()(const UnsignedBits key)
65  {
66  UnsignedBits digit = digit_extractor_.Digit(key);
67  if(IS_DESCENDING)
68  {
69  // Flip digit bits
70  digit ^= (1 << RADIX_BITS) - 1;
71  }
72  return digit;
73  }
74 };
75 } // namespace detail
76 
111 template<int BLOCK_DIM_X,
112  int RADIX_BITS,
113  bool IS_DESCENDING,
114  bool MEMOIZE_OUTER_SCAN = false,
115  BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS,
116  hipSharedMemConfig SMEM_CONFIG = hipSharedMemBankSizeFourByte,
117  int BLOCK_DIM_Y = 1,
118  int BLOCK_DIM_Z = 1,
119  int ARCH = HIPCUB_ARCH /* ignored */>
121  : private ::rocprim::block_radix_rank<BLOCK_DIM_X,
122  RADIX_BITS,
123  MEMOIZE_OUTER_SCAN
124  ? ::rocprim::block_radix_rank_algorithm::basic_memoize
125  : ::rocprim::block_radix_rank_algorithm::basic,
126  BLOCK_DIM_Y,
127  BLOCK_DIM_Z>
128 {
129  static_assert(BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z > 0,
130  "BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z must be greater than 0");
131 
132  using base_type
133  = ::rocprim::block_radix_rank<BLOCK_DIM_X,
134  RADIX_BITS,
135  MEMOIZE_OUTER_SCAN
136  ? ::rocprim::block_radix_rank_algorithm::basic_memoize
137  : ::rocprim::block_radix_rank_algorithm::basic,
138  BLOCK_DIM_Y,
139  BLOCK_DIM_Z>;
140 
141 public:
142  using TempStorage = typename base_type::storage_type;
143 
144 private:
145  // Reference to temporary storage (usually shared memory)
146  TempStorage& temp_storage_;
147 
148  HIPCUB_DEVICE inline TempStorage& PrivateStorage()
149  {
150  HIPCUB_SHARED_MEMORY TempStorage private_storage;
151  return private_storage;
152  }
153 
154 public:
155  enum
156  {
158  BINS_TRACKED_PER_THREAD = base_type::digits_per_thread,
159  };
160 
161  /******************************************************************/
165 
169  HIPCUB_DEVICE inline BlockRadixRank() : temp_storage_(PrivateStorage()) {}
170 
174  HIPCUB_DEVICE inline BlockRadixRank(
175  TempStorage&
176  temp_storage)
177  : temp_storage_(temp_storage)
178  {}
179 
181  /******************************************************************/
185 
189  template<typename UnsignedBits,
190  int KEYS_PER_THREAD,
191  typename DigitExtractorT>
192  HIPCUB_DEVICE inline void RankKeys(
193  UnsignedBits (&keys)[KEYS_PER_THREAD],
194  int (&ranks)[KEYS_PER_THREAD],
195  DigitExtractorT digit_extractor)
196  {
197  detail::DigitExtractorAdopter<DigitExtractorT, UnsignedBits, RADIX_BITS, IS_DESCENDING>
198  digit_extractor_adopter(digit_extractor);
199  base_type::rank_keys(keys,
200  reinterpret_cast<unsigned int(&)[KEYS_PER_THREAD]>(ranks),
201  temp_storage_,
202  digit_extractor_adopter);
203  }
204 
208  template<typename UnsignedBits,
209  int KEYS_PER_THREAD,
210  typename DigitExtractorT>
211  HIPCUB_DEVICE inline void RankKeys(
212  UnsignedBits (&keys)[KEYS_PER_THREAD],
213  int (&ranks)
214  [KEYS_PER_THREAD],
215  DigitExtractorT digit_extractor,
216  int (&exclusive_digit_prefix)
217  [BINS_TRACKED_PER_THREAD])
218  {
219  unsigned int counts[BINS_TRACKED_PER_THREAD];
220  detail::DigitExtractorAdopter<DigitExtractorT, UnsignedBits, RADIX_BITS, IS_DESCENDING>
221  digit_extractor_adopter(digit_extractor);
222  base_type::rank_keys(
223  keys,
224  reinterpret_cast<unsigned int(&)[KEYS_PER_THREAD]>(ranks),
225  temp_storage_,
226  digit_extractor_adopter,
227  reinterpret_cast<unsigned int(&)[BINS_TRACKED_PER_THREAD]>(exclusive_digit_prefix),
228  counts);
229  }
230 };
231 
235 template<int BLOCK_DIM_X,
236  int RADIX_BITS,
237  bool IS_DESCENDING,
238  BlockScanAlgorithm INNER_SCAN_ALGORITHM = BLOCK_SCAN_WARP_SCANS,
239  int BLOCK_DIM_Y = 1,
240  int BLOCK_DIM_Z = 1,
241  int ARCH = HIPCUB_ARCH>
243  : private ::rocprim::block_radix_rank<BLOCK_DIM_X,
244  RADIX_BITS,
245  ::rocprim::block_radix_rank_algorithm::match,
246  BLOCK_DIM_Y,
247  BLOCK_DIM_Z>
248 {
249  static_assert(BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z > 0,
250  "BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z must be greater than 0");
251 
252  using base_type = ::rocprim::block_radix_rank<BLOCK_DIM_X,
253  RADIX_BITS,
254  ::rocprim::block_radix_rank_algorithm::match,
255  BLOCK_DIM_Y,
256  BLOCK_DIM_Z>;
257 
258 public:
259  using TempStorage = typename base_type::storage_type;
260 
261 private:
262  // Reference to temporary storage (usually shared memory)
263  TempStorage& temp_storage_;
264 
265  HIPCUB_DEVICE inline TempStorage& PrivateStorage()
266  {
267  HIPCUB_SHARED_MEMORY TempStorage private_storage;
268  return private_storage;
269  }
270 
271 public:
272  enum
273  {
275  BINS_TRACKED_PER_THREAD = base_type::digits_per_thread,
276  };
277 
278  /******************************************************************/
282 
286  HIPCUB_DEVICE inline BlockRadixRankMatch() : temp_storage_(PrivateStorage()) {}
287 
291  HIPCUB_DEVICE inline BlockRadixRankMatch(
292  TempStorage&
293  temp_storage)
294  : temp_storage_(temp_storage)
295  {}
296 
298  /******************************************************************/
302 
306  template<typename UnsignedBits,
307  int KEYS_PER_THREAD,
308  typename DigitExtractorT>
309  __device__ __forceinline__ void RankKeys(
310  UnsignedBits (&keys)[KEYS_PER_THREAD],
311  int (&ranks)[KEYS_PER_THREAD],
312  DigitExtractorT digit_extractor)
313  {
314  detail::DigitExtractorAdopter<DigitExtractorT, UnsignedBits, RADIX_BITS, IS_DESCENDING>
315  digit_extractor_adopter(digit_extractor);
316  base_type::rank_keys(keys,
317  reinterpret_cast<unsigned int(&)[KEYS_PER_THREAD]>(ranks),
318  temp_storage_,
319  digit_extractor_adopter);
320  }
321 
325  template<typename UnsignedBits,
326  int KEYS_PER_THREAD,
327  typename DigitExtractorT>
328  __device__ __forceinline__ void RankKeys(
329  UnsignedBits (&keys)[KEYS_PER_THREAD],
330  int (&ranks)
331  [KEYS_PER_THREAD],
332  DigitExtractorT digit_extractor,
333  int (&exclusive_digit_prefix)
334  [BINS_TRACKED_PER_THREAD])
335  {
336  unsigned int counts[BINS_TRACKED_PER_THREAD];
337  detail::DigitExtractorAdopter<DigitExtractorT, UnsignedBits, RADIX_BITS, IS_DESCENDING>
338  digit_extractor_adopter(digit_extractor);
339  base_type::rank_keys(
340  keys,
341  reinterpret_cast<unsigned int(&)[KEYS_PER_THREAD]>(ranks),
342  temp_storage_,
343  digit_extractor_adopter,
344  reinterpret_cast<unsigned int(&)[BINS_TRACKED_PER_THREAD]>(exclusive_digit_prefix),
345  counts);
346  }
347 };
348 
349 END_HIPCUB_NAMESPACE
350 
351 #endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_RADIX_RANK_HPP_
Definition: block_radix_rank.hpp:248
__device__ BlockRadixRankMatch(TempStorage &temp_storage)
Collective constructor using the specified memory allocation as temporary storage.
Definition: block_radix_rank.hpp:291
__device__ BlockRadixRankMatch()
Collective constructor using a private static allocation of shared memory as temporary storage.
Definition: block_radix_rank.hpp:286
__device__ __forceinline__ void RankKeys(UnsignedBits(&keys)[KEYS_PER_THREAD], int(&ranks)[KEYS_PER_THREAD], DigitExtractorT digit_extractor, int(&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD])
Rank keys. For the lower RADIX_DIGITS threads, digit counts for each digit are provided for the corre...
Definition: block_radix_rank.hpp:328
__device__ __forceinline__ void RankKeys(UnsignedBits(&keys)[KEYS_PER_THREAD], int(&ranks)[KEYS_PER_THREAD], DigitExtractorT digit_extractor)
Rank keys.
Definition: block_radix_rank.hpp:309
BlockRadixRank provides operations for ranking unsigned integer types within a CUDA thread block.
Definition: block_radix_rank.hpp:128
__device__ BlockRadixRank(TempStorage &temp_storage)
Collective constructor using the specified memory allocation as temporary storage.
Definition: block_radix_rank.hpp:174
__device__ BlockRadixRank()
Collective constructor using a private static allocation of shared memory as temporary storage.
Definition: block_radix_rank.hpp:169
__device__ void RankKeys(UnsignedBits(&keys)[KEYS_PER_THREAD], int(&ranks)[KEYS_PER_THREAD], DigitExtractorT digit_extractor, int(&exclusive_digit_prefix)[BINS_TRACKED_PER_THREAD])
Rank keys. For the lower RADIX_DIGITS threads, digit counts for each digit are provided for the corre...
Definition: block_radix_rank.hpp:211
__device__ void RankKeys(UnsignedBits(&keys)[KEYS_PER_THREAD], int(&ranks)[KEYS_PER_THREAD], DigitExtractorT digit_extractor)
Rank keys.
Definition: block_radix_rank.hpp:192