/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/core/tensor/tile_distribution_encoding.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/core/tensor/tile_distribution_encoding.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/core/tensor/tile_distribution_encoding.hpp Source File
tile_distribution_encoding.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
16 
17 namespace ck_tile {
18 
19 template <typename RsLengths_, // sequence<...>
20  typename HsLengthss_, // tuple<sequence<...>, ...>
21  typename Ps2RHssMajor_, // tuple<sequence<...>, ...>
22  typename Ps2RHssMinor_, // tuple<sequence<...>, ...>
23  typename Ys2RHsMajor_, // sequence<...>
24  typename Ys2RHsMinor_> // sequence<...>
26 {
33 
34  static_assert(Ps2RHssMajor::size() == Ps2RHssMinor::size(), "wrong!");
35  static_assert(Ys2RHsMajor::size() == Ys2RHsMinor::size(), "wrong!");
36 
37  static constexpr index_t NDimX = HsLengthss::size();
38  static constexpr index_t NDimP = Ps2RHssMajor::size();
39  static constexpr index_t NDimY = Ys2RHsMajor::size();
40  static constexpr index_t NDimR = RsLengths::size();
41 
42  // FIXME: move into detail
43  static constexpr auto rs_lengths_ = RsLengths{};
44  static constexpr auto hs_lengthss_ = HsLengthss{};
45  static constexpr auto ps_to_rhss_major_ = Ps2RHssMajor{};
46  static constexpr auto ps_to_rhss_minor_ = Ps2RHssMinor{};
47  static constexpr auto ys_to_rhs_major_ = Ys2RHsMajor{};
48  static constexpr auto ys_to_rhs_minor_ = Ys2RHsMinor{};
49 
50  // redundant but useful info
51  // TODO: really bad code, should be over-hauled
52  struct detail
53  {
54  // ndim_rh_major_, ndim_span_mainor_
55  static constexpr index_t ndim_rh_major_ = NDimX + 1;
56  static constexpr index_t ndim_span_major_ = NDimX;
57 
58  // ndims_rhs_minor_[ndim_rh_major_]
59  static constexpr auto ndims_rhs_minor_ = generate_array(
60  [](auto i) {
61  if constexpr(i.value == 0)
62  {
63  return rs_lengths_.size();
64  }
65  else
66  {
67  return hs_lengthss_[i - number<1>{}].size();
68  }
69  },
71 
72  // max_ndim_rh_minor_
73  static constexpr index_t max_ndim_rh_minor_ =
75 
76  // rhs_lengthss_[ndim_rh_major_][max_ndim_rh_minor_]
77  static constexpr auto rhs_lengthss_ =
79 
80  // ys_lengths_
81  static constexpr auto ys_lengths_ = [] {
82  array<index_t, NDimY> ys_lengths_tmp{-1};
83 
84  for(index_t i = 0; i < NDimY; i++)
85  {
86  index_t rh_major = ys_to_rhs_major_[i];
87  index_t rh_minor = ys_to_rhs_minor_[i];
88 
89  ys_lengths_tmp(i) = rhs_lengthss_[rh_major][rh_minor];
90  }
91 
92  return ys_lengths_tmp;
93  }();
94 
95  // rhs_major_minor_to_ys_[ndim_rh_majpr_][max_ndim_rh_minor_]
96  static constexpr auto rhs_major_minor_to_ys_ = [] {
97  array<array<index_t, max_ndim_rh_minor_>, NDimX + 1> rhs_major_minor_to_ys_tmp{{-1}};
98 
99  static_for<0, NDimY, 1>{}([&](auto i) {
100  constexpr index_t rh_major = ys_to_rhs_major_[i];
101  constexpr index_t rh_minor = ys_to_rhs_minor_[i];
102 
103  rhs_major_minor_to_ys_tmp(rh_major)(rh_minor) = i;
104  });
105 
106  return rhs_major_minor_to_ys_tmp;
107  }();
108 
109  // ndims_span_minor_[NDimY]
110  static constexpr auto ndims_span_minor_ = [] {
111  array<index_t, NDimX> ndims_span_minor{0};
112 
113  for(index_t i = 0; i < NDimY; i++)
114  {
115  const index_t span_major = ys_to_rhs_major_[i] - 1;
116 
117  ndims_span_minor(span_major)++;
118  }
119 
120  return ndims_span_minor;
121  }();
122 
123  // max_ndim_span_minor_
124  static constexpr index_t max_ndim_span_minor_ =
126 
127  // rhs_major_minor_to_span_minor_ [ndim_rh_major_][max_ndim_rh_minor_]
128  static constexpr auto rhs_major_minor_to_span_minor_ = [] {
129  array<array<index_t, max_ndim_rh_minor_>, ndim_rh_major_> rhs_major_minor_to_span_minor{
130  {-1}};
131 
132  static_for<0, ndim_rh_major_, 1>{}([&](auto rh_major) {
133  constexpr index_t ndim_rh_minor = ndims_rhs_minor_[rh_major];
134 
135  index_t cnt_ndim_span_minor = 0;
136 
137  static_for<0, ndim_rh_minor, 1>{}([&](auto rh_minor) {
138  constexpr index_t idim_y = rhs_major_minor_to_ys_[rh_major][rh_minor];
139 
140  if(idim_y >= 0)
141  {
142  rhs_major_minor_to_span_minor(rh_major)(rh_minor) = cnt_ndim_span_minor;
143 
144  cnt_ndim_span_minor++;
145  }
146  });
147  });
148 
149  return rhs_major_minor_to_span_minor;
150  }();
151 
152  // ys_to_span_major_[NDimY]
153  static constexpr auto ys_to_span_major_ =
154  generate_array([](auto i) { return ys_to_rhs_major_[i] - 1; }, number<NDimY>{});
155 
156  // ys_to_span_minor_[NDimY]
157  static constexpr auto ys_to_span_minor_ = generate_array(
158  [](auto i) {
160  },
161  number<NDimY>{});
162 
163  // distributed_spans_lengthss_[ndim_span_major_][max_ndim_span_minor_]
164  static constexpr auto distributed_spans_lengthss_ = [] {
166  distributed_spans_lengthss{{-1}};
167 
168  static_for<0, NDimY, 1>{}([&](auto i) {
169  const index_t rh_major = ys_to_rhs_major_[i];
170  const index_t rh_minor = ys_to_rhs_minor_[i];
171 
172  const index_t h_length = hs_lengthss_[number<rh_major - 1>{}][rh_minor];
173 
174  const index_t span_major = rh_major - 1;
175  const index_t span_minor = rhs_major_minor_to_span_minor_[rh_major][rh_minor];
176 
177  distributed_spans_lengthss(span_major)(span_minor) = h_length;
178  });
179 
180  return distributed_spans_lengthss;
181  }();
182 
183  // ndims_distributed_spans_minor_[ndim_span_major_]
184  static constexpr auto ndims_distributed_spans_minor_ = [] {
185  array<index_t, ndim_span_major_> ndims_distributed_spans_minor{0};
186 
187  static_for<0, NDimY, 1>{}([&](auto i) {
188  const index_t span_major = ys_to_rhs_major_[i] - 1;
189 
190  ndims_distributed_spans_minor(span_major)++;
191  });
192 
193  return ndims_distributed_spans_minor;
194  }();
195 
196  // does_p_own_r_[NDimP][NDimR]
197  static constexpr auto does_p_own_r_ = [] {
198  if constexpr(NDimR > 0)
199  {
200  array<array<bool, NDimR>, NDimP> does_p_own_r{{false}};
201 
202  static_for<0, NDimP, 1>{}([&](auto idim_p) {
203  constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size();
204 
205  static_for<0, ndim_low, 1>{}([&](auto idim_low) {
206  constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low];
207  constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low];
208 
209  if constexpr(rh_major == 0)
210  {
211  does_p_own_r(idim_p)(rh_minor) = true;
212  }
213  });
214  });
215 
216  return does_p_own_r;
217  }
218  else
219  {
220  return array<array<bool, NDimR>, NDimP>{};
221  }
222  }();
223 
224  // ps_over_rs_derivative_[NDimP][NDimR]
225  static constexpr auto ps_over_rs_derivative_ = [] {
226  if constexpr(NDimR > 0)
227  {
228  array<array<index_t, NDimR>, NDimP> ps_over_rs_derivative{{0}};
229 
230  static_for<0, NDimP, 1>{}([&](auto idim_p) {
231  constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size();
232 
233  index_t p_over_rh_derivative = 1;
234 
235  static_for<ndim_low - 1, -1, -1>{}([&](auto idim_low) {
236  constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low];
237  constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low];
238 
239  constexpr index_t rh_length = rhs_lengthss_[rh_major][rh_minor];
240 
241  if constexpr(rh_major == 0)
242  {
243  ps_over_rs_derivative(idim_p)(rh_minor) = p_over_rh_derivative;
244  }
245 
246  p_over_rh_derivative *= rh_length;
247  });
248  });
249 
250  return ps_over_rs_derivative;
251  }
252  else
253  {
255  }
256  }();
257 
258  // e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8>
260  {
261  // <len_d0, len_d1, ...>
262  // e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5>
263  constexpr auto uniformed_h_dim_lengths = generate_sequence_v2(
264  [&](auto i) {
265  constexpr index_t size = HsLengthss{}[i].size();
266  return number<size>{};
267  },
268  number<NDimX>{});
269 
270  // <0, len_d0, len_d0+len_d1, ...>
271  // e.g. seq<3, 5> --> seq<0, 3, 8>
272  constexpr auto h_dim_prefix_sum = prefix_sum_sequence(uniformed_h_dim_lengths);
273 
274  return h_dim_prefix_sum;
275  }
276 
278  {
279  constexpr auto all_ys_2_rhss = transform_sequences(
280  [](auto major, auto minor) constexpr {
281  // <0, 0, len_d0, len_d0+len_d1, ...>
282  constexpr auto x_dim_prefix_sum = merge_sequences(
283  sequence<0>{} /*for R dims*/, get_h_dim_lengths_prefix_sum());
284  return x_dim_prefix_sum.at(major) + minor;
285  },
286  Ys2RHsMajor{},
287  Ys2RHsMinor{});
288 
289  return all_ys_2_rhss;
290  }
291 
292  // return tuple<sorted_dims, sorted_maps, sorted_prefix_sum>
293  template <typename IdxSeq, typename PrefixSumSeq>
294  CK_TILE_HOST_DEVICE static constexpr auto get_sorted_info(IdxSeq, PrefixSumSeq)
295  {
296  using sorted_idx = sequence_unique_sort<IdxSeq, less<index_t>, equal<index_t>>;
297 
298  constexpr auto sorted_dims = typename sorted_idx::type{};
299  constexpr auto sorted_maps = typename sorted_idx::sorted2unsorted_map{};
300 
301  constexpr auto sorted_histogram =
302  histogram_sorted_sequence(sorted_dims, PrefixSumSeq{});
303  constexpr auto sorted_prefix_sum = prefix_sum_sequence(sorted_histogram);
304 
305  return make_tuple(sorted_dims, sorted_maps, sorted_prefix_sum);
306  }
307 
308  CK_TILE_HOST_DEVICE static constexpr auto get_sorted_y_info()
309  {
311  }
312 
314  {
315  printf("tile_distribution_encoding::detail{");
316  //
317  printf("ndim_rh_major_: ");
319  printf(", ");
320  //
321  printf("ndim_span_major_: ");
323  printf(", ");
324  //
325  printf("ndims_rhs_minor_: ");
327  printf(", ");
328  //
329  printf("ndim_rh_major_: ");
331  printf(", ");
332  //
333  printf("max_ndim_rh_minor_: ");
335  printf(", ");
336  //
337  printf("rhs_lengthss_: ");
339  printf(", ");
340  //
341  printf("ys_lengths_: ");
343  printf(", ");
344  //
345  printf("rhs_major_minor_to_ys_: ");
347  printf(", ");
348  //
349  printf("ndims_span_minor_: ");
351  printf(", ");
352  //
353  printf("max_ndim_span_minor_: ");
355  printf(", ");
356  //
357  printf("ys_to_span_major_: ");
359  printf(", ");
360  //
361  printf("ys_to_span_minor_: ");
363  printf(", ");
364  //
365  printf("distributed_spans_lengthss_: ");
367  printf(", ");
368  //
369  printf("ndims_distributed_spans_minor_: ");
371  printf(", ");
372  //
373  printf("ps_over_rs_derivative_: ");
375  //
376  printf("}");
377  }
378  };
379 
381  {
382  printf("tile_distribution_encoding{");
383  //
384  printf("NDimX: %d, NDimP: %d, NDimY: %d, ", NDimX, NDimP, NDimY);
385  //
386  printf("rs_lengths_: ");
388  printf(", ");
389  //
390  printf("hs_lengthss_: ");
392  printf(", ");
393  //
394  printf("ps_to_rhss_major_: ");
396  printf(", ");
397  //
398  printf("ps_to_rhss_minor_: ");
400  printf(", ");
401  //
402  printf("ys_to_rhs_major_: ");
404  printf(", ");
405  //
406  printf("ys_to_rhs_minor_: ");
408  printf(", ");
409  //
410  printf("detail: ");
411  print(detail{});
412  //
413  printf("}");
414  }
415 };
416 
417 namespace detail {
418 
419 template <typename OuterDstr, typename InnerDstr>
420 CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
421 {
422  static_assert(OuterDstr::NDimX == InnerDstr::NDimX, "wrong!");
423 
424  constexpr index_t NDimHMajor = OuterDstr::NDimX;
425 
426  using RsLengths =
428 
429  constexpr auto hs_lengthss = generate_tuple(
430  [&](auto i) {
431  return merge_sequences(typename OuterDstr::HsLengthss{}[i],
432  typename InnerDstr::HsLengthss{}[i]);
433  },
435 
436  //
437  constexpr auto rhs_major_2_ndim_outer_rhs_minor = [&]() {
438  array<index_t, NDimHMajor + 1> rhs_major_2_ndim_outer_rhs_minor_;
439 
440  // R dimension
441  rhs_major_2_ndim_outer_rhs_minor_(0) = OuterDstr::RsLengths::size();
442 
443  // Hs dimensions
444  static_for<0, NDimHMajor, 1>{}([&](auto i) {
445  rhs_major_2_ndim_outer_rhs_minor_(i + 1) = typename OuterDstr::HsLengthss{}[i].size();
446  });
447 
448  return rhs_major_2_ndim_outer_rhs_minor_;
449  }();
450 
451  // Ps2RHssMinor
452  constexpr auto updated_inner_ps_2_rhss_minor = generate_tuple(
453  [&](auto p) {
454  constexpr auto inner_p_2_rhss_major = typename InnerDstr::Ps2RHssMajor{}[p];
455  constexpr auto inner_p_2_rhss_minor = typename InnerDstr::Ps2RHssMinor{}[p];
456 
457  constexpr index_t ndim_tmp = inner_p_2_rhss_minor.size();
458 
459  constexpr auto updated_inner_p_2_rhss_minor = [&]() {
460  array<index_t, ndim_tmp> updated_inner_p_2_rhss_minor_;
461 
462  for(index_t i = 0; i < ndim_tmp; i++)
463  {
464  index_t rh_major = inner_p_2_rhss_major[i];
465 
466  index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
467 
468  updated_inner_p_2_rhss_minor_(i) = inner_p_2_rhss_minor[i] + ndim_outer_h_minor;
469  }
470 
471  return updated_inner_p_2_rhss_minor_;
472  }();
473 
474  return TO_SEQUENCE(updated_inner_p_2_rhss_minor, ndim_tmp);
475  },
477 
478  // Ys2RHsMinor
479  constexpr auto updated_inner_ys_2_rhs_minor = [&]() {
480  constexpr auto inner_ys_2_rhs_major = typename InnerDstr::Ys2RHsMajor{};
481  constexpr auto inner_ys_2_rhs_minor = typename InnerDstr::Ys2RHsMinor{};
482 
483  constexpr index_t ndim_tmp = inner_ys_2_rhs_minor.size();
484 
485  constexpr auto updated_inner_ys_2_rhs_minor_ = [&]() {
486  array<index_t, ndim_tmp> updated_inner_ys_2_rhs_minor__;
487 
488  for(index_t i = 0; i < ndim_tmp; i++)
489  {
490  index_t rh_major = inner_ys_2_rhs_major[i];
491 
492  index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
493 
494  updated_inner_ys_2_rhs_minor__(i) = inner_ys_2_rhs_minor[i] + ndim_outer_h_minor;
495  }
496 
497  return updated_inner_ys_2_rhs_minor__;
498  }();
499 
500  return TO_SEQUENCE(updated_inner_ys_2_rhs_minor_, ndim_tmp);
501  }();
502 
503  //
504  constexpr auto ps_2_rhss_major =
505  container_concat(typename OuterDstr::Ps2RHssMajor{}, typename InnerDstr::Ps2RHssMajor{});
506 
507  constexpr auto ps_2_rhss_minor =
508  container_concat(typename OuterDstr::Ps2RHssMinor{}, updated_inner_ps_2_rhss_minor);
509 
510  //
511  constexpr auto ys_2_rhs_major =
512  merge_sequences(typename OuterDstr::Ys2RHsMajor{}, typename InnerDstr::Ys2RHsMajor{});
513 
514  constexpr auto ys_2_rhs_minor =
515  merge_sequences(typename OuterDstr::Ys2RHsMinor{}, updated_inner_ys_2_rhs_minor);
516 
517  return tile_distribution_encoding<RsLengths,
518  remove_cvref_t<decltype(hs_lengthss)>,
519  remove_cvref_t<decltype(ps_2_rhss_major)>,
520  remove_cvref_t<decltype(ps_2_rhss_minor)>,
521  remove_cvref_t<decltype(ys_2_rhs_major)>,
522  remove_cvref_t<decltype(ys_2_rhs_minor)>>{};
523 }
524 
525 template <typename InDstr, index_t... InReduceDimXs>
526 CK_TILE_HOST_DEVICE constexpr auto
528 {
529  constexpr auto I1 = number<1>{};
530 
531  // FIXME: increase if fail
532  constexpr index_t max_ndim_r_out = 20;
533  constexpr index_t max_ndim_y_out = 20;
534 
535  //
536  constexpr index_t ndim_p = InDstr::NDimP;
537  constexpr index_t ndim_x_in = InDstr::NDimX;
538  constexpr index_t ndim_y_in = InDstr::NDimY;
539  constexpr index_t ndim_rh_major_in = InDstr::NDimX + 1;
540  constexpr index_t ndim_x_out = ndim_x_in - sizeof...(InReduceDimXs);
541  constexpr index_t max_ndim_rh_minor_in = InDstr::detail::max_ndim_rh_minor_;
542 
543  // ndims_ps_low
544  constexpr auto ndims_ps_low = generate_array(
545  [&](auto i) { return InDstr::ps_to_rhss_major_[i].size(); }, number<ndim_p>{});
546 
547  // is_rh_major_in_for_reduce
548  array<bool, ndim_rh_major_in> is_rh_major_in_for_reduce{false};
549 
550  for(index_t i = 0; i < reduce_dim_xs_in.size(); i++)
551  {
552  index_t rh_major = reduce_dim_xs_in[i] + 1;
553 
554  is_rh_major_in_for_reduce(rh_major) = true;
555  }
556 
557  // is_y_in_for_reduce
558  array<bool, ndim_y_in> is_y_in_for_reduce{false};
559 
560  for(index_t i = 0; i < ndim_y_in; i++)
561  {
562  index_t rh_major = InDstr::ys_to_rhs_major_[i];
563 
564  if(is_rh_major_in_for_reduce[rh_major])
565  {
566  is_y_in_for_reduce(i) = true;
567  }
568  }
569 
570  // is_rh_minor_in_for_y_reduce
571  array<array<bool, max_ndim_rh_minor_in>, ndim_rh_major_in> is_rh_minor_in_for_y_reduce{{false}};
572 
573  static_for<0, ndim_y_in, 1>{}([&](auto i) {
574  index_t rh_major = InDstr::ys_to_rhs_major_[i];
575  index_t rh_minor = InDstr::ys_to_rhs_minor_[i];
576 
577  if(is_y_in_for_reduce[i])
578  {
579  is_rh_minor_in_for_y_reduce(rh_major)(rh_minor) = true;
580  }
581  });
582 
583  // in2out_rh_major
584  array<index_t, ndim_rh_major_in> in2out_rh_major{-1};
585  index_t cnt_ndim_rh_major_out = 0;
586 
587  for(index_t i = 0; i < ndim_rh_major_in; i++)
588  {
589  if(is_rh_major_in_for_reduce[i])
590  {
591  in2out_rh_major(i) = 0;
592  }
593  else
594  {
595  in2out_rh_major(i) = cnt_ndim_rh_major_out;
596 
597  cnt_ndim_rh_major_out++;
598  }
599  }
600 
601  // rs_lengths_out, in2out_rh_minor
602  array<index_t, max_ndim_r_out> rs_lengths_out{-1};
603  array<array<index_t, max_ndim_rh_minor_in>, ndim_rh_major_in> in2out_rh_minor{{-1}};
604 
605  // loop over input R dim
606  for(index_t i = 0; i < InDstr::rs_lengths_.size(); i++)
607  {
608  // rs_lengths_out
609  rs_lengths_out(i) = InDstr::rs_lengths_[i];
610 
611  // in2out_rh_minor
612  in2out_rh_minor(0)(i) = i;
613  }
614 
615  // loop over input H Dim
616  index_t cnt_ndim_r_out = InDstr::rs_lengths_.size();
617 
618  static_for<1, ndim_rh_major_in, 1>{}([&](auto rh_major_in) {
619  constexpr auto h_major_in = rh_major_in - I1;
620 
621  constexpr index_t ndim_rh_minor_in = InDstr::hs_lengthss_[h_major_in].size();
622 
623  if(is_rh_major_in_for_reduce[rh_major_in])
624  {
625  for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
626  {
627  if(not is_rh_minor_in_for_y_reduce[rh_major_in][rh_minor_in])
628  {
629  // rs_lengths_out
630  rs_lengths_out(cnt_ndim_r_out) = InDstr::hs_lengthss_[h_major_in][rh_minor_in];
631 
632  // in2out_rh_minor
633  in2out_rh_minor(rh_major_in)(rh_minor_in) = cnt_ndim_r_out;
634 
635  cnt_ndim_r_out++;
636  }
637  }
638  }
639  else
640  {
641  for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
642  {
643  // in2out_rh_minor
644  in2out_rh_minor(rh_major_in)(rh_minor_in) = rh_minor_in;
645  }
646  }
647  });
648 
649  // ndim_r_out
650  const index_t ndim_r_out = cnt_ndim_r_out;
651 
652  // ndims_hs_minor_out, hs_lengthss_out
653  array<index_t, ndim_x_out> ndims_hs_minor_out{-1};
654  array<array<index_t, max_ndim_rh_minor_in>, ndim_x_out> hs_lengthss_out{{-1}};
655 
656  index_t cnt_ndim_x_out = 0;
657 
658  static_for<0, ndim_x_in, 1>{}([&](auto i) {
659  if(not is_rh_major_in_for_reduce[i + I1])
660  {
661  // ndims_hs_minor_out
662  ndims_hs_minor_out(cnt_ndim_x_out) = InDstr::hs_lengthss_[i].size();
663 
664  // hs_lengthss_out
665  static_for<0, InDstr::hs_lengthss_[i].size(), 1>{}(
666  [&](auto j) { hs_lengthss_out(cnt_ndim_x_out)(j) = InDstr::hs_lengthss_[i][j]; });
667 
668  cnt_ndim_x_out++;
669  }
670  });
671 
672  // ps_to_rhss_major_out, ps_to_rhss_minor_out
673  array<array<index_t, max_ndim_rh_minor_in>, ndim_p> ps_to_rhss_major_out{{-1}};
674  array<array<index_t, max_ndim_rh_minor_in>, ndim_p> ps_to_rhss_minor_out{{-1}};
675 
676  static_for<0, ndim_p, 1>{}([&](auto idim_p) {
677  static_for<0, InDstr::ps_to_rhss_major_[idim_p].size(), 1>{}([&](auto idim_low) {
678  index_t rh_major_in = InDstr::ps_to_rhss_major_[idim_p][idim_low];
679  index_t rh_minor_in = InDstr::ps_to_rhss_minor_[idim_p][idim_low];
680 
681  ps_to_rhss_major_out(idim_p)(idim_low) = in2out_rh_major[rh_major_in];
682  ps_to_rhss_minor_out(idim_p)(idim_low) = in2out_rh_minor[rh_major_in][rh_minor_in];
683  });
684  });
685 
686  // ys_to_rhs_major_out, ys_to_rhs_minor_out
687  array<index_t, max_ndim_y_out> ys_to_rhs_major_out{-1};
688  array<index_t, max_ndim_y_out> ys_to_rhs_minor_out{-1};
689 
690  index_t cnt_ndim_y_out = 0;
691 
692  static_for<0, ndim_y_in, 1>{}([&](auto i) {
693  if(not is_y_in_for_reduce[i])
694  {
695  index_t rh_major_in = InDstr::ys_to_rhs_major_[i];
696  index_t rh_minor_in = InDstr::ys_to_rhs_minor_[i];
697 
698  ys_to_rhs_major_out(cnt_ndim_y_out) = in2out_rh_major[rh_major_in];
699  ys_to_rhs_minor_out(cnt_ndim_y_out) = in2out_rh_minor[rh_major_in][rh_minor_in];
700 
701  cnt_ndim_y_out++;
702  }
703  });
704 
705  // ndim_y_out
706  const index_t ndim_y_out = cnt_ndim_y_out;
707 
708  //
709  return make_tuple(ndim_x_out,
710  ndim_p,
711  ndim_y_out,
712  ndim_r_out,
713  ndims_hs_minor_out,
714  ndims_ps_low,
715  rs_lengths_out,
716  hs_lengthss_out,
717  ps_to_rhss_major_out,
718  ps_to_rhss_minor_out,
719  ys_to_rhs_major_out,
720  ys_to_rhs_minor_out);
721 }
722 
723 template <typename InDstr, index_t... InReduceDimXs>
724 CK_TILE_HOST_DEVICE constexpr auto
726 {
727  constexpr auto impl = make_reduce_tile_distribution_encoding_impl(InDstr{}, reduce_dim_xs_in);
728 
729  constexpr index_t ndim_x = impl.template at<0>();
730  constexpr index_t ndim_p = impl.template at<1>();
731  constexpr index_t ndim_y = impl.template at<2>();
732  constexpr index_t ndim_r = impl.template at<3>();
733  constexpr auto ndims_hs_minor = impl.template at<4>();
734  constexpr auto ndims_ps_low = impl.template at<5>();
735  constexpr auto rs_lengths_impl = impl.template at<6>();
736  constexpr auto hs_lengthss_impl = impl.template at<7>();
737  constexpr auto ps_to_rhss_major_impl = impl.template at<8>();
738  constexpr auto ps_to_rhss_minor_impl = impl.template at<9>();
739  constexpr auto ys_to_rhs_major_impl = impl.template at<10>();
740  constexpr auto ys_to_rhs_minor_impl = impl.template at<11>();
741 
742  constexpr auto rs_lengths = TO_SEQUENCE(rs_lengths_impl, ndim_r);
743  constexpr auto hs_lengthss = TO_TUPLE_OF_SEQUENCE(hs_lengthss_impl, ndim_x, ndims_hs_minor);
744  constexpr auto ps_to_rhss_major =
745  TO_TUPLE_OF_SEQUENCE(ps_to_rhss_major_impl, ndim_p, ndims_ps_low);
746  constexpr auto ps_to_rhss_minor =
747  TO_TUPLE_OF_SEQUENCE(ps_to_rhss_minor_impl, ndim_p, ndims_ps_low);
748  constexpr auto ys_to_rhs_major = TO_SEQUENCE(ys_to_rhs_major_impl, ndim_y);
749  constexpr auto ys_to_rhs_minor = TO_SEQUENCE(ys_to_rhs_minor_impl, ndim_y);
750 
751  return tile_distribution_encoding<remove_cvref_t<decltype(rs_lengths)>,
752  remove_cvref_t<decltype(hs_lengthss)>,
753  remove_cvref_t<decltype(ps_to_rhss_major)>,
754  remove_cvref_t<decltype(ps_to_rhss_minor)>,
755  remove_cvref_t<decltype(ys_to_rhs_major)>,
756  remove_cvref_t<decltype(ys_to_rhs_minor)>>{};
757 }
758 
759 } // namespace detail
760 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding_impl(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:527
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:725
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:420
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto container_reduce(const Container &x, Reduce reduce, Init init, number< IBegin >=number< 0 >{}, number< IEnd >=number< Container::size()>{}, number< IStep >=number< 1 >{})
Definition: container_helper.hpp:198
typename sequence_merge< Seqs... >::type sequence_merge_t
Definition: sequence.hpp:1014
constexpr CK_TILE_HOST_DEVICE auto transform_sequences(F f, sequence< Xs... >)
Definition: sequence.hpp:823
constexpr CK_TILE_HOST_DEVICE auto generate_array(F &&f, number< N >)
Definition: sequence.hpp:1106
constexpr CK_TILE_HOST_DEVICE auto to_array_of_array(tuple< Seqs... > t_of_s)
Definition: tuple.hpp:589
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constexpr CK_TILE_HOST_DEVICE auto generate_sequence_v2(F &&f, number< N >)
Definition: sequence.hpp:1036
constexpr CK_TILE_HOST_DEVICE auto merge_sequences(Seqs...)
Definition: sequence.hpp:817
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:400
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
constexpr CK_TILE_HOST_DEVICE auto histogram_sorted_sequence(SeqSortedSamples, sequence< r, rs... >)
Definition: sequence.hpp:1093
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
constexpr auto prefix_sum_sequence(Seq)
Definition: sequence.hpp:899
Definition: array.hpp:24
static constexpr CK_TILE_HOST_DEVICE auto size()
Definition: array.hpp:78
Definition: integral_constant.hpp:13
Definition: math.hpp:122
Definition: sequence.hpp:584
Definition: sequence.hpp:52
static constexpr CK_TILE_HOST_DEVICE index_t size()
Definition: sequence.hpp:56
Definition: functional.hpp:43
Definition: tile_distribution_encoding.hpp:53
static constexpr index_t max_ndim_span_minor_
Definition: tile_distribution_encoding.hpp:124
static constexpr CK_TILE_HOST_DEVICE auto get_sorted_info(IdxSeq, PrefixSumSeq)
Definition: tile_distribution_encoding.hpp:294
static constexpr auto rhs_lengthss_
Definition: tile_distribution_encoding.hpp:77
static constexpr auto distributed_spans_lengthss_
Definition: tile_distribution_encoding.hpp:164
static constexpr auto does_p_own_r_
Definition: tile_distribution_encoding.hpp:197
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_idx_y_to_h()
Definition: tile_distribution_encoding.hpp:277
static constexpr index_t max_ndim_rh_minor_
Definition: tile_distribution_encoding.hpp:73
static constexpr auto ys_to_span_major_
Definition: tile_distribution_encoding.hpp:153
CK_TILE_HOST_DEVICE void print() const
Definition: tile_distribution_encoding.hpp:313
static constexpr auto rhs_major_minor_to_span_minor_
Definition: tile_distribution_encoding.hpp:128
static constexpr CK_TILE_HOST_DEVICE auto get_h_dim_lengths_prefix_sum()
Definition: tile_distribution_encoding.hpp:259
static constexpr index_t ndim_span_major_
Definition: tile_distribution_encoding.hpp:56
static constexpr auto ndims_span_minor_
Definition: tile_distribution_encoding.hpp:110
static constexpr CK_TILE_HOST_DEVICE auto get_sorted_y_info()
Definition: tile_distribution_encoding.hpp:308
static constexpr index_t ndim_rh_major_
Definition: tile_distribution_encoding.hpp:55
static constexpr auto ps_over_rs_derivative_
Definition: tile_distribution_encoding.hpp:225
static constexpr auto ys_lengths_
Definition: tile_distribution_encoding.hpp:81
static constexpr auto ndims_distributed_spans_minor_
Definition: tile_distribution_encoding.hpp:184
static constexpr auto rhs_major_minor_to_ys_
Definition: tile_distribution_encoding.hpp:96
static constexpr auto ndims_rhs_minor_
Definition: tile_distribution_encoding.hpp:59
static constexpr auto ys_to_span_minor_
Definition: tile_distribution_encoding.hpp:157
Definition: tile_distribution_encoding.hpp:26
static constexpr index_t NDimR
Definition: tile_distribution_encoding.hpp:40
static constexpr auto ps_to_rhss_minor_
Definition: tile_distribution_encoding.hpp:46
static constexpr auto rs_lengths_
Definition: tile_distribution_encoding.hpp:43
static constexpr index_t NDimP
Definition: tile_distribution_encoding.hpp:38
remove_cvref_t< Ps2RHssMinor_ > Ps2RHssMinor
Definition: tile_distribution_encoding.hpp:30
CK_TILE_HOST_DEVICE void print() const
Definition: tile_distribution_encoding.hpp:380
static constexpr auto ys_to_rhs_major_
Definition: tile_distribution_encoding.hpp:47
static constexpr auto ys_to_rhs_minor_
Definition: tile_distribution_encoding.hpp:48
static constexpr index_t NDimY
Definition: tile_distribution_encoding.hpp:39
remove_cvref_t< Ys2RHsMinor_ > Ys2RHsMinor
Definition: tile_distribution_encoding.hpp:32
static constexpr auto hs_lengthss_
Definition: tile_distribution_encoding.hpp:44
remove_cvref_t< Ys2RHsMajor_ > Ys2RHsMajor
Definition: tile_distribution_encoding.hpp:31
remove_cvref_t< HsLengthss_ > HsLengthss
Definition: tile_distribution_encoding.hpp:28
remove_cvref_t< Ps2RHssMajor_ > Ps2RHssMajor
Definition: tile_distribution_encoding.hpp:29
remove_cvref_t< RsLengths_ > RsLengths
Definition: tile_distribution_encoding.hpp:27
static constexpr auto ps_to_rhss_major_
Definition: tile_distribution_encoding.hpp:45
static constexpr index_t NDimX
Definition: tile_distribution_encoding.hpp:37
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes)
Definition: container_helper.hpp:486
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10