19 template <
typename RsLengths_,
21 typename Ps2RHssMajor_,
22 typename Ps2RHssMinor_,
23 typename Ys2RHsMajor_,
24 typename Ys2RHsMinor_>
34 static_assert(Ps2RHssMajor::size() == Ps2RHssMinor::size(),
"wrong!");
35 static_assert(Ys2RHsMajor::size() == Ys2RHsMinor::size(),
"wrong!");
50 #if !CK_TILE_ENC_SUPPORT_Y_TO_R
52 "do not support Y dim pointed to R dim");
66 if constexpr(i.value == 0)
97 return ys_lengths_tmp;
108 rhs_major_minor_to_ys_tmp(rh_major)(rh_minor) = i;
111 return rhs_major_minor_to_ys_tmp;
122 ndims_span_minor(span_major)++;
125 return ndims_span_minor;
140 index_t cnt_ndim_span_minor = 0;
147 rhs_major_minor_to_span_minor(rh_major)(rh_minor) = cnt_ndim_span_minor;
149 cnt_ndim_span_minor++;
154 return rhs_major_minor_to_span_minor;
171 distributed_spans_lengthss{{-1}};
179 const index_t span_major = rh_major - 1;
182 distributed_spans_lengthss(span_major)(span_minor) = h_length;
185 return distributed_spans_lengthss;
195 ndims_distributed_spans_minor(span_major)++;
198 return ndims_distributed_spans_minor;
203 if constexpr(
NDimR > 0)
214 if constexpr(rh_major == 0)
216 does_p_own_r(idim_p)(rh_minor) =
true;
231 if constexpr(
NDimR > 0)
238 index_t p_over_rh_derivative = 1;
240 static_for<ndim_low - 1, -1, -1>{}([&](
auto idim_low) {
246 if constexpr(rh_major == 0)
248 ps_over_rs_derivative(idim_p)(rh_minor) = p_over_rh_derivative;
251 p_over_rh_derivative *= rh_length;
255 return ps_over_rs_derivative;
272 return uniformed_h_dim_lengths;
283 constexpr
auto uniformed_ps_to_rhss_major_ =
285 constexpr
auto uniformed_ps_to_rhss_minor_ =
288 constexpr
auto p_len_ = [&]() {
292 static_for<0, uniformed_ps_to_rhss_major_.size(), 1>{}([&](
auto idim_u_) {
293 if constexpr(major_.value == uniformed_ps_to_rhss_major_[idim_u_])
295 constexpr
auto minor_ = uniformed_ps_to_rhss_minor_[idim_u_];
296 constexpr
auto h_length_ =
hs_lengthss_[idim_x_][minor_];
297 len_[idim_x_] *= h_length_;
304 return p_len_over_h_seq_;
314 constexpr
auto uniformed_rh_dim_lengths =
317 return uniformed_rh_dim_lengths;
327 return h_dim_prefix_sum;
336 return rh_dim_prefix_sum;
342 constexpr
auto uniformed_ps_to_rhss_major_ =
344 constexpr
auto uniformed_ps_to_rhss_minor_ =
348 [](
auto major,
auto minor) constexpr {
350 return rh_dim_prefix_sum.at(major) + minor;
352 uniformed_ps_to_rhss_major_,
353 uniformed_ps_to_rhss_minor_);
355 return all_ps_2_rhss;
361 [](
auto major,
auto minor) constexpr {
363 return rh_dim_prefix_sum.at(major) + minor;
368 return all_ys_2_rhss;
375 [](
auto major,
auto minor) constexpr {
377 return rh_dim_prefix_sum.at(major) + minor -
NDimR;
382 return all_ys_2_rhss;
390 constexpr
auto size_ =
HsLengthss{}[i].size();
391 constexpr
auto current_y_to_h_mask_ = [&]() {
394 for(
auto j = 0; j <
NDimY; j++)
411 template <
typename IdxSeq,
typename PrefixSumSeq>
416 constexpr
auto sorted_dims =
typename sorted_idx::type{};
417 constexpr
auto sorted_maps =
typename sorted_idx::sorted2unsorted_map{};
419 constexpr
auto sorted_histogram =
423 return make_tuple(sorted_dims, sorted_maps, sorted_prefix_sum);
434 printf(
"tile_distribution_encoding::detail{");
436 printf(
"ndim_rh_major_: ");
440 printf(
"ndim_span_major_: ");
444 printf(
"ndims_rhs_minor_: ");
448 printf(
"ndim_rh_major_: ");
452 printf(
"max_ndim_rh_minor_: ");
456 printf(
"rhs_lengthss_: ");
460 printf(
"ys_lengths_: ");
464 printf(
"rhs_major_minor_to_ys_: ");
468 printf(
"ndims_span_minor_: ");
472 printf(
"max_ndim_span_minor_: ");
476 printf(
"ys_to_span_major_: ");
480 printf(
"ys_to_span_minor_: ");
484 printf(
"distributed_spans_lengthss_: ");
488 printf(
"ndims_distributed_spans_minor_: ");
492 printf(
"ps_over_rs_derivative_: ");
501 printf(
"tile_distribution_encoding{");
505 printf(
"rs_lengths_: ");
509 printf(
"hs_lengthss_: ");
513 printf(
"ps_to_rhss_major_: ");
517 printf(
"ps_to_rhss_minor_: ");
521 printf(
"ys_to_rhs_major_: ");
525 printf(
"ys_to_rhs_minor_: ");
538 template <
typename OuterDstr,
typename InnerDstr>
541 static_assert(OuterDstr::NDimX == InnerDstr::NDimX,
"wrong!");
543 constexpr
index_t NDimHMajor = OuterDstr::NDimX;
551 typename InnerDstr::HsLengthss{}[i]);
556 constexpr
auto rhs_major_2_ndim_outer_rhs_minor = [&]() {
560 rhs_major_2_ndim_outer_rhs_minor_(0) = OuterDstr::RsLengths::size();
564 rhs_major_2_ndim_outer_rhs_minor_(i + 1) =
typename OuterDstr::HsLengthss{}[i].
size();
567 return rhs_major_2_ndim_outer_rhs_minor_;
573 constexpr
auto inner_p_2_rhss_major =
typename InnerDstr::Ps2RHssMajor{}[p];
574 constexpr
auto inner_p_2_rhss_minor =
typename InnerDstr::Ps2RHssMinor{}[p];
576 constexpr
index_t ndim_tmp = inner_p_2_rhss_minor.size();
578 constexpr
auto updated_inner_p_2_rhss_minor = [&]() {
581 for(
index_t i = 0; i < ndim_tmp; i++)
583 index_t rh_major = inner_p_2_rhss_major[i];
585 index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
587 updated_inner_p_2_rhss_minor_(i) = inner_p_2_rhss_minor[i] + ndim_outer_h_minor;
590 return updated_inner_p_2_rhss_minor_;
593 return TO_SEQUENCE(updated_inner_p_2_rhss_minor, ndim_tmp);
598 constexpr
auto updated_inner_ys_2_rhs_minor = [&]() {
599 constexpr
auto inner_ys_2_rhs_major =
typename InnerDstr::Ys2RHsMajor{};
600 constexpr
auto inner_ys_2_rhs_minor =
typename InnerDstr::Ys2RHsMinor{};
602 constexpr
index_t ndim_tmp = inner_ys_2_rhs_minor.size();
604 constexpr
auto updated_inner_ys_2_rhs_minor_ = [&]() {
607 for(
index_t i = 0; i < ndim_tmp; i++)
609 index_t rh_major = inner_ys_2_rhs_major[i];
611 index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
613 updated_inner_ys_2_rhs_minor__(i) = inner_ys_2_rhs_minor[i] + ndim_outer_h_minor;
616 return updated_inner_ys_2_rhs_minor__;
619 return TO_SEQUENCE(updated_inner_ys_2_rhs_minor_, ndim_tmp);
623 constexpr
auto ps_2_rhss_major =
624 container_concat(
typename OuterDstr::Ps2RHssMajor{},
typename InnerDstr::Ps2RHssMajor{});
626 constexpr
auto ps_2_rhss_minor =
627 container_concat(
typename OuterDstr::Ps2RHssMinor{}, updated_inner_ps_2_rhss_minor);
630 constexpr
auto ys_2_rhs_major =
631 merge_sequences(
typename OuterDstr::Ys2RHsMajor{},
typename InnerDstr::Ys2RHsMajor{});
633 constexpr
auto ys_2_rhs_minor =
634 merge_sequences(
typename OuterDstr::Ys2RHsMinor{}, updated_inner_ys_2_rhs_minor);
644 template <
typename InDstr,
index_t... InReduceDimXs>
651 constexpr
index_t max_ndim_r_out = 20;
652 constexpr
index_t max_ndim_y_out = 20;
655 constexpr
index_t ndim_p = InDstr::NDimP;
656 constexpr
index_t ndim_x_in = InDstr::NDimX;
657 constexpr
index_t ndim_y_in = InDstr::NDimY;
658 constexpr
index_t ndim_rh_major_in = InDstr::NDimX + 1;
659 constexpr
index_t ndim_x_out = ndim_x_in -
sizeof...(InReduceDimXs);
660 constexpr
index_t max_ndim_rh_minor_in = InDstr::detail::max_ndim_rh_minor_;
664 [&](
auto i) {
return InDstr::ps_to_rhss_major_[i].size(); },
number<ndim_p>{});
669 for(
index_t i = 0; i < reduce_dim_xs_in.
size(); i++)
671 index_t rh_major = reduce_dim_xs_in[i] + 1;
673 is_rh_major_in_for_reduce(rh_major) =
true;
679 for(
index_t i = 0; i < ndim_y_in; i++)
681 index_t rh_major = InDstr::ys_to_rhs_major_[i];
683 if(is_rh_major_in_for_reduce[rh_major])
685 is_y_in_for_reduce(i) =
true;
693 index_t rh_major = InDstr::ys_to_rhs_major_[i];
694 index_t rh_minor = InDstr::ys_to_rhs_minor_[i];
696 if(is_y_in_for_reduce[i])
698 is_rh_minor_in_for_y_reduce(rh_major)(rh_minor) =
true;
704 index_t cnt_ndim_rh_major_out = 0;
706 for(
index_t i = 0; i < ndim_rh_major_in; i++)
708 if(is_rh_major_in_for_reduce[i])
710 in2out_rh_major(i) = 0;
714 in2out_rh_major(i) = cnt_ndim_rh_major_out;
716 cnt_ndim_rh_major_out++;
725 for(
index_t i = 0; i < InDstr::rs_lengths_.size(); i++)
728 rs_lengths_out(i) = InDstr::rs_lengths_[i];
731 in2out_rh_minor(0)(i) = i;
735 index_t cnt_ndim_r_out = InDstr::rs_lengths_.
size();
738 constexpr
auto h_major_in = rh_major_in - I1;
740 constexpr
index_t ndim_rh_minor_in = InDstr::hs_lengthss_[h_major_in].size();
742 if(is_rh_major_in_for_reduce[rh_major_in])
744 for(
index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
746 if(not is_rh_minor_in_for_y_reduce[rh_major_in][rh_minor_in])
749 rs_lengths_out(cnt_ndim_r_out) = InDstr::hs_lengthss_[h_major_in][rh_minor_in];
752 in2out_rh_minor(rh_major_in)(rh_minor_in) = cnt_ndim_r_out;
760 for(
index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
763 in2out_rh_minor(rh_major_in)(rh_minor_in) = rh_minor_in;
769 const index_t ndim_r_out = cnt_ndim_r_out;
778 if(not is_rh_major_in_for_reduce[i + I1])
781 ndims_hs_minor_out(cnt_ndim_x_out) = InDstr::hs_lengthss_[i].size();
784 static_for<0, InDstr::hs_lengthss_[i].size(), 1>{}(
785 [&](
auto j) { hs_lengthss_out(cnt_ndim_x_out)(j) = InDstr::hs_lengthss_[i][j]; });
796 static_for<0, InDstr::ps_to_rhss_major_[idim_p].size(), 1>{}([&](
auto idim_low) {
797 index_t rh_major_in = InDstr::ps_to_rhss_major_[idim_p][idim_low];
798 index_t rh_minor_in = InDstr::ps_to_rhss_minor_[idim_p][idim_low];
800 ps_to_rhss_major_out(idim_p)(idim_low) = in2out_rh_major[rh_major_in];
801 ps_to_rhss_minor_out(idim_p)(idim_low) = in2out_rh_minor[rh_major_in][rh_minor_in];
812 if(not is_y_in_for_reduce[i])
814 index_t rh_major_in = InDstr::ys_to_rhs_major_[i];
815 index_t rh_minor_in = InDstr::ys_to_rhs_minor_[i];
817 ys_to_rhs_major_out(cnt_ndim_y_out) = in2out_rh_major[rh_major_in];
818 ys_to_rhs_minor_out(cnt_ndim_y_out) = in2out_rh_minor[rh_major_in][rh_minor_in];
825 const index_t ndim_y_out = cnt_ndim_y_out;
836 ps_to_rhss_major_out,
837 ps_to_rhss_minor_out,
839 ys_to_rhs_minor_out);
842 template <
typename InDstr,
index_t... InReduceDimXs>
848 constexpr
index_t ndim_x = impl.template at<0>();
849 constexpr
index_t ndim_p = impl.template at<1>();
850 constexpr
index_t ndim_y = impl.template at<2>();
851 constexpr
index_t ndim_r = impl.template at<3>();
852 constexpr
auto ndims_hs_minor = impl.template at<4>();
853 constexpr
auto ndims_ps_low = impl.template at<5>();
854 constexpr
auto rs_lengths_impl = impl.template at<6>();
855 constexpr
auto hs_lengthss_impl = impl.template at<7>();
856 constexpr
auto ps_to_rhss_major_impl = impl.template at<8>();
857 constexpr
auto ps_to_rhss_minor_impl = impl.template at<9>();
858 constexpr
auto ys_to_rhs_major_impl = impl.template at<10>();
859 constexpr
auto ys_to_rhs_minor_impl = impl.template at<11>();
861 constexpr
auto rs_lengths =
TO_SEQUENCE(rs_lengths_impl, ndim_r);
863 constexpr
auto ps_to_rhss_major =
865 constexpr
auto ps_to_rhss_minor =
867 constexpr
auto ys_to_rhs_major =
TO_SEQUENCE(ys_to_rhs_major_impl, ndim_y);
868 constexpr
auto ys_to_rhs_minor =
TO_SEQUENCE(ys_to_rhs_minor_impl, ndim_y);
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding_impl(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:646
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:844
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:539
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto container_reduce(const Container &x, Reduce reduce, Init init, number< IBegin >=number< 0 >{}, number< IEnd >=number< Container::size()>{}, number< IStep >=number< 1 >{})
Definition: container_helper.hpp:198
typename sequence_merge< Seqs... >::type sequence_merge_t
Definition: sequence.hpp:1014
constexpr CK_TILE_HOST_DEVICE auto transform_sequences(F f, sequence< Xs... >)
Definition: sequence.hpp:823
constexpr CK_TILE_HOST_DEVICE auto generate_array(F &&f, number< N >)
Definition: sequence.hpp:1106
constexpr CK_TILE_HOST_DEVICE auto to_array_of_array(tuple< Seqs... > t_of_s)
Definition: tuple.hpp:594
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto generate_sequence_v2(F &&f, number< N >)
Definition: sequence.hpp:1036
constexpr CK_TILE_HOST_DEVICE auto merge_sequences(Seqs...)
Definition: sequence.hpp:817
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:406
constexpr index_t container_find(sequence< Is... > seq, index_t value)
Definition: container_helper.hpp:447
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
constexpr CK_TILE_HOST_DEVICE auto unpack(F &&f, X &&x)
Definition: functional.hpp:200
constexpr CK_TILE_HOST_DEVICE auto histogram_sorted_sequence(SeqSortedSamples, sequence< r, rs... >)
Definition: sequence.hpp:1093
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
constexpr auto prefix_sum_sequence(Seq)
Definition: sequence.hpp:899
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
static constexpr CK_TILE_HOST_DEVICE auto size()
Definition: array.hpp:97
Definition: integral_constant.hpp:13
Definition: sequence.hpp:584
Definition: sequence.hpp:52
static constexpr CK_TILE_HOST_DEVICE index_t size()
Definition: sequence.hpp:56
Definition: functional.hpp:43
Definition: tile_distribution_encoding.hpp:58
static constexpr index_t max_ndim_span_minor_
Definition: tile_distribution_encoding.hpp:129
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_idx_p_to_h()
Definition: tile_distribution_encoding.hpp:339
static constexpr CK_TILE_HOST_DEVICE auto get_sorted_info(IdxSeq, PrefixSumSeq)
Definition: tile_distribution_encoding.hpp:412
static constexpr auto rhs_lengthss_
Definition: tile_distribution_encoding.hpp:82
static constexpr auto distributed_spans_lengthss_
Definition: tile_distribution_encoding.hpp:169
static constexpr auto does_p_own_r_
Definition: tile_distribution_encoding.hpp:202
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_idx_y_to_h()
Definition: tile_distribution_encoding.hpp:371
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_rh_dim_lengths()
Definition: tile_distribution_encoding.hpp:312
static constexpr index_t max_ndim_rh_minor_
Definition: tile_distribution_encoding.hpp:78
static constexpr auto ys_to_span_major_
Definition: tile_distribution_encoding.hpp:158
CK_TILE_HOST_DEVICE void print() const
Definition: tile_distribution_encoding.hpp:432
static constexpr CK_TILE_HOST_DEVICE auto get_sorted_y_to_h_info()
Definition: tile_distribution_encoding.hpp:427
static constexpr auto rhs_major_minor_to_span_minor_
Definition: tile_distribution_encoding.hpp:133
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_p_dim_lengths_over_h()
Definition: tile_distribution_encoding.hpp:276
static constexpr CK_TILE_HOST_DEVICE auto get_h_dim_lengths_prefix_sum()
Definition: tile_distribution_encoding.hpp:321
static constexpr index_t ndim_span_major_
Definition: tile_distribution_encoding.hpp:61
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_h_dim_lengths()
Definition: tile_distribution_encoding.hpp:263
static constexpr auto ndims_span_minor_
Definition: tile_distribution_encoding.hpp:115
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_idx_y_to_rh()
Definition: tile_distribution_encoding.hpp:358
static constexpr index_t ndim_rh_major_
Definition: tile_distribution_encoding.hpp:60
static constexpr auto ps_over_rs_derivative_
Definition: tile_distribution_encoding.hpp:230
static constexpr CK_TILE_HOST_DEVICE auto get_y_to_h_masks()
Definition: tile_distribution_encoding.hpp:386
static constexpr auto ys_lengths_
Definition: tile_distribution_encoding.hpp:86
static constexpr auto ndims_distributed_spans_minor_
Definition: tile_distribution_encoding.hpp:189
static constexpr auto rhs_major_minor_to_ys_
Definition: tile_distribution_encoding.hpp:101
static constexpr auto ndims_rhs_minor_
Definition: tile_distribution_encoding.hpp:64
static constexpr auto ys_to_span_minor_
Definition: tile_distribution_encoding.hpp:162
static constexpr CK_TILE_HOST_DEVICE auto get_rh_dim_lengths_prefix_sum()
Definition: tile_distribution_encoding.hpp:330
Definition: tile_distribution_encoding.hpp:26
static constexpr index_t NDimR
Definition: tile_distribution_encoding.hpp:40
static constexpr auto ps_to_rhss_minor_
Definition: tile_distribution_encoding.hpp:46
static constexpr auto rs_lengths_
Definition: tile_distribution_encoding.hpp:43
static constexpr index_t NDimP
Definition: tile_distribution_encoding.hpp:38
remove_cvref_t< Ps2RHssMinor_ > Ps2RHssMinor
Definition: tile_distribution_encoding.hpp:30
CK_TILE_HOST_DEVICE void print() const
Definition: tile_distribution_encoding.hpp:499
static constexpr auto ys_to_rhs_major_
Definition: tile_distribution_encoding.hpp:47
static constexpr auto ys_to_rhs_minor_
Definition: tile_distribution_encoding.hpp:48
static constexpr index_t NDimY
Definition: tile_distribution_encoding.hpp:39
remove_cvref_t< Ys2RHsMinor_ > Ys2RHsMinor
Definition: tile_distribution_encoding.hpp:32
static constexpr auto hs_lengthss_
Definition: tile_distribution_encoding.hpp:44
remove_cvref_t< Ys2RHsMajor_ > Ys2RHsMajor
Definition: tile_distribution_encoding.hpp:31
remove_cvref_t< HsLengthss_ > HsLengthss
Definition: tile_distribution_encoding.hpp:28
remove_cvref_t< Ps2RHssMajor_ > Ps2RHssMajor
Definition: tile_distribution_encoding.hpp:29
remove_cvref_t< RsLengths_ > RsLengths
Definition: tile_distribution_encoding.hpp:27
static constexpr auto ps_to_rhss_major_
Definition: tile_distribution_encoding.hpp:45
static constexpr index_t NDimX
Definition: tile_distribution_encoding.hpp:37
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes)
Definition: container_helper.hpp:486
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10