/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/host/host_tensor.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/host/host_tensor.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/host/host_tensor.hpp Source File
host_tensor.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <algorithm>
7 #include <cassert>
8 #include <iostream>
9 #include <iomanip>
10 #include <numeric>
11 #include <utility>
12 #include <vector>
13 #include <functional>
14 #include <fstream>
15 
16 #include "ck_tile/core.hpp"
18 #include "ck_tile/host/ranges.hpp"
19 
20 namespace ck_tile {
21 
22 template <typename Range>
23 CK_TILE_HOST std::ostream& LogRange(std::ostream& os,
24  Range&& range,
25  std::string delim,
26  int precision = std::cout.precision(),
27  int width = 0)
28 {
29  bool first = true;
30  for(auto&& v : range)
31  {
32  if(first)
33  first = false;
34  else
35  os << delim;
36  os << std::setw(width) << std::setprecision(precision) << v;
37  }
38  return os;
39 }
40 
41 template <typename T, typename Range>
42 CK_TILE_HOST std::ostream& LogRangeAsType(std::ostream& os,
43  Range&& range,
44  std::string delim,
45  int precision = std::cout.precision(),
46  int width = 0)
47 {
48  bool first = true;
49  for(auto&& v : range)
50  {
51  if(first)
52  first = false;
53  else
54  os << delim;
55  os << std::setw(width) << std::setprecision(precision) << static_cast<T>(v);
56  }
57  return os;
58 }
59 
60 template <typename F, typename T, std::size_t... Is>
61 CK_TILE_HOST auto call_f_unpack_args_impl(F f, T args, std::index_sequence<Is...>)
62 {
63  return f(std::get<Is>(args)...);
64 }
65 
66 template <typename F, typename T>
68 {
69  constexpr std::size_t N = std::tuple_size<T>{};
70 
71  return call_f_unpack_args_impl(f, args, std::make_index_sequence<N>{});
72 }
73 
74 template <typename F, typename T, std::size_t... Is>
75 CK_TILE_HOST auto construct_f_unpack_args_impl(T args, std::index_sequence<Is...>)
76 {
77  return F(std::get<Is>(args)...);
78 }
79 
80 template <typename F, typename T>
82 {
83  constexpr std::size_t N = std::tuple_size<T>{};
84 
85  return construct_f_unpack_args_impl<F>(args, std::make_index_sequence<N>{});
86 }
87 
89 {
90  HostTensorDescriptor() = default;
91 
93  {
94  mStrides.clear();
95  mStrides.resize(mLens.size(), 0);
96  if(mStrides.empty())
97  return;
98 
99  mStrides.back() = 1;
100  std::partial_sum(mLens.rbegin(),
101  mLens.rend() - 1,
102  mStrides.rbegin() + 1,
103  std::multiplies<std::size_t>());
104  }
105 
106  template <typename X, typename = std::enable_if_t<std::is_convertible_v<X, std::size_t>>>
107  HostTensorDescriptor(const std::initializer_list<X>& lens) : mLens(lens.begin(), lens.end())
108  {
109  this->CalculateStrides();
110  }
111 
112  template <typename Lengths,
113  typename = std::enable_if_t<
114  std::is_convertible_v<ck_tile::ranges::range_value_t<Lengths>, std::size_t>>>
115  HostTensorDescriptor(const Lengths& lens) : mLens(lens.begin(), lens.end())
116  {
117  this->CalculateStrides();
118  }
119 
120  template <typename X,
121  typename Y,
122  typename = std::enable_if_t<std::is_convertible_v<X, std::size_t> &&
123  std::is_convertible_v<Y, std::size_t>>>
124  HostTensorDescriptor(const std::initializer_list<X>& lens,
125  const std::initializer_list<Y>& strides)
126  : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
127  {
128  }
129 
130  template <typename Lengths,
131  typename Strides,
132  typename = std::enable_if_t<
133  std::is_convertible_v<ck_tile::ranges::range_value_t<Lengths>, std::size_t> &&
134  std::is_convertible_v<ck_tile::ranges::range_value_t<Strides>, std::size_t>>>
135  HostTensorDescriptor(const Lengths& lens, const Strides& strides)
136  : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
137  {
138  }
139 
140  std::size_t get_num_of_dimension() const { return mLens.size(); }
141  std::size_t get_element_size() const
142  {
143  assert(mLens.size() == mStrides.size());
144  return std::accumulate(
145  mLens.begin(), mLens.end(), std::size_t{1}, std::multiplies<std::size_t>());
146  }
147  std::size_t get_element_space_size() const
148  {
149  std::size_t space = 1;
150  for(std::size_t i = 0; i < mLens.size(); ++i)
151  {
152  if(mLens[i] == 0)
153  continue;
154 
155  space += (mLens[i] - 1) * mStrides[i];
156  }
157  return space;
158  }
159 
160  std::size_t get_length(std::size_t dim) const { return mLens[dim]; }
161 
162  const std::vector<std::size_t>& get_lengths() const { return mLens; }
163 
164  std::size_t get_stride(std::size_t dim) const { return mStrides[dim]; }
165 
166  const std::vector<std::size_t>& get_strides() const { return mStrides; }
167 
168  template <typename... Is>
169  std::size_t GetOffsetFromMultiIndex(Is... is) const
170  {
171  assert(sizeof...(Is) == this->get_num_of_dimension());
172  std::initializer_list<std::size_t> iss{static_cast<std::size_t>(is)...};
173  return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
174  }
175 
176  std::size_t GetOffsetFromMultiIndex(std::vector<std::size_t> iss) const
177  {
178  return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
179  }
180 
181  friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc)
182  {
183  os << "dim " << desc.get_num_of_dimension() << ", ";
184 
185  os << "lengths {";
186  LogRange(os, desc.get_lengths(), ", ");
187  os << "}, ";
188 
189  os << "strides {";
190  LogRange(os, desc.get_strides(), ", ");
191  os << "}";
192 
193  return os;
194  }
195 
196  private:
197  std::vector<std::size_t> mLens;
198  std::vector<std::size_t> mStrides;
199 };
200 
201 template <typename New2Old>
203  const HostTensorDescriptor& a, const New2Old& new2old)
204 {
205  std::vector<std::size_t> new_lengths(a.get_num_of_dimension());
206  std::vector<std::size_t> new_strides(a.get_num_of_dimension());
207 
208  for(std::size_t i = 0; i < a.get_num_of_dimension(); i++)
209  {
210  new_lengths[i] = a.get_lengths()[new2old[i]];
211  new_strides[i] = a.get_strides()[new2old[i]];
212  }
213 
214  return HostTensorDescriptor(new_lengths, new_strides);
215 }
216 
217 template <typename F, typename... Xs>
219 {
220  F mF;
221  static constexpr std::size_t NDIM = sizeof...(Xs);
222  std::array<std::size_t, NDIM> mLens;
223  std::array<std::size_t, NDIM> mStrides;
224  std::size_t mN1d;
225 
226  ParallelTensorFunctor(F f, Xs... xs) : mF(f), mLens({static_cast<std::size_t>(xs)...})
227  {
228  mStrides.back() = 1;
229  std::partial_sum(mLens.rbegin(),
230  mLens.rend() - 1,
231  mStrides.rbegin() + 1,
232  std::multiplies<std::size_t>());
233  mN1d = mStrides[0] * mLens[0];
234  }
235 
236  std::array<std::size_t, NDIM> GetNdIndices(std::size_t i) const
237  {
238  std::array<std::size_t, NDIM> indices;
239 
240  for(std::size_t idim = 0; idim < NDIM; ++idim)
241  {
242  indices[idim] = i / mStrides[idim];
243  i -= indices[idim] * mStrides[idim];
244  }
245 
246  return indices;
247  }
248 
249  void operator()(std::size_t num_thread = 1) const
250  {
251  std::size_t work_per_thread = (mN1d + num_thread - 1) / num_thread;
252 
253  std::vector<joinable_thread> threads(num_thread);
254 
255  for(std::size_t it = 0; it < num_thread; ++it)
256  {
257  std::size_t iw_begin = it * work_per_thread;
258  std::size_t iw_end = std::min((it + 1) * work_per_thread, mN1d);
259 
260  auto f = [this, iw_begin, iw_end] {
261  for(std::size_t iw = iw_begin; iw < iw_end; ++iw)
262  {
263  call_f_unpack_args(this->mF, this->GetNdIndices(iw));
264  }
265  };
266  threads[it] = joinable_thread(f);
267  }
268  }
269 };
270 
271 template <typename F, typename... Xs>
273 {
274  return ParallelTensorFunctor<F, Xs...>(f, xs...);
275 }
276 
277 template <typename T>
279 {
281  using Data = std::vector<T>;
282 
283  template <typename X>
284  HostTensor(std::initializer_list<X> lens) : mDesc(lens), mData(mDesc.get_element_space_size())
285  {
286  }
287 
288  template <typename X, typename Y>
289  HostTensor(std::initializer_list<X> lens, std::initializer_list<Y> strides)
290  : mDesc(lens, strides), mData(mDesc.get_element_space_size())
291  {
292  }
293 
294  template <typename Lengths>
295  HostTensor(const Lengths& lens) : mDesc(lens), mData(mDesc.get_element_space_size())
296  {
297  }
298 
299  template <typename Lengths, typename Strides>
300  HostTensor(const Lengths& lens, const Strides& strides)
301  : mDesc(lens, strides), mData(get_element_space_size())
302  {
303  }
304 
306 
307  template <typename OutT>
309  {
310  HostTensor<OutT> ret(mDesc);
311  std::transform(mData.cbegin(), mData.cend(), ret.mData.begin(), [](auto value) {
312  return ck_tile::type_convert<OutT>(value);
313  });
314  return ret;
315  }
316 
317  HostTensor() = delete;
318  HostTensor(const HostTensor&) = default;
319  HostTensor(HostTensor&&) = default;
320 
321  ~HostTensor() = default;
322 
323  HostTensor& operator=(const HostTensor&) = default;
325 
326  template <typename FromT>
327  explicit HostTensor(const HostTensor<FromT>& other) : HostTensor(other.template CopyAsType<T>())
328  {
329  }
330 
331  std::size_t get_length(std::size_t dim) const { return mDesc.get_length(dim); }
332 
333  decltype(auto) get_lengths() const { return mDesc.get_lengths(); }
334 
335  std::size_t get_stride(std::size_t dim) const { return mDesc.get_stride(dim); }
336 
337  decltype(auto) get_strides() const { return mDesc.get_strides(); }
338 
339  std::size_t get_num_of_dimension() const { return mDesc.get_num_of_dimension(); }
340 
341  std::size_t get_element_size() const { return mDesc.get_element_size(); }
342 
343  std::size_t get_element_space_size() const { return mDesc.get_element_space_size(); }
344 
346  {
347  return sizeof(T) * get_element_space_size();
348  }
349 
350  // void SetZero() { ck_tile::ranges::fill<T>(mData, 0); }
351  void SetZero() { std::fill(mData.begin(), mData.end(), 0); }
352 
353  template <typename F>
354  void ForEach_impl(F&& f, std::vector<size_t>& idx, size_t rank)
355  {
357  {
358  f(*this, idx);
359  return;
360  }
361  // else
362  for(size_t i = 0; i < mDesc.get_lengths()[rank]; i++)
363  {
364  idx[rank] = i;
365  ForEach_impl(std::forward<F>(f), idx, rank + 1);
366  }
367  }
368 
369  template <typename F>
370  void ForEach(F&& f)
371  {
372  std::vector<size_t> idx(mDesc.get_num_of_dimension(), 0);
373  ForEach_impl(std::forward<F>(f), idx, size_t(0));
374  }
375 
376  template <typename F>
377  void ForEach_impl(const F&& f, std::vector<size_t>& idx, size_t rank) const
378  {
380  {
381  f(*this, idx);
382  return;
383  }
384  // else
385  for(size_t i = 0; i < mDesc.get_lengths()[rank]; i++)
386  {
387  idx[rank] = i;
388  ForEach_impl(std::forward<const F>(f), idx, rank + 1);
389  }
390  }
391 
392  template <typename F>
393  void ForEach(const F&& f) const
394  {
395  std::vector<size_t> idx(mDesc.get_num_of_dimension(), 0);
396  ForEach_impl(std::forward<const F>(f), idx, size_t(0));
397  }
398 
399  template <typename G>
400  void GenerateTensorValue(G g, std::size_t num_thread = 1)
401  {
402  switch(mDesc.get_num_of_dimension())
403  {
404  case 1: {
405  auto f = [&](auto i) { (*this)(i) = g(i); };
406  make_ParallelTensorFunctor(f, mDesc.get_lengths()[0])(num_thread);
407  break;
408  }
409  case 2: {
410  auto f = [&](auto i0, auto i1) { (*this)(i0, i1) = g(i0, i1); };
412  num_thread);
413  break;
414  }
415  case 3: {
416  auto f = [&](auto i0, auto i1, auto i2) { (*this)(i0, i1, i2) = g(i0, i1, i2); };
418  mDesc.get_lengths()[0],
419  mDesc.get_lengths()[1],
420  mDesc.get_lengths()[2])(num_thread);
421  break;
422  }
423  case 4: {
424  auto f = [&](auto i0, auto i1, auto i2, auto i3) {
425  (*this)(i0, i1, i2, i3) = g(i0, i1, i2, i3);
426  };
428  mDesc.get_lengths()[0],
429  mDesc.get_lengths()[1],
430  mDesc.get_lengths()[2],
431  mDesc.get_lengths()[3])(num_thread);
432  break;
433  }
434  case 5: {
435  auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4) {
436  (*this)(i0, i1, i2, i3, i4) = g(i0, i1, i2, i3, i4);
437  };
439  mDesc.get_lengths()[0],
440  mDesc.get_lengths()[1],
441  mDesc.get_lengths()[2],
442  mDesc.get_lengths()[3],
443  mDesc.get_lengths()[4])(num_thread);
444  break;
445  }
446  case 6: {
447  auto f = [&](auto i0, auto i1, auto i2, auto i3, auto i4, auto i5) {
448  (*this)(i0, i1, i2, i3, i4, i5) = g(i0, i1, i2, i3, i4, i5);
449  };
451  mDesc.get_lengths()[0],
452  mDesc.get_lengths()[1],
453  mDesc.get_lengths()[2],
454  mDesc.get_lengths()[3],
455  mDesc.get_lengths()[4],
456  mDesc.get_lengths()[5])(num_thread);
457  break;
458  }
459  default: throw std::runtime_error("unspported dimension");
460  }
461  }
462 
463  template <typename... Is>
464  std::size_t GetOffsetFromMultiIndex(Is... is) const
465  {
466  return mDesc.GetOffsetFromMultiIndex(is...);
467  }
468 
469  template <typename... Is>
470  T& operator()(Is... is)
471  {
472  return mData[mDesc.GetOffsetFromMultiIndex(is...)];
473  }
474 
475  template <typename... Is>
476  const T& operator()(Is... is) const
477  {
478  return mData[mDesc.GetOffsetFromMultiIndex(is...)];
479  }
480 
481  T& operator()(std::vector<std::size_t> idx)
482  {
483  return mData[mDesc.GetOffsetFromMultiIndex(idx)];
484  }
485 
486  const T& operator()(std::vector<std::size_t> idx) const
487  {
488  return mData[mDesc.GetOffsetFromMultiIndex(idx)];
489  }
490 
491  HostTensor<T> transpose(std::vector<size_t> axes = {}) const
492  {
493  if(axes.empty())
494  {
495  axes.resize(this->get_num_of_dimension());
496  std::iota(axes.rbegin(), axes.rend(), 0);
497  }
498  if(axes.size() != mDesc.get_num_of_dimension())
499  {
500  throw std::runtime_error(
501  "HostTensor::transpose(): size of axes must match tensor dimension");
502  }
503  std::vector<size_t> tlengths, tstrides;
504  for(const auto& axis : axes)
505  {
506  tlengths.push_back(get_lengths()[axis]);
507  tstrides.push_back(get_strides()[axis]);
508  }
509  HostTensor<T> ret(*this);
510  ret.mDesc = HostTensorDescriptor(tlengths, tstrides);
511  return ret;
512  }
513 
514  HostTensor<T> transpose(std::vector<size_t> axes = {})
515  {
516  return const_cast<HostTensor<T> const*>(this)->transpose(axes);
517  }
518 
519  typename Data::iterator begin() { return mData.begin(); }
520 
521  typename Data::iterator end() { return mData.end(); }
522 
523  typename Data::pointer data() { return mData.data(); }
524 
525  typename Data::const_iterator begin() const { return mData.begin(); }
526 
527  typename Data::const_iterator end() const { return mData.end(); }
528 
529  typename Data::const_pointer data() const { return mData.data(); }
530 
531  typename Data::size_type size() const { return mData.size(); }
532 
533  // return a slice of this tensor
534  // for simplicity we just copy the data and return a new tensor
535  auto slice(std::vector<size_t> s_begin, std::vector<size_t> s_end) const
536  {
537  assert(s_begin.size() == s_end.size());
538  assert(s_begin.size() == get_num_of_dimension());
539 
540  std::vector<size_t> s_len(s_begin.size());
542  s_end.begin(), s_end.end(), s_begin.begin(), s_len.begin(), std::minus<size_t>{});
543  HostTensor<T> sliced_tensor(s_len);
544 
545  sliced_tensor.ForEach([&](auto& self, auto idx) {
546  std::vector<size_t> src_idx(idx.size());
548  idx.begin(), idx.end(), s_begin.begin(), src_idx.begin(), std::plus<size_t>{});
549  self(idx) = operator()(src_idx);
550  });
551 
552  return sliced_tensor;
553  }
554 
555  template <typename U = T>
556  auto AsSpan() const
557  {
558  constexpr std::size_t FromSize = sizeof(T);
559  constexpr std::size_t ToSize = sizeof(U);
560 
561  using Element = std::add_const_t<std::remove_reference_t<U>>;
562  return ck_tile::span<Element>{reinterpret_cast<Element*>(data()),
563  size() * FromSize / ToSize};
564  }
565 
566  template <typename U = T>
567  auto AsSpan()
568  {
569  constexpr std::size_t FromSize = sizeof(T);
570  constexpr std::size_t ToSize = sizeof(U);
571 
572  using Element = std::remove_reference_t<U>;
573  return ck_tile::span<Element>{reinterpret_cast<Element*>(data()),
574  size() * FromSize / ToSize};
575  }
576 
577  friend std::ostream& operator<<(std::ostream& os, const HostTensor<T>& t)
578  {
579  os << t.mDesc;
580  os << "[";
581  for(typename Data::size_type idx = 0; idx < t.mData.size(); ++idx)
582  {
583  if(0 < idx)
584  {
585  os << ", ";
586  }
587  if constexpr(std::is_same_v<T, bf16_t> || std::is_same_v<T, fp16_t>)
588  {
589  os << type_convert<float>(t.mData[idx]) << " #### ";
590  }
591  else
592  {
593  os << t.mData[idx];
594  }
595  }
596  os << "]";
597  return os;
598  }
599 
600  // read data from a file, as dtype
601  // the file could dumped from torch as (targeting tensor is t here)
602  // numpy.savetxt("f.txt", t.view(-1).numpy())
603  // numpy.savetxt("f.txt", t.cpu().view(-1).numpy()) # from cuda to cpu to save
604  // numpy.savetxt("f.txt", t.cpu().view(-1).numpy(), fmt="%d") # save as int
605  // will output f.txt, each line is a value
606  // dtype=float or int, internally will cast to real type
607  void loadtxt(std::string file_name, std::string dtype = "float")
608  {
609  std::ifstream file(file_name);
610 
611  if(file.is_open())
612  {
613  std::string line;
614 
615  index_t cnt = 0;
616  while(std::getline(file, line))
617  {
618  if(cnt >= static_cast<index_t>(mData.size()))
619  {
620  throw std::runtime_error(std::string("data read from file:") + file_name +
621  " is too big");
622  }
623 
624  if(dtype == "float")
625  {
626  mData[cnt] = type_convert<T>(std::stof(line));
627  }
628  else if(dtype == "int" || dtype == "int32")
629  {
630  mData[cnt] = type_convert<T>(std::stoi(line));
631  }
632  cnt++;
633  }
634  file.close();
635  if(cnt < static_cast<index_t>(mData.size()))
636  {
637  std::cerr << "Warning! reading from file:" << file_name
638  << ", does not match the size of this tensor" << std::endl;
639  }
640  }
641  else
642  {
643  // Print an error message to the standard error
644  // stream if the file cannot be opened.
645  throw std::runtime_error(std::string("unable to open file:") + file_name);
646  }
647  }
648 
649  // can save to a txt file and read from torch as:
650  // torch.from_numpy(np.loadtxt('f.txt', dtype=np.int32/np.float32...)).view([...]).contiguous()
651  void savetxt(std::string file_name, std::string dtype = "float")
652  {
653  std::ofstream file(file_name);
654 
655  if(file.is_open())
656  {
657  for(auto& itm : mData)
658  {
659  if(dtype == "float")
660  file << type_convert<float>(itm) << std::endl;
661  else if(dtype == "int")
662  file << type_convert<int>(itm) << std::endl;
663  else
664  // TODO: we didn't implement operator<< for all custom
665  // data types, here fall back to float in case compile error
666  file << type_convert<float>(itm) << std::endl;
667  }
668  file.close();
669  }
670  else
671  {
672  // Print an error message to the standard error
673  // stream if the file cannot be opened.
674  throw std::runtime_error(std::string("unable to open file:") + file_name);
675  }
676  }
677 
680 };
681 
682 template <bool is_row_major>
683 auto host_tensor_descriptor(std::size_t row,
684  std::size_t col,
685  std::size_t stride,
687 {
688  using namespace ck_tile::literals;
689 
690  if constexpr(is_row_major)
691  {
692  return HostTensorDescriptor({row, col}, {stride, 1_uz});
693  }
694  else
695  {
696  return HostTensorDescriptor({row, col}, {1_uz, stride});
697  }
698 }
699 template <bool is_row_major>
700 auto get_default_stride(std::size_t row,
701  std::size_t col,
702  std::size_t stride,
704 {
705  if(stride == 0)
706  {
707  if constexpr(is_row_major)
708  {
709  return col;
710  }
711  else
712  {
713  return row;
714  }
715  }
716  else
717  return stride;
718 }
719 
720 } // namespace ck_tile
Definition: span.hpp:18
#define CK_TILE_HOST
Definition: config.hpp:39
__host__ constexpr __device__ auto rank([[maybe_unused]] const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition: layout_utils.hpp:310
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
auto fill(OutputRange &&range, const T &init) -> std::void_t< decltype(std::fill(std::begin(std::forward< OutputRange >(range)), std::end(std::forward< OutputRange >(range)), init))>
Definition: algorithm.hpp:25
auto transform(InputRange &&range, OutputIterator iter, UnaryOperation unary_op) -> decltype(std::transform(std::begin(range), std::end(range), iter, unary_op))
Definition: algorithm.hpp:36
Definition: literals.hpp:9
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:272
CK_TILE_HOST auto call_f_unpack_args(F f, T args)
Definition: host_tensor.hpp:67
CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old(const HostTensorDescriptor &a, const New2Old &new2old)
Definition: host_tensor.hpp:202
CK_TILE_HOST auto call_f_unpack_args_impl(F f, T args, std::index_sequence< Is... >)
Definition: host_tensor.hpp:61
auto host_tensor_descriptor(std::size_t row, std::size_t col, std::size_t stride, bool_constant< is_row_major >)
Definition: host_tensor.hpp:683
CK_TILE_HOST std::ostream & LogRangeAsType(std::ostream &os, Range &&range, std::string delim, int precision=std::cout.precision(), int width=0)
Definition: host_tensor.hpp:42
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST std::ostream & LogRange(std::ostream &os, Range &&range, std::string delim, int precision=std::cout.precision(), int width=0)
Definition: host_tensor.hpp:23
CK_TILE_HOST auto construct_f_unpack_args(F, T args)
Definition: host_tensor.hpp:81
CK_TILE_HOST auto construct_f_unpack_args_impl(T args, std::index_sequence< Is... >)
Definition: host_tensor.hpp:75
auto get_default_stride(std::size_t row, std::size_t col, std::size_t stride, bool_constant< is_row_major >)
Definition: host_tensor.hpp:700
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:13
__device__ void inner_product(const TA &a, const TB &b, TC &c)
Definition: host_tensor.hpp:97
Definition: host_tensor.hpp:89
std::size_t get_stride(std::size_t dim) const
Definition: host_tensor.hpp:164
std::size_t GetOffsetFromMultiIndex(Is... is) const
Definition: host_tensor.hpp:169
std::size_t get_element_size() const
Definition: host_tensor.hpp:141
void CalculateStrides()
Definition: host_tensor.hpp:92
std::size_t get_num_of_dimension() const
Definition: host_tensor.hpp:140
std::size_t GetOffsetFromMultiIndex(std::vector< std::size_t > iss) const
Definition: host_tensor.hpp:176
HostTensorDescriptor(const std::initializer_list< X > &lens, const std::initializer_list< Y > &strides)
Definition: host_tensor.hpp:124
std::size_t get_element_space_size() const
Definition: host_tensor.hpp:147
const std::vector< std::size_t > & get_strides() const
Definition: host_tensor.hpp:166
const std::vector< std::size_t > & get_lengths() const
Definition: host_tensor.hpp:162
std::size_t get_length(std::size_t dim) const
Definition: host_tensor.hpp:160
HostTensorDescriptor(const Lengths &lens, const Strides &strides)
Definition: host_tensor.hpp:135
HostTensorDescriptor(const std::initializer_list< X > &lens)
Definition: host_tensor.hpp:107
HostTensorDescriptor(const Lengths &lens)
Definition: host_tensor.hpp:115
friend std::ostream & operator<<(std::ostream &os, const HostTensorDescriptor &desc)
Definition: host_tensor.hpp:181
Definition: host_tensor.hpp:279
void ForEach(F &&f)
Definition: host_tensor.hpp:370
std::size_t get_stride(std::size_t dim) const
Definition: host_tensor.hpp:335
void ForEach(const F &&f) const
Definition: host_tensor.hpp:393
HostTensor(HostTensor &&)=default
Data::size_type size() const
Definition: host_tensor.hpp:531
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:333
HostTensor(std::initializer_list< X > lens, std::initializer_list< Y > strides)
Definition: host_tensor.hpp:289
HostTensor & operator=(HostTensor &&)=default
friend std::ostream & operator<<(std::ostream &os, const HostTensor< T > &t)
Definition: host_tensor.hpp:577
HostTensor(std::initializer_list< X > lens)
Definition: host_tensor.hpp:284
HostTensor & operator=(const HostTensor &)=default
std::size_t get_element_space_size_in_bytes() const
Definition: host_tensor.hpp:345
decltype(auto) get_strides() const
Definition: host_tensor.hpp:337
HostTensor(const HostTensor &)=default
Data::iterator end()
Definition: host_tensor.hpp:521
void GenerateTensorValue(G g, std::size_t num_thread=1)
Definition: host_tensor.hpp:400
void SetZero()
Definition: host_tensor.hpp:351
Descriptor mDesc
Definition: host_tensor.hpp:678
const T & operator()(Is... is) const
Definition: host_tensor.hpp:476
HostTensor(const Lengths &lens)
Definition: host_tensor.hpp:295
std::size_t GetOffsetFromMultiIndex(Is... is) const
Definition: host_tensor.hpp:464
Data::pointer data()
Definition: host_tensor.hpp:523
T & operator()(Is... is)
Definition: host_tensor.hpp:470
HostTensor< OutT > CopyAsType() const
Definition: host_tensor.hpp:308
auto AsSpan() const
Definition: host_tensor.hpp:556
auto slice(std::vector< size_t > s_begin, std::vector< size_t > s_end) const
Definition: host_tensor.hpp:535
std::vector< T > Data
Definition: host_tensor.hpp:281
auto AsSpan()
Definition: host_tensor.hpp:567
Data::const_iterator begin() const
Definition: host_tensor.hpp:525
std::size_t get_num_of_dimension() const
Definition: host_tensor.hpp:339
std::size_t get_element_space_size() const
Definition: host_tensor.hpp:343
HostTensor(const Lengths &lens, const Strides &strides)
Definition: host_tensor.hpp:300
void loadtxt(std::string file_name, std::string dtype="float")
Definition: host_tensor.hpp:607
Data::const_pointer data() const
Definition: host_tensor.hpp:529
T & operator()(std::vector< std::size_t > idx)
Definition: host_tensor.hpp:481
void ForEach_impl(const F &&f, std::vector< size_t > &idx, size_t rank) const
Definition: host_tensor.hpp:377
HostTensor(const Descriptor &desc)
Definition: host_tensor.hpp:305
const T & operator()(std::vector< std::size_t > idx) const
Definition: host_tensor.hpp:486
HostTensor< T > transpose(std::vector< size_t > axes={})
Definition: host_tensor.hpp:514
Data::iterator begin()
Definition: host_tensor.hpp:519
void savetxt(std::string file_name, std::string dtype="float")
Definition: host_tensor.hpp:651
HostTensor(const HostTensor< FromT > &other)
Definition: host_tensor.hpp:327
HostTensor< T > transpose(std::vector< size_t > axes={}) const
Definition: host_tensor.hpp:491
std::size_t get_length(std::size_t dim) const
Definition: host_tensor.hpp:331
std::size_t get_element_size() const
Definition: host_tensor.hpp:341
void ForEach_impl(F &&f, std::vector< size_t > &idx, size_t rank)
Definition: host_tensor.hpp:354
Data::const_iterator end() const
Definition: host_tensor.hpp:527
Data mData
Definition: host_tensor.hpp:679
Definition: host_tensor.hpp:219
void operator()(std::size_t num_thread=1) const
Definition: host_tensor.hpp:249
ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:226
std::size_t mN1d
Definition: host_tensor.hpp:224
std::array< std::size_t, NDIM > mLens
Definition: host_tensor.hpp:222
std::array< std::size_t, NDIM > mStrides
Definition: host_tensor.hpp:223
static constexpr std::size_t NDIM
Definition: host_tensor.hpp:221
F mF
Definition: host_tensor.hpp:220
std::array< std::size_t, NDIM > GetNdIndices(std::size_t i) const
Definition: host_tensor.hpp:236
Definition: integral_constant.hpp:13
Definition: joinable_thread.hpp:12