30 #ifndef HIPCUB_CONFIG_HPP_
31 #define HIPCUB_CONFIG_HPP_
33 #include <hip/hip_runtime.h>
35 #define HIPCUB_NAMESPACE hipcub
37 #define BEGIN_HIPCUB_NAMESPACE \
40 #define END_HIPCUB_NAMESPACE \
43 #ifdef __HIP_PLATFORM_AMD__
44 #define HIPCUB_ROCPRIM_API 1
45 #define HIPCUB_RUNTIME_FUNCTION __host__
47 #include <rocprim/device/config_types.hpp>
51 inline unsigned int host_warp_size_wrapper()
54 unsigned int host_warp_size = 0;
55 hipError_t error = hipGetDevice(&device_id);
56 if(error != hipSuccess)
58 fprintf(stderr,
"HIP error: %d line: %d: %s\n", error, __LINE__, hipGetErrorString(error));
61 if(::rocprim::host_warp_size(device_id, host_warp_size) != hipSuccess)
65 return host_warp_size;
69 #define HIPCUB_WARP_THREADS ::rocprim::warp_size()
70 #define HIPCUB_DEVICE_WARP_THREADS ::rocprim::device_warp_size()
71 #define HIPCUB_HOST_WARP_THREADS detail::host_warp_size_wrapper()
73 #elif defined(__HIP_PLATFORM_NVIDIA__)
74 #define HIPCUB_CUB_API 1
75 #define HIPCUB_RUNTIME_FUNCTION CUB_RUNTIME_FUNCTION
77 #include <cub/util_arch.cuh>
78 #define HIPCUB_WARP_THREADS CUB_PTX_WARP_THREADS
79 #define HIPCUB_DEVICE_WARP_THREADS CUB_PTX_WARP_THREADS
80 #define HIPCUB_HOST_WARP_THREADS CUB_PTX_WARP_THREADS
81 #define HIPCUB_ARCH CUB_PTX_ARCH
82 BEGIN_HIPCUB_NAMESPACE
88 #define HIPCUB_WARP_SIZE_32 32u
89 #define HIPCUB_WARP_SIZE_64 64u
90 #define HIPCUB_MAX_WARP_SIZE HIPCUB_WARP_SIZE_64
92 #define HIPCUB_HOST __host__
93 #define HIPCUB_DEVICE __device__
94 #define HIPCUB_HOST_DEVICE __host__ __device__
95 #define HIPCUB_SHARED_MEMORY __shared__
99 #define HIPCUB_PRAGMA_TO_STR(x) _Pragma(#x)
100 #define HIPCUB_CLANG_SUPPRESS_WARNING_PUSH _Pragma("clang diagnostic push")
101 #define HIPCUB_CLANG_SUPPRESS_WARNING(w) HIPCUB_PRAGMA_TO_STR(clang diagnostic ignored w)
102 #define HIPCUB_CLANG_SUPPRESS_WARNING_POP _Pragma("clang diagnostic pop")
103 #define HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH(w) \
104 HIPCUB_CLANG_SUPPRESS_WARNING_PUSH HIPCUB_CLANG_SUPPRESS_WARNING(w)
106 #define HIPCUB_CLANG_SUPPRESS_WARNING_PUSH
107 #define HIPCUB_CLANG_SUPPRESS_WARNING(w)
108 #define HIPCUB_CLANG_SUPPRESS_WARNING_POP
109 #define HIPCUB_CLANG_SUPPRESS_WARNING_WITH_PUSH(w)
112 BEGIN_HIPCUB_NAMESPACE
115 #if (defined(DEBUG) || defined(_DEBUG)) && !defined(HIPCUB_STDERR)
116 #define HIPCUB_STDERR
122 const char* filename,
130 fprintf(stderr,
"HIP error %d [%s, %d]: %s\n", error, filename, line, hipGetErrorString(error));
138 #define HipcubDebug(e) hipcub::Debug((hipError_t) (e), __FILE__, __LINE__)
141 #if __cpp_if_constexpr
142 #define HIPCUB_IF_CONSTEXPR constexpr
144 #if defined(_MSC_VER) && !defined(__clang__)
148 #define HIPCUB_IF_CONSTEXPR constexpr
150 #define HIPCUB_IF_CONSTEXPR