53 #ifndef ROCRAND_THREEFRY4_IMPL_H_
54 #define ROCRAND_THREEFRY4_IMPL_H_
56 #include "rocrand/rocrand_common.h"
57 #include "rocrand/rocrand_threefry_common.h"
59 #include <hip/hip_runtime.h>
61 #ifndef THREEFRY4x32_DEFAULT_ROUNDS
62 #define THREEFRY4x32_DEFAULT_ROUNDS 20
65 #ifndef THREEFRY4x64_DEFAULT_ROUNDS
66 #define THREEFRY4x64_DEFAULT_ROUNDS 20
69 namespace rocrand_device
72 template<
typename state_value,
typename value,
unsigned int Nrounds>
73 class threefry_engine4_base
76 struct threefry_state_4
81 unsigned int substate;
83 using state_type = threefry_state_4;
84 using state_vector_type = state_value;
87 __forceinline__ __device__ __host__
88 void discard(
unsigned long long offset)
90 this->discard_impl(offset);
91 this->m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
99 __forceinline__ __device__ __host__
100 void discard_subsequence(
unsigned long long subsequence)
102 this->discard_subsequence_impl(subsequence);
103 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
106 __forceinline__ __device__ __host__
112 __forceinline__ __device__ __host__
115 #if defined(__HIP_PLATFORM_AMD__)
116 value ret = ROCRAND_HIPVEC_ACCESS(m_state.result)[m_state.substate];
118 value ret = (&m_state.result.x)[m_state.substate];
121 if(m_state.substate == 4)
123 m_state.substate = 0;
124 m_state.counter = this->bump_counter(m_state.counter);
125 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
130 __forceinline__ __device__ __host__
133 state_value ret = m_state.result;
134 m_state.counter = this->bump_counter(m_state.counter);
135 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
137 return this->interleave(ret, m_state.result);
141 __forceinline__ __device__ __host__
142 static state_value threefry_rounds(state_value counter, state_value key)
147 static_assert(Nrounds <= 72,
"72 or less only supported in threefry rounds");
149 ks[4] = skein_ks_parity<value>();
172 return rounds_4<state_value, value, Nrounds>(X, ks);
177 __forceinline__ __device__ __host__
178 void discard_impl(
unsigned long long offset)
181 m_state.substate += offset & 3;
182 unsigned long long counter_offset = offset / 4;
183 counter_offset += m_state.substate < 4 ? 0 : 1;
184 m_state.substate += m_state.substate < 4 ? 0 : -4;
186 this->discard_state(counter_offset);
190 __forceinline__ __device__ __host__
191 void discard_subsequence_impl(
unsigned long long subsequence)
194 ::rocrand_device::detail::split_ull(lo, hi, subsequence);
196 value old_counter = m_state.counter.z;
197 m_state.counter.z += lo;
198 m_state.counter.w += hi + (m_state.counter.z < old_counter ? 1 : 0);
203 __forceinline__ __device__ __host__
204 void discard_state(
unsigned long long offset)
207 ::rocrand_device::detail::split_ull(lo, hi, offset);
209 state_value old_counter = m_state.counter;
210 m_state.counter.x += lo;
211 m_state.counter.y += hi + (m_state.counter.x < old_counter.x ? 1 : 0);
212 m_state.counter.z += (m_state.counter.y < old_counter.y ? 1 : 0);
213 m_state.counter.w += (m_state.counter.z < old_counter.z ? 1 : 0);
216 __forceinline__ __device__ __host__
217 static state_value bump_counter(state_value counter)
220 value add = counter.x == 0 ? 1 : 0;
222 add = counter.y == 0 ? add : 0;
224 add = counter.z == 0 ? add : 0;
229 __forceinline__ __device__ __host__
230 state_value interleave(
const state_value prev,
const state_value next)
const
232 switch(m_state.substate)
235 case 1:
return state_value{prev.y, prev.z, prev.w, next.x};
236 case 2:
return state_value{prev.z, prev.w, next.x, next.y};
237 case 3:
return state_value{prev.w, next.x, next.y, next.z};
239 __builtin_unreachable();
243 threefry_state_4 m_state;