/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/core/tensor/tile_distribution_encoding.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/core/tensor/tile_distribution_encoding.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/core/tensor/tile_distribution_encoding.hpp Source File
tile_distribution_encoding.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
16 
17 namespace ck_tile {
18 
19 template <typename RsLengths_, // sequence<...>
20  typename HsLengthss_, // tuple<sequence<...>, ...>
21  typename Ps2RHssMajor_, // tuple<sequence<...>, ...>
22  typename Ps2RHssMinor_, // tuple<sequence<...>, ...>
23  typename Ys2RHsMajor_, // sequence<...>
24  typename Ys2RHsMinor_> // sequence<...>
26 {
33 
34  static_assert(Ps2RHssMajor::size() == Ps2RHssMinor::size(), "wrong!");
35  static_assert(Ys2RHsMajor::size() == Ys2RHsMinor::size(), "wrong!");
36 
37  static constexpr index_t NDimX = HsLengthss::size();
38  static constexpr index_t NDimP = Ps2RHssMajor::size();
39  static constexpr index_t NDimY = Ys2RHsMajor::size();
40  static constexpr index_t NDimR = RsLengths::size();
41 
42  // FIXME: move into detail
43  static constexpr auto rs_lengths_ = RsLengths{};
44  static constexpr auto hs_lengthss_ = HsLengthss{};
45  static constexpr auto ps_to_rhss_major_ = Ps2RHssMajor{};
46  static constexpr auto ps_to_rhss_minor_ = Ps2RHssMinor{};
47  static constexpr auto ys_to_rhs_major_ = Ys2RHsMajor{};
48  static constexpr auto ys_to_rhs_minor_ = Ys2RHsMinor{};
49 
50 #if !CK_TILE_ENC_SUPPORT_Y_TO_R
51  static_assert(container_find(ys_to_rhs_major_, 0) == NDimY,
52  "do not support Y dim pointed to R dim");
53 #endif
54 
55  // redundant but useful info
56  // TODO: really bad code, should be over-hauled
57  struct detail
58  {
59  // ndim_rh_major_, ndim_span_mainor_
60  static constexpr index_t ndim_rh_major_ = NDimX + 1;
61  static constexpr index_t ndim_span_major_ = NDimX;
62 
63  // ndims_rhs_minor_[ndim_rh_major_]
64  static constexpr auto ndims_rhs_minor_ = generate_array(
65  [](auto i) {
66  if constexpr(i.value == 0)
67  {
68  return rs_lengths_.size();
69  }
70  else
71  {
72  return hs_lengthss_[i - number<1>{}].size();
73  }
74  },
76 
77  // max_ndim_rh_minor_
78  static constexpr index_t max_ndim_rh_minor_ =
80 
81  // rhs_lengthss_[ndim_rh_major_][max_ndim_rh_minor_]
82  static constexpr auto rhs_lengthss_ =
84 
85  // ys_lengths_
86  static constexpr auto ys_lengths_ = [] {
87  array<index_t, NDimY> ys_lengths_tmp{-1};
88 
89  for(index_t i = 0; i < NDimY; i++)
90  {
91  index_t rh_major = ys_to_rhs_major_[i];
92  index_t rh_minor = ys_to_rhs_minor_[i];
93 
94  ys_lengths_tmp(i) = rhs_lengthss_[rh_major][rh_minor];
95  }
96 
97  return ys_lengths_tmp;
98  }();
99 
100  // rhs_major_minor_to_ys_[ndim_rh_majpr_][max_ndim_rh_minor_]
101  static constexpr auto rhs_major_minor_to_ys_ = [] {
102  array<array<index_t, max_ndim_rh_minor_>, NDimX + 1> rhs_major_minor_to_ys_tmp{{-1}};
103 
104  static_for<0, NDimY, 1>{}([&](auto i) {
105  constexpr index_t rh_major = ys_to_rhs_major_[i];
106  constexpr index_t rh_minor = ys_to_rhs_minor_[i];
107 
108  rhs_major_minor_to_ys_tmp(rh_major)(rh_minor) = i;
109  });
110 
111  return rhs_major_minor_to_ys_tmp;
112  }();
113 
114  // ndims_span_minor_[NDimY]
115  static constexpr auto ndims_span_minor_ = [] {
116  array<index_t, NDimX> ndims_span_minor{0};
117 
118  for(index_t i = 0; i < NDimY; i++)
119  {
120  const index_t span_major = ys_to_rhs_major_[i] - 1;
121 
122  ndims_span_minor(span_major)++;
123  }
124 
125  return ndims_span_minor;
126  }();
127 
128  // max_ndim_span_minor_
129  static constexpr index_t max_ndim_span_minor_ =
131 
132  // rhs_major_minor_to_span_minor_ [ndim_rh_major_][max_ndim_rh_minor_]
133  static constexpr auto rhs_major_minor_to_span_minor_ = [] {
134  array<array<index_t, max_ndim_rh_minor_>, ndim_rh_major_> rhs_major_minor_to_span_minor{
135  {-1}};
136 
137  static_for<0, ndim_rh_major_, 1>{}([&](auto rh_major) {
138  constexpr index_t ndim_rh_minor = ndims_rhs_minor_[rh_major];
139 
140  index_t cnt_ndim_span_minor = 0;
141 
142  static_for<0, ndim_rh_minor, 1>{}([&](auto rh_minor) {
143  constexpr index_t idim_y = rhs_major_minor_to_ys_[rh_major][rh_minor];
144 
145  if(idim_y >= 0)
146  {
147  rhs_major_minor_to_span_minor(rh_major)(rh_minor) = cnt_ndim_span_minor;
148 
149  cnt_ndim_span_minor++;
150  }
151  });
152  });
153 
154  return rhs_major_minor_to_span_minor;
155  }();
156 
157  // ys_to_span_major_[NDimY]
158  static constexpr auto ys_to_span_major_ =
159  generate_array([](auto i) { return ys_to_rhs_major_[i] - 1; }, number<NDimY>{});
160 
161  // ys_to_span_minor_[NDimY]
162  static constexpr auto ys_to_span_minor_ = generate_array(
163  [](auto i) {
165  },
166  number<NDimY>{});
167 
168  // distributed_spans_lengthss_[ndim_span_major_][max_ndim_span_minor_]
169  static constexpr auto distributed_spans_lengthss_ = [] {
171  distributed_spans_lengthss{{-1}};
172 
173  static_for<0, NDimY, 1>{}([&](auto i) {
174  const index_t rh_major = ys_to_rhs_major_[i];
175  const index_t rh_minor = ys_to_rhs_minor_[i];
176 
177  const index_t h_length = hs_lengthss_[number<rh_major - 1>{}][rh_minor];
178 
179  const index_t span_major = rh_major - 1;
180  const index_t span_minor = rhs_major_minor_to_span_minor_[rh_major][rh_minor];
181 
182  distributed_spans_lengthss(span_major)(span_minor) = h_length;
183  });
184 
185  return distributed_spans_lengthss;
186  }();
187 
188  // ndims_distributed_spans_minor_[ndim_span_major_]
189  static constexpr auto ndims_distributed_spans_minor_ = [] {
190  array<index_t, ndim_span_major_> ndims_distributed_spans_minor{0};
191 
192  static_for<0, NDimY, 1>{}([&](auto i) {
193  const index_t span_major = ys_to_rhs_major_[i] - 1;
194 
195  ndims_distributed_spans_minor(span_major)++;
196  });
197 
198  return ndims_distributed_spans_minor;
199  }();
200 
201  // does_p_own_r_[NDimP][NDimR]
202  static constexpr auto does_p_own_r_ = [] {
203  if constexpr(NDimR > 0)
204  {
205  array<array<bool, NDimR>, NDimP> does_p_own_r{{false}};
206 
207  static_for<0, NDimP, 1>{}([&](auto idim_p) {
208  constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size();
209 
210  static_for<0, ndim_low, 1>{}([&](auto idim_low) {
211  constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low];
212  constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low];
213 
214  if constexpr(rh_major == 0)
215  {
216  does_p_own_r(idim_p)(rh_minor) = true;
217  }
218  });
219  });
220 
221  return does_p_own_r;
222  }
223  else
224  {
225  return array<array<bool, NDimR>, NDimP>{};
226  }
227  }();
228 
229  // ps_over_rs_derivative_[NDimP][NDimR]
230  static constexpr auto ps_over_rs_derivative_ = [] {
231  if constexpr(NDimR > 0)
232  {
233  array<array<index_t, NDimR>, NDimP> ps_over_rs_derivative{{0}};
234 
235  static_for<0, NDimP, 1>{}([&](auto idim_p) {
236  constexpr index_t ndim_low = ps_to_rhss_major_[idim_p].size();
237 
238  index_t p_over_rh_derivative = 1;
239 
240  static_for<ndim_low - 1, -1, -1>{}([&](auto idim_low) {
241  constexpr index_t rh_major = ps_to_rhss_major_[idim_p][idim_low];
242  constexpr index_t rh_minor = ps_to_rhss_minor_[idim_p][idim_low];
243 
244  constexpr index_t rh_length = rhs_lengthss_[rh_major][rh_minor];
245 
246  if constexpr(rh_major == 0)
247  {
248  ps_over_rs_derivative(idim_p)(rh_minor) = p_over_rh_derivative;
249  }
250 
251  p_over_rh_derivative *= rh_length;
252  });
253  });
254 
255  return ps_over_rs_derivative;
256  }
257  else
258  {
260  }
261  }();
262 
264  {
265  // e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5>
266  constexpr auto uniformed_h_dim_lengths = generate_sequence_v2(
267  [&](auto i) {
268  constexpr index_t size_ = HsLengthss{}[i].size();
269  return number<size_>{};
270  },
271  number<NDimX>{});
272  return uniformed_h_dim_lengths;
273  }
274 
275  // note: this function only count the p dim length along h, not r
277  {
278  // e.g. tuple<seq<1, 4, 32>, seq<1, 2, 8, 4, 4>>
279  // Y P Y Y P Y P Y
280  // | | |
281  // v v v
282  // return : seq<4, 2 * 4> => seq<4, 8>
283  constexpr auto uniformed_ps_to_rhss_major_ =
284  unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_major_);
285  constexpr auto uniformed_ps_to_rhss_minor_ =
286  unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_minor_);
287 
288  constexpr auto p_len_ = [&]() {
289  array<index_t, NDimX> len_{1};
290  static_for<0, NDimX, 1>{}([&](auto idim_x_) {
291  constexpr auto major_ = number<idim_x_ + 1>{}; // RDim
292  static_for<0, uniformed_ps_to_rhss_major_.size(), 1>{}([&](auto idim_u_) {
293  if constexpr(major_.value == uniformed_ps_to_rhss_major_[idim_u_])
294  {
295  constexpr auto minor_ = uniformed_ps_to_rhss_minor_[idim_u_];
296  constexpr auto h_length_ = hs_lengthss_[idim_x_][minor_];
297  len_[idim_x_] *= h_length_;
298  }
299  });
300  });
301  return len_;
302  }();
303  constexpr auto p_len_over_h_seq_ = TO_SEQUENCE(p_len_, NDimX);
304  return p_len_over_h_seq_;
305  }
306 
307  //
308  // R: seq<3>, H: tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>>
309  // => return seq<1, 3, 5>
310  // R: seq<>, H: tuple<seq<2, 4>, seq<16, 8, 8>>
311  // => return seq<0, 2, 3>
313  {
314  constexpr auto uniformed_rh_dim_lengths =
316 
317  return uniformed_rh_dim_lengths;
318  }
319 
320  // e.g. tuple<seq<1, 4, 32>, seq<4, 1, 4, 2, 4>> --> seq<3, 5> --> seq<0, 3, 8>
322  {
323  // <0, len_d0, len_d0+len_d1, ...>
324  // e.g. seq<3, 5> --> seq<0, 3, 8>
325  constexpr auto h_dim_prefix_sum = prefix_sum_sequence(get_uniformed_h_dim_lengths());
326 
327  return h_dim_prefix_sum;
328  }
329 
331  {
332  // <0, len_d0, len_d0+len_d1, ...>
333  // e.g. seq<3, 5> --> seq<0, 3, 8>
334  constexpr auto rh_dim_prefix_sum = prefix_sum_sequence(get_uniformed_rh_dim_lengths());
335 
336  return rh_dim_prefix_sum;
337  }
338 
340  {
341  // tuple<seq<xx..>, seq<yy..>> -> seq<xx..yy..>
342  constexpr auto uniformed_ps_to_rhss_major_ =
343  unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_major_);
344  constexpr auto uniformed_ps_to_rhss_minor_ =
345  unpack([](auto... xs_) { return merge_sequences(xs_...); }, ps_to_rhss_minor_);
346 
347  constexpr auto all_ps_2_rhss = transform_sequences(
348  [](auto major, auto minor) constexpr {
349  constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
350  return rh_dim_prefix_sum.at(major) + minor;
351  },
352  uniformed_ps_to_rhss_major_,
353  uniformed_ps_to_rhss_minor_);
354 
355  return all_ps_2_rhss;
356  }
357 
359  {
360  constexpr auto all_ys_2_rhss = transform_sequences(
361  [](auto major, auto minor) constexpr {
362  constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
363  return rh_dim_prefix_sum.at(major) + minor;
364  },
365  Ys2RHsMajor{},
366  Ys2RHsMinor{});
367 
368  return all_ys_2_rhss;
369  }
370 
372  {
373  // TODO: Y can't point to R
374  constexpr auto all_ys_2_rhss = transform_sequences(
375  [](auto major, auto minor) constexpr {
376  constexpr auto rh_dim_prefix_sum = get_rh_dim_lengths_prefix_sum();
377  return rh_dim_prefix_sum.at(major) + minor - NDimR;
378  },
379  Ys2RHsMajor{},
380  Ys2RHsMinor{});
381 
382  return all_ys_2_rhss;
383  }
384 
385  // return tuple of seq
386  CK_TILE_HOST_DEVICE static constexpr auto get_y_to_h_masks()
387  {
388  constexpr auto masks_ = generate_tuple(
389  [&](auto i) {
390  constexpr auto size_ = HsLengthss{}[i].size();
391  constexpr auto current_y_to_h_mask_ = [&]() {
392  array<index_t, size_> m_{0};
393  // TODO: we loop over all y for each h dim
394  for(auto j = 0; j < NDimY; j++)
395  {
396  if(Ys2RHsMajor{}[j] == (i + 1) /*RDim need plus 1*/)
397  {
398  m_[Ys2RHsMinor{}[j]] = 1;
399  }
400  }
401  return m_;
402  }();
403 
404  return TO_SEQUENCE(current_y_to_h_mask_, size_);
405  },
406  number<NDimX>{});
407  return masks_;
408  }
409 
410  // return tuple<sorted_dims, sorted_maps, sorted_prefix_sum>
411  template <typename IdxSeq, typename PrefixSumSeq>
412  CK_TILE_HOST_DEVICE static constexpr auto get_sorted_info(IdxSeq, PrefixSumSeq)
413  {
414  using sorted_idx = sequence_unique_sort<IdxSeq, less<index_t>, equal<index_t>>;
415 
416  constexpr auto sorted_dims = typename sorted_idx::type{};
417  constexpr auto sorted_maps = typename sorted_idx::sorted2unsorted_map{};
418 
419  constexpr auto sorted_histogram =
420  histogram_sorted_sequence(sorted_dims, PrefixSumSeq{});
421  constexpr auto sorted_prefix_sum = prefix_sum_sequence(sorted_histogram);
422 
423  return make_tuple(sorted_dims, sorted_maps, sorted_prefix_sum);
424  }
425 
426  // Note here y_to_h does not count R dim!
428  {
430  }
431 
433  {
434  printf("tile_distribution_encoding::detail{");
435  //
436  printf("ndim_rh_major_: ");
438  printf(", ");
439  //
440  printf("ndim_span_major_: ");
442  printf(", ");
443  //
444  printf("ndims_rhs_minor_: ");
446  printf(", ");
447  //
448  printf("ndim_rh_major_: ");
450  printf(", ");
451  //
452  printf("max_ndim_rh_minor_: ");
454  printf(", ");
455  //
456  printf("rhs_lengthss_: ");
458  printf(", ");
459  //
460  printf("ys_lengths_: ");
462  printf(", ");
463  //
464  printf("rhs_major_minor_to_ys_: ");
466  printf(", ");
467  //
468  printf("ndims_span_minor_: ");
470  printf(", ");
471  //
472  printf("max_ndim_span_minor_: ");
474  printf(", ");
475  //
476  printf("ys_to_span_major_: ");
478  printf(", ");
479  //
480  printf("ys_to_span_minor_: ");
482  printf(", ");
483  //
484  printf("distributed_spans_lengthss_: ");
486  printf(", ");
487  //
488  printf("ndims_distributed_spans_minor_: ");
490  printf(", ");
491  //
492  printf("ps_over_rs_derivative_: ");
494  //
495  printf("}");
496  }
497  };
498 
500  {
501  printf("tile_distribution_encoding{");
502  //
503  printf("NDimX: %d, NDimP: %d, NDimY: %d, ", NDimX, NDimP, NDimY);
504  //
505  printf("rs_lengths_: ");
507  printf(", ");
508  //
509  printf("hs_lengthss_: ");
511  printf(", ");
512  //
513  printf("ps_to_rhss_major_: ");
515  printf(", ");
516  //
517  printf("ps_to_rhss_minor_: ");
519  printf(", ");
520  //
521  printf("ys_to_rhs_major_: ");
523  printf(", ");
524  //
525  printf("ys_to_rhs_minor_: ");
527  printf(", ");
528  //
529  printf("detail: ");
530  print(detail{});
531  //
532  printf("}");
533  }
534 };
535 
536 namespace detail {
537 
538 template <typename OuterDstr, typename InnerDstr>
539 CK_TILE_HOST_DEVICE constexpr auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
540 {
541  static_assert(OuterDstr::NDimX == InnerDstr::NDimX, "wrong!");
542 
543  constexpr index_t NDimHMajor = OuterDstr::NDimX;
544 
545  using RsLengths =
547 
548  constexpr auto hs_lengthss = generate_tuple(
549  [&](auto i) {
550  return merge_sequences(typename OuterDstr::HsLengthss{}[i],
551  typename InnerDstr::HsLengthss{}[i]);
552  },
554 
555  //
556  constexpr auto rhs_major_2_ndim_outer_rhs_minor = [&]() {
557  array<index_t, NDimHMajor + 1> rhs_major_2_ndim_outer_rhs_minor_;
558 
559  // R dimension
560  rhs_major_2_ndim_outer_rhs_minor_(0) = OuterDstr::RsLengths::size();
561 
562  // Hs dimensions
563  static_for<0, NDimHMajor, 1>{}([&](auto i) {
564  rhs_major_2_ndim_outer_rhs_minor_(i + 1) = typename OuterDstr::HsLengthss{}[i].size();
565  });
566 
567  return rhs_major_2_ndim_outer_rhs_minor_;
568  }();
569 
570  // Ps2RHssMinor
571  constexpr auto updated_inner_ps_2_rhss_minor = generate_tuple(
572  [&](auto p) {
573  constexpr auto inner_p_2_rhss_major = typename InnerDstr::Ps2RHssMajor{}[p];
574  constexpr auto inner_p_2_rhss_minor = typename InnerDstr::Ps2RHssMinor{}[p];
575 
576  constexpr index_t ndim_tmp = inner_p_2_rhss_minor.size();
577 
578  constexpr auto updated_inner_p_2_rhss_minor = [&]() {
579  array<index_t, ndim_tmp> updated_inner_p_2_rhss_minor_;
580 
581  for(index_t i = 0; i < ndim_tmp; i++)
582  {
583  index_t rh_major = inner_p_2_rhss_major[i];
584 
585  index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
586 
587  updated_inner_p_2_rhss_minor_(i) = inner_p_2_rhss_minor[i] + ndim_outer_h_minor;
588  }
589 
590  return updated_inner_p_2_rhss_minor_;
591  }();
592 
593  return TO_SEQUENCE(updated_inner_p_2_rhss_minor, ndim_tmp);
594  },
596 
597  // Ys2RHsMinor
598  constexpr auto updated_inner_ys_2_rhs_minor = [&]() {
599  constexpr auto inner_ys_2_rhs_major = typename InnerDstr::Ys2RHsMajor{};
600  constexpr auto inner_ys_2_rhs_minor = typename InnerDstr::Ys2RHsMinor{};
601 
602  constexpr index_t ndim_tmp = inner_ys_2_rhs_minor.size();
603 
604  constexpr auto updated_inner_ys_2_rhs_minor_ = [&]() {
605  array<index_t, ndim_tmp> updated_inner_ys_2_rhs_minor__;
606 
607  for(index_t i = 0; i < ndim_tmp; i++)
608  {
609  index_t rh_major = inner_ys_2_rhs_major[i];
610 
611  index_t ndim_outer_h_minor = rhs_major_2_ndim_outer_rhs_minor[rh_major];
612 
613  updated_inner_ys_2_rhs_minor__(i) = inner_ys_2_rhs_minor[i] + ndim_outer_h_minor;
614  }
615 
616  return updated_inner_ys_2_rhs_minor__;
617  }();
618 
619  return TO_SEQUENCE(updated_inner_ys_2_rhs_minor_, ndim_tmp);
620  }();
621 
622  //
623  constexpr auto ps_2_rhss_major =
624  container_concat(typename OuterDstr::Ps2RHssMajor{}, typename InnerDstr::Ps2RHssMajor{});
625 
626  constexpr auto ps_2_rhss_minor =
627  container_concat(typename OuterDstr::Ps2RHssMinor{}, updated_inner_ps_2_rhss_minor);
628 
629  //
630  constexpr auto ys_2_rhs_major =
631  merge_sequences(typename OuterDstr::Ys2RHsMajor{}, typename InnerDstr::Ys2RHsMajor{});
632 
633  constexpr auto ys_2_rhs_minor =
634  merge_sequences(typename OuterDstr::Ys2RHsMinor{}, updated_inner_ys_2_rhs_minor);
635 
636  return tile_distribution_encoding<RsLengths,
637  remove_cvref_t<decltype(hs_lengthss)>,
638  remove_cvref_t<decltype(ps_2_rhss_major)>,
639  remove_cvref_t<decltype(ps_2_rhss_minor)>,
640  remove_cvref_t<decltype(ys_2_rhs_major)>,
641  remove_cvref_t<decltype(ys_2_rhs_minor)>>{};
642 }
643 
644 template <typename InDstr, index_t... InReduceDimXs>
645 CK_TILE_HOST_DEVICE constexpr auto
647 {
648  constexpr auto I1 = number<1>{};
649 
650  // FIXME: increase if fail
651  constexpr index_t max_ndim_r_out = 20;
652  constexpr index_t max_ndim_y_out = 20;
653 
654  //
655  constexpr index_t ndim_p = InDstr::NDimP;
656  constexpr index_t ndim_x_in = InDstr::NDimX;
657  constexpr index_t ndim_y_in = InDstr::NDimY;
658  constexpr index_t ndim_rh_major_in = InDstr::NDimX + 1;
659  constexpr index_t ndim_x_out = ndim_x_in - sizeof...(InReduceDimXs);
660  constexpr index_t max_ndim_rh_minor_in = InDstr::detail::max_ndim_rh_minor_;
661 
662  // ndims_ps_low
663  constexpr auto ndims_ps_low = generate_array(
664  [&](auto i) { return InDstr::ps_to_rhss_major_[i].size(); }, number<ndim_p>{});
665 
666  // is_rh_major_in_for_reduce
667  array<bool, ndim_rh_major_in> is_rh_major_in_for_reduce{false};
668 
669  for(index_t i = 0; i < reduce_dim_xs_in.size(); i++)
670  {
671  index_t rh_major = reduce_dim_xs_in[i] + 1;
672 
673  is_rh_major_in_for_reduce(rh_major) = true;
674  }
675 
676  // is_y_in_for_reduce
677  array<bool, ndim_y_in> is_y_in_for_reduce{false};
678 
679  for(index_t i = 0; i < ndim_y_in; i++)
680  {
681  index_t rh_major = InDstr::ys_to_rhs_major_[i];
682 
683  if(is_rh_major_in_for_reduce[rh_major])
684  {
685  is_y_in_for_reduce(i) = true;
686  }
687  }
688 
689  // is_rh_minor_in_for_y_reduce
690  array<array<bool, max_ndim_rh_minor_in>, ndim_rh_major_in> is_rh_minor_in_for_y_reduce{{false}};
691 
692  static_for<0, ndim_y_in, 1>{}([&](auto i) {
693  index_t rh_major = InDstr::ys_to_rhs_major_[i];
694  index_t rh_minor = InDstr::ys_to_rhs_minor_[i];
695 
696  if(is_y_in_for_reduce[i])
697  {
698  is_rh_minor_in_for_y_reduce(rh_major)(rh_minor) = true;
699  }
700  });
701 
702  // in2out_rh_major
703  array<index_t, ndim_rh_major_in> in2out_rh_major{-1};
704  index_t cnt_ndim_rh_major_out = 0;
705 
706  for(index_t i = 0; i < ndim_rh_major_in; i++)
707  {
708  if(is_rh_major_in_for_reduce[i])
709  {
710  in2out_rh_major(i) = 0;
711  }
712  else
713  {
714  in2out_rh_major(i) = cnt_ndim_rh_major_out;
715 
716  cnt_ndim_rh_major_out++;
717  }
718  }
719 
720  // rs_lengths_out, in2out_rh_minor
721  array<index_t, max_ndim_r_out> rs_lengths_out{-1};
722  array<array<index_t, max_ndim_rh_minor_in>, ndim_rh_major_in> in2out_rh_minor{{-1}};
723 
724  // loop over input R dim
725  for(index_t i = 0; i < InDstr::rs_lengths_.size(); i++)
726  {
727  // rs_lengths_out
728  rs_lengths_out(i) = InDstr::rs_lengths_[i];
729 
730  // in2out_rh_minor
731  in2out_rh_minor(0)(i) = i;
732  }
733 
734  // loop over input H Dim
735  index_t cnt_ndim_r_out = InDstr::rs_lengths_.size();
736 
737  static_for<1, ndim_rh_major_in, 1>{}([&](auto rh_major_in) {
738  constexpr auto h_major_in = rh_major_in - I1;
739 
740  constexpr index_t ndim_rh_minor_in = InDstr::hs_lengthss_[h_major_in].size();
741 
742  if(is_rh_major_in_for_reduce[rh_major_in])
743  {
744  for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
745  {
746  if(not is_rh_minor_in_for_y_reduce[rh_major_in][rh_minor_in])
747  {
748  // rs_lengths_out
749  rs_lengths_out(cnt_ndim_r_out) = InDstr::hs_lengthss_[h_major_in][rh_minor_in];
750 
751  // in2out_rh_minor
752  in2out_rh_minor(rh_major_in)(rh_minor_in) = cnt_ndim_r_out;
753 
754  cnt_ndim_r_out++;
755  }
756  }
757  }
758  else
759  {
760  for(index_t rh_minor_in = 0; rh_minor_in < ndim_rh_minor_in; rh_minor_in++)
761  {
762  // in2out_rh_minor
763  in2out_rh_minor(rh_major_in)(rh_minor_in) = rh_minor_in;
764  }
765  }
766  });
767 
768  // ndim_r_out
769  const index_t ndim_r_out = cnt_ndim_r_out;
770 
771  // ndims_hs_minor_out, hs_lengthss_out
772  array<index_t, ndim_x_out> ndims_hs_minor_out{-1};
773  array<array<index_t, max_ndim_rh_minor_in>, ndim_x_out> hs_lengthss_out{{-1}};
774 
775  index_t cnt_ndim_x_out = 0;
776 
777  static_for<0, ndim_x_in, 1>{}([&](auto i) {
778  if(not is_rh_major_in_for_reduce[i + I1])
779  {
780  // ndims_hs_minor_out
781  ndims_hs_minor_out(cnt_ndim_x_out) = InDstr::hs_lengthss_[i].size();
782 
783  // hs_lengthss_out
784  static_for<0, InDstr::hs_lengthss_[i].size(), 1>{}(
785  [&](auto j) { hs_lengthss_out(cnt_ndim_x_out)(j) = InDstr::hs_lengthss_[i][j]; });
786 
787  cnt_ndim_x_out++;
788  }
789  });
790 
791  // ps_to_rhss_major_out, ps_to_rhss_minor_out
792  array<array<index_t, max_ndim_rh_minor_in>, ndim_p> ps_to_rhss_major_out{{-1}};
793  array<array<index_t, max_ndim_rh_minor_in>, ndim_p> ps_to_rhss_minor_out{{-1}};
794 
795  static_for<0, ndim_p, 1>{}([&](auto idim_p) {
796  static_for<0, InDstr::ps_to_rhss_major_[idim_p].size(), 1>{}([&](auto idim_low) {
797  index_t rh_major_in = InDstr::ps_to_rhss_major_[idim_p][idim_low];
798  index_t rh_minor_in = InDstr::ps_to_rhss_minor_[idim_p][idim_low];
799 
800  ps_to_rhss_major_out(idim_p)(idim_low) = in2out_rh_major[rh_major_in];
801  ps_to_rhss_minor_out(idim_p)(idim_low) = in2out_rh_minor[rh_major_in][rh_minor_in];
802  });
803  });
804 
805  // ys_to_rhs_major_out, ys_to_rhs_minor_out
806  array<index_t, max_ndim_y_out> ys_to_rhs_major_out{-1};
807  array<index_t, max_ndim_y_out> ys_to_rhs_minor_out{-1};
808 
809  index_t cnt_ndim_y_out = 0;
810 
811  static_for<0, ndim_y_in, 1>{}([&](auto i) {
812  if(not is_y_in_for_reduce[i])
813  {
814  index_t rh_major_in = InDstr::ys_to_rhs_major_[i];
815  index_t rh_minor_in = InDstr::ys_to_rhs_minor_[i];
816 
817  ys_to_rhs_major_out(cnt_ndim_y_out) = in2out_rh_major[rh_major_in];
818  ys_to_rhs_minor_out(cnt_ndim_y_out) = in2out_rh_minor[rh_major_in][rh_minor_in];
819 
820  cnt_ndim_y_out++;
821  }
822  });
823 
824  // ndim_y_out
825  const index_t ndim_y_out = cnt_ndim_y_out;
826 
827  //
828  return make_tuple(ndim_x_out,
829  ndim_p,
830  ndim_y_out,
831  ndim_r_out,
832  ndims_hs_minor_out,
833  ndims_ps_low,
834  rs_lengths_out,
835  hs_lengthss_out,
836  ps_to_rhss_major_out,
837  ps_to_rhss_minor_out,
838  ys_to_rhs_major_out,
839  ys_to_rhs_minor_out);
840 }
841 
842 template <typename InDstr, index_t... InReduceDimXs>
843 CK_TILE_HOST_DEVICE constexpr auto
845 {
846  constexpr auto impl = make_reduce_tile_distribution_encoding_impl(InDstr{}, reduce_dim_xs_in);
847 
848  constexpr index_t ndim_x = impl.template at<0>();
849  constexpr index_t ndim_p = impl.template at<1>();
850  constexpr index_t ndim_y = impl.template at<2>();
851  constexpr index_t ndim_r = impl.template at<3>();
852  constexpr auto ndims_hs_minor = impl.template at<4>();
853  constexpr auto ndims_ps_low = impl.template at<5>();
854  constexpr auto rs_lengths_impl = impl.template at<6>();
855  constexpr auto hs_lengthss_impl = impl.template at<7>();
856  constexpr auto ps_to_rhss_major_impl = impl.template at<8>();
857  constexpr auto ps_to_rhss_minor_impl = impl.template at<9>();
858  constexpr auto ys_to_rhs_major_impl = impl.template at<10>();
859  constexpr auto ys_to_rhs_minor_impl = impl.template at<11>();
860 
861  constexpr auto rs_lengths = TO_SEQUENCE(rs_lengths_impl, ndim_r);
862  constexpr auto hs_lengthss = TO_TUPLE_OF_SEQUENCE(hs_lengthss_impl, ndim_x, ndims_hs_minor);
863  constexpr auto ps_to_rhss_major =
864  TO_TUPLE_OF_SEQUENCE(ps_to_rhss_major_impl, ndim_p, ndims_ps_low);
865  constexpr auto ps_to_rhss_minor =
866  TO_TUPLE_OF_SEQUENCE(ps_to_rhss_minor_impl, ndim_p, ndims_ps_low);
867  constexpr auto ys_to_rhs_major = TO_SEQUENCE(ys_to_rhs_major_impl, ndim_y);
868  constexpr auto ys_to_rhs_minor = TO_SEQUENCE(ys_to_rhs_minor_impl, ndim_y);
869 
870  return tile_distribution_encoding<remove_cvref_t<decltype(rs_lengths)>,
871  remove_cvref_t<decltype(hs_lengthss)>,
872  remove_cvref_t<decltype(ps_to_rhss_major)>,
873  remove_cvref_t<decltype(ps_to_rhss_minor)>,
874  remove_cvref_t<decltype(ys_to_rhs_major)>,
875  remove_cvref_t<decltype(ys_to_rhs_minor)>>{};
876 }
877 
878 } // namespace detail
879 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding_impl(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:646
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:844
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:539
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto container_reduce(const Container &x, Reduce reduce, Init init, number< IBegin >=number< 0 >{}, number< IEnd >=number< Container::size()>{}, number< IStep >=number< 1 >{})
Definition: container_helper.hpp:198
typename sequence_merge< Seqs... >::type sequence_merge_t
Definition: sequence.hpp:1014
constexpr CK_TILE_HOST_DEVICE auto transform_sequences(F f, sequence< Xs... >)
Definition: sequence.hpp:823
constexpr CK_TILE_HOST_DEVICE auto generate_array(F &&f, number< N >)
Definition: sequence.hpp:1106
constexpr CK_TILE_HOST_DEVICE auto to_array_of_array(tuple< Seqs... > t_of_s)
Definition: tuple.hpp:594
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto generate_sequence_v2(F &&f, number< N >)
Definition: sequence.hpp:1036
constexpr CK_TILE_HOST_DEVICE auto merge_sequences(Seqs...)
Definition: sequence.hpp:817
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:406
constexpr index_t container_find(sequence< Is... > seq, index_t value)
Definition: container_helper.hpp:447
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
constexpr CK_TILE_HOST_DEVICE auto unpack(F &&f, X &&x)
Definition: functional.hpp:200
constexpr CK_TILE_HOST_DEVICE auto histogram_sorted_sequence(SeqSortedSamples, sequence< r, rs... >)
Definition: sequence.hpp:1093
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
constexpr auto prefix_sum_sequence(Seq)
Definition: sequence.hpp:899
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
static constexpr CK_TILE_HOST_DEVICE auto size()
Definition: array.hpp:97
Definition: integral_constant.hpp:13
Definition: math.hpp:122
Definition: sequence.hpp:584
Definition: sequence.hpp:52
static constexpr CK_TILE_HOST_DEVICE index_t size()
Definition: sequence.hpp:56
Definition: functional.hpp:43
Definition: tile_distribution_encoding.hpp:58
static constexpr index_t max_ndim_span_minor_
Definition: tile_distribution_encoding.hpp:129
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_idx_p_to_h()
Definition: tile_distribution_encoding.hpp:339
static constexpr CK_TILE_HOST_DEVICE auto get_sorted_info(IdxSeq, PrefixSumSeq)
Definition: tile_distribution_encoding.hpp:412
static constexpr auto rhs_lengthss_
Definition: tile_distribution_encoding.hpp:82
static constexpr auto distributed_spans_lengthss_
Definition: tile_distribution_encoding.hpp:169
static constexpr auto does_p_own_r_
Definition: tile_distribution_encoding.hpp:202
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_idx_y_to_h()
Definition: tile_distribution_encoding.hpp:371
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_rh_dim_lengths()
Definition: tile_distribution_encoding.hpp:312
static constexpr index_t max_ndim_rh_minor_
Definition: tile_distribution_encoding.hpp:78
static constexpr auto ys_to_span_major_
Definition: tile_distribution_encoding.hpp:158
CK_TILE_HOST_DEVICE void print() const
Definition: tile_distribution_encoding.hpp:432
static constexpr CK_TILE_HOST_DEVICE auto get_sorted_y_to_h_info()
Definition: tile_distribution_encoding.hpp:427
static constexpr auto rhs_major_minor_to_span_minor_
Definition: tile_distribution_encoding.hpp:133
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_p_dim_lengths_over_h()
Definition: tile_distribution_encoding.hpp:276
static constexpr CK_TILE_HOST_DEVICE auto get_h_dim_lengths_prefix_sum()
Definition: tile_distribution_encoding.hpp:321
static constexpr index_t ndim_span_major_
Definition: tile_distribution_encoding.hpp:61
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_h_dim_lengths()
Definition: tile_distribution_encoding.hpp:263
static constexpr auto ndims_span_minor_
Definition: tile_distribution_encoding.hpp:115
static constexpr CK_TILE_HOST_DEVICE auto get_uniformed_idx_y_to_rh()
Definition: tile_distribution_encoding.hpp:358
static constexpr index_t ndim_rh_major_
Definition: tile_distribution_encoding.hpp:60
static constexpr auto ps_over_rs_derivative_
Definition: tile_distribution_encoding.hpp:230
static constexpr CK_TILE_HOST_DEVICE auto get_y_to_h_masks()
Definition: tile_distribution_encoding.hpp:386
static constexpr auto ys_lengths_
Definition: tile_distribution_encoding.hpp:86
static constexpr auto ndims_distributed_spans_minor_
Definition: tile_distribution_encoding.hpp:189
static constexpr auto rhs_major_minor_to_ys_
Definition: tile_distribution_encoding.hpp:101
static constexpr auto ndims_rhs_minor_
Definition: tile_distribution_encoding.hpp:64
static constexpr auto ys_to_span_minor_
Definition: tile_distribution_encoding.hpp:162
static constexpr CK_TILE_HOST_DEVICE auto get_rh_dim_lengths_prefix_sum()
Definition: tile_distribution_encoding.hpp:330
Definition: tile_distribution_encoding.hpp:26
static constexpr index_t NDimR
Definition: tile_distribution_encoding.hpp:40
static constexpr auto ps_to_rhss_minor_
Definition: tile_distribution_encoding.hpp:46
static constexpr auto rs_lengths_
Definition: tile_distribution_encoding.hpp:43
static constexpr index_t NDimP
Definition: tile_distribution_encoding.hpp:38
remove_cvref_t< Ps2RHssMinor_ > Ps2RHssMinor
Definition: tile_distribution_encoding.hpp:30
CK_TILE_HOST_DEVICE void print() const
Definition: tile_distribution_encoding.hpp:499
static constexpr auto ys_to_rhs_major_
Definition: tile_distribution_encoding.hpp:47
static constexpr auto ys_to_rhs_minor_
Definition: tile_distribution_encoding.hpp:48
static constexpr index_t NDimY
Definition: tile_distribution_encoding.hpp:39
remove_cvref_t< Ys2RHsMinor_ > Ys2RHsMinor
Definition: tile_distribution_encoding.hpp:32
static constexpr auto hs_lengthss_
Definition: tile_distribution_encoding.hpp:44
remove_cvref_t< Ys2RHsMajor_ > Ys2RHsMajor
Definition: tile_distribution_encoding.hpp:31
remove_cvref_t< HsLengthss_ > HsLengthss
Definition: tile_distribution_encoding.hpp:28
remove_cvref_t< Ps2RHssMajor_ > Ps2RHssMajor
Definition: tile_distribution_encoding.hpp:29
remove_cvref_t< RsLengths_ > RsLengths
Definition: tile_distribution_encoding.hpp:27
static constexpr auto ps_to_rhss_major_
Definition: tile_distribution_encoding.hpp:45
static constexpr index_t NDimX
Definition: tile_distribution_encoding.hpp:37
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes)
Definition: container_helper.hpp:486
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10