/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.0.2/hipcub/include/hipcub/backend/rocprim/util_ptx.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.0.2/hipcub/include/hipcub/backend/rocprim/util_ptx.hpp Source File#

hipCUB: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.0.2/hipcub/include/hipcub/backend/rocprim/util_ptx.hpp Source File
util_ptx.hpp
1 /******************************************************************************
2  * Copyright (c) 2010-2011, Duane Merrill. All rights reserved.
3  * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
4  * Modifications Copyright (c) 2017-2020, Advanced Micro Devices, Inc. All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  * * Redistributions of source code must retain the above copyright
9  * notice, this list of conditions and the following disclaimer.
10  * * Redistributions in binary form must reproduce the above copyright
11  * notice, this list of conditions and the following disclaimer in the
12  * documentation and/or other materials provided with the distribution.
13  * * Neither the name of the NVIDIA CORPORATION nor the
14  * names of its contributors may be used to endorse or promote products
15  * derived from this software without specific prior written permission.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
18  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
19  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
21  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
22  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
23  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
24  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27  *
28  ******************************************************************************/
29 
30 #ifndef HIPCUB_ROCPRIM_UTIL_PTX_HPP_
31 #define HIPCUB_ROCPRIM_UTIL_PTX_HPP_
32 
33 #include <cstdint>
34 #include <type_traits>
35 
36 #include "../../config.hpp"
37 
38 #include <rocprim/intrinsics/warp_shuffle.hpp>
39 
40 #define HIPCUB_WARP_THREADS ::rocprim::warp_size()
41 #define HIPCUB_DEVICE_WARP_THREADS ::rocprim::device_warp_size()
42 #define HIPCUB_HOST_WARP_THREADS ::rocprim::host_warp_size()
43 #define HIPCUB_ARCH 1 // ignored with rocPRIM backend
44 
45 
46 BEGIN_HIPCUB_NAMESPACE
47 
48 // Missing compared to CUB:
49 // * ThreadExit - not supported
50 // * ThreadTrap - not supported
51 // * FFMA_RZ, FMUL_RZ - not in CUB public API
52 // * WARP_SYNC - not supported, not CUB public API
53 // * CTA_SYNC_AND - not supported, not CUB public API
54 // * MatchAny - not in CUB public API
55 //
56 // Differences:
57 // * Warp thread masks (when used) are 64-bit unsigned integers
58 // * member_mask argument is ignored in WARP_[ALL|ANY|BALLOT] funcs
59 // * Arguments first_lane, last_lane, and member_mask are ignored
60 // in Shuffle* funcs
61 // * count in BAR is ignored, BAR works like CTA_SYNC
62 
63 // ID functions etc.
64 
65 HIPCUB_DEVICE inline
66 int RowMajorTid(int block_dim_x, int block_dim_y, int block_dim_z)
67 {
68  return ((block_dim_z == 1) ? 0 : (hipThreadIdx_z * block_dim_x * block_dim_y))
69  + ((block_dim_y == 1) ? 0 : (hipThreadIdx_y * block_dim_x))
70  + hipThreadIdx_x;
71 }
72 
73 HIPCUB_DEVICE inline
74 unsigned int LaneId()
75 {
76  return ::rocprim::lane_id();
77 }
78 
79 HIPCUB_DEVICE inline
80 unsigned int WarpId()
81 {
82  return ::rocprim::warp_id();
83 }
84 
85 // Returns the warp lane mask of all lanes less than the calling thread
86 HIPCUB_DEVICE inline
87 uint64_t LaneMaskLt()
88 {
89  return (uint64_t(1) << LaneId()) - 1;
90 }
91 
92 // Returns the warp lane mask of all lanes less than or equal to the calling thread
93 HIPCUB_DEVICE inline
94 uint64_t LaneMaskLe()
95 {
96  return ((uint64_t(1) << LaneId()) << 1) - 1;
97 }
98 
99 // Returns the warp lane mask of all lanes greater than the calling thread
100 HIPCUB_DEVICE inline
101 uint64_t LaneMaskGt()
102 {
103  return uint64_t(-1)^LaneMaskLe();
104 }
105 
106 // Returns the warp lane mask of all lanes greater than or equal to the calling thread
107 HIPCUB_DEVICE inline
108 uint64_t LaneMaskGe()
109 {
110  return uint64_t(-1)^LaneMaskLt();
111 }
112 
113 // Shuffle funcs
114 
115 template <
116  int LOGICAL_WARP_THREADS,
117  typename T
118 >
119 HIPCUB_DEVICE inline
120 T ShuffleUp(T input,
121  int src_offset,
122  int first_thread,
123  unsigned int member_mask)
124 {
125  // Not supported in rocPRIM.
126  (void) first_thread;
127  // Member mask is not supported in rocPRIM, because it's
128  // not supported in ROCm.
129  (void) member_mask;
130  return ::rocprim::warp_shuffle_up(
131  input, src_offset, LOGICAL_WARP_THREADS
132  );
133 }
134 
135 template <
136  int LOGICAL_WARP_THREADS,
137  typename T
138 >
139 HIPCUB_DEVICE inline
140 T ShuffleDown(T input,
141  int src_offset,
142  int last_thread,
143  unsigned int member_mask)
144 {
145  // Not supported in rocPRIM.
146  (void) last_thread;
147  // Member mask is not supported in rocPRIM, because it's
148  // not supported in ROCm.
149  (void) member_mask;
150  return ::rocprim::warp_shuffle_down(
151  input, src_offset, LOGICAL_WARP_THREADS
152  );
153 }
154 
155 template <
156  int LOGICAL_WARP_THREADS,
157  typename T
158 >
159 HIPCUB_DEVICE inline
160 T ShuffleIndex(T input,
161  int src_lane,
162  unsigned int member_mask)
163 {
164  // Member mask is not supported in rocPRIM, because it's
165  // not supported in ROCm.
166  (void) member_mask;
167  return ::rocprim::warp_shuffle(
168  input, src_lane, LOGICAL_WARP_THREADS
169  );
170 }
171 
172 // Other
173 
174 HIPCUB_DEVICE inline
175 unsigned int SHR_ADD(unsigned int x,
176  unsigned int shift,
177  unsigned int addend)
178 {
179  return (x >> shift) + addend;
180 }
181 
182 HIPCUB_DEVICE inline
183 unsigned int SHL_ADD(unsigned int x,
184  unsigned int shift,
185  unsigned int addend)
186 {
187  return (x << shift) + addend;
188 }
189 
190 namespace detail {
191 
192 template <typename UnsignedBits>
193 HIPCUB_DEVICE inline
194 auto unsigned_bit_extract(UnsignedBits source,
195  unsigned int bit_start,
196  unsigned int num_bits)
197  -> typename std::enable_if<sizeof(UnsignedBits) == 8, unsigned int>::type
198 {
199  #ifdef __HIP_PLATFORM_AMD__
200  return __bitextract_u64(source, bit_start, num_bits);
201  #else
202  return (source << (64 - bit_start - num_bits)) >> (64 - num_bits);
203  #endif // __HIP_PLATFORM_AMD__
204 }
205 
206 template <typename UnsignedBits>
207 HIPCUB_DEVICE inline
208 auto unsigned_bit_extract(UnsignedBits source,
209  unsigned int bit_start,
210  unsigned int num_bits)
211  -> typename std::enable_if<sizeof(UnsignedBits) < 8, unsigned int>::type
212 {
213  #ifdef __HIP_PLATFORM_AMD__
214  return __bitextract_u32(source, bit_start, num_bits);
215  #else
216  return (static_cast<unsigned int>(source) << (32 - bit_start - num_bits)) >> (32 - num_bits);
217  #endif // __HIP_PLATFORM_AMD__
218 }
219 
220 } // end namespace detail
221 
222 // Bitfield-extract.
223 // Extracts \p num_bits from \p source starting at bit-offset \p bit_start.
224 // The input \p source may be an 8b, 16b, 32b, or 64b unsigned integer type.
225 template <typename UnsignedBits>
226 HIPCUB_DEVICE inline
227 unsigned int BFE(UnsignedBits source,
228  unsigned int bit_start,
229  unsigned int num_bits)
230 {
231  static_assert(std::is_unsigned<UnsignedBits>::value, "UnsignedBits must be unsigned");
232  return detail::unsigned_bit_extract(source, bit_start, num_bits);
233 }
234 
235 // Bitfield insert.
236 // Inserts the \p num_bits least significant bits of \p y into \p x at bit-offset \p bit_start.
237 HIPCUB_DEVICE inline
238 void BFI(unsigned int &ret,
239  unsigned int x,
240  unsigned int y,
241  unsigned int bit_start,
242  unsigned int num_bits)
243 {
244  #ifdef __HIP_PLATFORM_AMD__
245  ret = __bitinsert_u32(x, y, bit_start, num_bits);
246  #else
247  x <<= bit_start;
248  unsigned int MASK_X = ((1 << num_bits) - 1) << bit_start;
249  unsigned int MASK_Y = ~MASK_X;
250  ret = (y & MASK_Y) | (x & MASK_X);
251  #endif // __HIP_PLATFORM_AMD__
252 }
253 
254 HIPCUB_DEVICE inline
255 unsigned int IADD3(unsigned int x, unsigned int y, unsigned int z)
256 {
257  return x + y + z;
258 }
259 
260 HIPCUB_DEVICE inline
261 int PRMT(unsigned int a, unsigned int b, unsigned int index)
262 {
263  return ::__byte_perm(a, b, index);
264 }
265 
266 HIPCUB_DEVICE inline
267 void BAR(int count)
268 {
269  (void) count;
270  __syncthreads();
271 }
272 
273 HIPCUB_DEVICE inline
274 void CTA_SYNC()
275 {
276  __syncthreads();
277 }
278 
279 HIPCUB_DEVICE inline
280 void WARP_SYNC(unsigned int member_mask)
281 {
282  // Does nothing, on ROCm threads in warp are always in sync
283  (void) member_mask;
284 }
285 
286 HIPCUB_DEVICE inline
287 int WARP_ANY(int predicate, uint64_t member_mask)
288 {
289  (void) member_mask;
290  return ::__any(predicate);
291 }
292 
293 HIPCUB_DEVICE inline
294 int WARP_ALL(int predicate, uint64_t member_mask)
295 {
296  (void) member_mask;
297  return ::__all(predicate);
298 }
299 
300 HIPCUB_DEVICE inline
301 int64_t WARP_BALLOT(int predicate, uint64_t member_mask)
302 {
303  (void) member_mask;
304  return __ballot(predicate);
305 }
306 
307 END_HIPCUB_NAMESPACE
308 
309 #endif // HIPCUB_ROCPRIM_UTIL_PTX_HPP_