32 template <
typename... Ts, 
typename... Ls>
 
   33 __host__ __device__ constexpr 
auto CalculateLocalPartitionShape(
const Tuple<Ts...>& 
shape,
 
   34                                                                 const Tuple<Ls...>& thread_lengths)
 
   36     static_assert(Tuple<Ts...>::Size() == Tuple<Ls...>::Size(), 
"Wrong thread_lengths shape.");
 
   39             constexpr 
auto num_i = Number<i>{};
 
   40             const auto slice_len =
 
   44         Number<Tuple<Ls...>::Size()>{});
 
   56 template <
typename MultiIndex, 
typename ProjectionTuple>
 
   57 __host__ __device__ constexpr 
auto 
   58 ApplyProjection([[maybe_unused]] 
const MultiIndex& base_tuple,
 
   59                 [[maybe_unused]] 
const ProjectionTuple& projection)
 
   61     if constexpr(
is_same_v<ProjectionTuple, Tuple<>>)
 
   69                 const auto i_num = 
Number<i.value>{};
 
   71                     is_detected<is_slice, tuple_element_t<i_num, ProjectionTuple>>::value ||
 
   72                     is_same_v<tuple_element_t<i_num, ProjectionTuple>, Number<1>>);
 
   73                 if constexpr(
is_detected<is_slice, tuple_element_t<i_num, ProjectionTuple>>::value)
 
   86         return UnrollNestedTuple<0, 1>(base_tuple_after_projection);
 
   99 template <
typename... Ts, 
typename... Ps>
 
  100 __host__ __device__ constexpr 
auto CalculateShapeWithProjection(
const Tuple<Ts...>& 
shape,
 
  101                                                                 const Tuple<Ps...>& projection)
 
  107                 return size<i>(projection).to_;
 
  114                     detail::ApplyProjection(TupleSlice<0, i>(Tuple<Ts...>{}),
 
  115                                             TupleSlice<0, i>(Tuple<Ps...>{}))
 
  117                 return size<shape_i>(
shape);
 
  120         Number<Tuple<Ps...>::Size()>{});
 
  130 template <
typename... Ts, 
typename... Ls, 
typename... Ps>
 
  131 __host__ __device__ constexpr 
auto CalculateGridSize(
const Tuple<Ts...>& 
shape,
 
  132                                                      const Tuple<Ls...>& tile_shape)
 
  136         Number<Tuple<Ls...>::Size()>{});
 
  147 template <
typename ThreadIdxs, 
typename PartitionLengthsSeq, 
typename OldOffsetIdxs>
 
  148 __host__ __device__ constexpr 
auto 
  149 CalculateOffsetMultiIdxs(
const ThreadIdxs& thread_idxs,
 
  150                          const PartitionLengthsSeq& partition_lengths_seq,
 
  151                          const OldOffsetIdxs& old_offset_idxs)
 
  153     return thread_idxs * partition_lengths_seq + old_offset_idxs;
 
  162 template <
typename BlockIdxs>
 
  163 __host__ __device__ constexpr 
auto GetDimsToPartition([[maybe_unused]] 
const BlockIdxs& block_idxs)
 
  167             if constexpr(!
is_detected<is_slice, tuple_element_t<i, BlockIdxs>>::value)
 
  176         Number<BlockIdxs::Size()>{});
 
  178     return UnrollNestedTuple<0, 1>(dims_to_partition);
 
  187 template <
typename BlockIdxs>
 
  188 __host__ __device__ constexpr 
auto ReplaceSlicesWithZeros(
const BlockIdxs& block_idxs)
 
  192             if constexpr(!
is_detected<is_slice, tuple_element_t<i, BlockIdxs>>::value)
 
  194                 return block_idxs.At(i);
 
  201         Number<BlockIdxs::Size()>{});
 
  210 template <
typename TileShape>
 
  211 __host__ __device__ constexpr 
auto 
  212 GenerateDefaultProjection([[maybe_unused]] 
const TileShape tile_shape)
 
  224 template <
typename ThreadShape, 
typename ThreadUnrolledDesc>
 
  225 __host__ __device__ constexpr 
