13 namespace tensor_operation {
 
   32 template <index_t NumDim1, index_t NumDim2>
 
   33 auto CalculateMaxRead(
const std::vector<index_t>& lengths, 
const std::vector<index_t>& strides)
 
   35     if(lengths.size() != NumDim1 + NumDim2)
 
   37         std::ostringstream err;
 
   38         err << 
"Incorrect number of lengths in " 
   39             << 
"device_contraction_utils.hpp" 
   40             << 
":" << __LINE__ << 
", in function: " << __func__;
 
   41         throw std::runtime_error(err.str());
 
   43     if(strides.size() != NumDim1 + NumDim2)
 
   45         std::ostringstream err;
 
   46         err << 
"Incorrect number of strides in " 
   47             << 
"device_contraction_utils.hpp" 
   48             << 
":" << __LINE__ << 
", in function: " << __func__;
 
   49         throw std::runtime_error(err.str());
 
   53     index_t begin_idx, end_idx, continous_dim, consecutive_stride = 1;
 
   54     if(strides[NumDim1 - 1] == 1 && strides[NumDim1 + NumDim2 - 1] == 1)
 
   57         bool dims1_are_ones = 
true;
 
   58         for(
index_t dim_idx = 0; dim_idx < NumDim1; dim_idx++)
 
   60             if(lengths[dim_idx] != 1)
 
   62                 dims1_are_ones = 
false;
 
   69             end_idx       = NumDim1 + NumDim2 - 1;
 
   75             end_idx       = NumDim1 - 1;
 
   79     else if(strides[NumDim1 - 1] == 1)
 
   82         end_idx       = NumDim1 - 1;
 
   85     else if(strides[NumDim1 + NumDim2 - 1] == 1)
 
   88         end_idx       = NumDim1 + NumDim2 - 1;
 
   95         consecutive_stride = 1;
 
   97         return make_tuple(continous_dim, consecutive_stride);
 
  100     for(
index_t dim_idx = end_idx; dim_idx >= begin_idx; --dim_idx)
 
  102         if(strides[dim_idx] == consecutive_stride)
 
  104             consecutive_stride *= lengths[dim_idx];
 
  111     const index_t max_subsequent_elems = consecutive_stride;
 
  112     return make_tuple(continous_dim, max_subsequent_elems);
 
auto CalculateMaxRead(const std::vector< index_t > &lengths, const std::vector< index_t > &strides)
Definition: device_contraction_utils.hpp:33
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:289