/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/core/tensor/tile_distribution.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/core/tensor/tile_distribution.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/core/tensor/tile_distribution.hpp Source File
tile_distribution.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
17 
18 namespace ck_tile {
19 
20 namespace detail {
21 template <typename Distribution>
23 {
24  return Distribution::_get_partition_index();
25 }
26 } // namespace detail
27 
28 // distributed span
29 template <index_t... PartialHsLengths>
31 {
32  using Impl = sequence<PartialHsLengths...>;
33 
34  static constexpr auto impl_ = Impl{};
35 
36  CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
37 };
38 
39 // distributed index
40 template <index_t... PartialHsIndices>
42 {
43  using Impl = sequence<PartialHsIndices...>;
44 
45  static constexpr auto impl_ = Impl{};
46 
47  CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
48 };
49 
50 namespace detail {
51 
52 template <index_t... Is>
54 {
55  return tile_distributed_span<Is...>{};
56 }
57 
58 template <index_t... Is>
60 {
61  return tile_distributed_index<Is...>{};
62 }
63 
64 } // namespace detail
65 
66 template <typename PsYs2XsAdaptor_,
67  typename Ys2DDescriptor_,
68  typename StaticTileDistributionEncoding_,
69  typename TileDistributionDetail_> // FIXME: this is for hold ad-hoc but useful info,
70  // should be more elegnat
72 {
77 
79  "wrong! should be static");
80 
81  static constexpr index_t NDimX = PsYs2XsAdaptor::get_num_of_bottom_dimension();
82  static constexpr index_t NDimY = Ys2DDescriptor::get_num_of_top_dimension();
83  static constexpr index_t NDimP = PsYs2XsAdaptor::get_num_of_top_dimension() - NDimY;
84  static constexpr index_t NDimR = StaticTileDistributionEncoding_::NDimR;
85 
88 
93 
95  {
96  // only support warp-tile and block-tile
97  static_assert(NDimP == 1 or NDimP == 2, "wrong!");
98 
99  if constexpr(NDimP == 1)
100  {
101  return array<index_t, 1>{get_lane_id()};
102  }
103  else if constexpr(NDimP == 2)
104  {
106  }
107  }
108 
109  CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
110  {
111 #if 0
112  // FIXME: tensor_adaptor::GetBottomDimensionLengths is wrong. re-enable this after it's fixed
113  ps_ys_to_xs_.GetBottomDimensionLengths();
114 #else
115  return generate_tuple(
116  [&](auto i) {
117  constexpr index_t x_length =
118  container_reduce(typename DstrEncode::HsLengthss{}[i], multiplies{}, 1);
119 
120  return number<x_length>{};
121  },
122  number<NDimX>{});
123 #endif
124  }
125 
126  CK_TILE_HOST_DEVICE constexpr const auto& get_ps_ys_to_xs_adaptor() const
127  {
128  return ps_ys_to_xs_;
129  }
130 
131  CK_TILE_HOST_DEVICE constexpr const auto& get_ys_to_d_descriptor() const { return ys_to_d_; }
132 
134  {
135  return DstrEncode{};
136  }
137 
138 #if 1
139  // Calculate Replication index [R0, R1, ...] based on Partion index
140  // FIXME: very nasty implementation
141  template <typename PartitionIndex>
142  CK_TILE_HOST_DEVICE auto calculate_rs_index_from_ps_index(const PartitionIndex& ps_idx) const
143  {
144  static_assert(PartitionIndex::size() == NDimP, "wrong!");
145 
146  const auto ps_ys_idx = container_concat(ps_idx, array<index_t, NDimY>{0});
147 
148  const auto dummy_adaptor_coord = make_tensor_adaptor_coordinate(ps_ys_to_xs_, ps_ys_idx);
149 
150  array<index_t, NDimR> rs_idx;
151 
152  static_for<0, NDimP, 1>{}([&](auto idim_p) {
153  constexpr index_t ndim_low = DstrEncode::ps_to_rhss_major_[idim_p].size();
154 
155  static_for<0, ndim_low, 1>{}([&](auto i) {
156  constexpr index_t rh_major = DstrEncode::ps_to_rhss_major_[idim_p][i];
157  constexpr index_t rh_minor = DstrEncode::ps_to_rhss_minor_[idim_p][i];
158 
159  // 0-th rh_major is the replicate dimension
160  if constexpr(rh_major == 0)
161  {
162  constexpr index_t adaptor_hidden_id =
163  DstrDetail::rh_major_minor_to_adaptor_hidden_idss_[rh_major][rh_minor];
164 
165  // fill in
166  rs_idx(rh_minor) = dummy_adaptor_coord.get_hidden_index()[adaptor_hidden_id];
167  }
168  });
169  });
170 
171  return rs_idx;
172  }
173 #endif
174 
175  template <typename PartitionIndex = decltype(_get_partition_index())>
177  calculate_index(const PartitionIndex& ps_idx = _get_partition_index()) const
178  {
179  const auto ps_ys_idx = container_concat(ps_idx, array<index_t, NDimY>{0});
180  const auto window_adaptor_thread_coord_tmp =
182  return window_adaptor_thread_coord_tmp.get_bottom_index();
183  }
184 
186  {
187  constexpr auto distributed_spans_impl = DstrEncode::detail::distributed_spans_lengthss_;
188  constexpr auto ndims_spans_minor = DstrEncode::detail::ndims_distributed_spans_minor_;
189 
190  return generate_tuple(
191  [&](auto i) {
192  constexpr auto span_impl = distributed_spans_impl[i];
193  constexpr index_t ndim_span_minor = ndims_spans_minor[i];
194 
195  constexpr auto span = TO_SEQUENCE(span_impl, ndim_span_minor);
196 
198  },
199  number<NDimX>{});
200  }
201 
202  // FIXME: it's hacky to get Y index from Distributed-Index
203  template <typename DistributedIndices>
204  CK_TILE_HOST_DEVICE static constexpr auto
206  {
207  constexpr auto ys_idx_arr = [] {
208  array<index_t, NDimY> ys_idx;
209 
210  static_for<0, NDimY, 1>{}([&](auto i) {
211  constexpr index_t span_major = DstrEncode::detail::ys_to_span_major_[i];
212  constexpr index_t span_minor = DstrEncode::detail::ys_to_span_minor_[i];
213 
214  constexpr auto dstr_index = DistributedIndices{}[number<span_major>{}];
215 
216  ys_idx(i) = dstr_index.impl_[span_minor];
217  });
218 
219  return ys_idx;
220  }();
221 
222  constexpr index_t ndim_y = NDimY;
223 
224  return TO_SEQUENCE(ys_idx_arr, ndim_y);
225  }
226 
227  CK_TILE_HOST_DEVICE static constexpr bool is_static()
228  {
230  }
231 
233  {
234  printf("tile_distribution{");
235  //
236  printf("tile_distribution_encoding: ");
237  print(DstrEncode{});
238  printf(", ");
239  //
240  printf("ps_ys_to_xs_: ");
242  printf(", ");
243  //
244  printf("ys_to_d_: ");
245  print(ys_to_d_);
246  //
247  printf("}");
248  }
249 };
250 
251 namespace detail {
252 
253 template <index_t NDimMax>
255 {
257 
258  for(index_t i = 0; i < iend - ibegin; ++i)
259  {
260  arr(i) = ibegin + i;
261  }
262 
263  return arr;
264 }
265 
266 // this returns a constexpr encoding of tile_distribution
267 template <typename StaticTileDistributionEncoding_>
268 CK_TILE_HOST_DEVICE constexpr auto
269  make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
270 {
271  using RsLengths = typename StaticTileDistributionEncoding_::RsLengths;
272  using HsLengthss = typename StaticTileDistributionEncoding_::HsLengthss;
273  using Ps2RHssMajor = typename StaticTileDistributionEncoding_::Ps2RHssMajor;
274  using Ps2RHssMinor = typename StaticTileDistributionEncoding_::Ps2RHssMinor;
275  using Ys2RHsMajor = typename StaticTileDistributionEncoding_::Ys2RHsMajor;
276  using Ys2RHsMinor = typename StaticTileDistributionEncoding_::Ys2RHsMinor;
277 
278  // FIXME: increase max value if fail
279  constexpr index_t kMaxNumTransforms = 20;
280  constexpr index_t kMaxMetaDataSize = 128;
281  constexpr index_t kMaxNumDim = 10;
282 
283  using Name = coord_transform_enum;
284  using MetaData = meta_data_buffer<kMaxMetaDataSize>;
285  using NumDim = index_t;
286  using Dims = array<index_t, kMaxNumDim>;
287  using Lengths = array<index_t, kMaxNumDim>;
288 
289  // Tile Adaptor
290  // bottom dims [x0, x1, x2, ...]
291  // top dims [p0, p1, ..., y0, y1, ...]
292  constexpr index_t ndim_x = HsLengthss::size();
293 
294  // Dim Ids: [idim_x_major, idim_x_minor] to [idim_hidden]
295  array<array<index_t, kMaxNumDim>, ndim_x + 1> rh_major_minor_to_hidden_ids;
296  array<array<index_t, kMaxNumDim>, ndim_x + 1> rh_major_minor_to_hidden_lengths;
297 
298  auto trans = array<tuple<Name, MetaData, NumDim, Dims, NumDim, Dims>, kMaxNumTransforms>{};
299 
300  index_t num_tran = 0;
301  index_t hidden_dim_cnt = ndim_x;
302 
303  // this is replicate transform
304  {
305  constexpr index_t ndim_r_minor = RsLengths::size();
306 
307  constexpr auto r_minor_lengths = RsLengths{};
308 
309  trans(num_tran++) = {
311  MetaData{to_array<index_t, ndim_r_minor>(r_minor_lengths)},
312  NumDim{0},
313  Dims{},
314  NumDim{ndim_r_minor},
315  make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_r_minor)};
316 
317  for(index_t i = 0; i < ndim_r_minor; ++i)
318  {
319  rh_major_minor_to_hidden_ids(0)(i) = hidden_dim_cnt;
320  rh_major_minor_to_hidden_lengths(0)(i) = r_minor_lengths[i];
321 
322  hidden_dim_cnt++;
323  }
324  };
325 
326  // these are Unmerge transforms for X dimesions
327  static_for<0, ndim_x, 1>{}([&trans,
328  &num_tran,
329  &hidden_dim_cnt,
330  &rh_major_minor_to_hidden_ids,
331  &rh_major_minor_to_hidden_lengths](auto idim_x) {
332  // typename HsLengthss::base{}.foo();
333  constexpr auto h_minor_lengths =
334  HsLengthss{}.get(idim_x); // std::tuple_element_t<idim_x, HsLengthss>{};
335  // constexpr auto h_minor_lengths = impl::getv<idim_x>(HsLengthss{});
336 
337  constexpr index_t ndim_h_minor = h_minor_lengths.size();
338 
339  trans(num_tran++) = {
341  MetaData{to_array<index_t, ndim_h_minor>(h_minor_lengths)},
342  NumDim{1},
343  Dims{idim_x},
344  NumDim{ndim_h_minor},
345  make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_h_minor)};
346 
347  for(index_t i = 0; i < ndim_h_minor; ++i)
348  {
349  rh_major_minor_to_hidden_ids(idim_x + 1)(i) = hidden_dim_cnt;
350  rh_major_minor_to_hidden_lengths(idim_x + 1)(i) = h_minor_lengths[i];
351 
352  hidden_dim_cnt++;
353  }
354  });
355 
356  // transform: P dimensions
357  constexpr index_t ndim_p = Ps2RHssMajor::size();
358 
359  Dims hidden_dim_id_ps;
360 
361  static_for<0, ndim_p, 1>{}([&](auto iDimP) {
362  //
363  index_t hidden_dim_id_p = hidden_dim_cnt++;
364 
365  hidden_dim_id_ps(iDimP) = hidden_dim_id_p;
366 
367  constexpr auto p2RHsMajor = Ps2RHssMajor{}[iDimP];
368  constexpr auto p2RHsMinor = Ps2RHssMinor{}[iDimP];
369 
370  static_assert(p2RHsMajor.size() == p2RHsMinor.size(), "wrong!");
371 
372  constexpr index_t ndim_low = p2RHsMajor.size();
373 
374  Dims low_dims;
375  Lengths low_lengths;
376 
377  for(index_t i = 0; i < ndim_low; ++i)
378  {
379  index_t rh_major = p2RHsMajor[i];
380  index_t rh_minor = p2RHsMinor[i];
381  low_dims(i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
382  low_lengths(i) = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
383  }
384 
385  trans(num_tran++) = {coord_transform_enum::merge,
386  MetaData{to_array<index_t, ndim_low>(low_lengths)},
387  NumDim{ndim_low},
388  low_dims,
389  NumDim{1},
390  Dims{hidden_dim_id_p}};
391  });
392 
393  constexpr index_t ndim_bottom = ndim_x;
394 
395  constexpr auto bottom_dim_ids = make_sequential_index<kMaxNumDim>(0, ndim_bottom);
396 
397  constexpr auto ys_to_rhs_major = Ys2RHsMajor{};
398  constexpr auto ys_to_rhs_minor = Ys2RHsMinor{};
399 
400  constexpr index_t ndim_y = Ys2RHsMajor::size();
401  constexpr index_t ndim_top = ndim_p + ndim_y;
402 
403  auto top_dim_ids = hidden_dim_id_ps;
404 
405  {
406  for(index_t i = 0; i < ndim_y; ++i)
407  {
408  index_t rh_major = ys_to_rhs_major[i];
409  index_t rh_minor = ys_to_rhs_minor[i];
410  top_dim_ids(ndim_p + i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
411  }
412  }
413 
414  //
415  const auto ps_ys_to_xs_adaptor_encoding =
416  make_tuple(trans, num_tran, bottom_dim_ids, ndim_bottom, top_dim_ids, ndim_top);
417 
418  // descriptor: [y0, y1, ...] to [d]
419  Lengths y_lengths;
420  index_t d_length = 1;
421 
422  for(index_t i = 0; i < ndim_y; ++i)
423  {
424  index_t rh_major = ys_to_rhs_major[i];
425  index_t rh_minor = ys_to_rhs_minor[i];
426  index_t y_length = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
427  y_lengths(i) = y_length;
428  d_length *= y_length;
429  }
430 
432  MetaData{to_array<index_t, ndim_y>(y_lengths)},
433  NumDim{1},
434  Dims{0},
435  NumDim{ndim_y},
436  make_sequential_index<kMaxNumDim>(1, ndim_y + 1));
437 
438  const auto ys_to_d_adaptor_encoding = make_tuple(
439  make_tuple(tran), 1, Dims{0}, 1, make_sequential_index<kMaxNumDim>(1, ndim_y + 1), ndim_y);
440 
441  return make_tuple(ps_ys_to_xs_adaptor_encoding,
442  ys_to_d_adaptor_encoding,
443  d_length,
444  rh_major_minor_to_hidden_ids);
445 }
446 
447 // FIXME: this is nasty. Move it inside TileDistributionEncoding::detail
448 template <typename RhMajorMinor2AdaptorHiddenIdss> // tuple<sequence<...>, ...>
450 {
452  to_array_of_array(RhMajorMinor2AdaptorHiddenIdss{});
453 };
454 
455 } // namespace detail
456 
457 #if 0
458 // this returns a constexpr tile_distribution
459 template <typename StaticTileDistributionEncoding_>
460 CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
461 {
462  using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
463 
464  constexpr auto adaptor_impl =
465  detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
466 
467  constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
468  constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
469  constexpr index_t d_length = adaptor_impl.template at<2>();
470  constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
471 
472  constexpr auto ps_ys_to_xs_adaptor =
473  CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
474 
475  constexpr auto ys_to_d_adaptor = CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl);
476 
477  constexpr auto ys_to_d_descriptor =
478  make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, d_length);
479 
480  //
481  constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
482  constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
483 
484  constexpr auto rh_major_minor_to_hidden_ids =
485  TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
486 
487  return tile_distribution<
488  remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
489  remove_cvref_t<decltype(ys_to_d_descriptor)>,
490  remove_cvref_t<DstrEncode>,
491  detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
492  ps_ys_to_xs_adaptor, ys_to_d_descriptor};
493 }
494 #endif
495 
496 // this returns a static tile_distribution
497 template <typename StaticTileDistributionEncoding_>
498 CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
499 {
501 
502  constexpr auto adaptor_impl =
503  detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
504 
505  constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
506  constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
507  constexpr index_t d_length = adaptor_impl.template at<2>();
508  constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
509 
510  constexpr auto ps_ys_to_xs_adaptor =
511  CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
512 
513  constexpr auto ys_to_d_adaptor =
514  CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl);
515 
516  constexpr auto ys_to_d_descriptor =
518 
519  //
520  constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
521  constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
522 
523  constexpr auto rh_major_minor_to_hidden_ids =
524  TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
525 
526  return tile_distribution<
527  remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
528  remove_cvref_t<decltype(ys_to_d_descriptor)>,
530  detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
531  ps_ys_to_xs_adaptor, ys_to_d_descriptor};
532 }
533 
534 //***********************************************************************************
535 
536 namespace detail {
537 //
538 // slice tensor from x_dim, result in split in y_dim, not p_dim.
539 // We don't support slice cross p_dim (aka, slice different threads)
540 // also, sliced along y_dim need be the first dim of current dim.
541 // Multiply Y dim before sliced dim does not make sense
542 //
543 // e.g
544 // X0 X1
545 // <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 32>, (0 means all length)
546 // Y P P Y P Y P Y
547 // => <1, 4, 32> - <1, 1, 4, 2, 4> -> OK
548 // |--> slice along this Y dim, is the first dim of X1, totally 4 slices
549 //
550 // X0 X1
551 // <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 8>, (0 means all length)
552 // Y P P Y P Y P Y
553 // => <1, 4, 32> - <1, 1, 1, 2, 4> -> OK
554 // |--> slice along this Y dim, the P dim is 1 in the left, so is OK
555 // totally 16 slices
556 //
557 // X0 X1
558 // <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 4>, (0 means all length)
559 // Y P P Y P Y P Y
560 // => <1, 4, 32> - <1, 1, 1, 1, 4> -> Fail
561 // |--> slice along this P dim, will split threads, not supported
562 //
563 // X0 X1
564 // <1, 4, 32> - <4, 1, 4, 2, 4> | slice origin:<0, 0>, len:<0, 16>, (0 means all length)
565 // Y P P Y P Y P Y
566 // => <1, 4, 32> - <1, 1, 2, 2, 4> -> OK
567 // |--> slice along this Y dim, but this Y sim need to split into 2
568 // subdime
569 // the P dim in the left is 1, means actually not crossing P
570 //
571 template <typename Distribution, index_t... XSliceBegins, index_t... XSliceEnds>
573  Distribution, sequence<XSliceBegins...> x_slice_begins, sequence<XSliceEnds...> x_slice_ends)
574 {
575  // NOTE: this function need to be called under constexpr context,
576  // due to https://wg21.link/p2280r0 we have to use non-reference type for distribution
577  using Encoding = decltype(Distribution::get_static_tile_distribution_encoding());
578 
579  static_assert(sizeof...(XSliceBegins) == sizeof...(XSliceEnds));
580 
581  constexpr auto x_slice_lengths = x_slice_ends - x_slice_begins;
582 
583  constexpr auto src_h_prefix_sum = Encoding::detail::get_h_dim_lengths_prefix_sum();
584  constexpr auto src_y_info = Encoding::detail::get_sorted_y_info();
585  constexpr auto src_y_dims = src_y_info[number<0>{}];
586  constexpr auto src_y_maps = src_y_info[number<1>{}];
587  constexpr auto src_y_prefix_sum = src_y_info[number<2>{}];
588 
589  constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr
590  {
591  auto y_slice_sorted_origins = make_zero_multi_index<Encoding::NDimY>();
592  auto y_slice_lengths = Encoding::detail::ys_lengths_;
593 
594  // This lambda will modify some value outside, so c++ will not treat return value as
595  // constexpr
596  // TODO: ugly
597  auto new_h_lengths = transform_tuples(
598  [&](auto h_len, auto id) {
599  constexpr auto sliced_h =
600  reverse_slice_sequence(h_len, number<x_slice_lengths[id]>{});
601 
602  constexpr auto sliced_h_lens = sliced_h[number<0>{}];
603  constexpr auto sliced_h_index = sliced_h[number<2>{}];
604 
605  // update y_slice_lengths
606  constexpr auto uniformed_h_index = sliced_h_index + number<src_h_prefix_sum[id]>{};
607  constexpr auto found_y_index = container_find(src_y_dims, uniformed_h_index);
608 
609  static_assert(found_y_index >= 0 && found_y_index < src_y_dims.size(),
610  "not sliced at y dim, please check");
611 
613  y_slice_lengths(src_y_maps[found_y_index - i]) =
614  sliced_h_lens[sliced_h_index - i];
615  });
616  // TODO: add validations not across p dim
617 
618  // NOTE: this y_origin is for all dims, not only current dim
619  // will later use pick to select target dim
620  constexpr auto y_origin = [&]() {
621  constexpr auto h_trans = make_merge_transform_v3_division_mod(h_len);
622  auto h_origin_ = make_zero_multi_index<h_trans.NDimLow>();
623  h_trans.calculate_lower_index(h_origin_, sequence<x_slice_begins[id].value>{});
624 
625  auto y_origin_ = make_zero_multi_index<Encoding::NDimY>();
627  y_origin_(found_y_index - i) = h_origin_[sliced_h_index - i];
628  });
629  return y_origin_;
630  }();
631 
632  constexpr auto y_picks = typename arithmetic_sequence_gen<src_y_prefix_sum[id],
633  src_y_prefix_sum[id + 1],
634  1>::type{};
635 
637  y_slice_sorted_origins, y_picks, get_container_subset(y_origin, y_picks));
638  return sliced_h_lens;
639  },
640  typename Encoding::HsLengthss{},
641  typename arithmetic_sequence_gen<0, Encoding::HsLengthss::size(), 1>::type{});
642 
643  auto y_slice_origins = container_reorder_given_old2new(y_slice_sorted_origins, src_y_maps);
644 
645  return make_tuple(new_h_lengths, y_slice_origins, y_slice_lengths);
646  }
647  ();
648 
649  constexpr auto sliced_h_lengths = sliced_hlen_yidx_ylen[number<0>{}];
650  constexpr auto sliced_y_origins_array = sliced_hlen_yidx_ylen[number<1>{}];
651  constexpr auto sliced_y_origins_size = sliced_y_origins_array.size();
652  constexpr auto sliced_y_lengths_array = sliced_hlen_yidx_ylen[number<2>{}];
653  constexpr auto sliced_y_lengths_size = sliced_y_lengths_array.size();
654 
655  constexpr auto sliced_y_origins = TO_SEQUENCE(sliced_y_origins_array, sliced_y_origins_size);
656  constexpr auto sliced_y_lengths = TO_SEQUENCE(sliced_y_lengths_array, sliced_y_lengths_size);
657 
658  return make_tuple(
660  tile_distribution_encoding<typename Encoding::RsLengths,
661  remove_cvref_t<decltype(sliced_h_lengths)>, // only need to
662  // change the
663  // h_lengths type
664  typename Encoding::Ps2RHssMajor,
665  typename Encoding::Ps2RHssMinor,
666  typename Encoding::Ys2RHsMajor,
667  typename Encoding::Ys2RHsMinor>{}),
668  sliced_y_origins,
669  sliced_y_lengths);
670 }
671 
672 } // namespace detail
673 } // namespace ck_tile
Definition: span.hpp:18
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
constexpr CK_TILE_HOST_DEVICE auto make_sequential_index(index_t ibegin, index_t iend)
Definition: tile_distribution.hpp:254
constexpr CK_TILE_HOST_DEVICE auto make_tile_distributed_span(sequence< Is... >)
Definition: tile_distribution.hpp:53
constexpr CK_TILE_HOST_DEVICE auto slice_distribution_from_x(Distribution, sequence< XSliceBegins... > x_slice_begins, sequence< XSliceEnds... > x_slice_ends)
Definition: tile_distribution.hpp:572
constexpr CK_TILE_HOST_DEVICE auto make_tile_distributed_index(sequence< Is... >)
Definition: tile_distribution.hpp:59
constexpr CK_TILE_HOST_DEVICE auto make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:269
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:22
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_zero_multi_index()
Definition: multi_index.hpp:26
constexpr CK_TILE_HOST_DEVICE auto container_reorder_given_old2new(const array< TData, NSize > &old_array, sequence< IRs... > old2new)
Definition: container_helper.hpp:48
constexpr CK_TILE_HOST_DEVICE auto container_reduce(const Container &x, Reduce reduce, Init init, number< IBegin >=number< 0 >{}, number< IEnd >=number< Container::size()>{}, number< IStep >=number< 1 >{})
Definition: container_helper.hpp:198
coord_transform_enum
Definition: coordinate_transform.hpp:16
CK_TILE_DEVICE index_t get_lane_id()
Definition: arch.hpp:69
constexpr CK_TILE_HOST_DEVICE void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition: container_helper.hpp:420
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:55
constexpr auto reverse_slice_sequence(Seq, number< SliceSize >, Mask=typename uniform_sequence_gen< Seq::size(), 1 >::type{})
Definition: sequence.hpp:1205
constexpr CK_TILE_HOST_DEVICE auto to_array_of_array(tuple< Seqs... > t_of_s)
Definition: tuple.hpp:589
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constexpr CK_TILE_HOST_DEVICE auto make_tensor_descriptor_from_adaptor(const Adaptor &adaptor, const ElementSpaceSize &element_space_size)
Definition: tensor_descriptor.hpp:164
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1666
CK_TILE_DEVICE index_t get_warp_id()
Definition: arch.hpp:71
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:400
constexpr index_t container_find(sequence< Is... > seq, index_t value)
Definition: container_helper.hpp:447
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
constexpr CK_TILE_HOST_DEVICE auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition: container_helper.hpp:389
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:498
impl::is_static_impl< remove_cvref_t< T > > is_static
Definition: type_traits.hpp:86
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
constexpr CK_TILE_HOST_DEVICE auto transform_tuples(F f, const X &x)
Definition: tuple.hpp:471
Definition: sequence.hpp:278
Definition: array.hpp:24
Definition: integral_constant.hpp:13
Definition: tile_distribution.hpp:450
static constexpr auto rh_major_minor_to_adaptor_hidden_idss_
Definition: tile_distribution.hpp:451
Definition: meta_data_buffer.hpp:16
Definition: math.hpp:98
Definition: sequence.hpp:52
Definition: functional.hpp:43
Definition: tile_distribution.hpp:42
static constexpr CK_TILE_HOST_DEVICE bool is_static()
Definition: tile_distribution.hpp:47
static constexpr auto impl_
Definition: tile_distribution.hpp:45
Definition: tile_distribution.hpp:31
static constexpr auto impl_
Definition: tile_distribution.hpp:34
static constexpr CK_TILE_HOST_DEVICE bool is_static()
Definition: tile_distribution.hpp:36
Definition: tile_distribution_encoding.hpp:26
Definition: tile_distribution.hpp:72
remove_cvref_t< Ys2DDescriptor_ > Ys2DDescriptor
Definition: tile_distribution.hpp:74
PsYs2XsAdaptor ps_ys_to_xs_
Definition: tile_distribution.hpp:86
static constexpr CK_TILE_HOST_DEVICE auto get_distributed_spans()
Definition: tile_distribution.hpp:185
static CK_TILE_HOST_DEVICE auto _get_partition_index()
Definition: tile_distribution.hpp:94
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:126
static constexpr index_t NDimY
Definition: tile_distribution.hpp:82
remove_cvref_t< StaticTileDistributionEncoding_ > DstrEncode
Definition: tile_distribution.hpp:75
remove_cvref_t< TileDistributionDetail_ > DstrDetail
Definition: tile_distribution.hpp:76
CK_TILE_HOST_DEVICE auto calculate_index(const PartitionIndex &ps_idx=_get_partition_index()) const
Definition: tile_distribution.hpp:177
static constexpr CK_TILE_HOST_DEVICE auto get_lengths()
Definition: tile_distribution.hpp:109
static constexpr index_t NDimP
Definition: tile_distribution.hpp:83
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_dimension_x()
Definition: tile_distribution.hpp:89
static constexpr CK_TILE_HOST_DEVICE auto get_y_indices_from_distributed_indices(DistributedIndices)
Definition: tile_distribution.hpp:205
CK_TILE_HOST_DEVICE auto calculate_rs_index_from_ps_index(const PartitionIndex &ps_idx) const
Definition: tile_distribution.hpp:142
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_dimension_p()
Definition: tile_distribution.hpp:91
constexpr CK_TILE_HOST_DEVICE const auto & get_ys_to_d_descriptor() const
Definition: tile_distribution.hpp:131
CK_TILE_HOST_DEVICE void print() const
Definition: tile_distribution.hpp:232
remove_cvref_t< PsYs2XsAdaptor_ > PsYs2XsAdaptor
Definition: tile_distribution.hpp:73
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_dimension_r()
Definition: tile_distribution.hpp:92
static constexpr index_t NDimR
Definition: tile_distribution.hpp:84
static constexpr CK_TILE_HOST_DEVICE bool is_static()
Definition: tile_distribution.hpp:227
Ys2DDescriptor ys_to_d_
Definition: tile_distribution.hpp:87
static constexpr index_t NDimX
Definition: tile_distribution.hpp:81
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_dimension_y()
Definition: tile_distribution.hpp:90
static constexpr CK_TILE_HOST_DEVICE auto get_static_tile_distribution_encoding()
Definition: tile_distribution.hpp:133
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes)
Definition: container_helper.hpp:486
#define CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor)
Definition: tensor_adaptor.hpp:833
#define CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor)
Definition: tensor_adaptor.hpp:709
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10