/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_distribution.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_distribution.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_distribution.hpp Source File
tile_distribution.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
17 
18 namespace ck_tile {
19 
20 template <typename Distribution>
22 {
24 }
25 
26 // distributed span
27 template <index_t... PartialHsLengths>
29 {
30  using Impl = sequence<PartialHsLengths...>;
31 
32  static constexpr auto impl_ = Impl{};
33 
34  CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
35 };
36 
37 // distributed index
38 template <index_t... PartialHsIndices>
40 {
41  using Impl = sequence<PartialHsIndices...>;
42 
43  static constexpr auto impl_ = Impl{};
44 
45  CK_TILE_HOST_DEVICE static constexpr bool is_static() { return true; }
46 };
47 
48 namespace detail {
49 
50 template <index_t... Is>
52 {
53  return tile_distributed_span<Is...>{};
54 }
55 
56 template <index_t... Is>
58 {
59  return tile_distributed_index<Is...>{};
60 }
61 
62 } // namespace detail
63 
64 template <typename PsYs2XsAdaptor_,
65  typename Ys2DDescriptor_,
66  typename StaticTileDistributionEncoding_,
67  typename TileDistributionDetail_> // FIXME: this is for hold ad-hoc but useful info,
68  // should be more elegnat
70 {
75 
77  "wrong! should be static");
78 
79  static constexpr index_t NDimX = PsYs2XsAdaptor::get_num_of_bottom_dimension();
80  static constexpr index_t NDimY = Ys2DDescriptor::get_num_of_top_dimension();
81  static constexpr index_t NDimP = PsYs2XsAdaptor::get_num_of_top_dimension() - NDimY;
82  static constexpr index_t NDimR = StaticTileDistributionEncoding_::NDimR;
83 
86 
91 
93  {
94  // only support warp-tile and block-tile
95  static_assert(NDimP == 1 or NDimP == 2, "wrong!");
96 
97  if constexpr(NDimP == 1)
98  {
99  return array<index_t, 1>{get_lane_id()};
100  }
101  else if constexpr(NDimP == 2)
102  {
103  return array<index_t, 2>{get_warp_id(), get_lane_id()};
104  }
105  }
106 
107  CK_TILE_HOST_DEVICE static constexpr auto get_lengths()
108  {
109 #if 0
110  // FIXME: tensor_adaptor::GetBottomDimensionLengths is wrong. re-enable this after it's fixed
111  ps_ys_to_xs_.GetBottomDimensionLengths();
112 #else
113  return generate_tuple(
114  [&](auto i) {
115  constexpr index_t x_length =
116  container_reduce(typename DstrEncode::HsLengthss{}[i], multiplies{}, 1);
117 
118  return number<x_length>{};
119  },
120  number<NDimX>{});
121 #endif
122  }
123 
124  CK_TILE_HOST_DEVICE constexpr const auto& get_ps_ys_to_xs_adaptor() const
125  {
126  return ps_ys_to_xs_;
127  }
128 
129  CK_TILE_HOST_DEVICE constexpr const auto& get_ys_to_d_descriptor() const { return ys_to_d_; }
130 
132  {
133  return DstrEncode{};
134  }
135 
136 #if 1
137  // Calculate Replication index [R0, R1, ...] based on Partion index
138  // FIXME: very nasty implementation
139  template <typename PartitionIndex>
140  CK_TILE_HOST_DEVICE auto calculate_rs_index_from_ps_index(const PartitionIndex& ps_idx) const
141  {
142  static_assert(PartitionIndex::size() == NDimP, "wrong!");
143 
144  const auto ps_ys_idx = container_concat(ps_idx, array<index_t, NDimY>{0});
145 
146  const auto dummy_adaptor_coord = make_tensor_adaptor_coordinate(ps_ys_to_xs_, ps_ys_idx);
147 
148  array<index_t, NDimR> rs_idx;
149 
150  static_for<0, NDimP, 1>{}([&](auto idim_p) {
151  constexpr index_t ndim_low = DstrEncode::ps_to_rhss_major_[idim_p].size();
152 
153  static_for<0, ndim_low, 1>{}([&](auto i) {
154  constexpr index_t rh_major = DstrEncode::ps_to_rhss_major_[idim_p][i];
155  constexpr index_t rh_minor = DstrEncode::ps_to_rhss_minor_[idim_p][i];
156 
157  // 0-th rh_major is the replicate dimension
158  if constexpr(rh_major == 0)
159  {
160  constexpr index_t adaptor_hidden_id =
161  DstrDetail::rh_major_minor_to_adaptor_hidden_idss_[rh_major][rh_minor];
162 
163  // fill in
164  rs_idx(rh_minor) = dummy_adaptor_coord.get_hidden_index()[adaptor_hidden_id];
165  }
166  });
167  });
168 
169  return rs_idx;
170  }
171 #endif
172 
173  template <typename PartitionIndex = decltype(get_partition_index())>
175  calculate_index(const PartitionIndex& ps_idx = get_partition_index()) const
176  {
177  const auto ps_ys_idx = container_concat(ps_idx, array<index_t, NDimY>{0});
178  const auto window_adaptor_thread_coord_tmp =
180  return window_adaptor_thread_coord_tmp.get_bottom_index();
181  }
182 
184  {
185  constexpr auto distributed_spans_impl = DstrEncode::detail::distributed_spans_lengthss_;
186  constexpr auto ndims_spans_minor = DstrEncode::detail::ndims_distributed_spans_minor_;
187 
188  return generate_tuple(
189  [&](auto i) {
190  constexpr auto span_impl = distributed_spans_impl[i];
191  constexpr index_t ndim_span_minor = ndims_spans_minor[i];
192 
193  constexpr auto span = TO_SEQUENCE(span_impl, ndim_span_minor);
194 
196  },
197  number<NDimX>{});
198  }
199 
200  // FIXME: it's hacky to get Y index from Distributed-Index
201  template <typename DistributedIndices>
202  CK_TILE_HOST_DEVICE static constexpr auto
204  {
205  constexpr auto ys_idx_arr = [] {
206  array<index_t, NDimY> ys_idx;
207 
208  static_for<0, NDimY, 1>{}([&](auto i) {
209  constexpr index_t span_major = DstrEncode::detail::ys_to_span_major_[i];
210  constexpr index_t span_minor = DstrEncode::detail::ys_to_span_minor_[i];
211 
212  constexpr auto dstr_index = DistributedIndices{}[number<span_major>{}];
213 
214  ys_idx(i) = dstr_index.impl_[span_minor];
215  });
216 
217  return ys_idx;
218  }();
219 
220  constexpr index_t ndim_y = NDimY;
221 
222  return TO_SEQUENCE(ys_idx_arr, ndim_y);
223  }
224 
225  CK_TILE_HOST_DEVICE static constexpr bool is_static()
226  {
228  }
229 };
230 
231 template <typename T>
233 {
234 };
235 template <typename PsYs2XsAdaptor,
236  typename Ys2DDescriptor,
237  typename StaticTileDistributionEncoding,
238  typename TileDistributionDetail>
240  Ys2DDescriptor,
241  StaticTileDistributionEncoding,
242  TileDistributionDetail>> : std::true_type
243 {
244 };
245 template <typename T>
247 
248 namespace detail {
249 
250 template <index_t NDimMax>
252 {
254 
255  for(index_t i = 0; i < iend - ibegin; ++i)
256  {
257  arr(i) = ibegin + i;
258  }
259 
260  return arr;
261 }
262 
263 // this returns a constexpr encoding of tile_distribution
264 template <typename StaticTileDistributionEncoding_>
265 CK_TILE_HOST_DEVICE constexpr auto
266 make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
267 {
268  using RsLengths = typename StaticTileDistributionEncoding_::RsLengths;
269  using HsLengthss = typename StaticTileDistributionEncoding_::HsLengthss;
270  using Ps2RHssMajor = typename StaticTileDistributionEncoding_::Ps2RHssMajor;
271  using Ps2RHssMinor = typename StaticTileDistributionEncoding_::Ps2RHssMinor;
272  using Ys2RHsMajor = typename StaticTileDistributionEncoding_::Ys2RHsMajor;
273  using Ys2RHsMinor = typename StaticTileDistributionEncoding_::Ys2RHsMinor;
274 
275  // FIXME: increase max value if fail
276  constexpr index_t kMaxNumTransforms = 20;
277  constexpr index_t kMaxMetaDataSize = 128;
278  constexpr index_t kMaxNumDim = 10;
279 
280  using Name = coord_transform_enum;
281  using MetaData = meta_data_buffer<kMaxMetaDataSize>;
282  using NumDim = index_t;
283  using Dims = array<index_t, kMaxNumDim>;
284  using Lengths = array<index_t, kMaxNumDim>;
285 
286  // Tile Adaptor
287  // bottom dims [x0, x1, x2, ...]
288  // top dims [p0, p1, ..., y0, y1, ...]
289  constexpr index_t ndim_x = HsLengthss::size();
290 
291  // Dim Ids: [idim_x_major, idim_x_minor] to [idim_hidden]
292  array<array<index_t, kMaxNumDim>, ndim_x + 1> rh_major_minor_to_hidden_ids;
293  array<array<index_t, kMaxNumDim>, ndim_x + 1> rh_major_minor_to_hidden_lengths;
294 
295  auto trans = array<tuple<Name, MetaData, NumDim, Dims, NumDim, Dims>, kMaxNumTransforms>{};
296 
297  index_t num_tran = 0;
298  index_t hidden_dim_cnt = ndim_x;
299 
300  // this is replicate transform
301  {
302  constexpr index_t ndim_r_minor = RsLengths::size();
303 
304  constexpr auto r_minor_lengths = RsLengths{};
305 
306  trans(num_tran++) = {
308  MetaData{to_array<index_t, ndim_r_minor>(r_minor_lengths)},
309  NumDim{0},
310  Dims{},
311  NumDim{ndim_r_minor},
312  make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_r_minor)};
313 
314  for(index_t i = 0; i < ndim_r_minor; ++i)
315  {
316  rh_major_minor_to_hidden_ids(0)(i) = hidden_dim_cnt;
317  rh_major_minor_to_hidden_lengths(0)(i) = r_minor_lengths[i];
318 
319  hidden_dim_cnt++;
320  }
321  };
322 
323  // these are Unmerge transforms for X dimesions
324  static_for<0, ndim_x, 1>{}([&trans,
325  &num_tran,
326  &hidden_dim_cnt,
327  &rh_major_minor_to_hidden_ids,
328  &rh_major_minor_to_hidden_lengths](auto idim_x) {
329  // typename HsLengthss::base{}.foo();
330  constexpr auto h_minor_lengths =
331  HsLengthss{}.get(idim_x); // std::tuple_element_t<idim_x, HsLengthss>{};
332  // constexpr auto h_minor_lengths = impl::getv<idim_x>(HsLengthss{});
333 
334  constexpr index_t ndim_h_minor = h_minor_lengths.size();
335 
336  trans(num_tran++) = {
338  MetaData{to_array<index_t, ndim_h_minor>(h_minor_lengths)},
339  NumDim{1},
340  Dims{idim_x},
341  NumDim{ndim_h_minor},
342  make_sequential_index<kMaxNumDim>(hidden_dim_cnt, hidden_dim_cnt + ndim_h_minor)};
343 
344  for(index_t i = 0; i < ndim_h_minor; ++i)
345  {
346  rh_major_minor_to_hidden_ids(idim_x + 1)(i) = hidden_dim_cnt;
347  rh_major_minor_to_hidden_lengths(idim_x + 1)(i) = h_minor_lengths[i];
348 
349  hidden_dim_cnt++;
350  }
351  });
352 
353  // transform: P dimensions
354  constexpr index_t ndim_p = Ps2RHssMajor::size();
355 
356  Dims hidden_dim_id_ps;
357 
358  static_for<0, ndim_p, 1>{}([&](auto iDimP) {
359  //
360  index_t hidden_dim_id_p = hidden_dim_cnt++;
361 
362  hidden_dim_id_ps(iDimP) = hidden_dim_id_p;
363 
364  constexpr auto p2RHsMajor = Ps2RHssMajor{}[iDimP];
365  constexpr auto p2RHsMinor = Ps2RHssMinor{}[iDimP];
366 
367  static_assert(p2RHsMajor.size() == p2RHsMinor.size(), "wrong!");
368 
369  constexpr index_t ndim_low = p2RHsMajor.size();
370 
371  Dims low_dims;
372  Lengths low_lengths;
373 
374  for(index_t i = 0; i < ndim_low; ++i)
375  {
376  index_t rh_major = p2RHsMajor[i];
377  index_t rh_minor = p2RHsMinor[i];
378  low_dims(i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
379  low_lengths(i) = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
380  }
381 
382  trans(num_tran++) = {coord_transform_enum::merge,
383  MetaData{to_array<index_t, ndim_low>(low_lengths)},
384  NumDim{ndim_low},
385  low_dims,
386  NumDim{1},
387  Dims{hidden_dim_id_p}};
388  });
389 
390  constexpr index_t ndim_bottom = ndim_x;
391 
392  constexpr auto bottom_dim_ids = make_sequential_index<kMaxNumDim>(0, ndim_bottom);
393 
394  constexpr auto ys_to_rhs_major = Ys2RHsMajor{};
395  constexpr auto ys_to_rhs_minor = Ys2RHsMinor{};
396 
397  constexpr index_t ndim_y = Ys2RHsMajor::size();
398  constexpr index_t ndim_top = ndim_p + ndim_y;
399 
400  auto top_dim_ids = hidden_dim_id_ps;
401 
402  {
403  for(index_t i = 0; i < ndim_y; ++i)
404  {
405  index_t rh_major = ys_to_rhs_major[i];
406  index_t rh_minor = ys_to_rhs_minor[i];
407  top_dim_ids(ndim_p + i) = rh_major_minor_to_hidden_ids[rh_major][rh_minor];
408  }
409  }
410 
411  //
412  const auto ps_ys_to_xs_adaptor_encoding =
413  make_tuple(trans, num_tran, bottom_dim_ids, ndim_bottom, top_dim_ids, ndim_top);
414 
415  // descriptor: [y0, y1, ...] to [d]
416  Lengths y_lengths;
417  index_t d_length = 1;
418 
419  for(index_t i = 0; i < ndim_y; ++i)
420  {
421  index_t rh_major = ys_to_rhs_major[i];
422  index_t rh_minor = ys_to_rhs_minor[i];
423  index_t y_length = rh_major_minor_to_hidden_lengths[rh_major][rh_minor];
424  y_lengths(i) = y_length;
425  d_length *= y_length;
426  }
427 
429  MetaData{to_array<index_t, ndim_y>(y_lengths)},
430  NumDim{1},
431  Dims{0},
432  NumDim{ndim_y},
433  make_sequential_index<kMaxNumDim>(1, ndim_y + 1));
434 
435  const auto ys_to_d_adaptor_encoding = make_tuple(
436  make_tuple(tran), 1, Dims{0}, 1, make_sequential_index<kMaxNumDim>(1, ndim_y + 1), ndim_y);
437 
438  return make_tuple(ps_ys_to_xs_adaptor_encoding,
439  ys_to_d_adaptor_encoding,
440  d_length,
441  rh_major_minor_to_hidden_ids);
442 }
443 
444 // FIXME: this is nasty. Move it inside TileDistributionEncoding::detail
445 template <typename RhMajorMinor2AdaptorHiddenIdss> // tuple<sequence<...>, ...>
447 {
449  to_array_of_array(RhMajorMinor2AdaptorHiddenIdss{});
450 };
451 
452 } // namespace detail
453 
454 #if 0
455 // this returns a constexpr tile_distribution
456 template <typename StaticTileDistributionEncoding_>
457 CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
458 {
459  using DstrEncode = remove_cvref_t<StaticTileDistributionEncoding_>;
460 
461  constexpr auto adaptor_impl =
462  detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
463 
464  constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
465  constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
466  constexpr index_t d_length = adaptor_impl.template at<2>();
467  constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
468 
469  constexpr auto ps_ys_to_xs_adaptor =
470  CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
471 
472  constexpr auto ys_to_d_adaptor = CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl);
473 
474  constexpr auto ys_to_d_descriptor =
475  make_tensor_descriptor_from_adaptor(ys_to_d_adaptor, d_length);
476 
477  //
478  constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
479  constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
480 
481  constexpr auto rh_major_minor_to_hidden_ids =
482  TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
483 
484  return tile_distribution<
485  remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
486  remove_cvref_t<decltype(ys_to_d_descriptor)>,
487  remove_cvref_t<DstrEncode>,
488  detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
489  ps_ys_to_xs_adaptor, ys_to_d_descriptor};
490 }
491 #endif
492 
493 // this returns a static tile_distribution
494 template <typename StaticTileDistributionEncoding_>
495 CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistributionEncoding_)
496 {
498 
499  constexpr auto adaptor_impl =
500  detail::make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_{});
501 
502  constexpr auto ps_ys_to_xs_adaptor_impl = adaptor_impl.template at<0>();
503  constexpr auto ys_to_d_adaptor_impl = adaptor_impl.template at<1>();
504  constexpr index_t d_length = adaptor_impl.template at<2>();
505  constexpr auto rh_major_minor_to_hidden_ids_impl = adaptor_impl.template at<3>();
506 
507  constexpr auto ps_ys_to_xs_adaptor =
508  CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ps_ys_to_xs_adaptor_impl);
509 
510  constexpr auto ys_to_d_adaptor =
511  CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(ys_to_d_adaptor_impl);
512 
513  constexpr auto ys_to_d_descriptor =
515 
516  //
517  constexpr index_t ndim_rh_major = DstrEncode::detail::ndim_rh_major_;
518  constexpr auto ndims_rhs_minor = DstrEncode::detail::ndims_rhs_minor_;
519 
520  constexpr auto rh_major_minor_to_hidden_ids =
521  TO_TUPLE_OF_SEQUENCE(rh_major_minor_to_hidden_ids_impl, ndim_rh_major, ndims_rhs_minor);
522 
523  return tile_distribution<
524  remove_cvref_t<decltype(ps_ys_to_xs_adaptor)>,
525  remove_cvref_t<decltype(ys_to_d_descriptor)>,
527  detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
528  ps_ys_to_xs_adaptor, ys_to_d_descriptor};
529 }
530 
531 //***********************************************************************************
532 
533 namespace detail {
534 //
535 // slice tensor from x_dim, result in split in y_dim, not p_dim.
536 // We don't support slice cross p_dim (aka, slice different threads)
537 // also, sliced along y_dim need be the first dim of current dim.
538 // Multiply Y dim before sliced dim does not make sense
539 //
540 // e.g
541 // X0 X1
542 // <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 32>, (-1 means the last one)
543 // Y P P Y P Y P Y
544 // => <1, 4, 32> - <1, 1, 4, 2, 4> -> OK
545 // |--> slice along this Y dim, is the first dim of X1, totally 4 slices
546 //
547 // X0 X1
548 // <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 8>, (-1 means the last one)
549 // Y P P Y P Y P Y
550 // => <1, 4, 32> - <1, 1, 1, 2, 4> -> OK
551 // |--> slice along this Y dim, the P dim is 1 in the left, so is OK
552 // totally 16 slices
553 //
554 // X0 X1
555 // <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 4>, (-1 means the last one)
556 // Y P P Y P Y P Y
557 // => <1, 4, 32> - <1, 1, 1, 1, 4> -> Fail
558 // |--> slice along this P dim, will split threads, not supported
559 //
560 // X0 X1
561 // <1, 4, 32> - <4, 1, 4, 2, 4> | slice start:<0, 0>, end:<-1, 16>, (-1 means the last one)
562 // Y P P Y P Y P Y
563 // => <1, 4, 32> - <1, 1, 2, 2, 4> -> OK
564 // |--> slice along this Y dim, but this Y sim need to split into 2
565 // subdime
566 // the P dim in the left is 1, means actually not crossing P
567 //
568 template <typename Distribution, index_t... XSliceBegins, index_t... XSliceEnds>
570  Distribution, sequence<XSliceBegins...> x_slice_begins, sequence<XSliceEnds...> x_slice_ends)
571 {
572  // NOTE: this function need to be called under constexpr context,
573  // due to https://wg21.link/p2280r0 we have to use non-reference type for distribution
574  using Encoding = decltype(Distribution::get_static_tile_distribution_encoding());
575 
576  static_assert(sizeof...(XSliceBegins) == sizeof...(XSliceEnds));
577  static_assert(sizeof...(XSliceBegins) == Encoding::NDimX, "only support slice over h, not r");
578 
579  constexpr auto p_len_over_h = Encoding::detail::get_uniformed_p_dim_lengths_over_h();
580 
581  constexpr auto x_slice_ends_ = generate_sequence_v2(
582  [&](auto i) {
583  if constexpr(x_slice_ends[i] == -1)
584  {
585  // -1 means till the end
586  constexpr auto x_length_ =
587  container_reduce(typename Encoding::HsLengthss{}[i], multiplies{}, number<1>{});
588  return x_length_;
589  }
590  else
591  {
592  return x_slice_ends[i];
593  }
594  },
595  number<x_slice_ends.size()>{});
596 
597  constexpr auto x_slice_lengths = x_slice_ends_ - x_slice_begins;
598 
599  constexpr auto x_slice_lengths_without_p = generate_sequence_v2(
600  [&](auto i) constexpr {
601  constexpr auto len_ = x_slice_lengths[i];
602  static_assert(len_ % p_len_over_h[i] == 0,
603  "slice length must be dividable by p_len_over_h");
604  return number<len_ / p_len_over_h[i]>{};
605  },
606  number<x_slice_lengths.size()>{});
607 
608  constexpr auto src_h_prefix_sum = Encoding::detail::get_h_dim_lengths_prefix_sum();
609  constexpr auto src_y_info = Encoding::detail::get_sorted_y_to_h_info();
610  constexpr auto src_y_dims = src_y_info[number<0>{}];
611  constexpr auto src_y_maps = src_y_info[number<1>{}];
612  constexpr auto src_y_prefix_sum = src_y_info[number<2>{}];
613 
614  constexpr auto sliced_hlen_yidx_ylen = [&]() constexpr {
615  auto y_slice_sorted_origins = make_zero_multi_index<Encoding::NDimY>();
616  auto y_slice_lengths = Encoding::detail::ys_lengths_;
617  constexpr auto y_to_h_masks = Encoding::detail::get_y_to_h_masks();
618 
619  // This lambda will modify some value outside, so c++ will not treat return value as
620  // constexpr
621  // TODO: ugly
622  auto new_h_lengths = transform_tuples(
623  [&](auto h_len, auto id) {
624  constexpr auto sliced_h = reverse_slice_sequence(
625  h_len, number<x_slice_lengths_without_p[id]>{}, y_to_h_masks[id]);
626 
627  constexpr auto sliced_h_lens = sliced_h[number<0>{}];
628  constexpr auto sliced_h_index = sliced_h[number<2>{}];
629 
630  // update y_slice_lengths
631  constexpr auto uniformed_h_index = sliced_h_index + number<src_h_prefix_sum[id]>{};
632  constexpr auto found_y_index = container_find(src_y_dims, uniformed_h_index);
633  constexpr auto y_to_h_dim_end = src_y_prefix_sum[id + 1];
634 
635  static_assert(found_y_index >= 0 && found_y_index < src_y_dims.size(),
636  "not sliced at y dim, please check");
637 
638  {
639  constexpr auto sliced_y_to_h_lens =
640  pick_sequence_elements_by_mask(sliced_h_lens, y_to_h_masks[id]);
641  constexpr auto sliced_y_to_h_dims = sliced_y_to_h_lens.size();
643  y_slice_lengths(src_y_maps[y_to_h_dim_end - 1 - i]) =
644  sliced_y_to_h_lens[sliced_y_to_h_dims - 1 - i];
645  });
646  }
647  // TODO: add validations not across p dim
648 
649  // NOTE: this y_origin is for all dims, not only current dim
650  // will later use pick to select target dim
651  constexpr auto y_origin = [&]() {
652  // can't use Encoding::Ys2RHsMajor/Ys2RHsMinor, these are unordered
653  constexpr auto y_to_h_len =
654  pick_sequence_elements_by_mask(h_len, y_to_h_masks[id]);
655  constexpr auto y_to_h_dims = y_to_h_len.size();
656 
657  constexpr auto h_trans = make_merge_transform_v3_division_mod(y_to_h_len);
658  auto h_origin_ = make_zero_multi_index<h_trans.NDimLow>();
659  constexpr auto y_begin_ = x_slice_begins[id] / p_len_over_h[id];
660  h_trans.calculate_lower_index(h_origin_, sequence<y_begin_.value>{});
661 
662  auto y_origin_ = make_zero_multi_index<Encoding::NDimY>();
663 
664  static_for<0, y_to_h_dims, 1>{}([&](auto i) {
665  y_origin_(y_to_h_dim_end - 1 - i) = h_origin_[y_to_h_dims - 1 - i];
666  });
667  return y_origin_;
668  }();
669 
670  constexpr auto y_picks = typename arithmetic_sequence_gen<src_y_prefix_sum[id],
671  src_y_prefix_sum[id + 1],
672  1>::type{};
673 
675  y_slice_sorted_origins, y_picks, get_container_subset(y_origin, y_picks));
676  return sliced_h_lens;
677  },
678  typename Encoding::HsLengthss{},
679  typename arithmetic_sequence_gen<0, Encoding::HsLengthss::size(), 1>::type{});
680 
681  auto y_slice_origins = container_reorder_given_old2new(y_slice_sorted_origins, src_y_maps);
682 
683  return make_tuple(new_h_lengths, y_slice_origins, y_slice_lengths);
684  }();
685 
686  constexpr auto sliced_h_lengths = sliced_hlen_yidx_ylen[number<0>{}];
687  constexpr auto sliced_y_origins_array = sliced_hlen_yidx_ylen[number<1>{}];
688  constexpr auto sliced_y_origins_size = sliced_y_origins_array.size();
689  constexpr auto sliced_y_lengths_array = sliced_hlen_yidx_ylen[number<2>{}];
690  constexpr auto sliced_y_lengths_size = sliced_y_lengths_array.size();
691 
692  constexpr auto sliced_y_origins = TO_SEQUENCE(sliced_y_origins_array, sliced_y_origins_size);
693  constexpr auto sliced_y_lengths = TO_SEQUENCE(sliced_y_lengths_array, sliced_y_lengths_size);
694 
695  return make_tuple(
697  tile_distribution_encoding<typename Encoding::RsLengths,
698  remove_cvref_t<decltype(sliced_h_lengths)>, // only need to
699  // change the
700  // h_lengths type
701  typename Encoding::Ps2RHssMajor,
702  typename Encoding::Ps2RHssMinor,
703  typename Encoding::Ys2RHsMajor,
704  typename Encoding::Ys2RHsMinor>{}),
705  sliced_y_origins,
706  sliced_y_lengths);
707 }
708 
709 } // namespace detail
710 
711 // Free print function for tile_distribution
712 template <typename PsYs2XsAdaptor_,
713  typename Ys2DDescriptor_,
714  typename StaticTileDistributionEncoding_,
715  typename TileDistributionDetail_>
716 CK_TILE_HOST_DEVICE void print(const tile_distribution<PsYs2XsAdaptor_,
717  Ys2DDescriptor_,
718  StaticTileDistributionEncoding_,
719  TileDistributionDetail_>& distribution)
720 {
721  printf("tile_distribution{");
722  printf("tile_distribution_encoding: ");
723  print(StaticTileDistributionEncoding_{});
724  printf(", ");
725  printf("ps_ys_to_xs_: ");
726  print(distribution.ps_ys_to_xs_);
727  printf(", ");
728  printf("ys_to_d_: ");
729  print(distribution.ys_to_d_);
730  printf("}\n");
731 }
732 
733 } // namespace ck_tile
Definition: span.hpp:18
Concept for encoding of Unicode characters.
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
constexpr CK_TILE_HOST_DEVICE auto make_sequential_index(index_t ibegin, index_t iend)
Definition: tile_distribution.hpp:251
constexpr CK_TILE_HOST_DEVICE auto make_tile_distributed_span(sequence< Is... >)
Definition: tile_distribution.hpp:51
constexpr CK_TILE_HOST_DEVICE auto slice_distribution_from_x(Distribution, sequence< XSliceBegins... > x_slice_begins, sequence< XSliceEnds... > x_slice_ends)
Definition: tile_distribution.hpp:569
constexpr CK_TILE_HOST_DEVICE auto make_tile_distributed_index(sequence< Is... >)
Definition: tile_distribution.hpp:57
constexpr CK_TILE_HOST_DEVICE auto make_adaptor_encoding_for_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:266
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_zero_multi_index()
Definition: multi_index.hpp:26
constexpr CK_TILE_HOST_DEVICE auto container_reorder_given_old2new(const array< TData, NSize > &old_array, sequence< IRs... > old2new)
Definition: container_helper.hpp:48
constexpr CK_TILE_HOST_DEVICE auto container_reduce(const Container &x, Reduce reduce, Init init, number< IBegin >=number< 0 >{}, number< IEnd >=number< Container::size()>{}, number< IStep >=number< 1 >{})
Definition: container_helper.hpp:198
coord_transform_enum
Definition: coordinate_transform.hpp:17
constexpr CK_TILE_HOST_DEVICE auto pick_sequence_elements_by_mask(Seq, Mask)
Definition: sequence.hpp:956
constexpr CK_TILE_HOST_DEVICE void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition: container_helper.hpp:420
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:56
constexpr auto reverse_slice_sequence(Seq, number< SliceSize >, Mask=typename uniform_sequence_gen< Seq::size(), 1 >::type{})
Definition: sequence.hpp:1234
constexpr CK_TILE_HOST_DEVICE auto to_array_of_array(tuple< Seqs... > t_of_s)
Definition: tuple.hpp:630
int32_t index_t
Definition: integer.hpp:9
CK_TILE_HOST_DEVICE void print(const tile_distribution_encoding_pattern_2d< BlockSize, YPerTile, XPerTile, VecSize, DistributionPattern, NumWaveGroups > &)
Definition: static_encoding_pattern.hpp:341
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto generate_sequence_v2(F &&f, number< N >)
Definition: sequence.hpp:1056
constexpr CK_TILE_HOST_DEVICE auto make_tensor_descriptor_from_adaptor(const Adaptor &adaptor, const ElementSpaceSize &element_space_size)
Definition: tensor_descriptor.hpp:183
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1609
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:21
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr index_t container_find(sequence< Is... > seq, index_t value)
Definition: container_helper.hpp:447
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition: container_helper.hpp:389
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:495
impl::is_static_impl< remove_cvref_t< T > > is_static
Definition: type_traits.hpp:87
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
constexpr CK_TILE_HOST_DEVICE auto transform_tuples(F f, const X &x)
Definition: tuple.hpp:505
constexpr bool is_tile_distribution_v
Definition: tile_distribution.hpp:246
bool_constant< false > false_type
Definition: integral_constant.hpp:63
bool_constant< true > true_type
Definition: integral_constant.hpp:62
Definition: sequence.hpp:298
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
Definition: integral_constant.hpp:13
Definition: tile_distribution.hpp:447
static constexpr auto rh_major_minor_to_adaptor_hidden_idss_
Definition: tile_distribution.hpp:448
Definition: tile_distribution.hpp:233
Definition: meta_data_buffer.hpp:16
Definition: math.hpp:98
Definition: sequence.hpp:49
static constexpr CK_TILE_HOST_DEVICE index_t size()
Definition: sequence.hpp:53
Definition: functional.hpp:43
Definition: tile_distribution.hpp:40
static constexpr CK_TILE_HOST_DEVICE bool is_static()
Definition: tile_distribution.hpp:45
static constexpr auto impl_
Definition: tile_distribution.hpp:43
Definition: tile_distribution.hpp:29
static constexpr auto impl_
Definition: tile_distribution.hpp:32
static constexpr CK_TILE_HOST_DEVICE bool is_static()
Definition: tile_distribution.hpp:34
Definition: tile_distribution_encoding.hpp:26
Definition: tile_distribution.hpp:70
remove_cvref_t< Ys2DDescriptor_ > Ys2DDescriptor
Definition: tile_distribution.hpp:72
PsYs2XsAdaptor ps_ys_to_xs_
Definition: tile_distribution.hpp:84
static constexpr CK_TILE_HOST_DEVICE auto get_distributed_spans()
Definition: tile_distribution.hpp:183
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:124
static constexpr index_t NDimY
Definition: tile_distribution.hpp:80
remove_cvref_t< StaticTileDistributionEncoding_ > DstrEncode
Definition: tile_distribution.hpp:73
remove_cvref_t< TileDistributionDetail_ > DstrDetail
Definition: tile_distribution.hpp:74
static constexpr CK_TILE_HOST_DEVICE auto get_lengths()
Definition: tile_distribution.hpp:107
static constexpr index_t NDimP
Definition: tile_distribution.hpp:81
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_dimension_x()
Definition: tile_distribution.hpp:87
static constexpr CK_TILE_HOST_DEVICE auto get_y_indices_from_distributed_indices(DistributedIndices)
Definition: tile_distribution.hpp:203
CK_TILE_HOST_DEVICE auto calculate_rs_index_from_ps_index(const PartitionIndex &ps_idx) const
Definition: tile_distribution.hpp:140
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_dimension_p()
Definition: tile_distribution.hpp:89
constexpr CK_TILE_HOST_DEVICE const auto & get_ys_to_d_descriptor() const
Definition: tile_distribution.hpp:129
remove_cvref_t< PsYs2XsAdaptor_ > PsYs2XsAdaptor
Definition: tile_distribution.hpp:71
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_dimension_r()
Definition: tile_distribution.hpp:90
CK_TILE_HOST_DEVICE auto calculate_index(const PartitionIndex &ps_idx=get_partition_index()) const
Definition: tile_distribution.hpp:175
static constexpr index_t NDimR
Definition: tile_distribution.hpp:82
static constexpr CK_TILE_HOST_DEVICE bool is_static()
Definition: tile_distribution.hpp:225
Ys2DDescriptor ys_to_d_
Definition: tile_distribution.hpp:85
static CK_TILE_HOST_DEVICE auto get_partition_index()
Definition: tile_distribution.hpp:92
static constexpr index_t NDimX
Definition: tile_distribution.hpp:79
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_dimension_y()
Definition: tile_distribution.hpp:88
static constexpr CK_TILE_HOST_DEVICE auto get_static_tile_distribution_encoding()
Definition: tile_distribution.hpp:131
#define TO_TUPLE_OF_SEQUENCE(a_of_b_impl, a_size, bs_sizes)
Definition: container_helper.hpp:486
#define CONSTRUCT_STATIC_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor)
Definition: tensor_adaptor.hpp:840
#define CONSTRUCT_TENSOR_ADAPTOR_FROM_ENCODING(encoded_tensor_adaptor)
Definition: tensor_adaptor.hpp:716
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10