53 #ifndef ROCRAND_THREEFRY2_IMPL_H_
54 #define ROCRAND_THREEFRY2_IMPL_H_
56 #include "rocrand/rocrand_common.h"
57 #include "rocrand/rocrand_threefry_common.h"
59 #include <hip/hip_runtime.h>
61 #ifndef THREEFRY2x32_DEFAULT_ROUNDS
62 #define THREEFRY2x32_DEFAULT_ROUNDS 20
65 #ifndef THREEFRY2x64_DEFAULT_ROUNDS
66 #define THREEFRY2x64_DEFAULT_ROUNDS 20
69 namespace rocrand_device
73 __forceinline__ __device__ __host__
int threefry_rotation_array(
int index) =
delete;
76 __forceinline__ __device__ __host__
int threefry_rotation_array<unsigned int>(
int index)
82 static constexpr
int THREEFRY_ROTATION_32_2[8] = {13, 15, 26, 6, 17, 29, 16, 24};
83 return THREEFRY_ROTATION_32_2[index];
87 __forceinline__ __device__ __host__
int threefry_rotation_array<unsigned long long>(
int index)
93 static constexpr
int THREEFRY_ROTATION_64_2[8] = {16, 42, 12, 31, 16, 32, 24, 21};
94 return THREEFRY_ROTATION_64_2[index];
97 template<
typename state_value,
typename value,
unsigned int Nrounds>
98 class threefry_engine2_base
101 struct threefry_state_2
106 unsigned int substate;
108 using state_type = threefry_state_2;
109 using state_vector_type = state_value;
111 __forceinline__ __device__ __host__
void discard(
unsigned long long offset)
113 this->discard_impl(offset);
114 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
117 __forceinline__ __device__ __host__
void discard()
119 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
127 __forceinline__ __device__ __host__
void discard_subsequence(
unsigned long long subsequence)
129 this->discard_subsequence_impl(subsequence);
130 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
133 __forceinline__ __device__ __host__ value operator()()
138 __forceinline__ __device__ __host__ value next()
140 #if defined(__HIP_PLATFORM_AMD__)
141 value ret = ROCRAND_HIPVEC_ACCESS(m_state.result)[m_state.substate];
143 value ret = (&m_state.result.x)[m_state.substate];
146 if(m_state.substate == 2)
148 m_state.substate = 0;
149 m_state.counter = this->bump_counter(m_state.counter);
150 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
155 __forceinline__ __device__ __host__ state_value next2()
157 state_value ret = m_state.result;
158 m_state.counter = this->bump_counter(m_state.counter);
159 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
161 return this->interleave(ret, m_state.result);
165 __forceinline__ __device__ __host__
static state_value threefry_rounds(state_value counter,
171 static_assert(Nrounds <= 32,
"32 or less only supported in threefry rounds");
173 ks[2] = skein_ks_parity<value>();
188 for(
unsigned int round_idx = 0; round_idx < Nrounds; round_idx++)
191 X.y = rotl<value>(X.y, threefry_rotation_array<value>(round_idx & 7u));
194 if((round_idx & 3u) == 3)
196 unsigned int inject_idx = round_idx / 4;
198 X.x += ks[(1 + inject_idx) % 3];
199 X.y += ks[(2 + inject_idx) % 3];
200 X.y += 1 + inject_idx;
209 __forceinline__ __device__ __host__
void discard_impl(
unsigned long long offset)
212 m_state.substate += offset & 1;
213 unsigned long long counter_offset = offset / 2;
214 counter_offset += m_state.substate < 2 ? 0 : 1;
215 m_state.substate += m_state.substate < 2 ? 0 : -2;
217 this->discard_state(counter_offset);
221 __forceinline__ __device__ __host__
void
222 discard_subsequence_impl(
unsigned long long subsequence)
224 m_state.counter.y += subsequence;
229 __forceinline__ __device__ __host__
void discard_state(
unsigned long long offset)
232 ::rocrand_device::detail::split_ull(lo, hi, offset);
234 value old_counter = m_state.counter.x;
235 m_state.counter.x += lo;
236 m_state.counter.y += hi + (m_state.counter.x < old_counter ? 1 : 0);
239 __forceinline__ __device__ __host__
static state_value bump_counter(state_value counter)
242 value add = counter.x == 0 ? 1 : 0;
247 __forceinline__ __device__ __host__ state_value interleave(
const state_value prev,
248 const state_value next)
const
250 switch(m_state.substate)
253 case 1:
return state_value{prev.y, next.x};
255 __builtin_unreachable();
259 threefry_state_2 m_state;