10 #if __clang_major__ >= 20 
   11 #include "amd_buffer_addressing_builtins.hpp" 
   25           typename ElementSpaceSize,
 
   26           bool InvalidElementUseNumericalZeroValue,
 
   48     __host__ __device__ constexpr 
DynamicBuffer(T* p_data, ElementSpaceSize element_space_size)
 
   54                                                 ElementSpaceSize element_space_size,
 
   55                                                 T invalid_element_value)
 
   64         return BufferAddressSpace;
 
   76     __host__ __device__ constexpr 
auto Get(IndexType i, 
bool is_valid_element)
 const 
   83         static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
 
   84                       "wrong! X should contain multiple T");
 
   86 #if CK_USE_AMD_BUFFER_LOAD 
   87         bool constexpr use_amd_buffer_addressing = 
sizeof(IndexType) <= 
sizeof(
int32_t);
 
   89         bool constexpr use_amd_buffer_addressing = 
false;
 
   94             constexpr 
index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
 
   96             if constexpr(InvalidElementUseNumericalZeroValue)
 
   98                 return amd_buffer_load_invalid_element_return_zero<remove_cvref_t<T>,
 
  105                 return amd_buffer_load_invalid_element_return_customized_value<remove_cvref_t<T>,
 
  119 #if CK_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS 
  122                 __builtin_memcpy(&tmp, &(
p_data_[i]), 
sizeof(X));
 
  126                 return *c_style_pointer_cast<const X*>(&
p_data_[i]);
 
  131                 if constexpr(InvalidElementUseNumericalZeroValue)
 
  148     __host__ __device__ 
void Update(IndexType i, 
bool is_valid_element, 
const X& x)
 
  152             this->
template Set<X>(i, is_valid_element, x);
 
  156             this->
template AtomicAdd<X>(i, is_valid_element, x);
 
  160             this->
template AtomicMax<X>(i, is_valid_element, x);
 
  164             auto tmp       = this->
template Get<X>(i, is_valid_element);
 
  167             if constexpr(is_same_v<scalar_t, bhalf_t>)
 
  173                         type_convert<X>(type_convert<float>(x) + type_convert<float>(tmp));
 
  174                     this->
template Set<X>(i, is_valid_element, result);
 
  183                         auto result = type_convert<scalar_t>(
 
  184                             type_convert<float>(a_vector.template AsType<scalar_t>()[idx]) +
 
  185                             type_convert<float>(b_vector.template AsType<scalar_t>()[idx]));
 
  186                         this->
template Set<scalar_t>(i + idx, is_valid_element, result);
 
  192                 this->
template Set<X>(i, is_valid_element, x + tmp);
 
  197     template <
typename DstBuffer, index_t NumElemsPerThread>
 
  199                                              IndexType src_offset,
 
  200                                              IndexType dst_offset,
 
  201                                              bool is_valid_element)
 const 
  205                       "Source data must come from a global memory buffer.");
 
  207                       "Destination data must be stored in an LDS memory buffer.");
 
  209         amd_direct_load_global_to_lds<T, NumElemsPerThread>(
p_data_,
 
  217     template <
typename X,
 
  220                                      !is_native_type<X>(),
 
  222     __host__ __device__ 
void Set(IndexType i, 
bool is_valid_element, 
const X& x)
 
  229         static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
 
  230                       "wrong! X should contain multiple T");
 
  232 #if CK_USE_AMD_BUFFER_LOAD 
  233         bool constexpr use_amd_buffer_addressing = 
sizeof(IndexType) <= 
sizeof(
int32_t);
 
  235         bool constexpr use_amd_buffer_addressing = 
false;
 
  238 #if CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 
  239         bool constexpr workaround_int8_ds_write_issue = 
true;
 
  241         bool constexpr workaround_int8_ds_write_issue = 
false;
 
  246             constexpr 