auto CalculateThreadMultiIdx(
 
  229     static_assert(ThreadUnrolledDesc::GetNumOfTransform() == 1,
 
  230                   "Thread layout should not be transformed.");
 
  231     constexpr 
auto embed_transform = ThreadUnrolledDesc{}.GetTransforms().At(Number<0>{});
 
  232     constexpr 
auto shape           = ThreadShape{};
 
  233     constexpr 
auto strides         = embed_transform.coefficients_;
 
  237             constexpr 
auto num_i = Number<i>{};
 
  238             return (thread_id / strides.At(num_i)) % 
shape.At(num_i);
 
  240         Number<ThreadShape::Size()>{});
 
  258 template <
typename TensorType,
 
  259           typename ThreadShape,
 
  260           typename ThreadUnrolledDesc,
 
  261           typename ProjectionTuple>
 
  262 __host__ __device__ constexpr 
auto 
  266                      const ProjectionTuple& projection)
 
  270     const auto& tensor_shape = 
shape(tensor);
 
  272     constexpr 
auto projected_thread_lengths =
 
  273         detail::ApplyProjection(ThreadShape{}, ProjectionTuple{});
 
  274     constexpr 
auto partition_shape =
 
  275         detail::CalculateLocalPartitionShape(decltype(tensor_shape){}, projected_thread_lengths);
 
  276     constexpr 
auto partition_shape_seq =
 
  278                              Number<decltype(partition_shape)::Size()>{});
 
  280     const auto thread_idxs = detail::CalculateThreadMultiIdx(thread_layout, thread_id);
 
  282     const auto projected_thread_idxs = detail::ApplyProjection(thread_idxs, projection);
 
  283     const auto offset_multi_idxs     = detail::CalculateOffsetMultiIdxs(
 
  284         projected_thread_idxs, partition_shape_seq, tensor.GetMultiIdxOffsets());
 
  286     auto& unrolled_desc = 
layout(tensor).GetUnrolledDescriptor();
 
  291                                         offset_multi_idxs.At(i),
 
  292                                         partition_shape.At(i) + offset_multi_idxs.At(i));
 
  295     const auto lower_upper_dims =
 
  301     const auto partition_layout =
 
  303             partition_shape, sliced_desc);
 
  304     auto partition_tensor =
 
  305         make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), partition_layout);
 
  307     return partition_tensor;
 
  319 template <
typename TensorType, 
typename ThreadShape, 
typename ThreadUnrolledDesc>
 
  320 __host__ __device__ constexpr 
auto 
  325     const auto projection = detail::GenerateDefaultProjection(ThreadShape{});
 
  346 template <
typename TensorType,
 
  347           typename BlockShapeTuple,
 
  349           typename ProjectionTuple>
 
  351                                                    const BlockShapeTuple& tile_shape,
 
  352                                                    const BlockIdxs& block_idxs,
 
  353                                                    const ProjectionTuple& projection)
 
  358     constexpr 
auto I0 = Number<0>{};
 
  359     constexpr 
auto I1 = Number<1>{};
 
  360     constexpr 
auto I2 = Number<2>{};
 
  362     auto& aligned_desc = 
layout(tensor).GetMergedNestingDescriptor();
 
  364     constexpr 
auto projected_tile_shape =
 
  365         detail::ApplyProjection(BlockShapeTuple{}, ProjectionTuple{});
 
  367     constexpr 
auto dims_to_partition = detail::GetDimsToPartition(BlockIdxs{});
 
  368     const auto parsed_block_idxs     = detail::ReplaceSlicesWithZeros(block_idxs);
 
  369     if constexpr(decltype(dims_to_partition)::Size() == I2)
 
  371         const auto shape_with_projection_dims =
 
  372             detail::CalculateShapeWithProjection(
shape(tensor), projection);
 
  374         const auto M             = shape_with_projection_dims.At(dims_to_partition.At(I0));
 
  375         const auto N             = shape_with_projection_dims.At(dims_to_partition.At(I1));
 
  376         constexpr 
auto MPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I0));
 
  377         constexpr 
