21 template <
typename Distribution>
24 return Distribution::_get_partition_index();
29 template <
index_t... PartialHsLengths>
40 template <
index_t... PartialHsIndices>
66 template <
typename PsYs2XsAdaptor_,
67 typename Ys2DDescriptor_,
68 typename StaticTileDistributionEncoding_,
69 typename TileDistributionDetail_>
79 "wrong! should be static");
81 static constexpr
index_t NDimX = PsYs2XsAdaptor::get_num_of_bottom_dimension();
82 static constexpr
index_t NDimY = Ys2DDescriptor::get_num_of_top_dimension();
84 static constexpr
index_t NDimR = StaticTileDistributionEncoding_::NDimR;
97 static_assert(
NDimP == 1 or
NDimP == 2,
"wrong!");
99 if constexpr(
NDimP == 1)
103 else if constexpr(
NDimP == 2)
141 template <
typename PartitionIndex>
144 static_assert(PartitionIndex::size() ==
NDimP,
"wrong!");
153 constexpr
index_t ndim_low = DstrEncode::ps_to_rhss_major_[idim_p].size();
156 constexpr
index_t rh_major = DstrEncode::ps_to_rhss_major_[idim_p][i];
157 constexpr
index_t rh_minor = DstrEncode::ps_to_rhss_minor_[idim_p][i];
160 if constexpr(rh_major == 0)
162 constexpr
index_t adaptor_hidden_id =
163 DstrDetail::rh_major_minor_to_adaptor_hidden_idss_[rh_major][rh_minor];
166 rs_idx(rh_minor) = dummy_adaptor_coord.get_hidden_index()[adaptor_hidden_id];
175 template <
typename PartitionIndex = decltype(_get_partition_index())>
180 const auto window_adaptor_thread_coord_tmp =
182 return window_adaptor_thread_coord_tmp.get_bottom_index();
187 constexpr
auto distributed_spans_impl = DstrEncode::detail::distributed_spans_lengthss_;
188 constexpr
auto ndims_spans_minor = DstrEncode::detail::ndims_distributed_spans_minor_;
192 constexpr
auto span_impl = distributed_spans_impl[i];
193 constexpr
index_t ndim_span_minor = ndims_spans_minor[i];
203 template <
typename DistributedIndices>
207 constexpr
auto ys_idx_arr = [] {
211 constexpr
index_t span_major = DstrEncode::detail::ys_to_span_major_[i];
212 constexpr
index_t span_minor = DstrEncode::detail::ys_to_span_minor_[i];
216 ys_idx(i) = dstr_index.impl_[span_minor];
234 printf(
"tile_distribution{");
236 printf(
"tile_distribution_encoding: ");
240 printf(
"ps_ys_to_xs_: ");
244 printf(
"ys_to_d_: ");
253 template <index_t NDimMax>
258 for(
index_t i = 0; i < iend - ibegin; ++i)
267 template <
typename StaticTileDistributionEncoding_>
271 using RsLengths =
typename StaticTileDistributionEncoding_::RsLengths;
272 using HsLengthss =
typename StaticTileDistributionEncoding_::HsLengthss;
273 using Ps2RHssMajor =
typename StaticTileDistributionEncoding_::Ps2RHssMajor;
274 using Ps2RHssMinor =
typename StaticTileDistributionEncoding_::Ps2RHssMinor;
275 using Ys2RHsMajor =
typename StaticTileDistributionEncoding_::Ys2RHsMajor;
276 using Ys2RHsMinor =
typename StaticTileDistributionEncoding_::Ys2RHsMinor;
279 constexpr
index_t kMaxNumTransforms = 20;
280 constexpr
index_t kMaxMetaDataSize = 128;
281 constexpr
index_t kMaxNumDim = 10;
292 constexpr
index_t ndim_x = HsLengthss::size();
301 index_t hidden_dim_cnt = ndim_x;
305 constexpr
index_t ndim_r_minor = RsLengths::size();
307 constexpr
auto r_minor_lengths = RsLengths{};
309 trans(num_tran++) = {
311 MetaData{to_array<index_t, ndim_r_minor>(r_minor_lengths)},
314 NumDim{ndim_r_minor},
315 make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_r_minor)};
317 for(
index_t i = 0; i < ndim_r_minor; ++i)
319 rh_major_minor_to_hidden_ids(0)(i) = hidden_dim_cnt;
320 rh_major_minor_to_hidden_lengths(0)(i) = r_minor_lengths[i];
330 &rh_major_minor_to_hidden_ids,
331 &rh_major_minor_to_hidden_lengths](
auto idim_x) {
333 constexpr
auto h_minor_lengths =
334 HsLengthss{}.get(idim_x);
337 constexpr
index_t ndim_h_minor = h_minor_lengths.size();
339 trans(num_tran++) = {
341 MetaData{to_array<index_t, ndim_h_minor>(h_minor_lengths)},
344 NumDim{ndim_h_minor},
345 make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_h_minor)};
347 for(
index_t i = 0; i < ndim_h_minor; ++i)
349 rh_major_minor_to_hidden_ids(idim_x + 1)(i) = hidden_dim_cnt;
350 rh_major_minor_to_hidden_lengths(idim_x + 1)(i) = h_minor_lengths[i];
357 constexpr
index_t ndim_p = Ps2RHssMajor::size();
359 Dims hidden_dim_id_ps;
363 index_t hidden_dim_id_p = hidden_dim_cnt++;
365 hidden_dim_id_ps(iDimP) = hidden_dim_id_p;
367 constexpr
auto p2RHsMajor = Ps2RHssMajor{}[iDimP];
368 constexpr
auto p2RHsMinor = Ps2RHssMinor{}[iDimP];
370 static_assert(p2RHsMajor.size() == p2RHsMinor.size(),
"wrong!");
372 constexpr
index_t ndim_low = p2RHsMajor.size();
377 for(
index_t i = 0; i < ndim_low; ++i)
379 index_t rh_major = p2RHsMajor[i];
380 index_t rh_minor = p2RHsMinor[i];
381 low_dims(i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
382 low_lengths(i) = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
386 MetaData{to_array<index_t, ndim_low>(low_lengths)},
390 Dims{hidden_dim_id_p}};
393 constexpr
index_t ndim_bottom = ndim_x;
395 constexpr
auto bottom_dim_ids = make_sequential_index<kMaxNumDim>(0, ndim_bottom);
397 constexpr
auto ys_to_rhs_major = Ys2RHsMajor{};
398 constexpr
auto ys_to_rhs_minor = Ys2RHsMinor{};
400 constexpr
index_t ndim_y = Ys2RHsMajor::size();
401 constexpr
index_t ndim_top = ndim_p + ndim_y;
403 auto top_dim_ids = hidden_dim_id_ps;
406 for(
index_t i = 0; i < ndim_y; ++i)
408 index_t rh_major = ys_to_rhs_major[i];
409 index_t rh_minor = ys_to_rhs_minor[i];
410 top_dim_ids(ndim_p + i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
415 const auto ps_ys_to_xs_adaptor_encoding =
416 make_tuple(trans, num_tran, bottom_dim_ids, ndim_bottom, top_dim_ids, ndim_top);
422 for(
index_t i = 0; i < ndim_y; ++i)
424 index_t rh_major = ys_to_rhs_major[i];
425 index_t rh_minor = ys_to_rhs_minor[i];
426 index_t y_length = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
427 y_lengths(i) = y_length;
428 d_length *= y_length;
432 MetaData{to_array<index_t, ndim_y>(y_lengths)},
436 make_sequential_index<kMaxNumDim>(1, ndim_y + 1));
438 const auto ys_to_d_adaptor_encoding =
make_tuple(
439 make_tuple(tran), 1, Dims{0}, 1, make_sequential_index<kMaxNumDim>(1, ndim_y + 1), ndim_y);
441 return make_tuple(ps_ys_to_xs_adaptor_encoding,
442 ys_to_d_adaptor_encoding,
444 rh_major_minor_to_hidden_ids);
448 template <
typename RhMajorMinor2AdaptorH
iddenIdss>
459 template <
typename StaticTileDistributionEncoding_>
462 using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
464 constexpr
auto adaptor_impl =
467 constexpr
auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
468 constexpr
auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
469 constexpr
index_t d_length = adaptor_impl.template at<2>();
470 constexpr
auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
472 constexpr
auto ps_ys_to_xs_adaptor =
477 constexpr
auto ys_to_d_descriptor =
481 constexpr
index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
482 constexpr
auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
484 constexpr
auto rh_major_minor_to_hidden_ids =
487 return tile_distribution<
490 remove_cvref_t<DstrEncode>,
491 detail::tile_distribution_detail<
remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
492 ps_ys_to_xs_adaptor, ys_to_d_descriptor};
497 template <
typename StaticTileDistributionEncoding_>
502 constexpr
auto adaptor_impl =
505 constexpr
auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
506 constexpr
auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
507 constexpr
index_t d_length = adaptor_impl.template at<2>();
508 constexpr
auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
510 constexpr
auto ps_ys_to_xs_adaptor =
513 constexpr
auto ys_to_d_adaptor =
516 constexpr
auto ys_to_d_descriptor =
520 constexpr
index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
521 constexpr
auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
523 constexpr
auto rh_major_minor_to_hidden_ids =
531 ps_ys_to_xs_adaptor, ys_to_d_descriptor};
571 template <
typename Distribution,
index_t... XSliceBegins,
index_t... XSliceEnds>
577 using Encoding = decltype(Distribution::get_static_tile_distribution_encoding());
579 static_assert(
sizeof...(XSliceBegins) ==
sizeof...(XSliceEnds));
581 constexpr
auto x_slice_lengths = x_slice_ends - x_slice_begins;
583 constexpr
auto src_h_prefix_sum = Encoding::detail::get_h_dim_lengths_prefix_sum();
584 constexpr
auto src_y_info = Encoding::detail::get_sorted_y_info();
585 constexpr
auto src_y_dims = src_y_info[
number<0>{}];
586 constexpr
auto src_y_maps = src_y_info[
number<1>{}];
587 constexpr
auto src_y_prefix_sum = src_y_info[
number<2>{}];
589 constexpr
auto sliced_hlen_yidx_ylen = [&]() constexpr
591 auto y_slice_sorted_origins = make_zero_multi_index<Encoding::NDimY>();
592 auto y_slice_lengths = Encoding::detail::ys_lengths_;
598 [&](
auto h_len,
auto id) {
599 constexpr
auto sliced_h =
602 constexpr
auto sliced_h_lens = sliced_h[
number<0>{}];
603 constexpr
auto sliced_h_index = sliced_h[
number<2>{}];
607 constexpr
auto found_y_index =
container_find(src_y_dims, uniformed_h_index);
609 static_assert(found_y_index >= 0 && found_y_index < src_y_dims.size(),
610 "not sliced at y dim, please check");
613 y_slice_lengths(src_y_maps[found_y_index - i]) =
614 sliced_h_lens[sliced_h_index - i];
620 constexpr
auto y_origin = [&]() {
623 h_trans.calculate_lower_index(h_origin_,
sequence<x_slice_begins[
id].value>{});
625 auto y_origin_ = make_zero_multi_index<Encoding::NDimY>();
627 y_origin_(found_y_index - i) = h_origin_[sliced_h_index - i];
633 src_y_prefix_sum[
id + 1],
638 return sliced_h_lens;
640 typename Encoding::HsLengthss{},
645 return make_tuple(new_h_lengths, y_slice_origins, y_slice_lengths);
649 constexpr
auto sliced_h_lengths = sliced_hlen_yidx_ylen[
number<0>{}];
650 constexpr
auto sliced_y_origins_array = sliced_hlen_yidx_ylen[
number<1>{}];
651 constexpr
auto sliced_y_origins_size = sliced_y_origins_array.size();
652 constexpr
auto sliced_y_lengths_array = sliced_hlen_yidx_ylen[
number<2>{}];
653 constexpr
auto sliced_y_lengths_size = sliced_y_lengths_array.size();
655 constexpr
auto sliced_y_origins =
TO_SEQUENCE(sliced_y_origins_array, sliced_y_origins_size);
656 constexpr
auto sliced_y_lengths =
TO_SEQUENCE(sliced_y_lengths_array, sliced_y_lengths_size);
664 typename Encoding::Ps2RHssMajor,
665 typename Encoding::Ps2RHssMinor,
666 typename Encoding::Ys2RHsMajor,
667 typename Encoding::Ys2RHsMinor>{}),
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
constexpr CK_TILE_HOST_DEVICE auto make_sequential_index(index_t ibegin, index_t iend)
Definition: tile_distribution.hpp:254
constexpr CK_TILE_HOST_DEVICE auto make_tile_distributed_span(sequence< Is... >)
Definition: tile_distribution.hpp:53
constexpr CK_TILE_HOST_DEVICE auto slice_distribution_from_x(Distribution, sequence< XSliceBegins... > x_slice_begins, sequence< XSliceEnds... > x_slice_ends)
Definition: tile_distribution.hpp:572
constexpr CK_TILE_HOST_DEVICE auto make_tile_distributed_index(sequence< Is... >)
Definition: tile_distribution.hpp:59
constexpr CK_TILE_HOST_DEVICE auto make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:269
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:22
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_zero_multi_index()
Definition: multi_index.hpp:26
constexpr CK_TILE_HOST_DEVICE auto container_reorder_given_old2new(const array< TData, NSize > &old_array, sequence< IRs... > old2new)
Definition: container_helper.hpp:48
constexpr CK_TILE_HOST_DEVICE auto container_reduce(const Container &x, Reduce reduce, Init init, number< IBegin >=number< 0 >{}, number< IEnd >=number< Container::size()>{}, number< IStep >=number< 1 >{})
Definition: container_helper.hpp:198
coord_transform_enum
Definition: coordinate_transform.hpp:16
CK_TILE_DEVICE index_t get_lane_id()
Definition: arch.hpp:69
constexpr CK_TILE_HOST_DEVICE void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition: container_helper.hpp:420
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:55
constexpr auto reverse_slice_sequence(Seq, number< SliceSize >, Mask=typename uniform_sequence_gen< Seq::size(), 1 >::type{})
Definition: sequence.hpp:1205
constexpr CK_TILE_HOST_DEVICE auto to_array_of_array(tuple< Seqs... > t_of_s)
Definition: tuple.hpp:589
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constexpr CK_TILE_HOST_DEVICE auto make_tensor_descriptor_from_adaptor(const Adaptor &adaptor, const ElementSpaceSize &element_space_size)
Definition: tensor_descriptor.hpp:164
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1666
CK_TILE_DEVICE index_t get_warp_id()
Definition: arch.hpp:71
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:400
constexpr index_t container_find(sequence< Is... > seq, index_t value)
Definition: container_helper.hpp:447
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
constexpr CK_TILE_HOST_DEVICE auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition: container_helper.hpp:389
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:498
impl::is_static_impl< remove_cvref_t< T > > is_static
Definition: type_traits.hpp:86
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
constexpr CK_TILE_HOST_DEVICE auto transform_tuples(F f, const X &x)
Definition: tuple.hpp:471
Definition: sequence.hpp:278
Definition: integral_constant.hpp:13
Definition: tile_distribution.hpp:450
static constexpr auto rh_major_minor_to_adaptor_hidden_idss_
Definition: tile_distribution.hpp:451
Definition: sequence.hpp:52
Definition: functional.hpp:43
Definition: tile_distribution.hpp:42
static constexpr CK_TILE_HOST_DEVICE bool is_static()
Definition: tile_distribution.hpp:47
static constexpr auto impl_
Definition: tile_distribution.hpp:45
Definition: tile_distribution.hpp:31
static constexpr auto impl_
Definition: tile_distribution.hpp:34
static constexpr CK_TILE_HOST_DEVICE bool is_static()
Definition: tile_distribution.hpp:36
Definition: tile_distribution_encoding.hpp:26
Definition: tile_distribution.hpp:72
remove_cvref_t< Ys2DDescriptor_ > Ys2DDescriptor
Definition: tile_distribution.hpp:74
PsYs2XsAdaptor ps_ys_to_xs_
Definition: tile_distribution.hpp:86
static constexpr CK_TILE_HOST_DEVICE auto get_distributed_spans()
Definition: tile_distribution.hpp:185
static CK_TILE_HOST_DEVICE auto _get_partition_index()
Definition: tile_distribution.hpp:94
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:126
static constexpr index_t NDimY
Definition: tile_distribution.hpp:82
remove_cvref_t< StaticTileDistributionEncoding_ > DstrEncode
Definition: tile_distribution.hpp:75
remove_cvref_t< TileDistributionDetail_ > DstrDetail
Definition: tile_distribution.hpp:76
CK_TILE_HOST_DEVICE auto calculate_index(const PartitionIndex &ps_idx=_get_partition_index()) const
Definition: tile_distribution.hpp:177
static constexpr CK_TILE_HOST_DEVICE auto get_lengths()
Definition: tile_distribution.hpp:109
static constexpr index_t NDimP
Definition: tile_distribution.hpp:83
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_dimension_x()
Definition: tile_distribution.hpp:89
static constexpr CK_TILE_HOST_DEVICE auto get_y_indices_from_distributed_indices(DistributedIndices)
Definition: tile_distribution.hpp:205
CK_TILE_HOST_DEVICE auto calculate_rs_index_from_ps_index(const PartitionIndex &ps_idx) const
Definition: tile_distribution.hpp:142
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_dimension_p()
Definition: tile_distribution.hpp:91
constexpr CK_TILE_HOST_DEVICE const auto & get_ys_to_d_descriptor() const
Definition: tile_distribution.hpp:131
CK_TILE_HOST_DEVICE void print() const
Definition: tile_distribution.hpp:232
remove_cvref_t< PsYs2XsAdaptor_ > PsYs2XsAdaptor
Definition: tile_distribution.hpp:73
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_dimension_r()
Definition: tile_distribution.hpp:92
static constexpr index_t NDimR
Definition: tile_distribution.hpp:84
static constexpr CK_TILE_HOST_DEVICE bool is_static()
Definition: tile_distribution.hpp:227
Ys2DDescriptor ys_to_d_
Definition: tile_distribution.hpp:87
static constexpr index_t NDimX
Definition: tile_distribution.hpp:81
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_dimension_y()
Definition: tile_distribution.hpp:90
static constexpr CK_TILE_HOST_DEVICE auto get_static_tile_distribution_encoding()
Definition: tile_distribution.hpp:133
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes)
Definition: container_helper.hpp:486
#define CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor)
Definition: tensor_adaptor.hpp:833
#define CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor)
Definition: tensor_adaptor.hpp:709
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10