/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor/static_tensor.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor/static_tensor.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor/static_tensor.hpp Source File
static_tensor.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #ifndef CK_STATIC_TENSOR_HPP
5 #define CK_STATIC_TENSOR_HPP
6 
7 namespace ck {
8 
9 // StaticTensor for Scalar
10 template <AddressSpaceEnum AddressSpace,
11  typename T,
12  typename TensorDesc,
13  bool InvalidElementUseNumericalZeroValue,
14  typename enable_if<TensorDesc::IsKnownAtCompileTime(), bool>::type = false>
16 {
17  static constexpr auto desc_ = TensorDesc{};
18  static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension();
19  static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize();
20 
21  __host__ __device__ constexpr StaticTensor() : invalid_element_scalar_value_{0} {}
22 
23  __host__ __device__ constexpr StaticTensor(T invalid_element_value)
24  : invalid_element_scalar_value_{invalid_element_value}
25  {
26  }
27 
28  // read access
29  template <typename Idx,
30  typename enable_if<is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
31  bool>::type = false>
32  __host__ __device__ constexpr const T& operator[](Idx) const
33  {
34  constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
35 
36  constexpr index_t offset = coord.GetOffset();
37 
38  constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
39 
40  if constexpr(is_valid)
41  {
42  return data_[Number<offset>{}];
43  }
44  else
45  {
46  if constexpr(InvalidElementUseNumericalZeroValue)
47  {
48  return zero_scalar_value_;
49  }
50  else
51  {
53  }
54  }
55  }
56 
57  // write access
58  template <typename Idx,
59  typename enable_if<is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
60  bool>::type = false>
61  __host__ __device__ constexpr T& operator()(Idx)
62  {
63  constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
64 
65  constexpr index_t offset = coord.GetOffset();
66 
67  constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
68 
69  if constexpr(is_valid)
70  {
71  return data_(Number<offset>{});
72  }
73  else
74  {
76  }
77  }
78 
80  static constexpr T zero_scalar_value_ = T{0};
83 };
84 
85 // StaticTensor for vector
86 template <AddressSpaceEnum AddressSpace,
87  typename S,
88  index_t ScalarPerVector,
89  typename TensorDesc,
90  bool InvalidElementUseNumericalZeroValue,
91  typename enable_if<TensorDesc::IsKnownAtCompileTime(), bool>::type = false>
93 {
94  static constexpr auto desc_ = TensorDesc{};
95  static constexpr index_t ndim_ = TensorDesc::GetNumOfDimension();
96  static constexpr index_t element_space_size_ = desc_.GetElementSpaceSize();
97 
98  static constexpr index_t num_of_vector_ =
100 
102 
103  __host__ __device__ constexpr StaticTensorTupleOfVectorBuffer()
105  {
106  }
107 
108  __host__ __device__ constexpr StaticTensorTupleOfVectorBuffer(S invalid_element_value)
109  : invalid_element_scalar_value_{invalid_element_value}
110  {
111  }
112 
113  // Get S
114  // Idx is for S, not V
115  template <typename Idx,
116  typename enable_if<is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
117  bool>::type = false>
118  __host__ __device__ constexpr const S& operator[](Idx) const
119  {
120  constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
121 
122  constexpr index_t offset = coord.GetOffset();
123 
124  constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
125 
126  if constexpr(is_valid)
127  {
128  return data_[Number<offset>{}];
129  }
130  else
131  {
132  if constexpr(InvalidElementUseNumericalZeroValue)
133  {
134  return zero_scalar_value_;
135  }
136  else
137  {
139  }
140  }
141  }
142 
143  // Set S
144  // Idx is for S, not V
145  template <typename Idx,
146  typename enable_if<is_known_at_compile_time<Idx>::value && Idx::Size() == ndim_,
147  bool>::type = false>
148  __host__ __device__ constexpr S& operator()(Idx)
149  {
150  constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
151 
152  constexpr index_t offset = coord.GetOffset();
153 
154  constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
155 
156  if constexpr(is_valid)
157  {
158  return data_(Number<offset>{});
159  }
160  else
161  {
163  }
164  }
165 
166  // Get X
167  // Idx is for S, not X. Idx should be aligned with X
168  template <typename X,
169  typename Idx,
170  typename enable_if<(has_same_scalar_type<S, X>::value || !is_native_type<S>()) &&
172  bool>::type = false>
173  __host__ __device__ constexpr X GetAsType(Idx) const
174  {
175  constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
176 
177  constexpr index_t offset = coord.GetOffset();
178 
179  constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
180 
181  if constexpr(is_valid)
182  {
183  return data_.template GetAsType<X>(Number<offset>{});
184  }
185  else
186  {
187  if constexpr(InvalidElementUseNumericalZeroValue)
188  {
189  // TODO: is this right way to initialize a vector?
190  return X{0};
191  }
192  else
193  {
194  // TODO: is this right way to initialize a vector?
196  }
197  }
198  }
199 
200  // Set X
201  // Idx is for S, not X. Idx should be aligned with X
202  template <typename X,
203  typename Idx,
204  typename enable_if<(has_same_scalar_type<S, X>::value || !is_native_type<S>()) &&
206  bool>::type = false>
207  __host__ __device__ constexpr void SetAsType(Idx, X x)
208  {
209  constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
210 
211  constexpr index_t offset = coord.GetOffset();
212 
213  constexpr bool is_valid = coordinate_has_valid_offset(desc_, coord);
214 
215  if constexpr(is_valid)
216  {
217  data_.template SetAsType<X>(Number<offset>{}, x);
218  }
219  }
220 
221  // Get read access to V. No is_valid check
222  // Idx is for S, not V. Idx should be aligned with V
223  template <typename Idx>
224  __host__ __device__ constexpr const V& GetVectorTypeReference(Idx) const
225  {
226  constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
227 
228  constexpr index_t offset = coord.GetOffset();
229 
231  }
232 
233  // Get read access to V. No is_valid check
234  // Idx is for S, not V. Idx should be aligned with V
235  template <typename Idx>
236  __host__ __device__ constexpr V& GetVectorTypeReference(Idx)
237  {
238  constexpr auto coord = make_tensor_coordinate(desc_, to_multi_index(Idx{}));
239 
240  constexpr index_t offset = coord.GetOffset();
241 
243  }
244 
246  static constexpr S zero_scalar_value_ = S{0};
249 };
250 
251 template <AddressSpaceEnum AddressSpace,
252  typename T,
253  typename TensorDesc,
254  typename enable_if<TensorDesc::IsKnownAtCompileTime(), bool>::type = false>
255 __host__ __device__ constexpr auto make_static_tensor(TensorDesc)
256 {
258 }
259 
260 template <
261  AddressSpaceEnum AddressSpace,
262  typename T,
263  typename TensorDesc,
264  typename X,
265  typename enable_if<TensorDesc::IsKnownAtCompileTime(), bool>::type = false,
266  typename enable_if<is_same<remove_cvref_t<T>, remove_cvref_t<X>>::value, bool>::type = false>
267 __host__ __device__ constexpr auto make_static_tensor(TensorDesc, X invalid_element_value)
268 {
269  return StaticTensor<AddressSpace, T, TensorDesc, true>{invalid_element_value};
270 }
271 
272 } // namespace ck
273 #endif
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_static_tensor(TensorDesc)
Definition: static_tensor.hpp:255
__host__ constexpr __device__ auto to_multi_index(const T &x)
Definition: array_multi_index.hpp:28
AddressSpaceEnum
Definition: amd_address_space.hpp:15
__host__ constexpr __device__ auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition: tensor_descriptor.hpp:407
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:10
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ bool coordinate_has_valid_offset(const TensorDesc &tensor_desc, const TensorCoord &coord)
Definition: tensor_descriptor.hpp:587
__host__ constexpr __device__ const auto & GetVectorTypeReference(Number< I > i) const
Definition: static_buffer.hpp:156
Definition: static_tensor.hpp:16
T ignored_element_scalar_
Definition: static_tensor.hpp:82
static constexpr T zero_scalar_value_
Definition: static_tensor.hpp:80
__host__ constexpr __device__ StaticTensor(T invalid_element_value)
Definition: static_tensor.hpp:23
static constexpr index_t ndim_
Definition: static_tensor.hpp:18
__host__ constexpr __device__ T & operator()(Idx)
Definition: static_tensor.hpp:61
static constexpr index_t element_space_size_
Definition: static_tensor.hpp:19
static constexpr auto desc_
Definition: static_tensor.hpp:17
const T invalid_element_scalar_value_
Definition: static_tensor.hpp:81
__host__ constexpr __device__ const T & operator[](Idx) const
Definition: static_tensor.hpp:32
__host__ constexpr __device__ StaticTensor()
Definition: static_tensor.hpp:21
StaticBuffer< AddressSpace, T, element_space_size_, true > data_
Definition: static_tensor.hpp:79
Definition: static_tensor.hpp:93
static constexpr index_t num_of_vector_
Definition: static_tensor.hpp:98
__host__ constexpr __device__ S & operator()(Idx)
Definition: static_tensor.hpp:148
const S invalid_element_scalar_value_
Definition: static_tensor.hpp:247
__host__ constexpr __device__ void SetAsType(Idx, X x)
Definition: static_tensor.hpp:207
S ignored_element_scalar_
Definition: static_tensor.hpp:248
static constexpr S zero_scalar_value_
Definition: static_tensor.hpp:246
__host__ constexpr __device__ StaticTensorTupleOfVectorBuffer()
Definition: static_tensor.hpp:103
__host__ constexpr __device__ const S & operator[](Idx) const
Definition: static_tensor.hpp:118
__host__ constexpr __device__ const V & GetVectorTypeReference(Idx) const
Definition: static_tensor.hpp:224
StaticBufferTupleOfVector< AddressSpace, S, num_of_vector_, ScalarPerVector, true > data_
Definition: static_tensor.hpp:245
__host__ constexpr __device__ V & GetVectorTypeReference(Idx)
Definition: static_tensor.hpp:236
__host__ constexpr __device__ X GetAsType(Idx) const
Definition: static_tensor.hpp:173
static constexpr index_t element_space_size_
Definition: static_tensor.hpp:96
static constexpr auto desc_
Definition: static_tensor.hpp:94
__host__ constexpr __device__ StaticTensorTupleOfVectorBuffer(S invalid_element_value)
Definition: static_tensor.hpp:108
static constexpr index_t ndim_
Definition: static_tensor.hpp:95
Definition: integral_constant.hpp:10
Definition: is_known_at_compile_time.hpp:14
Definition: type.hpp:177
Definition: data_type.hpp:347