/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.6.0/hipcub/include/hipcub/backend/rocprim/warp/warp_load.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.6.0/hipcub/include/hipcub/backend/rocprim/warp/warp_load.hpp Source File#

hipCUB: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.6.0/hipcub/include/hipcub/backend/rocprim/warp/warp_load.hpp Source File
warp_load.hpp
1 /******************************************************************************
2  * Copyright (c) 2010-2011, Duane Merrill. All rights reserved.
3  * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
4  * Modifications Copyright (c) 2017-2021, Advanced Micro Devices, Inc. All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  * * Redistributions of source code must retain the above copyright
9  * notice, this list of conditions and the following disclaimer.
10  * * Redistributions in binary form must reproduce the above copyright
11  * notice, this list of conditions and the following disclaimer in the
12  * documentation and/or other materials provided with the distribution.
13  * * Neither the name of the NVIDIA CORPORATION nor the
14  * names of its contributors may be used to endorse or promote products
15  * derived from this software without specific prior written permission.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
18  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
19  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
21  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
22  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
23  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
24  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27  *
28  ******************************************************************************/
29 
30 #ifndef HIPCUB_ROCPRIM_WARP_WARP_LOAD_HPP_
31 #define HIPCUB_ROCPRIM_WARP_WARP_LOAD_HPP_
32 
33 #include "../../../config.hpp"
34 
35 #include "../util_type.hpp"
36 #include "../iterator/cache_modified_input_iterator.hpp"
37 #include "./warp_exchange.hpp"
38 
39 #include <rocprim/block/block_load_func.hpp>
40 
41 BEGIN_HIPCUB_NAMESPACE
42 
43 enum WarpLoadAlgorithm
44 {
45  WARP_LOAD_DIRECT,
46  WARP_LOAD_STRIPED,
47  WARP_LOAD_VECTORIZE,
48  WARP_LOAD_TRANSPOSE
49 };
50 
51 template<
52  class InputT,
53  int ITEMS_PER_THREAD,
54  WarpLoadAlgorithm ALGORITHM = WARP_LOAD_DIRECT,
55  int LOGICAL_WARP_THREADS = HIPCUB_DEVICE_WARP_THREADS,
56  int ARCH = HIPCUB_ARCH
57 >
58 class WarpLoad
59 {
60 private:
61  constexpr static bool IS_ARCH_WARP
62  = static_cast<unsigned>(LOGICAL_WARP_THREADS) == HIPCUB_DEVICE_WARP_THREADS;
63 
64  template <WarpLoadAlgorithm _POLICY>
65  struct LoadInternal;
66 
67  template <>
68  struct LoadInternal<WARP_LOAD_DIRECT>
69  {
70  using TempStorage = NullType;
71  int linear_tid;
72 
73  HIPCUB_DEVICE __forceinline__
74  LoadInternal(
75  TempStorage & /*temp_storage*/,
76  int linear_tid)
77  : linear_tid(linear_tid)
78  {
79  }
80 
81  template <typename InputIteratorT>
82  HIPCUB_DEVICE __forceinline__ void Load(
83  InputIteratorT block_itr,
84  InputT (&items)[ITEMS_PER_THREAD])
85  {
86  ::rocprim::block_load_direct_blocked(
87  static_cast<unsigned>(linear_tid),
88  block_itr,
89  items
90  );
91  }
92 
93  template <typename InputIteratorT>
94  HIPCUB_DEVICE __forceinline__ void Load(
95  InputIteratorT block_itr,
96  InputT (&items)[ITEMS_PER_THREAD],
97  int valid_items)
98  {
99  ::rocprim::block_load_direct_blocked(
100  static_cast<unsigned>(linear_tid),
101  block_itr,
102  items,
103  static_cast<unsigned>(valid_items)
104  );
105  }
106 
107  template <typename InputIteratorT, typename DefaultT>
108  HIPCUB_DEVICE __forceinline__ void Load(
109  InputIteratorT block_itr,
110  InputT (&items)[ITEMS_PER_THREAD],
111  int valid_items,
112  DefaultT oob_default)
113  {
114  ::rocprim::block_load_direct_blocked(
115  static_cast<unsigned>(linear_tid),
116  block_itr,
117  items,
118  static_cast<unsigned>(valid_items),
119  oob_default
120  );
121  }
122  };
123 
124  template <>
125  struct LoadInternal<WARP_LOAD_STRIPED>
126  {
127  using TempStorage = NullType;
128  int linear_tid;
129 
130  HIPCUB_DEVICE __forceinline__
131  LoadInternal(
132  TempStorage & /*temp_storage*/,
133  int linear_tid)
134  : linear_tid(linear_tid)
135  {
136  }
137 
138  template <typename InputIteratorT>
139  HIPCUB_DEVICE __forceinline__ void Load(
140  InputIteratorT block_itr,
141  InputT (&items)[ITEMS_PER_THREAD])
142  {
143  ::rocprim::block_load_direct_warp_striped<LOGICAL_WARP_THREADS>(
144  static_cast<unsigned>(linear_tid),
145  block_itr,
146  items
147  );
148  }
149 
150  template <typename InputIteratorT>
151  HIPCUB_DEVICE __forceinline__ void Load(
152  InputIteratorT block_itr,
153  InputT (&items)[ITEMS_PER_THREAD],
154  int valid_items)
155  {
156  ::rocprim::block_load_direct_warp_striped<LOGICAL_WARP_THREADS>(
157  static_cast<unsigned>(linear_tid),
158  block_itr,
159  items,
160  static_cast<unsigned>(valid_items)
161  );
162  }
163 
164  template <typename InputIteratorT, typename DefaultT>
165  HIPCUB_DEVICE __forceinline__ void Load(
166  InputIteratorT block_itr,
167  InputT (&items)[ITEMS_PER_THREAD],
168  int valid_items,
169  DefaultT oob_default)
170  {
171  ::rocprim::block_load_direct_warp_striped<LOGICAL_WARP_THREADS>(
172  static_cast<unsigned>(linear_tid),
173  block_itr,
174  items,
175  static_cast<unsigned>(valid_items),
176  oob_default
177  );
178  }
179  };
180 
181  template <>
182  struct LoadInternal<WARP_LOAD_VECTORIZE>
183  {
184  using TempStorage = NullType;
185  int linear_tid;
186 
187  HIPCUB_DEVICE __forceinline__ LoadInternal(
188  TempStorage & /*temp_storage*/,
189  int linear_tid)
190  : linear_tid(linear_tid)
191  {
192  }
193 
194  template <typename InputIteratorT>
195  HIPCUB_DEVICE __forceinline__ void Load(
196  InputT *block_ptr,
197  InputT (&items)[ITEMS_PER_THREAD])
198  {
199  ::rocprim::block_load_direct_blocked_vectorized(
200  static_cast<unsigned>(linear_tid),
201  block_ptr,
202  items
203  );
204  }
205 
206  template <typename InputIteratorT>
207  HIPCUB_DEVICE __forceinline__ void Load(
208  const InputT *block_ptr,
209  InputT (&items)[ITEMS_PER_THREAD])
210  {
211  ::rocprim::block_load_direct_blocked_vectorized(
212  static_cast<unsigned>(linear_tid),
213  block_ptr,
214  items
215  );
216  }
217 
218  template<
219  CacheLoadModifier MODIFIER,
220  typename ValueType,
221  typename OffsetT
222  >
223  HIPCUB_DEVICE __forceinline__ void Load(
225  InputT (&items)[ITEMS_PER_THREAD])
226  {
227  ::rocprim::block_load_direct_blocked_vectorized(
228  static_cast<unsigned>(linear_tid),
229  block_itr,
230  items
231  );
232  }
233 
234  template <typename _InputIteratorT>
235  HIPCUB_DEVICE __forceinline__ void Load(
236  _InputIteratorT block_itr,
237  InputT (&items)[ITEMS_PER_THREAD])
238  {
239  ::rocprim::block_load_direct_blocked_vectorized(
240  static_cast<unsigned>(linear_tid),
241  block_itr,
242  items
243  );
244  }
245 
246  template <typename InputIteratorT>
247  HIPCUB_DEVICE __forceinline__ void Load(
248  InputIteratorT block_itr,
249  InputT (&items)[ITEMS_PER_THREAD],
250  int valid_items)
251  {
252  ::rocprim::block_load_direct_blocked_vectorized(
253  static_cast<unsigned>(linear_tid),
254  block_itr,
255  items,
256  static_cast<unsigned>(valid_items)
257  );
258  }
259 
260  template <typename InputIteratorT, typename DefaultT>
261  HIPCUB_DEVICE __forceinline__ void Load(
262  InputIteratorT block_itr,
263  InputT (&items)[ITEMS_PER_THREAD],
264  int valid_items,
265  DefaultT oob_default)
266  {
267  // vectorized overload does not exist
268  // fall back to direct blocked
269  ::rocprim::block_load_direct_blocked(
270  static_cast<unsigned>(linear_tid),
271  block_itr,
272  items,
273  static_cast<unsigned>(valid_items),
274  oob_default
275  );
276  }
277  };
278 
279  template <>
280  struct LoadInternal<WARP_LOAD_TRANSPOSE>
281  {
282  using WarpExchangeT = WarpExchange<
283  InputT,
284  ITEMS_PER_THREAD,
285  LOGICAL_WARP_THREADS,
286  ARCH
287  >;
288  using TempStorage = typename WarpExchangeT::TempStorage;
289  TempStorage& temp_storage;
290  int linear_tid;
291 
292  HIPCUB_DEVICE __forceinline__ LoadInternal(
293  TempStorage &temp_storage,
294  int linear_tid) :
295  temp_storage(temp_storage),
296  linear_tid(linear_tid)
297  {
298  }
299 
300  template <typename InputIteratorT>
301  HIPCUB_DEVICE __forceinline__ void Load(
302  InputIteratorT block_itr,
303  InputT (&items)[ITEMS_PER_THREAD])
304  {
305  ::rocprim::block_load_direct_warp_striped<LOGICAL_WARP_THREADS>(
306  static_cast<unsigned>(linear_tid),
307  block_itr,
308  items
309  );
310  WarpExchangeT(temp_storage).StripedToBlocked(items, items);
311  }
312 
313  template <typename InputIteratorT>
314  HIPCUB_DEVICE __forceinline__ void Load(
315  InputIteratorT block_itr,
316  InputT (&items)[ITEMS_PER_THREAD],
317  int valid_items)
318  {
319  ::rocprim::block_load_direct_warp_striped<LOGICAL_WARP_THREADS>(
320  static_cast<unsigned>(linear_tid),
321  block_itr,
322  items,
323  static_cast<unsigned>(valid_items)
324  );
325  WarpExchangeT(temp_storage).StripedToBlocked(items, items);
326  }
327 
328  template <typename InputIteratorT, typename DefaultT>
329  HIPCUB_DEVICE __forceinline__ void Load(
330  InputIteratorT block_itr,
331  InputT (&items)[ITEMS_PER_THREAD],
332  int valid_items,
333  DefaultT oob_default)
334  {
335  ::rocprim::block_load_direct_warp_striped<LOGICAL_WARP_THREADS>(
336  static_cast<unsigned>(linear_tid),
337  block_itr,
338  items,
339  static_cast<unsigned>(valid_items),
340  oob_default
341  );
342  WarpExchangeT(temp_storage).StripedToBlocked(items, items);
343  }
344  };
345 
346  using InternalLoad = LoadInternal<ALGORITHM>;
347 
348  using _TempStorage = typename InternalLoad::TempStorage;
349 
350  HIPCUB_DEVICE __forceinline__ _TempStorage &PrivateStorage()
351  {
352  __shared__ _TempStorage private_storage;
353  return private_storage;
354  }
355 
356  _TempStorage &temp_storage;
357  int linear_tid;
358 
359 public:
360  struct TempStorage : Uninitialized<_TempStorage>
361  {
362  };
363 
364  HIPCUB_DEVICE __forceinline__
365  WarpLoad() :
366  temp_storage(PrivateStorage()),
367  linear_tid(IS_ARCH_WARP ? ::rocprim::lane_id() : (::rocprim::lane_id() % LOGICAL_WARP_THREADS))
368  {
369  }
370 
371  HIPCUB_DEVICE __forceinline__
372  WarpLoad(TempStorage &temp_storage) :
373  temp_storage(temp_storage.Alias()),
374  linear_tid(IS_ARCH_WARP ? ::rocprim::lane_id() : (::rocprim::lane_id() % LOGICAL_WARP_THREADS))
375  {
376  }
377 
378  template <typename InputIteratorT>
379  HIPCUB_DEVICE __forceinline__ void Load(
380  InputIteratorT block_itr,
381  InputT (&items)[ITEMS_PER_THREAD])
382  {
383  InternalLoad(temp_storage, linear_tid)
384  .Load(block_itr, items);
385  }
386 
387  template <typename InputIteratorT>
388  HIPCUB_DEVICE __forceinline__ void Load(
389  InputIteratorT block_itr,
390  InputT (&items)[ITEMS_PER_THREAD],
391  int valid_items)
392  {
393  InternalLoad(temp_storage, linear_tid)
394  .Load(block_itr, items, valid_items);
395  }
396 
397  template <typename InputIteratorT,
398  typename DefaultT>
399  HIPCUB_DEVICE __forceinline__ void Load(
400  InputIteratorT block_itr,
401  InputT (&items)[ITEMS_PER_THREAD],
402  int valid_items,
403  DefaultT oob_default)
404  {
405  InternalLoad(temp_storage, linear_tid)
406  .Load(block_itr, items, valid_items, oob_default);
407  }
408 };
409 
410 END_HIPCUB_NAMESPACE
411 
412 #endif // HIPCUB_ROCPRIM_WARP_WARP_LOAD_HPP_
Definition: cache_modified_input_iterator.hpp:52
Definition: warp_exchange.hpp:47
Definition: warp_load.hpp:59
A storage-backing wrapper that allows types with non-trivial constructors to be aliased in unions.
Definition: util_type.hpp:363
Definition: warp_load.hpp:361