29 template <
typename Shape, 
typename UnrolledDescriptorType>
 
   33 using is_tuple = decltype(std::declval<T&>().IsTuple());
 
   43 template <
typename... Ts>
 
   44 __host__ __device__ constexpr 
static auto 
   45 GenerateColumnMajorPackedStrides(
const Tuple<Ts...>& 
shape)
 
   50             if constexpr(i.value == 0)
 
   56                 return TupleReduce<Number<0>{}.value, i.value>([](
auto x, 
auto y) { 
return x * y; },
 
   60         Number<decltype(unrolled_shape)::Size()>{});
 
   70 template <
typename LayoutShape, 
typename LayoutStr
ides>
 
   71 __host__ __device__ constexpr 
auto MakeUnrolledDescriptor(
const LayoutShape& 
shape,
 
   72                                                           const LayoutStrides& strides)
 
   75     if constexpr(
is_same_v<LayoutStrides, Tuple<>>)
 
   78         const auto unrolled_strides = GenerateColumnMajorPackedStrides(unrolled_shape);
 
   79         static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
 
   80                       "Size of strides and shape are not consistent.");
 
   86         static_assert(unrolled_shape.Size() == unrolled_strides.Size(),
 
   87                       "Size of strides and shape are not consistent.");
 
  104 template <
typename Shape, 
typename Str
ides>
 
  107     using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Strides{}));
 
  109                                                  detail::MakeUnrolledDescriptor(
shape, strides));
 
  119 template <
typename Shape>
 
  122     using UnrolledDescriptorType = decltype(detail::MakeUnrolledDescriptor(Shape{}, Tuple<>{}));
 
  124                                                  detail::MakeUnrolledDescriptor(
shape, Tuple<>{}));
 
  135 template <
typename T>
 
  136 __host__ __device__ T constexpr 
get(
const T& dim)
 
  148 template <
index_t idx, 
typename... Dims>
 
  149 __host__ __device__ constexpr 
auto get(
const Tuple<Dims...>& tuple)
 
  151     return tuple.At(Number<idx>{});
 
  161 template <index_t 
idx, 
typename Shape, 
typename UnrolledDesc>
 
  165     const auto new_shape = get<idx>(
shape);
 
  167                   "Shape of sub layout must be tuple");
 
  177             if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
 
  187         Number<old_shape_dims>{});
 
  189     const auto lower_dims =
 
  190         generate_tuple([&](
auto i) { 
return Sequence<i.value>{}; }, Number<old_shape_dims>{});
 
  193             if constexpr(i < shape_offset || i >= shape_offset + new_shape_dims)
 
  198                 return Sequence<i.value - shape_offset>{};
 
  201         Number<old_shape_dims>{});
 
  203     const auto& flatten_desc = 
layout.GetUnrolledDescriptor();
 
  216 __host__ __device__ constexpr 
auto get(
const T& elem)
 
  218     return get<Idxs...>(get<Idx>(elem));
 
  229 template <
typename T>
 
  230 __host__ __device__ T constexpr 
size(
const T& dim)
 
  242 template <index_t 
idx, 
typename Shape, 
typename UnrolledDescriptorType>
 
  245     return layout.template GetLength<idx>();
 
  254 template <
typename... ShapeDims>
 
  255 __host__ __device__ constexpr 
auto size(
const Tuple<ShapeDims...>& 
shape)
 
  258     return TupleReduce<0, unrolled_shape.Size()>([](
auto x, 
auto y) { 
return x * y; },
 
  268 template <
typename Shape, 
typename UnrolledDescriptorType>
 
  271     return layout.GetLengths();
 
  281 template <
index_t idx, 
typename... Ts>
 
  282 __host__ __device__ constexpr 
auto size(
const Tuple<Ts...>& tuple)
 
  284     return size(tuple.At(Number<idx>{}));
 
  296 __host__ __device__ constexpr 
auto size(
const T& elem)
 
  298     return size(get<Idx, Idxs...>(elem));
 
  308 template <
typename Shape, 
typename UnrolledDescriptorType>
 
  309 __host__ __device__ constexpr 
auto 
  312     return Shape::Size();
 
  322 template <
typename... Dims>
 
  323 __host__ __device__ constexpr 
auto rank([[maybe_unused]] 
const Tuple<Dims...>& tuple)
 
  325     return Tuple<Dims...>::Size();
 
  335 template <index_t IDim>
 
  336 __host__ __device__ constexpr 
index_t rank([[maybe_unused]] 
const Number<IDim>& dim)
 
  348 __host__ __device__ constexpr 
index_t rank([[maybe_unused]] 
const index_t& dim) { 
return 1; }
 
  357 template <
index_t... Idxs, 
typename T>
 
  358 __host__ __device__ constexpr 
auto rank(
const T& elem)
 
  360     return rank(get<Idxs...>(elem));
 
  370 template <
typename Shape, 
typename UnrolledDescriptorType>
 
  383 template <
typename... Dims>
 
  384 __host__ __device__ constexpr 
auto depth(
const Tuple<Dims...>& tuple)
 
  396 template <index_t IDim>
 
  397 __host__ __device__ constexpr 
index_t depth([[maybe_unused]] 
const Number<IDim>& dim)
 
  409 __host__ __device__ constexpr 
index_t depth([[maybe_unused]] 
const index_t& dim) { 
return 0; }
 
  418 template <
index_t... Idxs, 
typename T>
 
  419 __host__ __device__ constexpr 
auto depth(
const T& elem)
 
  421     return depth(get<Idxs...>(elem));
 
  430 template <
typename LayoutType>
 
  431 __host__ __device__ constexpr 
const auto& 
shape(
const LayoutType& 
layout)
 
  445 template <
typename Shape, 
typename UnrolledDesc, 
typename TileLengths>
 
  447                                        const TileLengths& tile_lengths)
 
  449     auto& unrolled_desc = 
