/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.1.0/hipcub/include/hipcub/backend/rocprim/util_ptx.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.1.0/hipcub/include/hipcub/backend/rocprim/util_ptx.hpp Source File#

hipCUB: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.1.0/hipcub/include/hipcub/backend/rocprim/util_ptx.hpp Source File
util_ptx.hpp
1 /******************************************************************************
2  * Copyright (c) 2010-2011, Duane Merrill. All rights reserved.
3  * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
4  * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  * * Redistributions of source code must retain the above copyright
9  * notice, this list of conditions and the following disclaimer.
10  * * Redistributions in binary form must reproduce the above copyright
11  * notice, this list of conditions and the following disclaimer in the
12  * documentation and/or other materials provided with the distribution.
13  * * Neither the name of the NVIDIA CORPORATION nor the
14  * names of its contributors may be used to endorse or promote products
15  * derived from this software without specific prior written permission.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
18  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
19  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
21  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
22  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
23  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
24  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27  *
28  ******************************************************************************/
29 
30 #ifndef HIPCUB_ROCPRIM_UTIL_PTX_HPP_
31 #define HIPCUB_ROCPRIM_UTIL_PTX_HPP_
32 
33 #include <cstdint>
34 #include <type_traits>
35 
36 #include "../../config.hpp"
37 
38 #include <rocprim/intrinsics/warp_shuffle.hpp>
39 
40 #define HIPCUB_WARP_THREADS ::rocprim::warp_size()
41 #define HIPCUB_DEVICE_WARP_THREADS ::rocprim::device_warp_size()
42 #define HIPCUB_HOST_WARP_THREADS ::rocprim::host_warp_size()
43 #define HIPCUB_ARCH 1 // ignored with rocPRIM backend
44 
45 
46 BEGIN_HIPCUB_NAMESPACE
47 
48 // Missing compared to CUB:
49 // * ThreadExit - not supported
50 // * ThreadTrap - not supported
51 // * FFMA_RZ, FMUL_RZ - not in CUB public API
52 // * WARP_SYNC - not supported, not CUB public API
53 // * CTA_SYNC_AND - not supported, not CUB public API
54 // * MatchAny - not in CUB public API
55 //
56 // Differences:
57 // * Warp thread masks (when used) are 64-bit unsigned integers
58 // * member_mask argument is ignored in WARP_[ALL|ANY|BALLOT] funcs
59 // * Arguments first_lane, last_lane, and member_mask are ignored
60 // in Shuffle* funcs
61 // * count in BAR is ignored, BAR works like CTA_SYNC
62 
63 // ID functions etc.
64 
65 HIPCUB_DEVICE inline
66 int RowMajorTid(int block_dim_x, int block_dim_y, int block_dim_z)
67 {
68  return ((block_dim_z == 1) ? 0 : (hipThreadIdx_z * block_dim_x * block_dim_y))
69  + ((block_dim_y == 1) ? 0 : (hipThreadIdx_y * block_dim_x))
70  + hipThreadIdx_x;
71 }
72 
73 HIPCUB_DEVICE inline
74 unsigned int LaneId()
75 {
76  return ::rocprim::lane_id();
77 }
78 
79 HIPCUB_DEVICE inline
80 unsigned int WarpId()
81 {
82  return ::rocprim::warp_id();
83 }
84 
85 template <int LOGICAL_WARP_THREADS, int /* ARCH */ = 0>
86 HIPCUB_DEVICE inline
87 uint64_t WarpMask(unsigned int warp_id) {
88  constexpr bool is_pow_of_two = ::rocprim::detail::is_power_of_two(LOGICAL_WARP_THREADS);
89  constexpr bool is_arch_warp =
90  LOGICAL_WARP_THREADS == ::rocprim::device_warp_size();
91 
92  uint64_t member_mask = uint64_t(-1) >> (64 - LOGICAL_WARP_THREADS);
93 
94  if (is_pow_of_two && !is_arch_warp) {
95  member_mask <<= warp_id * LOGICAL_WARP_THREADS;
96  }
97 
98  return member_mask;
99 }
100 
101 // Returns the warp lane mask of all lanes less than the calling thread
102 HIPCUB_DEVICE inline
103 uint64_t LaneMaskLt()
104 {
105  return (uint64_t(1) << LaneId()) - 1;
106 }
107 
108 // Returns the warp lane mask of all lanes less than or equal to the calling thread
109 HIPCUB_DEVICE inline
110 uint64_t LaneMaskLe()
111 {
112  return ((uint64_t(1) << LaneId()) << 1) - 1;
113 }
114 
115 // Returns the warp lane mask of all lanes greater than the calling thread
116 HIPCUB_DEVICE inline
117 uint64_t LaneMaskGt()
118 {
119  return uint64_t(-1)^LaneMaskLe();
120 }
121 
122 // Returns the warp lane mask of all lanes greater than or equal to the calling thread
123 HIPCUB_DEVICE inline
124 uint64_t LaneMaskGe()
125 {
126  return uint64_t(-1)^LaneMaskLt();
127 }
128 
129 // Shuffle funcs
130 
131 template <
132  int LOGICAL_WARP_THREADS,
133  typename T
134 >
135 HIPCUB_DEVICE inline
136 T ShuffleUp(T input,
137  int src_offset,
138  int first_thread,
139  unsigned int member_mask)
140 {
141  // Not supported in rocPRIM.
142  (void) first_thread;
143  // Member mask is not supported in rocPRIM, because it's
144  // not supported in ROCm.
145  (void) member_mask;
146  return ::rocprim::warp_shuffle_up(
147  input, src_offset, LOGICAL_WARP_THREADS
148  );
149 }
150 
151 template <
152  int LOGICAL_WARP_THREADS,
153  typename T
154 >
155 HIPCUB_DEVICE inline
156 T ShuffleDown(T input,
157  int src_offset,
158  int last_thread,
159  unsigned int member_mask)
160 {
161  // Not supported in rocPRIM.
162  (void) last_thread;
163  // Member mask is not supported in rocPRIM, because it's
164  // not supported in ROCm.
165  (void) member_mask;
166  return ::rocprim::warp_shuffle_down(
167  input, src_offset, LOGICAL_WARP_THREADS
168  );
169 }
170 
171 template <
172  int LOGICAL_WARP_THREADS,
173  typename T
174 >
175 HIPCUB_DEVICE inline
176 T ShuffleIndex(T input,
177  int src_lane,
178  unsigned int member_mask)
179 {
180  // Member mask is not supported in rocPRIM, because it's
181  // not supported in ROCm.
182  (void) member_mask;
183  return ::rocprim::warp_shuffle(
184  input, src_lane, LOGICAL_WARP_THREADS
185  );
186 }
187 
188 // Other
189 
190 HIPCUB_DEVICE inline
191 unsigned int SHR_ADD(unsigned int x,
192  unsigned int shift,
193  unsigned int addend)
194 {
195  return (x >> shift) + addend;
196 }
197 
198 HIPCUB_DEVICE inline
199 unsigned int SHL_ADD(unsigned int x,
200  unsigned int shift,
201  unsigned int addend)
202 {
203  return (x << shift) + addend;
204 }
205 
206 namespace detail {
207 
208 template <typename UnsignedBits>
209 HIPCUB_DEVICE inline
210 auto unsigned_bit_extract(UnsignedBits source,
211  unsigned int bit_start,
212  unsigned int num_bits)
213  -> typename std::enable_if<sizeof(UnsignedBits) == 8, unsigned int>::type
214 {
215  #ifdef __HIP_PLATFORM_AMD__
216  return __bitextract_u64(source, bit_start, num_bits);
217  #else
218  return (source << (64 - bit_start - num_bits)) >> (64 - num_bits);
219  #endif // __HIP_PLATFORM_AMD__
220 }
221 
222 template <typename UnsignedBits>
223 HIPCUB_DEVICE inline
224 auto unsigned_bit_extract(UnsignedBits source,
225  unsigned int bit_start,
226  unsigned int num_bits)
227  -> typename std::enable_if<sizeof(UnsignedBits) < 8, unsigned int>::type
228 {
229  #ifdef __HIP_PLATFORM_AMD__
230  return __bitextract_u32(source, bit_start, num_bits);
231  #else
232  return (static_cast<unsigned int>(source) << (32 - bit_start - num_bits)) >> (32 - num_bits);
233  #endif // __HIP_PLATFORM_AMD__
234 }
235 
236 } // end namespace detail
237 
238 // Bitfield-extract.
239 // Extracts \p num_bits from \p source starting at bit-offset \p bit_start.
240 // The input \p source may be an 8b, 16b, 32b, or 64b unsigned integer type.
241 template <typename UnsignedBits>
242 HIPCUB_DEVICE inline
243 unsigned int BFE(UnsignedBits source,
244  unsigned int bit_start,
245  unsigned int num_bits)
246 {
247  static_assert(std::is_unsigned<UnsignedBits>::value, "UnsignedBits must be unsigned");
248  return detail::unsigned_bit_extract(source, bit_start, num_bits);
249 }
250 
251 // Bitfield insert.
252 // Inserts the \p num_bits least significant bits of \p y into \p x at bit-offset \p bit_start.
253 HIPCUB_DEVICE inline
254 void BFI(unsigned int &ret,
255  unsigned int x,
256  unsigned int y,
257  unsigned int bit_start,
258  unsigned int num_bits)
259 {
260  #ifdef __HIP_PLATFORM_AMD__
261  ret = __bitinsert_u32(x, y, bit_start, num_bits);
262  #else
263  x <<= bit_start;
264  unsigned int MASK_X = ((1 << num_bits) - 1) << bit_start;
265  unsigned int MASK_Y = ~MASK_X;
266  ret = (y & MASK_Y) | (x & MASK_X);
267  #endif // __HIP_PLATFORM_AMD__
268 }
269 
270 HIPCUB_DEVICE inline
271 unsigned int IADD3(unsigned int x, unsigned int y, unsigned int z)
272 {
273  return x + y + z;
274 }
275 
276 HIPCUB_DEVICE inline
277 int PRMT(unsigned int a, unsigned int b, unsigned int index)
278 {
279  return ::__byte_perm(a, b, index);
280 }
281 
282 HIPCUB_DEVICE inline
283 void BAR(int count)
284 {
285  (void) count;
286  __syncthreads();
287 }
288 
289 HIPCUB_DEVICE inline
290 void CTA_SYNC()
291 {
292  __syncthreads();
293 }
294 
295 HIPCUB_DEVICE inline
296 void WARP_SYNC(unsigned int member_mask)
297 {
298  (void) member_mask;
299  ::rocprim::wave_barrier();
300 }
301 
302 HIPCUB_DEVICE inline
303 int WARP_ANY(int predicate, uint64_t member_mask)
304 {
305  (void) member_mask;
306  return ::__any(predicate);
307 }
308 
309 HIPCUB_DEVICE inline
310 int WARP_ALL(int predicate, uint64_t member_mask)
311 {
312  (void) member_mask;
313  return ::__all(predicate);
314 }
315 
316 HIPCUB_DEVICE inline
317 int64_t WARP_BALLOT(int predicate, uint64_t member_mask)
318 {
319  (void) member_mask;
320  return __ballot(predicate);
321 }
322 
323 END_HIPCUB_NAMESPACE
324 
325 #endif // HIPCUB_ROCPRIM_UTIL_PTX_HPP_