/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/core/tensor/tile_window.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/core/tensor/tile_window.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/core/tensor/tile_window.hpp Source File
tile_window.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
19 
20 namespace ck_tile {
21 
33 template <typename BottomTensorView_,
34  typename WindowLengths_,
35  typename StaticTileDistribution_,
36  index_t NumCoord>
39  tile_window_with_static_distribution<BottomTensorView_,
40  WindowLengths_,
41  StaticTileDistribution_,
42  NumCoord>,
43  BottomTensorView_,
44  WindowLengths_,
45  StaticTileDistribution_>
46 {
48  tile_window_with_static_distribution<BottomTensorView_,
49  WindowLengths_,
50  StaticTileDistribution_,
51  NumCoord>,
52  BottomTensorView_,
53  WindowLengths_,
54  StaticTileDistribution_>;
55 
56  static constexpr auto I0 = number<0>{};
57  static constexpr auto I1 = number<1>{};
58  static_assert(NumCoord == 1);
59 
60  static_assert(Base::Traits::NumAccess % NumCoord == 0,
61  "wrong! # of access is not divisible by NumCoord");
62  static constexpr index_t NumAccessPerCoord = Base::Traits::NumAccess / NumCoord;
63 
65 
67  const typename Base::BottomTensorView& bottom_tensor_view,
68  const typename Base::WindowLengths& window_lengths,
69  const typename Base::BottomTensorIndex& window_origin,
70  const typename Base::TileDstr& tile_distribution)
72  {
73 
74  this->window_origin_ = window_origin;
75  this->window_lengths_ = window_lengths;
76  this->bottom_tensor_view_ = bottom_tensor_view;
77  this->tile_dstr_ = tile_distribution;
78  const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
82 
83  typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
84  window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
85 
86  const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
87  bottom_tensor_view.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
88 
89  // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
90  // future load/store() calls (might allocate more registers)
91  using Traits = typename Base::Traits;
92  using SFC_Ys = typename Traits::SFC_Ys;
93 
94  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
95  auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
96  auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
97 
98  constexpr auto idx_diff_ys =
99  SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
100 
101  constexpr auto idx_diff_ps_ys = container_concat(
102  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
103  idx_diff_ys);
104 
106  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
107 
108  pre_computed_coords_(iCoord) =
109  make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
110  });
111  }
112 
113  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
116  {
117  constexpr auto tile_dstr = typename Base::TileDstr{};
118  auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
119  load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
120  return dst_tensor;
121  }
122 
123  template <typename DistributedTensor,
124  index_t i_access_unsupport_ = -1,
125  bool oob_conditional_check = true>
126  CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
129  {
130  using Traits = typename Base::Traits;
131  using vector_t = typename Traits::vector_t;
132  using SFC_Ys = typename Traits::SFC_Ys;
133 
134  constexpr auto tile_dstr = typename Base::TileDstr{};
135 
136  // loop over thread tensor space [y0, y1, ...]
137  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
139  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
140  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
141 
142  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
143  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
144 
145  // data index [y0, y1, ...]
146  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
147 
148  // read from bottom tensor
149  const vector_t vec_value =
150  this->get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
151  bottom_tensor_thread_coord, 0, bool_constant<oob_conditional_check>{});
152  // write into distributed tensor
153  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
154  constexpr auto idx_ys = generate_tuple(
155  [&](auto jj) {
156  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
157  : idx_ys_start[jj];
158  },
159  number<Base::NDimY>{});
160 
161  constexpr index_t d =
162  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
163  Traits::PackedSize;
164 
165  dst_tensor.get_thread_buffer().template at<d>() =
166  vec_value
167  .template get_as<typename Base::DataType>()[j / Traits::PackedSize];
168  });
169  // move thread coordinate
170  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
171  {
172  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
173 
174  constexpr auto idx_diff_ps_ys = container_concat(
175  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
176  idx_diff_ys);
177 
179  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
180  }
181  });
182  });
183  }
184 
185  template <typename DstTile,
186  index_t i_access_unsupport_ = -1,
187  bool oob_conditional_check = true,
188  bool pre_nop = false>
189  CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
192  bool_constant<pre_nop> = {}) const
193  {
194  using Traits = typename Base::Traits;
195  using vector_t = typename Traits::vector_t;
196  using SFC_Ys = typename Traits::SFC_Ys;
197  static constexpr index_t YElementSize =
198  typename Base::TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
199  static_assert(YElementSize % (Traits::PackedSize * Traits::ScalarPerVector) == 0);
200  using vectorized_tbuf =
201  array<vector_t, YElementSize / (Traits::PackedSize * Traits::ScalarPerVector)>;
202 
203  constexpr auto tile_dstr = typename Base::TileDstr{};
204 
205  auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
206 
207  // loop over thread tensor space [y0, y1, ...]
208  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
210  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
211  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
212 
213  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
214  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
215  constexpr auto pre_nop_ = [&]() {
216  if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
217  return bool_constant<true>{};
218  else
219  return bool_constant<false>{};
220  }();
221 
222  // data index [y0, y1, ...]
223  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
224  constexpr index_t d =
225  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start) /
226  Traits::PackedSize;
227  static_assert(d % Traits::ScalarPerVector == 0);
228 
229  this->get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
230  dst_vec_tbuf.template at<d / Traits::ScalarPerVector>(),
231  bottom_tensor_thread_coord,
232  0 ,
233  bool_constant<oob_conditional_check>{},
234  pre_nop_);
235 #if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
236  CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
237  asm volatile(
238  ""); // this is starting from rocm-6.2, but same sympton, reuse this flag
239 #endif
240  // move thread coordinate
241  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
242  {
243  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
244 
245  constexpr auto idx_diff_ps_ys = container_concat(
246  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
247  idx_diff_ys);
248 
250  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
251  }
252  });
253  });
254  }
255 
256  // TODO: currently async load only implemented in inline asm
257  template <typename LdsTileWindow_,
258  index_t i_access_unsupport_ = -1,
259  bool oob_conditional_check = true,
260  bool pre_nop = false>
261  CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
264  bool_constant<pre_nop> = {}) const
265  {
266  using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
267  // using LdsTensorView = typename LdsTileWindow::BottomTensorView;
268  using LdsDataType = typename LdsTileWindow::DataType;
269  // using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc;
270 
271  // issues * warps * lanes
272  static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
273 
274  const index_t size_per_buf =
275  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
276  make_tuple(number<0>{}, number<0>{}, number<0>{})) *
277  sizeof(LdsDataType);
278 
279  const index_t size_per_wave =
280  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
281  make_tuple(number<0>{}, number<1>{}, number<0>{})) *
282  sizeof(LdsDataType) -
283  size_per_buf;
284 
285  const index_t size_per_issue =
286  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
287  make_tuple(number<1>{}, number<0>{}, number<0>{})) *
288  sizeof(LdsDataType) -
289  size_per_buf;
290 
291  const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
292  m0_set_with_memory(m0_init_value); // This should be wave independent
293 
294  using Traits = typename Base::Traits;
295 
296  using vector_t = typename Traits::vector_t;
297  using SFC_Ys = typename Traits::SFC_Ys;
298 
299  LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
300 
301  // loop over thread tensor space [y0, y1, ...]
302  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
304  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
305  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
306 
307  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
308  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
309  constexpr auto pre_nop_ = [&]() {
310  if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
311  return bool_constant<true>{};
312  else
313  return bool_constant<false>{};
314  }();
315 
316  // read from bottom tensor
317  this->get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
318  smem, bottom_tensor_thread_coord, 0, pre_nop_);
319 
320  // move thread coordinate
321  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
322  {
323  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
324 
325  constexpr auto idx_diff_ps_ys = container_concat(
326  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
327  idx_diff_ys);
328 
330  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
331 
332  m0_inc_with_memory(size_per_issue);
333  }
334  });
335  });
336  }
337 
338  template <typename LdsTileWindow_,
339  index_t i_access_unsupport_ = -1,
340  bool oob_conditional_check = true>
341  CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
344  {
345  using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
346  using LdsDataType = typename LdsTileWindow::DataType;
347  using Traits = typename Base::Traits;
348 
349  using vector_t = typename Traits::vector_t;
350  using SFC_Ys = typename Traits::SFC_Ys;
351 
352  // Precompute invariant values outside loops
353  const auto window_origin = lds_tile.get_window_origin();
354  const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
355  const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
356  auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
357 
358  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
359  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
360  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
361 
362  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
363  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
364 
365  // Use precomputed window origin
366  auto lds_bottom_tensor_thread_idx =
367  window_origin + window_adaptor_thread_coord.get_bottom_index();
368 
369  // Use precomputed tensor descriptor
370  const auto lds_coord =
371  make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
372 
373  // Calculate SMEM address using base pointer
374  CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
375 
376  // Write into bottom tensor
377  this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
378  smem,
379  bottom_tensor_thread_coord,
380  number<0>{},
381  bool_constant<oob_conditional_check>{});
382 
383  // Move thread coordinate if not last access
384  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
385  {
386  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
387  constexpr auto idx_diff_ps_ys = container_concat(
388  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
389  idx_diff_ys);
390 
392  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
393  }
394  });
395  });
396  }
397 
398  template <typename Policy, index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
400  {
401  constexpr auto tile_dstr = typename Base::TileDstr{};
402  auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
403  this->template load_transpose<Policy>(
405  return dst_tensor;
406  }
407 
408  template <typename Policy,
409  typename DistributedTensor,
410  index_t i_access_unsupport_ = -1,
411  bool oob_conditional_check = true>
412  CK_TILE_DEVICE auto load_transpose(DistributedTensor& dst_tensor,
415  {
416  using Traits = typename Base::Traits;
417  using vector_t = typename Traits::vector_t;
418  using SFC_Ys = typename Traits::SFC_Ys;
419 
420  constexpr auto tile_dstr = typename Base::TileDstr{};
421 
422  constexpr auto group_func = Policy::group_func;
423 
424  // loop over thread tensor space [y0, y1, ...]
425  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
427  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
428  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
429 
430  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
431  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
432 
433  // data index [y0, y1, ...]
434  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
435 
436  // read from bottom tensor
437  const vector_t vec_value =
438  this->get_bottom_tensor_view()
439  .template get_transpose_vectorized_elements<vector_t>(
440  bottom_tensor_thread_coord, 0);
441  // write into distributed tensor
442  static_for<0, Traits::ScalarPerVector, 1>{}([&](auto j) {
443  constexpr auto orig_idx_ys = generate_tuple(
444  [&](auto jj) {
445  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
446  : idx_ys_start[jj];
447  },
448  number<Base::NDimY>{});
449 
450  constexpr auto grouped_idx_ys = group_func(orig_idx_ys);
451 
452  constexpr index_t linear_distributed_index =
453  tile_dstr.get_ys_to_d_descriptor().calculate_offset(grouped_idx_ys);
454 
455  dst_tensor.get_thread_buffer().template at<linear_distributed_index>() =
456  vec_value.template get_as<typename Base::DataType>()[j];
457  });
458  // move thread coordinate
459  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
460  {
461  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
462 
463  constexpr auto idx_diff_ps_ys = container_concat(
464  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
465  idx_diff_ys);
466 
468  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
469  }
470  });
471  });
472  }
473 
474  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
476  typename Base::TileDstr>& dstr_tensor,
479  {
480  using Traits = typename Base::Traits;
481 
482  using vector_t = typename Traits::vector_t;
483  using SFC_Ys = typename Traits::SFC_Ys;
484 
485  constexpr auto tile_dstr = typename Base::TileDstr{};
486 
487  // loop over thread tensor space [y0, y1, ...]
488  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
489  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
490  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
491 
492  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
493  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
494 
495  // data index [y0, y1, ...]
496  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
497 
498  // read from distributed tensor
499  // vector_type_t vec;
500  vector_t vec_value;
501 
502  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
503  constexpr auto idx_ys = generate_tuple(
504  [&](auto jj) {
505  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
506  : idx_ys_start[jj];
507  },
508  number<Base::NDimY>{});
509 
510  constexpr index_t d =
511  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
512  Traits::PackedSize;
513 
514  vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
515  dstr_tensor.get_thread_buffer().template at<d>();
516  });
517 
518  // const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
519 
520  // write into bottom tensor
521  this->get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
522  bottom_tensor_thread_coord,
523  0,
524  vec_value,
525  bool_constant<oob_conditional_check>{});
526 
527  // move thread coordinate
528  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
529  {
530  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
531 
532  constexpr auto idx_diff_ps_ys = container_concat(
533  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
534  idx_diff_ys);
535 
537  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
538  }
539  });
540  });
541  }
542 
543  template <index_t i_access_unsupport_ = -1>
544  CK_TILE_DEVICE void
546  dstr_tensor,
547  number<i_access_unsupport_> = {}) const
548  {
549  using Traits = typename Base::Traits;
550 
551  using vector_t = typename Traits::vector_t;
552  using SFC_Ys = typename Traits::SFC_Ys;
553 
554  constexpr auto tile_dstr = typename Base::TileDstr{};
555  static constexpr bool oob_conditional_check = true;
556 
557  // loop over thread tensor space [y0, y1, ...]
558  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
560  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
561  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
562 
563  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
564  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
565 
566  // data index [y0, y1, ...]
567  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
568 
569  // read from distributed tensor
570  vector_t vec_value;
571  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
572  constexpr auto idx_ys = generate_tuple(
573  [&](auto jj) {
574  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
575  : idx_ys_start[jj];
576  },
577  number<Base::NDimY>{});
578  constexpr index_t d =
579  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
580  Traits::PackedSize;
581  vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
582  dstr_tensor.get_thread_buffer().template at<d>();
583  });
584 
585  // write into bottom tensor
586  this->get_bottom_tensor_view()
587  .template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
588  bottom_tensor_thread_coord, 0, vec_value);
589 
590  // move thread coordinate
591  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
592  {
593  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
594 
595  constexpr auto idx_diff_ps_ys = container_concat(
596  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
597  idx_diff_ys);
598 
600  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
601  }
602  });
603  });
604  }
605 
606  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
607  CK_TILE_DEVICE void
609  dstr_tensor,
612  {
613  using Traits = typename Base::Traits;
614 
615  using vector_t = typename Traits::vector_t;
616  using SFC_Ys = typename Traits::SFC_Ys;
617 
618  constexpr auto tile_dstr = typename Base::TileDstr{};
619 
620  // loop over thread tensor space [y0, y1, ...]
621  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
623  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
624  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
625 
626  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
627  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
628 
629  // data index [y0, y1, ...]
630  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
631 
632  // read from distributed tensor
633  vector_t vec_value;
634 
635  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
636  constexpr auto idx_ys = generate_tuple(
637  [&](auto jj) {
638  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
639  : idx_ys_start[jj];
640  },
641  number<Base::NDimY>{});
642 
643  constexpr index_t d =
644  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
645  Traits::PackedSize;
646 
647  vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
648  dstr_tensor.get_thread_buffer().template at<d>();
649  });
650 
651  // write into bottom tensor
652  this->get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
653  bottom_tensor_thread_coord,
654  0,
655  vec_value,
656  bool_constant<oob_conditional_check>{});
657 
658  // move thread coordinate
659  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
660  {
661  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
662 
663  constexpr auto idx_diff_ps_ys = container_concat(
664  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
665  idx_diff_ys);
666 
668  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
669  }
670  });
671  });
672  }
673 
674  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true, bool pre_nop>
675  CK_TILE_DEVICE void
677  dstr_tensor,
680  bool_constant<pre_nop> = {}) const
681  {
682  using Traits = typename Base::Traits;
683 
684  using vector_t = typename Traits::vector_t;
685  using SFC_Ys = typename Traits::SFC_Ys;
686 
687  constexpr auto tile_dstr = typename Base::TileDstr{};
688 
689  // loop over thread tensor space [y0, y1, ...]
690  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
692  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
693  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
694 
695  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
696  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
697 
698  // data index [y0, y1, ...]
699  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
700 
701  // read from distributed tensor
702  vector_t vec_value;
703 
704  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
705  constexpr auto idx_ys = generate_tuple(
706  [&](auto jj) {
707  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
708  : idx_ys_start[jj];
709  },
710  number<Base::NDimY>{});
711 
712  constexpr index_t d =
713  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
714  Traits::PackedSize;
715 
716  vec_value.template get_as<typename Base::DataType>()(j / Traits::PackedSize) =
717  dstr_tensor.get_thread_buffer().template at<d>();
718  });
719 
720  // write into bottom tensor
721  this->get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
722  bottom_tensor_thread_coord,
723  0,
724  vec_value,
725  bool_constant<oob_conditional_check>{},
726  bool_constant<pre_nop>{});
727 
728  // move thread coordinate
729  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
730  {
731  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
732 
733  constexpr auto idx_diff_ps_ys = container_concat(
734  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
735  idx_diff_ys);
736 
738  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
739  }
740  });
741  });
742  }
743 
744  // Custom move behavior
746  {
747  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
748  move_tensor_coordinate(this->bottom_tensor_view_.get_tensor_descriptor(),
749  pre_computed_coords_(iCoord)(I1),
750  step);
751  });
752  }
753 
755  {
756  // TODO: this use less register for FA, but more register for GEMM
757  // need investigation
758  const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
759  this->tile_dstr_.get_ps_ys_to_xs_adaptor(),
762 
763  typename Base::BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
764  this->window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
765 
766  const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
767  this->bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
768 
769  // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
770  // future load/store() calls (might allocate more registers)
771  using Traits = typename Base::Traits;
772  using SFC_Ys = typename Traits::SFC_Ys;
773 
774  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
775  auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
776  auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
777 
778  constexpr auto idx_diff_ys =
779  SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
780 
781  constexpr auto idx_diff_ps_ys = container_concat(
782  generate_tuple([&](auto) { return number<0>{}; }, number<Base::NDimP>{}),
783  idx_diff_ys);
784 
786  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
787 
788  pre_computed_coords_(iCoord) =
789  make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
790  });
791  }
792 
793  // this contains:
794  // per-thread coordinate for window adaptor
795  // per-thread coordinate for bottom tensor
798 };
799 
800 // TODO: use strategy
801 template <typename TensorView_,
802  typename WindowLengths_,
803  typename StaticTileDistribution_,
804  index_t NumCoord = 1>
805 CK_TILE_DEVICE constexpr auto
806 make_tile_window(const TensorView_& tensor_view,
807  const WindowLengths_& window_lengths,
808  const multi_index<TensorView_::get_num_of_dimension()>& origin,
809  const StaticTileDistribution_& tile_distribution,
810  number<NumCoord> = {})
811 {
812  return tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
813  remove_cvref_t<WindowLengths_>,
814  remove_cvref_t<StaticTileDistribution_>,
815  NumCoord>{
816  tensor_view, window_lengths, origin, tile_distribution};
817 }
818 
819 // this version can't be called in a constexpr context
820 template <typename TensorView_,
821  typename WindowLengths_,
822  typename StaticTileDistribution_,
823  index_t NumCoord = 1>
824 CK_TILE_DEVICE auto
826  const WindowLengths_& window_lengths,
827  const multi_index<TensorView_::get_num_of_dimension()>& origin,
828  const StaticTileDistribution_& tile_distribution,
829  number<NumCoord> = {})
830 {
831  auto w = tile_window_with_static_distribution<remove_cvref_t<TensorView_>,
832  remove_cvref_t<WindowLengths_>,
833  remove_cvref_t<StaticTileDistribution_>,
834  NumCoord>{
835  tensor_view, window_lengths, origin, tile_distribution};
836  w.init_raw();
837  return w;
838 }
839 
840 template <typename TensorView_,
841  typename WindowLengths_,
842  typename StaticTileDistribution_,
843  index_t NumCoord>
846  WindowLengths_,
847  StaticTileDistribution_,
848  NumCoord>& window,
849  const typename tile_window_with_static_distribution<TensorView_,
850  WindowLengths_,
851  StaticTileDistribution_,
852  NumCoord>::BottomTensorIndex& step)
853 {
854  window.move(step);
855 }
856 
865 template <typename BottomTensorView_, typename WindowLengths_>
867  : public tile_window_base<tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>,
868  BottomTensorView_,
869  WindowLengths_>
870 {
871  using Base =
873  BottomTensorView_,
874  WindowLengths_>;
875 
877 
879  const typename Base::BottomTensorView& bottom_tensor_view,
880  const typename Base::WindowLengths& window_lengths,
881  const typename Base::BottomTensorIndex& window_origin)
882  {
883  this->window_origin_ = window_origin;
884  this->window_lengths_ = window_lengths;
885  this->bottom_tensor_view_ = bottom_tensor_view;
886  }
887 };
888 
889 template <typename TensorView_, typename WindowLengths_>
890 CK_TILE_DEVICE constexpr auto
891 make_tile_window(const TensorView_& tensor_view,
892  const WindowLengths_& window_lengths,
893  const multi_index<TensorView_::get_num_of_dimension()>& origin)
894 {
896  "wrong! lengths should be static");
897 
900  tensor_view, window_lengths, origin};
901 }
902 
903 // duplicate tile window and replace its origin
904 template <typename TensorView, typename WindowLengths>
905 CK_TILE_DEVICE constexpr auto
907  const multi_index<TensorView::get_num_of_dimension()>& origin)
908 {
910  tile_window.get_bottom_tensor_view(), tile_window.get_window_lengths(), origin};
911 }
912 
913 template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
914 CK_TILE_DEVICE constexpr auto
916  const multi_index<TensorView::get_num_of_dimension()>& origin,
917  const StaticTileDistribution& tile_distribution)
918 {
919  return make_tile_window(tile_window.get_bottom_tensor_view(),
920  tile_window.get_window_lengths(),
921  origin,
923 }
924 
925 template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
926 CK_TILE_DEVICE constexpr auto
928  const StaticTileDistribution& tile_distribution)
929 {
930  return make_tile_window(tile_window.get_bottom_tensor_view(),
931  tile_window.get_window_lengths(),
932  tile_window.get_window_origin(),
934 }
935 
936 template <typename TensorView, typename WindowLengths, typename StaticTileDistribution>
937 CK_TILE_DEVICE constexpr auto
939  const StaticTileDistribution& tile_distribution)
940 {
941  auto w = make_tile_window(tile_window.get_bottom_tensor_view(),
942  tile_window.get_window_lengths(),
943  tile_window.get_window_origin(),
945  w.init_raw();
946  return w;
947 }
948 
949 template <typename TensorView_, typename WindowLengths_>
953  step)
954 {
955  window.move(step);
956 }
957 
965 template <typename T>
967 {
968 };
969 
978 template <typename BottomTensorView_,
979  typename WindowLengths_,
980  typename StaticTileDistribution_,
981  index_t NumCoord>
983  tile_window_with_static_distribution<BottomTensorView_,
984  WindowLengths_,
985  StaticTileDistribution_,
986  NumCoord>> : std::true_type
987 {
988 };
989 
997 template <typename T>
1000 
1008 template <typename T>
1010 {
1011 };
1012 
1019 template <typename BottomTensorView_, typename WindowLengths_>
1021  tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>> : std::true_type
1022 {
1023 };
1024 
1032 template <typename T>
1035 
1036 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_LDS_ADDR
Definition: config.hpp:57
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:22
Definition: cluster_descriptor.hpp:13
constexpr bool is_tile_window_with_static_distribution_v
Helper variable template to check if a type is a tile window with static distribution.
Definition: tile_window.hpp:998
constexpr CK_TILE_HOST_DEVICE void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition: tensor_coordinate.hpp:72
CK_TILE_DEVICE auto make_tile_window_raw(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, number< NumCoord >={})
Definition: tile_window.hpp:825
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:55
constant< b > bool_constant
Definition: integral_constant.hpp:39
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition: tensor_coordinate.hpp:60
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
CK_TILE_DEVICE index_t get_warp_id()
Definition: arch.hpp:74
constexpr bool is_tile_window_with_static_lengths_v
Helper variable template to check if a type is a tile window with static lengths.
Definition: tile_window.hpp:1033
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:92
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:406
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
Definition: utility.hpp:19
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
Definition: utility.hpp:25
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
bool_constant< false > false_type
Definition: integral_constant.hpp:63
bool_constant< true > true_type
Definition: integral_constant.hpp:62
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:76
Type trait to determine if a type is a tile window with static distribution.
Definition: tile_window.hpp:967
Type trait to determine if a type is a tile window with static lengths.
Definition: tile_window.hpp:1010
Definition: static_distributed_tensor.hpp:21
constexpr CK_TILE_HOST_DEVICE const auto & get_thread_buffer() const
Definition: static_distributed_tensor.hpp:58
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
Definition: tile_distribution.hpp:72
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:126
This class provides description of tile windowed view on the device memory.
Definition: tile_window_base.hpp:31
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_window_base.hpp:45
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition: tile_window_base.hpp:67
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_window_base.hpp:46
This class provides tile (windowed) view and access to the device memory.
Definition: tile_window.hpp:46
CK_TILE_DEVICE void store_raw(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}) const
Definition: tile_window.hpp:545
CK_TILE_DEVICE void move_extended(const typename Base::BottomTensorIndex &step)
Definition: tile_window.hpp:745
CK_TILE_DEVICE auto load_transpose() const
Definition: tile_window.hpp:399
CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex &)
Definition: tile_window.hpp:754
array< tuple< typename Base::WindowAdaptorCoord, typename Base::BottomTensorCoord >, NumCoord > pre_computed_coords_
Definition: tile_window.hpp:797
CK_TILE_DEVICE void update_raw(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window.hpp:676
constexpr CK_TILE_DEVICE tile_window_with_static_distribution()=default
CK_TILE_DEVICE auto async_load(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:341
CK_TILE_DEVICE auto load_transpose(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:412
CK_TILE_DEVICE auto load(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:126
CK_TILE_DEVICE auto load(number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:114
static constexpr auto I0
Definition: tile_window.hpp:56
CK_TILE_DEVICE void load_raw(DstTile &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window.hpp:189
static constexpr auto I1
Definition: tile_window.hpp:57
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window.hpp:261
CK_TILE_DEVICE void update(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:608
constexpr CK_TILE_DEVICE tile_window_with_static_distribution(const typename Base::BottomTensorView &bottom_tensor_view, const typename Base::WindowLengths &window_lengths, const typename Base::BottomTensorIndex &window_origin, const typename Base::TileDstr &tile_distribution)
Definition: tile_window.hpp:66
static constexpr index_t NumAccessPerCoord
Definition: tile_window.hpp:62
CK_TILE_DEVICE void store(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window.hpp:475
This class provides description of tile windowed view on the device memory.
Definition: tile_window.hpp:870
constexpr CK_TILE_DEVICE tile_window_with_static_lengths()=default
constexpr CK_TILE_DEVICE tile_window_with_static_lengths(const typename Base::BottomTensorView &bottom_tensor_view, const typename Base::WindowLengths &window_lengths, const typename Base::BottomTensorIndex &window_origin)
Definition: tile_window.hpp:878
Definition: tile_window_base.hpp:94
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition: tile_window_base.hpp:129