index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
 
  248             amd_buffer_store<remove_cvref_t<T>, t_per_x, coherence>(
 
  253                           workaround_int8_ds_write_issue)
 
  277                               "wrong! not implemented for this combination, please add " 
  285                     *c_style_pointer_cast<int8_t*>(&
p_data_[i]) =
 
  286                         *c_style_pointer_cast<const int8_t*>(&x);
 
  293                     *c_style_pointer_cast<int16_t*>(&
p_data_[i]) =
 
  294                         *c_style_pointer_cast<const int16_t*>(&x);
 
  301                     *c_style_pointer_cast<int32_t*>(&
p_data_[i]) =
 
  302                         *c_style_pointer_cast<const int32_t*>(&x);
 
  309                     *c_style_pointer_cast<int32x2_t*>(&
p_data_[i]) =
 
  310                         *c_style_pointer_cast<const int32x2_t*>(&x);
 
  317                     *c_style_pointer_cast<int32x4_t*>(&
p_data_[i]) =
 
  318                         *c_style_pointer_cast<const int32x4_t*>(&x);
 
  325                     *c_style_pointer_cast<int32_t*>(&
p_data_[i]) =
 
  326                         *c_style_pointer_cast<const int32_t*>(&x);
 
  333                     *c_style_pointer_cast<int32x2_t*>(&
p_data_[i]) =
 
  334                         *c_style_pointer_cast<const int32x2_t*>(&x);
 
  341                     *c_style_pointer_cast<int32x4_t*>(&
p_data_[i]) =
 
  342                         *c_style_pointer_cast<const int32x4_t*>(&x);
 
  353                 __builtin_memcpy(&(
p_data_[i]), &tmp, 
sizeof(X));
 
  356                 *c_style_pointer_cast<X*>(&
p_data_[i]) = x;
 
  362     template <
typename X,
 
  366     __host__ __device__ 
void AtomicAdd(IndexType i, 
bool is_valid_element, 
const X& x)
 
  375         static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
 
  376                       "wrong! X should contain multiple T");
 
  380 #if CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 
  381         bool constexpr use_amd_buffer_addressing =
 
  382             is_same_v<remove_cvref_t<scalar_t>, 
int32_t> ||
 
  383             is_same_v<remove_cvref_t<scalar_t>, 
float> ||
 
  384             (is_same_v<remove_cvref_t<scalar_t>, 
half_t> && scalar_per_x_vector % 2 == 0) ||
 
  386 #elif CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER && (!CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT) 
  387         bool constexpr use_amd_buffer_addressing =
 
  388             sizeof(IndexType) <= 
sizeof(
int32_t) && is_same_v<remove_cvref_t<scalar_t>, 
int32_t>;
 
  389 #elif(!CK_USE_AMD_BUFFER_ATOMIC_ADD_INTEGER) && CK_USE_AMD_BUFFER_ATOMIC_ADD_FLOAT 
  390         bool constexpr use_amd_buffer_addressing =
 
  391             sizeof(IndexType) <= 
sizeof(
int32_t) &&
 
  392             (is_same_v<remove_cvref_t<scalar_t>, 
float> ||
 
  393              (is_same_v<remove_cvref_t<scalar_t>, 
half_t> && scalar_per_x_vector % 2 == 0) ||
 
  396         bool constexpr use_amd_buffer_addressing = 
false;
 
  399         if constexpr(use_amd_buffer_addressing)
 
  401             constexpr 
index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
 
  403             amd_buffer_atomic_add<remove_cvref_t<T>, t_per_x>(
 
  410                 atomic_add<X>(c_style_pointer_cast<X*>(&
p_data_[i]), x);
 
  415     template <
typename X,
 
  419     __host__ __device__ 
void AtomicMax(IndexType i, 
bool is_valid_element, 
const X& x)
 
  426         static_assert(scalar_per_x_vector % scalar_per_t_vector == 0,
 
  427                       "wrong! X should contain multiple T");
 
  431 #if CK_USE_AMD_BUFFER_ATOMIC_MAX_FLOAT64 
  433         bool constexpr use_amd_buffer_addressing =
 
  434             sizeof(IndexType) <= 
sizeof(
int32_t) && is_same_v<remove_cvref_t<scalar_t>, 
double>;
 
  436         bool constexpr use_amd_buffer_addressing = 
false;
 
  439         if constexpr(use_amd_buffer_addressing)
 
  441             constexpr 
index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
 
  443             amd_buffer_atomic_max<remove_cvref_t<T>, t_per_x>(
 
  446         else if(is_valid_element)
 
  448             atomic_max<X>(c_style_pointer_cast<X*>(&
p_data_[i]), x);
 
  460           typename ElementSpaceSize>
 
  464         p, element_space_size};
 
  470           typename ElementSpaceSize>
 
  472                                                             ElementSpaceSize element_space_size)
 
  475         p, element_space_size};
 
  482     typename ElementSpaceSize,
 
  484     typename enable_if<is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value, 
