29 #ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_MERGE_SORT_HPP_
30 #define HIPCUB_ROCPRIM_BLOCK_BLOCK_MERGE_SORT_HPP_
32 #include "../thread/thread_sort.hpp"
33 #include "../util_math.hpp"
34 #include "../util_type.hpp"
36 #include <rocprim/detail/various.hpp>
37 #include <rocprim/functional.hpp>
39 BEGIN_HIPCUB_NAMESPACE
47 template <
typename KeyT,
48 typename KeyIteratorT,
51 HIPCUB_DEVICE __forceinline__ OffsetT MergePath(KeyIteratorT keys1,
56 BinaryPred binary_pred)
58 OffsetT keys1_begin = diag < keys2_count ? 0 : diag - keys2_count;
59 OffsetT keys1_end = (::rocprim::min)(diag, keys1_count);
61 while (keys1_begin < keys1_end)
63 OffsetT mid = hipcub::MidPoint<OffsetT>(keys1_begin, keys1_end);
64 KeyT key1 = keys1[mid];
65 KeyT key2 = keys2[diag - 1 - mid];
66 bool pred = binary_pred(key2, key1);
74 keys1_begin = mid + 1;
80 template <
typename KeyT,
typename CompareOp,
int ITEMS_PER_THREAD>
81 HIPCUB_DEVICE __forceinline__
void SerialMerge(KeyT *keys_shared,
86 KeyT (&output)[ITEMS_PER_THREAD],
87 int (&indices)[ITEMS_PER_THREAD],
90 int keys1_end = keys1_beg + keys1_count;
91 int keys2_end = keys2_beg + keys2_count;
93 KeyT key1 = keys_shared[keys1_beg];
94 KeyT key2 = keys_shared[keys2_beg];
97 for (
int item = 0; item < ITEMS_PER_THREAD; ++item)
99 bool p = (keys2_beg < keys2_end) &&
100 ((keys1_beg >= keys1_end)
101 || compare_op(key2, key1));
103 output[item] = p ? key2 : key1;
104 indices[item] = p ? keys2_beg++ : keys1_beg++;
108 key2 = keys_shared[keys2_beg];
112 key1 = keys_shared[keys1_beg];
169 template <
typename KeyT,
172 int ITEMS_PER_THREAD,
173 typename SynchronizationPolicy>
177 "NUM_THREADS must be a power of two");
181 static constexpr
int ITEMS_PER_TILE = ITEMS_PER_THREAD * NUM_THREADS;
184 static constexpr
bool KEYS_ONLY = ::rocprim::Equals<ValueT, NullType>::VALUE;
189 KeyT keys_shared[ITEMS_PER_TILE + 1];
190 ValueT items_shared[ITEMS_PER_TILE + 1];
194 _TempStorage &temp_storage;
197 HIPCUB_DEVICE __forceinline__ _TempStorage& PrivateStorage()
199 __shared__ _TempStorage private_storage;
200 return private_storage;
203 const unsigned int linear_tid;
210 explicit HIPCUB_DEVICE __forceinline__
212 : temp_storage(PrivateStorage())
213 , linear_tid(linear_tid)
217 unsigned int linear_tid)
218 : temp_storage(temp_storage.Alias())
219 , linear_tid(linear_tid)
222 HIPCUB_DEVICE __forceinline__
unsigned int get_linear_tid()
const
249 template <
typename CompareOp>
250 HIPCUB_DEVICE __forceinline__
void Sort(KeyT (&keys)[ITEMS_PER_THREAD],
251 CompareOp compare_op)
253 ValueT items[ITEMS_PER_THREAD];
254 Sort<CompareOp, false>(keys, items, compare_op, ITEMS_PER_TILE, keys[0]);
291 template <
typename CompareOp>
292 HIPCUB_DEVICE __forceinline__
void Sort(KeyT (&keys)[ITEMS_PER_THREAD],
293 CompareOp compare_op,
297 ValueT items[ITEMS_PER_THREAD];
298 Sort<CompareOp, true>(keys, items, compare_op, valid_items, oob_default);
325 template <
typename CompareOp>
326 HIPCUB_DEVICE __forceinline__
void Sort(KeyT (&keys)[ITEMS_PER_THREAD],
327 ValueT (&items)[ITEMS_PER_THREAD],
328 CompareOp compare_op)
330 Sort<CompareOp, false>(keys, items, compare_op, ITEMS_PER_TILE, keys[0]);
373 template <
typename CompareOp,
374 bool IS_LAST_TILE =
true>
375 HIPCUB_DEVICE __forceinline__
void Sort(KeyT (&keys)[ITEMS_PER_THREAD],
376 ValueT (&items)[ITEMS_PER_THREAD],
377 CompareOp compare_op,
386 KeyT max_key = oob_default;
389 for (
int item = 1; item < ITEMS_PER_THREAD; ++item)
391 if (ITEMS_PER_THREAD *
static_cast<int>(linear_tid) + item < valid_items)
393 max_key = compare_op(max_key, keys[item]) ? keys[item] : max_key;
397 keys[item] = max_key;
404 if (!IS_LAST_TILE || ITEMS_PER_THREAD *
static_cast<int>(linear_tid) < valid_items)
406 StableOddEvenSort(keys, items, compare_op);
413 for (
int target_merged_threads_number = 2;
414 target_merged_threads_number <= NUM_THREADS;
415 target_merged_threads_number *= 2)
417 int merged_threads_number = target_merged_threads_number / 2;
418 int mask = target_merged_threads_number - 1;
425 for (
int item = 0; item < ITEMS_PER_THREAD; ++item)
427 int idx = ITEMS_PER_THREAD * linear_tid + item;
428 temp_storage.keys_shared[idx] = keys[item];
433 int indices[ITEMS_PER_THREAD];
435 int first_thread_idx_in_thread_group_being_merged = ~mask & linear_tid;
436 int start = ITEMS_PER_THREAD * first_thread_idx_in_thread_group_being_merged;
437 int size = ITEMS_PER_THREAD * merged_threads_number;
439 int thread_idx_in_thread_group_being_merged = mask & linear_tid;
442 (::rocprim::min)(valid_items,
443 ITEMS_PER_THREAD * thread_idx_in_thread_group_being_merged);
445 int keys1_beg = (::rocprim::min)(valid_items, start);
446 int keys1_end = (::rocprim::min)(valid_items, keys1_beg + size);
447 int keys2_beg = keys1_end;
448 int keys2_end = (::rocprim::min)(valid_items, keys2_beg + size);
450 int keys1_count = keys1_end - keys1_beg;
451 int keys2_count = keys2_end - keys2_beg;
453 int partition_diag = MergePath<KeyT>(&temp_storage.keys_shared[keys1_beg],
454 &temp_storage.keys_shared[keys2_beg],
460 int keys1_beg_loc = keys1_beg + partition_diag;
461 int keys1_end_loc = keys1_end;
462 int keys2_beg_loc = keys2_beg + diag - partition_diag;
463 int keys2_end_loc = keys2_end;
464 int keys1_count_loc = keys1_end_loc - keys1_beg_loc;
465 int keys2_count_loc = keys2_end_loc - keys2_beg_loc;
466 SerialMerge(&temp_storage.keys_shared[0],
482 for (
int item = 0; item < ITEMS_PER_THREAD; ++item)
484 int idx = ITEMS_PER_THREAD * linear_tid + item;
485 temp_storage.items_shared[idx] = items[item];
493 for (
int item = 0; item < ITEMS_PER_THREAD; ++item)
495 items[item] = temp_storage.items_shared[indices[item]];
524 template <
typename CompareOp>
525 HIPCUB_DEVICE __forceinline__
void StableSort(KeyT (&keys)[ITEMS_PER_THREAD],
526 CompareOp compare_op)
528 Sort(keys, compare_op);
557 template <
typename CompareOp>
558 HIPCUB_DEVICE __forceinline__
void StableSort(KeyT (&keys)[ITEMS_PER_THREAD],
559 ValueT (&items)[ITEMS_PER_THREAD],
560 CompareOp compare_op)
562 Sort(keys, items, compare_op);
601 template <
typename CompareOp>
602 HIPCUB_DEVICE __forceinline__
void StableSort(KeyT (&keys)[ITEMS_PER_THREAD],
603 CompareOp compare_op,
607 Sort(keys, compare_op, valid_items, oob_default);
651 template <
typename CompareOp,
652 bool IS_LAST_TILE =
true>
653 HIPCUB_DEVICE __forceinline__
void StableSort(KeyT (&keys)[ITEMS_PER_THREAD],
654 ValueT (&items)[ITEMS_PER_THREAD],
655 CompareOp compare_op,
659 Sort<CompareOp, IS_LAST_TILE>(keys,
667 HIPCUB_DEVICE __forceinline__
void Sync()
const
669 static_cast<const SynchronizationPolicy*
>(
this)->SyncImplementation();
754 template <
typename KeyT,
756 int ITEMS_PER_THREAD,
757 typename ValueT = NullType,
763 BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z,
774 static constexpr
int BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z;
775 static constexpr
int ITEMS_PER_TILE = ITEMS_PER_THREAD * BLOCK_THREADS;
787 RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
794 RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
798 HIPCUB_DEVICE __forceinline__
void SyncImplementation()
const
Generalized merge sort algorithm.
Definition: block_merge_sort.hpp:175
__device__ __forceinline__ void StableSort(KeyT(&keys)[ITEMS_PER_THREAD], ValueT(&items)[ITEMS_PER_THREAD], CompareOp compare_op)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:558
__device__ __forceinline__ void StableSort(KeyT(&keys)[ITEMS_PER_THREAD], CompareOp compare_op, int valid_items, KeyT oob_default)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:602
__device__ __forceinline__ void StableSort(KeyT(&keys)[ITEMS_PER_THREAD], CompareOp compare_op)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:525
__device__ __forceinline__ void Sort(KeyT(&keys)[ITEMS_PER_THREAD], CompareOp compare_op)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:250
__device__ __forceinline__ void Sort(KeyT(&keys)[ITEMS_PER_THREAD], ValueT(&items)[ITEMS_PER_THREAD], CompareOp compare_op, int valid_items, KeyT oob_default)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:375
__device__ __forceinline__ void Sort(KeyT(&keys)[ITEMS_PER_THREAD], ValueT(&items)[ITEMS_PER_THREAD], CompareOp compare_op)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:326
__device__ __forceinline__ void StableSort(KeyT(&keys)[ITEMS_PER_THREAD], ValueT(&items)[ITEMS_PER_THREAD], CompareOp compare_op, int valid_items, KeyT oob_default)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:653
__device__ __forceinline__ void Sort(KeyT(&keys)[ITEMS_PER_THREAD], CompareOp compare_op, int valid_items, KeyT oob_default)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:292
The BlockMergeSort class provides methods for sorting items partitioned across a CUDA thread block us...
Definition: block_merge_sort.hpp:771
\smemstorage{BlockMergeSort}
Definition: block_merge_sort.hpp:207
Definition: util_type.hpp:78
A storage-backing wrapper that allows types with non-trivial constructors to be aliased in unions.
Definition: util_type.hpp:363