19 template <
typename RsLengths_,
21 typename Ps2RHssMajor_,
22 typename Ps2RHssMinor_,
23 typename Ys2RHsMajor_,
24 typename Ys2RHsMinor_>
34 static_assert(Ps2RHssMajor::size() == Ps2RHssMinor::size(),
"wrong!");
35 static_assert(Ys2RHsMajor::size() == Ys2RHsMinor::size(),
"wrong!");
50 #if !CK_TILE_ENC_SUPPORT_Y_TO_R
52 "do not support Y dim pointed to R dim");
66 if constexpr(i.value == 0)
97 return ys_lengths_tmp;
108 rhs_major_minor_to_ys_tmp(rh_major)(rh_minor) = i;
111 return rhs_major_minor_to_ys_tmp;
122 ndims_span_minor(span_major)++;
125 return ndims_span_minor;
140 index_t cnt_ndim_span_minor = 0;
147 rhs_major_minor_to_span_minor(rh_major)(rh_minor) = cnt_ndim_span_minor;
149 cnt_ndim_span_minor++;
154 return rhs_major_minor_to_span_minor;
171 distributed_spans_lengthss{{-1}};
179 const index_t span_major = rh_major - 1;
182 distributed_spans_lengthss(span_major)(span_minor) = h_length;
185 return distributed_spans_lengthss;
195 ndims_distributed_spans_minor(span_major)++;
198 return ndims_distributed_spans_minor;
203 if constexpr(
NDimR > 0)
214 if constexpr(rh_major == 0)
216 does_p_own_r(idim_p)(rh_minor) =
true;
231 if constexpr(
NDimR > 0)
238 index_t p_over_rh_derivative = 1;
240 static_for<ndim_low - 1, -1, -1>{}([&](
auto idim_low) {
246 if constexpr(rh_major == 0)
248 ps_over_rs_derivative(idim_p)(rh_minor) = p_over_rh_derivative;
251 p_over_rh_derivative *= rh_length;
255 return ps_over_rs_derivative;
272 return uniformed_h_dim_lengths;
283 constexpr
auto uniformed_ps_to_rhss_major_ =
285 constexpr
auto uniformed_ps_to_rhss_minor_ =
288 constexpr
auto p_len_ = [&]() {
292 static_for<0, uniformed_ps_to_rhss_major_.size(), 1>{}([&](
auto idim_u_) {
293 if constexpr(major_.value == uniformed_ps_to_rhss_major_[idim_u_])
295 constexpr
auto minor_ = uniformed_ps_to_rhss_minor_[idim_u_];
296 constexpr
auto h_length_ =
hs_lengthss_[idim_x_][minor_];
297 len_[idim_x_] *= h_length_;
304 return p_len_over_h_seq_;
314 constexpr
auto uniformed_rh_dim_lengths =
317 return uniformed_rh_dim_lengths;
327 return h_dim_prefix_sum;
336 return rh_dim_prefix_sum;
342 constexpr
auto uniformed_ps_to_rhss_major_ =
344 constexpr
auto uniformed_ps_to_rhss_minor_ =
348 [](
auto major,
auto minor) constexpr {
350 return rh_dim_prefix_sum.at(major) + minor;
352 uniformed_ps_to_rhss_major_,
353 uniformed_ps_to_rhss_minor_);
355 return all_ps_2_rhss;
361 [](
auto major,
auto minor) constexpr {
363 return rh_dim_prefix_sum.at(major) + minor;
368 return all_ys_2_rhss;
375 [](
auto major,
auto minor) constexpr {
377 return rh_dim_prefix_sum.at(major) + minor -
NDimR;
382 return all_ys_2_rhss;
390 constexpr
auto size_ =
HsLengthss{}[i].size();
391 constexpr
auto current_y_to_h_mask_ = [&]() {
394 for(
auto j = 0; j <
NDimY; j++)
411 template <
typename IdxSeq,
typename PrefixSumSeq>
416 constexpr
auto sorted_dims =
typename sorted_idx::type{};
417 constexpr
auto sorted_maps =
typename sorted_idx::sorted2unsorted_map{};
419 constexpr
auto sorted_histogram =
423 return make_tuple(sorted_dims, sorted_maps, sorted_prefix_sum);
434 printf(
"tile_distribution_encoding::detail{");
436 printf(
"ndim_rh_major_: ");
440 printf(
"ndim_span_major_: ");
444 printf(
"ndims_rhs_minor_: ");
448 printf(
"ndim_rh_major_: ");
452 printf(
"max_ndim_rh_minor_: ");
456 printf(
"rhs_lengthss_: ");
460 printf(
"ys_lengths_: ");
464 printf(
"rhs_major_minor_to_ys_: ");
468 printf(
"ndims_span_minor_: ");
472 printf(
"max_ndim_span_minor_: ");
476 printf(
"ys_to_span_major_: ");
480 printf(
"ys_to_span_minor_: ");
484 printf(
"distributed_spans_lengthss_: ");
488 printf(
"ndims_distributed_spans_minor_: ");
492 printf(
"ps_over_rs_derivative_: ");
501 printf(
"tile_distribution_encoding{");
505 printf(
"rs_lengths_: ");
509 printf(
"hs_lengthss_: ");
513 printf(
"ps_to_rhss_major_: ");
517 printf(
"ps_to_rhss_minor_: ");
521 printf(
"ys_to_rhs_major_: ");
525 printf(
"ys_to_rhs_minor_: ");
536 template <
typename encoding,
typename shuffle>
538 template <
typename encoding,
index_t... shuffle>
541 template <
typename Ys2RHs>
546 typename encoding::HsLengthss,
547 typename encoding::Ps2RHssMajor,
548 typename encoding::Ps2RHssMinor,
552 template <
typename encoding,
typename shuffle>
558 template <
typename OuterDstr,
typename InnerDstr>
561 static_assert(OuterDstr::NDimX == InnerDstr::NDimX,
"wrong!");
563 constexpr
index_t NDimHMajor = OuterDstr::NDimX;
571 typename InnerDstr::HsLengthss{}[i]);
576 constexpr
auto rhs_major_2_ndim_outer_rhs_minor = [&]() {
580 rhs_major_2_ndim_outer_rhs_minor_(0) = OuterDstr::RsLengths::size();
584 rhs_major_2_ndim_outer_rhs_minor_(i + 1) =
typename OuterDstr::HsLengthss{}[i].
size();
587 return rhs_major_2_ndim_outer_rhs_minor_;
593 constexpr
auto inner_p_2_rhss_major =
typename InnerDstr::Ps2RHssMajor{}[p];
594 constexpr
auto inner_p_2_rhss_minor =
typename InnerDstr::Ps2RHssMinor{}[p];
596 constexpr
index_t ndim_tmp = inner_p_2_rhss_minor.size();
598 constexpr
auto updated_inner_p_2_rhss_minor = [&]() {
601 for(
index_t i = 0; i < ndim_tmp; i++)
603 index_t rh_major = inner_p_2_rhss_major[i];
605 index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
607 updated_inner_p_2_rhss_minor_(i) = inner_p_2_rhss_minor[i] + ndim_outer_h_minor;
610 return updated_inner_p_2_rhss_minor_;
613 return TO_SEQUENCE(updated_inner_p_2_rhss_minor, ndim_tmp);
618 constexpr
auto updated_inner_ys_2_rhs_minor = [&]() {
619 constexpr
auto inner_ys_2_rhs_major =
typename InnerDstr::Ys2RHsMajor{};
620 constexpr
auto inner_ys_2_rhs_minor =
typename InnerDstr::Ys2RHsMinor{};
622 constexpr
index_t ndim_tmp = inner_ys_2_rhs_minor.size();
624 constexpr
auto updated_inner_ys_2_rhs_minor_ = [&]() {
627 for(
index_t i = 0; i < ndim_tmp; i++)
629 index_t rh_major = inner_ys_2_rhs_major[i];
631 index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
633 updated_inner_ys_2_rhs_minor__(i) = inner_ys_2_rhs_minor[i] + ndim_outer_h_minor;
636 return updated_inner_ys_2_rhs_minor__;
639 return TO_SEQUENCE(updated_inner_ys_2_rhs_minor_, ndim_tmp);
643 constexpr
auto ps_2_rhss_major =
644 container_concat(
typename OuterDstr::Ps2RHssMajor{},
typename InnerDstr::Ps2RHssMajor{});
646 constexpr
auto ps_2_rhss_minor =
647 container_concat(
typename OuterDstr::Ps2RHssMinor{}, updated_inner_ps_2_rhss_minor);
650 constexpr
auto ys_2_rhs_major =
651 merge_sequences(
typename OuterDstr::Ys2RHsMajor{},
typename InnerDstr::Ys2RHsMajor{});
653 constexpr
auto ys_2_rhs_minor =
654 merge_sequences(
typename OuterDstr::Ys2RHsMinor{}, updated_inner_ys_2_rhs_minor);
664 template <
typename InDstr,
index_t... InReduceDimXs>
671 constexpr
index_t max_ndim_r_out = 20;
672 constexpr
index_t max_ndim_y_out = 20;
675 constexpr
index_t ndim_p = InDstr::NDimP;
676 constexpr
index_t ndim_x_in = InDstr::NDimX;
677 constexpr
index_t ndim_y_in = InDstr::NDimY;
678 constexpr
index_t ndim_rh_major_in = InDstr::NDimX + 1;
679 constexpr
index_t ndim_x_out = ndim_x_in -
sizeof...(InReduceDimXs);
680 constexpr
index_t max_ndim_rh_minor_in = InDstr::detail::max_ndim_rh_minor_;
684 [&](
auto i) {
return InDstr::ps_to_rhss_major_[i].size(); },
number<ndim_p>{});
689 for(
index_t i = 0; i < reduce_dim_xs_in.
size(); i++)
691 index_t rh_major = reduce_dim_xs_in[i] + 1;
693 is_rh_major_in_for_reduce(rh_major) =
true;
699 for(
index_t i = 0; i < ndim_y_in; i++)
701 index_t rh_major = InDstr::ys_to_rhs_major_[i];
703 if(is_rh_major_in_for_reduce[rh_major])
705 is_y_in_for_reduce(i) =
true;
713 index_t rh_major = InDstr::ys_to_rhs_major_[i];
714 index_t rh_minor = InDstr::ys_to_rhs_minor_[i];
716 if(is_y_in_for_reduce[i])
718 is_rh_minor_in_for_y_reduce(rh_major)(rh_minor) =
true;
724 index_t cnt_ndim_rh_major_out = 0;
726 for(
index_t i = 0; i < ndim_rh_major_in; i++)
728 if(is_rh_major_in_for_reduce[i])
730 in2out_rh_major(i) = 0;
734 in2out_rh_major(i) = cnt_ndim_rh_major_out;
736 cnt_ndim_rh_major_out++;
745 for(
index_t i = 0; i < InDstr::rs_lengths_.size(); i++)
748 rs_lengths_out(i) = InDstr::rs_lengths_[i];
751 in2out_rh_minor(0)(i) = i;
755 index_t cnt_ndim_r_out = InDstr::rs_lengths_.
size();
758 constexpr
auto h_major_in = rh_major_in - I1;
760 constexpr
index_t ndim_rh_minor_in = InDstr::hs_lengthss_[h_major_in].size();
762 if(is_rh_major_in_for_reduce[rh_major_in])
764 for(
index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
766 if(not is_rh_minor_in_for_y_reduce[rh_major_in][rh_minor_in])
769 rs_lengths_out(cnt_ndim_r_out) = InDstr::hs_lengthss_[h_major_in][rh_minor_in];
772 in2out_rh_minor(rh_major_in)(rh_minor_in) = cnt_ndim_r_out;
780 for(
index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
783 in2out_rh_minor(rh_major_in)(rh_minor_in) = rh_minor_in;
789 const index_t ndim_r_out = cnt_ndim_r_out;
798 if(not is_rh_major_in_for_reduce[i + I1])
801 ndims_hs_minor_out(cnt_ndim_x_out) = InDstr::hs_lengthss_[i].size();
804 static_for<0, InDstr::hs_lengthss_[i].size(), 1>{}(
805 [&](
auto j) { hs_lengthss_out(cnt_ndim_x_out)(j) = InDstr::hs_lengthss_[i][j]; });
816 static_for<0, InDstr::ps_to_rhss_major_[idim_p].size(), 1>{}([&](
auto idim_low) {
817 index_t rh_major_in = InDstr::ps_to_rhss_major_[idim_p][idim_low];
818 index_t rh_minor_in = InDstr::ps_to_rhss_minor_[idim_p][idim_low];
820 ps_to_rhss_major_out(idim_p)(idim_low) = in2out_rh_major[rh_major_in];
821 ps_to_rhss_minor_out(idim_p)(idim_low) = in2out_rh_minor[rh_major_in][rh_minor_in];
832 if(not is_y_in_for_reduce[i])
834 index_t rh_major_in = InDstr::ys_to_rhs_major_[i];
835 index_t rh_minor_in = InDstr::ys_to_rhs_minor_[i];
837 ys_to_rhs_major_out(cnt_ndim_y_out) = in2out_rh_major[rh_major_in];
838 ys_to_rhs_minor_out(cnt_ndim_y_out) = in2out_rh_minor[rh_major_in][rh_minor_in];
845 const index_t ndim_y_out = cnt_ndim_y_out;
856 ps_to_rhss_major_out,
857 ps_to_rhss_minor_out,
859 ys_to_rhs_minor_out);
862 template <
typename InDstr,
index_t... InReduceDimXs>
868 constexpr
index_t ndim_x = impl.template at<0>();
869 constexpr
index_t ndim_p = impl.template at<1>();
870 constexpr
index_t ndim_y = impl.template at<2>();
871 constexpr
index_t ndim_r = impl.template at<3>();
872 constexpr
auto ndims_hs_minor = impl.template at<4>();
873 constexpr
auto ndims_ps_low = impl.template at<5>();
874 constexpr
auto rs_lengths_impl = impl.template at<6>();
875 constexpr
auto hs_lengthss_impl = impl.template at<7>();
876 constexpr
auto ps_to_rhss_major_impl = impl.template at<8>();
877 constexpr
auto ps_to_rhss_minor_impl = impl.template at<9>();
878 constexpr
auto ys_to_rhs_major_impl = impl.template at<10>();
879 constexpr
auto ys_to_rhs_minor_impl = impl.template at<11>();
881 constexpr
auto rs_lengths =
TO_SEQUENCE(rs_lengths_impl, ndim_r);
883 constexpr
auto ps_to_rhss_major =
885 constexpr
auto ps_to_rhss_minor =
887 constexpr
auto ys_to_rhs_major =
TO_SEQUENCE(ys_to_rhs_major_impl, ndim_y);
888 constexpr
auto ys_to_rhs_minor =
TO_SEQUENCE(ys_to_rhs_minor_impl, ndim_y);
Definition: tile_distribution_encoding.hpp:537
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding_impl(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:666
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:864
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:559
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto container_reduce(const Container &x, Reduce reduce, Init init, number< IBegin >=number< 0 >{}, number< IEnd >=number< Container::size()>{}, number< IStep >=number< 1 >{})
Definition: container_helper.hpp:198
typename sequence_merge< Seqs... >::type sequence_merge_t
Definition: sequence.hpp:1014
constexpr CK_TILE_HOST_DEVICE auto transform_sequences(F f, sequence< Xs... >)
Definition: sequence.hpp:823
typename tile_distribution_encoding_shuffle< encoding, shuffle >::type tile_distribution_encoding_shuffle_t
Definition: tile_distribution_encoding.hpp:554
constexpr CK_TILE_HOST_DEVICE auto generate_array(F &&f, number< N >)
Definition: sequence.hpp:1106
constexpr CK_TILE_HOST_DEVICE auto to_array_of_array(tuple< Seqs... > t_of_s)
Definition: tuple.hpp:613
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto generate_sequence_v2(F &&f, number< N >)
Definition: sequence.hpp:1036
constexpr CK_TILE_HOST_DEVICE auto merge_sequences(Seqs...)
Definition: sequence.hpp:817
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:412
constexpr index_t container_find(sequence< Is... > seq, index_t value)
Definition: container_helper.hpp:447
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:343
constexpr CK_TILE_HOST_DEVICE auto unpack(F &&f, X &&x)
Definition: functional.hpp:200
constexpr CK_TILE_HOST_DEVICE auto histogram_sorted_sequence(SeqSortedSamples, sequence< r, rs... >)
Definition: sequence.hpp:1093
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
constexpr auto prefix_sum_sequence(Seq)
Definition: sequence.hpp:899
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
static constexpr CK_TILE_HOST_DEVICE auto size()
Definition: array.hpp:97
Definition: integral_constant.hpp:13
Definition: sequence.hpp:584
Definition: sequence.hpp:52
static constexpr CK_TILE_HOST_DEVICE index_t size()
Definition: sequence.hpp:56
Definition: functional.hpp:43
Definition: tile_distribution_encoding.hpp:58
static constexpr index_t max_ndim_span_minor_
Definition: tile_distribution_encoding.hpp:129
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_idx_p_to_h()
Definition: tile_distribution_encoding.hpp:339
static constexpr CK_TILE_HOST_DEVICE auto get_sorted_info(IdxSeq, PrefixSumSeq)
Definition: tile_distribution_encoding.hpp:412
static constexpr auto rhs_lengthss_
Definition: tile_distribution_encoding.hpp:82
static constexpr auto distributed_spans_lengthss_
Definition: tile_distribution_encoding.hpp:169
static constexpr auto does_p_own_r_
Definition: tile_distribution_encoding.hpp:202
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_idx_y_to_h()
Definition: tile_distribution_encoding.hpp:371
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_rh_dim_lengths()
Definition: tile_distribution_encoding.hpp:312
static constexpr index_t max_ndim_rh_minor_
Definition: tile_distribution_encoding.hpp:78
static constexpr auto ys_to_span_major_
Definition: tile_distribution_encoding.hpp:158
CK_TILE_HOST_DEVICE void print() const
Definition: tile_distribution_encoding.hpp:432
static constexpr CK_TILE_HOST_DEVICE auto get_sorted_y_to_h_info()
Definition: tile_distribution_encoding.hpp:427
static constexpr auto rhs_major_minor_to_span_minor_
Definition: tile_distribution_encoding.hpp:133
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_p_dim_lengths_over_h()
Definition: tile_distribution_encoding.hpp:276
static constexpr CK_TILE_HOST_DEVICE auto get_h_dim_lengths_prefix_sum()
Definition: tile_distribution_encoding.hpp:321
static constexpr index_t ndim_span_major_
Definition: tile_distribution_encoding.hpp:61
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_h_dim_lengths()
Definition: tile_distribution_encoding.hpp:263
static constexpr auto ndims_span_minor_
Definition: tile_distribution_encoding.hpp:115
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_idx_y_to_rh()
Definition: tile_distribution_encoding.hpp:358
static constexpr index_t ndim_rh_major_
Definition: tile_distribution_encoding.hpp:60
static constexpr auto ps_over_rs_derivative_
Definition: tile_distribution_encoding.hpp:230
static constexpr CK_TILE_HOST_DEVICE auto get_y_to_h_masks()
Definition: tile_distribution_encoding.hpp:386
static constexpr auto ys_lengths_
Definition: tile_distribution_encoding.hpp:86
static constexpr auto ndims_distributed_spans_minor_
Definition: tile_distribution_encoding.hpp:189
static constexpr auto rhs_major_minor_to_ys_
Definition: tile_distribution_encoding.hpp:101
static constexpr auto ndims_rhs_minor_
Definition: tile_distribution_encoding.hpp:64
static constexpr auto ys_to_span_minor_
Definition: tile_distribution_encoding.hpp:162
static constexpr CK_TILE_HOST_DEVICE auto get_rh_dim_lengths_prefix_sum()
Definition: tile_distribution_encoding.hpp:330
Definition: tile_distribution_encoding.hpp:26
static constexpr index_t NDimR
Definition: tile_distribution_encoding.hpp:40
static constexpr auto ps_to_rhss_minor_
Definition: tile_distribution_encoding.hpp:46
static constexpr auto rs_lengths_
Definition: tile_distribution_encoding.hpp:43
static constexpr index_t NDimP
Definition: tile_distribution_encoding.hpp:38
remove_cvref_t< Ps2RHssMinor_ > Ps2RHssMinor
Definition: tile_distribution_encoding.hpp:30
CK_TILE_HOST_DEVICE void print() const
Definition: tile_distribution_encoding.hpp:499
static constexpr auto ys_to_rhs_major_
Definition: tile_distribution_encoding.hpp:47
static constexpr auto ys_to_rhs_minor_
Definition: tile_distribution_encoding.hpp:48
static constexpr index_t NDimY
Definition: tile_distribution_encoding.hpp:39
remove_cvref_t< Ys2RHsMinor_ > Ys2RHsMinor
Definition: tile_distribution_encoding.hpp:32
static constexpr auto hs_lengthss_
Definition: tile_distribution_encoding.hpp:44
remove_cvref_t< Ys2RHsMajor_ > Ys2RHsMajor
Definition: tile_distribution_encoding.hpp:31
remove_cvref_t< HsLengthss_ > HsLengthss
Definition: tile_distribution_encoding.hpp:28
remove_cvref_t< Ps2RHssMajor_ > Ps2RHssMajor
Definition: tile_distribution_encoding.hpp:29
remove_cvref_t< RsLengths_ > RsLengths
Definition: tile_distribution_encoding.hpp:27
static constexpr auto ps_to_rhss_major_
Definition: tile_distribution_encoding.hpp:45
static constexpr index_t NDimX
Definition: tile_distribution_encoding.hpp:37
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes)
Definition: container_helper.hpp:486
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10