30 #ifndef HIPCUB_ROCPRIM_WARP_WARP_EXCHANGE_HPP_
31 #define HIPCUB_ROCPRIM_WARP_WARP_EXCHANGE_HPP_
33 #include "../../../config.hpp"
34 #include "../util_type.hpp"
36 BEGIN_HIPCUB_NAMESPACE
41 int LOGICAL_WARP_THREADS = HIPCUB_DEVICE_WARP_THREADS,
42 int ARCH = HIPCUB_ARCH
47 "LOGICAL_WARP_THREADS must be a power of two");
49 constexpr
static int SMEM_BANKS = ::rocprim::detail::get_lds_banks_no();
51 constexpr
static bool HAS_BANK_CONFLICTS =
54 constexpr
static int BANK_CONFLICTS_PADDING =
55 HAS_BANK_CONFLICTS ? (ITEMS_PER_THREAD / SMEM_BANKS) : 0;
57 constexpr
static int ITEMS_PER_TILE =
58 ITEMS_PER_THREAD * LOGICAL_WARP_THREADS + BANK_CONFLICTS_PADDING;
60 constexpr
static bool IS_ARCH_WARP = LOGICAL_WARP_THREADS ==
61 HIPCUB_DEVICE_WARP_THREADS;
65 InputT items_shared[ITEMS_PER_TILE];
68 _TempStorage &temp_storage;
76 explicit HIPCUB_DEVICE __forceinline__
78 temp_storage(temp_storage.Alias()),
79 lane_id(IS_ARCH_WARP ? LaneId() : LaneId() % LOGICAL_WARP_THREADS)
83 template <
typename OutputT>
84 HIPCUB_DEVICE __forceinline__
85 void BlockedToStriped(
86 const InputT (&input_items)[ITEMS_PER_THREAD],
87 OutputT (&output_items)[ITEMS_PER_THREAD])
89 for (
int item = 0; item < ITEMS_PER_THREAD; ++item)
91 const int idx = ITEMS_PER_THREAD * lane_id + item;
92 temp_storage.items_shared[idx] = input_items[item];
98 for (
int item = 0; item < ITEMS_PER_THREAD; ++item)
100 const int idx = LOGICAL_WARP_THREADS * item + lane_id;
101 output_items[item] = temp_storage.items_shared[idx];
105 template <
typename OutputT>
106 HIPCUB_DEVICE __forceinline__
107 void StripedToBlocked(
108 const InputT (&input_items)[ITEMS_PER_THREAD],
109 OutputT (&output_items)[ITEMS_PER_THREAD])
111 for (
int item = 0; item < ITEMS_PER_THREAD; ++item)
113 const int idx = LOGICAL_WARP_THREADS * item + lane_id;
114 temp_storage.items_shared[idx] = input_items[item];
120 for (
int item = 0; item < ITEMS_PER_THREAD; ++item)
122 const int idx = ITEMS_PER_THREAD * lane_id + item;
123 output_items[item] = temp_storage.items_shared[idx];
127 template <
typename OffsetT>
128 HIPCUB_DEVICE __forceinline__
129 void ScatterToStriped(
130 InputT (&items)[ITEMS_PER_THREAD],
131 OffsetT (&ranks)[ITEMS_PER_THREAD])
133 ScatterToStriped(items, items, ranks);
136 template <
typename OutputT,
138 HIPCUB_DEVICE __forceinline__
139 void ScatterToStriped(
140 const InputT (&input_items)[ITEMS_PER_THREAD],
141 OutputT (&output_items)[ITEMS_PER_THREAD],
142 OffsetT (&ranks)[ITEMS_PER_THREAD])
145 for (
int item = 0; item < ITEMS_PER_THREAD; ++item)
147 temp_storage.items_shared[ranks[item]] = input_items[item];
154 for (
int item = 0; item < ITEMS_PER_THREAD; item++)
156 int item_offset = (item * LOGICAL_WARP_THREADS) + lane_id;
157 output_items[item] = temp_storage.items_shared[item_offset];
Definition: warp_exchange.hpp:45
Definition: util_type.hpp:78
A storage-backing wrapper that allows types with non-trivial constructors to be aliased in unions.
Definition: util_type.hpp:363
Definition: warp_exchange.hpp:72