21 template <
typename Range>
22 std::ostream&
LogRange(std::ostream& os, Range&& range, std::string delim)
36 template <
typename T,
typename Range>
37 std::ostream&
LogRangeAsType(std::ostream& os, Range&& range, std::string delim)
48 if constexpr(std::is_same_v<RangeType, ck::f8_t> || std::is_same_v<RangeType, ck::bf8_t> ||
49 std::is_same_v<RangeType, ck::bhalf_t>)
51 os << ck::type_convert<float>(v);
53 else if constexpr(std::is_same_v<RangeType, ck::pk_i4_t>)
55 const auto packed_floats = ck::type_convert<ck::float2_t>(v);
57 os << vector_of_floats.template AsType<float>()[
ck::Number<0>{}] << delim
58 << vector_of_floats.template AsType<float>()[
ck::Number<1>{}];
62 os << static_cast<T>(v);
68 template <
typename F,
typename T, std::size_t... Is>
71 return f(std::get<Is>(args)...);
74 template <
typename F,
typename T>
77 constexpr std::size_t N = std::tuple_size<T>{};
82 template <
typename F,
typename T, std::size_t... Is>
85 return F(std::get<Is>(args)...);
88 template <
typename F,
typename T>
91 constexpr std::size_t N = std::tuple_size<T>{};
93 return construct_f_unpack_args_impl<F>(args, std::make_index_sequence<N>{});
102 template <
typename X,
typename = std::enable_if_t<std::is_convertible_v<X, std::
size_t>>>
109 : mLens(lens.begin(), lens.end())
114 template <
typename Lengths,
116 std::is_convertible_v<ck::ranges::range_value_t<Lengths>, std::size_t> ||
117 std::is_convertible_v<ck::ranges::range_value_t<Lengths>,
ck::long_index_t>>>
123 template <
typename X,
125 typename = std::enable_if_t<std::is_convertible_v<X, std::size_t> &&
126 std::is_convertible_v<Y, std::size_t>>>
128 const std::initializer_list<Y>& strides)
129 : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
134 const std::initializer_list<ck::long_index_t>& strides)
135 : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
139 template <
typename Lengths,
142 (std::is_convertible_v<ck::ranges::range_value_t<Lengths>, std::size_t> &&
143 std::is_convertible_v<ck::ranges::range_value_t<Strides>, std::size_t>) ||
147 : mLens(lens.begin(), lens.end()), mStrides(strides.begin(), strides.end())
158 template <
typename... Is>
162 std::initializer_list<std::size_t> iss{
static_cast<std::size_t
>(is)...};
174 std::vector<std::size_t> mLens;
175 std::vector<std::size_t> mStrides;
178 template <
typename New2Old>
180 const New2Old& new2old)
196 template <
typename... Xs>
211 template <
typename F,
typename... Xs>
215 static constexpr std::size_t
NDIM =
sizeof...(Xs);
216 std::array<std::size_t, NDIM>
mLens;
223 std::partial_sum(
mLens.rbegin(),
226 std::multiplies<std::size_t>());
232 std::array<std::size_t, NDIM> indices;
234 for(std::size_t idim = 0; idim <
NDIM; ++idim)
237 i -= indices[idim] *
mStrides[idim];
245 std::size_t work_per_thread = (
mN1d + num_thread - 1) / num_thread;
247 std::vector<joinable_thread> threads(num_thread);
249 for(std::size_t it = 0; it < num_thread; ++it)
251 std::size_t iw_begin = it * work_per_thread;
252 std::size_t iw_end =
std::min((it + 1) * work_per_thread,
mN1d);
255 for(std::size_t iw = iw_begin; iw < iw_end; ++iw)
265 template <
typename F,
typename... Xs>
271 template <
typename T>
277 template <
typename X>
282 template <
typename X,
typename Y>
283 Tensor(std::initializer_list<X> lens, std::initializer_list<Y> strides)
288 template <
typename Lengths>
293 template <
typename Lengths,
typename Str
ides>
294 Tensor(
const Lengths& lens,
const Strides& strides)
301 template <
typename OutT>
307 mData, ret.
mData.begin(), [](
auto value) { return ck::type_convert<OutT>(value); });
321 template <
typename FromT>
350 template <
typename F>
366 template <
typename F>
373 template <
typename F>
389 template <
typename F>
396 template <
typename G>
402 auto f = [&](
auto i) { (*this)(i) = g(i); };
407 auto f = [&](
auto i0,
auto i1) { (*this)(i0, i1) = g(i0, i1); };
412 auto f = [&](
auto i0,
auto i1,
auto i2) { (*this)(i0, i1, i2) = g(i0, i1, i2); };
418 auto f = [&](
auto i0,
auto i1,
auto i2,
auto i3) {
419 (*this)(i0, i1, i2, i3) = g(i0, i1, i2, i3);
429 auto f = [&](
auto i0,
auto i1,
auto i2,
auto i3,
auto i4) {
430 (*this)(i0, i1, i2, i3, i4) = g(i0, i1, i2, i3, i4);
441 auto f = [&](
auto i0,
auto i1,
auto i2,
auto i3,
auto i4,
auto i5) {
442 (*this)(i0, i1, i2, i3, i4, i5) = g(i0, i1, i2, i3, i4, i5);
454 auto f = [&](
auto i0,
466 (*this)(i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11) =
467 g(i0, i1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11);
484 default:
throw std::runtime_error(
"unspported dimension");
488 template <
typename... Is>
501 template <
typename... Is>
514 template <
typename... Is>
553 typename Data::iterator
end() {
return mData.end(); }
557 typename Data::const_iterator
begin()
const {
return mData.begin(); }
559 typename Data::const_iterator
end()
const {
return mData.end(); }
561 typename Data::const_pointer
data()
const {
return mData.data(); }
563 typename Data::size_type
size()
const {
return mData.size(); }
565 template <
typename U = T>
568 constexpr std::size_t FromSize =
sizeof(T);
569 constexpr std::size_t ToSize =
sizeof(U);
571 using Element = std::add_const_t<std::remove_reference_t<U>>;
575 template <
typename U = T>
578 constexpr std::size_t FromSize =
sizeof(T);
579 constexpr std::size_t ToSize =
sizeof(U);
581 using Element = std::remove_reference_t<U>;
__host__ constexpr __device__ auto rank([[maybe_unused]] const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition: layout_utils.hpp:310
auto call_f_unpack_args_impl(F f, T args, std::index_sequence< Is... >)
Definition: host_tensor.hpp:69
std::ostream & LogRangeAsType(std::ostream &os, Range &&range, std::string delim)
Definition: host_tensor.hpp:37
auto construct_f_unpack_args_impl(T args, std::index_sequence< Is... >)
Definition: host_tensor.hpp:83
auto call_f_unpack_args(F f, T args)
Definition: host_tensor.hpp:75
auto construct_f_unpack_args(F, T args)
Definition: host_tensor.hpp:89
auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:266
std::ostream & LogRange(std::ostream &os, Range &&range, std::string delim)
Definition: host_tensor.hpp:22
HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old(const HostTensorDescriptor &a, const New2Old &new2old)
Definition: host_tensor.hpp:179
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
auto transform(InputRange &&range, OutputIterator iter, UnaryOperation unary_op) -> decltype(std::transform(std::begin(range), std::end(range), iter, unary_op))
Definition: algorithm.hpp:36
iter_value_t< ranges::iterator_t< R > > range_value_t
Definition: ranges.hpp:28
int64_t long_index_t
Definition: ck.hpp:290
constexpr bool is_same_v
Definition: type.hpp:283
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:13
__device__ void inner_product(const TA &a, const TB &b, TC &c)
Definition: host_tensor.hpp:97
HostTensorDescriptor(const Lengths &lens)
Definition: host_tensor.hpp:118
const std::vector< std::size_t > & GetStrides() const
HostTensorDescriptor(const std::initializer_list< X > &lens)
Definition: host_tensor.hpp:103
std::size_t GetElementSize() const
const std::vector< std::size_t > & GetLengths() const
std::size_t GetOffsetFromMultiIndex(Is... is) const
Definition: host_tensor.hpp:159
HostTensorDescriptor(const std::initializer_list< X > &lens, const std::initializer_list< Y > &strides)
Definition: host_tensor.hpp:127
HostTensorDescriptor(const std::initializer_list< ck::long_index_t > &lens, const std::initializer_list< ck::long_index_t > &strides)
Definition: host_tensor.hpp:133
std::size_t GetOffsetFromMultiIndex(std::vector< std::size_t > iss) const
Definition: host_tensor.hpp:166
std::size_t GetNumOfDimension() const
std::size_t GetElementSpaceSize() const
HostTensorDescriptor()=default
HostTensorDescriptor(const Lengths &lens, const Strides &strides)
Definition: host_tensor.hpp:146
friend std::ostream & operator<<(std::ostream &os, const HostTensorDescriptor &desc)
HostTensorDescriptor(const std::initializer_list< ck::long_index_t > &lens)
Definition: host_tensor.hpp:108
Definition: host_tensor.hpp:213
std::array< std::size_t, NDIM > GetNdIndices(std::size_t i) const
Definition: host_tensor.hpp:230
F mF
Definition: host_tensor.hpp:214
std::size_t mN1d
Definition: host_tensor.hpp:218
ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:220
std::array< std::size_t, NDIM > mLens
Definition: host_tensor.hpp:216
std::array< std::size_t, NDIM > mStrides
Definition: host_tensor.hpp:217
void operator()(std::size_t num_thread=1) const
Definition: host_tensor.hpp:243
static constexpr std::size_t NDIM
Definition: host_tensor.hpp:215
Tensor wrapper that performs static and dynamic buffer logic. The tensor is based on a descriptor sto...
Definition: host_tensor.hpp:273
auto AsSpan() const
Definition: host_tensor.hpp:566
Tensor(const Lengths &lens, const Strides &strides)
Definition: host_tensor.hpp:294
std::size_t GetNumOfDimension() const
Definition: host_tensor.hpp:330
void ForEach(const F &&f) const
Definition: host_tensor.hpp:390
decltype(auto) GetLengths() const
Definition: host_tensor.hpp:326
Data::const_iterator end() const
Definition: host_tensor.hpp:559
std::size_t GetOffsetFromMultiIndex(Is... is) const
Definition: host_tensor.hpp:489
Tensor< OutT > CopyAsType() const
Definition: host_tensor.hpp:302
T & operator()(std::vector< std::size_t > idx)
Definition: host_tensor.hpp:527
void ForEach(F &&f)
Definition: host_tensor.hpp:367
Data::pointer data()
Definition: host_tensor.hpp:555
void ForEach_impl(F &&f, std::vector< size_t > &idx, size_t rank)
Definition: host_tensor.hpp:351
std::size_t GetElementSpaceSizeInBytes() const
Definition: host_tensor.hpp:346
void ForEach_impl(const F &&f, std::vector< size_t > &idx, size_t rank) const
Definition: host_tensor.hpp:374
Tensor & operator=(const Tensor &)=default
std::vector< T > Data
Definition: host_tensor.hpp:275
Data mData
Definition: host_tensor.hpp:586
Data::iterator end()
Definition: host_tensor.hpp:553
std::size_t GetElementSize() const
Definition: host_tensor.hpp:332
void SetZero()
Definition: host_tensor.hpp:348
Tensor(const Lengths &lens)
Definition: host_tensor.hpp:289
Tensor(Tensor &&)=default
const T & operator()(Is... is) const
Definition: host_tensor.hpp:515
Data::const_pointer data() const
Definition: host_tensor.hpp:561
auto AsSpan()
Definition: host_tensor.hpp:576
Data::iterator begin()
Definition: host_tensor.hpp:551
Tensor(std::initializer_list< X > lens, std::initializer_list< Y > strides)
Definition: host_tensor.hpp:283
Tensor(const Tensor &)=default
Tensor(const Descriptor &desc)
Definition: host_tensor.hpp:299
const T & operator()(std::vector< std::size_t > idx) const
Definition: host_tensor.hpp:539
Descriptor mDesc
Definition: host_tensor.hpp:585
Tensor & operator=(Tensor &&)=default
Data::const_iterator begin() const
Definition: host_tensor.hpp:557
std::size_t GetElementSpaceSize() const
Definition: host_tensor.hpp:334
Tensor(const Tensor< FromT > &other)
Definition: host_tensor.hpp:322
Data::size_type size() const
Definition: host_tensor.hpp:563
void GenerateTensorValue(G g, std::size_t num_thread=1)
Definition: host_tensor.hpp:397
decltype(auto) GetStrides() const
Definition: host_tensor.hpp:328
T & operator()(Is... is)
Definition: host_tensor.hpp:502
Tensor(std::initializer_list< X > lens)
Definition: host_tensor.hpp:278
Definition: integral_constant.hpp:10
Definition: data_type.hpp:320
Definition: data_type.hpp:347
Definition: host_tensor.hpp:195
joinable_thread(joinable_thread &&)=default
joinable_thread(Xs &&... xs)
Definition: host_tensor.hpp:197
~joinable_thread()
Definition: host_tensor.hpp:204
joinable_thread & operator=(joinable_thread &&)=default