13 namespace tensor_operation {
32 template <index_t NumDim1, index_t NumDim2>
33 auto CalculateMaxRead(
const std::vector<index_t>& lengths,
const std::vector<index_t>& strides)
35 if(lengths.size() != NumDim1 + NumDim2)
37 std::ostringstream err;
38 err <<
"Incorrect number of lengths in "
39 <<
"device_contraction_utils.hpp"
40 <<
":" << __LINE__ <<
", in function: " << __func__;
41 throw std::runtime_error(err.str());
43 if(strides.size() != NumDim1 + NumDim2)
45 std::ostringstream err;
46 err <<
"Incorrect number of strides in "
47 <<
"device_contraction_utils.hpp"
48 <<
":" << __LINE__ <<
", in function: " << __func__;
49 throw std::runtime_error(err.str());
53 index_t begin_idx, end_idx, continous_dim, consecutive_stride = 1;
54 if(strides[NumDim1 - 1] == 1 && strides[NumDim1 + NumDim2 - 1] == 1)
57 bool dims1_are_ones =
true;
58 for(
index_t dim_idx = 0; dim_idx < NumDim1; dim_idx++)
60 if(lengths[dim_idx] != 1)
62 dims1_are_ones =
false;
69 end_idx = NumDim1 + NumDim2 - 1;
75 end_idx = NumDim1 - 1;
79 else if(strides[NumDim1 - 1] == 1)
82 end_idx = NumDim1 - 1;
85 else if(strides[NumDim1 + NumDim2 - 1] == 1)
88 end_idx = NumDim1 + NumDim2 - 1;
95 consecutive_stride = 1;
97 return make_tuple(continous_dim, consecutive_stride);
100 for(
index_t dim_idx = end_idx; dim_idx >= begin_idx; --dim_idx)
102 if(strides[dim_idx] == consecutive_stride)
104 consecutive_stride *= lengths[dim_idx];
111 const index_t max_subsequent_elems = consecutive_stride;
112 return make_tuple(continous_dim, max_subsequent_elems);
auto CalculateMaxRead(const std::vector< index_t > &lengths, const std::vector< index_t > &strides)
Definition: device_contraction_utils.hpp:33
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289