14 namespace tensor_operation {
 
   19 template <index_t Rank, 
int NumReduceDim>
 
   20 std::pair<long_index_t, long_index_t> 
get_2d_lengths(
const std::vector<index_t>& inLengths)
 
   22     static_assert(Rank <= 12, 
"bigger Rank size not supported!");
 
   27     constexpr 
int NumInvariantDim = Rank - NumReduceDim;
 
   29     for(
int i = NumInvariantDim; i < Rank; i++)
 
   30         reduce_total_length *= inLengths[i];
 
   32     for(
int i = 0; i < NumInvariantDim; i++)
 
   33         invariant_total_length *= inLengths[i];
 
   35     return std::make_pair(invariant_total_length, reduce_total_length);
 
   38 template <index_t Rank, 
int NumReduceDim>
 
   39 std::pair<long_index_t, long_index_t> 
get_2d_lengths(
const std::array<index_t, Rank>& inLengths)
 
   41     static_assert(Rank <= 12, 
"bigger Rank size not supported!");
 
   46     constexpr 
int NumInvariantDim = Rank - NumReduceDim;
 
   48     for(
int i = NumInvariantDim; i < Rank; i++)
 
   49         reduce_total_length *= inLengths[i];
 
   51     for(
int i = 0; i < NumInvariantDim; i++)
 
   52         invariant_total_length *= inLengths[i];
 
   54     return std::make_pair(invariant_total_length, reduce_total_length);
 
   64 template <index_t arraySize>
 
   67     static_assert(arraySize >= 1 && arraySize <= 6, 
"The tensor should have 1 to 6 dimensions");
 
   74 template <index_t Rank, index_t NumReduceDim>
 
   76                                                const std::vector<int>& reduceDims)
 
   78     std::vector<index_t> newLengthsStrides;
 
   80     assert(Rank == origLengthsStrides.size() && NumReduceDim == reduceDims.size());
 
   85     for(
int i = 0; i < NumReduceDim; i++)
 
   87         reduceFlag |= 1 << reduceDims[i];
 
   91     for(
int i = 0; i < Rank; i++)
 
   92         if((reduceFlag & (1 << i)) == 0)
 
   94             newLengthsStrides.push_back(origLengthsStrides[i]);
 
   98     for(
int i = 0; i < Rank; i++)
 
   99         if((reduceFlag & (1 << i)) > 0)
 
  101             newLengthsStrides.push_back(origLengthsStrides[i]);
 
  104     return newLengthsStrides;
 
  107 template <index_t Rank, index_t NumReduceDim>
 
  108 std::array<index_t, Rank>
 
  110                           const std::array<int, NumReduceDim>& reduceDims)
 
  112     std::array<index_t, Rank> newLengthsStrides;
 
  117     for(
int i = 0; i < NumReduceDim; i++)
 
  119         reduceFlag |= 1 << reduceDims[i];
 
  124     for(
int i = 0; i < Rank; i++)
 
  125         if((reduceFlag & (1 << i)) == 0)
 
  127             newLengthsStrides[pos++] = origLengthsStrides[i];
 
  131     for(
int i = 0; i < Rank; i++)
 
  132         if((reduceFlag & (1 << i)) > 0)
 
  134             newLengthsStrides[pos++] = origLengthsStrides[i];
 
  137     return newLengthsStrides;
 
std::vector< index_t > shuffle_tensor_dimensions(const std::vector< index_t > &origLengthsStrides, const std::vector< int > &reduceDims)
Definition: device_reduce_common.hpp:75
 
std::pair< long_index_t, long_index_t > get_2d_lengths(const std::vector< index_t > &inLengths)
Definition: device_reduce_common.hpp:20
 
auto make_tuple_from_array(const std::vector< index_t > &lengths, Number< arraySize >)
Definition: device_reduce_common.hpp:65
 
auto make_tuple_from_array_and_index_seq(const std::vector< index_t > &lengths, Sequence< Ns... >)
Definition: device_reduce_common.hpp:59
 
int64_t long_index_t
Definition: ck.hpp:290
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:289
 
Definition: sequence.hpp:43
 
typename conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:256
 
Definition: integral_constant.hpp:10