auto NPerBlock = BlockShapeTuple{}.At(dims_to_partition.At(I1));
 
  380         const auto grid_size = detail::CalculateGridSize(shape_with_projection_dims, tile_shape);
 
  382         const index_t block_id_1d     = block_lengths_desc.CalculateOffset(parsed_block_idxs);
 
  384         const auto block_2_tile_map =
 
  385             BlockToCTileMap_M00_N0_M01Adapt<MPerBlock,
 
  388         const auto block_work_idx =
 
  390         const index_t m_block_data_idx_on_grid =
 
  391             __builtin_amdgcn_readfirstlane(block_work_idx[I0] * MPerBlock);
 
  392         const index_t n_block_data_idx_on_grid =
 
  393             __builtin_amdgcn_readfirstlane(block_work_idx[I1] * NPerBlock);
 
  397                 if constexpr(i == dims_to_partition.At(I0))
 
  399                     return m_block_data_idx_on_grid;
 
  401                 else if constexpr(i == dims_to_partition.At(I1))
 
  403                     return n_block_data_idx_on_grid;
 
  410             Number<BlockShapeTuple::Size()>{});
 
  411         const auto projected_offset_multi_idxs =
 
  412             detail::ApplyProjection(offset_multi_idxs, projection);
 
  414         const auto tile_layout =
 
  416                 projected_tile_shape, aligned_desc);
 
  418             make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout);
 
  420         tile_tensor.SetMultiIdxOffset(
to_multi_index(projected_offset_multi_idxs));
 
  427         using ProjectedTileShapeTuple = decltype(projected_tile_shape);
 
  428         constexpr 
auto projected_tile_shape_seq =
 
  430                                  Number<ProjectedTileShapeTuple::Size()>{});
 
  432         const auto projected_block_idxs =
 
  433             to_multi_index(detail::ApplyProjection(parsed_block_idxs, projection));
 
  434         const auto offset_multi_idxs = detail::CalculateOffsetMultiIdxs(
 
  435             projected_block_idxs, projected_tile_shape_seq, tensor.GetMultiIdxOffsets());
 
  437         const auto tile_layout =
 
  439                 projected_tile_shape, aligned_desc);
 
  441             make_tensor<TensorType::TensorBufferAddressSpace>(tensor.GetPointer(), tile_layout);
 
  460 template <
typename TensorType, 
typename BlockShapeTuple, 
typename BlockIdxs>
 
  462                                                    const BlockShapeTuple& tile_shape,
 
  463                                                    const BlockIdxs& block_idxs)
 
  465     const auto projection = detail::GenerateDefaultProjection(BlockShapeTuple{});
 
__host__ constexpr __device__ const auto & shape(const LayoutType &layout)
Get Layout shape.
Definition: layout_utils.hpp:431
 
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
typename remove_reference< T >::type remove_reference_t
Definition: type.hpp:292
 
__host__ constexpr __device__ auto IsNestedTuple(const Tuple< Ts... > &)
Definition: tuple_helper.hpp:180
 
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
 
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
__host__ constexpr __device__ auto to_multi_index(const T &x)
Definition: array_multi_index.hpp:28
 
__host__ constexpr __device__ auto generate_sequence_v2(F &&f, Number< N >)
Definition: sequence_helper.hpp:25
 
__host__ constexpr __device__ auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition: multi_index_transform_helper.hpp:110
 
constexpr bool is_same_v
Definition: type.hpp:283
 
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: is_detected.hpp:34
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
int32_t index_t
Definition: ck.hpp:297
 
Array< index_t, N > MultiIndex
Definition: array_multi_index.hpp:12
 
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
 
integral_constant< index_t, N > Number
Definition: number.hpp:12
 
Layout wrapper that performs the tensor descriptor logic.
Definition: layout.hpp:24
 
__host__ static constexpr __device__ index_t Size()
Definition: array.hpp:20
 
__host__ constexpr __device__ auto make_local_partition(TensorType &tensor, [[maybe_unused]] const Layout< ThreadShape, ThreadUnrolledDesc > &thread_layout, const index_t thread_id, const ProjectionTuple &projection)
Create local partition for thread (At now only packed partition is supported).
Definition: tensor_partition.hpp:263
 
__host__ constexpr __device__ auto make_local_tile(const TensorType &tensor, const BlockShapeTuple &tile_shape, const BlockIdxs &block_idxs, const ProjectionTuple &projection)
Create local tile for thread block. (At now only packed tile is supported).
Definition: tensor_partition.hpp:350
 
__host__ constexpr __device__ const auto & layout(const Tensor< BufferAddressSpace, ElementType, Shape, UnrolledDescriptorType > &tensor)
Get Tensor Layout.
Definition: tensor_utils.hpp:162