19 template <
typename RsLengths_,
21 typename Ps2RHssMajor_,
22 typename Ps2RHssMinor_,
23 typename Ys2RHsMajor_,
24 typename Ys2RHsMinor_>
34 static_assert(Ps2RHssMajor::size() == Ps2RHssMinor::size(),
"wrong!");
35 static_assert(Ys2RHsMajor::size() == Ys2RHsMinor::size(),
"wrong!");
61 if constexpr(i.value == 0)
92 return ys_lengths_tmp;
103 rhs_major_minor_to_ys_tmp(rh_major)(rh_minor) = i;
106 return rhs_major_minor_to_ys_tmp;
117 ndims_span_minor(span_major)++;
120 return ndims_span_minor;
135 index_t cnt_ndim_span_minor = 0;
142 rhs_major_minor_to_span_minor(rh_major)(rh_minor) = cnt_ndim_span_minor;
144 cnt_ndim_span_minor++;
149 return rhs_major_minor_to_span_minor;
166 distributed_spans_lengthss{{-1}};
174 const index_t span_major = rh_major - 1;
177 distributed_spans_lengthss(span_major)(span_minor) = h_length;
180 return distributed_spans_lengthss;
190 ndims_distributed_spans_minor(span_major)++;
193 return ndims_distributed_spans_minor;
198 if constexpr(
NDimR > 0)
209 if constexpr(rh_major == 0)
211 does_p_own_r(idim_p)(rh_minor) =
true;
226 if constexpr(
NDimR > 0)
233 index_t p_over_rh_derivative = 1;
235 static_for<ndim_low - 1, -1, -1>{}([&](
auto idim_low) {
241 if constexpr(rh_major == 0)
243 ps_over_rs_derivative(idim_p)(rh_minor) = p_over_rh_derivative;
246 p_over_rh_derivative *= rh_length;
250 return ps_over_rs_derivative;
274 return h_dim_prefix_sum;
280 [](
auto major,
auto minor) constexpr {
284 return x_dim_prefix_sum.at(major) + minor;
289 return all_ys_2_rhss;
293 template <
typename IdxSeq,
typename PrefixSumSeq>
298 constexpr
auto sorted_dims =
typename sorted_idx::type{};
299 constexpr
auto sorted_maps =
typename sorted_idx::sorted2unsorted_map{};
301 constexpr
auto sorted_histogram =
305 return make_tuple(sorted_dims, sorted_maps, sorted_prefix_sum);
315 printf(
"tile_distribution_encoding::detail{");
317 printf(
"ndim_rh_major_: ");
321 printf(
"ndim_span_major_: ");
325 printf(
"ndims_rhs_minor_: ");
329 printf(
"ndim_rh_major_: ");
333 printf(
"max_ndim_rh_minor_: ");
337 printf(
"rhs_lengthss_: ");
341 printf(
"ys_lengths_: ");
345 printf(
"rhs_major_minor_to_ys_: ");
349 printf(
"ndims_span_minor_: ");
353 printf(
"max_ndim_span_minor_: ");
357 printf(
"ys_to_span_major_: ");
361 printf(
"ys_to_span_minor_: ");
365 printf(
"distributed_spans_lengthss_: ");
369 printf(
"ndims_distributed_spans_minor_: ");
373 printf(
"ps_over_rs_derivative_: ");
382 printf(
"tile_distribution_encoding{");
386 printf(
"rs_lengths_: ");
390 printf(
"hs_lengthss_: ");
394 printf(
"ps_to_rhss_major_: ");
398 printf(
"ps_to_rhss_minor_: ");
402 printf(
"ys_to_rhs_major_: ");
406 printf(
"ys_to_rhs_minor_: ");
419 template <
typename OuterDstr,
typename InnerDstr>
422 static_assert(OuterDstr::NDimX == InnerDstr::NDimX,
"wrong!");
424 constexpr
index_t NDimHMajor = OuterDstr::NDimX;
432 typename InnerDstr::HsLengthss{}[i]);
437 constexpr
auto rhs_major_2_ndim_outer_rhs_minor = [&]() {
441 rhs_major_2_ndim_outer_rhs_minor_(0) = OuterDstr::RsLengths::size();
445 rhs_major_2_ndim_outer_rhs_minor_(i + 1) =
typename OuterDstr::HsLengthss{}[i].
size();
448 return rhs_major_2_ndim_outer_rhs_minor_;
454 constexpr
auto inner_p_2_rhss_major =
typename InnerDstr::Ps2RHssMajor{}[p];
455 constexpr
auto inner_p_2_rhss_minor =
typename InnerDstr::Ps2RHssMinor{}[p];
457 constexpr
index_t ndim_tmp = inner_p_2_rhss_minor.size();
459 constexpr
auto updated_inner_p_2_rhss_minor = [&]() {
462 for(
index_t i = 0; i < ndim_tmp; i++)
464 index_t rh_major = inner_p_2_rhss_major[i];
466 index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
468 updated_inner_p_2_rhss_minor_(i) = inner_p_2_rhss_minor[i] + ndim_outer_h_minor;
471 return updated_inner_p_2_rhss_minor_;
474 return TO_SEQUENCE(updated_inner_p_2_rhss_minor, ndim_tmp);
479 constexpr
auto updated_inner_ys_2_rhs_minor = [&]() {
480 constexpr
auto inner_ys_2_rhs_major =
typename InnerDstr::Ys2RHsMajor{};
481 constexpr
auto inner_ys_2_rhs_minor =
typename InnerDstr::Ys2RHsMinor{};
483 constexpr
index_t ndim_tmp = inner_ys_2_rhs_minor.size();
485 constexpr
auto updated_inner_ys_2_rhs_minor_ = [&]() {
488 for(
index_t i = 0; i < ndim_tmp; i++)
490 index_t rh_major = inner_ys_2_rhs_major[i];
492 index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
494 updated_inner_ys_2_rhs_minor__(i) = inner_ys_2_rhs_minor[i] + ndim_outer_h_minor;
497 return updated_inner_ys_2_rhs_minor__;
500 return TO_SEQUENCE(updated_inner_ys_2_rhs_minor_, ndim_tmp);
504 constexpr
auto ps_2_rhss_major =
505 container_concat(
typename OuterDstr::Ps2RHssMajor{},
typename InnerDstr::Ps2RHssMajor{});
507 constexpr
auto ps_2_rhss_minor =
508 container_concat(
typename OuterDstr::Ps2RHssMinor{}, updated_inner_ps_2_rhss_minor);
511 constexpr
auto ys_2_rhs_major =
512 merge_sequences(
typename OuterDstr::Ys2RHsMajor{},
typename InnerDstr::Ys2RHsMajor{});
514 constexpr
auto ys_2_rhs_minor =
515 merge_sequences(
typename OuterDstr::Ys2RHsMinor{}, updated_inner_ys_2_rhs_minor);
525 template <
typename InDstr,
index_t... InReduceDimXs>
532 constexpr
index_t max_ndim_r_out = 20;
533 constexpr
index_t max_ndim_y_out = 20;
536 constexpr
index_t ndim_p = InDstr::NDimP;
537 constexpr
index_t ndim_x_in = InDstr::NDimX;
538 constexpr
index_t ndim_y_in = InDstr::NDimY;
539 constexpr
index_t ndim_rh_major_in = InDstr::NDimX + 1;
540 constexpr
index_t ndim_x_out = ndim_x_in -
sizeof...(InReduceDimXs);
541 constexpr
index_t max_ndim_rh_minor_in = InDstr::detail::max_ndim_rh_minor_;
545 [&](
auto i) {
return InDstr::ps_to_rhss_major_[i].size(); },
number<ndim_p>{});
550 for(
index_t i = 0; i < reduce_dim_xs_in.
size(); i++)
552 index_t rh_major = reduce_dim_xs_in[i] + 1;
554 is_rh_major_in_for_reduce(rh_major) =
true;
560 for(
index_t i = 0; i < ndim_y_in; i++)
562 index_t rh_major = InDstr::ys_to_rhs_major_[i];
564 if(is_rh_major_in_for_reduce[rh_major])
566 is_y_in_for_reduce(i) =
true;
574 index_t rh_major = InDstr::ys_to_rhs_major_[i];
575 index_t rh_minor = InDstr::ys_to_rhs_minor_[i];
577 if(is_y_in_for_reduce[i])
579 is_rh_minor_in_for_y_reduce(rh_major)(rh_minor) =
true;
585 index_t cnt_ndim_rh_major_out = 0;
587 for(
index_t i = 0; i < ndim_rh_major_in; i++)
589 if(is_rh_major_in_for_reduce[i])
591 in2out_rh_major(i) = 0;
595 in2out_rh_major(i) = cnt_ndim_rh_major_out;
597 cnt_ndim_rh_major_out++;
606 for(
index_t i = 0; i < InDstr::rs_lengths_.size(); i++)
609 rs_lengths_out(i) = InDstr::rs_lengths_[i];
612 in2out_rh_minor(0)(i) = i;
616 index_t cnt_ndim_r_out = InDstr::rs_lengths_.
size();
619 constexpr
auto h_major_in = rh_major_in - I1;
621 constexpr
index_t ndim_rh_minor_in = InDstr::hs_lengthss_[h_major_in].size();
623 if(is_rh_major_in_for_reduce[rh_major_in])
625 for(
index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
627 if(not is_rh_minor_in_for_y_reduce[rh_major_in][rh_minor_in])
630 rs_lengths_out(cnt_ndim_r_out) = InDstr::hs_lengthss_[h_major_in][rh_minor_in];
633 in2out_rh_minor(rh_major_in)(rh_minor_in) = cnt_ndim_r_out;
641 for(
index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
644 in2out_rh_minor(rh_major_in)(rh_minor_in) = rh_minor_in;
650 const index_t ndim_r_out = cnt_ndim_r_out;
659 if(not is_rh_major_in_for_reduce[i + I1])
662 ndims_hs_minor_out(cnt_ndim_x_out) = InDstr::hs_lengthss_[i].size();
665 static_for<0, InDstr::hs_lengthss_[i].size(), 1>{}(
666 [&](
auto j) { hs_lengthss_out(cnt_ndim_x_out)(j) = InDstr::hs_lengthss_[i][j]; });
677 static_for<0, InDstr::ps_to_rhss_major_[idim_p].size(), 1>{}([&](
auto idim_low) {
678 index_t rh_major_in = InDstr::ps_to_rhss_major_[idim_p][idim_low];
679 index_t rh_minor_in = InDstr::ps_to_rhss_minor_[idim_p][idim_low];
681 ps_to_rhss_major_out(idim_p)(idim_low) = in2out_rh_major[rh_major_in];
682 ps_to_rhss_minor_out(idim_p)(idim_low) = in2out_rh_minor[rh_major_in][rh_minor_in];
693 if(not is_y_in_for_reduce[i])
695 index_t rh_major_in = InDstr::ys_to_rhs_major_[i];
696 index_t rh_minor_in = InDstr::ys_to_rhs_minor_[i];
698 ys_to_rhs_major_out(cnt_ndim_y_out) = in2out_rh_major[rh_major_in];
699 ys_to_rhs_minor_out(cnt_ndim_y_out) = in2out_rh_minor[rh_major_in][rh_minor_in];
706 const index_t ndim_y_out = cnt_ndim_y_out;
717 ps_to_rhss_major_out,
718 ps_to_rhss_minor_out,
720 ys_to_rhs_minor_out);
723 template <
typename InDstr,
index_t... InReduceDimXs>
729 constexpr
index_t ndim_x = impl.template at<0>();
730 constexpr
index_t ndim_p = impl.template at<1>();
731 constexpr
index_t ndim_y = impl.template at<2>();
732 constexpr
index_t ndim_r = impl.template at<3>();
733 constexpr
auto ndims_hs_minor = impl.template at<4>();
734 constexpr
auto ndims_ps_low = impl.template at<5>();
735 constexpr
auto rs_lengths_impl = impl.template at<6>();
736 constexpr
auto hs_lengthss_impl = impl.template at<7>();
737 constexpr
auto ps_to_rhss_major_impl = impl.template at<8>();
738 constexpr
auto ps_to_rhss_minor_impl = impl.template at<9>();
739 constexpr
auto ys_to_rhs_major_impl = impl.template at<10>();
740 constexpr
auto ys_to_rhs_minor_impl = impl.template at<11>();
742 constexpr
auto rs_lengths =
TO_SEQUENCE(rs_lengths_impl, ndim_r);
744 constexpr
auto ps_to_rhss_major =
746 constexpr
auto ps_to_rhss_minor =
748 constexpr
auto ys_to_rhs_major =
TO_SEQUENCE(ys_to_rhs_major_impl, ndim_y);
749 constexpr
auto ys_to_rhs_minor =
TO_SEQUENCE(ys_to_rhs_minor_impl, ndim_y);
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding_impl(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:527
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:725
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:420
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto container_reduce(const Container &x, Reduce reduce, Init init, number< IBegin >=number< 0 >{}, number< IEnd >=number< Container::size()>{}, number< IStep >=number< 1 >{})
Definition: container_helper.hpp:198
typename sequence_merge< Seqs... >::type sequence_merge_t
Definition: sequence.hpp:1014
constexpr CK_TILE_HOST_DEVICE auto transform_sequences(F f, sequence< Xs... >)
Definition: sequence.hpp:823
constexpr CK_TILE_HOST_DEVICE auto generate_array(F &&f, number< N >)
Definition: sequence.hpp:1106
constexpr CK_TILE_HOST_DEVICE auto to_array_of_array(tuple< Seqs... > t_of_s)
Definition: tuple.hpp:589
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constexpr CK_TILE_HOST_DEVICE auto generate_sequence_v2(F &&f, number< N >)
Definition: sequence.hpp:1036
constexpr CK_TILE_HOST_DEVICE auto merge_sequences(Seqs...)
Definition: sequence.hpp:817
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:400
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
constexpr CK_TILE_HOST_DEVICE auto histogram_sorted_sequence(SeqSortedSamples, sequence< r, rs... >)
Definition: sequence.hpp:1093
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
constexpr auto prefix_sum_sequence(Seq)
Definition: sequence.hpp:899
static constexpr CK_TILE_HOST_DEVICE auto size()
Definition: array.hpp:78
Definition: integral_constant.hpp:13
Definition: sequence.hpp:584
Definition: sequence.hpp:52
static constexpr CK_TILE_HOST_DEVICE index_t size()
Definition: sequence.hpp:56
Definition: functional.hpp:43
Definition: tile_distribution_encoding.hpp:53
static constexpr index_t max_ndim_span_minor_
Definition: tile_distribution_encoding.hpp:124
static constexpr CK_TILE_HOST_DEVICE auto get_sorted_info(IdxSeq, PrefixSumSeq)
Definition: tile_distribution_encoding.hpp:294
static constexpr auto rhs_lengthss_
Definition: tile_distribution_encoding.hpp:77
static constexpr auto distributed_spans_lengthss_
Definition: tile_distribution_encoding.hpp:164
static constexpr auto does_p_own_r_
Definition: tile_distribution_encoding.hpp:197
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_idx_y_to_h()
Definition: tile_distribution_encoding.hpp:277
static constexpr index_t max_ndim_rh_minor_
Definition: tile_distribution_encoding.hpp:73
static constexpr auto ys_to_span_major_
Definition: tile_distribution_encoding.hpp:153
CK_TILE_HOST_DEVICE void print() const
Definition: tile_distribution_encoding.hpp:313
static constexpr auto rhs_major_minor_to_span_minor_
Definition: tile_distribution_encoding.hpp:128
static constexpr CK_TILE_HOST_DEVICE auto get_h_dim_lengths_prefix_sum()
Definition: tile_distribution_encoding.hpp:259
static constexpr index_t ndim_span_major_
Definition: tile_distribution_encoding.hpp:56
static constexpr auto ndims_span_minor_
Definition: tile_distribution_encoding.hpp:110
static constexpr CK_TILE_HOST_DEVICE auto get_sorted_y_info()
Definition: tile_distribution_encoding.hpp:308
static constexpr index_t ndim_rh_major_
Definition: tile_distribution_encoding.hpp:55
static constexpr auto ps_over_rs_derivative_
Definition: tile_distribution_encoding.hpp:225
static constexpr auto ys_lengths_
Definition: tile_distribution_encoding.hpp:81
static constexpr auto ndims_distributed_spans_minor_
Definition: tile_distribution_encoding.hpp:184
static constexpr auto rhs_major_minor_to_ys_
Definition: tile_distribution_encoding.hpp:96
static constexpr auto ndims_rhs_minor_
Definition: tile_distribution_encoding.hpp:59
static constexpr auto ys_to_span_minor_
Definition: tile_distribution_encoding.hpp:157
Definition: tile_distribution_encoding.hpp:26
static constexpr index_t NDimR
Definition: tile_distribution_encoding.hpp:40
static constexpr auto ps_to_rhss_minor_
Definition: tile_distribution_encoding.hpp:46
static constexpr auto rs_lengths_
Definition: tile_distribution_encoding.hpp:43
static constexpr index_t NDimP
Definition: tile_distribution_encoding.hpp:38
remove_cvref_t< Ps2RHssMinor_ > Ps2RHssMinor
Definition: tile_distribution_encoding.hpp:30
CK_TILE_HOST_DEVICE void print() const
Definition: tile_distribution_encoding.hpp:380
static constexpr auto ys_to_rhs_major_
Definition: tile_distribution_encoding.hpp:47
static constexpr auto ys_to_rhs_minor_
Definition: tile_distribution_encoding.hpp:48
static constexpr index_t NDimY
Definition: tile_distribution_encoding.hpp:39
remove_cvref_t< Ys2RHsMinor_ > Ys2RHsMinor
Definition: tile_distribution_encoding.hpp:32
static constexpr auto hs_lengthss_
Definition: tile_distribution_encoding.hpp:44
remove_cvref_t< Ys2RHsMajor_ > Ys2RHsMajor
Definition: tile_distribution_encoding.hpp:31
remove_cvref_t< HsLengthss_ > HsLengthss
Definition: tile_distribution_encoding.hpp:28
remove_cvref_t< Ps2RHssMajor_ > Ps2RHssMajor
Definition: tile_distribution_encoding.hpp:29
remove_cvref_t< RsLengths_ > RsLengths
Definition: tile_distribution_encoding.hpp:27
static constexpr auto ps_to_rhss_major_
Definition: tile_distribution_encoding.hpp:45
static constexpr index_t NDimX
Definition: tile_distribution_encoding.hpp:37
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes)
Definition: container_helper.hpp:486
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10