/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/core/utility/philox_rand.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/core/utility/philox_rand.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/core/utility/philox_rand.hpp Source File
philox_rand.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck_tile {
9 
10 // Reference: https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/src/philox.cuh
11 class philox
12 {
13  public:
14  CK_TILE_HOST_DEVICE philox(unsigned long long seed_, unsigned long long offset_)
15  : seed(reinterpret_cast<const uint2&>(seed_))
16  {
17 
18  ull2* tmp = reinterpret_cast<ull2*>(&counter);
19  tmp->x = offset_;
20  }
21 
22  CK_TILE_HOST_DEVICE uint4 get_philox_4x32(const unsigned long long subsequence) const
23  {
24 
25  uint4 counter_ = counter;
26  ull2* tmp = reinterpret_cast<ull2*>(&counter_);
27  tmp->y = subsequence;
28 
29  uint2 key_ = seed;
30 // 7-round philox
31 #pragma unroll
32  for(int i = 0; i < 6; i++)
33  {
34  counter_ = philox_single_round(counter_, key_);
35  key_.x += kPhilox10A;
36  key_.y += kPhilox10B;
37  }
38  uint4 output = philox_single_round(counter_, key_);
39  return output;
40  }
41 
43  const unsigned long long subsequence) const
44  {
45  uint4 tmp_ph;
46  tmp_ph = get_philox_4x32(subsequence);
47 
48  uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
49 
50  out_tmp[0] = tmp_ph.x;
51  out_tmp[1] = tmp_ph.y;
52  out_tmp[2] = tmp_ph.z;
53  out_tmp[3] = tmp_ph.w;
54  }
55 
57  const unsigned long long subsequence,
58  const index_t start_idx) const
59  {
60  uint4 tmp_ph;
61  tmp_ph = get_philox_4x32(subsequence);
62 
63  uint32x4_t tmp;
64  tmp[0] = tmp_ph.x;
65  tmp[1] = tmp_ph.y;
66  tmp[2] = tmp_ph.z;
67  tmp[3] = tmp_ph.w;
68  uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
69  out_tmp[0] = tmp[start_idx];
70  out_tmp[1] = tmp[start_idx + 2];
71  }
72 
74  const unsigned long long subsequence,
75  const index_t start_idx) const
76  {
77  uint4 tmp_ph;
78  tmp_ph = get_philox_4x32(subsequence);
79 
80  uint32x4_t tmp;
81  tmp[0] = tmp_ph.x;
82  tmp[1] = tmp_ph.y;
83  tmp[2] = tmp_ph.z;
84  tmp[3] = tmp_ph.w;
85  uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
86  out_tmp[0] = tmp[start_idx];
87  }
88 
89  private:
90  struct ull2
91  {
92  uint64_t x;
93  uint64_t y;
94  };
95  uint4 counter;
96  const uint2 seed;
97 
98  CK_TILE_HOST_DEVICE uint2 mulhilo32(const unsigned int a, const unsigned int b) const
99  {
100  uint2* res;
101  unsigned long long tmp;
102  tmp = static_cast<unsigned long long>(a) * b;
103  res = reinterpret_cast<uint2*>(&tmp);
104  return *res;
105  }
106 
107  CK_TILE_HOST_DEVICE uint4 philox_single_round(const uint4 ctr, const uint2 key) const
108  {
109 
110  uint2 res0 = mulhilo32(kPhiloxSA, ctr.x);
111  uint2 res1 = mulhilo32(kPhiloxSB, ctr.z);
112  uint4 ret = {res1.y ^ ctr.y ^ key.x, res1.x, res0.y ^ ctr.w ^ key.y, res0.x};
113  return ret;
114  }
115 
116  static const unsigned long kPhilox10A = 0x9E3779B9;
117  static const unsigned long kPhilox10B = 0xBB67AE85;
118  static const unsigned long kPhiloxSA = 0xD2511F53;
119  static const unsigned long kPhiloxSB = 0xCD9E8D57;
120 };
121 
122 } // namespace ck_tile
Definition: philox_rand.hpp:12
CK_TILE_HOST_DEVICE uint4 get_philox_4x32(const unsigned long long subsequence) const
Definition: philox_rand.hpp:22
CK_TILE_HOST_DEVICE philox(unsigned long long seed_, unsigned long long offset_)
Definition: philox_rand.hpp:14
CK_TILE_HOST_DEVICE void get_random_16x8(uint8_t *out, const unsigned long long subsequence) const
Definition: philox_rand.hpp:42
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t *out, const unsigned long long subsequence, const index_t start_idx) const
Definition: philox_rand.hpp:56
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t *out, const unsigned long long subsequence, const index_t start_idx) const
Definition: philox_rand.hpp:73
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
uint32_t uint32x4_t
Definition: vector_type.hpp:123