/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.7.0/hipcub/include/hipcub/backend/rocprim/util_ptx.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.7.0/hipcub/include/hipcub/backend/rocprim/util_ptx.hpp Source File#

hipCUB: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipcub/checkouts/docs-5.7.0/hipcub/include/hipcub/backend/rocprim/util_ptx.hpp Source File
util_ptx.hpp
1 /******************************************************************************
2  * Copyright (c) 2010-2011, Duane Merrill. All rights reserved.
3  * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved.
4  * Modifications Copyright (c) 2017-2023, Advanced Micro Devices, Inc. All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  * * Redistributions of source code must retain the above copyright
9  * notice, this list of conditions and the following disclaimer.
10  * * Redistributions in binary form must reproduce the above copyright
11  * notice, this list of conditions and the following disclaimer in the
12  * documentation and/or other materials provided with the distribution.
13  * * Neither the name of the NVIDIA CORPORATION nor the
14  * names of its contributors may be used to endorse or promote products
15  * derived from this software without specific prior written permission.
16  *
17  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
18  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
19  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
21  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
22  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
23  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
24  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27  *
28  ******************************************************************************/
29 
30 #ifndef HIPCUB_ROCPRIM_UTIL_PTX_HPP_
31 #define HIPCUB_ROCPRIM_UTIL_PTX_HPP_
32 
33 #include <cstdint>
34 #include <type_traits>
35 
36 #include "../../config.hpp"
37 
38 #include <rocprim/intrinsics/warp_shuffle.hpp>
39 
40 BEGIN_HIPCUB_NAMESPACE
41 
42 // Missing compared to CUB:
43 // * ThreadExit - not supported
44 // * ThreadTrap - not supported
45 // * FFMA_RZ, FMUL_RZ - not in CUB public API
46 // * WARP_SYNC - not supported, not CUB public API
47 // * CTA_SYNC_AND - not supported, not CUB public API
48 // * MatchAny - not in CUB public API
49 //
50 // Differences:
51 // * Warp thread masks (when used) are 64-bit unsigned integers
52 // * member_mask argument is ignored in WARP_[ALL|ANY|BALLOT] funcs
53 // * Arguments first_thread, last_thread, and member_mask are ignored
54 // in Shuffle* funcs
55 // * count in BAR is ignored, BAR works like CTA_SYNC
56 
57 // ID functions etc.
58 
59 HIPCUB_DEVICE inline
60 int RowMajorTid(int block_dim_x, int block_dim_y, int block_dim_z)
61 {
62  return ((block_dim_z == 1) ? 0 : (hipThreadIdx_z * block_dim_x * block_dim_y))
63  + ((block_dim_y == 1) ? 0 : (hipThreadIdx_y * block_dim_x))
64  + hipThreadIdx_x;
65 }
66 
67 HIPCUB_DEVICE inline
68 unsigned int LaneId()
69 {
70  return ::rocprim::lane_id();
71 }
72 
73 HIPCUB_DEVICE inline
74 unsigned int WarpId()
75 {
76  return ::rocprim::warp_id();
77 }
78 
79 template <int LOGICAL_WARP_THREADS, int /* ARCH */ = 0>
80 HIPCUB_DEVICE inline
81 uint64_t WarpMask(unsigned int warp_id) {
82  constexpr bool is_pow_of_two = ::rocprim::detail::is_power_of_two(LOGICAL_WARP_THREADS);
83  constexpr bool is_arch_warp =
84  LOGICAL_WARP_THREADS == ::rocprim::device_warp_size();
85 
86  uint64_t member_mask = uint64_t(-1) >> (64 - LOGICAL_WARP_THREADS);
87 
88  if (is_pow_of_two && !is_arch_warp) {
89  member_mask <<= warp_id * LOGICAL_WARP_THREADS;
90  }
91 
92  return member_mask;
93 }
94 
95 // Returns the warp lane mask of all lanes less than the calling thread
96 HIPCUB_DEVICE inline
97 uint64_t LaneMaskLt()
98 {
99  return (uint64_t(1) << LaneId()) - 1;
100 }
101 
102 // Returns the warp lane mask of all lanes less than or equal to the calling thread
103 HIPCUB_DEVICE inline
104 uint64_t LaneMaskLe()
105 {
106  return ((uint64_t(1) << LaneId()) << 1) - 1;
107 }
108 
109 // Returns the warp lane mask of all lanes greater than the calling thread
110 HIPCUB_DEVICE inline
111 uint64_t LaneMaskGt()
112 {
113  return uint64_t(-1)^LaneMaskLe();
114 }
115 
116 // Returns the warp lane mask of all lanes greater than or equal to the calling thread
117 HIPCUB_DEVICE inline
118 uint64_t LaneMaskGe()
119 {
120  return uint64_t(-1)^LaneMaskLt();
121 }
122 
123 // Shuffle funcs
124 
125 template <
126  int LOGICAL_WARP_THREADS,
127  typename T
128 >
129 HIPCUB_DEVICE inline
130 T ShuffleUp(T input,
131  int src_offset,
132  int first_thread,
133  unsigned int member_mask)
134 {
135  // Not supported in rocPRIM.
136  (void) first_thread;
137  // Member mask is not supported in rocPRIM, because it's
138  // not supported in ROCm.
139  (void) member_mask;
140  return ::rocprim::warp_shuffle_up(
141  input, src_offset, LOGICAL_WARP_THREADS
142  );
143 }
144 
145 template <
146  int LOGICAL_WARP_THREADS,
147  typename T
148 >
149 HIPCUB_DEVICE inline
150 T ShuffleDown(T input,
151  int src_offset,
152  int last_thread,
153  unsigned int member_mask)
154 {
155  // Not supported in rocPRIM.
156  (void) last_thread;
157  // Member mask is not supported in rocPRIM, because it's
158  // not supported in ROCm.
159  (void) member_mask;
160  return ::rocprim::warp_shuffle_down(
161  input, src_offset, LOGICAL_WARP_THREADS
162  );
163 }
164 
165 template <
166  int LOGICAL_WARP_THREADS,
167  typename T
168 >
169 HIPCUB_DEVICE inline
170 T ShuffleIndex(T input,
171  int src_lane,
172  unsigned int member_mask)
173 {
174  // Member mask is not supported in rocPRIM, because it's
175  // not supported in ROCm.
176  (void) member_mask;
177  return ::rocprim::warp_shuffle(
178  input, src_lane, LOGICAL_WARP_THREADS
179  );
180 }
181 
182 // Other
183 
184 HIPCUB_DEVICE inline
185 unsigned int SHR_ADD(unsigned int x,
186  unsigned int shift,
187  unsigned int addend)
188 {
189  return (x >> shift) + addend;
190 }
191 
192 HIPCUB_DEVICE inline
193 unsigned int SHL_ADD(unsigned int x,
194  unsigned int shift,
195  unsigned int addend)
196 {
197  return (x << shift) + addend;
198 }
199 
200 namespace detail {
201 
202 template <typename UnsignedBits>
203 HIPCUB_DEVICE inline
204 auto unsigned_bit_extract(UnsignedBits source,
205  unsigned int bit_start,
206  unsigned int num_bits)
207  -> typename std::enable_if<sizeof(UnsignedBits) == 8, unsigned int>::type
208 {
209  #ifdef __HIP_PLATFORM_AMD__
210  return __bitextract_u64(source, bit_start, num_bits);
211  #else
212  return (source << (64 - bit_start - num_bits)) >> (64 - num_bits);
213  #endif // __HIP_PLATFORM_AMD__
214 }
215 
216 template <typename UnsignedBits>
217 HIPCUB_DEVICE inline
218 auto unsigned_bit_extract(UnsignedBits source,
219  unsigned int bit_start,
220  unsigned int num_bits)
221  -> typename std::enable_if<sizeof(UnsignedBits) < 8, unsigned int>::type
222 {
223  #ifdef __HIP_PLATFORM_AMD__
224  return __bitextract_u32(source, bit_start, num_bits);
225  #else
226  return (static_cast<unsigned int>(source) << (32 - bit_start - num_bits)) >> (32 - num_bits);
227  #endif // __HIP_PLATFORM_AMD__
228 }
229 
230 } // end namespace detail
231 
232 // Bitfield-extract.
233 // Extracts \p num_bits from \p source starting at bit-offset \p bit_start.
234 // The input \p source may be an 8b, 16b, 32b, or 64b unsigned integer type.
235 template <typename UnsignedBits>
236 HIPCUB_DEVICE inline
237 unsigned int BFE(UnsignedBits source,
238  unsigned int bit_start,
239  unsigned int num_bits)
240 {
241  static_assert(std::is_unsigned<UnsignedBits>::value, "UnsignedBits must be unsigned");
242  return detail::unsigned_bit_extract(source, bit_start, num_bits);
243 }
244 
245 // Bitfield insert.
246 // Inserts the \p num_bits least significant bits of \p y into \p x at bit-offset \p bit_start.
247 HIPCUB_DEVICE inline
248 void BFI(unsigned int &ret,
249  unsigned int x,
250  unsigned int y,
251  unsigned int bit_start,
252  unsigned int num_bits)
253 {
254  #ifdef __HIP_PLATFORM_AMD__
255  ret = __bitinsert_u32(x, y, bit_start, num_bits);
256  #else
257  x <<= bit_start;
258  unsigned int MASK_X = ((1 << num_bits) - 1) << bit_start;
259  unsigned int MASK_Y = ~MASK_X;
260  ret = (y & MASK_Y) | (x & MASK_X);
261  #endif // __HIP_PLATFORM_AMD__
262 }
263 
264 HIPCUB_DEVICE inline
265 unsigned int IADD3(unsigned int x, unsigned int y, unsigned int z)
266 {
267  return x + y + z;
268 }
269 
270 HIPCUB_DEVICE inline
271 int PRMT(unsigned int a, unsigned int b, unsigned int index)
272 {
273  return ::__byte_perm(a, b, index);
274 }
275 
276 HIPCUB_DEVICE inline
277 void BAR(int count)
278 {
279  (void) count;
280  __syncthreads();
281 }
282 
283 HIPCUB_DEVICE inline
284 void CTA_SYNC()
285 {
286  __syncthreads();
287 }
288 
289 HIPCUB_DEVICE inline
290 void WARP_SYNC(unsigned int member_mask)
291 {
292  (void) member_mask;
293  ::rocprim::wave_barrier();
294 }
295 
296 HIPCUB_DEVICE inline
297 int WARP_ANY(int predicate, uint64_t member_mask)
298 {
299  (void) member_mask;
300  return ::__any(predicate);
301 }
302 
303 HIPCUB_DEVICE inline
304 int WARP_ALL(int predicate, uint64_t member_mask)
305 {
306  (void) member_mask;
307  return ::__all(predicate);
308 }
309 
310 HIPCUB_DEVICE inline
311 int64_t WARP_BALLOT(int predicate, uint64_t member_mask)
312 {
313  (void) member_mask;
314  return __ballot(predicate);
315 }
316 
317 END_HIPCUB_NAMESPACE
318 
319 #endif // HIPCUB_ROCPRIM_UTIL_PTX_HPP_