30 #ifndef HIPCUB_ROCPRIM_WARP_WARP_LOAD_HPP_
31 #define HIPCUB_ROCPRIM_WARP_WARP_LOAD_HPP_
33 #include "../../../config.hpp"
35 #include "../util_type.hpp"
36 #include "../iterator/cache_modified_input_iterator.hpp"
37 #include "./warp_exchange.hpp"
39 #include <rocprim/block/block_load_func.hpp>
41 BEGIN_HIPCUB_NAMESPACE
43 enum WarpLoadAlgorithm
54 WarpLoadAlgorithm ALGORITHM = WARP_LOAD_DIRECT,
55 int LOGICAL_WARP_THREADS = HIPCUB_DEVICE_WARP_THREADS,
56 int ARCH = HIPCUB_ARCH
61 constexpr
static bool IS_ARCH_WARP
62 =
static_cast<unsigned>(LOGICAL_WARP_THREADS) == HIPCUB_DEVICE_WARP_THREADS;
64 template <WarpLoadAlgorithm _POLICY>
68 struct LoadInternal<WARP_LOAD_DIRECT>
73 HIPCUB_DEVICE __forceinline__
77 : linear_tid(linear_tid)
81 template <
typename InputIteratorT>
82 HIPCUB_DEVICE __forceinline__
void Load(
83 InputIteratorT block_itr,
84 InputT (&items)[ITEMS_PER_THREAD])
86 ::rocprim::block_load_direct_blocked(
87 static_cast<unsigned>(linear_tid),
93 template <
typename InputIteratorT>
94 HIPCUB_DEVICE __forceinline__
void Load(
95 InputIteratorT block_itr,
96 InputT (&items)[ITEMS_PER_THREAD],
99 ::rocprim::block_load_direct_blocked(
100 static_cast<unsigned>(linear_tid),
103 static_cast<unsigned>(valid_items)
107 template <
typename InputIteratorT,
typename DefaultT>
108 HIPCUB_DEVICE __forceinline__
void Load(
109 InputIteratorT block_itr,
110 InputT (&items)[ITEMS_PER_THREAD],
112 DefaultT oob_default)
114 ::rocprim::block_load_direct_blocked(
115 static_cast<unsigned>(linear_tid),
118 static_cast<unsigned>(valid_items),
125 struct LoadInternal<WARP_LOAD_STRIPED>
130 HIPCUB_DEVICE __forceinline__
134 : linear_tid(linear_tid)
138 template <
typename InputIteratorT>
139 HIPCUB_DEVICE __forceinline__
void Load(
140 InputIteratorT block_itr,
141 InputT (&items)[ITEMS_PER_THREAD])
143 ::rocprim::block_load_direct_warp_striped<LOGICAL_WARP_THREADS>(
144 static_cast<unsigned>(linear_tid),
150 template <
typename InputIteratorT>
151 HIPCUB_DEVICE __forceinline__
void Load(
152 InputIteratorT block_itr,
153 InputT (&items)[ITEMS_PER_THREAD],
156 ::rocprim::block_load_direct_warp_striped<LOGICAL_WARP_THREADS>(
157 static_cast<unsigned>(linear_tid),
160 static_cast<unsigned>(valid_items)
164 template <
typename InputIteratorT,
typename DefaultT>
165 HIPCUB_DEVICE __forceinline__
void Load(
166 InputIteratorT block_itr,
167 InputT (&items)[ITEMS_PER_THREAD],
169 DefaultT oob_default)
171 ::rocprim::block_load_direct_warp_striped<LOGICAL_WARP_THREADS>(
172 static_cast<unsigned>(linear_tid),
175 static_cast<unsigned>(valid_items),
182 struct LoadInternal<WARP_LOAD_VECTORIZE>
187 HIPCUB_DEVICE __forceinline__ LoadInternal(
190 : linear_tid(linear_tid)
194 template <
typename InputIteratorT>
195 HIPCUB_DEVICE __forceinline__
void Load(
197 InputT (&items)[ITEMS_PER_THREAD])
199 ::rocprim::block_load_direct_blocked_vectorized(
200 static_cast<unsigned>(linear_tid),
206 template <
typename InputIteratorT>
207 HIPCUB_DEVICE __forceinline__
void Load(
208 const InputT *block_ptr,
209 InputT (&items)[ITEMS_PER_THREAD])
211 ::rocprim::block_load_direct_blocked_vectorized(
212 static_cast<unsigned>(linear_tid),
219 CacheLoadModifier MODIFIER,
223 HIPCUB_DEVICE __forceinline__
void Load(
225 InputT (&items)[ITEMS_PER_THREAD])
227 ::rocprim::block_load_direct_blocked_vectorized(
228 static_cast<unsigned>(linear_tid),
234 template <
typename _InputIteratorT>
235 HIPCUB_DEVICE __forceinline__
void Load(
236 _InputIteratorT block_itr,
237 InputT (&items)[ITEMS_PER_THREAD])
239 ::rocprim::block_load_direct_blocked_vectorized(
240 static_cast<unsigned>(linear_tid),
246 template <
typename InputIteratorT>
247 HIPCUB_DEVICE __forceinline__
void Load(
248 InputIteratorT block_itr,
249 InputT (&items)[ITEMS_PER_THREAD],
252 ::rocprim::block_load_direct_blocked_vectorized(
253 static_cast<unsigned>(linear_tid),
256 static_cast<unsigned>(valid_items)
260 template <
typename InputIteratorT,
typename DefaultT>
261 HIPCUB_DEVICE __forceinline__
void Load(
262 InputIteratorT block_itr,
263 InputT (&items)[ITEMS_PER_THREAD],
265 DefaultT oob_default)
269 ::rocprim::block_load_direct_blocked(
270 static_cast<unsigned>(linear_tid),
273 static_cast<unsigned>(valid_items),
280 struct LoadInternal<WARP_LOAD_TRANSPOSE>
285 LOGICAL_WARP_THREADS,
288 using TempStorage =
typename WarpExchangeT::TempStorage;
292 HIPCUB_DEVICE __forceinline__ LoadInternal(
295 temp_storage(temp_storage),
296 linear_tid(linear_tid)
300 template <
typename InputIteratorT>
301 HIPCUB_DEVICE __forceinline__
void Load(
302 InputIteratorT block_itr,
303 InputT (&items)[ITEMS_PER_THREAD])
305 ::rocprim::block_load_direct_warp_striped<LOGICAL_WARP_THREADS>(
306 static_cast<unsigned>(linear_tid),
310 WarpExchangeT(temp_storage).StripedToBlocked(items, items);
313 template <
typename InputIteratorT>
314 HIPCUB_DEVICE __forceinline__
void Load(
315 InputIteratorT block_itr,
316 InputT (&items)[ITEMS_PER_THREAD],
319 ::rocprim::block_load_direct_warp_striped<LOGICAL_WARP_THREADS>(
320 static_cast<unsigned>(linear_tid),
323 static_cast<unsigned>(valid_items)
325 WarpExchangeT(temp_storage).StripedToBlocked(items, items);
328 template <
typename InputIteratorT,
typename DefaultT>
329 HIPCUB_DEVICE __forceinline__
void Load(
330 InputIteratorT block_itr,
331 InputT (&items)[ITEMS_PER_THREAD],
333 DefaultT oob_default)
335 ::rocprim::block_load_direct_warp_striped<LOGICAL_WARP_THREADS>(
336 static_cast<unsigned>(linear_tid),
339 static_cast<unsigned>(valid_items),
342 WarpExchangeT(temp_storage).StripedToBlocked(items, items);
346 using InternalLoad = LoadInternal<ALGORITHM>;
348 using _TempStorage =
typename InternalLoad::TempStorage;
350 HIPCUB_DEVICE __forceinline__ _TempStorage &PrivateStorage()
352 __shared__ _TempStorage private_storage;
353 return private_storage;
356 _TempStorage &temp_storage;
364 HIPCUB_DEVICE __forceinline__
366 temp_storage(PrivateStorage()),
367 linear_tid(IS_ARCH_WARP ? ::rocprim::lane_id() : (::rocprim::lane_id() % LOGICAL_WARP_THREADS))
371 HIPCUB_DEVICE __forceinline__
372 WarpLoad(TempStorage &temp_storage) :
373 temp_storage(temp_storage.Alias()),
374 linear_tid(IS_ARCH_WARP ? ::rocprim::lane_id() : (::rocprim::lane_id() % LOGICAL_WARP_THREADS))
378 template <
typename InputIteratorT>
379 HIPCUB_DEVICE __forceinline__
void Load(
380 InputIteratorT block_itr,
381 InputT (&items)[ITEMS_PER_THREAD])
383 InternalLoad(temp_storage, linear_tid)
384 .Load(block_itr, items);
387 template <
typename InputIteratorT>
388 HIPCUB_DEVICE __forceinline__
void Load(
389 InputIteratorT block_itr,
390 InputT (&items)[ITEMS_PER_THREAD],
393 InternalLoad(temp_storage, linear_tid)
394 .Load(block_itr, items, valid_items);
397 template <
typename InputIteratorT,
399 HIPCUB_DEVICE __forceinline__
void Load(
400 InputIteratorT block_itr,
401 InputT (&items)[ITEMS_PER_THREAD],
403 DefaultT oob_default)
405 InternalLoad(temp_storage, linear_tid)
406 .Load(block_itr, items, valid_items, oob_default);
Definition: warp_exchange.hpp:45
Definition: warp_load.hpp:59
A storage-backing wrapper that allows types with non-trivial constructors to be aliased in unions.
Definition: util_type.hpp:363
Definition: warp_load.hpp:361