/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.0.2/hipcub/include/hipcub/backend/rocprim/block/block_merge_sort.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.0.2/hipcub/include/hipcub/backend/rocprim/block/block_merge_sort.hpp Source File#

hipCUB: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.0.2/hipcub/include/hipcub/backend/rocprim/block/block_merge_sort.hpp Source File
block_merge_sort.hpp
1 /******************************************************************************
2 * Copyright (c) 2011-2021, NVIDIA CORPORATION. All rights reserved.
3 * Modifications Copyright (c) 2021, Advanced Micro Devices, Inc. All rights reserved.
4 *
5 * Redistribution and use in source and binary forms, with or without
6 * modification, are permitted provided that the following conditions are met:
7 * * Redistributions of source code must retain the above copyright
8 * notice, this list of conditions and the following disclaimer.
9 * * Redistributions in binary form must reproduce the above copyright
10 * notice, this list of conditions and the following disclaimer in the
11 * documentation and/or other materials provided with the distribution.
12 * * Neither the name of the NVIDIA CORPORATION nor the
13 * names of its contributors may be used to endorse or promote products
14 * derived from this software without specific prior written permission.
15 *
16 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19 * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
20 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 *
27 ******************************************************************************/
28 
29 #ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_MERGE_SORT_HPP_
30 #define HIPCUB_ROCPRIM_BLOCK_BLOCK_MERGE_SORT_HPP_
31 
32 #include "../util_math.hpp"
33 #include "../util_type.hpp"
34 
35 #include <rocprim/functional.hpp>
36 
37 BEGIN_HIPCUB_NAMESPACE
38 
39 
40 // Implementation of the MergePath algorithm, as described in:
41 // Odeh et al, "Merge Path - Parallel Merging Made Simple"
42 // doi:10.1109/IPDPSW.2012.202
43 template <typename KeyT,
44  typename KeyIteratorT,
45  typename OffsetT,
46  typename BinaryPred>
47 __device__ __forceinline__ OffsetT MergePath(KeyIteratorT keys1,
48  KeyIteratorT keys2,
49  OffsetT keys1_count,
50  OffsetT keys2_count,
51  OffsetT diag,
52  BinaryPred binary_pred)
53 {
54  OffsetT keys1_begin = diag < keys2_count ? 0 : diag - keys2_count;
55  OffsetT keys1_end = (::rocprim::min)(diag, keys1_count);
56 
57  while (keys1_begin < keys1_end)
58  {
59  OffsetT mid = hipcub::MidPoint<OffsetT>(keys1_begin, keys1_end);
60  KeyT key1 = keys1[mid];
61  KeyT key2 = keys2[diag - 1 - mid];
62  bool pred = binary_pred(key2, key1);
63 
64  if (pred)
65  {
66  keys1_end = mid;
67  }
68  else
69  {
70  keys1_begin = mid + 1;
71  }
72  }
73  return keys1_begin;
74 }
75 
76 template <typename KeyT, typename CompareOp, int ITEMS_PER_THREAD>
77 __device__ __forceinline__ void SerialMerge(KeyT *keys_shared,
78  int keys1_beg,
79  int keys2_beg,
80  int keys1_count,
81  int keys2_count,
82  KeyT (&output)[ITEMS_PER_THREAD],
83  int (&indices)[ITEMS_PER_THREAD],
84  CompareOp compare_op)
85 {
86  int keys1_end = keys1_beg + keys1_count;
87  int keys2_end = keys2_beg + keys2_count;
88 
89  KeyT key1 = keys_shared[keys1_beg];
90  KeyT key2 = keys_shared[keys2_beg];
91 
92 #pragma unroll
93  for (int item = 0; item < ITEMS_PER_THREAD; ++item)
94  {
95  bool p = (keys2_beg < keys2_end) &&
96  ((keys1_beg >= keys1_end)
97  || compare_op(key2, key1));
98 
99  output[item] = p ? key2 : key1;
100  indices[item] = p ? keys2_beg++ : keys1_beg++;
101 
102  if (p)
103  {
104  key2 = keys_shared[keys2_beg];
105  }
106  else
107  {
108  key1 = keys_shared[keys1_beg];
109  }
110  }
111 }
112 
178 template <
179  typename KeyT,
180  int BLOCK_DIM_X,
181  int ITEMS_PER_THREAD,
182  typename ValueT = NullType,
183  int BLOCK_DIM_Y = 1,
184  int BLOCK_DIM_Z = 1>
186 {
187  private:
188 
189  // The thread block size in threads
190  static constexpr int BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z;
191  static constexpr int ITEMS_PER_TILE = ITEMS_PER_THREAD * BLOCK_THREADS;
192 
193  // Whether or not there are values to be trucked along with keys
194  static constexpr bool KEYS_ONLY = ::rocprim::Equals<ValueT, NullType>::VALUE;
195 
197  union _TempStorage
198  {
199  KeyT keys_shared[ITEMS_PER_TILE + 1];
200  ValueT items_shared[ITEMS_PER_TILE + 1];
201  }; // union TempStorage
202 
204  __device__ __forceinline__ _TempStorage& PrivateStorage()
205  {
206  __shared__ _TempStorage private_storage;
207  return private_storage;
208  }
209 
211  _TempStorage &temp_storage;
212 
214  unsigned int linear_tid;
215 
216  public:
217 
219  struct TempStorage : Uninitialized<_TempStorage> {};
220 
221  __device__ __forceinline__ BlockMergeSort()
222  : temp_storage(PrivateStorage())
223  , linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
224  {}
225 
226  __device__ __forceinline__ BlockMergeSort(TempStorage &temp_storage)
227  : temp_storage(temp_storage.Alias())
228  , linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
229  {}
230 
231  private:
232 
233  template <typename T>
234  __device__ __forceinline__ void Swap(T &lhs, T &rhs)
235  {
236  T temp = lhs;
237  lhs = rhs;
238  rhs = temp;
239  }
240 
241  template <typename CompareOp>
242  __device__ __forceinline__ void
243  StableOddEvenSort(KeyT (&keys)[ITEMS_PER_THREAD],
244  ValueT (&items)[ITEMS_PER_THREAD],
245  CompareOp compare_op)
246  {
247 #pragma unroll
248  for (int i = 0; i < ITEMS_PER_THREAD; ++i)
249  {
250 #pragma unroll
251  for (int j = 1 & i; j < ITEMS_PER_THREAD - 1; j += 2)
252  {
253  if (compare_op(keys[j + 1], keys[j]))
254  {
255  Swap(keys[j], keys[j + 1]);
256  if (!KEYS_ONLY)
257  {
258  Swap(items[j], items[j + 1]);
259  }
260  }
261  } // inner loop
262  } // outer loop
263  }
264 
265  public:
266 
278  template <typename CompareOp>
279  __device__ __forceinline__ void
280  Sort(KeyT (&keys)[ITEMS_PER_THREAD],
281  CompareOp compare_op)
284  {
285  ValueT items[ITEMS_PER_THREAD];
286  Sort<CompareOp, false>(keys, items, compare_op, ITEMS_PER_TILE, keys[0]);
287  }
288 
306  template <typename CompareOp>
307  __device__ __forceinline__ void
308  Sort(KeyT (&keys)[ITEMS_PER_THREAD],
309  CompareOp compare_op,
310  int valid_items,
311  KeyT oob_default)
312  {
313  ValueT items[ITEMS_PER_THREAD];
314  Sort<CompareOp, true>(keys, items, compare_op, valid_items, oob_default);
315  }
316 
328  template <typename CompareOp>
329  __device__ __forceinline__ void
330  Sort(KeyT (&keys)[ITEMS_PER_THREAD],
331  ValueT (&items)[ITEMS_PER_THREAD],
332  CompareOp compare_op)
333  {
334  Sort<CompareOp, false>(keys, items, compare_op, ITEMS_PER_TILE, keys[0]);
335  }
336 
355  template <typename CompareOp,
356  bool IS_LAST_TILE = true>
357  __device__ __forceinline__ void
358  Sort(KeyT (&keys)[ITEMS_PER_THREAD],
359  ValueT (&items)[ITEMS_PER_THREAD],
360  CompareOp compare_op,
361  int valid_items,
362  KeyT oob_default)
363  {
364  if (IS_LAST_TILE)
365  {
366  // if last tile, find valid max_key
367  // and fill the remaining keys with it
368  //
369  KeyT max_key = oob_default;
370 #pragma unroll
371  for (int item = 1; item < ITEMS_PER_THREAD; ++item)
372  {
373  if (ITEMS_PER_THREAD * linear_tid + item < valid_items)
374  {
375  max_key = compare_op(max_key, keys[item]) ? keys[item] : max_key;
376  }
377  else
378  {
379  keys[item] = max_key;
380  }
381  }
382  }
383 
384  // if first element of thread is in input range, stable sort items
385  //
386  if (!IS_LAST_TILE || ITEMS_PER_THREAD * linear_tid < valid_items)
387  {
388  StableOddEvenSort(keys, items, compare_op);
389  }
390 
391  // each thread has sorted keys
392  // merge sort keys in shared memory
393  //
394 #pragma unroll
395  for (int target_merged_threads_number = 2;
396  target_merged_threads_number <= BLOCK_THREADS;
397  target_merged_threads_number *= 2)
398  {
399  int merged_threads_number = target_merged_threads_number / 2;
400  int mask = target_merged_threads_number - 1;
401 
402  CTA_SYNC();
403 
404  // store keys in shmem
405  //
406 #pragma unroll
407  for (int item = 0; item < ITEMS_PER_THREAD; ++item)
408  {
409  int idx = ITEMS_PER_THREAD * linear_tid + item;
410  temp_storage.keys_shared[idx] = keys[item];
411  }
412 
413  CTA_SYNC();
414 
415  int indices[ITEMS_PER_THREAD];
416 
417  int first_thread_idx_in_thread_group_being_merged = ~mask & linear_tid;
418  int start = ITEMS_PER_THREAD * first_thread_idx_in_thread_group_being_merged;
419  int size = ITEMS_PER_THREAD * merged_threads_number;
420 
421  int thread_idx_in_thread_group_being_merged = mask & linear_tid;
422 
423  int diag =
424  (rocprim::min)(valid_items,
425  ITEMS_PER_THREAD * thread_idx_in_thread_group_being_merged);
426 
427  int keys1_beg = (rocprim::min)(valid_items, start);
428  int keys1_end = (rocprim::min)(valid_items, keys1_beg + size);
429  int keys2_beg = keys1_end;
430  int keys2_end = (rocprim::min)(valid_items, keys2_beg + size);
431 
432  int keys1_count = keys1_end - keys1_beg;
433  int keys2_count = keys2_end - keys2_beg;
434 
435  int partition_diag = MergePath<KeyT>(&temp_storage.keys_shared[keys1_beg],
436  &temp_storage.keys_shared[keys2_beg],
437  keys1_count,
438  keys2_count,
439  diag,
440  compare_op);
441 
442  int keys1_beg_loc = keys1_beg + partition_diag;
443  int keys1_end_loc = keys1_end;
444  int keys2_beg_loc = keys2_beg + diag - partition_diag;
445  int keys2_end_loc = keys2_end;
446  int keys1_count_loc = keys1_end_loc - keys1_beg_loc;
447  int keys2_count_loc = keys2_end_loc - keys2_beg_loc;
448  SerialMerge(&temp_storage.keys_shared[0],
449  keys1_beg_loc,
450  keys2_beg_loc,
451  keys1_count_loc,
452  keys2_count_loc,
453  keys,
454  indices,
455  compare_op);
456 
457  if (!KEYS_ONLY)
458  {
459  CTA_SYNC();
460 
461  // store keys in shmem
462  //
463 #pragma unroll
464  for (int item = 0; item < ITEMS_PER_THREAD; ++item)
465  {
466  int idx = ITEMS_PER_THREAD * linear_tid + item;
467  temp_storage.items_shared[idx] = items[item];
468  }
469 
470  CTA_SYNC();
471 
472  // gather items from shmem
473  //
474 #pragma unroll
475  for (int item = 0; item < ITEMS_PER_THREAD; ++item)
476  {
477  items[item] = temp_storage.items_shared[indices[item]];
478  }
479  }
480  }
481  } // func block_merge_sort
482 
495  template <typename CompareOp>
496  __device__ __forceinline__ void
497  StableSort(KeyT (&keys)[ITEMS_PER_THREAD],
498  CompareOp compare_op)
499  {
500  Sort(keys, compare_op);
501  }
502 
515  template <typename CompareOp>
516  __device__ __forceinline__ void
517  StableSort(KeyT (&keys)[ITEMS_PER_THREAD],
518  ValueT (&items)[ITEMS_PER_THREAD],
519  CompareOp compare_op)
520  {
521  Sort(keys, items, compare_op);
522  }
523 
542  template <typename CompareOp>
543  __device__ __forceinline__ void
544  StableSort(KeyT (&keys)[ITEMS_PER_THREAD],
545  CompareOp compare_op,
546  int valid_items,
547  KeyT oob_default)
548  {
549  Sort(keys, compare_op, valid_items, oob_default);
550  }
551 
571  template <typename CompareOp,
572  bool IS_LAST_TILE = true>
573  __device__ __forceinline__ void
574  StableSort(KeyT (&keys)[ITEMS_PER_THREAD],
575  ValueT (&items)[ITEMS_PER_THREAD],
576  CompareOp compare_op,
577  int valid_items,
578  KeyT oob_default)
579  {
580  Sort<CompareOp, IS_LAST_TILE>(keys,
581  items,
582  compare_op,
583  valid_items,
584  oob_default);
585  }
586 };
587 
588 END_HIPCUB_NAMESPACE
589 
590 #endif // HIPCUB_ROCPRIM_BLOCK_BLOCK_MERGE_SORT_HPP_
The BlockMergeSort class provides methods for sorting items partitioned across a CUDA thread block us...
Definition: block_merge_sort.hpp:186
__device__ __forceinline__ void StableSort(KeyT(&keys)[ITEMS_PER_THREAD], ValueT(&items)[ITEMS_PER_THREAD], CompareOp compare_op, int valid_items, KeyT oob_default)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:574
__device__ __forceinline__ void Sort(KeyT(&keys)[ITEMS_PER_THREAD], CompareOp compare_op, int valid_items, KeyT oob_default)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:308
__device__ __forceinline__ void StableSort(KeyT(&keys)[ITEMS_PER_THREAD], CompareOp compare_op, int valid_items, KeyT oob_default)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:544
__device__ __forceinline__ void StableSort(KeyT(&keys)[ITEMS_PER_THREAD], ValueT(&items)[ITEMS_PER_THREAD], CompareOp compare_op)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:517
__device__ __forceinline__ void StableSort(KeyT(&keys)[ITEMS_PER_THREAD], CompareOp compare_op)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:497
__device__ __forceinline__ void Sort(KeyT(&keys)[ITEMS_PER_THREAD], ValueT(&items)[ITEMS_PER_THREAD], CompareOp compare_op)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:330
__device__ __forceinline__ void Sort(KeyT(&keys)[ITEMS_PER_THREAD], CompareOp compare_op)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:280
__device__ __forceinline__ void Sort(KeyT(&keys)[ITEMS_PER_THREAD], ValueT(&items)[ITEMS_PER_THREAD], CompareOp compare_op, int valid_items, KeyT oob_default)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:358
\smemstorage{BlockMergeSort}
Definition: block_merge_sort.hpp:219
A storage-backing wrapper that allows types with non-trivial constructors to be aliased in unions.
Definition: util_type.hpp:359