/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/core/arch/generic_memory_space_atomic.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/core/arch/generic_memory_space_atomic.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/core/arch/generic_memory_space_atomic.hpp Source File
generic_memory_space_atomic.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
8 
9 namespace ck_tile {
10 
11 template <typename T, typename ComputeType>
12 CK_TILE_HOST_DEVICE T add(const T& a, const T& b)
13 {
14  return type_convert<T>(type_convert<ComputeType>(a) + type_convert<ComputeType>(b));
15 }
16 
18 {
19  bf16x2_t rtn;
20  rtn[0] = add<bf16_t, float>(a[0], b[0]);
21  rtn[1] = add<bf16_t, float>(a[1], b[1]);
22  return rtn;
23 }
24 
26 {
27  bf16x4_t rtn;
28  rtn[0] = add<bf16_t, float>(a[0], b[0]);
29  rtn[1] = add<bf16_t, float>(a[1], b[1]);
30  rtn[2] = add<bf16_t, float>(a[2], b[2]);
31  rtn[3] = add<bf16_t, float>(a[3], b[3]);
32  return rtn;
33 }
34 
36 {
37  fp8x4_t rtn;
38  rtn[0] = add<fp8_t, float>(a[0], b[0]);
39  rtn[1] = add<fp8_t, float>(a[1], b[1]);
40  rtn[2] = add<fp8_t, float>(a[2], b[2]);
41  rtn[3] = add<fp8_t, float>(a[3], b[3]);
42  return rtn;
43 }
44 
46 {
47  fp8x8_t rtn;
48  rtn[0] = add<fp8_t, float>(a[0], b[0]);
49  rtn[1] = add<fp8_t, float>(a[1], b[1]);
50  rtn[2] = add<fp8_t, float>(a[2], b[2]);
51  rtn[3] = add<fp8_t, float>(a[3], b[3]);
52  rtn[4] = add<fp8_t, float>(a[4], b[4]);
53  rtn[5] = add<fp8_t, float>(a[5], b[5]);
54  rtn[6] = add<fp8_t, float>(a[6], b[6]);
55  rtn[7] = add<fp8_t, float>(a[7], b[7]);
56  return rtn;
57 }
58 
60 {
61  bf8x4_t rtn;
62  rtn[0] = add<bf8_t, float>(a[0], b[0]);
63  rtn[1] = add<bf8_t, float>(a[1], b[1]);
64  rtn[2] = add<bf8_t, float>(a[2], b[2]);
65  rtn[3] = add<bf8_t, float>(a[3], b[3]);
66  return rtn;
67 }
68 
70 {
71  bf8x8_t rtn;
72  rtn[0] = add<bf8_t, float>(a[0], b[0]);
73  rtn[1] = add<bf8_t, float>(a[1], b[1]);
74  rtn[2] = add<bf8_t, float>(a[2], b[2]);
75  rtn[3] = add<bf8_t, float>(a[3], b[3]);
76  rtn[4] = add<bf8_t, float>(a[4], b[4]);
77  rtn[5] = add<bf8_t, float>(a[5], b[5]);
78  rtn[6] = add<bf8_t, float>(a[6], b[6]);
79  rtn[7] = add<bf8_t, float>(a[7], b[7]);
80  return rtn;
81 }
82 
83 // Caution: DO NOT REMOVE
84 // intentionally have only declaration but no definition to cause compilation failure when trying to
85 // instantiate this template. The purpose is to make the implementation of atomic_add explicit for
86 // each datatype.
87 template <typename X>
88 CK_TILE_DEVICE void atomic_add(X* p_dst, const X& x);
89 
90 template <>
92 {
93  union U32BF162_ADDR
94  {
95  uint32_t* u32_a;
96  bf16x2_t* bf162_a;
97  };
98 
99  union U32BF162
100  {
101  uint32_t u32;
102  bf16x2_t bf162;
103  };
104 
105  U32BF162_ADDR dword_addr;
106  U32BF162 cur_v;
107  U32BF162 new_;
108  uint32_t old_v, new_v;
109  dword_addr.bf162_a = p_dst;
110  cur_v.u32 = *dword_addr.u32_a;
111 
112  do
113  {
114  old_v = cur_v.u32;
115  new_.bf162 = add_bf16x2_t(cur_v.bf162, x);
116  new_v = new_.u32;
117  cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
118  } while(cur_v.u32 != old_v);
119 }
120 
121 template <>
123 {
124  // Union to treat the pointer as either bf16x4_t* or uint64_t*:
125  union U64BF164_ADDR
126  {
127  uint64_t* u64_a;
128  bf16x4_t* bf164_a;
129  };
130 
131  // Union to treat the data as either bf16x4_t or 64-bit integer
132  union U64BF164
133  {
134  uint64_t u64;
135  bf16x4_t bf164;
136  };
137 
138  U64BF164_ADDR addr;
139  addr.bf164_a = p_dst; // interpret p_dst as a 64-bit location
140 
141  // First read (non-atomic) of the old value
142  U64BF164 cur_v;
143  cur_v.u64 = *addr.u64_a;
144 
145  U64BF164 new_v_union;
146  uint64_t old_v, new_v;
147 
148  do
149  {
150  // old 64 bits
151  old_v = cur_v.u64;
152 
153  // Add elementwise in bf16
154  new_v_union.bf164 = add_bf16x4_t(cur_v.bf164, x);
155  new_v = new_v_union.u64;
156 
157  // Attempt the 64-bit CAS
158  cur_v.u64 = atomicCAS(addr.u64_a, old_v, new_v);
159 
160  } while(cur_v.u64 != old_v);
161 }
162 
163 template <>
165 {
166  union U32FP84_ADDR
167  {
168  uint32_t* u32_a;
169  fp8x4_t* fp84_a;
170  };
171 
172  union U32FP84
173  {
174  uint32_t u32;
175  fp8x4_t fp84;
176  };
177 
178  U32FP84_ADDR dword_addr;
179  U32FP84 cur_v;
180  U32FP84 new_;
181  uint32_t old_v, new_v;
182 
183  dword_addr.fp84_a = p_dst;
184  cur_v.u32 = *dword_addr.u32_a;
185 
186  do
187  {
188  old_v = cur_v.u32;
189  new_.fp84 = add_fp8x4_t(cur_v.fp84, x);
190  new_v = new_.u32;
191  cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
192  } while(cur_v.u32 != old_v);
193 }
194 
195 template <>
197 {
198  union U32BF84_ADDR
199  {
200  uint32_t* u32_a;
201  bf8x4_t* bf84_a;
202  };
203 
204  union U32BF84
205  {
206  uint32_t u32;
207  bf8x4_t bf84;
208  };
209 
210  U32BF84_ADDR dword_addr;
211  U32BF84 cur_v;
212  U32BF84 new_;
213  uint32_t old_v, new_v;
214 
215  dword_addr.bf84_a = p_dst;
216  cur_v.u32 = *dword_addr.u32_a;
217 
218  do
219  {
220  old_v = cur_v.u32;
221  new_.bf84 = add_bf8x4_t(cur_v.bf84, x);
222  new_v = new_.u32;
223  cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
224  } while(cur_v.u32 != old_v);
225 }
226 
227 //
228 // Atomic add for fp8x8_t
229 //
230 template <>
232 {
233  // Union for addressing 64 bits as either "fp8x8_t" or a 64-bit integer.
234  union U64FP88_ADDR
235  {
236  uint64_t* u64_a; // pointer to 64-bit integer
237  fp8x8_t* fp88_a; // pointer to fp8x8_t
238  };
239 
240  union U64FP88
241  {
242  uint64_t u64;
243  fp8x8_t fp88;
244  };
245 
246  U64FP88_ADDR dword_addr;
247  U64FP88 cur_v;
248  U64FP88 new_v_union;
249  uint64_t old_v, new_v;
250 
251  // Point to the destination as both fp8x8_t* and uint64_t*.
252  dword_addr.fp88_a = p_dst;
253  // Initial read of 64 bits from memory
254  cur_v.u64 = *dword_addr.u64_a;
255 
256  do
257  {
258  old_v = cur_v.u64;
259  // Add each fp8 element using your add_fp8x8_t(...) routine
260  new_v_union.fp88 = add_fp8x8_t(cur_v.fp88, x);
261  new_v = new_v_union.u64;
262 
263  // Attempt 64-bit CAS
264  cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v);
265  } while(cur_v.u64 != old_v);
266 }
267 
268 //
269 // Atomic add for bf8x8_t
270 //
271 template <>
273 {
274  union U64BF88_ADDR
275  {
276  uint64_t* u64_a;
277  bf8x8_t* bf88_a;
278  };
279 
280  union U64BF88
281  {
282  uint64_t u64;
283  bf8x8_t bf88;
284  };
285 
286  U64BF88_ADDR dword_addr;
287  U64BF88 cur_v;
288  U64BF88 new_v_union;
289  uint64_t old_v, new_v;
290 
291  dword_addr.bf88_a = p_dst;
292  // Read the original 64 bits
293  cur_v.u64 = *dword_addr.u64_a;
294 
295  do
296  {
297  old_v = cur_v.u64;
298  // Add each bf8 element using your add_bf8x8_t(...) routine
299  new_v_union.bf88 = add_bf8x8_t(cur_v.bf88, x);
300  new_v = new_v_union.u64;
301 
302  // 64-bit CAS loop
303  cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v);
304  } while(cur_v.u64 != old_v);
305 }
306 
307 template <typename T, index_t N>
309 {
310  static_assert((std::is_same<T, int32_t>::value && (N == 1)) ||
311  (std::is_same<T, uint32_t>::value && (N == 1)) ||
312  (std::is_same<T, float>::value && (N == 1 || N == 2)) ||
313  (std::is_same<T, double>::value && (N == 1 || N == 2)) ||
314  (std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
315  (std::is_same<T, fp8_t>::value && (N == 4 || N == 8 || N == 16)) ||
316  (std::is_same<T, bf8_t>::value && (N == 4 || N == 8 || N == 16)),
317  "The granularity of the thread buffer is unsupported on the hardware!");
318 
319  constexpr auto I0 = number<0>{};
320  constexpr auto I1 = number<1>{};
321 
322  if constexpr(std::is_same<T, float>::value)
323  {
324  if constexpr(N == 1)
325  {
326  atomicAdd(p_dst, bit_cast<float>(x));
327  }
328  else if constexpr(N == 2)
329  {
330  atomicAdd(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
331  atomicAdd(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
332  }
333  }
334  else if constexpr(std::is_same<T, double>::value)
335  {
336  if constexpr(N == 1)
337  {
338  return atomicAdd(p_dst, bit_cast<double>(x));
339  }
340  else if constexpr(N == 2)
341  {
342  atomicAdd(c_style_pointer_cast<double*>(p_dst), x.template get_as<double>()[I0]);
343  atomicAdd(c_style_pointer_cast<double*>(p_dst) + 1, x.template get_as<double>()[I1]);
344  }
345  }
346  else if constexpr(std::is_same<T, int32_t>::value)
347  {
348  if constexpr(N == 1)
349  {
350  atomicAdd(p_dst, bit_cast<int32_t>(x));
351  }
352  }
353  else if constexpr(std::is_same<T, uint32_t>::value)
354  {
355  if constexpr(N == 1)
356  {
357  atomicAdd(p_dst, bit_cast<uint32_t>(x));
358  }
359  }
360  else if constexpr(std::is_same<T, bf16_t>::value)
361  {
362  if constexpr(N == 2)
363  {
364  atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst), bit_cast<bf16x2_t>(x));
365  }
366  else if constexpr(N == 4)
367  {
368  atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst), x.template get_as<bf16x4_t>()[I0]);
369  }
370  else if constexpr(N == 8)
371  {
372  atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst), x.template get_as<bf16x4_t>()[I0]);
373  atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst) + 1,
374  x.template get_as<bf16x4_t>()[I1]);
375  }
376  }
377  else if constexpr(std::is_same<T, fp8_t>::value)
378  {
379  if constexpr(N == 4)
380  {
381  atomic_add(c_style_pointer_cast<fp8x4_t*>(p_dst), x.template get_as<fp8x4_t>()[I0]);
382  }
383  if constexpr(N == 8)
384  {
385  atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst), x.template get_as<fp8x8_t>()[I0]);
386  }
387  if constexpr(N == 16)
388  {
389  atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst), x.template get_as<fp8x8_t>()[I0]);
390  atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst) + 1, x.template get_as<fp8x8_t>()[I1]);
391  }
392  }
393  else if constexpr(std::is_same<T, bf8_t>::value)
394  {
395  if constexpr(N == 4)
396  {
397  atomic_add(c_style_pointer_cast<bf8x4_t*>(p_dst), x.template get_as<bf8x4_t>()[I0]);
398  }
399  if constexpr(N == 8)
400  {
401  atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst), x.template get_as<bf8x8_t>()[I0]);
402  }
403  if constexpr(N == 16)
404  {
405  atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst), x.template get_as<bf8x8_t>()[I0]);
406  atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst) + 1, x.template get_as<bf8x8_t>()[I1]);
407  }
408  }
409 }
410 
411 template <typename T, index_t N>
413 {
414  static_assert((std::is_same<T, int32_t>::value && (N == 1)) ||
415  (std::is_same<T, uint32_t>::value && (N == 1)) ||
416  (std::is_same<T, float>::value && (N == 1 || N == 2)) ||
417  (std::is_same<T, double>::value && (N == 1)),
418  "wrong! not implemented");
419 
420  constexpr auto I0 = number<0>{};
421  constexpr auto I1 = number<1>{};
422 
423  if constexpr(std::is_same<T, float>::value)
424  {
425  if constexpr(N == 1)
426  {
427  atomicMax(p_dst, bit_cast<float>(x));
428  }
429  else if constexpr(N == 2)
430  {
431  atomicMax(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
432  atomicMax(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
433  }
434  }
435  else if constexpr(std::is_same<T, double>::value)
436  {
437  if constexpr(N == 1)
438  {
439  atomicMax(p_dst, bit_cast<double>(x));
440  }
441  }
442  else if constexpr(std::is_same<T, int32_t>::value)
443  {
444  if constexpr(N == 1)
445  {
446  atomicMax(p_dst, bit_cast<int32_t>(x));
447  }
448  }
449  else if constexpr(std::is_same<T, uint32_t>::value)
450  {
451  if constexpr(N == 1)
452  {
453  atomicMax(p_dst, bit_cast<uint32_t>(x));
454  }
455  }
456 }
457 
458 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
bf16_raw_t bf16x4_t
Definition: vector_type.hpp:105
fp8_t __attribute((ext_vector_type(8))) fp8x8_t
Definition: vector_type.hpp:188
CK_TILE_DEVICE void atomic_add< bf16x2_t >(bf16x2_t *p_dst, const bf16x2_t &x)
Definition: generic_memory_space_atomic.hpp:91
CK_TILE_DEVICE void atomic_add(X *p_dst, const X &x)
bf8_t __attribute((ext_vector_type(4))) bf8x4_t
Definition: vector_type.hpp:196
CK_TILE_DEVICE void atomic_add_g(T *p_dst, const thread_buffer< T, N > &x)
Definition: generic_memory_space_atomic.hpp:308
CK_TILE_DEVICE void atomic_add< fp8x8_t >(fp8x8_t *p_dst, fp8x8_t const &x)
Definition: generic_memory_space_atomic.hpp:231
tuple_array< T, N > thread_buffer
Definition: thread_buffer.hpp:14
CK_TILE_HOST_DEVICE fp8x8_t add_fp8x8_t(const fp8x8_t &a, const fp8x8_t &b)
Definition: generic_memory_space_atomic.hpp:45
bf16_raw_t bf16x2_t
Definition: vector_type.hpp:104
CK_TILE_HOST_DEVICE T add(const T &a, const T &b)
Definition: generic_memory_space_atomic.hpp:12
CK_TILE_DEVICE void atomic_add< bf16x4_t >(bf16x4_t *p_dst, bf16x4_t const &x)
Definition: generic_memory_space_atomic.hpp:122
CK_TILE_HOST_DEVICE bf8x8_t add_bf8x8_t(const bf8x8_t &a, const bf8x8_t &b)
Definition: generic_memory_space_atomic.hpp:69
CK_TILE_HOST_DEVICE fp8x4_t add_fp8x4_t(const fp8x4_t &a, const fp8x4_t &b)
Definition: generic_memory_space_atomic.hpp:35
CK_TILE_DEVICE void atomic_add< fp8x4_t >(fp8x4_t *p_dst, const fp8x4_t &x)
Definition: generic_memory_space_atomic.hpp:164
fp8_t __attribute((ext_vector_type(4))) fp8x4_t
Definition: vector_type.hpp:187
CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t &a, const bf16x2_t &b)
Definition: generic_memory_space_atomic.hpp:17
bf8_t __attribute((ext_vector_type(8))) bf8x8_t
Definition: vector_type.hpp:197
CK_TILE_DEVICE void atomic_max_g(T *p_dst, const thread_buffer< T, N > &x)
Definition: generic_memory_space_atomic.hpp:412
CK_TILE_HOST_DEVICE bf16x4_t add_bf16x4_t(const bf16x4_t &a, const bf16x4_t &b)
Definition: generic_memory_space_atomic.hpp:25
CK_TILE_DEVICE void atomic_add< bf8x4_t >(bf8x4_t *p_dst, const bf8x4_t &x)
Definition: generic_memory_space_atomic.hpp:196
CK_TILE_HOST_DEVICE bf8x4_t add_bf8x4_t(const bf8x4_t &a, const bf8x4_t &b)
Definition: generic_memory_space_atomic.hpp:59
CK_TILE_DEVICE void atomic_add< bf8x8_t >(bf8x8_t *p_dst, bf8x8_t const &x)
Definition: generic_memory_space_atomic.hpp:272
Definition: integral_constant.hpp:13