/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_scatter_gather.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_scatter_gather.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_scatter_gather.hpp Source File
tile_scatter_gather.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
18 
19 namespace ck_tile {
20 
32 template <typename BottomTensorView_,
33  typename WindowLengths_,
34  typename StaticTileDistribution_,
35  typename StaticPageIndexArray_,
36  typename StaticValidArray_,
37  index_t HsGatherDim = 0,
38  index_t NumCoord = 1,
39  index_t YsGatherDim = 0>
41 {
47  using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
48  using BottomTensorDesc = typename BottomTensorView::TensorDesc;
49 
51 
52  static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
53  static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
54 
55  static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
56  static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
57 
58  static constexpr auto I0 = number<0>{};
59  static constexpr auto I1 = number<1>{};
60  static_assert(NumCoord == 1);
61 
62  // TODO: check WindowLengths and StaticTileDistribution are consistent
63 
65  "wrong! lengths should be static");
66  static_assert(TileDstr::is_static(), "wrong!");
67 
68  static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
69  "wrong! inconsistent # of diemsnions");
70 
73 
76 
79 
81  {
82  private:
83  static constexpr auto get_vector_dim_y_scalar_per_vector()
84  {
85  const auto [ys_vector_lengths, ys_vector_strides] =
87 
88  index_t VectorDimY_ = 0;
89  index_t ScalarPerVector_ = 1;
90 
91  for(index_t i = 0; i < NDimY; ++i)
92  {
93  if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
94  {
95  ScalarPerVector_ = ys_vector_lengths[i];
96  VectorDimY_ = i;
97  }
98  }
99 
100  return make_tuple(VectorDimY_, ScalarPerVector_);
101  }
102 
103  public:
104  static constexpr index_t PackedSize =
106  static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
107  static constexpr index_t ScalarPerVector =
108  get_vector_dim_y_scalar_per_vector().template at<1>();
109 
110  // using vector_type_t = vector_type_maker_t<DataType, ScalarPerVector>;
111  // using vector_t = typename vector_type_t::type;
113 
114  private:
115  static constexpr auto scalars_per_access_ = [] {
116  constexpr auto scalars_per_access_arr = generate_array(
117  [&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
118 
120  constexpr auto NDimY_ = NDimY;
121 
122  return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
123  }();
124 
125  static constexpr auto get_space_filling_curve()
126  {
127  constexpr auto tile_dstr = TileDstr{};
128 
129  constexpr auto thread_tensor_lengths_ys =
130  to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
131 
132  // FIXME: need logic to judge dim access order
133  using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
134 
135  return space_filling_curve<decltype(thread_tensor_lengths_ys),
136  DimAccessOrder,
137  decltype(scalars_per_access_)>{};
138  }
139 
140  public:
141  using SFC_Ys = decltype(get_space_filling_curve());
142 
143  static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
144 
145  static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
146  static_assert(NumAccess % NumCoord == 0, "wrong! # of access is not divisible by NumCoord");
147  };
148 
150 
151  CK_TILE_DEVICE constexpr tile_scatter_gather() = default;
152 
153  CK_TILE_DEVICE constexpr tile_scatter_gather(const BottomTensorView& bottom_tensor_view,
154  const WindowLengths& window_lengths,
155  const BottomTensorIndex& window_origin,
157  const PageIdxArray& page_idx,
158  const ValidArray& valids)
159  : bottom_tensor_view_{bottom_tensor_view},
160  window_lengths_{window_lengths},
161  window_origin_{window_origin},
163  page_idx_{page_idx},
164  valids_{valids},
166  {
167 #if 0 // debug
168  // TODO: this use more register for FA, but less register for GEMM
169  // need investigation
170  // only support warp-tile and block-tile
171  static_assert(NDimP == 1 or NDimP == 2, "wrong!");
172 
173  WindowAdaptorCoord window_adaptor_thread_coord_tmp;
174 
175  if constexpr(NDimP == 1)
176  {
177  window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
179  }
180  else if constexpr(NDimP == 2)
181  {
182  window_adaptor_thread_coord_tmp =
184  AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
185  }
186 #else
187  // TODO: this use less register for FA, but more register for GEMM
188  // need investigation
189  const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
192 #endif
193 
194  BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
195  window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
196  bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
197  const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
198  bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
199 
200  // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
201  // future load/store() calls (might allocate more registers)
202  using Traits = load_store_traits;
203  using SFC_Ys = typename Traits::SFC_Ys;
204 
205  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
206  auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
207  auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
208 
209  constexpr auto idx_diff_ys =
210  SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
211 
212  constexpr auto idx_diff_ps_ys = container_concat(
213  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
214 
216  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
217 
218  pre_computed_coords_(iCoord) =
219  make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
220  });
221  }
222 
224 
226  {
227  return TileDstr::is_static();
228  }
229 
230  CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
231 
232  CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
233 
234  CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
235 
236  CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
237 
238  CK_TILE_DEVICE constexpr void
239  set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
240  {
241  bottom_tensor_view_.buf_.p_data_ = data;
242  }
243 
244  // move thread's window adaptor coordinate and bottom tensor coordinate
245  // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
246  template <typename ATopIndex>
248  WindowAdaptorCoord& window_adaptor_thread_coord,
249  BottomTensorCoord& bottom_tensor_thread_coord,
250  const ATopIndex& idx_diff_adaptor_top) const
251  {
252  array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
253 
254  move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
255  window_adaptor_thread_coord,
256  idx_diff_adaptor_top,
257  idx_diff_adaptor_bottom);
258 
259  move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
260  bottom_tensor_thread_coord,
261  idx_diff_adaptor_bottom);
262  }
263 
264  // return vector dimension among [y0, y1, ...]
266  {
267  // bottom tensor top dimension vector lengths and strides
268  const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
269  BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
270 
271  // window vector lengths/strides
272  const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
273  const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
274 
275  // window adaptor [p0, p1, ..., y0, y1, ...]
276  array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
277  -1};
278  array<index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
279  -1};
280 
281  constexpr auto window_adaptor_bottom_dims =
282  WindowAdaptor::get_bottom_dimension_hidden_ids();
283 
284  set_container_subset(window_adaptor_vector_lengths,
285  window_adaptor_bottom_dims,
286  window_adaptor_bottom_dim_vector_lengths);
287  set_container_subset(window_adaptor_vector_strides,
288  window_adaptor_bottom_dims,
289  window_adaptor_bottom_dim_vector_strides);
290 
291  const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
292  WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
293  window_adaptor_vector_lengths, window_adaptor_vector_strides);
294 
295  // [y0, y1, ...]
296  constexpr auto y_dims = typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
298  1>::type{};
299 
300  return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
301  get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
302  }
303 
305 
306  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
309  {
310  constexpr auto tile_dstr = TileDstr{};
311  auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
312  load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
313  return dst_tensor;
314  }
315 
316  template <typename DistributedTensor,
317  index_t i_access_unsupport_ = -1,
318  bool oob_conditional_check = true>
319  CK_TILE_DEVICE auto load(DistributedTensor& dst_tensor,
322  {
323  using Traits = load_store_traits;
324  using vector_t = typename Traits::vector_t;
325  using SFC_Ys = typename Traits::SFC_Ys;
326 
327  constexpr auto tile_dstr = TileDstr{};
328 
329  // loop over thread tensor space [y0, y1, ...]
330  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
332  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
333  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
334 
335  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
336  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
337 
338  // data index [y0, y1, ...]
339  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
340  constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
341  const auto page_offset = page_idx_[idx_gather];
342 
343  // read from bottom tensor
344  const vector_t vec_value = [&]() {
345  if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
346  {
347  return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
348  bottom_tensor_thread_coord,
349  page_offset,
350  bool_constant<oob_conditional_check>{});
351  }
352  else
353  {
354  return get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
355  bottom_tensor_thread_coord,
356  page_offset,
357  valids_[idx_gather],
358  bool_constant<oob_conditional_check>{});
359  }
360  }();
361 #if 1
362  // write into distributed tensor
363  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
364  constexpr auto idx_ys = generate_tuple(
365  [&](auto jj) {
366  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
367  : idx_ys_start[jj];
368  },
369  number<NDimY>{});
370 
371  constexpr index_t d =
372  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
373  Traits::PackedSize;
374 
375  dst_tensor.get_thread_buffer().template at<d>() =
376  vec_value.template get_as<DataType>()[j / Traits::PackedSize];
377  });
378 #else
379  constexpr index_t d =
380  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
381  static_assert(d % Traits::ScalarPerVector == 0);
382 
383  dst_tensor.get_thread_buffer().template get_as<vector_t>()(
384  number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
385 #endif
386  // move thread coordinate
387  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
388  {
389  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
390 
391  constexpr auto forward_step_scatter = generate_tuple(
392  [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
393  number<NDimY>{});
394 
395  constexpr auto idx_diff_ps_ys = container_concat(
396  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
397  forward_step_scatter);
398 
400  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
401  }
402  });
403  });
404  }
405 
406  template <typename LdsTileWindow_,
407  index_t i_access_unsupport_ = -1,
408  bool oob_conditional_check = true>
409  CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
412  {
413  using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
414  using LdsDataType = typename LdsTileWindow::DataType;
415  using Traits = load_store_traits;
416  using vector_t = typename Traits::vector_t;
417  using SFC_Ys = typename Traits::SFC_Ys;
418 
419  constexpr auto tile_dstr = TileDstr{};
420 
421  // Precompute invariant values outside loops
422  const auto window_origin = lds_tile.get_window_origin();
423  const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
424  const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
425  auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
426 
427  // loop over thread tensor space [y0, y1, ...]
428  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
430  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
431  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
432 
433  auto lds_window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
434  auto lds_bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
435 
436  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
437  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
438 
439  // Use precomputed window origin
440  auto lds_bottom_tensor_thread_idx =
441  window_origin + lds_window_adaptor_thread_coord.get_bottom_index();
442  // Use precomputed tensor descriptor
443  const auto lds_coord =
444  make_tensor_coordinate(tensor_descriptor, lds_bottom_tensor_thread_idx);
445  // Calculate SMEM address using base pointer
446  CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
447 
448  // data index [y0, y1, ...]
449  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
450  constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
451  const auto page_offset = page_idx_[idx_gather];
452 
453  // merge page_offset into bottom_coord
454  auto mixed_bottom_thread_coord = bottom_tensor_thread_coord;
455  mixed_bottom_thread_coord.get_hidden_index()[number<0>{}] += page_offset;
456 
457  // read from bottom tensor
458  if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
459  this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
460  smem,
461  mixed_bottom_thread_coord,
462  number<0>{},
463  bool_constant<oob_conditional_check>{});
464  else
465  this->get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
466  smem,
467  mixed_bottom_thread_coord,
468  number<0>{},
469  valids_[idx_gather],
470  bool_constant<oob_conditional_check>{});
471 
472  // move thread coordinate
473  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
474  {
475  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
476 
477  constexpr auto forward_step_scatter = generate_tuple(
478  [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
479  number<NDimY>{});
480 
481  constexpr auto idx_diff_ps_ys = container_concat(
482  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
483  forward_step_scatter);
484  // lds_diff doesn't need to mask the difference of the gather-dim.
485  constexpr auto lds_idx_diff_ps_ys = container_concat(
486  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
487  idx_diff_ys);
488 
490  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
492  lds_window_adaptor_thread_coord,
493  lds_bottom_tensor_thread_coord,
494  lds_idx_diff_ps_ys);
495  }
496  });
497  });
498  }
499 
500  // TODO: currently async load only implemented in inline asm
501  template <typename LdsTileWindow_,
502  index_t i_access_unsupport_ = -1,
503  bool oob_conditional_check = true,
504  bool pre_nop = false>
505  CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
508  bool_constant<pre_nop> = {}) const
509  {
510  using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
511  // using LdsTensorView = typename LdsTileWindow::BottomTensorView;
512  using LdsDataType = typename LdsTileWindow::DataType;
513  // using LdsDescriptor = typename LdsTileWindow::BottomTensorDesc;
514 
515  // issues * warps * lanes
516  static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
517 
518  const index_t size_per_buf =
519  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
520  make_tuple(number<0>{}, number<0>{}, number<0>{})) *
521  sizeof(LdsDataType);
522 
523  const index_t size_per_wave =
524  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
525  make_tuple(number<0>{}, number<1>{}, number<0>{})) *
526  sizeof(LdsDataType) -
527  size_per_buf;
528 
529  const index_t size_per_issue =
530  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
531  make_tuple(number<1>{}, number<0>{}, number<0>{})) *
532  sizeof(LdsDataType) -
533  size_per_buf;
534 
535  const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
536  m0_set_with_memory(m0_init_value); // This should be wave independent
537 
538  using Traits = load_store_traits;
539 
540  // using vector_type_t = typename Traits::vector_type_t;
541  using vector_t = typename Traits::vector_t;
542  using SFC_Ys = typename Traits::SFC_Ys;
543 
544  LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
545 
546  // loop over thread tensor space [y0, y1, ...]
547  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
549  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
550  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
551 
552  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
553  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
554  constexpr auto pre_nop_ = [&]() {
555  if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
556  return bool_constant<true>{};
557  else
558  return bool_constant<false>{};
559  }();
560 
561  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
562  constexpr auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
563  const auto page_offset = page_idx_[idx_gather];
564 
565  // read from bottom tensor
566  if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
567  {
568  get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
569  smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
570  }
571  else
572  {
573  get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
574  smem,
575  bottom_tensor_thread_coord,
576  page_offset,
577  valids_[idx_gather],
578  0,
579  pre_nop_);
580  }
581 
582  // move thread coordinate
583  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
584  {
585  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
586 
587  constexpr auto forward_step_scatter = generate_tuple(
588  [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
589  number<NDimY>{});
590 
591  constexpr auto idx_diff_ps_ys = container_concat(
592  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
593  forward_step_scatter);
594 
596  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
597 
598  m0_inc_with_memory(size_per_issue);
599  }
600  });
601  });
602  }
603 
604  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
608  {
609  using Traits = load_store_traits;
610 
611  // using vector_type_t = typename Traits::vector_type_t;
612  using vector_t = typename Traits::vector_t;
613  using SFC_Ys = typename Traits::SFC_Ys;
614 
615  constexpr auto tile_dstr = TileDstr{};
616 
617  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
618  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
619  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
620 
621  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
622  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
623 
624  // data index [y0, y1, ...]
625  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
626  constexpr auto idx_gather = idx_ys_start[number<0>{}];
627  const auto page_offset = page_idx_[idx_gather];
628 
629  // read from distributed tensor
630  vector_t vec_value;
631 
632  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
633  constexpr auto idx_ys = generate_tuple(
634  [&](auto jj) {
635  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
636  : idx_ys_start[jj];
637  },
638  number<NDimY>{});
639 
640  constexpr index_t d =
641  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
642  Traits::PackedSize;
643 
644  vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
645  dstr_tensor.get_thread_buffer().template at<d>();
646  });
647 
648  // write into bottom tensor
649  if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
650  {
651  get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
652  bottom_tensor_thread_coord,
653  page_offset,
654  vec_value,
655  bool_constant<oob_conditional_check>{});
656  }
657  else
658  {
659  get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
660  bottom_tensor_thread_coord,
661  page_offset,
662  valids_[idx_gather],
663  vec_value,
664  bool_constant<oob_conditional_check>{});
665  }
666 
667  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
668  {
669  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
670 
671  constexpr auto forward_step_scatter = generate_tuple(
672  [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
673  number<NDimY>{});
674 
675  constexpr auto idx_diff_ps_ys = container_concat(
676  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
677  forward_step_scatter);
678 
680  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
681  }
682  });
683  });
684  }
685 
686  template <index_t i_access_unsupport_ = -1, bool oob_conditional_check = true>
690  {
691  using Traits = load_store_traits;
692 
693  // using vector_type_t = typename Traits::vector_type_t;
694  using vector_t = typename Traits::vector_t;
695  using SFC_Ys = typename Traits::SFC_Ys;
696 
697  constexpr auto tile_dstr = TileDstr{};
698  // printf("off %d\n", page_idx_[I0]);
699  // loop over thread tensor space [y0, y1, ...]
700  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
701  auto window_adaptor_thread_coord = pre_computed_coords_[iCoord][I0];
702  auto bottom_tensor_thread_coord = pre_computed_coords_[iCoord][I1];
703 
704  static_for<0, NumAccessPerCoord, 1>{}([&](auto iCoordAccess) {
705  constexpr auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
706 
707  // data index [y0, y1, ...]
708  constexpr auto idx_ys_start = SFC_Ys::get_index(iAccess);
709  constexpr auto idx_gather = idx_ys_start[number<0>{}];
710  const auto page_offset = page_idx_[idx_gather];
711 
712  // printf("idx_ys_start[0], idx_ys_start[1](%d, %d) \n",
713  // idx_ys_start[number<0>{}]+0, idx_ys_start[number<1>{}]+0);
714 
715  // read from distributed tensor
716  // vector_type_t vec;
717  vector_t vec_value;
718 
719  static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) {
720  constexpr auto idx_ys = generate_tuple(
721  [&](auto jj) {
722  return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
723  : idx_ys_start[jj];
724  },
725  number<NDimY>{});
726 
727  constexpr index_t d =
728  tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
729  Traits::PackedSize;
730  // printf("thread_idx_m: %d j: %d\n", idx_ys[number<0>{}] + 0, 0+j);
731  vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
732  dstr_tensor.get_thread_buffer().template at<d>();
733  });
734 
735  // const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
736 
737  // write into bottom tensor
738  if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
739  {
740  get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
741  bottom_tensor_thread_coord,
742  page_offset,
743  vec_value,
744  bool_constant<oob_conditional_check>{});
745  }
746  else
747  {
748  get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
749  bottom_tensor_thread_coord,
750  page_offset,
751  valids_[idx_gather],
752  vec_value,
753  bool_constant<oob_conditional_check>{});
754  }
755 
756  // printf("coord_offset:%d, scatter_offset:%d \n",
757  // bottom_tensor_thread_coord.get_offset(), offset); move thread coordinate
758  if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
759  {
760  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
761 
762  constexpr auto forward_step_scatter = generate_tuple(
763  [&](auto i) { return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
764  number<NDimY>{});
765 
766  constexpr auto idx_diff_ps_ys = container_concat(
767  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
768  forward_step_scatter);
769 
771  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
772  }
773  });
774  });
775  }
776 
777  // move thread's botom tensor coordiante
778  // [x0', x1', ... ] ==> [offset]
779  // also move window-origin
781  {
782  window_origin_ += step;
783  BottomTensorIndex step_new = step;
784  step_new(HsGatherDim) = 0;
785  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
786  move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
787  pre_computed_coords_(iCoord)(I1),
788  step_new);
789  });
790  }
791 
792  CK_TILE_DEVICE void update_page_idx(const PageIdxArray& new_idx) { page_idx_ = new_idx; }
793 
794  CK_TILE_DEVICE void update_valids(const ValidArray& new_valids)
795  {
796  if constexpr(std::is_same_v<ValidArray, std::nullptr_t> == false)
797  {
798  valids_ = new_valids;
799  }
800  }
801 
803  const ValidArray& new_valids)
804  {
805  update_page_idx(new_idx);
806  update_valids(new_valids);
807  }
808 
809  CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
810  {
811  window_origin_ = new_window_origin;
812 
813 #if 0 // debug
814  // TODO: this use more register for FA, but less register for GEMM
815  // need investigation
816  // only support warp-tile and block-tile
817  static_assert(NDimP == 1 or NDimP == 2, "wrong!");
818 
819  WindowAdaptorCoord window_adaptor_thread_coord_tmp;
820 
821  if constexpr(NDimP == 1)
822  {
823  window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
824  tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
825  }
826  else if constexpr(NDimP == 2)
827  {
828  window_adaptor_thread_coord_tmp =
829  make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
830  AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
831  }
832 #else
833  // TODO: this use less register for FA, but more register for GEMM
834  // need investigation
835  const auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
836  tile_dstr_.get_ps_ys_to_xs_adaptor(),
838 #endif
839 
840  BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
841  window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
842 
843  bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
844  const auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
845  bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
846 
847  // pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
848  // future load/store() calls (might allocate more registers)
849  using Traits = load_store_traits;
850  using SFC_Ys = typename Traits::SFC_Ys;
851 
852  static_for<0, NumCoord, 1>{}([&](auto iCoord) {
853  auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
854  auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
855 
856  constexpr auto idx_diff_ys =
857  SFC_Ys::get_step_between(number<0>{}, number<iCoord * NumAccessPerCoord>{});
858 
859  constexpr auto idx_diff_ps_ys = container_concat(
860  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}), idx_diff_ys);
861 
863  window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
864 
865  pre_computed_coords_(iCoord) =
866  make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
867  });
868  }
869 
871 
872  // this is the bottom tensor view
873  // [x0', x1', ...] ==> [offset]
875 
876  //
878 
879  // origin ([x0', x1', ...]) of window on bottom tensor
881 
882  // Tile tensor distribution, which contains:
883  // 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
884  // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
886 
889 
890  // this contains:
891  // per-thread coordinate for window adaptor
892  // per-thread coordinate for bottom tensor
894 };
895 
896 // TODO: use strategy
897 template <typename TensorView_,
898  typename WindowLengths_,
899  typename StaticTileDistribution_,
900  typename StaticPageIndexArray_,
901  index_t HsGatherDim = 0,
902  index_t NumCoord = 1>
903 CK_TILE_DEVICE constexpr auto
905  const WindowLengths_& window_lengths,
906  const multi_index<TensorView_::get_num_of_dimension()>& origin,
907  const StaticTileDistribution_& tile_distribution,
908  const StaticPageIndexArray_& page_idx,
909  number<HsGatherDim> = {},
910  number<NumCoord> = {})
911 {
912  return tile_scatter_gather<remove_cvref_t<TensorView_>,
913  remove_cvref_t<WindowLengths_>,
914  remove_cvref_t<StaticTileDistribution_>,
915  remove_cvref_t<StaticPageIndexArray_>,
916  std::nullptr_t,
917  HsGatherDim,
918  NumCoord>{
919  tensor_view, window_lengths, origin, tile_distribution, page_idx, nullptr};
920 }
921 
922 template <typename TensorView,
923  typename WindowLengths,
924  typename StaticTileDistribution,
925  typename StaticPageIndexArray,
926  index_t HsGatherDim>
929  const multi_index<TensorView::get_num_of_dimension()>& origin,
930  const StaticTileDistribution& tile_distribution,
931  const StaticPageIndexArray& page_idx,
932  number<HsGatherDim> = {})
933 {
935  tile_window.get_window_lengths(),
936  origin,
937  tile_distribution,
938  page_idx,
939  number<HsGatherDim>{});
940 }
941 
942 template <typename TensorView,
943  typename WindowLengths,
944  typename StaticTileDistribution,
945  typename StaticPageIndexArray,
946  index_t HsGatherDim>
949  const StaticTileDistribution& tile_distribution,
950  const StaticPageIndexArray& page_idx,
951  number<HsGatherDim> = {})
952 {
954  tile_window.get_window_lengths(),
955  tile_window.get_window_origin(),
956  tile_distribution,
957  page_idx,
958  number<HsGatherDim>{});
959 }
960 
961 template <typename TensorView_,
962  typename WindowLengths_,
963  typename StaticTileDistribution_,
964  typename StaticPageIndexArray_,
965  typename StaticValidArray_,
966  index_t HsGatherDim = 0,
967  index_t NumCoord = 1>
968 CK_TILE_DEVICE constexpr auto
970  const WindowLengths_& window_lengths,
971  const multi_index<TensorView_::get_num_of_dimension()>& origin,
972  const StaticTileDistribution_& tile_distribution,
973  const StaticPageIndexArray_& page_idx,
974  const StaticValidArray_& valids,
975  number<HsGatherDim> = {},
976  number<NumCoord> = {})
977 {
978  return tile_scatter_gather<remove_cvref_t<TensorView_>,
979  remove_cvref_t<WindowLengths_>,
980  remove_cvref_t<StaticTileDistribution_>,
981  remove_cvref_t<StaticPageIndexArray_>,
982  remove_cvref_t<StaticValidArray_>,
983  HsGatherDim,
984  NumCoord>{
985  tensor_view, window_lengths, origin, tile_distribution, page_idx, valids};
986 }
987 
988 template <typename TensorView,
989  typename WindowLengths,
990  typename StaticTileDistribution,
991  typename StaticPageIndexArray,
992  typename StaticValidArray,
993  index_t HsGatherDim>
996  const multi_index<TensorView::get_num_of_dimension()>& origin,
997  const StaticTileDistribution& tile_distribution,
998  const StaticPageIndexArray& page_idx,
999  const StaticValidArray& valids,
1000  number<HsGatherDim> = {})
1001 {
1002  return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
1003  tile_window.get_window_lengths(),
1004  origin,
1005  tile_distribution,
1006  page_idx,
1007  valids,
1008  number<HsGatherDim>{});
1009 }
1010 
1011 template <typename TensorView,
1012  typename WindowLengths,
1013  typename StaticTileDistribution,
1014  typename StaticPageIndexArray,
1015  typename StaticValidArray,
1016  index_t HsGatherDim>
1019  const StaticTileDistribution& tile_distribution,
1020  const StaticPageIndexArray& page_idx,
1021  const StaticValidArray& valids,
1022  number<HsGatherDim> = {})
1023 {
1024  return make_tile_scatter_gather(tile_window.get_bottom_tensor_view(),
1025  tile_window.get_window_lengths(),
1026  tile_window.get_window_origin(),
1027  tile_distribution,
1028  page_idx,
1029  valids,
1030  number<HsGatherDim>{});
1031 }
1032 
1033 template <typename NewTensorView_,
1034  typename OldTensorView_,
1035  typename WindowLengths_,
1036  typename StaticTileDistribution_,
1037  typename StaticPageIndexArray_,
1038  typename StaticValidArray_,
1039  index_t HsGatherDim = 0,
1040  index_t NumCoord = 1>
1041 CK_TILE_DEVICE auto replace_bottom_tensor_view(const NewTensorView_& new_tensor_view,
1042  const tile_scatter_gather<OldTensorView_,
1043  WindowLengths_,
1044  StaticTileDistribution_,
1045  StaticPageIndexArray_,
1046  StaticValidArray_,
1047  HsGatherDim,
1048  NumCoord>& tile_window)
1049 {
1050  return make_tile_scatter_gather(new_tensor_view,
1051  tile_window.window_lengths_,
1052  tile_window.window_origin_,
1053  tile_window.tile_dstr_,
1054  tile_window.page_idx_,
1055  tile_window.valids_);
1056 }
1057 
1058 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_LDS_ADDR
Definition: config.hpp:62
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition: container_helper.hpp:420
constexpr CK_TILE_HOST_DEVICE void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition: tensor_coordinate.hpp:72
CK_TILE_DEVICE auto replace_bottom_tensor_view(const NewTensorView_ &new_tensor_view, const tile_scatter_gather< OldTensorView_, WindowLengths_, StaticTileDistribution_, StaticPageIndexArray_, StaticValidArray_, HsGatherDim, NumCoord > &tile_window)
Definition: tile_scatter_gather.hpp:1041
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:56
constant< b > bool_constant
Definition: integral_constant.hpp:43
constexpr CK_TILE_HOST_DEVICE auto generate_array(F &&f, number< N >)
Definition: sequence.hpp:1126
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition: tensor_coordinate.hpp:60
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constant< v > number
Definition: integral_constant.hpp:37
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1066
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:21
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
Definition: utility.hpp:19
typename std::remove_reference< T >::type remove_reference_t
Definition: type_traits.hpp:15
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_DEVICE auto make_tile_scatter_gather(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, const StaticPageIndexArray_ &page_idx, number< HsGatherDim >={}, number< NumCoord >={})
Definition: tile_scatter_gather.hpp:904
constexpr CK_TILE_HOST_DEVICE void move_tensor_adaptor_coordinate(const Adaptor &adaptor, AdaptorCoord &coord, const TopIndex &idx_diff_top, BottomIndex &idx_diff_bottom)
Definition: tensor_adaptor_coordinate.hpp:98
constexpr CK_TILE_HOST_DEVICE auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition: container_helper.hpp:389
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
Definition: utility.hpp:25
impl::is_static_impl< remove_cvref_t< T > > is_static
Definition: type_traits.hpp:87
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
Definition: sequence.hpp:298
typename std::conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:313
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:76
Definition: numeric.hpp:81
Definition: space_filling_curve.hpp:20
Definition: static_distributed_tensor.hpp:21
constexpr CK_TILE_HOST_DEVICE const auto & get_thread_buffer() const
Definition: static_distributed_tensor.hpp:58
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
Definition: debug.hpp:27
Definition: tile_distribution.hpp:70
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:124
Definition: tile_scatter_gather.hpp:81
static constexpr index_t PackedSize
Definition: tile_scatter_gather.hpp:104
static constexpr index_t NumAccess
Definition: tile_scatter_gather.hpp:143
decltype(get_space_filling_curve()) SFC_Ys
Definition: tile_scatter_gather.hpp:141
static constexpr index_t VectorDimY
Definition: tile_scatter_gather.hpp:106
static constexpr index_t ScalarPerVector
Definition: tile_scatter_gather.hpp:107
This class provides tile (windowed) view and access to the device memory.
Definition: tile_scatter_gather.hpp:41
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition: tile_scatter_gather.hpp:780
static constexpr index_t NumAccessPerCoord
Definition: tile_scatter_gather.hpp:149
static constexpr auto I1
Definition: tile_scatter_gather.hpp:59
constexpr CK_TILE_DEVICE tile_scatter_gather(const BottomTensorView &bottom_tensor_view, const WindowLengths &window_lengths, const BottomTensorIndex &window_origin, const TileDstr &tile_distribution, const PageIdxArray &page_idx, const ValidArray &valids)
Definition: tile_scatter_gather.hpp:153
BottomTensorIndex window_origin_
Definition: tile_scatter_gather.hpp:880
CK_TILE_DEVICE auto load(number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:307
WindowLengths window_lengths_
Definition: tile_scatter_gather.hpp:877
constexpr CK_TILE_DEVICE auto get_tile_distribution() const
Definition: tile_scatter_gather.hpp:232
constexpr CK_TILE_DEVICE auto get_num_of_access() const
Definition: tile_scatter_gather.hpp:304
static constexpr index_t NDimBottomTensor
Definition: tile_scatter_gather.hpp:53
static constexpr CK_TILE_DEVICE auto get_window_adaptor_ys_safe_vector_length_strides()
Definition: tile_scatter_gather.hpp:265
array< index_t, NDimBottomTensor > BottomTensorIndex
Definition: tile_scatter_gather.hpp:72
PageIdxArray page_idx_
Definition: tile_scatter_gather.hpp:887
CK_TILE_DEVICE auto async_load(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:409
remove_cvref_t< WindowLengths_ > WindowLengths
Definition: tile_scatter_gather.hpp:43
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex &new_window_origin)
Definition: tile_scatter_gather.hpp:809
array< tuple< WindowAdaptorCoord, BottomTensorCoord >, NumCoord > pre_computed_coords_
Definition: tile_scatter_gather.hpp:893
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_scatter_gather.hpp:236
remove_cvref_t< StaticTileDistribution_ > TileDstr
Definition: tile_scatter_gather.hpp:44
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition: tile_scatter_gather.hpp:247
CK_TILE_DEVICE auto load(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:319
CK_TILE_DEVICE void store(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:687
CK_TILE_DEVICE void update_page_idx_and_valids(const PageIdxArray &new_idx, const ValidArray &new_valids)
Definition: tile_scatter_gather.hpp:802
typename BottomTensorView::TensorDesc BottomTensorDesc
Definition: tile_scatter_gather.hpp:48
TileDstr tile_dstr_
Definition: tile_scatter_gather.hpp:885
ValidArray valids_
Definition: tile_scatter_gather.hpp:888
static constexpr index_t NDimY
Definition: tile_scatter_gather.hpp:56
remove_cvref_t< typename BottomTensorView::DataType > DataType
Definition: tile_scatter_gather.hpp:50
static constexpr index_t NDimWindowAdaptorTop
Definition: tile_scatter_gather.hpp:52
static constexpr CK_TILE_DEVICE bool has_static_tile_distribution()
Definition: tile_scatter_gather.hpp:225
remove_cvref_t< StaticValidArray_ > ValidArray
Definition: tile_scatter_gather.hpp:46
static constexpr index_t NDimP
Definition: tile_scatter_gather.hpp:55
remove_reference_t< BottomTensorView_ > BottomTensorView
Definition: tile_scatter_gather.hpp:42
constexpr CK_TILE_DEVICE tile_scatter_gather()=default
CK_TILE_DEVICE void update(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:605
remove_cvref_t< StaticPageIndexArray_ > PageIdxArray
Definition: tile_scatter_gather.hpp:45
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_scatter_gather.hpp:230
CK_TILE_HOST_DEVICE void init_raw()
Definition: tile_scatter_gather.hpp:870
static constexpr auto I0
Definition: tile_scatter_gather.hpp:58
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})) BottomTensorCoord
Definition: tile_scatter_gather.hpp:78
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_scatter_gather.hpp:234
typename TileDstr::PsYs2XsAdaptor WindowAdaptor
Definition: tile_scatter_gather.hpp:47
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})) WindowAdaptorCoord
Definition: tile_scatter_gather.hpp:75
constexpr CK_TILE_DEVICE void set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType *data)
Definition: tile_scatter_gather.hpp:239
BottomTensorView bottom_tensor_view_
Definition: tile_scatter_gather.hpp:874
CK_TILE_DEVICE void update_valids(const ValidArray &new_valids)
Definition: tile_scatter_gather.hpp:794
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_scatter_gather.hpp:505
array< index_t, NDimWindowAdaptorTop > AdaptorTopIndex
Definition: tile_scatter_gather.hpp:71
CK_TILE_DEVICE void update_page_idx(const PageIdxArray &new_idx)
Definition: tile_scatter_gather.hpp:792
static constexpr CK_TILE_DEVICE index_t get_num_of_dimension()
Definition: tile_scatter_gather.hpp:223
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_window_base.hpp:45
BottomTensorIndex window_origin_
Definition: tile_window_base.hpp:79
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_window_base.hpp:47
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_window_base.hpp:46
WindowLengths window_lengths_
Definition: tile_window_base.hpp:81
This class provides description of tile windowed view on the device memory.
Definition: tile_window.hpp:1195
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10