/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/utility/dynamic_buffer.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/utility/dynamic_buffer.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/utility/dynamic_buffer.hpp Source File
dynamic_buffer.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck/ck.hpp"
8 #include "enable_if.hpp"
12 
13 namespace ck {
14 
15 // T may be scalar or vector
16 // X may be scalar or vector
17 // T and X have same scalar type
18 // X contains multiple T
19 template <AddressSpaceEnum BufferAddressSpace,
20  typename T,
21  typename ElementSpaceSize,
22  bool InvalidElementUseNumericalZeroValue,
25 {
26  using type = T;
27 
28  T* p_data_;
29  ElementSpaceSize element_space_size_;
31 
32  static constexpr index_t PackedSize = []() {
33  if constexpr(is_same_v<remove_cvref_t<T>, pk_i4_t>)
34  return 2;
35  else
36  return 1;
37  }();
38 
39  __host__ __device__ constexpr DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
40  : p_data_{p_data}, element_space_size_{element_space_size}
41  {
42  }
43 
44  __host__ __device__ constexpr DynamicBuffer(T* p_data,
45  ElementSpaceSize element_space_size,
46  T invalid_element_value)
47  : p_data_{p_data},
48  element_space_size_{element_space_size},
49  invalid_element_value_{invalid_element_value}
50  {
51  }
52 
53  __host__ __device__ static constexpr AddressSpaceEnum GetAddressSpace()
54  {
55  return BufferAddressSpace;
56  }
57 
58  __host__ __device__ constexpr const T& operator[](index_t i) const { return p_data_[i]; }
59 
60  __host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; }
61 
62  template <typename X,
64  typename scalar_type<remove_cvref_t<T>>::type>::value ||
65  !is_native_type<X>(),
66  bool>::type = false>
67  __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const
68  {
69  // X contains multiple T
70  constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
71 
72  constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
73 
74  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
75  "wrong! X should contain multiple T");
76 
77 #if CK_USE_AMD_BUFFER_LOAD
78  bool constexpr use_amd_buffer_addressing = true;
79 #else
80  bool constexpr use_amd_buffer_addressing = false;
81 #endif
82 
83  if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing)
84  {
85  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
86 
87  if constexpr(InvalidElementUseNumericalZeroValue)
88  {
89  return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>,
90  t_per_x,
91  coherence>(
92  p_data_, i, is_valid_element, element_space_size_ / PackedSize);
93  }
94  else
95  {
96  return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>,
97  t_per_x,
98  coherence>(
99  p_data_,
100  i,
101  is_valid_element,
104  }
105  }
106  else
107  {
108  if(is_valid_element)
109  {
110 #if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
111  X tmp;
112 
113  __builtin_memcpy(&tmp, &(p_data_[i]), sizeof(X));
114 
115  return tmp;
116 #else
117  return *c_style_pointer_cast<const X*>(&p_data_[i]);
118 #endif
119  }
120  else
121  {
122  if constexpr(InvalidElementUseNumericalZeroValue)
123  {
124  return X{0};
125  }
126  else
127  {
128  return X{invalid_element_value_};
129  }
130  }
131  }
132  }
133 
134  template <InMemoryDataOperationEnum Op,
135  typename X,
137  typename scalar_type<remove_cvref_t<T>>::type>::value,
138  bool>::type = false>
139  __host__ __device__ void Update(index_t i, bool is_valid_element, const X& x)
140  {
141  if constexpr(Op == InMemoryDataOperationEnum::Set)
142  {
143  this->template Set<X>(i, is_valid_element, x);
144  }
145  else if constexpr(Op == InMemoryDataOperationEnum::AtomicAdd)
146  {
147  this->template AtomicAdd<X>(i, is_valid_element, x);
148  }
149  else if constexpr(Op == InMemoryDataOperationEnum::AtomicMax)
150  {
151  this->template AtomicMax<X>(i, is_valid_element, x);
152  }
153  else if constexpr(Op == InMemoryDataOperationEnum::Add)
154  {
155  auto tmp = this->template Get<X>(i, is_valid_element);
156  using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
157  // handle bfloat addition
158  if constexpr(is_same_v<scalar_t, bhalf_t>)
159  {
160  if constexpr(is_scalar_type<X>::value)
161  {
162  // Scalar type
163  auto result =
164  type_convert<X>(type_convert<float>(x) + type_convert<float>(tmp));
165  this->template Set<X>(i, is_valid_element, result);
166  }
167  else
168  {
169  // Vector type
170  constexpr auto vector_size = scalar_type<remove_cvref_t<X>>::vector_size;
171  const vector_type<scalar_t, vector_size> a_vector{tmp};
172  const vector_type<scalar_t, vector_size> b_vector{x};
173  static_for<0, vector_size, 1>{}([&](auto idx) {
174  auto result = type_convert<scalar_t>(
175  type_convert<float>(a_vector.template AsType<scalar_t>()[idx]) +
176  type_convert<float>(b_vector.template AsType<scalar_t>()[idx]));
177  this->template Set<scalar_t>(i + idx, is_valid_element, result);
178  });
179  }
180  }
181  else
182  {
183  this->template Set<X>(i, is_valid_element, x + tmp);
184  }
185  }
186  }
187 
188  template <typename DstBuffer, index_t NumElemsPerThread>
189  __host__ __device__ void DirectCopyToLds(DstBuffer& dst_buf,
190  index_t src_offset,
191  index_t dst_offset,
192  bool is_valid_element) const
193  {
194  // Copy data from global to LDS memory using direct loads.
195  static_assert(GetAddressSpace() == AddressSpaceEnum::Global,
196  "Source data must come from a global memory buffer.");
197  static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum::Lds,
198  "Destination data must be stored in an LDS memory buffer.");
199 
200  amd_direct_load_global_to_lds<T, NumElemsPerThread>(p_data_,
201  src_offset,
202  dst_buf.p_data_,
203  dst_offset,
204  is_valid_element,
206  }
207 
208  template <typename X,
210  typename scalar_type<remove_cvref_t<T>>::type>::value ||
211  !is_native_type<X>(),
212  bool>::type = false>
213  __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x)
214  {
215  // X contains multiple T
216  constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
217 
218  constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
219 
220  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
221  "wrong! X should contain multiple T");
222 
223 #if CK_USE_AMD_BUFFER_STORE
224  bool constexpr use_amd_buffer_addressing = true;
225 #else
226  bool constexpr use_amd_buffer_addressing = false;
227 #endif
228 
229 #if CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
230  bool constexpr workaround_int8_ds_write_issue = true;
231 #else
232  bool constexpr workaround_int8_ds_write_issue = false;
233 #endif
234 
235  if constexpr(GetAddressSpace() == AddressSpaceEnum::Global && use_amd_buffer_addressing)
236  {
237  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
238 
239  amd_buffer_store<remove_cvref_t<T>, t_per_x, coherence>(
240  x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
241  }
242  else if constexpr(GetAddressSpace() == AddressSpaceEnum::Lds &&
243  is_same<typename scalar_type<remove_cvref_t<T>>::type, int8_t>::value &&
244  workaround_int8_ds_write_issue)
245  {
246  if(is_valid_element)
247  {
248  // HACK: compiler would lower IR "store<i8, 16> address_space(3)" into inefficient
249  // ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
250  // ds_write_b128
251  // TODO: remove this after compiler fix
252  static_assert((is_same<remove_cvref_t<T>, int8_t>::value &&
253  is_same<remove_cvref_t<X>, int8_t>::value) ||
254  (is_same<remove_cvref_t<T>, int8_t>::value &&
255  is_same<remove_cvref_t<X>, int8x2_t>::value) ||
256  (is_same<remove_cvref_t<T>, int8_t>::value &&
257  is_same<remove_cvref_t<X>, int8x4_t>::value) ||
258  (is_same<remove_cvref_t<T>, int8_t>::value &&
259  is_same<remove_cvref_t<X>, int8x8_t>::value) ||
260  (is_same<remove_cvref_t<T>, int8_t>::value &&
261  is_same<remove_cvref_t<X>, int8x16_t>::value) ||
262  (is_same<remove_cvref_t<T>, int8x4_t>::value &&
263  is_same<remove_cvref_t<X>, int8x4_t>::value) ||
264  (is_same<remove_cvref_t<T>, int8x8_t>::value &&
265  is_same<remove_cvref_t<X>, int8x8_t>::value) ||
266  (is_same<remove_cvref_t<T>, int8x16_t>::value &&
268  "wrong! not implemented for this combination, please add "
269  "implementation");
270 
271  if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
273  {
274  // HACK: cast pointer of x is bad
275  // TODO: remove this after compiler fix
276  *c_style_pointer_cast<int8_t*>(&p_data_[i]) =
277  *c_style_pointer_cast<const int8_t*>(&x);
278  }
279  else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
281  {
282  // HACK: cast pointer of x is bad
283  // TODO: remove this after compiler fix
284  *c_style_pointer_cast<int16_t*>(&p_data_[i]) =
285  *c_style_pointer_cast<const int16_t*>(&x);
286  }
287  else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
289  {
290  // HACK: cast pointer of x is bad
291  // TODO: remove this after compiler fix
292  *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
293  *c_style_pointer_cast<const int32_t*>(&x);
294  }
295  else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
297  {
298  // HACK: cast pointer of x is bad
299  // TODO: remove this after compiler fix
300  *c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
301  *c_style_pointer_cast<const int32x2_t*>(&x);
302  }
303  else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
305  {
306  // HACK: cast pointer of x is bad
307  // TODO: remove this after compiler fix
308  *c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
309  *c_style_pointer_cast<const int32x4_t*>(&x);
310  }
311  else if constexpr(is_same<remove_cvref_t<T>, int8x4_t>::value &&
313  {
314  // HACK: cast pointer of x is bad
315  // TODO: remove this after compiler fix
316  *c_style_pointer_cast<int32_t*>(&p_data_[i]) =
317  *c_style_pointer_cast<const int32_t*>(&x);
318  }
319  else if constexpr(is_same<remove_cvref_t<T>, int8x8_t>::value &&
321  {
322  // HACK: cast pointer of x is bad
323  // TODO: remove this after compiler fix
324  *c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
325  *c_style_pointer_cast<const int32x2_t*>(&x);
326  }
327  else if constexpr(is_same<remove_cvref_t<T>, int8x16_t>::value &&
329  {
330  // HACK: cast pointer of x is bad
331  // TODO: remove this after compiler fix
332  *c_style_pointer_cast<int32x4_t*>(&p_data_[i]) =
333  *c_style_pointer_cast<const int32x4_t*>(&x);
334  }
335  }
336  }
337  else
338  {
339  if(is_valid_element)
340  {
341 #if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
342  X tmp = x;
343 
344  __builtin_memcpy(&(p_data_[i]), &tmp, sizeof(X));
345 #else
346  *c_style_pointer_cast<X*>(&p_data_[i]) = x;
347 #endif
348  }
349  }
350  }
351 
352  template <typename X,
354  typename scalar_type<remove_cvref_t<T>>::type>::value,
355  bool>::type = false>
356  __host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x)
357  {
358  using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
359 
360  // X contains multiple T
361  constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
362 
363  constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
364 
365  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
366  "wrong! X should contain multiple T");
367 
368  static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem");
369 
370 #if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
371  bool constexpr use_amd_buffer_addressing =
372  is_same_v<remove_cvref_t<scalar_t>, int32_t> ||
373  is_same_v<remove_cvref_t<scalar_t>, float> ||
374  (is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
375  (is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0);
376 #elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT)
377  bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, int32_t>;
378 #elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT
379  bool constexpr use_amd_buffer_addressing =
380  is_same_v<remove_cvref_t<scalar_t>, float> ||
381  (is_same_v<remove_cvref_t<scalar_t>, half_t> && scalar_per_x_vector % 2 == 0) ||
382  (is_same_v<remove_cvref_t<scalar_t>, bhalf_t> && scalar_per_x_vector % 2 == 0);
383 #else
384  bool constexpr use_amd_buffer_addressing = false;
385 #endif
386 
387  if constexpr(use_amd_buffer_addressing)
388  {
389  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
390 
391  amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
392  x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
393  }
394  else
395  {
396  if(is_valid_element)
397  {
398  atomic_add<X>(c_style_pointer_cast<X*>(&p_data_[i]), x);
399  }
400  }
401  }
402 
403  template <typename X,
405  typename scalar_type<remove_cvref_t<T>>::type>::value,
406  bool>::type = false>
407  __host__ __device__ void AtomicMax(index_t i, bool is_valid_element, const X& x)
408  {
409  // X contains multiple T
410  constexpr index_t scalar_per_t_vector = scalar_type<remove_cvref_t<T>>::vector_size;
411 
412  constexpr index_t scalar_per_x_vector = scalar_type<remove_cvref_t<X>>::vector_size;
413 
414  static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
415  "wrong! X should contain multiple T");
416 
417  static_assert(GetAddressSpace() == AddressSpaceEnum::Global, "only support global mem");
418 
419 #if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64
420  using scalar_t = typename scalar_type<remove_cvref_t<T>>::type;
421  bool constexpr use_amd_buffer_addressing = is_same_v<remove_cvref_t<scalar_t>, double>;
422 #else
423  bool constexpr use_amd_buffer_addressing = false;
424 #endif
425 
426  if constexpr(use_amd_buffer_addressing)
427  {
428  constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
429 
430  amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
431  x, p_data_, i, is_valid_element, element_space_size_ / PackedSize);
432  }
433  else if(is_valid_element)
434  {
435  atomic_max<X>(c_style_pointer_cast<X*>(&p_data_[i]), x);
436  }
437  }
438 
439  __host__ __device__ static constexpr bool IsStaticBuffer() { return false; }
440 
441  __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
442 };
443 
444 template <AddressSpaceEnum BufferAddressSpace,
446  typename T,
447  typename ElementSpaceSize>
448 __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
449 {
451  p, element_space_size};
452 }
453 
454 template <
455  AddressSpaceEnum BufferAddressSpace,
457  typename T,
458  typename ElementSpaceSize,
459  typename X,
460  typename enable_if<is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value, bool>::type = false>
461 __host__ __device__ constexpr auto
462 make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element_value)
463 {
465  p, element_space_size, invalid_element_value};
466 }
467 
468 } // namespace ck
int8_t int8_t
Definition: int8.hpp:20
Definition: ck.hpp:264
AmdBufferCoherenceEnum
Definition: amd_buffer_addressing.hpp:295
InMemoryDataOperationEnum
Definition: ck.hpp:267
typename vector_type< int8_t, 2 >::type int8x2_t
Definition: data_type.hpp:2513
typename vector_type< int8_t, 8 >::type int8x8_t
Definition: data_type.hpp:2515
AddressSpaceEnum
Definition: amd_address_space.hpp:15
_Float16 half_t
Definition: data_type.hpp:25
ushort bhalf_t
Definition: data_type.hpp:24
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:10
typename vector_type< int8_t, 16 >::type int8x16_t
Definition: data_type.hpp:2516
constexpr bool is_same_v
Definition: type.hpp:283
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
int32_t index_t
Definition: ck.hpp:289
typename vector_type< int8_t, 4 >::type int8x4_t
Definition: data_type.hpp:2514
__host__ constexpr __device__ auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition: dynamic_buffer.hpp:448
Definition: dynamic_buffer.hpp:25
__host__ constexpr __device__ DynamicBuffer(T *p_data, ElementSpaceSize element_space_size)
Definition: dynamic_buffer.hpp:39
__host__ static constexpr __device__ bool IsStaticBuffer()
Definition: dynamic_buffer.hpp:439
__host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X &x)
Definition: dynamic_buffer.hpp:356
__host__ __device__ void AtomicMax(index_t i, bool is_valid_element, const X &x)
Definition: dynamic_buffer.hpp:407
__host__ constexpr __device__ auto Get(index_t i, bool is_valid_element) const
Definition: dynamic_buffer.hpp:67
ElementSpaceSize element_space_size_
Definition: dynamic_buffer.hpp:29
T * p_data_
Definition: dynamic_buffer.hpp:28
__host__ static constexpr __device__ AddressSpaceEnum GetAddressSpace()
Definition: dynamic_buffer.hpp:53
__host__ constexpr __device__ const T & operator[](index_t i) const
Definition: dynamic_buffer.hpp:58
__host__ constexpr __device__ T & operator()(index_t i)
Definition: dynamic_buffer.hpp:60
__host__ constexpr __device__ DynamicBuffer(T *p_data, ElementSpaceSize element_space_size, T invalid_element_value)
Definition: dynamic_buffer.hpp:44
T invalid_element_value_
Definition: dynamic_buffer.hpp:30
__host__ __device__ void Set(index_t i, bool is_valid_element, const X &x)
Definition: dynamic_buffer.hpp:213
static constexpr index_t PackedSize
Definition: dynamic_buffer.hpp:32
__host__ __device__ void DirectCopyToLds(DstBuffer &dst_buf, index_t src_offset, index_t dst_offset, bool is_valid_element) const
Definition: dynamic_buffer.hpp:189
__host__ static constexpr __device__ bool IsDynamicBuffer()
Definition: dynamic_buffer.hpp:441
T type
Definition: dynamic_buffer.hpp:26
__host__ __device__ void Update(index_t i, bool is_valid_element, const X &x)
Definition: dynamic_buffer.hpp:139
Definition: type.hpp:177
Definition: data_type.hpp:399
Definition: data_type.hpp:320
Definition: data_type.hpp:394
Definition: functional2.hpp:31
Definition: data_type.hpp:347