/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_window.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_window.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_window.hpp Source File
tile_window.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
20 
21 namespace ck_tile {
22 
34 template <typename BottomTensorView_,
35  typename WindowLengths_,
36  typename StaticTileDistribution_,
37  index_t NumCoord>
40  tile_window_with_static_distribution<BottomTensorView_,
41  WindowLengths_,
42  StaticTileDistribution_,
43  NumCoord>,
44  BottomTensorView_,
45  WindowLengths_,
46  StaticTileDistribution_>
47 {
49  tile_window_with_static_distribution<BottomTensorView_,
50  WindowLengths_,
51  StaticTileDistribution_,
52  NumCoord>,
53  BottomTensorView_,
54  WindowLengths_,
55  StaticTileDistribution_>;
56 
57  static constexpr auto I0 = number<0>{};
58  static constexpr auto I1 = number<1>{};
59  static_assert(NumCoord == 1);
60 
61  static_assert(Base::Traits::NumAccess % NumCoord == 0,
62  "wrong! # of access is not divisible by NumCoord");
63  static constexpr index_t NumAccessPerCoord = Base::Traits::NumAccess / NumCoord;
64 
66 
68  const typename Base::BottomTensorView& bottom_tensor_view,
69  const typename Base::WindowLengths& window_lengths,
70  const typename Base::BottomTensorIndex& window_origin,
71  const typename Base::TileDstr& tile_distribution,
72  decltype(get_partition_index(tile_distribution)) partition_index)
74  {
75 
76  this->window_origin_ = window_origin;
77  this->window_lengths_ = window_lengths;
78  this->bottom_tensor_view_ = bottom_tensor_view;
79  this->tile_dstr_ = tile_distribution;
80 
82  prepare_coords(bottom_tensor_view, window_origin, tile_distribution, partition_index);
83  if constexpr(Base::BottomTensorView::buffer_view::get_address_space() ==
84  address_space_enum::global)
85  {
86  auto use_lane_id_0 = partition_index;
87  use_lane_id_0[1] = 0;
88 
90  prepare_coords(bottom_tensor_view, window_origin, tile_distribution, use_lane_id_0);
91  }
92  }
93 
95  const typename Base::BottomTensorView& bottom_tensor_view,
96  const typename Base::WindowLengths& window_lengths,
97  const typename Base::BottomTensorIndex& window_origin,
98  const typename Base::TileDstr& tile_distribution)
99  : tile_window_with_static_distribution(bottom_tensor_view,
100  window_lengths,
101  window_origin,
104  {
105  }
106 
107  CK_TILE_DEVICE constexpr auto
108  prepare_coords(const typename Base::BottomTensorView& bottom_tensor_view,
109  const typename Base::BottomTensorIndex& window_origin,
110  const typename Base::TileDstr& tile_distribution,
111  decltype(get_partition_index(tile_distribution)) partition_index) const
112  {
114  coords;
115 
116  const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
118  container_concat(partition_index, multi_index<Base::NDimY>{0}));
119 
120  typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
121  window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
122 
123  const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
124  bottom_tensor_view.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
125 
126  // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
127  // future load/store() calls (might allocate more registers)
128  using Traits = typename Base::Traits;
129  using SFC_Ys = typename Traits::SFC_Ys;
130 
131  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
132  auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
133  auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
134 
135  constexpr auto idx_diff_ys =
136  SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
137 
138  constexpr auto idx_diff_ps_ys = container_concat(
139  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
140  idx_diff_ys);
141 
143  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
144 
145  coords(iCoord) = make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
146  });
147 
148  return coords;
149  }
150 
151  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
154  {
155  return load_with_offset(
156  0, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
157  }
158 
159  template <index_t i_access_unsupport_ = -1,
160  bool oob_conditional_check = true,
161  typename offset_t = index_t>
165  {
166  constexpr auto tile_dstr = typename Base::TileDstr{};
167  auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
168  load_with_offset(offset,
169  dst_tensor,
170  number<i_access_unsupport_>{},
171  bool_constant<oob_conditional_check>{});
172  return dst_tensor;
173  }
174 
185  template <typename TileWindow_,
186  typename ElementWise_,
187  index_t i_access_unsupport_ = -1,
188  bool oob_conditional_check = true>
189  CK_TILE_DEVICE auto load(const TileWindow_& tile_window,
190  ElementWise_ elementwise,
193  {
194  constexpr auto tile_dstr = typename Base::TileDstr{};
195  auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
196  load(dst_tensor,
197  tile_window,
198  elementwise,
199  number<i_access_unsupport_>{},
200  bool_constant<oob_conditional_check>{});
201  return dst_tensor;
202  }
203 
204  template <typename DistributedTensor,
205  typename TileWindow_,
206  typename ElementWise_,
207  index_t i_access_unsupport_ = -1,
208  bool oob_conditional_check = true>
209  CK_TILE_DEVICE void load(DistributedTensor& dst_tensor,
210  const TileWindow_& tile_window,
211  ElementWise_ elementwise,
214  {
215 
216  using Traits = typename Base::Traits;
217  using vector_t = typename Traits::vector_t;
218  using SFC_Ys = typename Traits::SFC_Ys;
219 
220  constexpr auto tile_dstr = typename Base::TileDstr{};
221  constexpr auto sizeOfTuple = TileWindow_::size();
222  // loop over thread tensor space [y0, y1, ...]
223  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
225  auto window_adaptor_thread_coord =
226  tile_window[number<0>{}].pre_computed_coords_[iCoord][I0];
227  auto bottom_tensor_thread_coord =
228  tile_window[number<0>{}].pre_computed_coords_[iCoord][I1];
229 
230  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
231  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
232 
233  // data index [y0, y1, ...]
234  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
235 
236  // read from bottom tensor
237  const auto idx_vec_value = generate_tuple(
238  [&](auto jj) {
239  return tile_window[number<jj>{}]
240  .get_bottom_tensor_view()
241  .template get_vectorized_elements<vector_t>(
242  bottom_tensor_thread_coord,
243  0,
244  bool_constant<oob_conditional_check>{});
245  },
246  number<sizeOfTuple>{});
247 
248  // write into distributed tensor
249  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
250  constexpr auto idx_ys = generate_tuple(
251  [&](auto jj) {
252  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
253  : idx_ys_start[jj];
254  },
255  number<Base::NDimY>{});
256 
257  constexpr index_t d =
258  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
259  Traits::PackedSize;
260 
262  [&](auto&&... t) {
263  elementwise(dst_tensor.get_thread_buffer().template at<d>(),
264  t.template get_as<
265  typename Base::DataType>()[j / Traits::PackedSize]...);
266  },
267  idx_vec_value);
268  });
269  // move thread coordinate
270  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
271  {
272  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
273 
274  constexpr auto idx_diff_ps_ys = container_concat(
275  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
276  idx_diff_ys);
277 
279  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
280  }
281  });
282  });
283  }
284 
285  template <typename DistributedTensor,
286  index_t i_access_unsupport_ = -1,
287  bool oob_conditional_check = true>
288  CK_TILE_DEVICE void load(DistributedTensor& dst_tensor,
291  {
293  0, dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
294  }
295 
296  template <typename offset_t>
297  CK_TILE_DEVICE constexpr auto get_load_offset(offset_t = {}) const
298  {
299  constexpr auto bottom_tensor_idx_off = to_multi_index(offset_t{});
300  const auto bottom_tensor_coord_off = make_tensor_coordinate(
301  this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_idx_off);
302  return amd_wave_read_first_lane(bottom_tensor_coord_off.get_offset());
303  }
304 
305  template <typename DataType,
306  typename StaticTileDistribution,
307  index_t i_access_unsupport_ = -1,
308  bool oob_conditional_check = true,
309  typename offset_t>
311  offset_t offset,
315  {
316  using Traits = typename Base::Traits;
317  using vector_t = typename Traits::vector_t;
318  using SFC_Ys = typename Traits::SFC_Ys;
319 
320  constexpr auto tile_dstr = typename Base::TileDstr{};
321 
322  const index_t linear_off = [&]() {
323  if constexpr(std::is_integral_v<offset_t>)
324  return offset;
325  else if constexpr(is_constant_v<offset_t>)
326  return offset_t::value;
327  else
328  return get_load_offset(offset_t{});
329  }();
330  // loop over thread tensor space [y0, y1, ...]
331  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
333  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
334  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
335 
336  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
337  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
338 
339  // data index [y0, y1, ...]
340  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
341 
342  // read from bottom tensor
343  const vector_t vec_value =
344  this->get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
345  bottom_tensor_thread_coord,
346  linear_off,
347  bool_constant<oob_conditional_check>{});
348  // write into distributed tensor
349  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
350  constexpr auto idx_ys = generate_tuple(
351  [&](auto jj) {
352  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
353  : idx_ys_start[jj];
354  },
355  number<Base::NDimY>{});
356 
357  constexpr index_t d =
358  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
359  Traits::PackedSize;
360 
361  dst_tensor.get_thread_buffer().template at<d>() =
362  vec_value
363  .template get_as<typename Base::DataType>()[j / Traits::PackedSize];
364  });
365  // move thread coordinate
366  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
367  {
368  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
369 
370  constexpr auto idx_diff_ps_ys = container_concat(
371  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
372  idx_diff_ys);
373 
375  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
376  }
377  });
378  });
379  }
380 
381  template <typename DstTile,
382  index_t i_access_unsupport_ = -1,
383  bool oob_conditional_check = true,
384  bool pre_nop = false>
385  CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
388  bool_constant<pre_nop> = {}) const
389  {
390  using Traits = typename Base::Traits;
391  using vector_t = typename Traits::vector_t;
392  using SFC_Ys = typename Traits::SFC_Ys;
393  static constexpr index_t YElementSize =
394  typename Base::TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
395  static_assert(YElementSize % (Traits::PackedSize * Traits::ScalarPerVector) == 0);
396  using vectorized_tbuf =
397  array<vector_t, YElementSize / (Traits::PackedSize * Traits::ScalarPerVector)>;
398 
399  constexpr auto tile_dstr = typename Base::TileDstr{};
400 
401  auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
402 
403  // loop over thread tensor space [y0, y1, ...]
404  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
406  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
407  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
408 
409  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
410  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
411  constexpr auto pre_nop_ = [&]() {
412  if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
413  return bool_constant<true>{};
414  else
415  return bool_constant<false>{};
416  }();
417 
418  // data index [y0, y1, ...]
419  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
420  constexpr index_t d =
421  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start) /
422  Traits::PackedSize;
423  static_assert(d % Traits::ScalarPerVector == 0);
424 
425  this->get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
426  dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
427  bottom_tensor_thread_coord,
428  0 ,
429  bool_constant<oob_conditional_check>{},
430  pre_nop_);
431 #if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
432  CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
433  asm volatile(
434  ""); // this is starting from rocm-6.2, but same sympton, reuse this flag
435 #endif
436  // move thread coordinate
437  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
438  {
439  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
440 
441  constexpr auto idx_diff_ps_ys = container_concat(
442  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
443  idx_diff_ys);
444 
446  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
447  }
448  });
449  });
450  }
451 
452  // TODO: currently async load only implemented in inline asm
453  template <typename LdsTileWindow_,
454  index_t i_access_unsupport_ = -1,
455  bool oob_conditional_check = true,
456  bool pre_nop = false>
457  CK_TILE_DEVICE void async_load_raw(LdsTileWindow_&& lds_tile,
460  bool_constant<pre_nop> = {}) const
461  {
462  using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
463  // using LdsTensorView = typename LdsTileWindow::BottomTensorView;
464  using LdsDataType = typename LdsTileWindow::DataType;
465  // using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc;
466 
467  // issues * warps * lanes
468  static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
469 
470  const index_t size_per_buf =
471  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
472  make_tuple(number<0>{}, number<0>{}, number<0>{})) *
473  sizeof(LdsDataType);
474 
475  const index_t size_per_wave =
476  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
477  make_tuple(number<0>{}, number<1>{}, number<0>{})) *
478  sizeof(LdsDataType) -
479  size_per_buf;
480 
481  const index_t size_per_issue =
482  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
483  make_tuple(number<1>{}, number<0>{}, number<0>{})) *
484  sizeof(LdsDataType) -
485  size_per_buf;
486 
487  // Use VALU so the compiler can optimize redundant/repeated computations
488  const index_t m0_init_value =
489  size_per_buf + size_per_wave * get_warp_id(/*ReturnSgpr=*/bool_constant<false>{});
491  amd_wave_read_first_lane(m0_init_value)); // This should be wave independent
492 
493  using Traits = typename Base::Traits;
494 
495  using vector_t = typename Traits::vector_t;
496  using SFC_Ys = typename Traits::SFC_Ys;
497 
498  LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
499 
500  // loop over thread tensor space [y0, y1, ...]
501  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
503  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
504  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
505 
506  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
507  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
508  constexpr auto pre_nop_ = [&]() {
509  if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
510  return bool_constant<true>{};
511  else
512  return bool_constant<false>{};
513  }();
514 
515  // read from bottom tensor
516  this->get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
517  smem, bottom_tensor_thread_coord, 0, pre_nop_);
518 
519  // move thread coordinate
520  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
521  {
522  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
523 
524  constexpr auto idx_diff_ps_ys = container_concat(
525  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
526  idx_diff_ys);
527 
529  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
530 
531  m0_inc_with_memory(size_per_issue);
532  }
533  });
534  });
535  }
536 
537  template <typename LdsTileWindow_,
538  index_t i_access_unsupport_ = -1,
539  bool oob_conditional_check = true,
540  bool static_move_ys = false,
541  typename = std::enable_if_t<std::is_class_v<remove_cvref_t<LdsTileWindow_>>>>
543  LdsTileWindow_&& lds_tile,
546  bool_constant<static_move_ys> = {}) const
547  {
548  using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
549  using LdsDataType = typename LdsTileWindow::DataType;
550  using Traits = typename Base::Traits;
551 
552  using vector_t = typename Traits::vector_t;
553  using SFC_Ys = typename Traits::SFC_Ys;
554 
555  // Precompute invariant values outside loops
556  const auto window_origin = lds_tile.get_window_origin();
557  const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
558  const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
559  auto lds_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
560 
561  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
562  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
563  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
564 
565  auto window_adaptor_warp_coord = pre_computed_warp_coords_[iCoord][I0];
566  auto bottom_tensor_warp_coord = pre_computed_warp_coords_[iCoord][I1];
567 
568  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
569  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
570 
571  constexpr auto idx_ys_offset = [&]() {
572  constexpr auto idx_off_ys = SFC_Ys::get_step_between(number<0>{}, iAccess);
573  constexpr auto adapter_ys_offset = make_tensor_adaptor_coordinate(
574  StaticTileDistribution_{}.get_ps_ys_to_xs_adaptor(),
575  container_concat(array<index_t, Base::NDimP>{0},
576  to_array<index_t, idx_off_ys.size()>(idx_off_ys)));
577  return adapter_ys_offset.get_bottom_index();
578  }();
579  const auto lds_ys_offset = [&]() {
580  if constexpr(static_move_ys)
581  {
582  const auto coord_ys_offset =
583  make_tensor_coordinate(tensor_descriptor, idx_ys_offset);
584  return coord_ys_offset.get_offset();
585  }
586  else
587  return 0;
588  }();
589 
590  // Use precomputed window origin & tensor descriptor
591  auto lds_bottom_tensor_thread_idx =
592  window_origin + window_adaptor_warp_coord.get_bottom_index();
593  const auto lds_coord =
594  make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
595 
596  // Calculate SMEM address using base pointer
597  CK_TILE_LDS_ADDR LdsDataType* smem = lds_base_ptr +
598  lds_coord.get_offset() / Traits::PackedSize +
599  lds_ys_offset / Traits::PackedSize;
600 
601  const auto dram_ys_offset = [&]() {
602  if constexpr(static_move_ys)
603  {
604  const auto coord_ys_offset = make_tensor_coordinate(
605  this->get_bottom_tensor_view().get_tensor_descriptor(), idx_ys_offset);
606  return coord_ys_offset.get_offset();
607  }
608  else
609  return 0;
610  }();
611 
612  this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
613  smem,
614  bottom_tensor_thread_coord,
615  offset + dram_ys_offset,
616  bool_constant<oob_conditional_check>{});
617 
618  // Move thread coordinate if not last access
619  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
620  {
621  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
622  constexpr auto idx_diff_ps_ys = container_concat(
623  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
624  idx_diff_ys);
625 
626  if constexpr(!static_move_ys)
628  window_adaptor_thread_coord,
629  bottom_tensor_thread_coord,
630  idx_diff_ps_ys);
631 
632  if constexpr(!static_move_ys)
634  window_adaptor_warp_coord, bottom_tensor_warp_coord, idx_diff_ps_ys);
635  }
636  });
637  });
638  }
639 
640  template <typename Policy, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
643  {
644  return this->template load_transpose_with_offset<Policy>(
645  0, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
646  }
647 
648  template <typename Policy, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
652  {
653  constexpr auto tile_dstr = typename Base::TileDstr{};
654  auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
655  this->template load_transpose_with_offset<Policy>(offset,
656  dst_tensor,
657  number<i_access_unsupport_>{},
658  bool_constant<oob_conditional_check>{});
659  return dst_tensor;
660  }
661 
662  template <typename Policy,
663  typename DistributedTensor,
664  index_t i_access_unsupport_ = -1,
665  bool oob_conditional_check = true>
667  DistributedTensor& dst_tensor,
670  {
671  using Traits = typename Base::Traits;
672  using vector_t = typename Traits::vector_t;
673  using SFC_Ys = typename Traits::SFC_Ys;
674 
675  constexpr auto tile_dstr = typename Base::TileDstr{};
676 
677  constexpr auto group_func = Policy::group_func;
678 
679  // loop over thread tensor space [y0, y1, ...]
680  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
682  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
683  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
684 
685  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
686  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
687 
688  // data index [y0, y1, ...]
689  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
690 
691  // read from bottom tensor
692  const vector_t vec_value =
693  this->get_bottom_tensor_view()
694  .template get_transpose_vectorized_elements<vector_t>(
695  bottom_tensor_thread_coord, offset);
696  // write into distributed tensor
697  static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
698  constexpr auto orig_idx_ys = generate_tuple(
699  [&](auto jj) {
700  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
701  : idx_ys_start[jj];
702  },
703  number<Base::NDimY>{});
704 
705  constexpr auto grouped_idx_ys = group_func(orig_idx_ys);
706 
707  constexpr index_t linear_distributed_index =
708  tile_dstr.get_ys_to_d_descriptor().calculate_offset(grouped_idx_ys);
709 
710  dst_tensor.get_thread_buffer().template at<linear_distributed_index>() =
711  vec_value.template get_as<typename Base::DataType>()[j];
712  });
713  // move thread coordinate
714  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
715  {
716  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
717 
718  constexpr auto idx_diff_ps_ys = container_concat(
719  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
720  idx_diff_ys);
721 
723  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
724  }
725  });
726  });
727  }
728 
729  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
731  typename Base::TileDstr>& dstr_tensor,
734  {
735  using Traits = typename Base::Traits;
736 
737  using vector_t = typename Traits::vector_t;
738  using SFC_Ys = typename Traits::SFC_Ys;
739 
740  constexpr auto tile_dstr = typename Base::TileDstr{};
741 
742  // loop over thread tensor space [y0, y1, ...]
743  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
744  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
745  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
746 
747  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
748  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
749 
750  // data index [y0, y1, ...]
751  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
752 
753  // read from distributed tensor
754  // vector_type_t vec;
755  vector_t vec_value;
756 
757  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
758  constexpr auto idx_ys = generate_tuple(
759  [&](auto jj) {
760  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
761  : idx_ys_start[jj];
762  },
763  number<Base::NDimY>{});
764 
765  constexpr index_t d =
766  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
767  Traits::PackedSize;
768 
769  vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
770  dstr_tensor.get_thread_buffer().template at<d>();
771  });
772 
773  // const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
774 
775  // write into bottom tensor
776  this->get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
777  bottom_tensor_thread_coord,
778  0,
779  vec_value,
780  bool_constant<oob_conditional_check>{});
781 
782  // move thread coordinate
783  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
784  {
785  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
786 
787  constexpr auto idx_diff_ps_ys = container_concat(
788  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
789  idx_diff_ys);
790 
792  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
793  }
794  });
795  });
796  }
797 
798  template <index_t i_access_unsupport_ = -1>
799  CK_TILE_DEVICE void
801  dstr_tensor,
802  number<i_access_unsupport_> = {}) const
803  {
804  using Traits = typename Base::Traits;
805 
806  using vector_t = typename Traits::vector_t;
807  using SFC_Ys = typename Traits::SFC_Ys;
808 
809  constexpr auto tile_dstr = typename Base::TileDstr{};
810  static constexpr bool oob_conditional_check = true;
811 
812  // loop over thread tensor space [y0, y1, ...]
813  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
815  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
816  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
817 
818  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
819  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
820 
821  // data index [y0, y1, ...]
822  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
823 
824  // read from distributed tensor
825  vector_t vec_value;
826  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
827  constexpr auto idx_ys = generate_tuple(
828  [&](auto jj) {
829  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
830  : idx_ys_start[jj];
831  },
832  number<Base::NDimY>{});
833  constexpr index_t d =
834  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
835  Traits::PackedSize;
836  vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
837  dstr_tensor.get_thread_buffer().template at<d>();
838  });
839 
840  // write into bottom tensor
841  this->get_bottom_tensor_view()
842  .template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
843  bottom_tensor_thread_coord, 0, vec_value);
844 
845  // move thread coordinate
846  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
847  {
848  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
849 
850  constexpr auto idx_diff_ps_ys = container_concat(
851  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
852  idx_diff_ys);
853 
855  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
856  }
857  });
858  });
859  }
860 
861  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
862  CK_TILE_DEVICE void
864  dstr_tensor,
867  {
868  using Traits = typename Base::Traits;
869 
870  using vector_t = typename Traits::vector_t;
871  using SFC_Ys = typename Traits::SFC_Ys;
872 
873  constexpr auto tile_dstr = typename Base::TileDstr{};
874 
875  // loop over thread tensor space [y0, y1, ...]
876  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
878  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
879  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
880 
881  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
882  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
883 
884  // data index [y0, y1, ...]
885  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
886 
887  // read from distributed tensor
888  vector_t vec_value;
889 
890  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
891  constexpr auto idx_ys = generate_tuple(
892  [&](auto jj) {
893  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
894  : idx_ys_start[jj];
895  },
896  number<Base::NDimY>{});
897 
898  constexpr index_t d =
899  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
900  Traits::PackedSize;
901 
902  vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
903  dstr_tensor.get_thread_buffer().template at<d>();
904  });
905 
906  // write into bottom tensor
907  this->get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
908  bottom_tensor_thread_coord,
909  0,
910  vec_value,
911  bool_constant<oob_conditional_check>{});
912 
913  // move thread coordinate
914  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
915  {
916  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
917 
918  constexpr auto idx_diff_ps_ys = container_concat(
919  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
920  idx_diff_ys);
921 
923  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
924  }
925  });
926  });
927  }
928 
929  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true, bool pre_nop>
930  CK_TILE_DEVICE void
932  dstr_tensor,
935  bool_constant<pre_nop> = {}) const
936  {
937  using Traits = typename Base::Traits;
938 
939  using vector_t = typename Traits::vector_t;
940  using SFC_Ys = typename Traits::SFC_Ys;
941 
942  constexpr auto tile_dstr = typename Base::TileDstr{};
943 
944  // loop over thread tensor space [y0, y1, ...]
945  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
947  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
948  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
949 
950  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
951  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
952 
953  // data index [y0, y1, ...]
954  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
955 
956  // read from distributed tensor
957  vector_t vec_value;
958 
959  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
960  constexpr auto idx_ys = generate_tuple(
961  [&](auto jj) {
962  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
963  : idx_ys_start[jj];
964  },
965  number<Base::NDimY>{});
966 
967  constexpr index_t d =
968  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
969  Traits::PackedSize;
970 
971  vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
972  dstr_tensor.get_thread_buffer().template at<d>();
973  });
974 
975  // write into bottom tensor
976  this->get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
977  bottom_tensor_thread_coord,
978  0,
979  vec_value,
980  bool_constant<oob_conditional_check>{},
981  bool_constant<pre_nop>{});
982 
983  // move thread coordinate
984  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
985  {
986  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
987 
988  constexpr auto idx_diff_ps_ys = container_concat(
989  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
990  idx_diff_ys);
991 
993  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
994  }
995  });
996  });
997  }
998 
999  // Custom move behavior
1001  {
1002  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
1003  move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(),
1004  pre_computed_coords_(iCoord)(I1),
1005  step);
1006  });
1007 
1008  if constexpr(Base::BottomTensorView::buffer_view::get_address_space() ==
1009  address_space_enum::global)
1010  {
1011  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
1012  move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(),
1013  pre_computed_warp_coords_(iCoord)(I1),
1014  step);
1015  });
1016  }
1017  }
1018 
1020  {
1021  // TODO: this use less register for FA, but more register for GEMM
1022  // need investigation
1023  const auto window_adaptor_thread_coord_tmp =
1024  make_tensor_adaptor_coordinate(this->tile_dstr_.get_ps_ys_to_xs_adaptor(),
1027 
1028  typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
1029  this->window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
1030 
1031  const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
1032  this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
1033 
1034  // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
1035  // future load/store() calls (might allocate more registers)
1036  using Traits = typename Base::Traits;
1037  using SFC_Ys = typename Traits::SFC_Ys;
1038 
1039  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
1040  auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
1041  auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
1042 
1043  constexpr auto idx_diff_ys =
1044  SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
1045 
1046  constexpr auto idx_diff_ps_ys = container_concat(
1047  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
1048  idx_diff_ys);
1049 
1051  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
1052 
1053  pre_computed_coords_(iCoord) =
1054  make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
1055  });
1056  }
1057 
1058  // this contains:
1059  // per-thread coordinate for window adaptor
1060  // per-thread coordinate for bottom tensor
1063  // pre_computed_warp_coords_ exists only in the global memory tile_window
1065  Base::BottomTensorView::buffer_view::get_address_space() == address_space_enum::global,
1067  std::byte>
1069 };
1070 
1071 // TODO: use strategy
1072 template <typename TensorView_,
1073  typename WindowLengths_,
1074  typename StaticTileDistribution_,
1075  index_t NumCoord = 1>
1076 CK_TILE_DEVICE constexpr auto
1077 make_tile_window(const TensorView_& tensor_view,
1078  const WindowLengths_& window_lengths,
1079  const multi_index<TensorView_::get_num_of_dimension()>& origin,
1080  const StaticTileDistribution_& tile_distribution,
1081  number<NumCoord> = {})
1082 {
1083  return tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
1084  remove_cvref_t<WindowLengths_>,
1085  remove_cvref_t<StaticTileDistribution_>,
1086  NumCoord>{
1087  tensor_view, window_lengths, origin, tile_distribution};
1088 }
1089 
1090 template <typename TensorView_,
1091  typename WindowLengths_,
1092  typename StaticTileDistribution_,
1093  index_t NumCoord = 1,
1094  typename = std::enable_if_t<is_tensor_view_v<TensorView_> &&
1095  is_tile_distribution_v<StaticTileDistribution_>>>
1096 CK_TILE_DEVICE constexpr auto
1097 make_tile_window(const TensorView_& tensor_view,
1098  const WindowLengths_& window_lengths,
1099  const multi_index<TensorView_::get_num_of_dimension()>& origin,
1100  const StaticTileDistribution_& tile_distribution,
1101  decltype(get_partition_index(tile_distribution)) partition_index,
1102  number<NumCoord> = {})
1103 {
1104  return tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
1105  remove_cvref_t<WindowLengths_>,
1106  remove_cvref_t<StaticTileDistribution_>,
1107  NumCoord>{
1108  tensor_view, window_lengths, origin, tile_distribution, partition_index};
1109 }
1110 
1111 // this version can't be called in a constexpr context
1112 template <typename TensorView_,
1113  typename WindowLengths_,
1114  typename StaticTileDistribution_,
1115  index_t NumCoord = 1>
1116 CK_TILE_DEVICE auto
1118  const WindowLengths_& window_lengths,
1119  const multi_index<TensorView_::get_num_of_dimension()>& origin,
1120  const StaticTileDistribution_& tile_distribution,
1121  number<NumCoord> = {})
1122 {
1123  auto w = tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
1124  remove_cvref_t<WindowLengths_>,
1125  remove_cvref_t<StaticTileDistribution_>,
1126  NumCoord>{
1127  tensor_view, window_lengths, origin, tile_distribution};
1128  w.init_raw();
1129  return w;
1130 }
1131 
1132 template <typename TensorView_,
1133  typename WindowLengths_,
1134  typename StaticTileDistribution_,
1135  index_t NumCoord>
1138  WindowLengths_,
1139  StaticTileDistribution_,
1140  NumCoord>& window,
1141  const typename tile_window_with_static_distribution<TensorView_,
1142  WindowLengths_,
1143  StaticTileDistribution_,
1144  NumCoord>::BottomTensorIndex& step)
1145 {
1146  window.move(step);
1147 }
1148 
1149 template <typename TensorView_,
1150  typename WindowLengths_,
1151  typename StaticTileDistribution_,
1152  index_t NumCoord>
1155  WindowLengths_,
1156  StaticTileDistribution_,
1157  NumCoord>>& window,
1158  const typename tile_window_with_static_distribution<TensorView_,
1159  WindowLengths_,
1160  StaticTileDistribution_,
1161  NumCoord>::BottomTensorIndex& step)
1162 {
1163  using T = tuple<tile_window_with_static_distribution<TensorView_,
1164  WindowLengths_,
1165  StaticTileDistribution_,
1166  NumCoord>>;
1167 
1168  static constexpr auto N = T::size();
1169  static_for<0, N, 1>{}([&](auto Is) { window[number<Is>{}].move(step); });
1170 }
1171 
1172 template <typename TileWindowWithStaticDistributionType,
1173  typename StepType,
1174  typename std::enable_if_t<
1176 CK_TILE_DEVICE void move_tile_window(TileWindowWithStaticDistributionType& window, StepType& step)
1177 {
1178  static constexpr auto N = TileWindowWithStaticDistributionType::size();
1179  static_for<0, N, 1>{}([&](auto Is) { window[number<Is>{}].move(step); });
1180 }
1181 
1190 template <typename BottomTensorView_, typename WindowLengths_>
1192  : public tile_window_base<tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>,
1193  BottomTensorView_,
1194  WindowLengths_>
1195 {
1196  using Base =
1198  BottomTensorView_,
1199  WindowLengths_>;
1200 
1202 
1204  const typename Base::BottomTensorView& bottom_tensor_view,
1205  const typename Base::WindowLengths& window_lengths,
1206  const typename Base::BottomTensorIndex& window_origin)
1207  {
1208  this->window_origin_ = window_origin;
1209  this->window_lengths_ = window_lengths;
1210  this->bottom_tensor_view_ = bottom_tensor_view;
1211  }
1212 
1226  template <typename DataType>
1228  index_t end_i,
1229  index_t start_j,
1230  index_t end_j,
1231  const char* label = "") const
1232  {
1233  const auto& tensor_view = this->get_bottom_tensor_view();
1234  const auto window_origin = this->get_window_origin();
1235 
1236  printf("%s Window Range [%d:%d, %d:%d] (origin: %d, %d):\n",
1237  label,
1238  start_i,
1239  end_i - 1,
1240  start_j,
1241  end_j - 1,
1242  window_origin[0],
1243  window_origin[1]);
1244 
1245  for(index_t i = start_i; i < end_i; i++)
1246  {
1247  for(index_t j = start_j; j < end_j; j++)
1248  {
1249  // Create coordinate for this element relative to window origin
1250  auto coord =
1252  make_tuple(window_origin[0] + i, window_origin[1] + j));
1253 
1254  // Get the element using thread buffer type directly
1255  using ThreadBuf = thread_buffer<DataType, 2>;
1256  auto buf = tensor_view.template get_vectorized_elements<ThreadBuf>(coord, 0);
1257  auto value = buf.at(number<0>{}); // Extract first element from thread buffer
1258  printf(" %s[%d,%d] = %f", label, i, j, type_convert<float>(value));
1259  }
1260  printf("\n");
1261  }
1262  printf("\n");
1263  }
1264 };
1265 
1266 template <typename TensorView_, typename WindowLengths_>
1267 CK_TILE_DEVICE constexpr auto
1268 make_tile_window(const TensorView_& tensor_view,
1269  const WindowLengths_& window_lengths,
1270  const multi_index<TensorView_::get_num_of_dimension()>& origin)
1271 {
1273  "wrong! lengths should be static");
1274 
1277  tensor_view, window_lengths, origin};
1278 }
1279 
1280 // duplicate tile window and replace its origin
1281 template <typename TensorView, typename WindowLengths>
1282 CK_TILE_DEVICE constexpr auto
1284  const multi_index<TensorView::get_num_of_dimension()>& origin)
1285 {
1287  tile_window.get_bottom_tensor_view(), tile_window.get_window_lengths(), origin};
1288 }
1289 
1290 template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
1291 CK_TILE_DEVICE constexpr auto
1293  const multi_index<TensorView::get_num_of_dimension()>& origin,
1294  const StaticTileDistribution& tile_distribution)
1295 {
1296  return make_tile_window(tile_window.get_bottom_tensor_view(),
1297  tile_window.get_window_lengths(),
1298  origin,
1300 }
1301 
1302 template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
1303 CK_TILE_DEVICE constexpr auto
1305  const StaticTileDistribution& tile_distribution)
1306 {
1307  return make_tile_window(tile_window.get_bottom_tensor_view(),
1308  tile_window.get_window_lengths(),
1309  tile_window.get_window_origin(),
1311 }
1312 
1313 template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
1314 CK_TILE_DEVICE constexpr auto
1316  const StaticTileDistribution& tile_distribution,
1317  decltype(get_partition_index(tile_distribution)) partition_index)
1318 {
1319  return make_tile_window(tile_window.get_bottom_tensor_view(),
1320  tile_window.get_window_lengths(),
1321  tile_window.get_window_origin(),
1323  partition_index);
1324 }
1325 
1326 template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
1327 CK_TILE_DEVICE constexpr auto
1329  const StaticTileDistribution& tile_distribution)
1330 {
1331  auto w = make_tile_window(tile_window, tile_distribution);
1332  w.init_raw();
1333  return w;
1334 }
1335 
1336 template <typename TensorView_, typename WindowLengths_>
1340  step)
1341 {
1342  window.move(step);
1343 }
1344 
1345 template <typename NewTensorView_,
1346  typename OldTensorView_,
1347  typename WindowLengths_,
1348  typename StaticTileDistribution_,
1349  index_t NumCoord = 1>
1350 CK_TILE_DEVICE auto
1351 replace_bottom_tensor_view(const NewTensorView_& new_tensor_view,
1352  const tile_window_with_static_distribution<OldTensorView_,
1353  WindowLengths_,
1354  StaticTileDistribution_,
1355  NumCoord>& tile_window)
1356 {
1357  return make_tile_window(new_tensor_view,
1358  tile_window.get_window_lengths(),
1359  tile_window.get_window_origin(),
1360  tile_window.get_tile_distribution());
1361 }
1362 
1363 template <typename NewTensorView_, typename OldTensorView_, typename WindowLengths_>
1365  const NewTensorView_& new_tensor_view,
1367 {
1368  return make_tile_window(
1369  new_tensor_view, tile_window.get_window_lengths(), tile_window.get_window_origin());
1370 }
1371 
1379 template <typename T>
1381 {
1382 };
1383 
1392 template <typename BottomTensorView_,
1393  typename WindowLengths_,
1394  typename StaticTileDistribution_,
1395  index_t NumCoord>
1397  tile_window_with_static_distribution<BottomTensorView_,
1398  WindowLengths_,
1399  StaticTileDistribution_,
1400  NumCoord>> : std::true_type
1401 {
1402 };
1403 
1411 template <typename T>
1414 
1422 template <typename T>
1424 {
1425 };
1426 
1433 template <typename BottomTensorView_, typename WindowLengths_>
1435  tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>> : std::true_type
1436 {
1437 };
1438 
1446 template <typename T>
1449 
1450 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_LDS_ADDR
Definition: config.hpp:62
Definition: cluster_descriptor.hpp:13
constexpr decltype(auto) apply(F &&f, Tuple &&t)
Definition: tuple.hpp:526
constexpr CK_TILE_HOST_DEVICE auto to_array(const std::vector< X > &x)
Definition: array.hpp:286
constexpr bool is_tile_window_with_static_distribution_v
Helper variable template to check if a type is a tile window with static distribution.
Definition: tile_window.hpp:1412
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
constexpr CK_TILE_HOST_DEVICE void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition: tensor_coordinate.hpp:72
CK_TILE_DEVICE auto replace_bottom_tensor_view(const NewTensorView_ &new_tensor_view, const tile_scatter_gather< OldTensorView_, WindowLengths_, StaticTileDistribution_, StaticPageIndexArray_, StaticValidArray_, HsGatherDim, NumCoord > &tile_window)
Definition: tile_scatter_gather.hpp:1041
CK_TILE_DEVICE auto make_tile_window_raw(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, number< NumCoord >={})
Definition: tile_window.hpp:1117
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:56
constant< b > bool_constant
Definition: integral_constant.hpp:43
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition: tensor_coordinate.hpp:60
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto to_multi_index(const T &x)
Definition: multi_index.hpp:33
constexpr bool is_tile_window_with_static_lengths_v
Helper variable template to check if a type is a tile window with static lengths.
Definition: tile_window.hpp:1447
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
Definition: utility.hpp:19
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
Definition: utility.hpp:25
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
bool_constant< false > false_type
Definition: integral_constant.hpp:63
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
bool_constant< true > true_type
Definition: integral_constant.hpp:62
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1697
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:76
Type trait to determine if a type is a tile window with static distribution.
Definition: tile_window.hpp:1381
Type trait to determine if a type is a tile window with static lengths.
Definition: tile_window.hpp:1424
Definition: coordinate_transform.hpp:1392
Definition: static_distributed_tensor.hpp:21
constexpr CK_TILE_HOST_DEVICE const auto & get_thread_buffer() const
Definition: static_distributed_tensor.hpp:58
Definition: functional.hpp:81
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
constexpr CK_TILE_HOST_DEVICE auto & get_tensor_descriptor() const
Definition: tensor_view.hpp:61
Definition: debug.hpp:27
Definition: tile_distribution.hpp:70
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:124
This class provides description of tile windowed view on the device memory.
Definition: tile_window_base.hpp:31
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition: tile_window_base.hpp:67
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_window_base.hpp:46
This class provides tile (windowed) view and access to the device memory.
Definition: tile_window.hpp:47
CK_TILE_DEVICE void store_raw(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}) const
Definition: tile_window.hpp:800
CK_TILE_DEVICE void move_extended(const typename Base::BottomTensorIndex &step)
Definition: tile_window.hpp:1000
constexpr CK_TILE_DEVICE tile_window_with_static_distribution(const typename Base::BottomTensorView &bottom_tensor_view, const typename Base::WindowLengths &window_lengths, const typename Base::BottomTensorIndex &window_origin, const typename Base::TileDstr &tile_distribution, decltype(get_partition_index(tile_distribution)) partition_index)
Definition: tile_window.hpp:67
CK_TILE_DEVICE void async_load_raw(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window.hpp:457
CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex &)
Definition: tile_window.hpp:1019
array< tuple< typename Base::WindowAdaptorCoord, typename Base::BottomTensorCoord >, NumCoord > pre_computed_coords_
Definition: tile_window.hpp:1062
CK_TILE_DEVICE auto load_transpose(number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:641
CK_TILE_DEVICE auto load_transpose_with_offset(index_t offset, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:649
CK_TILE_DEVICE void update_raw(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window.hpp:931
CK_TILE_DEVICE void load_with_offset(offset_t offset, static_distributed_tensor< DataType, StaticTileDistribution > &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:310
constexpr CK_TILE_DEVICE tile_window_with_static_distribution()=default
CK_TILE_DEVICE void async_load_with_offset(index_t offset, LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< static_move_ys >={}) const
Definition: tile_window.hpp:542
constexpr CK_TILE_DEVICE auto get_load_offset(offset_t={}) const
Definition: tile_window.hpp:297
CK_TILE_DEVICE auto load(number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:152
CK_TILE_DEVICE void load(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:288
constexpr CK_TILE_DEVICE auto prepare_coords(const typename Base::BottomTensorView &bottom_tensor_view, const typename Base::BottomTensorIndex &window_origin, const typename Base::TileDstr &tile_distribution, decltype(get_partition_index(tile_distribution)) partition_index) const
Definition: tile_window.hpp:108
static constexpr auto I0
Definition: tile_window.hpp:57
CK_TILE_DEVICE auto load(const TileWindow_ &tile_window, ElementWise_ elementwise, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Load tile with elementwise function.
Definition: tile_window.hpp:189
CK_TILE_DEVICE void load_raw(DstTile &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window.hpp:385
CK_TILE_DEVICE void load(DistributedTensor &dst_tensor, const TileWindow_ &tile_window, ElementWise_ elementwise, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:209
static constexpr auto I1
Definition: tile_window.hpp:58
CK_TILE_DEVICE auto load_with_offset(offset_t offset, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:162
CK_TILE_DEVICE void update(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:863
constexpr CK_TILE_DEVICE tile_window_with_static_distribution(const typename Base::BottomTensorView &bottom_tensor_view, const typename Base::WindowLengths &window_lengths, const typename Base::BottomTensorIndex &window_origin, const typename Base::TileDstr &tile_distribution)
Definition: tile_window.hpp:94
std::conditional_t< Base::BottomTensorView::buffer_view::get_address_space()==address_space_enum::global, array< tuple< typename Base::WindowAdaptorCoord, typename Base::BottomTensorCoord >, NumCoord >, std::byte > pre_computed_warp_coords_
Definition: tile_window.hpp:1068
static constexpr index_t NumAccessPerCoord
Definition: tile_window.hpp:63
CK_TILE_DEVICE void load_transpose_with_offset(index_t offset, DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:666
CK_TILE_DEVICE void store(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:730
This class provides description of tile windowed view on the device memory.
Definition: tile_window.hpp:1195
constexpr CK_TILE_DEVICE tile_window_with_static_lengths()=default
CK_TILE_DEVICE void print_tile_window_range(index_t start_i, index_t end_i, index_t start_j, index_t end_j, const char *label="") const
Definition: tile_window.hpp:1227
constexpr CK_TILE_DEVICE tile_window_with_static_lengths(const typename Base::BottomTensorView &bottom_tensor_view, const typename Base::WindowLengths &window_lengths, const typename Base::BottomTensorIndex &window_origin)
Definition: tile_window.hpp:1203
Definition: tile_window_base.hpp:94
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition: tile_window_base.hpp:129
Definition: tuple.hpp:192