/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/library/utility/host_tensor.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/library/utility/host_tensor.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/library/utility/host_tensor.hpp Source File
host_tensor.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <algorithm>
7 #include <cassert>
8 #include <iostream>
9 #include <numeric>
10 #include <thread>
11 #include <utility>
12 #include <vector>
13 
14 #include "ck/utility/data_type.hpp"
15 #include "ck/utility/span.hpp"
17 
20 
21 template <typename Range>
22 std::ostream& LogRange(std::ostream& os, Range&& range, std::string delim)
23 {
24  bool first = true;
25  for(auto&& v : range)
26  {
27  if(first)
28  first = false;
29  else
30  os << delim;
31  os << v;
32  }
33  return os;
34 }
35 
36 template <typename T, typename Range>
37 std::ostream& LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
38 {
39  bool first = true;
40  for(auto&& v : range)
41  {
42  if(first)
43  first = false;
44  else
45  os << delim;
46 
47  using RangeType = ck::remove_cvref_t<decltype(v)>;
48  if constexpr(std::is_same_v<RangeType, ck::f8_t> || std::is_same_v<RangeType, ck::bf8_t> ||
49  std::is_same_v<RangeType, ck::bhalf_t>)
50  {
51  os << ck::type_convert<float>(v);
52  }
53  else if constexpr(std::is_same_v<RangeType, ck::pk_i4_t>)
54  {
55  const auto packed_floats = ck::type_convert<ck::float2_t>(v);
56  const ck::vector_type<float, 2> vector_of_floats{packed_floats};
57  os << vector_of_floats.template AsType<float>()[ck::Number<0>{}] << delim
58  << vector_of_floats.template AsType<float>()[ck::Number<1>{}];
59  }
60  else
61  {
62  os << static_cast<T>(v);
63  }
64  }
65  return os;
66 }
67 
68 template <typename F, typename T, std::size_t... Is>
69 auto call_f_unpack_args_impl(F f, T args, std::index_sequence<Is...>)
70 {
71  return f(std::get<Is>(args)...);
72 }
73 
74 template <typename F, typename T>
75 auto call_f_unpack_args(F f, T args)
76 {
77  constexpr std::size_t N = std::tuple_size<T>{};
78 
79  return call_f_unpack_args_impl(f, args, std::make_index_sequence<N>{});
80 }
81 
82 template <typename F, typename T, std::size_t... Is>
83 auto construct_f_unpack_args_impl(T args, std::index_sequence<Is...>)
84 {
85  return F(std::get<Is>(args)...);
86 }
87 
88 template <typename F, typename T>
89 auto construct_f_unpack_args(F, T args)
90 {
91  constexpr std::size_t N = std::tuple_size<T>{};
92 
93  return construct_f_unpack_args_impl<F>(args, std::make_index_sequence<N>{});
94 }
95 
97 {
98  HostTensorDescriptor() = default;
99 
101 
102  template <typename X, typename = std::enable_if_t<std::is_convertible_v<X, std::size_t>>>
103  HostTensorDescriptor(const std::initializer_list<X>& lens) : mLens(lens.begin(), lens.end())
104  {
105  this->CalculateStrides();
106  }
107 
108  HostTensorDescriptor(const std::initializer_list<ck::long_index_t>& lens)
109  : mLens(lens.begin(), lens.end())
110  {
111  this->CalculateStrides();
112  }
113 
114  template <typename Lengths,
115  typename = std::enable_if_t<
116  std::is_convertible_v<ck::ranges::range_value_t<Lengths>, std::size_t> ||
117  std::is_convertible_v<ck::ranges::range_value_t<Lengths>, ck::long_index_t>>>
118  HostTensorDescriptor(const Lengths& lens) : mLens(lens.begin(), lens.end())
119  {
120  this->CalculateStrides();
121  }
122 
123  template <typename X,
124  typename Y,
125  typename = std::enable_if_t<std::is_convertible_v<X, std::size_t> &&
126  std::is_convertible_v<Y, std::size_t>>>
127  HostTensorDescriptor(const std::initializer_list<X>& lens,
128  const std::initializer_list<Y>& strides)
129  : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
130  {
131  }
132 
133  HostTensorDescriptor(const std::initializer_list<ck::long_index_t>& lens,
134  const std::initializer_list<ck::long_index_t>& strides)
135  : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
136  {
137  }
138 
139  template <typename Lengths,
140  typename Strides,
141  typename = std::enable_if_t<
142  (std::is_convertible_v<ck::ranges::range_value_t<Lengths>, std::size_t> &&
143  std::is_convertible_v<ck::ranges::range_value_t<Strides>, std::size_t>) ||
144  (std::is_convertible_v<ck::ranges::range_value_t<Lengths>, ck::long_index_t> &&
145  std::is_convertible_v<ck::ranges::range_value_t<Strides>, ck::long_index_t>)>>
146  HostTensorDescriptor(const Lengths& lens, const Strides& strides)
147  : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
148  {
149  }
150 
151  std::size_t GetNumOfDimension() const;
152  std::size_t GetElementSize() const;
153  std::size_t GetElementSpaceSize() const;
154 
155  const std::vector<std::size_t>& GetLengths() const;
156  const std::vector<std::size_t>& GetStrides() const;
157 
158  template <typename... Is>
159  std::size_t GetOffsetFromMultiIndex(Is... is) const
160  {
161  assert(sizeof...(Is) == this->GetNumOfDimension());
162  std::initializer_list<std::size_t> iss{static_cast<std::size_t>(is)...};
163  return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
164  }
165 
166  std::size_t GetOffsetFromMultiIndex(std::vector<std::size_t> iss) const
167  {
168  return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
169  }
170 
171  friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc);
172 
173  private:
174  std::vector<std::size_t> mLens;
175  std::vector<std::size_t> mStrides;
176 };
177 
178 template <typename New2Old>
180  const New2Old& new2old)
181 {
182  std::vector<std::size_t> new_lengths(a.GetNumOfDimension());
183  std::vector<std::size_t> new_strides(a.GetNumOfDimension());
184 
185  for(std::size_t i = 0; i < a.GetNumOfDimension(); i++)
186  {
187  new_lengths[i] = a.GetLengths()[new2old[i]];
188  new_strides[i] = a.GetStrides()[new2old[i]];
189  }
190 
191  return HostTensorDescriptor(new_lengths, new_strides);
192 }
193 
194 struct joinable_thread : std::thread
195 {
196  template <typename... Xs>
197  joinable_thread(Xs&&... xs) : std::thread(std::forward<Xs>(xs)...)
198  {
199  }
200 
203 
205  {
206  if(this->joinable())
207  this->join();
208  }
209 };
210 
211 template <typename F, typename... Xs>
213 {
214  F mF;
215  static constexpr std::size_t NDIM = sizeof...(Xs);
216  std::array<std::size_t, NDIM> mLens;
217  std::array<std::size_t, NDIM> mStrides;
218  std::size_t mN1d;
219 
220  ParallelTensorFunctor(F f, Xs... xs) : mF(f), mLens({static_cast<std::size_t>(xs)...})
221  {
222  mStrides.back() = 1;
223  std::partial_sum(mLens.rbegin(),
224  mLens.rend() - 1,
225  mStrides.rbegin() + 1,
226  std::multiplies<std::size_t>());
227  mN1d = mStrides[0] * mLens[0];
228  }
229 
230  std::array<std::size_t, NDIM> GetNdIndices(std::size_t i) const
231  {
232  std::array<std::size_t, NDIM> indices;
233 
234  for(std::size_t idim = 0; idim < NDIM; ++idim)
235  {
236  indices[idim] = i / mStrides[idim];
237  i -= indices[idim] * mStrides[idim];
238  }
239 
240  return indices;
241  }
242 
243  void operator()(std::size_t num_thread = 1) const
244  {
245  std::size_t work_per_thread = (mN1d + num_thread - 1) / num_thread;
246 
247  std::vector<joinable_thread> threads(num_thread);
248 
249  for(std::size_t it = 0; it < num_thread; ++it)
250  {
251  std::size_t iw_begin = it * work_per_thread;
252  std::size_t iw_end = std::min((it + 1) * work_per_thread, mN1d);
253 
254  auto f = [=] {
255  for(std::size_t iw = iw_begin; iw < iw_end; ++iw)
256  {
258  }
259  };
260  threads[it] = joinable_thread(f);
261  }
262  }
263 };
264 
265 template <typename F, typename... Xs>
266 auto make_ParallelTensorFunctor(F f, Xs... xs)
267 {
268  return ParallelTensorFunctor<F, Xs...>(f, xs...);
269 }
270 
271 template <typename T>
272 struct Tensor
273 {
275  using Data = std::vector<T>;
276 
277  template <typename X>
278  Tensor(std::initializer_list<X> lens) : mDesc(lens), mData(GetElementSpaceSize())
279  {
280  }
281 
282  template <typename X, typename Y>
283  Tensor(std::initializer_list<X> lens, std::initializer_list<Y> strides)
284  : mDesc(lens, strides), mData(GetElementSpaceSize())
285  {
286  }
287 
288  template <typename Lengths>
289  Tensor(const Lengths& lens) : mDesc(lens), mData(GetElementSpaceSize())
290  {
291  }
292 
293  template <typename Lengths, typename Strides>
294  Tensor(const Lengths& lens, const Strides& strides)
295  : mDesc(lens, strides), mData(GetElementSpaceSize())
296  {
297  }
298 
299  Tensor(const Descriptor& desc) : mDesc(desc), mData(GetElementSpaceSize()) {}
300 
301  template <typename OutT>
303  {
304  Tensor<OutT> ret(mDesc);
305 
307  mData, ret.mData.begin(), [](auto value) { return ck::type_convert<OutT>(value); });
308 
309  return ret;
310  }
311 
312  Tensor() = delete;
313  Tensor(const Tensor&) = default;
314  Tensor(Tensor&&) = default;
315 
316  ~Tensor() = default;
317 
318  Tensor& operator=(const Tensor&) = default;
319  Tensor& operator=(Tensor&&) = default;
320 
321  template <typename FromT>
322  explicit Tensor(const Tensor<FromT>& other) : Tensor(other.template CopyAsType<T>())
323  {
324  }
325 
326  decltype(auto) GetLengths() const { return mDesc.GetLengths(); }
327 
328  decltype(auto) GetStrides() const { return mDesc.GetStrides(); }
329 
330  std::size_t GetNumOfDimension() const { return mDesc.GetNumOfDimension(); }
331 
332  std::size_t GetElementSize() const { return mDesc.GetElementSize(); }
333 
334  std::size_t GetElementSpaceSize() const
335  {
337  {
338  return (mDesc.GetElementSpaceSize() + 1) / 2;
339  }
340  else
341  {
342  return mDesc.GetElementSpaceSize();
343  }
344  }
345 
346  std::size_t GetElementSpaceSizeInBytes() const { return sizeof(T) * GetElementSpaceSize(); }
347 
348  void SetZero() { ck::ranges::fill<T>(mData, T{0}); }
349 
350  template <typename F>
351  void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)
352  {
353  if(rank == mDesc.GetNumOfDimension())
354  {
355  f(*this, idx);
356  return;
357  }
358  // else
359  for(size_t i = 0; i < mDesc.GetLengths()[rank]; i++)
360  {
361  idx[rank] = i;
362  ForEach_impl(std::forward<F>(f), idx, rank + 1);
363  }
364  }
365 
366  template <typename F>
367  void ForEach(F&& f)
368  {
369  std::vector<size_t> idx(mDesc.GetNumOfDimension(), 0);
370  ForEach_impl(std::forward<F>(f), idx, size_t(0));
371  }
372 
373  template <typename F>
374  void ForEach_impl(const F&& f, std::vector<size_t>& idx, size_t rank) const
375  {
376  if(rank == mDesc.GetNumOfDimension())
377  {
378  f(*this, idx);
379  return;
380  }
381  // else
382  for(size_t i = 0; i < mDesc.GetLengths()[rank]; i++)
383  {
384  idx[rank] = i;
385  ForEach_impl(std::forward<const F>(f), idx, rank + 1);
386  }
387  }
388 
389  template <typename F>
390  void ForEach(const F&& f) const
391  {
392  std::vector<size_t> idx(mDesc.GetNumOfDimension(), 0);
393  ForEach_impl(std::forward<const F>(f), idx, size_t(0));
394  }
395 
396  template <typename G>
397  void GenerateTensorValue(G g, std::size_t num_thread = 1)
398  {
399  switch(mDesc.GetNumOfDimension())
400  {
401  case 1: {
402  auto f = [&](auto i) { (*this)(i) = g(i); };
403  make_ParallelTensorFunctor(f, mDesc.GetLengths()[0])(num_thread);
404  break;
405  }
406  case 2: {
407  auto f = [&](auto i0, auto i1) { (*this)(i0, i1) = g(i0, i1); };
408  make_ParallelTensorFunctor(f, mDesc.GetLengths()[0], mDesc.GetLengths()[1])(num_thread);
409  break;
410  }
411  case 3: {
412  auto f = [&](auto i0, auto i1, auto i2) { (*this)(i0, i1, i2) = g(i0, i1, i2); };
414  f, mDesc.GetLengths()[0], mDesc.GetLengths()[1], mDesc.GetLengths()[2])(num_thread);
415  break;
416  }
417  case 4: {
418  auto f = [&](auto i0, auto i1, auto i2, auto i3) {
419  (*this)(i0, i1, i2, i3) = g(i0, i1, i2, i3);
420  };
422  mDesc.GetLengths()[0],
423  mDesc.GetLengths()[1],
424  mDesc.GetLengths()[2],
425  mDesc.GetLengths()[3])(num_thread);
426  break;
427  }
428  case 5: {
429  auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4) {
430  (*this)(i0, i1, i2, i3, i4) = g(i0, i1, i2, i3, i4);
431  };
433  mDesc.GetLengths()[0],
434  mDesc.GetLengths()[1],
435  mDesc.GetLengths()[2],
436  mDesc.GetLengths()[3],
437  mDesc.GetLengths()[4])(num_thread);
438  break;
439  }
440  case 6: {
441  auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4, auto i5) {
442  (*this)(i0, i1, i2, i3, i4, i5) = g(i0, i1, i2, i3, i4, i5);
443  };
445  mDesc.GetLengths()[0],
446  mDesc.GetLengths()[1],
447  mDesc.GetLengths()[2],
448  mDesc.GetLengths()[3],
449  mDesc.GetLengths()[4],
450  mDesc.GetLengths()[5])(num_thread);
451  break;
452  }
453  case 12: {
454  auto f = [&](auto i0,
455  auto i1,
456  auto i2,
457  auto i3,
458  auto i4,
459  auto i5,
460  auto i6,
461  auto i7,
462  auto i8,
463  auto i9,
464  auto i10,
465  auto i11) {
466  (*this)(i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11) =
467  g(i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11);
468  };
470  mDesc.GetLengths()[0],
471  mDesc.GetLengths()[1],
472  mDesc.GetLengths()[2],
473  mDesc.GetLengths()[3],
474  mDesc.GetLengths()[4],
475  mDesc.GetLengths()[5],
476  mDesc.GetLengths()[6],
477  mDesc.GetLengths()[7],
478  mDesc.GetLengths()[8],
479  mDesc.GetLengths()[9],
480  mDesc.GetLengths()[10],
481  mDesc.GetLengths()[11])(num_thread);
482  break;
483  }
484  default: throw std::runtime_error("unspported dimension");
485  }
486  }
487 
488  template <typename... Is>
489  std::size_t GetOffsetFromMultiIndex(Is... is) const
490  {
492  {
493  return mDesc.GetOffsetFromMultiIndex(is...) / 2;
494  }
495  else
496  {
497  return mDesc.GetOffsetFromMultiIndex(is...);
498  }
499  }
500 
501  template <typename... Is>
502  T& operator()(Is... is)
503  {
505  {
506  return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
507  }
508  else
509  {
510  return mData[mDesc.GetOffsetFromMultiIndex(is...)];
511  }
512  }
513 
514  template <typename... Is>
515  const T& operator()(Is... is) const
516  {
518  {
519  return mData[mDesc.GetOffsetFromMultiIndex(is...) / 2];
520  }
521  else
522  {
523  return mData[mDesc.GetOffsetFromMultiIndex(is...)];
524  }
525  }
526 
527  T& operator()(std::vector<std::size_t> idx)
528  {
530  {
531  return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
532  }
533  else
534  {
535  return mData[mDesc.GetOffsetFromMultiIndex(idx)];
536  }
537  }
538 
539  const T& operator()(std::vector<std::size_t> idx) const
540  {
542  {
543  return mData[mDesc.GetOffsetFromMultiIndex(idx) / 2];
544  }
545  else
546  {
547  return mData[mDesc.GetOffsetFromMultiIndex(idx)];
548  }
549  }
550 
551  typename Data::iterator begin() { return mData.begin(); }
552 
553  typename Data::iterator end() { return mData.end(); }
554 
555  typename Data::pointer data() { return mData.data(); }
556 
557  typename Data::const_iterator begin() const { return mData.begin(); }
558 
559  typename Data::const_iterator end() const { return mData.end(); }
560 
561  typename Data::const_pointer data() const { return mData.data(); }
562 
563  typename Data::size_type size() const { return mData.size(); }
564 
565  template <typename U = T>
566  auto AsSpan() const
567  {
568  constexpr std::size_t FromSize = sizeof(T);
569  constexpr std::size_t ToSize = sizeof(U);
570 
571  using Element = std::add_const_t<std::remove_reference_t<U>>;
572  return ck::span<Element>{reinterpret_cast<Element*>(data()), size() * FromSize / ToSize};
573  }
574 
575  template <typename U = T>
576  auto AsSpan()
577  {
578  constexpr std::size_t FromSize = sizeof(T);
579  constexpr std::size_t ToSize = sizeof(U);
580 
581  using Element = std::remove_reference_t<U>;
582  return ck::span<Element>{reinterpret_cast<Element*>(data()), size() * FromSize / ToSize};
583  }
584 
587 };
Definition: span.hpp:14
__host__ constexpr __device__ auto rank([[maybe_unused]] const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition: layout_utils.hpp:310
auto call_f_unpack_args_impl(F f, T args, std::index_sequence< Is... >)
Definition: host_tensor.hpp:69
std::ostream & LogRangeAsType(std::ostream &os, Range &&range, std::string delim)
Definition: host_tensor.hpp:37
auto construct_f_unpack_args_impl(T args, std::index_sequence< Is... >)
Definition: host_tensor.hpp:83
auto call_f_unpack_args(F f, T args)
Definition: host_tensor.hpp:75
auto construct_f_unpack_args(F, T args)
Definition: host_tensor.hpp:89
auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:266
std::ostream & LogRange(std::ostream &os, Range &&range, std::string delim)
Definition: host_tensor.hpp:22
HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old(const HostTensorDescriptor &a, const New2Old &new2old)
Definition: host_tensor.hpp:179
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
auto transform(InputRange &&range, OutputIterator iter, UnaryOperation unary_op) -> decltype(std::transform(std::begin(range), std::end(range), iter, unary_op))
Definition: algorithm.hpp:36
iter_value_t< ranges::iterator_t< R > > range_value_t
Definition: ranges.hpp:28
int64_t long_index_t
Definition: ck.hpp:290
constexpr bool is_same_v
Definition: type.hpp:283
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:13
__device__ void inner_product(const TA &a, const TB &b, TC &c)
Definition: host_tensor.hpp:97
HostTensorDescriptor(const Lengths &lens)
Definition: host_tensor.hpp:118
const std::vector< std::size_t > & GetStrides() const
HostTensorDescriptor(const std::initializer_list< X > &lens)
Definition: host_tensor.hpp:103
std::size_t GetElementSize() const
const std::vector< std::size_t > & GetLengths() const
std::size_t GetOffsetFromMultiIndex(Is... is) const
Definition: host_tensor.hpp:159
HostTensorDescriptor(const std::initializer_list< X > &lens, const std::initializer_list< Y > &strides)
Definition: host_tensor.hpp:127
HostTensorDescriptor(const std::initializer_list< ck::long_index_t > &lens, const std::initializer_list< ck::long_index_t > &strides)
Definition: host_tensor.hpp:133
std::size_t GetOffsetFromMultiIndex(std::vector< std::size_t > iss) const
Definition: host_tensor.hpp:166
std::size_t GetNumOfDimension() const
std::size_t GetElementSpaceSize() const
HostTensorDescriptor()=default
HostTensorDescriptor(const Lengths &lens, const Strides &strides)
Definition: host_tensor.hpp:146
friend std::ostream & operator<<(std::ostream &os, const HostTensorDescriptor &desc)
HostTensorDescriptor(const std::initializer_list< ck::long_index_t > &lens)
Definition: host_tensor.hpp:108
Definition: host_tensor.hpp:213
std::array< std::size_t, NDIM > GetNdIndices(std::size_t i) const
Definition: host_tensor.hpp:230
F mF
Definition: host_tensor.hpp:214
std::size_t mN1d
Definition: host_tensor.hpp:218
ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:220
std::array< std::size_t, NDIM > mLens
Definition: host_tensor.hpp:216
std::array< std::size_t, NDIM > mStrides
Definition: host_tensor.hpp:217
void operator()(std::size_t num_thread=1) const
Definition: host_tensor.hpp:243
static constexpr std::size_t NDIM
Definition: host_tensor.hpp:215
Tensor wrapper that performs static and dynamic buffer logic. The tensor is based on a descriptor sto...
Definition: host_tensor.hpp:273
auto AsSpan() const
Definition: host_tensor.hpp:566
Tensor(const Lengths &lens, const Strides &strides)
Definition: host_tensor.hpp:294
Tensor()=delete
std::size_t GetNumOfDimension() const
Definition: host_tensor.hpp:330
void ForEach(const F &&f) const
Definition: host_tensor.hpp:390
decltype(auto) GetLengths() const
Definition: host_tensor.hpp:326
Data::const_iterator end() const
Definition: host_tensor.hpp:559
std::size_t GetOffsetFromMultiIndex(Is... is) const
Definition: host_tensor.hpp:489
Tensor< OutT > CopyAsType() const
Definition: host_tensor.hpp:302
T & operator()(std::vector< std::size_t > idx)
Definition: host_tensor.hpp:527
void ForEach(F &&f)
Definition: host_tensor.hpp:367
Data::pointer data()
Definition: host_tensor.hpp:555
void ForEach_impl(F &&f, std::vector< size_t > &idx, size_t rank)
Definition: host_tensor.hpp:351
std::size_t GetElementSpaceSizeInBytes() const
Definition: host_tensor.hpp:346
void ForEach_impl(const F &&f, std::vector< size_t > &idx, size_t rank) const
Definition: host_tensor.hpp:374
Tensor & operator=(const Tensor &)=default
std::vector< T > Data
Definition: host_tensor.hpp:275
Data mData
Definition: host_tensor.hpp:586
Data::iterator end()
Definition: host_tensor.hpp:553
std::size_t GetElementSize() const
Definition: host_tensor.hpp:332
~Tensor()=default
void SetZero()
Definition: host_tensor.hpp:348
Tensor(const Lengths &lens)
Definition: host_tensor.hpp:289
Tensor(Tensor &&)=default
const T & operator()(Is... is) const
Definition: host_tensor.hpp:515
Data::const_pointer data() const
Definition: host_tensor.hpp:561
auto AsSpan()
Definition: host_tensor.hpp:576
Data::iterator begin()
Definition: host_tensor.hpp:551
Tensor(std::initializer_list< X > lens, std::initializer_list< Y > strides)
Definition: host_tensor.hpp:283
Tensor(const Tensor &)=default
Tensor(const Descriptor &desc)
Definition: host_tensor.hpp:299
const T & operator()(std::vector< std::size_t > idx) const
Definition: host_tensor.hpp:539
Descriptor mDesc
Definition: host_tensor.hpp:585
Tensor & operator=(Tensor &&)=default
Data::const_iterator begin() const
Definition: host_tensor.hpp:557
std::size_t GetElementSpaceSize() const
Definition: host_tensor.hpp:334
Tensor(const Tensor< FromT > &other)
Definition: host_tensor.hpp:322
Data::size_type size() const
Definition: host_tensor.hpp:563
void GenerateTensorValue(G g, std::size_t num_thread=1)
Definition: host_tensor.hpp:397
decltype(auto) GetStrides() const
Definition: host_tensor.hpp:328
T & operator()(Is... is)
Definition: host_tensor.hpp:502
Tensor(std::initializer_list< X > lens)
Definition: host_tensor.hpp:278
Definition: integral_constant.hpp:10
Definition: data_type.hpp:320
Definition: data_type.hpp:347
Definition: host_tensor.hpp:195
joinable_thread(joinable_thread &&)=default
joinable_thread(Xs &&... xs)
Definition: host_tensor.hpp:197
~joinable_thread()
Definition: host_tensor.hpp:204
joinable_thread & operator=(joinable_thread &&)=default