53 #ifndef ROCRAND_THREEFRY2_IMPL_H_
54 #define ROCRAND_THREEFRY2_IMPL_H_
56 #include "rocrand/rocrand_common.h"
57 #include "rocrand/rocrand_threefry_common.h"
59 #include <hip/hip_runtime.h>
61 #ifndef THREEFRY2x32_DEFAULT_ROUNDS
62 #define THREEFRY2x32_DEFAULT_ROUNDS 20
65 #ifndef THREEFRY2x64_DEFAULT_ROUNDS
66 #define THREEFRY2x64_DEFAULT_ROUNDS 20
69 namespace rocrand_device
72 template<
typename state_value,
typename value,
unsigned int Nrounds>
73 class threefry_engine2_base
76 struct threefry_state_2
81 unsigned int substate;
83 using state_type = threefry_state_2;
84 using state_vector_type = state_value;
86 __forceinline__ __device__ __host__
87 void discard(
unsigned long long offset)
89 this->discard_impl(offset);
90 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
93 __forceinline__ __device__ __host__
96 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
104 __forceinline__ __device__ __host__
105 void discard_subsequence(
unsigned long long subsequence)
107 this->discard_subsequence_impl(subsequence);
108 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
111 __forceinline__ __device__ __host__
117 __forceinline__ __device__ __host__
120 #if defined(__HIP_PLATFORM_AMD__)
121 value ret = ROCRAND_HIPVEC_ACCESS(m_state.result)[m_state.substate];
123 value ret = (&m_state.result.x)[m_state.substate];
126 if(m_state.substate == 2)
128 m_state.substate = 0;
129 m_state.counter = this->bump_counter(m_state.counter);
130 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
135 __forceinline__ __device__ __host__
138 state_value ret = m_state.result;
139 m_state.counter = this->bump_counter(m_state.counter);
140 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
142 return this->interleave(ret, m_state.result);
146 __forceinline__ __device__ __host__
147 static state_value threefry_rounds(state_value counter, state_value key)
152 static_assert(Nrounds <= 32,
"32 or less only supported in threefry rounds");
154 ks[2] = skein_ks_parity<value>();
169 return rounds_2<state_value, value, Nrounds>(X, ks);
174 __forceinline__ __device__ __host__
175 void discard_impl(
unsigned long long offset)
178 m_state.substate += offset & 1;
179 unsigned long long counter_offset = offset / 2;
180 counter_offset += m_state.substate < 2 ? 0 : 1;
181 m_state.substate += m_state.substate < 2 ? 0 : -2;
183 this->discard_state(counter_offset);
187 __forceinline__ __device__ __host__
188 void discard_subsequence_impl(
unsigned long long subsequence)
190 m_state.counter.y += subsequence;
195 __forceinline__ __device__ __host__
196 void discard_state(
unsigned long long offset)
199 ::rocrand_device::detail::split_ull(lo, hi, offset);
201 value old_counter = m_state.counter.x;
202 m_state.counter.x += lo;
203 m_state.counter.y += hi + (m_state.counter.x < old_counter ? 1 : 0);
206 __forceinline__ __device__ __host__
207 static state_value bump_counter(state_value counter)
210 value add = counter.x == 0 ? 1 : 0;
215 __forceinline__ __device__ __host__
216 state_value interleave(
const state_value prev,
const state_value next)
const
218 switch(m_state.substate)
221 case 1:
return state_value{prev.y, next.x};
223 __builtin_unreachable();
227 threefry_state_2 m_state;