11 template <
typename T,
typename ComputeType>
14 return type_convert<T>(type_convert<ComputeType>(a) + type_convert<ComputeType>(b));
20 rtn[0] = add<bf16_t, float>(a[0], b[0]);
21 rtn[1] = add<bf16_t, float>(a[1], b[1]);
28 rtn[0] = add<bf16_t, float>(a[0], b[0]);
29 rtn[1] = add<bf16_t, float>(a[1], b[1]);
30 rtn[2] = add<bf16_t, float>(a[2], b[2]);
31 rtn[3] = add<bf16_t, float>(a[3], b[3]);
38 rtn[0] = add<fp8_t, float>(a[0], b[0]);
39 rtn[1] = add<fp8_t, float>(a[1], b[1]);
40 rtn[2] = add<fp8_t, float>(a[2], b[2]);
41 rtn[3] = add<fp8_t, float>(a[3], b[3]);
48 rtn[0] = add<fp8_t, float>(a[0], b[0]);
49 rtn[1] = add<fp8_t, float>(a[1], b[1]);
50 rtn[2] = add<fp8_t, float>(a[2], b[2]);
51 rtn[3] = add<fp8_t, float>(a[3], b[3]);
52 rtn[4] = add<fp8_t, float>(a[4], b[4]);
53 rtn[5] = add<fp8_t, float>(a[5], b[5]);
54 rtn[6] = add<fp8_t, float>(a[6], b[6]);
55 rtn[7] = add<fp8_t, float>(a[7], b[7]);
62 rtn[0] = add<bf8_t, float>(a[0], b[0]);
63 rtn[1] = add<bf8_t, float>(a[1], b[1]);
64 rtn[2] = add<bf8_t, float>(a[2], b[2]);
65 rtn[3] = add<bf8_t, float>(a[3], b[3]);
72 rtn[0] = add<bf8_t, float>(a[0], b[0]);
73 rtn[1] = add<bf8_t, float>(a[1], b[1]);
74 rtn[2] = add<bf8_t, float>(a[2], b[2]);
75 rtn[3] = add<bf8_t, float>(a[3], b[3]);
76 rtn[4] = add<bf8_t, float>(a[4], b[4]);
77 rtn[5] = add<bf8_t, float>(a[5], b[5]);
78 rtn[6] = add<bf8_t, float>(a[6], b[6]);
79 rtn[7] = add<bf8_t, float>(a[7], b[7]);
105 U32BF162_ADDR dword_addr;
108 uint32_t old_v, new_v;
109 dword_addr.bf162_a = p_dst;
110 cur_v.u32 = *dword_addr.u32_a;
117 cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
118 }
while(cur_v.u32 != old_v);
139 addr.bf164_a = p_dst;
143 cur_v.u64 = *addr.u64_a;
145 U64BF164 new_v_union;
146 uint64_t old_v, new_v;
155 new_v = new_v_union.u64;
158 cur_v.u64 = atomicCAS(addr.u64_a, old_v, new_v);
160 }
while(cur_v.u64 != old_v);
178 U32FP84_ADDR dword_addr;
181 uint32_t old_v, new_v;
183 dword_addr.fp84_a = p_dst;
184 cur_v.u32 = *dword_addr.u32_a;
191 cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
192 }
while(cur_v.u32 != old_v);
210 U32BF84_ADDR dword_addr;
213 uint32_t old_v, new_v;
215 dword_addr.bf84_a = p_dst;
216 cur_v.u32 = *dword_addr.u32_a;
223 cur_v.u32 = atomicCAS(dword_addr.u32_a, old_v, new_v);
224 }
while(cur_v.u32 != old_v);
246 U64FP88_ADDR dword_addr;
249 uint64_t old_v, new_v;
252 dword_addr.fp88_a = p_dst;
254 cur_v.u64 = *dword_addr.u64_a;
261 new_v = new_v_union.u64;
264 cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v);
265 }
while(cur_v.u64 != old_v);
286 U64BF88_ADDR dword_addr;
289 uint64_t old_v, new_v;
291 dword_addr.bf88_a = p_dst;
293 cur_v.u64 = *dword_addr.u64_a;
300 new_v = new_v_union.u64;
303 cur_v.u64 = atomicCAS(dword_addr.u64_a, old_v, new_v);
304 }
while(cur_v.u64 != old_v);
307 template <
typename T, index_t N>
310 static_assert((std::is_same<T, int32_t>::value && (N == 1)) ||
311 (std::is_same<T, uint32_t>::value && (N == 1)) ||
312 (std::is_same<T, float>::value && (N == 1 || N == 2)) ||
313 (std::is_same<T, double>::value && (N == 1 || N == 2)) ||
314 (std::is_same<T, bf16_t>::value && (N == 2 || N == 4 || N == 8)) ||
315 (std::is_same<T, fp8_t>::value && (N == 4 || N == 8 || N == 16)) ||
316 (std::is_same<T, bf8_t>::value && (N == 4 || N == 8 || N == 16)),
317 "The granularity of the thread buffer is unsupported on the hardware!");
322 if constexpr(std::is_same<T, float>::value)
326 atomicAdd(p_dst, bit_cast<float>(x));
328 else if constexpr(N == 2)
330 atomicAdd(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
331 atomicAdd(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
334 else if constexpr(std::is_same<T, double>::value)
338 return atomicAdd(p_dst, bit_cast<double>(x));
340 else if constexpr(N == 2)
342 atomicAdd(c_style_pointer_cast<double*>(p_dst), x.template get_as<double>()[I0]);
343 atomicAdd(c_style_pointer_cast<double*>(p_dst) + 1, x.template get_as<double>()[I1]);
346 else if constexpr(std::is_same<T, int32_t>::value)
350 atomicAdd(p_dst, bit_cast<int32_t>(x));
353 else if constexpr(std::is_same<T, uint32_t>::value)
357 atomicAdd(p_dst, bit_cast<uint32_t>(x));
360 else if constexpr(std::is_same<T, bf16_t>::value)
364 atomic_add(c_style_pointer_cast<bf16x2_t*>(p_dst), bit_cast<bf16x2_t>(x));
366 else if constexpr(N == 4)
368 atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst), x.template get_as<bf16x4_t>()[I0]);
370 else if constexpr(N == 8)
372 atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst), x.template get_as<bf16x4_t>()[I0]);
373 atomic_add(c_style_pointer_cast<bf16x4_t*>(p_dst) + 1,
374 x.template get_as<bf16x4_t>()[I1]);
377 else if constexpr(std::is_same<T, fp8_t>::value)
381 atomic_add(c_style_pointer_cast<fp8x4_t*>(p_dst), x.template get_as<fp8x4_t>()[I0]);
385 atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst), x.template get_as<fp8x8_t>()[I0]);
387 if constexpr(N == 16)
389 atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst), x.template get_as<fp8x8_t>()[I0]);
390 atomic_add(c_style_pointer_cast<fp8x8_t*>(p_dst) + 1, x.template get_as<fp8x8_t>()[I1]);
393 else if constexpr(std::is_same<T, bf8_t>::value)
397 atomic_add(c_style_pointer_cast<bf8x4_t*>(p_dst), x.template get_as<bf8x4_t>()[I0]);
401 atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst), x.template get_as<bf8x8_t>()[I0]);
403 if constexpr(N == 16)
405 atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst), x.template get_as<bf8x8_t>()[I0]);
406 atomic_add(c_style_pointer_cast<bf8x8_t*>(p_dst) + 1, x.template get_as<bf8x8_t>()[I1]);
411 template <
typename T, index_t N>
414 static_assert((std::is_same<T, int32_t>::value && (N == 1)) ||
415 (std::is_same<T, uint32_t>::value && (N == 1)) ||
416 (std::is_same<T, float>::value && (N == 1 || N == 2)) ||
417 (std::is_same<T, double>::value && (N == 1)),
418 "wrong! not implemented");
423 if constexpr(std::is_same<T, float>::value)
427 atomicMax(p_dst, bit_cast<float>(x));
429 else if constexpr(N == 2)
431 atomicMax(c_style_pointer_cast<float*>(p_dst), x.template get_as<float>()[I0]);
432 atomicMax(c_style_pointer_cast<float*>(p_dst) + 1, x.template get_as<float>()[I1]);
435 else if constexpr(std::is_same<T, double>::value)
439 atomicMax(p_dst, bit_cast<double>(x));
442 else if constexpr(std::is_same<T, int32_t>::value)
446 atomicMax(p_dst, bit_cast<int32_t>(x));
449 else if constexpr(std::is_same<T, uint32_t>::value)
453 atomicMax(p_dst, bit_cast<uint32_t>(x));
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
bf16_raw_t bf16x4_t
Definition: vector_type.hpp:105
fp8_t __attribute((ext_vector_type(8))) fp8x8_t
Definition: vector_type.hpp:188
CK_TILE_DEVICE void atomic_add< bf16x2_t >(bf16x2_t *p_dst, const bf16x2_t &x)
Definition: generic_memory_space_atomic.hpp:91
CK_TILE_DEVICE void atomic_add(X *p_dst, const X &x)
bf8_t __attribute((ext_vector_type(4))) bf8x4_t
Definition: vector_type.hpp:196
CK_TILE_DEVICE void atomic_add_g(T *p_dst, const thread_buffer< T, N > &x)
Definition: generic_memory_space_atomic.hpp:308
CK_TILE_DEVICE void atomic_add< fp8x8_t >(fp8x8_t *p_dst, fp8x8_t const &x)
Definition: generic_memory_space_atomic.hpp:231
tuple_array< T, N > thread_buffer
Definition: thread_buffer.hpp:14
CK_TILE_HOST_DEVICE fp8x8_t add_fp8x8_t(const fp8x8_t &a, const fp8x8_t &b)
Definition: generic_memory_space_atomic.hpp:45
bf16_raw_t bf16x2_t
Definition: vector_type.hpp:104
CK_TILE_HOST_DEVICE T add(const T &a, const T &b)
Definition: generic_memory_space_atomic.hpp:12
CK_TILE_DEVICE void atomic_add< bf16x4_t >(bf16x4_t *p_dst, bf16x4_t const &x)
Definition: generic_memory_space_atomic.hpp:122
CK_TILE_HOST_DEVICE bf8x8_t add_bf8x8_t(const bf8x8_t &a, const bf8x8_t &b)
Definition: generic_memory_space_atomic.hpp:69
CK_TILE_HOST_DEVICE fp8x4_t add_fp8x4_t(const fp8x4_t &a, const fp8x4_t &b)
Definition: generic_memory_space_atomic.hpp:35
CK_TILE_DEVICE void atomic_add< fp8x4_t >(fp8x4_t *p_dst, const fp8x4_t &x)
Definition: generic_memory_space_atomic.hpp:164
fp8_t __attribute((ext_vector_type(4))) fp8x4_t
Definition: vector_type.hpp:187
CK_TILE_HOST_DEVICE bf16x2_t add_bf16x2_t(const bf16x2_t &a, const bf16x2_t &b)
Definition: generic_memory_space_atomic.hpp:17
bf8_t __attribute((ext_vector_type(8))) bf8x8_t
Definition: vector_type.hpp:197
CK_TILE_DEVICE void atomic_max_g(T *p_dst, const thread_buffer< T, N > &x)
Definition: generic_memory_space_atomic.hpp:412
CK_TILE_HOST_DEVICE bf16x4_t add_bf16x4_t(const bf16x4_t &a, const bf16x4_t &b)
Definition: generic_memory_space_atomic.hpp:25
CK_TILE_DEVICE void atomic_add< bf8x4_t >(bf8x4_t *p_dst, const bf8x4_t &x)
Definition: generic_memory_space_atomic.hpp:196
CK_TILE_HOST_DEVICE bf8x4_t add_bf8x4_t(const bf8x4_t &a, const bf8x4_t &b)
Definition: generic_memory_space_atomic.hpp:59
CK_TILE_DEVICE void atomic_add< bf8x8_t >(bf8x8_t *p_dst, bf8x8_t const &x)
Definition: generic_memory_space_atomic.hpp:272
Definition: integral_constant.hpp:13