53 #ifndef ROCRAND_THREEFRY4_IMPL_H_
54 #define ROCRAND_THREEFRY4_IMPL_H_
57 #define FQUALIFIERS __forceinline__ __device__
60 #include "rocrand/rocrand_threefry_common.h"
61 #include <rocrand/rocrand_common.h>
63 #ifndef THREEFRY4x32_DEFAULT_ROUNDS
64 #define THREEFRY4x32_DEFAULT_ROUNDS 20
67 #ifndef THREEFRY4x64_DEFAULT_ROUNDS
68 #define THREEFRY4x64_DEFAULT_ROUNDS 20
73 static constexpr __device__
int THREEFRY_ROTATION_64_4[8][2] = {
89 static constexpr __device__
int THREEFRY_ROTATION_32_4[8][2] = {
100 namespace rocrand_device
103 template<
class value>
104 FQUALIFIERS int threefry_rotation_array(
int indexX,
int indexY);
107 FQUALIFIERS int threefry_rotation_array<unsigned int>(
int indexX,
int indexY)
109 return THREEFRY_ROTATION_32_4[indexX][indexY];
113 FQUALIFIERS int threefry_rotation_array<unsigned long long>(
int indexX,
int indexY)
115 return THREEFRY_ROTATION_64_4[indexX][indexY];
118 template<
typename state_value,
typename value,
unsigned int Nrounds>
119 class threefry_engine4_base
122 struct threefry_state_4
127 unsigned int substate;
131 FQUALIFIERS void discard(
unsigned long long offset)
133 this->discard_impl(offset);
134 this->m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
142 FQUALIFIERS void discard_subsequence(
unsigned long long subsequence)
144 this->discard_subsequence_impl(subsequence);
145 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
155 #if defined(__HIP_PLATFORM_AMD__)
156 value ret = m_state.result.data[m_state.substate];
158 value ret = (&m_state.result.x)[m_state.substate];
161 if(m_state.substate == 4)
163 m_state.substate = 0;
164 m_state.counter = this->bump_counter(m_state.counter);
165 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
172 state_value ret = m_state.result;
173 m_state.counter = this->bump_counter(m_state.counter);
174 m_state.result = this->threefry_rounds(m_state.counter, m_state.key);
176 return this->interleave(ret, m_state.result);
180 FQUALIFIERS static state_value threefry_rounds(state_value counter, state_value key)
185 static_assert(Nrounds <= 72,
"72 or less only supported in threefry rounds");
187 ks[4] = skein_ks_parity<value>();
210 for(
unsigned int round_idx = 0; round_idx < Nrounds; round_idx++)
212 int rot_0 = threefry_rotation_array<value>(round_idx & 7u, 0);
213 int rot_1 = threefry_rotation_array<value>(round_idx & 7u, 1);
214 if((round_idx & 2u) == 0)
217 X.y = rotl<value>(X.y, rot_0);
220 X.w = rotl<value>(X.w, rot_1);
226 X.w = rotl<value>(X.w, rot_0);
229 X.y = rotl<value>(X.y, rot_1);
233 if((round_idx & 3u) == 3)
235 unsigned int inject_idx = round_idx / 4;
237 X.x += ks[(1 + inject_idx) % 5];
238 X.y += ks[(2 + inject_idx) % 5];
239 X.z += ks[(3 + inject_idx) % 5];
240 X.w += ks[(4 + inject_idx) % 5];
241 X.w += 1 + inject_idx;
250 FQUALIFIERS void discard_impl(
unsigned long long offset)
253 m_state.substate += offset & 3;
254 unsigned long long counter_offset = offset / 4;
255 counter_offset += m_state.substate < 4 ? 0 : 1;
256 m_state.substate += m_state.substate < 4 ? 0 : -4;
258 this->discard_state(counter_offset);
262 FQUALIFIERS void discard_subsequence_impl(
unsigned long long subsequence)
265 ::rocrand_device::detail::split_ull(lo, hi, subsequence);
267 value old_counter = m_state.counter.z;
268 m_state.counter.z += lo;
269 m_state.counter.w += hi + (m_state.counter.z < old_counter ? 1 : 0);
274 FQUALIFIERS void discard_state(
unsigned long long offset)
277 ::rocrand_device::detail::split_ull(lo, hi, offset);
279 state_value old_counter = m_state.counter;
280 m_state.counter.x += lo;
281 m_state.counter.y += hi + (m_state.counter.x < old_counter.x ? 1 : 0);
282 m_state.counter.z += (m_state.counter.y < old_counter.y ? 1 : 0);
283 m_state.counter.w += (m_state.counter.z < old_counter.z ? 1 : 0);
286 FQUALIFIERS static state_value bump_counter(state_value counter)
289 value add = counter.x == 0 ? 1 : 0;
291 add = counter.y == 0 ? add : 0;
293 add = counter.z == 0 ? add : 0;
298 FQUALIFIERS state_value interleave(
const state_value prev,
const state_value next)
const
300 switch(m_state.substate)
303 case 1:
return state_value{prev.y, prev.z, prev.w, next.x};
304 case 2:
return state_value{prev.z, prev.w, next.x, next.y};
305 case 3:
return state_value{prev.w, next.x, next.y, next.z};
307 __builtin_unreachable();
311 threefry_state_4 m_state;
#define FQUALIFIERS
Shorthand for commonly used function qualifiers.
Definition: rocrand_uniform.h:31