bool>::type = 
false>
 
  485 __host__ __device__ constexpr 
auto 
  489         p, element_space_size, invalid_element_value};
 
int8_t int8_t
Definition: int8.hpp:20
 
int32_t int32_t
Definition: integer.hpp:10
 
AmdBufferCoherenceEnum
Definition: amd_buffer_addressing.hpp:295
 
InMemoryDataOperationEnum
Definition: ck.hpp:275
 
typename vector_type< int8_t, 2 >::type int8x2_t
Definition: dtype_vector.hpp:2162
 
__host__ constexpr __device__ auto make_long_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition: dynamic_buffer.hpp:471
 
typename vector_type< int8_t, 8 >::type int8x8_t
Definition: dtype_vector.hpp:2164
 
AddressSpaceEnum
Definition: amd_address_space.hpp:15
 
_Float16 half_t
Definition: data_type.hpp:30
 
ushort bhalf_t
Definition: data_type.hpp:29
 
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
 
typename vector_type< int8_t, 16 >::type int8x16_t
Definition: dtype_vector.hpp:2165
 
constexpr bool is_same_v
Definition: type.hpp:283
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
int32_t index_t
Definition: ck.hpp:297
 
typename vector_type< int8_t, 4 >::type int8x4_t
Definition: dtype_vector.hpp:2163
 
__host__ constexpr __device__ auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition: dynamic_buffer.hpp:461
 
Definition: dynamic_buffer.hpp:30
 
__host__ constexpr __device__ auto Get(IndexType i, bool is_valid_element) const
Definition: dynamic_buffer.hpp:76
 
__host__ constexpr __device__ const T & operator[](IndexType i) const
Definition: dynamic_buffer.hpp:67
 
ElementSpaceSize element_space_size_
Definition: dynamic_buffer.hpp:34
 
__host__ constexpr __device__ DynamicBuffer(T *p_data, ElementSpaceSize element_space_size, T invalid_element_value)
Definition: dynamic_buffer.hpp:53
 
__host__ constexpr __device__ T & operator()(IndexType i)
Definition: dynamic_buffer.hpp:69
 
T invalid_element_value_
Definition: dynamic_buffer.hpp:35
 
static constexpr index_t PackedSize
Definition: dynamic_buffer.hpp:41
 
__host__ __device__ void Update(IndexType i, bool is_valid_element, const X &x)
Definition: dynamic_buffer.hpp:148
 
__host__ static constexpr __device__ bool IsDynamicBuffer()
Definition: dynamic_buffer.hpp:454
 
T * p_data_
Definition: dynamic_buffer.hpp:33
 
__host__ __device__ void AtomicAdd(IndexType i, bool is_valid_element, const X &x)
Definition: dynamic_buffer.hpp:366
 
__host__ static constexpr __device__ AddressSpaceEnum GetAddressSpace()
Definition: dynamic_buffer.hpp:62
 
__host__ __device__ void DirectCopyToLds(DstBuffer &dst_buf, IndexType src_offset, IndexType dst_offset, bool is_valid_element) const
Definition: dynamic_buffer.hpp:198
 
__host__ __device__ void Set(IndexType i, bool is_valid_element, const X &x)
Definition: dynamic_buffer.hpp:222
 
T type
Definition: dynamic_buffer.hpp:31
 
__host__ static constexpr __device__ bool IsStaticBuffer()
Definition: dynamic_buffer.hpp:452
 
__host__ constexpr __device__ DynamicBuffer(T *p_data, ElementSpaceSize element_space_size)
Definition: dynamic_buffer.hpp:48
 
__host__ __device__ void AtomicMax(IndexType i, bool is_valid_element, const X &x)
Definition: dynamic_buffer.hpp:419
 
Definition: data_type.hpp:228
 
Definition: data_type.hpp:197
 
Definition: data_type.hpp:38
 
Definition: functional2.hpp:33
 
Definition: dtype_vector.hpp:10