53 #ifndef ROCRAND_THREEFRY4_IMPL_H_
54 #define ROCRAND_THREEFRY4_IMPL_H_
56 #include "rocrand/rocrand_threefry_common.h"
57 #include <rocrand/rocrand_common.h>
59 #ifndef THREEFRY4x32_DEFAULT_ROUNDS
60 #define THREEFRY4x32_DEFAULT_ROUNDS 20
63 #ifndef THREEFRY4x64_DEFAULT_ROUNDS
64 #define THREEFRY4x64_DEFAULT_ROUNDS 20
67 namespace rocrand_device
71 __forceinline__ __device__ __host__
int threefry_rotation_array(
int indexX,
int indexY) =
delete;
74 __forceinline__ __device__ __host__
int threefry_rotation_array<unsigned int>(
int indexX,
82 static constexpr
int THREEFRY_ROTATION_32_4[8][2] = {
92 return THREEFRY_ROTATION_32_4[indexX][indexY];
96 __forceinline__ __device__ __host__
int threefry_rotation_array<unsigned long long>(
int indexX,
101 static constexpr
int THREEFRY_ROTATION_64_4[8][2] = {
111 return THREEFRY_ROTATION_64_4[indexX][indexY];
114 template<
typename state_value,
typename value,
unsigned int Nrounds>
115 class threefry_engine4_base
118 struct threefry_state_4
123 unsigned int substate;
125 using state_type = threefry_state_4;
126 using state_vector_type = state_value;
129 __forceinline__ __device__ __host__
void discard(
unsigned long long offset)
131 this->discard_impl(offset);
132 this->m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
140 __forceinline__ __device__ __host__
void discard_subsequence(
unsigned long long subsequence)
142 this->discard_subsequence_impl(subsequence);
143 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
146 __forceinline__ __device__ __host__ value operator()()
151 __forceinline__ __device__ __host__ value next()
153 #if defined(__HIP_PLATFORM_AMD__)
154 value ret = m_state.result.data[m_state.substate];
156 value ret = (&m_state.result.x)[m_state.substate];
159 if(m_state.substate == 4)
161 m_state.substate = 0;
162 m_state.counter = this->bump_counter(m_state.counter);
163 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
168 __forceinline__ __device__ __host__ state_value next4()
170 state_value ret = m_state.result;
171 m_state.counter = this->bump_counter(m_state.counter);
172 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
174 return this->interleave(ret, m_state.result);
178 __forceinline__ __device__ __host__
static state_value threefry_rounds(state_value counter,
184 static_assert(Nrounds <= 72,
"72 or less only supported in threefry rounds");
186 ks[4] = skein_ks_parity<value>();
209 for(
unsigned int round_idx = 0; round_idx < Nrounds; round_idx++)
211 int rot_0 = threefry_rotation_array<value>(round_idx & 7u, 0);
212 int rot_1 = threefry_rotation_array<value>(round_idx & 7u, 1);
213 if((round_idx & 2u) == 0)
216 X.y = rotl<value>(X.y, rot_0);
219 X.w = rotl<value>(X.w, rot_1);
225 X.w = rotl<value>(X.w, rot_0);
228 X.y = rotl<value>(X.y, rot_1);
232 if((round_idx & 3u) == 3)
234 unsigned int inject_idx = round_idx / 4;
236 X.x += ks[(1 + inject_idx) % 5];
237 X.y += ks[(2 + inject_idx) % 5];
238 X.z += ks[(3 + inject_idx) % 5];
239 X.w += ks[(4 + inject_idx) % 5];
240 X.w += 1 + inject_idx;
249 __forceinline__ __device__ __host__
void discard_impl(
unsigned long long offset)
252 m_state.substate += offset & 3;
253 unsigned long long counter_offset = offset / 4;
254 counter_offset += m_state.substate < 4 ? 0 : 1;
255 m_state.substate += m_state.substate < 4 ? 0 : -4;
257 this->discard_state(counter_offset);
261 __forceinline__ __device__ __host__
void
262 discard_subsequence_impl(
unsigned long long subsequence)
265 ::rocrand_device::detail::split_ull(lo, hi, subsequence);
267 value old_counter = m_state.counter.z;
268 m_state.counter.z += lo;
269 m_state.counter.w += hi + (m_state.counter.z < old_counter ? 1 : 0);
274 __forceinline__ __device__ __host__
void discard_state(
unsigned long long offset)
277 ::rocrand_device::detail::split_ull(lo, hi, offset);
279 state_value old_counter = m_state.counter;
280 m_state.counter.x += lo;
281 m_state.counter.y += hi + (m_state.counter.x < old_counter.x ? 1 : 0);
282 m_state.counter.z += (m_state.counter.y < old_counter.y ? 1 : 0);
283 m_state.counter.w += (m_state.counter.z < old_counter.z ? 1 : 0);
286 __forceinline__ __device__ __host__
static state_value bump_counter(state_value counter)
289 value add = counter.x == 0 ? 1 : 0;
291 add = counter.y == 0 ? add : 0;
293 add = counter.z == 0 ? add : 0;
298 __forceinline__ __device__ __host__ state_value interleave(
const state_value prev,
299 const state_value next)
const
301 switch(m_state.substate)
304 case 1:
return state_value{prev.y, prev.z, prev.w, next.x};
305 case 2:
return state_value{prev.z, prev.w, next.x, next.y};
306 case 3:
return state_value{prev.w, next.x, next.y, next.z};
308 __builtin_unreachable();
312 threefry_state_4 m_state;