/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/core/algorithm/coordinate_transform.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/core/algorithm/coordinate_transform.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/core/algorithm/coordinate_transform.hpp Source File
coordinate_transform.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
12 
13 namespace ck_tile {
14 
16 {
17  undefined,
19  pad,
20  embed,
21  merge,
22  unmerge,
23  replicate,
24  xor_t,
25  offset,
26  indexing,
27 };
28 
29 template <index_t NDimLow, index_t NDimUp>
31 {
32  CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
33  {
35  }
36 
37  CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_lower_dimension() { return NDimLow; }
38 
39  CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_upper_dimension() { return NDimUp; }
40 
41  // return safe value for vector length/stride, based on compile-time known only
42  // variables
43  // MUST be static function
44  template <typename LowVectorLengths, typename LowVectorStrides>
45  CK_TILE_HOST_DEVICE static constexpr auto
47  const LowVectorStrides&)
48  {
49  if constexpr(NDimUp > 0)
50  {
51  array<index_t, NDimUp> up_vector_lengths{-1};
52  array<index_t, NDimUp> up_vector_strides{-1};
53 
54  return make_tuple(up_vector_lengths, up_vector_strides);
55  }
56  else
57  {
59  }
60  }
61 };
62 
63 template <typename LowLength>
64 struct pass_through : public base_transform<1, 1>
65 {
67 
70 
71  using UpLengths = decltype(make_tuple(LowLength{}));
72 
74 
75  CK_TILE_HOST_DEVICE constexpr pass_through() = default;
76 
77  CK_TILE_HOST_DEVICE constexpr pass_through(const LowLength& low_length)
78  : up_lengths_{make_tuple(low_length)}
79  {
80  }
81 
82  CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
83  {
85  }
86 
87  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
88 
89  template <typename LowIdx, typename UpIdx>
90  CK_TILE_HOST_DEVICE static constexpr void calculate_lower_index(LowIdx& idx_low,
91  const UpIdx& idx_up)
92  {
93  static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
94  "wrong! inconsistent # of dimension");
95 
96  idx_low(number<0>{}) = idx_up[number<0>{}];
97  }
98 
99  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
100  CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
101  const UpIdxDiff& idx_diff_up,
102  LowIdx& idx_low,
103  const UpIdx&)
104  {
105  static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
106  UpIdx::size() == 1,
107  "wrong! inconsistent # of dimension");
108 
109  constexpr auto I0 = number<0>{};
110 
111  idx_diff_low[I0] = idx_diff_up[I0];
112 
113  idx_low += idx_diff_low;
114  }
115 
116  CK_TILE_HOST_DEVICE static constexpr bool
118  {
119  return true;
120  }
121 
122  template <typename UpIdx>
123  CK_TILE_HOST_DEVICE static constexpr bool
125  {
126  return true;
127  }
128 
130  {
132  }
133 
134  // MUST be static function
135  template <typename LowVectorLengths, typename LowVectorStrides>
136  CK_TILE_HOST_DEVICE static constexpr auto
137  calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
138  const LowVectorStrides& low_vector_strides)
139  {
140  return make_tuple(low_vector_lengths, low_vector_strides);
141  }
142 
144  {
145  printf("pass_through{");
146 
147  //
148  printf("up_lengths_:");
150 
151  //
152  printf("}");
153  }
154 };
155 
156 template <typename LowLength,
157  typename LeftPadLength,
158  typename RightPadLength,
159  bool SkipIsValidCheck = false>
160 struct pad : public base_transform<1, 1>
161 {
164 
165  using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{} + RightPadLength{}));
166 
168  LeftPadLength left_pad_length_;
169  RightPadLength right_pad_length_;
170 
172 
173  CK_TILE_HOST_DEVICE constexpr pad(const LowLength& low_length,
174  const LeftPadLength& left_pad_length,
175  const RightPadLength& right_pad_length)
176  : up_lengths_{make_tuple(low_length + left_pad_length + right_pad_length)},
177  left_pad_length_{left_pad_length},
178  right_pad_length_{right_pad_length}
179  {
180  }
181 
182  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
183 
184  template <typename LowIdx, typename UpIdx>
185  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
186  const UpIdx& idx_up) const
187  {
188  static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
189  "wrong! inconsistent # of dimension");
190 
191  idx_low(number<0>{}) = idx_up[number<0>{}] - left_pad_length_;
192  }
193 
194  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
195  CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
196  const UpIdxDiff& idx_diff_up,
197  LowIdx& idx_low,
198  const UpIdx&)
199  {
200  static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
201  UpIdx::size() == 1,
202  "wrong! inconsistent # of dimension");
203 
204  constexpr auto I0 = number<0>{};
205 
206  idx_diff_low[I0] = idx_diff_up[I0];
207 
208  idx_low += idx_diff_low;
209  }
210 
211  CK_TILE_HOST_DEVICE static constexpr bool
213  {
214  return SkipIsValidCheck;
215  }
216 
217  template <typename UpIdx>
218  CK_TILE_HOST_DEVICE constexpr bool
220  {
221  return SkipIsValidCheck ||
222  ((idx_up[number<0>{}] >= left_pad_length_) &&
223  (idx_up[number<0>{}] < up_lengths_[number<0>{}] - right_pad_length_));
224  }
225 
227  {
231  }
232 
234  {
235  printf("pad{");
236 
237  //
238  printf("up_lengths_: ");
240  printf(", ");
241 
242  //
243  printf("left_pad_length_: ");
245  printf(", ");
246 
247  //
248  printf("right_pad_length_: ");
250 
251  printf("}");
252  }
253 };
254 
255 template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
256 struct left_pad
257 {
260 
261  using UpLengths = decltype(make_tuple(LowLength{} + LeftPadLength{}));
262 
264  LeftPadLength left_pad_length_;
265 
266  CK_TILE_HOST_DEVICE constexpr left_pad() = default;
267 
268  CK_TILE_HOST_DEVICE constexpr left_pad(const LowLength& low_length,
269  const LeftPadLength& left_pad_length)
270  : up_lengths_{make_tuple(low_length + left_pad_length)}, left_pad_length_{left_pad_length}
271  {
272  }
273 
274  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
275 
276  template <typename LowIdx, typename UpIdx>
277  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
278  const UpIdx& idx_up) const
279  {
280  static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
281  "wrong! inconsistent # of dimension");
282 
283  idx_low(number<0>{}) = idx_up[number<0>{}] - left_pad_length_;
284  }
285 
286  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
287  CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
288  const UpIdxDiff& idx_diff_up,
289  LowIdx& idx_low,
290  const UpIdx&)
291  {
292  static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
293  UpIdx::size() == 1,
294  "wrong! inconsistent # of dimension");
295 
296  constexpr auto I0 = number<0>{};
297 
298  idx_diff_low[I0] = idx_diff_up[I0];
299 
300  idx_low += idx_diff_low;
301  }
302 
303  CK_TILE_HOST_DEVICE static constexpr bool
305  {
306  return SkipIsValidCheck;
307  }
308 
309  template <typename UpIdx>
310  CK_TILE_HOST_DEVICE constexpr bool
312  {
313  return SkipIsValidCheck || (idx_up[number<0>{}] >= left_pad_length_);
314  }
315 
317  {
320  }
321 
322  // MUST be static function
323  template <typename LowVectorLengths, typename LowVectorStrides>
324  CK_TILE_HOST_DEVICE static constexpr auto
325  calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
326  const LowVectorStrides& low_vector_strides)
327  {
328  // TODO: we allow pass through this vector length. If one need per-pixel check,
329  // should change the guaranteed vector length while creating the tensor view.
330  // It's up to runtime to check the padding length should be multiple of vector length
331  return make_tuple(low_vector_lengths, low_vector_strides);
332  }
333 
335  {
336  printf("left_pad{");
337 
338  //
339  printf("up_lengths_: ");
341  printf(", ");
342 
343  //
344  printf("left_pad_length_: ");
346 
347  printf("}");
348  }
349 };
350 
351 template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
352 struct right_pad : public base_transform<1, 1>
353 {
356 
357  using UpLengths = decltype(make_tuple(LowLength{} + RightPadLength{}));
358 
360  LowLength low_length_;
361  RightPadLength right_pad_length_;
362 
363  CK_TILE_HOST_DEVICE constexpr right_pad() = default;
364 
365  CK_TILE_HOST_DEVICE constexpr right_pad(const LowLength& low_length,
366  const RightPadLength& right_pad_length)
367  : up_lengths_{make_tuple(low_length + right_pad_length)},
368  low_length_{low_length},
369  right_pad_length_{right_pad_length}
370  {
371  }
372 
373  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
374 
375  template <typename LowIdx, typename UpIdx>
376  CK_TILE_HOST_DEVICE static constexpr void calculate_lower_index(LowIdx& idx_low,
377  const UpIdx& idx_up)
378  {
379  static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
380  "wrong! inconsistent # of dimension");
381 
382  idx_low(number<0>{}) = idx_up[number<0>{}];
383  }
384 
385  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
386  CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
387  const UpIdxDiff& idx_diff_up,
388  LowIdx& idx_low,
389  const UpIdx&)
390  {
391  static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
392  UpIdx::size() == 1,
393  "wrong! inconsistent # of dimension");
394 
395  constexpr auto I0 = number<0>{};
396 
397  idx_diff_low[I0] = idx_diff_up[I0];
398 
399  idx_low += idx_diff_low;
400  }
401 
402  CK_TILE_HOST_DEVICE static constexpr bool
404  {
405  return SkipIsValidCheck;
406  }
407 
408  template <typename UpIdx>
409  CK_TILE_HOST_DEVICE constexpr bool
411  {
412  return SkipIsValidCheck || (idx_up[number<0>{}] < low_length_);
413  }
414 
416  {
420  }
421 
422  // MUST be static function
423  template <typename LowVectorLengths, typename LowVectorStrides>
424  CK_TILE_HOST_DEVICE static constexpr auto
425  calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
426  const LowVectorStrides& low_vector_strides)
427  {
428  // TODO: we allow pass through this vector length. If one need per-pixel check,
429  // should change the guaranteed vector length while creating the tensor view.
430  // It's up to runtime to check the padding length should be multiple of vector length
431  return make_tuple(low_vector_lengths, low_vector_strides);
432  }
433 
435  {
436  printf("right_pad{");
437 
438  //
439  printf("up_lengths_: ");
441  printf(", ");
442 
443  //
444  printf("right_pad_length_: ");
446 
447  printf("}");
448  }
449 };
450 
451 // idx_low = coefficients[0, ...nDimUp-1] * idx_up[0, ...nDimUp-1]
452 // UpLengths and Coefficients can be either of the followings:
453 // 1) Tuple of index_t, which is known at run-time, or
454 // 2) Tuple of number, which is known at compile-time, or
455 // 3) Tuple of mixture of index_t and number, which is known partially at run-time and partially
456 // at compile-time
457 template <typename UpLengths,
458  typename Coefficients,
459  typename std::enable_if<UpLengths::size() == Coefficients::size(), bool>::type = false>
460 struct embed : public base_transform<1, UpLengths::size()>
461 {
462  static constexpr index_t NDimUp = UpLengths::size();
463 
466 
467  UpLengths up_lengths_;
468  Coefficients coefficients_;
469 
470  CK_TILE_HOST_DEVICE constexpr embed() = default;
471 
472  CK_TILE_HOST_DEVICE constexpr embed(const UpLengths& up_lengths,
473  const Coefficients& coefficients)
474  : up_lengths_{up_lengths}, coefficients_{coefficients}
475  {
476  }
477 
478  CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
479  {
481  }
482 
483  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
484 
485  template <typename LowIdx, typename UpIdx>
486  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
487  const UpIdx& idx_up) const
488  {
489  static_assert(LowIdx::size() == 1 && UpIdx::size() == NDimUp,
490  "wrong! inconsistent # of dimension");
491 
492  idx_low(number<0>{}) = 0;
493 
494  static_for<0, NDimUp, 1>{}([&idx_low, &idx_up, this](auto i) {
495  idx_low(number<0>{}) += idx_up[i] * this->coefficients_[i];
496  });
497  }
498 
499  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
500  CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
501  const UpIdxDiff& idx_diff_up,
502  LowIdx& idx_low,
503  const UpIdx&) const
504  {
505  static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == NDimUp &&
506  LowIdx::size() == 1 && UpIdx::size() == NDimUp,
507  "wrong! inconsistent # of dimension");
508 
509  idx_diff_low(number<0>{}) = 0;
510 
512  [&](auto i) { idx_diff_low(number<0>{}) += idx_diff_up[i] * coefficients_[i]; });
513 
514  idx_low += idx_diff_low;
515  }
516 
517  CK_TILE_HOST_DEVICE static constexpr bool
519  {
520  return true;
521  }
522 
523  template <typename UpIdx>
524  CK_TILE_HOST_DEVICE static constexpr bool
526  {
527  return true;
528  }
529 
531  {
534  }
535 
537  {
538  printf("embed{");
539 
540  //
541  printf("up_lengths_: ");
543  printf(", ");
544 
545  //
546  printf("coefficients_: ");
548 
549  printf("}");
550  }
551 };
552 
553 template <typename LowLengths>
555 {
556  template <index_t I>
557  CK_TILE_HOST_DEVICE constexpr auto operator()(number<I> i) const
558  {
559  return magic_division::calculate_magic_numbers(LowLengths{}[i]);
560  }
561 };
562 
563 // Implementation of "merge" transformation primitive that uses magic-number-division to do lowering
564 // of both multi-index and delta of multi-index
565 // Caution:
566 // 1. The magic number division implementation being used would produce correct result if the
567 // dividended is uint32_t and its value is with in 31-bit value range of uint32_t.
568 // 2. The magic number division for int32_t dividened has not been implemented, the int32_t
569 // dividend would be bit-wise interpreted as uint32_t and magic number division implementation for
570 // uint32_t is then used.
571 // 3. For merge primitive, upper-index is the dividend.
572 // 4. When upper-index is uint32_t, its value need to be within 31-bit range.
573 // 5. When upper-index is int32_t type (when index_t is int32_t), its value need to be
574 // non-negative.
575 template <typename LowLengths>
576 struct merge_v2_magic_division : public base_transform<LowLengths::size(), 1>
577 {
578  static constexpr index_t NDimLow = LowLengths::size();
579 
582 
583  using UpLengths =
584  decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number<1>{})));
585 
588  number<NDimLow>{}));
589 
590  LowLengths low_lengths_;
593 
594  static constexpr auto I0 = number<0>{};
595  static constexpr auto I1 = number<1>{};
596 
598 
599  CK_TILE_HOST_DEVICE constexpr merge_v2_magic_division(const LowLengths& low_lengths)
600  : low_lengths_{low_lengths},
602  [&](auto i) { return magic_division::calculate_magic_numbers(low_lengths[i]); },
603  number<NDimLow>{})},
604  up_lengths_{make_tuple(container_reduce(low_lengths, multiplies{}, I1))}
605  {
606  static_assert(LowerIndex::size() == NDimLow, "wrong!");
607  }
608 
609  CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
610  {
612  }
613 
614  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
615 
616  template <typename LowIdx, typename UpIdx>
617  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
618  const UpIdx& idx_up) const
619  {
620  static_assert(LowIdx::size() == NDimLow && UpIdx::size() == 1,
621  "wrong! inconsistent # of dimension");
622 
623  index_t tmp = idx_up[I0];
624 
625  static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
626  index_t tmp2 =
628  this->low_lengths_magic_divisor_[i][I0],
629  this->low_lengths_magic_divisor_[i][I1]);
630  idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
631  tmp = tmp2;
632  });
633 
634  idx_low(number<0>{}) = tmp;
635  }
636 
637  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
638  CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
639  const UpIdxDiff&,
640  LowIdx& idx_low,
641  const UpIdx& idx_up_new) const
642  {
643  static_assert(LowIdxDiff::size() == NDimLow && UpIdxDiff::size() == 1 &&
644  LowIdx::size() == NDimLow && UpIdx::size() == 1,
645  "wrong! inconsistent # of dimension");
646 
647  index_t tmp = idx_up_new[number<0>{}];
648 
649  static_for<NDimLow - 1, 0, -1>{}([&, this](auto i) {
650  index_t tmp2 =
652  this->low_lengths_magic_divisor_[i][I0],
653  this->low_lengths_magic_divisor_[i][I1]);
654 
655  index_t idx_low_old = idx_low[i];
656 
657  idx_low(i) = tmp - tmp2 * this->low_lengths_[i];
658  tmp = tmp2;
659 
660  idx_diff_low(i) = idx_low[i] - idx_low_old;
661  });
662 
663  idx_diff_low(number<0>{}) = tmp - idx_low(number<0>{});
664 
665  idx_low(number<0>{}) = tmp;
666  }
667 
668  CK_TILE_HOST_DEVICE static constexpr bool
670  {
671  return true;
672  }
673 
675  {
679  }
680 
681  template <typename UpIdx>
682  CK_TILE_HOST_DEVICE static constexpr bool
684  {
685  return true;
686  }
687 
688  // MUST be static function
689  template <typename LowVectorLengths, typename LowVectorStrides>
690  CK_TILE_HOST_DEVICE static constexpr auto
691  calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
692  const LowVectorStrides& low_vector_strides)
693  {
694  array<index_t, 1> up_vector_lengths{-1};
695  array<index_t, 1> up_vector_strides{-1};
696 
697  up_vector_lengths[0] = low_vector_lengths[number<NDimLow - 1>{}];
698  up_vector_strides[0] = low_vector_strides[number<NDimLow - 1>{}];
699 
700  return make_tuple(up_vector_lengths, up_vector_strides);
701  }
702 
704  {
705  printf("merge_v2_magic_division{");
706 
707  //
708  printf("low_lengths_ ");
709  print(low_lengths_);
710  printf(", ");
711 
712  //
713  printf("up_lengths_ ");
714  print(up_lengths_);
715 
716  printf("}");
717  }
718 };
719 
720 // Implementation of "merge" transformation primitive that uses division and mod. It is supposed to
721 // be used for low_lengths that are known at compile time and are power of 2, otherwise performance
722 // will be very bad
723 template <typename LowLengths>
724 struct merge_v3_division_mod : public base_transform<LowLengths::size(), 1>
725 {
726  static constexpr index_t NDimLow = LowLengths::size();
727 
730 
732  decltype(container_reverse_exclusive_scan(LowLengths{}, multiplies{}, number<1>{}));
733 
734  using UpLengths =
735  decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number<1>{})));
736 
737  LowLengths low_lengths_;
740 
742 
743  CK_TILE_HOST_DEVICE constexpr merge_v3_division_mod(const LowLengths& low_lengths)
744  : low_lengths_{low_lengths},
745  low_lengths_scan_{
746  container_reverse_exclusive_scan(low_lengths, multiplies{}, number<1>{})},
747  up_lengths_{make_tuple(container_reduce(low_lengths, multiplies{}, number<1>{}))}
748  {
749  static_assert(LowerIndex::size() == NDimLow, "wrong!");
750  }
751 
752  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
753 
754  template <typename LowIdx, typename UpIdx>
755  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
756  const UpIdx& idx_up) const
757  {
758  static_assert(LowIdx::size() == NDimLow && UpIdx::size() == 1,
759  "wrong! inconsistent # of dimension");
760 
761  index_t tmp = idx_up[number<0>{}];
762 
763  // division and mod
764  static_for<0, NDimLow - 1, 1>{}([&](auto i) {
765  idx_low(i) = tmp / this->low_lengths_scan_[i];
766  tmp %= this->low_lengths_scan_[i];
767  });
768 
769  idx_low(number<NDimLow - 1>{}) = tmp;
770  }
771 
772  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
773  CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
774  const UpIdxDiff&,
775  LowIdx& idx_low,
776  const UpIdx& idx_up_new) const
777  {
778  static_assert(LowIdxDiff::size() == NDimLow && UpIdxDiff::size() == 1 &&
779  LowIdx::size() == NDimLow && UpIdx::size() == 1,
780  "wrong! inconsistent # of dimension");
781 
782  constexpr auto I0 = number<0>{};
783  constexpr auto INm1 = number<NDimLow - 1>{};
784 
785  index_t tmp = idx_up_new[I0];
786 
787  static_for<0, NDimLow - 1, 1>{}([&](auto i) {
788  const index_t tmp2 = idx_low[i];
789  idx_low(i) = tmp / this->low_lengths_scan_[i];
790  idx_diff_low(i) = idx_low[i] - tmp2;
791  tmp %= this->low_lengths_scan_[i];
792  });
793 
794  const index_t tmp2 = idx_low[INm1];
795  idx_low(INm1) = tmp;
796  idx_diff_low(INm1) = idx_low[INm1] - tmp2;
797  }
798 
799  CK_TILE_HOST_DEVICE static constexpr bool
801  {
802  return true;
803  }
804 
806  {
810  }
811 
812  template <typename UpIdx>
813  CK_TILE_HOST_DEVICE static constexpr bool
815  {
816  return true;
817  }
818 
819  // MUST be static function
820  template <typename LowVectorLengths, typename LowVectorStrides>
821  CK_TILE_HOST_DEVICE static constexpr auto
822  calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
823  const LowVectorStrides& low_vector_strides)
824  {
825  array<index_t, 1> up_vector_lengths{-1};
826  array<index_t, 1> up_vector_strides{-1};
827 
828  up_vector_lengths[0] = low_vector_lengths[number<NDimLow - 1>{}];
829  up_vector_strides[0] = low_vector_strides[number<NDimLow - 1>{}];
830 
831  return make_tuple(up_vector_lengths, up_vector_strides);
832  }
833 
835  {
836  printf("Merge_v3_direct_division_mod{");
837 
838  //
839  printf("low_lengths_ ");
840  print(low_lengths_);
841  printf(", ");
842 
843  //
844  printf("low_lengths_scan_ ");
845  print(low_lengths_scan_);
846  printf(", ");
847 
848  //
849  printf("up_lengths_ ");
850  print(up_lengths_);
851 
852  printf("}");
853  }
854 };
855 
856 template <typename UpLengths, bool Use24BitIntegerCalculation>
857 struct unmerge : public base_transform<1, UpLengths::size()>
858 {
859  static constexpr index_t NDimUp = UpLengths::size();
860 
863 
865  decltype(container_reverse_exclusive_scan(UpLengths{}, multiplies{}, number<1>{}));
866 
867  UpLengths up_lengths_;
869 
870  CK_TILE_HOST_DEVICE constexpr unmerge() = default;
871 
872  CK_TILE_HOST_DEVICE constexpr unmerge(const UpLengths& up_lengths)
873  : up_lengths_{up_lengths},
874  up_lengths_scan_{container_reverse_exclusive_scan(up_lengths, multiplies{}, number<1>{})}
875  {
876  }
877 
878  CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
879  {
881  }
882 
883  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
884 
885  template <typename LowIdx, typename UpIdx>
886  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
887  const UpIdx& idx_up) const
888  {
889  if constexpr(!Use24BitIntegerCalculation)
890  {
891  idx_low(number<0>{}) = idx_up[number<NDimUp - 1>{}];
892 
893  static_for<0, NDimUp - 1, 1>{}(
894  [&](auto i) { idx_low(number<0>{}) += idx_up[i] * up_lengths_scan_[i]; });
895  }
896  else
897  {
898  idx_low(number<0>{}) = idx_up[number<NDimUp - 1>{}];
899 
900  static_for<0, NDimUp - 1, 1>{}([&](auto i) {
901  idx_low(number<0>{}) =
902  (0x00ffffff & idx_low[number<0>{}]) +
903  (0x00ffffff & idx_up[i]) * (0x00ffffff & up_lengths_scan_[i]);
904  });
905  }
906  }
907 
908  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
909  CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
910  const UpIdxDiff& idx_diff_up,
911  LowIdx& idx_low,
912  const UpIdx&) const
913  {
914  calculate_lower_index(idx_diff_low, idx_diff_up);
915 
916  idx_low += idx_diff_low;
917  }
918 
919  CK_TILE_HOST_DEVICE static constexpr bool
921  {
922  return true;
923  }
924 
925  template <typename UpIdx>
926  CK_TILE_HOST_DEVICE static constexpr bool
928  {
929  return true;
930  }
931 
933  {
936  }
937 
938  // MUST be static function
939  template <typename LowVectorLengths, typename LowVectorStrides>
940  CK_TILE_HOST_DEVICE static constexpr auto
941  calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths& low_vector_lengths,
942  const LowVectorStrides& low_vector_strides)
943  {
944  array<index_t, NDimUp> up_vector_lengths{-1};
945  array<index_t, NDimUp> up_vector_strides{-1};
946 
947  constexpr auto up_length_last = UpLengths{}[number<NDimUp - 1>{}];
948 
949  if constexpr(ck_tile::is_known_at_compile_time<decltype(up_length_last)>::value)
950  {
951  if(low_vector_lengths[0] != -1)
952  {
953  up_vector_lengths(NDimUp - 1) = gcd(low_vector_lengths[0], up_length_last);
954  }
955  }
956 
957  up_vector_strides(NDimUp - 1) = low_vector_strides[0];
958 
959  return make_tuple(up_vector_lengths, up_vector_strides);
960  }
961 
963  {
964  printf("unmerge{");
965 
966  //
967  printf("up_lengths_");
968  print(up_lengths_);
969  printf(", ");
970 
971  //
972  printf("up_lengths_scan_");
973  print(up_lengths_scan_);
974 
975  printf("}");
976  }
977 };
978 
979 template <typename LowerIndex>
980 struct freeze : public base_transform<1, 0>
981 {
982  LowerIndex low_idx_;
983 
984  CK_TILE_HOST_DEVICE constexpr freeze() = default;
985 
986  CK_TILE_HOST_DEVICE constexpr freeze(const LowerIndex& low_idx) : low_idx_{low_idx} {}
987 
988  CK_TILE_HOST_DEVICE static constexpr auto get_upper_lengths() { return tuple<>{}; }
989 
990  template <typename LowIdx, typename UpIdx>
991  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
992  const UpIdx& /* idx_up */) const
993  {
994  static_assert(LowIdx::size() == 1 && UpIdx::size() == 0,
995  "wrong! inconsistent # of dimension");
996 
997  idx_low(number<0>{}) = low_idx_;
998  }
999 
1000  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
1001  CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
1002  const UpIdxDiff& /* idx_diff_up */,
1003  LowIdx& /* idx_low */,
1004  const UpIdx& /* idx_up_new */)
1005  {
1006  idx_diff_low(number<0>{}) = 0;
1007  }
1008 
1009  CK_TILE_HOST_DEVICE static constexpr bool
1011  {
1012  return true;
1013  }
1014 
1015  template <typename UpIdx>
1016  CK_TILE_HOST_DEVICE static constexpr bool
1018  {
1019  return true;
1020  }
1021 
1023  {
1025  }
1026 
1028  {
1029  printf("freeze{");
1030 
1031  //
1032  printf("low_idx_: ");
1033  print(low_idx_);
1034 
1035  printf("}");
1036  }
1037 };
1038 
1039 // insert a dangling upper dimension without lower dimension
1040 template <typename UpperLength>
1041 struct insert : public base_transform<0, 1>
1042 {
1043  using UpLengths = decltype(make_tuple(UpperLength{}));
1044 
1046 
1047  CK_TILE_HOST_DEVICE constexpr insert() = default;
1048 
1049  CK_TILE_HOST_DEVICE constexpr insert(const UpperLength& up_length)
1050  : up_lengths_{make_tuple(up_length)}
1051  {
1052  }
1053 
1054  CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_lower_dimension() { return 0; }
1055 
1056  CK_TILE_HOST_DEVICE static constexpr index_t get_num_of_upper_dimension() { return 1; }
1057 
1058  CK_TILE_HOST_DEVICE constexpr auto get_upper_lengths() const { return up_lengths_; }
1059 
1060  template <typename LowIdx, typename UpIdx>
1061  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx&, const UpIdx&) const
1062  {
1063  static_assert(LowIdx::size() == 0 && UpIdx::size() == 1,
1064  "wrong! inconsistent # of dimension");
1065  }
1066 
1067  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
1068  CK_TILE_HOST_DEVICE static void
1069  update_lower_index(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&)
1070  {
1071  static_assert(LowIdxDiff::size() == 0 && UpIdxDiff::size() == 1 && LowIdx::size() == 0 &&
1072  UpIdx::size() == 1,
1073  "wrong! inconsistent # of dimension");
1074  }
1075 
1076  CK_TILE_HOST_DEVICE static constexpr bool IsLinearTransform() { return true; }
1077 
1078  CK_TILE_HOST_DEVICE static constexpr bool
1080  {
1081  return true;
1082  }
1083 
1084  template <typename UpIdx>
1085  CK_TILE_HOST_DEVICE static constexpr bool
1087  {
1088  return true;
1089  }
1090 
1092  {
1094  }
1095 
1097  {
1098  printf("insert{");
1099 
1100  //
1101  print(up_lengths_);
1102 
1103  printf("}");
1104  }
1105 };
1106 
1107 // replicate the original tensor and create a higher dimensional tensor
1108 template <typename UpLengths>
1109 struct replicate : public base_transform<0, UpLengths::size()>
1110 {
1111  static constexpr index_t NDimUp = UpLengths::size();
1112 
1113  CK_TILE_HOST_DEVICE constexpr replicate() = default;
1114 
1115  CK_TILE_HOST_DEVICE constexpr replicate(const UpLengths& up_lengths) : up_lengths_{up_lengths}
1116  {
1117  }
1118 
1119  CK_TILE_HOST_DEVICE constexpr auto get_upper_lengths() const { return up_lengths_; }
1120 
1121  template <typename LowIdx, typename UpIdx>
1122  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx&, const UpIdx&) const
1123  {
1124  static_assert(LowIdx::size() == 0 && UpIdx::size() == NDimUp,
1125  "wrong! inconsistent # of dimension");
1126  }
1127 
1128  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
1129  CK_TILE_HOST_DEVICE static void
1130  update_lower_index(LowIdxDiff&, const UpIdxDiff&, LowIdx&, const UpIdx&)
1131  {
1132  static_assert(LowIdxDiff::size() == 0 && UpIdxDiff::size() == NDimUp &&
1133  LowIdx::size() == 0 && UpIdx::size() == NDimUp,
1134  "wrong! inconsistent # of dimension");
1135  }
1136 
1137  CK_TILE_HOST_DEVICE static constexpr bool
1139  {
1140  return true;
1141  }
1142 
1143  template <typename UpIdx>
1144  CK_TILE_HOST_DEVICE static constexpr bool
1146  {
1147  return true;
1148  }
1149 
1151  {
1153  }
1154 
1156  {
1157  printf("replicate{");
1158 
1159  //
1160  printf("up_lengths_: ");
1161  print(up_lengths_);
1162 
1163  printf("}");
1164  }
1165 
1166  //
1167  UpLengths up_lengths_;
1168 };
1169 
1170 template <typename LowLength, typename SliceBegin, typename SliceEnd>
1171 struct slice : public base_transform<1, 1>
1172 {
1175 
1176  using UpLengths = decltype(make_tuple(SliceEnd{} - SliceBegin{}));
1177 
1179  SliceBegin slice_begin_;
1180  SliceEnd slice_end_;
1181 
1182  CK_TILE_HOST_DEVICE constexpr slice() = default;
1183 
1184  CK_TILE_HOST_DEVICE constexpr slice(const LowLength&,
1185  const SliceBegin& slice_begin,
1186  const SliceEnd& slice_end)
1187  : up_lengths_{make_tuple(slice_end - slice_begin)},
1188  slice_begin_{slice_begin},
1189  slice_end_{slice_end}
1190  {
1191  }
1192 
1193  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
1194 
1195  template <typename LowIdx, typename UpIdx>
1196  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
1197  const UpIdx& idx_up) const
1198  {
1199  static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
1200  "wrong! inconsistent # of dimension");
1201 
1202  idx_low(number<0>{}) = idx_up[number<0>{}] + slice_begin_;
1203  }
1204 
1205  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
1206  CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
1207  const UpIdxDiff& idx_diff_up,
1208  LowIdx& idx_low,
1209  const UpIdx&)
1210  {
1211  static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
1212  UpIdx::size() == 1,
1213  "wrong! inconsistent # of dimension");
1214 
1215  constexpr auto I0 = number<0>{};
1216 
1217  idx_diff_low[I0] = idx_diff_up[I0];
1218 
1219  idx_low += idx_diff_low;
1220  }
1221 
1222  CK_TILE_HOST_DEVICE static constexpr bool
1224  {
1225  return true;
1226  }
1227 
1228  template <typename UpIdx>
1229  CK_TILE_HOST_DEVICE constexpr bool
1231  {
1232  return true;
1233  }
1234 
1236  {
1240  }
1241 
1243  {
1244  printf("slice{");
1245 
1246  //
1247  printf("up_lengths_: ");
1248  print(up_lengths_);
1249  printf(", ");
1250 
1251  //
1252  printf("slice_begin_: ");
1253  print(slice_begin_);
1254  printf(", ");
1255 
1256  //
1257  printf("slice_end_: ");
1258  print(slice_end_);
1259 
1260  printf("}");
1261  } // namespace ck
1262 }; // namespace ck
1263 
1264 /*
1265  * \brief lower_idx = upper_idx % modulus.
1266  * TODO: Need an improved implementation since the modulo operation is expensive.
1267  */
1268 template <typename Modulus, typename UpLength>
1269 struct modulo : public base_transform<1, 1>
1270 {
1273  using UpLengths = decltype(make_tuple(UpLength{}));
1274 
1275  Modulus modulus_;
1277 
1278  CK_TILE_HOST_DEVICE constexpr modulo() = default;
1279 
1280  CK_TILE_HOST_DEVICE constexpr modulo(const Modulus& modulus, const UpLength& up_length)
1281  : modulus_{modulus}, up_lengths_{make_tuple(up_length)}
1282  {
1283  }
1284 
1285  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
1286 
1287  template <typename LowIdx, typename UpIdx>
1288  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
1289  const UpIdx& idx_up) const
1290  {
1291  static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
1292  "wrong! inconsistent # of dimension");
1293 
1294  idx_low(number<0>{}) = idx_up[number<0>{}] % modulus_;
1295  }
1296 
1297  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
1298  CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
1299  const UpIdxDiff& idx_diff_up,
1300  LowIdx& idx_low,
1301  const UpIdx& up_idx) const
1302  {
1303  static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
1304  UpIdx::size() == 1,
1305  "wrong! inconsistent # of dimension");
1306 
1307  constexpr auto I0 = number<0>{};
1308 
1309  const auto idx_low_old = idx_low;
1310  idx_low[I0] = (up_idx[I0] + idx_diff_up[I0]) % modulus_;
1311  idx_diff_low[I0] = idx_low - idx_low_old;
1312  }
1313 
1314  CK_TILE_HOST_DEVICE static constexpr bool
1316  {
1317  return true;
1318  }
1319 
1320  template <typename UpIdx>
1321  CK_TILE_HOST_DEVICE static constexpr bool
1323  {
1324  return true;
1325  }
1326 
1328  {
1330  }
1331 
1333  {
1334  printf("Modulus{");
1335 
1336  //
1337  printf("up_lengths_: ");
1338  print(up_lengths_);
1339 
1340  printf("}");
1341  }
1342 };
1343 
1344 // 2D XOR, NOTE: "xor" is a keyword
1345 template <typename LowLengths>
1346 struct xor_t : public base_transform<2, 2>
1347 {
1348  static constexpr auto type_enum = coord_transform_enum::xor_t;
1349 
1352 
1353  using UpLengths = LowLengths;
1354 
1356 
1357  CK_TILE_HOST_DEVICE constexpr xor_t() : up_lengths_{} {}
1358 
1359  CK_TILE_HOST_DEVICE constexpr xor_t(const LowLengths& low_lengths) : up_lengths_{low_lengths} {}
1360 
1361  CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
1362  {
1363  return coord_transform_enum::xor_t;
1364  }
1365 
1366  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
1367 
1368  template <typename LowIdx, typename UpIdx>
1369  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
1370  const UpIdx& idx_up) const
1371  {
1372  static_assert(LowIdx::size() == 2 && UpIdx::size() == 2,
1373  "wrong! inconsistent # of dimension");
1374 
1375  idx_low(number<0>{}) = idx_up[number<0>{}];
1376 
1377  idx_low(number<1>{}) =
1378  idx_up[number<1>{}] ^ (idx_up[number<0>{}] % up_lengths_[number<1>{}]);
1379  }
1380 
1381  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
1382  CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
1383  const UpIdxDiff&,
1384  LowIdx& idx_low,
1385  const UpIdx& idx_up) const
1386  {
1387  static_assert(LowIdxDiff::size() == 2 && UpIdxDiff::size() == 2 && LowIdx::size() == 2 &&
1388  UpIdx::size() == 2,
1389  "wrong! inconsistent # of dimension");
1390 
1391  const auto idx_low_old = idx_low;
1392 
1393  calculate_lower_index(idx_low, idx_up);
1394 
1395  idx_diff_low = idx_low - idx_low_old;
1396  }
1397 
1398  CK_TILE_HOST_DEVICE static constexpr bool
1400  {
1401  return true;
1402  }
1403 
1404  template <typename UpIdx>
1405  CK_TILE_HOST_DEVICE static constexpr bool
1407  {
1408  return true;
1409  }
1410 
1412  {
1414  }
1415 
1416  // MUST be static function
1417  template <typename LowVectorLengths, typename LowVectorStrides>
1419  const LowVectorLengths& low_vector_lengths,
1420  const LowVectorStrides& low_vector_strides) const
1421  {
1422  array<index_t, 2> up_vector_lengths = low_vector_lengths;
1423  array<index_t, 2> up_vector_strides = low_vector_strides;
1424 
1425  return make_tuple(up_vector_lengths, up_vector_strides);
1426  }
1427 
1429  {
1430  printf("xor_t{");
1431 
1432  //
1433  printf("up_lengths_: ");
1434  print(up_lengths_);
1435  printf(", ");
1436 
1437  printf("}");
1438  }
1439 };
1440 
1441 template <typename LowLength, typename OffsetLength>
1442 struct offset : public base_transform<1, 1>
1443 {
1446 
1447  using UpLengths = decltype(make_tuple(LowLength{}));
1448 
1450  OffsetLength offset_length_;
1451 
1452  CK_TILE_HOST_DEVICE constexpr offset() = default;
1453 
1454  CK_TILE_HOST_DEVICE constexpr offset(const LowLength& low_length,
1455  const OffsetLength& offset_length)
1456  : up_lengths_{make_tuple(low_length)}, offset_length_{offset_length}
1457  {
1458  }
1459 
1460  CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
1461  {
1462  return coord_transform_enum::offset;
1463  }
1464 
1465  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
1466 
1467  template <typename LowIdx, typename UpIdx>
1468  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
1469  const UpIdx& idx_up) const
1470  {
1471  static_assert(LowIdx::size() == 1 && UpIdx::size() == 1,
1472  "wrong! inconsistent # of dimension");
1473 
1474  idx_low(number<0>{}) = idx_up[number<0>{}] + offset_length_;
1475  }
1476 
1477  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
1478  CK_TILE_HOST_DEVICE static void update_lower_index(LowIdxDiff& idx_diff_low,
1479  const UpIdxDiff& idx_diff_up,
1480  LowIdx& idx_low,
1481  const UpIdx&)
1482  {
1483  static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
1484  UpIdx::size() == 1,
1485  "wrong! inconsistent # of dimension");
1486 
1487  constexpr auto I0 = number<0>{};
1488 
1489  idx_diff_low[I0] = idx_diff_up[I0];
1490 
1491  idx_low += idx_diff_low;
1492  }
1493 
1494  CK_TILE_HOST_DEVICE static constexpr bool
1496  {
1497  return true;
1498  }
1499 
1500  template <typename UpIdx>
1501  CK_TILE_HOST_DEVICE constexpr bool
1503  {
1504  return true;
1505  }
1506 
1508  {
1511  }
1512 
1514  {
1515  printf("offset{");
1516 
1517  //
1518  printf("up_lengths_: ");
1519  print(up_lengths_);
1520  printf(", ");
1521 
1522  //
1523  printf("offset_length_: ");
1524  print(offset_length_);
1525 
1526  printf("}");
1527  }
1528 };
1529 
1530 template <typename UpLength, typename IndexingAdaptor>
1531 struct indexing : public base_transform<1, 1>
1532 {
1533  static constexpr index_t NDimUp = 1;
1534 
1537 
1538  using UpLengths = decltype(make_tuple(UpLength{}));
1540  IndexingAdaptor iadaptor_;
1541 
1542  CK_TILE_HOST_DEVICE constexpr indexing() = default;
1543 
1544  CK_TILE_HOST_DEVICE constexpr indexing(const UpLength& up_length,
1545  const IndexingAdaptor& iadaptor)
1546  : up_lengths_{make_tuple(up_length)}, iadaptor_{iadaptor}
1547  {
1548  }
1549 
1550  CK_TILE_HOST_DEVICE static constexpr auto get_type_enum()
1551  {
1552  return coord_transform_enum::indexing;
1553  }
1554 
1555  CK_TILE_HOST_DEVICE constexpr const auto& get_upper_lengths() const { return up_lengths_; }
1556 
1557  template <typename LowIdx, typename UpIdx>
1558  CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
1559  const UpIdx& idx_up) const
1560  {
1561  static_assert(LowIdx::size() == 1 && UpIdx::size() == NDimUp,
1562  "wrong! inconsistent # of dimension");
1563  iadaptor_.calculate_lower_index(idx_low, idx_up);
1564  }
1565 
1566  template <typename LowIdxDiff, typename UpIdxDiff, typename LowIdx, typename UpIdx>
1567  CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low,
1568  const UpIdxDiff& idx_diff_up,
1569  LowIdx& idx_low,
1570  const UpIdx& idx_up) const
1571  {
1572  // TODO: nonthing changed here
1573  static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == NDimUp &&
1574  LowIdx::size() == 1 && UpIdx::size() == NDimUp,
1575  "wrong! inconsistent # of dimension");
1576 
1577  iadaptor_.update_lower_index(idx_diff_low, idx_diff_up, idx_low, idx_up);
1578  }
1579 
1580  CK_TILE_HOST_DEVICE static constexpr bool
1582  {
1583  return true;
1584  }
1585 
1586  template <typename UpIdx>
1587  CK_TILE_HOST_DEVICE static constexpr bool
1589  {
1590  return true;
1591  }
1592 
1594  {
1597  }
1598 
1600  {
1601  printf("embed{");
1602 
1603  //
1604  printf("up_lengths_: ");
1605  print(up_lengths_);
1606  printf(", ");
1607 
1608  printf("}");
1609  }
1610 };
1611 
1612 //*******************************************************************************************************
1613 
1614 template <typename LowLength>
1615 CK_TILE_HOST_DEVICE constexpr auto make_pass_through_transform(const LowLength& low_length)
1616 {
1617  return pass_through<LowLength>{low_length};
1618 }
1619 
1620 template <typename LowLength, typename LeftPad, typename RightPad, bool SkipIsValidCheck = false>
1621 CK_TILE_HOST_DEVICE constexpr auto
1622 make_pad_transform(const LowLength& low_length,
1623  const LeftPad& left_pad,
1624  const RightPad& right_pad,
1626 {
1627  return pad<LowLength, LeftPad, RightPad, SkipIsValidCheck>{low_length, left_pad, right_pad};
1628 }
1629 
1630 template <typename LowLength, typename LeftPadLength, bool SkipIsValidCheck = false>
1631 CK_TILE_HOST_DEVICE constexpr auto
1632 make_left_pad_transform(const LowLength& low_length,
1633  const LeftPadLength& left_pad_,
1635 {
1636  return left_pad<LowLength, LeftPadLength, SkipIsValidCheck>{low_length, left_pad_};
1637 }
1638 
1639 template <typename LowLength, typename RightPadLength, bool SkipIsValidCheck = false>
1640 CK_TILE_HOST_DEVICE constexpr auto
1641 make_right_pad_transform(const LowLength& low_length,
1642  const RightPadLength& right_pad_,
1644 {
1645  return right_pad<LowLength, RightPadLength, SkipIsValidCheck>{low_length, right_pad_};
1646 }
1647 
1648 template <typename UpLengths,
1649  typename Coefficients,
1650  typename std::enable_if<UpLengths::size() == Coefficients::size(), bool>::type = false>
1651 CK_TILE_HOST_DEVICE constexpr auto make_embed_transform(const UpLengths& up_lengths,
1652  const Coefficients& coefficients)
1653 {
1654  return embed<UpLengths, Coefficients>{up_lengths, coefficients};
1655 }
1656 
1657 template <typename LowLengths>
1658 CK_TILE_HOST_DEVICE constexpr auto
1659 make_merge_transform_v2_magic_division(const LowLengths& low_lengths)
1660 {
1661  return merge_v2_magic_division<LowLengths>{low_lengths};
1662 }
1663 
1664 template <typename LowLengths>
1665 CK_TILE_HOST_DEVICE constexpr auto
1666 make_merge_transform_v3_division_mod(const LowLengths& low_lengths)
1667 {
1668  return merge_v3_division_mod<LowLengths>{low_lengths};
1669 }
1670 
1671 template <typename LowLengths>
1672 CK_TILE_HOST_DEVICE constexpr auto make_merge_transform(const LowLengths& low_lengths)
1673 {
1674  return make_merge_transform_v2_magic_division(low_lengths);
1675 }
1676 
1677 template <typename UpLengths, bool Use24BitIntegerCalculation = false>
1678 CK_TILE_HOST_DEVICE constexpr auto
1679 make_unmerge_transform(const UpLengths& up_lengths,
1681 {
1682  return unmerge<UpLengths, Use24BitIntegerCalculation>{up_lengths};
1683 }
1684 
1685 template <typename LowerIndex>
1686 CK_TILE_HOST_DEVICE constexpr auto make_freeze_transform(const LowerIndex& low_idx)
1687 {
1688  return freeze<LowerIndex>{low_idx};
1689 }
1690 
1691 template <typename UpperIndex>
1692 CK_TILE_HOST_DEVICE constexpr auto make_insert_transform(const UpperIndex& up_idx)
1693 {
1694  return insert<UpperIndex>{up_idx};
1695 }
1696 
1697 template <typename UpLengths>
1698 CK_TILE_HOST_DEVICE constexpr auto make_replicate_transform(const UpLengths& up_lengths)
1699 {
1700  return replicate<UpLengths>{up_lengths};
1701 }
1702 
1703 template <typename LowLength, typename SliceBegin, typename SliceEnd>
1704 CK_TILE_HOST_DEVICE constexpr auto make_slice_transform(const LowLength& low_length,
1705  const SliceBegin& slice_begin,
1706  const SliceEnd& slice_end)
1707 {
1708  return slice<LowLength, SliceBegin, SliceEnd>{low_length, slice_begin, slice_end};
1709 }
1710 
1711 template <typename Modulus, typename UpLength>
1712 CK_TILE_HOST_DEVICE constexpr auto make_modulo_transform(const Modulus& modulus,
1713  const UpLength& up_length)
1714 {
1715  return modulo<Modulus, UpLength>{modulus, up_length};
1716 }
1717 
1718 template <typename LowLengths>
1719 CK_TILE_HOST_DEVICE constexpr auto make_xor_transform(const LowLengths& low_lengths)
1720 {
1721  return xor_t<LowLengths>{low_lengths};
1722 }
1723 
1724 template <typename LowLength, typename OffsetLength>
1725 CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_length,
1726  const OffsetLength& offset_length)
1727 {
1728  return offset<LowLength, OffsetLength>{low_length, offset_length};
1729 }
1730 
1731 } // namespace ck_tile
1732 
1734 namespace ck_tile {
1735 
1736 template <typename UpLength, typename Indices>
1737 CK_TILE_HOST_DEVICE constexpr auto make_indexing_transform(const UpLength& up_lengths,
1738  const Indices& indices)
1739 {
1740  // by default we use the simplest one
1743 }
1744 
1745 template <typename UpLength, typename IndexingAdaptor>
1746 CK_TILE_HOST_DEVICE constexpr auto
1747 make_indexing_transform_with_adaptor(const UpLength& up_lengths, const IndexingAdaptor& iadaptor)
1748 {
1749  return indexing<UpLength, IndexingAdaptor>{up_lengths, iadaptor};
1750 }
1751 
1752 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
__host__ constexpr __device__ auto unmerge(const Layout< Shape, UnrolledDesc > &layout, const NewLengths &new_lengths, [[maybe_unused]] const NewIdxs &new_indexes)
Unmerge selected dim in layout.
Definition: layout_utils.hpp:474
__host__ constexpr __device__ index_t gcd(index_t x, index_t y)
Definition: math.hpp:154
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_insert_transform(const UpperIndex &up_idx)
Definition: coordinate_transform.hpp:1692
constexpr CK_TILE_HOST_DEVICE auto container_reduce(const Container &x, Reduce reduce, Init init, number< IBegin >=number< 0 >{}, number< IEnd >=number< Container::size()>{}, number< IStep >=number< 1 >{})
Definition: container_helper.hpp:198
constexpr CK_TILE_HOST_DEVICE auto make_left_pad_transform(const LowLength &low_length, const LeftPadLength &left_pad_, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1632
coord_transform_enum
Definition: coordinate_transform.hpp:16
constexpr CK_TILE_HOST_DEVICE auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad_, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1641
__host__ __device__ multiplies() -> multiplies< void, void >
FIXME: create macro to replace 'host device' and nothing more.
constexpr CK_TILE_HOST_DEVICE auto make_indexing_transform_with_adaptor(const UpLength &up_lengths, const IndexingAdaptor &iadaptor)
Definition: coordinate_transform.hpp:1747
constexpr CK_TILE_HOST_DEVICE auto make_offset_transform(const LowLength &low_length, const OffsetLength &offset_length)
Definition: coordinate_transform.hpp:1725
is_static< T > is_known_at_compile_time
Definition: type_traits.hpp:93
constexpr CK_TILE_HOST_DEVICE auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition: coordinate_transform.hpp:1704
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1672
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1615
constexpr CK_TILE_HOST_DEVICE auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1622
constexpr CK_TILE_HOST_DEVICE auto make_unmerge_transform(const UpLengths &up_lengths, bool_constant< Use24BitIntegerCalculation >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1679
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1666
constexpr CK_TILE_HOST_DEVICE auto make_modulo_transform(const Modulus &modulus, const UpLength &up_length)
Definition: coordinate_transform.hpp:1712
constexpr CK_TILE_HOST_DEVICE auto make_indexing_transform(const UpLength &up_lengths, const Indices &indices)
Definition: coordinate_transform.hpp:1737
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:400
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
constexpr CK_TILE_HOST_DEVICE auto make_xor_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1719
constexpr CK_TILE_HOST_DEVICE auto make_replicate_transform(const UpLengths &up_lengths)
Definition: coordinate_transform.hpp:1698
constexpr CK_TILE_HOST_DEVICE auto make_freeze_transform(const LowerIndex &low_idx)
Definition: coordinate_transform.hpp:1686
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform_v2_magic_division(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1659
constexpr CK_TILE_HOST_DEVICE auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition: coordinate_transform.hpp:1651
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:10
__host__ constexpr __device__ auto container_reverse_exclusive_scan(const Array< TData, NSize > &x, Reduce f, TData init)
Definition: container_helper.hpp:213
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto container_reduce(const Container &x, Reduce reduce, Init init, Number< IBegin >=Number< 0 >{}, Number< IEnd >=Number< Container::Size()>{}, Number< IStep >=Number< 1 >{})
Definition: container_helper.hpp:111
Definition: array.hpp:24
Definition: coordinate_transform.hpp:31
static constexpr CK_TILE_HOST_DEVICE auto get_type_enum()
Definition: coordinate_transform.hpp:32
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_upper_dimension()
Definition: coordinate_transform.hpp:39
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_lower_dimension()
Definition: coordinate_transform.hpp:37
static constexpr CK_TILE_HOST_DEVICE auto calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths &, const LowVectorStrides &)
Definition: coordinate_transform.hpp:46
Definition: integral_constant.hpp:13
Definition: coordinate_transform.hpp:461
constexpr CK_TILE_HOST_DEVICE embed()=default
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:483
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:518
CK_TILE_HOST_DEVICE void print() const
Definition: coordinate_transform.hpp:536
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:486
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &) const
Definition: coordinate_transform.hpp:500
static constexpr CK_TILE_HOST_DEVICE auto get_type_enum()
Definition: coordinate_transform.hpp:478
UpLengths up_lengths_
Definition: coordinate_transform.hpp:467
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition: coordinate_transform.hpp:525
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:530
Coefficients coefficients_
Definition: coordinate_transform.hpp:468
constexpr CK_TILE_HOST_DEVICE embed(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition: coordinate_transform.hpp:472
static constexpr index_t NDimUp
Definition: coordinate_transform.hpp:462
Definition: coordinate_transform.hpp:981
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &) const
Definition: coordinate_transform.hpp:991
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &, const UpIdx &)
Definition: coordinate_transform.hpp:1001
LowerIndex low_idx_
Definition: coordinate_transform.hpp:982
static constexpr CK_TILE_HOST_DEVICE auto get_upper_lengths()
Definition: coordinate_transform.hpp:988
constexpr CK_TILE_HOST_DEVICE freeze()=default
constexpr CK_TILE_HOST_DEVICE freeze(const LowerIndex &low_idx)
Definition: coordinate_transform.hpp:986
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:1010
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:1022
CK_TILE_HOST_DEVICE void print() const
Definition: coordinate_transform.hpp:1027
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition: coordinate_transform.hpp:1017
Definition: type_traits.hpp:75
Definition: indexing_adaptor.hpp:20
Definition: coordinate_transform.hpp:1532
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition: coordinate_transform.hpp:1588
constexpr CK_TILE_HOST_DEVICE indexing()=default
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:1558
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:1581
decltype(make_tuple(UpLength{})) UpLengths
Definition: coordinate_transform.hpp:1538
CK_TILE_HOST_DEVICE void print() const
Definition: coordinate_transform.hpp:1599
UpLengths up_lengths_
Definition: coordinate_transform.hpp:1539
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:1555
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:1567
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:1593
IndexingAdaptor iadaptor_
Definition: coordinate_transform.hpp:1540
constexpr CK_TILE_HOST_DEVICE indexing(const UpLength &up_length, const IndexingAdaptor &iadaptor)
Definition: coordinate_transform.hpp:1544
static constexpr CK_TILE_HOST_DEVICE auto get_type_enum()
Definition: coordinate_transform.hpp:1550
Definition: coordinate_transform.hpp:1042
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &, const UpIdxDiff &, LowIdx &, const UpIdx &)
Definition: coordinate_transform.hpp:1069
constexpr CK_TILE_HOST_DEVICE insert(const UpperLength &up_length)
Definition: coordinate_transform.hpp:1049
UpLengths up_lengths_
Definition: coordinate_transform.hpp:1045
CK_TILE_HOST_DEVICE void print() const
Definition: coordinate_transform.hpp:1096
decltype(make_tuple(UpperLength{})) UpLengths
Definition: coordinate_transform.hpp:1043
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition: coordinate_transform.hpp:1086
constexpr CK_TILE_HOST_DEVICE insert()=default
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &, const UpIdx &) const
Definition: coordinate_transform.hpp:1061
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_lower_dimension()
Definition: coordinate_transform.hpp:1054
constexpr CK_TILE_HOST_DEVICE auto get_upper_lengths() const
Definition: coordinate_transform.hpp:1058
static constexpr CK_TILE_HOST_DEVICE bool IsLinearTransform()
Definition: coordinate_transform.hpp:1076
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:1091
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:1079
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_upper_dimension()
Definition: coordinate_transform.hpp:1056
constexpr CK_TILE_HOST_DEVICE auto operator()(number< I > i) const
Definition: coordinate_transform.hpp:557
Definition: coordinate_transform.hpp:257
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:304
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:316
LeftPadLength left_pad_length_
Definition: coordinate_transform.hpp:264
UpLengths up_lengths_
Definition: coordinate_transform.hpp:263
constexpr CK_TILE_HOST_DEVICE left_pad()=default
decltype(make_tuple(LowLength{}+LeftPadLength{})) UpLengths
Definition: coordinate_transform.hpp:261
constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:311
constexpr CK_TILE_HOST_DEVICE left_pad(const LowLength &low_length, const LeftPadLength &left_pad_length)
Definition: coordinate_transform.hpp:268
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &)
Definition: coordinate_transform.hpp:287
CK_TILE_HOST_DEVICE void print() const
Definition: coordinate_transform.hpp:334
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:274
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:277
static constexpr CK_TILE_HOST_DEVICE auto calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths &low_vector_lengths, const LowVectorStrides &low_vector_strides)
Definition: coordinate_transform.hpp:325
static constexpr CK_TILE_HOST_DEVICE auto calculate_magic_numbers(uint32_t divisor)
Definition: magic_div.hpp:29
static constexpr CK_TILE_DEVICE uint32_t do_magic_division(uint32_t dividend, uint32_t multiplier, uint32_t shift)
Definition: magic_div.hpp:60
Definition: coordinate_transform.hpp:577
static constexpr CK_TILE_HOST_DEVICE auto calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths &low_vector_lengths, const LowVectorStrides &low_vector_strides)
Definition: coordinate_transform.hpp:691
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &idx_low, const UpIdx &idx_up_new) const
Definition: coordinate_transform.hpp:638
LowLengthsMagicDivisor low_lengths_magic_divisor_
Definition: coordinate_transform.hpp:591
static constexpr auto I1
Definition: coordinate_transform.hpp:595
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:669
static constexpr index_t NDimLow
Definition: coordinate_transform.hpp:578
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:674
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:614
decltype(generate_tuple(lambda_merge_generate_MagicDivision_calculate_magic_divisor< LowLengths >{}, number< NDimLow >{})) LowLengthsMagicDivisor
Definition: coordinate_transform.hpp:588
LowLengths low_lengths_
Definition: coordinate_transform.hpp:590
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition: coordinate_transform.hpp:683
static constexpr auto I0
Definition: coordinate_transform.hpp:594
UpLengths up_lengths_
Definition: coordinate_transform.hpp:592
decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number< 1 >{}))) UpLengths
Definition: coordinate_transform.hpp:584
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:617
CK_TILE_HOST_DEVICE void print() const
Definition: coordinate_transform.hpp:703
constexpr CK_TILE_HOST_DEVICE merge_v2_magic_division()=default
static constexpr CK_TILE_HOST_DEVICE auto get_type_enum()
Definition: coordinate_transform.hpp:609
constexpr CK_TILE_HOST_DEVICE merge_v2_magic_division(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:599
Definition: coordinate_transform.hpp:725
decltype(make_tuple(container_reduce(LowLengths{}, multiplies{}, number< 1 >{}))) UpLengths
Definition: coordinate_transform.hpp:735
static constexpr CK_TILE_HOST_DEVICE auto calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths &low_vector_lengths, const LowVectorStrides &low_vector_strides)
Definition: coordinate_transform.hpp:822
decltype(container_reverse_exclusive_scan(LowLengths{}, multiplies{}, number< 1 >{})) LowLengthsScan
Definition: coordinate_transform.hpp:732
LowLengths low_lengths_
Definition: coordinate_transform.hpp:737
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition: coordinate_transform.hpp:814
UpLengths up_lengths_
Definition: coordinate_transform.hpp:739
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:752
constexpr CK_TILE_HOST_DEVICE merge_v3_division_mod(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:743
LowLengthsScan low_lengths_scan_
Definition: coordinate_transform.hpp:738
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:805
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:755
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &idx_low, const UpIdx &idx_up_new) const
Definition: coordinate_transform.hpp:773
CK_TILE_HOST_DEVICE void print() const
Definition: coordinate_transform.hpp:834
constexpr CK_TILE_HOST_DEVICE merge_v3_division_mod()=default
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:800
Definition: coordinate_transform.hpp:1270
UpLengths up_lengths_
Definition: coordinate_transform.hpp:1276
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition: coordinate_transform.hpp:1322
decltype(make_tuple(UpLength{})) UpLengths
Definition: coordinate_transform.hpp:1273
constexpr CK_TILE_HOST_DEVICE modulo(const Modulus &modulus, const UpLength &up_length)
Definition: coordinate_transform.hpp:1280
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &up_idx) const
Definition: coordinate_transform.hpp:1298
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:1288
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:1285
Modulus modulus_
Definition: coordinate_transform.hpp:1275
CK_TILE_HOST_DEVICE void print() const
Definition: coordinate_transform.hpp:1332
constexpr CK_TILE_HOST_DEVICE modulo()=default
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:1315
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:1327
Definition: math.hpp:98
Definition: coordinate_transform.hpp:1443
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &)
Definition: coordinate_transform.hpp:1478
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:1507
decltype(make_tuple(LowLength{})) UpLengths
Definition: coordinate_transform.hpp:1447
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:1495
OffsetLength offset_length_
Definition: coordinate_transform.hpp:1450
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:1465
CK_TILE_HOST_DEVICE void print() const
Definition: coordinate_transform.hpp:1513
constexpr CK_TILE_HOST_DEVICE offset()=default
UpLengths up_lengths_
Definition: coordinate_transform.hpp:1449
static constexpr CK_TILE_HOST_DEVICE auto get_type_enum()
Definition: coordinate_transform.hpp:1460
constexpr CK_TILE_HOST_DEVICE offset(const LowLength &low_length, const OffsetLength &offset_length)
Definition: coordinate_transform.hpp:1454
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:1468
constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &) const
Definition: coordinate_transform.hpp:1502
Definition: coordinate_transform.hpp:161
constexpr CK_TILE_HOST_DEVICE pad(const LowLength &low_length, const LeftPadLength &left_pad_length, const RightPadLength &right_pad_length)
Definition: coordinate_transform.hpp:173
decltype(make_tuple(LowLength{}+LeftPadLength{}+RightPadLength{})) UpLengths
Definition: coordinate_transform.hpp:165
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:182
LeftPadLength left_pad_length_
Definition: coordinate_transform.hpp:168
UpLengths up_lengths_
Definition: coordinate_transform.hpp:167
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:226
constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:219
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &)
Definition: coordinate_transform.hpp:195
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:185
constexpr CK_TILE_HOST_DEVICE pad()
Definition: coordinate_transform.hpp:171
CK_TILE_HOST_DEVICE void print() const
Definition: coordinate_transform.hpp:233
RightPadLength right_pad_length_
Definition: coordinate_transform.hpp:169
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:212
Definition: coordinate_transform.hpp:65
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition: coordinate_transform.hpp:124
UpLengths up_lengths_
Definition: coordinate_transform.hpp:73
constexpr CK_TILE_HOST_DEVICE pass_through(const LowLength &low_length)
Definition: coordinate_transform.hpp:77
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:117
decltype(make_tuple(LowLength{})) UpLengths
Definition: coordinate_transform.hpp:71
constexpr CK_TILE_HOST_DEVICE pass_through()=default
static constexpr auto type_enum
Definition: coordinate_transform.hpp:66
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &)
Definition: coordinate_transform.hpp:100
static constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up)
Definition: coordinate_transform.hpp:90
CK_TILE_HOST_DEVICE void print() const
Definition: coordinate_transform.hpp:143
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:87
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:129
static constexpr CK_TILE_HOST_DEVICE auto get_type_enum()
Definition: coordinate_transform.hpp:82
static constexpr CK_TILE_HOST_DEVICE auto calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths &low_vector_lengths, const LowVectorStrides &low_vector_strides)
Definition: coordinate_transform.hpp:137
Definition: coordinate_transform.hpp:1110
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:1138
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:1150
constexpr CK_TILE_HOST_DEVICE replicate()=default
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition: coordinate_transform.hpp:1145
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &, const UpIdx &) const
Definition: coordinate_transform.hpp:1122
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &, const UpIdxDiff &, LowIdx &, const UpIdx &)
Definition: coordinate_transform.hpp:1130
UpLengths up_lengths_
Definition: coordinate_transform.hpp:1167
CK_TILE_HOST_DEVICE void print() const
Definition: coordinate_transform.hpp:1155
constexpr CK_TILE_HOST_DEVICE auto get_upper_lengths() const
Definition: coordinate_transform.hpp:1119
constexpr CK_TILE_HOST_DEVICE replicate(const UpLengths &up_lengths)
Definition: coordinate_transform.hpp:1115
Definition: coordinate_transform.hpp:353
LowLength low_length_
Definition: coordinate_transform.hpp:360
constexpr CK_TILE_HOST_DEVICE right_pad(const LowLength &low_length, const RightPadLength &right_pad_length)
Definition: coordinate_transform.hpp:365
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:373
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:415
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &)
Definition: coordinate_transform.hpp:386
static constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up)
Definition: coordinate_transform.hpp:376
static constexpr CK_TILE_HOST_DEVICE auto calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths &low_vector_lengths, const LowVectorStrides &low_vector_strides)
Definition: coordinate_transform.hpp:425
RightPadLength right_pad_length_
Definition: coordinate_transform.hpp:361
CK_TILE_HOST_DEVICE void print() const
Definition: coordinate_transform.hpp:434
constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:410
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:403
constexpr CK_TILE_HOST_DEVICE right_pad()=default
decltype(make_tuple(LowLength{}+RightPadLength{})) UpLengths
Definition: coordinate_transform.hpp:357
UpLengths up_lengths_
Definition: coordinate_transform.hpp:359
Definition: coordinate_transform.hpp:1172
constexpr CK_TILE_HOST_DEVICE slice(const LowLength &, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition: coordinate_transform.hpp:1184
static CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &)
Definition: coordinate_transform.hpp:1206
constexpr CK_TILE_HOST_DEVICE slice()=default
constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &) const
Definition: coordinate_transform.hpp:1230
SliceBegin slice_begin_
Definition: coordinate_transform.hpp:1179
UpLengths up_lengths_
Definition: coordinate_transform.hpp:1178
decltype(make_tuple(SliceEnd{} - SliceBegin{})) UpLengths
Definition: coordinate_transform.hpp:1176
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:1193
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:1223
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:1196
CK_TILE_HOST_DEVICE void print() const
Definition: coordinate_transform.hpp:1242
SliceEnd slice_end_
Definition: coordinate_transform.hpp:1180
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:1235
Definition: functional.hpp:43
Definition: tuple.hpp:192
Definition: coordinate_transform.hpp:858
static constexpr CK_TILE_HOST_DEVICE auto calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths &low_vector_lengths, const LowVectorStrides &low_vector_strides)
Definition: coordinate_transform.hpp:941
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &idx_diff_up, LowIdx &idx_low, const UpIdx &) const
Definition: coordinate_transform.hpp:909
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition: coordinate_transform.hpp:927
constexpr CK_TILE_HOST_DEVICE unmerge(const UpLengths &up_lengths)
Definition: coordinate_transform.hpp:872
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:886
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:932
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:883
UpLengthsScan up_lengths_scan_
Definition: coordinate_transform.hpp:868
UpLengths up_lengths_
Definition: coordinate_transform.hpp:867
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:920
CK_TILE_HOST_DEVICE void print() const
Definition: coordinate_transform.hpp:962
constexpr CK_TILE_HOST_DEVICE unmerge()=default
static constexpr CK_TILE_HOST_DEVICE auto get_type_enum()
Definition: coordinate_transform.hpp:878
decltype(container_reverse_exclusive_scan(UpLengths{}, multiplies{}, number< 1 >{})) UpLengthsScan
Definition: coordinate_transform.hpp:865
Definition: coordinate_transform.hpp:1347
constexpr CK_TILE_HOST_DEVICE const auto & get_upper_lengths() const
Definition: coordinate_transform.hpp:1366
constexpr CK_TILE_HOST_DEVICE xor_t(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1359
constexpr CK_TILE_HOST_DEVICE void calculate_lower_index(LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:1369
constexpr CK_TILE_HOST_DEVICE auto calculate_upper_dimension_safe_vector_length_strides(const LowVectorLengths &low_vector_lengths, const LowVectorStrides &low_vector_strides) const
Definition: coordinate_transform.hpp:1418
LowLengths UpLengths
Definition: coordinate_transform.hpp:1353
UpLengths up_lengths_
Definition: coordinate_transform.hpp:1355
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_mapped_to_valid_lower_index(const UpIdx &)
Definition: coordinate_transform.hpp:1406
static constexpr CK_TILE_HOST_DEVICE bool is_known_at_compile_time()
Definition: coordinate_transform.hpp:1411
CK_TILE_HOST_DEVICE void print() const
Definition: coordinate_transform.hpp:1428
static constexpr CK_TILE_HOST_DEVICE auto get_type_enum()
Definition: coordinate_transform.hpp:1361
constexpr CK_TILE_HOST_DEVICE xor_t()
Definition: coordinate_transform.hpp:1357
CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff &idx_diff_low, const UpIdxDiff &, LowIdx &idx_low, const UpIdx &idx_up) const
Definition: coordinate_transform.hpp:1382
static constexpr CK_TILE_HOST_DEVICE bool is_valid_upper_index_always_mapped_to_valid_lower_index()
Definition: coordinate_transform.hpp:1399