/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.5.1/hipcub/include/hipcub/backend/rocprim/block/block_merge_sort.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.5.1/hipcub/include/hipcub/backend/rocprim/block/block_merge_sort.hpp Source File#

hipCUB: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.5.1/hipcub/include/hipcub/backend/rocprim/block/block_merge_sort.hpp Source File
block_merge_sort.hpp
1 /******************************************************************************
2 * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
3 * Modifications Copyright (c) 2021, Advanced Micro Devices, Inc. All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * * Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * * Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * * Neither the name of the NVIDIA CORPORATION nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
20 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 *
27 ******************************************************************************/
28 
29 #ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_MERGE_SORT_HPP_
30 #define HIPCUB_ROCPRIM_BLOCK_BLOCK_MERGE_SORT_HPP_
31 
32 #include "../thread/thread_sort.hpp"
33 #include "../util_math.hpp"
34 #include "../util_type.hpp"
35 
36 #include <rocprim/detail/various.hpp>
37 #include <rocprim/functional.hpp>
38 
39 BEGIN_HIPCUB_NAMESPACE
40 
41 
42 // Additional details of the Merge-Path Algorithm can be found in:
43 // S. Odeh, O. Green, Z. Mwassi, O. Shmueli, Y. Birk, " Merge Path - Parallel
44 // Merging Made Simple", Multithreaded Architectures and Applications (MTAAP)
45 // Workshop, IEEE 26th International Parallel & Distributed Processing
46 // Symposium (IPDPS), 2012
47 template <typename KeyT,
48  typename KeyIteratorT,
49  typename OffsetT,
50  typename BinaryPred>
51 HIPCUB_DEVICE __forceinline__ OffsetT MergePath(KeyIteratorT keys1,
52  KeyIteratorT keys2,
53  OffsetT keys1_count,
54  OffsetT keys2_count,
55  OffsetT diag,
56  BinaryPred binary_pred)
57 {
58  OffsetT keys1_begin = diag < keys2_count ? 0 : diag - keys2_count;
59  OffsetT keys1_end = (::rocprim::min)(diag, keys1_count);
60 
61  while (keys1_begin < keys1_end)
62  {
63  OffsetT mid = hipcub::MidPoint<OffsetT>(keys1_begin, keys1_end);
64  KeyT key1 = keys1[mid];
65  KeyT key2 = keys2[diag - 1 - mid];
66  bool pred = binary_pred(key2, key1);
67 
68  if (pred)
69  {
70  keys1_end = mid;
71  }
72  else
73  {
74  keys1_begin = mid + 1;
75  }
76  }
77  return keys1_begin;
78 }
79 
80 template <typename KeyT, typename CompareOp, int ITEMS_PER_THREAD>
81 HIPCUB_DEVICE __forceinline__ void SerialMerge(KeyT *keys_shared,
82  int keys1_beg,
83  int keys2_beg,
84  int keys1_count,
85  int keys2_count,
86  KeyT (&output)[ITEMS_PER_THREAD],
87  int (&indices)[ITEMS_PER_THREAD],
88  CompareOp compare_op)
89 {
90  int keys1_end = keys1_beg + keys1_count;
91  int keys2_end = keys2_beg + keys2_count;
92 
93  KeyT key1 = keys_shared[keys1_beg];
94  KeyT key2 = keys_shared[keys2_beg];
95 
96 #pragma unroll
97  for (int item = 0; item < ITEMS_PER_THREAD; ++item)
98  {
99  bool p = (keys2_beg < keys2_end) &&
100  ((keys1_beg >= keys1_end)
101  || compare_op(key2, key1));
102 
103  output[item] = p ? key2 : key1;
104  indices[item] = p ? keys2_beg++ : keys1_beg++;
105 
106  if (p)
107  {
108  key2 = keys_shared[keys2_beg];
109  }
110  else
111  {
112  key1 = keys_shared[keys1_beg];
113  }
114  }
115 }
116 
169 template <typename KeyT,
170  typename ValueT,
171  int NUM_THREADS,
172  int ITEMS_PER_THREAD,
173  typename SynchronizationPolicy>
175 {
176  static_assert(PowerOfTwo<NUM_THREADS>::VALUE,
177  "NUM_THREADS must be a power of two");
178 
179 private:
180 
181  static constexpr int ITEMS_PER_TILE = ITEMS_PER_THREAD * NUM_THREADS;
182 
183  // Whether or not there are values to be trucked along with keys
184  static constexpr bool KEYS_ONLY = ::rocprim::Equals<ValueT, NullType>::VALUE;
185 
187  union _TempStorage
188  {
189  KeyT keys_shared[ITEMS_PER_TILE + 1];
190  ValueT items_shared[ITEMS_PER_TILE + 1];
191  }; // union TempStorage
192 
194  _TempStorage &temp_storage;
195 
197  HIPCUB_DEVICE __forceinline__ _TempStorage& PrivateStorage()
198  {
199  __shared__ _TempStorage private_storage;
200  return private_storage;
201  }
202 
203  const unsigned int linear_tid;
204 
205 public:
207  struct TempStorage : Uninitialized<_TempStorage> {};
208 
209  BlockMergeSortStrategy() = delete;
210  explicit HIPCUB_DEVICE __forceinline__
211  BlockMergeSortStrategy(unsigned int linear_tid)
212  : temp_storage(PrivateStorage())
213  , linear_tid(linear_tid)
214  {}
215 
216  HIPCUB_DEVICE __forceinline__ BlockMergeSortStrategy(TempStorage &temp_storage,
217  unsigned int linear_tid)
218  : temp_storage(temp_storage.Alias())
219  , linear_tid(linear_tid)
220  {}
221 
222  HIPCUB_DEVICE __forceinline__ unsigned int get_linear_tid() const
223  {
224  return linear_tid;
225  }
226 
249  template <typename CompareOp>
250  HIPCUB_DEVICE __forceinline__ void Sort(KeyT (&keys)[ITEMS_PER_THREAD],
251  CompareOp compare_op)
252  {
253  ValueT items[ITEMS_PER_THREAD];
254  Sort<CompareOp, false>(keys, items, compare_op, ITEMS_PER_TILE, keys[0]);
255  }
256 
291  template <typename CompareOp>
292  HIPCUB_DEVICE __forceinline__ void Sort(KeyT (&keys)[ITEMS_PER_THREAD],
293  CompareOp compare_op,
294  int valid_items,
295  KeyT oob_default)
296  {
297  ValueT items[ITEMS_PER_THREAD];
298  Sort<CompareOp, true>(keys, items, compare_op, valid_items, oob_default);
299  }
300 
325  template <typename CompareOp>
326  HIPCUB_DEVICE __forceinline__ void Sort(KeyT (&keys)[ITEMS_PER_THREAD],
327  ValueT (&items)[ITEMS_PER_THREAD],
328  CompareOp compare_op)
329  {
330  Sort<CompareOp, false>(keys, items, compare_op, ITEMS_PER_TILE, keys[0]);
331  }
332 
373  template <typename CompareOp,
374  bool IS_LAST_TILE = true>
375  HIPCUB_DEVICE __forceinline__ void Sort(KeyT (&keys)[ITEMS_PER_THREAD],
376  ValueT (&items)[ITEMS_PER_THREAD],
377  CompareOp compare_op,
378  int valid_items,
379  KeyT oob_default)
380  {
381  if (IS_LAST_TILE)
382  {
383  // if last tile, find valid max_key
384  // and fill the remaining keys with it
385  //
386  KeyT max_key = oob_default;
387 
388  #pragma unroll
389  for (int item = 1; item < ITEMS_PER_THREAD; ++item)
390  {
391  if (ITEMS_PER_THREAD * static_cast<int>(linear_tid) + item < valid_items)
392  {
393  max_key = compare_op(max_key, keys[item]) ? keys[item] : max_key;
394  }
395  else
396  {
397  keys[item] = max_key;
398  }
399  }
400  }
401 
402  // if first element of thread is in input range, stable sort items
403  //
404  if (!IS_LAST_TILE || ITEMS_PER_THREAD * static_cast<int>(linear_tid) < valid_items)
405  {
406  StableOddEvenSort(keys, items, compare_op);
407  }
408 
409  // each thread has sorted keys
410  // merge sort keys in shared memory
411  //
412  #pragma unroll
413  for (int target_merged_threads_number = 2;
414  target_merged_threads_number <= NUM_THREADS;
415  target_merged_threads_number *= 2)
416  {
417  int merged_threads_number = target_merged_threads_number / 2;
418  int mask = target_merged_threads_number - 1;
419 
420  Sync();
421 
422  // store keys in shmem
423  //
424  #pragma unroll
425  for (int item = 0; item < ITEMS_PER_THREAD; ++item)
426  {
427  int idx = ITEMS_PER_THREAD * linear_tid + item;
428  temp_storage.keys_shared[idx] = keys[item];
429  }
430 
431  Sync();
432 
433  int indices[ITEMS_PER_THREAD];
434 
435  int first_thread_idx_in_thread_group_being_merged = ~mask & linear_tid;
436  int start = ITEMS_PER_THREAD * first_thread_idx_in_thread_group_being_merged;
437  int size = ITEMS_PER_THREAD * merged_threads_number;
438 
439  int thread_idx_in_thread_group_being_merged = mask & linear_tid;
440 
441  int diag =
442  (::rocprim::min)(valid_items,
443  ITEMS_PER_THREAD * thread_idx_in_thread_group_being_merged);
444 
445  int keys1_beg = (::rocprim::min)(valid_items, start);
446  int keys1_end = (::rocprim::min)(valid_items, keys1_beg + size);
447  int keys2_beg = keys1_end;
448  int keys2_end = (::rocprim::min)(valid_items, keys2_beg + size);
449 
450  int keys1_count = keys1_end - keys1_beg;
451  int keys2_count = keys2_end - keys2_beg;
452 
453  int partition_diag = MergePath<KeyT>(&temp_storage.keys_shared[keys1_beg],
454  &temp_storage.keys_shared[keys2_beg],
455  keys1_count,
456  keys2_count,
457  diag,
458  compare_op);
459 
460  int keys1_beg_loc = keys1_beg + partition_diag;
461  int keys1_end_loc = keys1_end;
462  int keys2_beg_loc = keys2_beg + diag - partition_diag;
463  int keys2_end_loc = keys2_end;
464  int keys1_count_loc = keys1_end_loc - keys1_beg_loc;
465  int keys2_count_loc = keys2_end_loc - keys2_beg_loc;
466  SerialMerge(&temp_storage.keys_shared[0],
467  keys1_beg_loc,
468  keys2_beg_loc,
469  keys1_count_loc,
470  keys2_count_loc,
471  keys,
472  indices,
473  compare_op);
474 
475  if (!KEYS_ONLY)
476  {
477  Sync();
478 
479  // store keys in shmem
480  //
481  #pragma unroll
482  for (int item = 0; item < ITEMS_PER_THREAD; ++item)
483  {
484  int idx = ITEMS_PER_THREAD * linear_tid + item;
485  temp_storage.items_shared[idx] = items[item];
486  }
487 
488  Sync();
489 
490  // gather items from shmem
491  //
492  #pragma unroll
493  for (int item = 0; item < ITEMS_PER_THREAD; ++item)
494  {
495  items[item] = temp_storage.items_shared[indices[item]];
496  }
497  }
498  }
499  } // func block_merge_sort
500 
524  template <typename CompareOp>
525  HIPCUB_DEVICE __forceinline__ void StableSort(KeyT (&keys)[ITEMS_PER_THREAD],
526  CompareOp compare_op)
527  {
528  Sort(keys, compare_op);
529  }
530 
557  template <typename CompareOp>
558  HIPCUB_DEVICE __forceinline__ void StableSort(KeyT (&keys)[ITEMS_PER_THREAD],
559  ValueT (&items)[ITEMS_PER_THREAD],
560  CompareOp compare_op)
561  {
562  Sort(keys, items, compare_op);
563  }
564 
601  template <typename CompareOp>
602  HIPCUB_DEVICE __forceinline__ void StableSort(KeyT (&keys)[ITEMS_PER_THREAD],
603  CompareOp compare_op,
604  int valid_items,
605  KeyT oob_default)
606  {
607  Sort(keys, compare_op, valid_items, oob_default);
608  }
609 
651  template <typename CompareOp,
652  bool IS_LAST_TILE = true>
653  HIPCUB_DEVICE __forceinline__ void StableSort(KeyT (&keys)[ITEMS_PER_THREAD],
654  ValueT (&items)[ITEMS_PER_THREAD],
655  CompareOp compare_op,
656  int valid_items,
657  KeyT oob_default)
658  {
659  Sort<CompareOp, IS_LAST_TILE>(keys,
660  items,
661  compare_op,
662  valid_items,
663  oob_default);
664  }
665 
666 private:
667  HIPCUB_DEVICE __forceinline__ void Sync() const
668  {
669  static_cast<const SynchronizationPolicy*>(this)->SyncImplementation();
670  }
671 };
672 
673 
754 template <typename KeyT,
755  int BLOCK_DIM_X,
756  int ITEMS_PER_THREAD,
757  typename ValueT = NullType,
758  int BLOCK_DIM_Y = 1,
759  int BLOCK_DIM_Z = 1>
761  : public BlockMergeSortStrategy<KeyT,
762  ValueT,
763  BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z,
764  ITEMS_PER_THREAD,
765  BlockMergeSort<KeyT,
766  BLOCK_DIM_X,
767  ITEMS_PER_THREAD,
768  ValueT,
769  BLOCK_DIM_Y,
770  BLOCK_DIM_Z>>
771 {
772 private:
773  // The thread block size in threads
774  static constexpr int BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z;
775  static constexpr int ITEMS_PER_TILE = ITEMS_PER_THREAD * BLOCK_THREADS;
776 
779  ValueT,
780  BLOCK_THREADS,
781  ITEMS_PER_THREAD,
783 
784 public:
785  HIPCUB_DEVICE __forceinline__ BlockMergeSort()
787  RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
788  {}
789 
790  HIPCUB_DEVICE __forceinline__ explicit BlockMergeSort(
791  typename BlockMergeSortStrategyT::TempStorage &temp_storage)
793  temp_storage,
794  RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
795  {}
796 
797 private:
798  HIPCUB_DEVICE __forceinline__ void SyncImplementation() const
799  {
800  CTA_SYNC();
801  }
802 
804 };
805 
806 END_HIPCUB_NAMESPACE
807 
808 #endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_MERGE_SORT_HPP_
Generalized merge sort algorithm.
Definition: block_merge_sort.hpp:175
__device__ __forceinline__ void StableSort(KeyT(&keys)[ITEMS_PER_THREAD], ValueT(&items)[ITEMS_PER_THREAD], CompareOp compare_op)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:558
__device__ __forceinline__ void StableSort(KeyT(&keys)[ITEMS_PER_THREAD], CompareOp compare_op, int valid_items, KeyT oob_default)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:602
__device__ __forceinline__ void StableSort(KeyT(&keys)[ITEMS_PER_THREAD], CompareOp compare_op)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:525
__device__ __forceinline__ void Sort(KeyT(&keys)[ITEMS_PER_THREAD], CompareOp compare_op)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:250
__device__ __forceinline__ void Sort(KeyT(&keys)[ITEMS_PER_THREAD], ValueT(&items)[ITEMS_PER_THREAD], CompareOp compare_op, int valid_items, KeyT oob_default)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:375
__device__ __forceinline__ void Sort(KeyT(&keys)[ITEMS_PER_THREAD], ValueT(&items)[ITEMS_PER_THREAD], CompareOp compare_op)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:326
__device__ __forceinline__ void StableSort(KeyT(&keys)[ITEMS_PER_THREAD], ValueT(&items)[ITEMS_PER_THREAD], CompareOp compare_op, int valid_items, KeyT oob_default)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:653
__device__ __forceinline__ void Sort(KeyT(&keys)[ITEMS_PER_THREAD], CompareOp compare_op, int valid_items, KeyT oob_default)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:292
The BlockMergeSort class provides methods for sorting items partitioned across a CUDA thread block us...
Definition: block_merge_sort.hpp:771
\smemstorage{BlockMergeSort}
Definition: block_merge_sort.hpp:207
Definition: util_type.hpp:78
A storage-backing wrapper that allows types with non-trivial constructors to be aliased in unions.
Definition: util_type.hpp:363