layout.GetUnrolledDescriptor();
 
  451     constexpr 
auto do_pads_seq =
 
  458         [&](
auto i) { 
return padded_desc.GetLength(Number<i>{}); }, 
Number<TileLengths::Size()>{});
 
  473 template <index_t Idx, 
typename Shape, 
typename UnrolledDesc, 
typename NewLengths, 
typename NewIdxs>
 
  475                                            const NewLengths& new_lengths,
 
  476                                            [[maybe_unused]] 
const NewIdxs& new_indexes)
 
  479     auto& unrolled_desc      = 
layout.GetUnrolledDescriptor();
 
  480     constexpr 
auto dims      = Shape::Size();
 
  484             if constexpr(i == Idx)
 
  495     constexpr 
auto lower_dims =
 
  496         generate_tuple([&](
auto i) { 
return Sequence<i.value>{}; }, Number<dims>{});
 
  507                 return Sequence<index>{};
 
  512     const auto unmerged_desc =
 
  514     const auto unmerged_shape =
 
  515         generate_tuple([&](
auto i) { 
return unmerged_desc.GetLength(Number<i>{}); },
 
  516                        Number<decltype(unmerged_desc)::GetNumOfVisibleDimension()>{});
 
auto pad(ck::index_t mpb, ck::index_t npb, ck::index_t kpb, ck::tensor_operation::device::GemmSpecialization gemm, CDesc_MRaw_NRaw conv)
Definition: helper.hpp:70
 
__host__ constexpr __device__ auto rank([[maybe_unused]] const Layout< Shape, UnrolledDescriptorType > &layout)
Get layout rank (num elements in shape).
Definition: layout_utils.hpp:310
 
__host__ constexpr __device__ auto depth(const Layout< Shape, UnrolledDescriptorType > &layout)
Get depth of the layout shape (return 0 if scalar).
Definition: layout_utils.hpp:371
 
__host__ constexpr __device__ auto get(const Tuple< Dims... > &tuple)
Get element from tuple (Shape/Strides/Idxs).
Definition: layout_utils.hpp:149
 
__host__ constexpr __device__ auto size(const Layout< Shape, UnrolledDescriptorType > &layout)
Length get (product if tuple).
Definition: layout_utils.hpp:243
 
__host__ constexpr __device__ auto unmerge(const Layout< Shape, UnrolledDesc > &layout, const NewLengths &new_lengths, [[maybe_unused]] const NewIdxs &new_indexes)
Unmerge selected dim in layout.
Definition: layout_utils.hpp:474
 
__host__ constexpr __device__ const auto & shape(const LayoutType &layout)
Get Layout shape.
Definition: layout_utils.hpp:431
 
__host__ constexpr __device__ auto make_layout(const Shape &shape, const Strides &strides)
Make layout function.
Definition: layout_utils.hpp:105
 
__host__ constexpr __device__ auto PadTensorDescriptor(const TensorDesc &desc, const TileLengths &tile_lengths, DoPads)
Definition: matrix_padder.hpp:19
 
__host__ constexpr __device__ auto TupleReduce(F &&f, const Tuple< Ts... > &tuple)
Definition: tuple_helper.hpp:161
 
__host__ constexpr __device__ auto to_sequence(Tuple< Number< Is >... >)
Definition: sequence_helper.hpp:32
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
 
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
 
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
 
__host__ constexpr __device__ auto UnrollNestedTuple(const Tuple<> &element)
Definition: tuple_helper.hpp:120
 
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
 
__host__ constexpr __device__ auto TupleDepth(const T &)
Definition: tuple_helper.hpp:188
 
__host__ constexpr __device__ auto generate_sequence_v2(F &&f, Number< N >)
Definition: sequence_helper.hpp:25
 
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
 
constexpr bool is_same_v
Definition: type.hpp:283
 
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: is_detected.hpp:34
 
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
 
int32_t index_t
Definition: ck.hpp:297
 
decltype(ck::declval< T & >().IsTuple()) is_tuple
Definition: tuple_helper.hpp:176
 
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
 
integral_constant< index_t, N > Number
Definition: number.hpp:12
 
Layout wrapper that performs the tensor descriptor logic.
Definition: layout.hpp:24
 
__host__ constexpr __device__ const auto & layout(const Tensor< BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType > &tensor)
Get Tensor Layout.
Definition: tensor_utils.hpp:162