14 namespace tensor_operation {
19 template <index_t Rank,
int NumReduceDim>
20 std::pair<long_index_t, long_index_t>
get_2d_lengths(
const std::vector<index_t>& inLengths)
22 static_assert(Rank <= 12,
"bigger Rank size not supported!");
27 constexpr
int NumInvariantDim = Rank - NumReduceDim;
29 for(
int i = NumInvariantDim; i < Rank; i++)
30 reduce_total_length *= inLengths[i];
32 for(
int i = 0; i < NumInvariantDim; i++)
33 invariant_total_length *= inLengths[i];
35 return std::make_pair(invariant_total_length, reduce_total_length);
38 template <index_t Rank,
int NumReduceDim>
39 std::pair<long_index_t, long_index_t>
get_2d_lengths(
const std::array<index_t, Rank>& inLengths)
41 static_assert(Rank <= 12,
"bigger Rank size not supported!");
46 constexpr
int NumInvariantDim = Rank - NumReduceDim;
48 for(
int i = NumInvariantDim; i < Rank; i++)
49 reduce_total_length *= inLengths[i];
51 for(
int i = 0; i < NumInvariantDim; i++)
52 invariant_total_length *= inLengths[i];
54 return std::make_pair(invariant_total_length, reduce_total_length);
64 template <index_t arraySize>
67 static_assert(arraySize >= 1 && arraySize <= 6,
"The tensor should have 1 to 6 dimensions");
74 template <index_t Rank, index_t NumReduceDim>
76 const std::vector<int>& reduceDims)
78 std::vector<index_t> newLengthsStrides;
80 assert(Rank == origLengthsStrides.size() && NumReduceDim == reduceDims.size());
85 for(
int i = 0; i < NumReduceDim; i++)
87 reduceFlag |= 1 << reduceDims[i];
91 for(
int i = 0; i < Rank; i++)
92 if((reduceFlag & (1 << i)) == 0)
94 newLengthsStrides.push_back(origLengthsStrides[i]);
98 for(
int i = 0; i < Rank; i++)
99 if((reduceFlag & (1 << i)) > 0)
101 newLengthsStrides.push_back(origLengthsStrides[i]);
104 return newLengthsStrides;
107 template <index_t Rank, index_t NumReduceDim>
108 std::array<index_t, Rank>
110 const std::array<int, NumReduceDim>& reduceDims)
112 std::array<index_t, Rank> newLengthsStrides;
117 for(
int i = 0; i < NumReduceDim; i++)
119 reduceFlag |= 1 << reduceDims[i];
124 for(
int i = 0; i < Rank; i++)
125 if((reduceFlag & (1 << i)) == 0)
127 newLengthsStrides[pos++] = origLengthsStrides[i];
131 for(
int i = 0; i < Rank; i++)
132 if((reduceFlag & (1 << i)) > 0)
134 newLengthsStrides[pos++] = origLengthsStrides[i];
137 return newLengthsStrides;
std::vector< index_t > shuffle_tensor_dimensions(const std::vector< index_t > &origLengthsStrides, const std::vector< int > &reduceDims)
Definition: device_reduce_common.hpp:75
std::pair< long_index_t, long_index_t > get_2d_lengths(const std::vector< index_t > &inLengths)
Definition: device_reduce_common.hpp:20
auto make_tuple_from_array(const std::vector< index_t > &lengths, Number< arraySize >)
Definition: device_reduce_common.hpp:65
auto make_tuple_from_array_and_index_seq(const std::vector< index_t > &lengths, Sequence< Ns... >)
Definition: device_reduce_common.hpp:59
int64_t long_index_t
Definition: ck.hpp:290
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
Definition: sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:256
Definition: integral_constant.hpp:10