29 #ifndef HIPCUB_ROCPRIM_BLOCK_BLOCK_MERGE_SORT_HPP_
30 #define HIPCUB_ROCPRIM_BLOCK_BLOCK_MERGE_SORT_HPP_
32 #include "../util_math.hpp"
33 #include "../util_type.hpp"
35 #include <rocprim/functional.hpp>
37 BEGIN_HIPCUB_NAMESPACE
43 template <
typename KeyT,
44 typename KeyIteratorT,
47 __device__ __forceinline__ OffsetT MergePath(KeyIteratorT keys1,
52 BinaryPred binary_pred)
54 OffsetT keys1_begin = diag < keys2_count ? 0 : diag - keys2_count;
55 OffsetT keys1_end = (::rocprim::min)(diag, keys1_count);
57 while (keys1_begin < keys1_end)
59 OffsetT mid = hipcub::MidPoint<OffsetT>(keys1_begin, keys1_end);
60 KeyT key1 = keys1[mid];
61 KeyT key2 = keys2[diag - 1 - mid];
62 bool pred = binary_pred(key2, key1);
70 keys1_begin = mid + 1;
76 template <
typename KeyT,
typename CompareOp,
int ITEMS_PER_THREAD>
77 __device__ __forceinline__
void SerialMerge(KeyT *keys_shared,
82 KeyT (&output)[ITEMS_PER_THREAD],
83 int (&indices)[ITEMS_PER_THREAD],
86 int keys1_end = keys1_beg + keys1_count;
87 int keys2_end = keys2_beg + keys2_count;
89 KeyT key1 = keys_shared[keys1_beg];
90 KeyT key2 = keys_shared[keys2_beg];
93 for (
int item = 0; item < ITEMS_PER_THREAD; ++item)
95 bool p = (keys2_beg < keys2_end) &&
96 ((keys1_beg >= keys1_end)
97 || compare_op(key2, key1));
99 output[item] = p ? key2 : key1;
100 indices[item] = p ? keys2_beg++ : keys1_beg++;
104 key2 = keys_shared[keys2_beg];
108 key1 = keys_shared[keys1_beg];
181 int ITEMS_PER_THREAD,
182 typename ValueT = NullType,
190 static constexpr
int BLOCK_THREADS = BLOCK_DIM_X * BLOCK_DIM_Y * BLOCK_DIM_Z;
191 static constexpr
int ITEMS_PER_TILE = ITEMS_PER_THREAD * BLOCK_THREADS;
194 static constexpr
bool KEYS_ONLY = ::rocprim::Equals<ValueT, NullType>::VALUE;
199 KeyT keys_shared[ITEMS_PER_TILE + 1];
200 ValueT items_shared[ITEMS_PER_TILE + 1];
204 __device__ __forceinline__ _TempStorage& PrivateStorage()
206 __shared__ _TempStorage private_storage;
207 return private_storage;
211 _TempStorage &temp_storage;
214 unsigned int linear_tid;
222 : temp_storage(PrivateStorage())
223 , linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
226 __device__ __forceinline__
BlockMergeSort(TempStorage &temp_storage)
227 : temp_storage(temp_storage.Alias())
228 , linear_tid(RowMajorTid(BLOCK_DIM_X, BLOCK_DIM_Y, BLOCK_DIM_Z))
233 template <
typename T>
234 __device__ __forceinline__
void Swap(T &lhs, T &rhs)
241 template <
typename CompareOp>
242 __device__ __forceinline__
void
243 StableOddEvenSort(KeyT (&keys)[ITEMS_PER_THREAD],
244 ValueT (&items)[ITEMS_PER_THREAD],
245 CompareOp compare_op)
248 for (
int i = 0; i < ITEMS_PER_THREAD; ++i)
251 for (
int j = 1 & i; j < ITEMS_PER_THREAD - 1; j += 2)
253 if (compare_op(keys[j + 1], keys[j]))
255 Swap(keys[j], keys[j + 1]);
258 Swap(items[j], items[j + 1]);
278 template <
typename CompareOp>
279 __device__ __forceinline__
void
280 Sort(KeyT (&keys)[ITEMS_PER_THREAD],
281 CompareOp compare_op)
285 ValueT items[ITEMS_PER_THREAD];
286 Sort<CompareOp, false>(keys, items, compare_op, ITEMS_PER_TILE, keys[0]);
306 template <
typename CompareOp>
307 __device__ __forceinline__
void
308 Sort(KeyT (&keys)[ITEMS_PER_THREAD],
309 CompareOp compare_op,
313 ValueT items[ITEMS_PER_THREAD];
314 Sort<CompareOp, true>(keys, items, compare_op, valid_items, oob_default);
328 template <
typename CompareOp>
329 __device__ __forceinline__
void
330 Sort(KeyT (&keys)[ITEMS_PER_THREAD],
331 ValueT (&items)[ITEMS_PER_THREAD],
332 CompareOp compare_op)
334 Sort<CompareOp, false>(keys, items, compare_op, ITEMS_PER_TILE, keys[0]);
355 template <
typename CompareOp,
356 bool IS_LAST_TILE =
true>
357 __device__ __forceinline__
void
358 Sort(KeyT (&keys)[ITEMS_PER_THREAD],
359 ValueT (&items)[ITEMS_PER_THREAD],
360 CompareOp compare_op,
369 KeyT max_key = oob_default;
371 for (
int item = 1; item < ITEMS_PER_THREAD; ++item)
373 if (ITEMS_PER_THREAD * linear_tid + item < valid_items)
375 max_key = compare_op(max_key, keys[item]) ? keys[item] : max_key;
379 keys[item] = max_key;
386 if (!IS_LAST_TILE || ITEMS_PER_THREAD * linear_tid < valid_items)
388 StableOddEvenSort(keys, items, compare_op);
395 for (
int target_merged_threads_number = 2;
396 target_merged_threads_number <= BLOCK_THREADS;
397 target_merged_threads_number *= 2)
399 int merged_threads_number = target_merged_threads_number / 2;
400 int mask = target_merged_threads_number - 1;
407 for (
int item = 0; item < ITEMS_PER_THREAD; ++item)
409 int idx = ITEMS_PER_THREAD * linear_tid + item;
410 temp_storage.keys_shared[idx] = keys[item];
415 int indices[ITEMS_PER_THREAD];
417 int first_thread_idx_in_thread_group_being_merged = ~mask & linear_tid;
418 int start = ITEMS_PER_THREAD * first_thread_idx_in_thread_group_being_merged;
419 int size = ITEMS_PER_THREAD * merged_threads_number;
421 int thread_idx_in_thread_group_being_merged = mask & linear_tid;
424 (rocprim::min)(valid_items,
425 ITEMS_PER_THREAD * thread_idx_in_thread_group_being_merged);
427 int keys1_beg = (rocprim::min)(valid_items, start);
428 int keys1_end = (rocprim::min)(valid_items, keys1_beg + size);
429 int keys2_beg = keys1_end;
430 int keys2_end = (rocprim::min)(valid_items, keys2_beg + size);
432 int keys1_count = keys1_end - keys1_beg;
433 int keys2_count = keys2_end - keys2_beg;
435 int partition_diag = MergePath<KeyT>(&temp_storage.keys_shared[keys1_beg],
436 &temp_storage.keys_shared[keys2_beg],
442 int keys1_beg_loc = keys1_beg + partition_diag;
443 int keys1_end_loc = keys1_end;
444 int keys2_beg_loc = keys2_beg + diag - partition_diag;
445 int keys2_end_loc = keys2_end;
446 int keys1_count_loc = keys1_end_loc - keys1_beg_loc;
447 int keys2_count_loc = keys2_end_loc - keys2_beg_loc;
448 SerialMerge(&temp_storage.keys_shared[0],
464 for (
int item = 0; item < ITEMS_PER_THREAD; ++item)
466 int idx = ITEMS_PER_THREAD * linear_tid + item;
467 temp_storage.items_shared[idx] = items[item];
475 for (
int item = 0; item < ITEMS_PER_THREAD; ++item)
477 items[item] = temp_storage.items_shared[indices[item]];
495 template <
typename CompareOp>
496 __device__ __forceinline__
void
498 CompareOp compare_op)
500 Sort(keys, compare_op);
515 template <
typename CompareOp>
516 __device__ __forceinline__
void
518 ValueT (&items)[ITEMS_PER_THREAD],
519 CompareOp compare_op)
521 Sort(keys, items, compare_op);
542 template <
typename CompareOp>
543 __device__ __forceinline__
void
545 CompareOp compare_op,
549 Sort(keys, compare_op, valid_items, oob_default);
571 template <
typename CompareOp,
572 bool IS_LAST_TILE =
true>
573 __device__ __forceinline__
void
575 ValueT (&items)[ITEMS_PER_THREAD],
576 CompareOp compare_op,
580 Sort<CompareOp, IS_LAST_TILE>(keys,
The BlockMergeSort class provides methods for sorting items partitioned across a CUDA thread block us...
Definition: block_merge_sort.hpp:186
__device__ __forceinline__ void StableSort(KeyT(&keys)[ITEMS_PER_THREAD], ValueT(&items)[ITEMS_PER_THREAD], CompareOp compare_op, int valid_items, KeyT oob_default)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:574
__device__ __forceinline__ void Sort(KeyT(&keys)[ITEMS_PER_THREAD], CompareOp compare_op, int valid_items, KeyT oob_default)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:308
__device__ __forceinline__ void StableSort(KeyT(&keys)[ITEMS_PER_THREAD], CompareOp compare_op, int valid_items, KeyT oob_default)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:544
__device__ __forceinline__ void StableSort(KeyT(&keys)[ITEMS_PER_THREAD], ValueT(&items)[ITEMS_PER_THREAD], CompareOp compare_op)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:517
__device__ __forceinline__ void StableSort(KeyT(&keys)[ITEMS_PER_THREAD], CompareOp compare_op)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:497
__device__ __forceinline__ void Sort(KeyT(&keys)[ITEMS_PER_THREAD], ValueT(&items)[ITEMS_PER_THREAD], CompareOp compare_op)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:330
__device__ __forceinline__ void Sort(KeyT(&keys)[ITEMS_PER_THREAD], CompareOp compare_op)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:280
__device__ __forceinline__ void Sort(KeyT(&keys)[ITEMS_PER_THREAD], ValueT(&items)[ITEMS_PER_THREAD], CompareOp compare_op, int valid_items, KeyT oob_default)
Sorts items partitioned across a CUDA thread block using a merge sorting method.
Definition: block_merge_sort.hpp:358
\smemstorage{BlockMergeSort}
Definition: block_merge_sort.hpp:219
A storage-backing wrapper that allows types with non-trivial constructors to be aliased in unions.
Definition: util_type.hpp:359