/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/core/tensor/tile_window_linear.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/core/tensor/tile_window_linear.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/core/tensor/tile_window_linear.hpp Source File
tile_window_linear.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
18 
19 namespace ck_tile {
20 
21 #define WINDOW_DISPATCH_ISSUE() \
22  if constexpr(i_access < 0) \
23  { \
24  static_for<0, NumAccess, 1>{}([&](auto ia) { issue(ia); }); \
25  } \
26  else \
27  { \
28  static_assert(i_access < NumAccess); \
29  issue(number<i_access>{}); \
30  }
31 
32 //
33 // This version of tile window will pre-cache offset/flags based on need
34 //
35 // LinearBottomDims_, e.g seq<0, 1> for 2d tensor, the last one is linear dim
36 // so last dim can use immediate offset to indexing, can save register
37 // TODO: if using this struct, better use load_raw()/store_raw(), can control
38 // the the immediate offset on the fly
39 // space-filing-curve is non-snaked here!
40 //
41 template <typename BottomTensorView_,
42  typename WindowLengths_,
43  typename StaticTileDistribution_,
44  typename LinearBottomDims_>
46 {
50 
51  using WindowAdaptor = typename TileDstr::PsYs2XsAdaptor;
52  using BottomTensorDesc = typename BottomTensorView::TensorDesc;
53 
56 
57  static_assert(LinearBottomDims::size() == BottomTensorView::get_num_of_dimension());
58 
59  static constexpr index_t NDimWindowAdaptorTop = WindowAdaptor::get_num_of_top_dimension();
60  static constexpr index_t NDimBottomTensor = BottomTensorDesc::get_num_of_dimension();
61 
62  static constexpr index_t NDimP = TileDstr::get_num_of_dimension_p();
63  static constexpr index_t NDimY = TileDstr::get_num_of_dimension_y();
64 
65  static constexpr auto I0 = number<0>{};
66  static constexpr auto I1 = number<1>{};
67 
68  // TODO: check WindowLengths and StaticTileDistribution are consistent
69 
71  "wrong! lengths should be static");
72  static_assert(TileDstr::is_static(), "wrong!");
73 
74  static_assert(NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
75  "wrong! inconsistent # of diemsnions");
76 
79 
82 
85 
86  struct traits
87  {
88  private:
89  // return vector dimension among [y0, y1, ...]
90  CK_TILE_DEVICE static constexpr auto get_window_adaptor_ys_safe_vector_length_strides()
91  {
92  // bottom tensor top dimension vector lengths and strides
93  const auto [bottom_tensor_top_dim_vector_lengths,
94  bottom_tensor_top_dim_vector_strides] =
95  BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
96 
97  // window vector lengths/strides
98  const auto window_adaptor_bottom_dim_vector_lengths =
99  bottom_tensor_top_dim_vector_lengths;
100  const auto window_adaptor_bottom_dim_vector_strides =
101  bottom_tensor_top_dim_vector_strides;
102 
103  // window adaptor [p0, p1, ..., y0, y1, ...]
104  array<index_t, WindowAdaptor::get_num_of_hidden_dimension()>
105  window_adaptor_vector_lengths{-1};
106  array<index_t, WindowAdaptor::get_num_of_hidden_dimension()>
107  window_adaptor_vector_strides{-1};
108 
109  constexpr auto window_adaptor_bottom_dims =
110  WindowAdaptor::get_bottom_dimension_hidden_ids();
111 
112  set_container_subset(window_adaptor_vector_lengths,
113  window_adaptor_bottom_dims,
114  window_adaptor_bottom_dim_vector_lengths);
115  set_container_subset(window_adaptor_vector_strides,
116  window_adaptor_bottom_dims,
117  window_adaptor_bottom_dim_vector_strides);
118 
119  const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
120  WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
121  window_adaptor_vector_lengths, window_adaptor_vector_strides);
122 
123  // [y0, y1, ...]
124  constexpr auto y_dims =
125  typename arithmetic_sequence_gen<TileDstr::get_num_of_dimension_p(),
127  1>::type{};
128 
129  return make_tuple(get_container_subset(window_adaptor_ps_ys_vector_lengths, y_dims),
130  get_container_subset(window_adaptor_ps_ys_vector_strides, y_dims));
131  }
132 
133  static constexpr auto get_vector_dim_y_scalar_per_vector()
134  {
135  const auto [ys_vector_lengths, ys_vector_strides] =
136  get_window_adaptor_ys_safe_vector_length_strides();
137 
138  index_t VectorDimY_ = 0;
139  index_t ScalarPerVector_ = 1;
140 
141  for(index_t i = 0; i < NDimY; ++i)
142  {
143  if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
144  {
145  ScalarPerVector_ = ys_vector_lengths[i];
146  VectorDimY_ = i;
147  }
148  }
149 
150  return make_tuple(VectorDimY_, ScalarPerVector_);
151  }
152 
153  public:
154  static constexpr index_t VectorDimY = get_vector_dim_y_scalar_per_vector().template at<0>();
155  static constexpr index_t ScalarPerVector =
156  get_vector_dim_y_scalar_per_vector().template at<1>();
157 
159 
160  private:
161  static constexpr auto scalars_per_access_ = [] {
162  constexpr auto scalars_per_access_arr = generate_array(
163  [&](auto i) { return (i == VectorDimY) ? ScalarPerVector : 1; }, number<NDimY>{});
164 
166  constexpr auto NDimY_ = NDimY;
167 
168  return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
169  }();
170 
171  static constexpr auto get_space_filling_curve()
172  {
173  constexpr auto thread_tensor_lengths_ys =
174  to_sequence(TileDstr{}.get_ys_to_d_descriptor().get_lengths());
175 
176  // FIXME: need logic to judge dim access order
177  using DimAccessOrder = typename arithmetic_sequence_gen<0, NDimY, 1>::type;
178 
179  return space_filling_curve<decltype(thread_tensor_lengths_ys),
180  DimAccessOrder,
181  decltype(scalars_per_access_),
182  false >{};
183  }
184 
185  public:
186  using SFC_Ys = decltype(get_space_filling_curve());
187 
188  static constexpr index_t NumAccess = SFC_Ys::get_num_of_access();
189 
190  static_assert(0 < NumAccess, "Wrong! NumAccess should be larger than 0");
191 
192  private:
193  static constexpr auto get_num_non_linear_access()
194  {
195  constexpr auto sfc_access_lens = SFC_Ys::access_lengths;
196  using ys_to_rhs_major =
197  typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
198 
199  constexpr auto non_linear = [&]() {
200  index_t cnt = 1;
201  static_for<0, NDimY, 1>{}([&](auto i_dim_y) {
202  constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y];
203  constexpr auto target_h_dim = number<rhs_major - 1>{}; // no r dim here!
204  if constexpr(LinearBottomDims{}[target_h_dim] == 0)
205  {
206  cnt *= sfc_access_lens[i_dim_y];
207  }
208  });
209  return cnt;
210  }();
211 
212  return non_linear;
213  }
214 
215  // example:
216  // non_linear_access_map: sequence<0, 0, 0, 0, 1, 1, 1, 1> for 8 access, totally 2 register
217  // used
218  // -> histogram : sequence<4, 4>
219  // -> prefixsum : seqneuce<0, 4, 8>
220  // non_linear_access_map: sequence<0, 1, 2, 3, 4, 5, 6, 7> for 8 access, totally 8 register
221  // used, will pre-cache 8
222  // -> histogram : sequence<1, 1, 1, 1, 1, 1, 1, 1>
223  // -> prefixsum : seqneuce<0, 1, 2, 3, 4, 5, 6, 7, 8>
224  // non_linear_access_map: sequence<0, 0, 1, 1, 2, 2, 3, 3> for 8 access, totally 4 register
225  // used, will pre-cache 4
226  // -> histogram : sequence<2, 2, 2, 2>
227  // -> prefixsum : seqneuce<0, 2, 4, 6, 8>
228  static constexpr auto get_non_linear_access_map()
229  {
230  constexpr auto sfc_access_lens = SFC_Ys::access_lengths;
231  using ys_to_rhs_major =
232  typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
233  constexpr auto non_linear_map = [&]() {
234  array<index_t, NumAccess> m_{0};
235  index_t cumulative_len_ = 1;
236  index_t cumulative_non_linear_len_ = 1;
237  static_for<0, NDimY, 1>{}([&](auto i_y) {
238  constexpr auto i_dim_y = number<NDimY - i_y - 1>{}; // from right to left
239  constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y];
240  constexpr auto target_h_dim = number<rhs_major - 1>{}; // no r dim here!
241  constexpr auto is_linear_dim = LinearBottomDims{}[target_h_dim];
242 
243  array<index_t, NumAccess> current_m_{0};
244  constexpr auto current_len_ = sfc_access_lens[i_dim_y];
245 
246  // copy cumulative length as current pattern
247  for(auto i_ = 0; i_ < cumulative_len_; i_++)
248  {
249  current_m_(i_) = m_[i_];
250  }
251  for(auto j_ = 0; j_ < current_len_; j_++)
252  {
253  auto j_offset_ = is_linear_dim ? 0 : j_ * cumulative_non_linear_len_;
254  for(auto i_ = 0; i_ < cumulative_len_; i_++)
255  {
256  m_(j_ * cumulative_len_ + i_) = current_m_[i_] + j_offset_;
257  }
258  }
259  cumulative_len_ *= current_len_;
260  if(!is_linear_dim)
261  cumulative_non_linear_len_ *= current_len_;
262  });
263  return m_;
264  }();
265 
266  return TO_SEQUENCE(non_linear_map, NumAccess);
267  }
268 
269  static constexpr auto get_non_linear_access_histogram()
270  {
271  constexpr auto m_ = get_non_linear_access_map();
272  // m_.foo();
273 
274  constexpr auto r_ =
275  typename arithmetic_sequence_gen<0, get_num_non_linear_access() + 1, 1>::type{};
276 
277  constexpr auto h_ = histogram_sorted_sequence(m_, r_);
278 
279  return h_;
280  }
281 
282  static constexpr auto get_non_linear_access_histogram_prefix_sum()
283  {
284  constexpr auto h_ = get_non_linear_access_histogram();
285  constexpr auto h_prefix_sum_ = prefix_sum_sequence(h_);
286  return h_prefix_sum_;
287  }
288 
289  public:
290  static constexpr index_t NumAccess_NonLinear = get_num_non_linear_access();
291  using AccessMap_NonLinear = decltype(get_non_linear_access_map()); // sequence
292  using AccessHistogram_NonLinear = decltype(get_non_linear_access_histogram());
293  using AccessPrefixSum_NonLinear = decltype(get_non_linear_access_histogram_prefix_sum());
294  };
295 
296  static constexpr index_t NumAccess = traits::NumAccess;
301 
302  CK_TILE_DEVICE constexpr tile_window_linear() = default;
303 
304  CK_TILE_DEVICE constexpr tile_window_linear(const BottomTensorView& bottom_tensor_view,
305  const WindowLengths& window_lengths,
306  const BottomTensorIndex& window_origin,
308  : bottom_tensor_view_{bottom_tensor_view},
309  window_lengths_{window_lengths},
310  window_origin_{window_origin},
312  cached_coords_{},
313  cached_flags_{}
314  {
315  auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
318  generate_tuple([&](auto) { return number<0>{}; }, number<NDimY>{})));
319 
320  BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
321  window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
322 
323  auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
324  bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
325 
326  // future load/store() calls (might allocate more registers)
327  using SFC_Ys = typename traits::SFC_Ys;
328 
329  static_for<0, NumAccess, 1>{}([&](auto i_access) {
330  constexpr auto non_linear_id = number<AccessMap_NonLinear{}[i_access]>{};
331  constexpr auto need_save_non_linear_coord =
332  bool_constant<AccessPrefixSum_NonLinear{}[non_linear_id] == i_access>{};
333 
334  if constexpr(need_save_non_linear_coord)
335  {
336  cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
337  }
338 
339  // TODO: need pad_tensor_view to check which dim need use flag to check
340  // cached flag is independent from non-linear-coord
341  // but need be updated in move_tile, with proper dims
343  bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_coord_tmp);
344 
345  if constexpr(i_access != (NumAccess - 1))
346  {
347  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(i_access); // tuple of number
348  constexpr auto idx_diff_ps_ys = container_concat(
349  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
350  idx_diff_ys);
351 
353  window_adaptor_thread_coord_tmp,
354  bottom_tensor_thread_coord_tmp,
355  idx_diff_ps_ys);
356  }
357  });
358  }
359 
361 
363  {
364  return TileDstr::is_static();
365  }
366 
367  CK_TILE_DEVICE constexpr auto get_window_lengths() const { return window_lengths_; }
368 
369  CK_TILE_DEVICE constexpr auto get_tile_distribution() const { return tile_dstr_; }
370 
371  CK_TILE_DEVICE constexpr auto get_bottom_tensor_view() const { return bottom_tensor_view_; }
372 
373  CK_TILE_DEVICE constexpr auto get_window_origin() const { return window_origin_; }
374 
375  CK_TILE_DEVICE constexpr void
376  set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType* data)
377  {
378  bottom_tensor_view_.buf_.p_data_ = data;
379  }
380 
381  // move thread's window adaptor coordinate and bottom tensor coordinate
382  // [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
383  template <typename ATopIndex>
385  WindowAdaptorCoord& window_adaptor_thread_coord,
386  BottomTensorCoord& bottom_tensor_thread_coord,
387  const ATopIndex& idx_diff_adaptor_top) const
388  {
389  array<index_t, NDimBottomTensor> idx_diff_adaptor_bottom;
390 
391  move_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
392  window_adaptor_thread_coord,
393  idx_diff_adaptor_top,
394  idx_diff_adaptor_bottom);
395 
396  move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
397  bottom_tensor_thread_coord,
398  idx_diff_adaptor_bottom);
399  }
400 
401  template <index_t i_access>
403  {
404  using SFC_Ys = typename traits::SFC_Ys;
405  constexpr auto idx_ys = SFC_Ys::get_index(number<i_access>{});
406  using ys_to_rhs_major =
407  typename decltype(TileDstr{}.get_static_tile_distribution_encoding())::Ys2RHsMajor;
408 
409  constexpr auto modified_idx_ys = generate_tuple(
410  [&](auto i_dim_y) {
411  constexpr auto rhs_major = ys_to_rhs_major{}[i_dim_y];
412  constexpr auto target_h_dim = number<rhs_major - 1>{}; // no r dim here!
413  if constexpr(LinearBottomDims{}[target_h_dim] == 0)
414  {
415  return number<0>{};
416  }
417  else
418  {
419  return number<idx_ys[i_dim_y]>{};
420  }
421  },
422  number<NDimY>{});
423 
424  constexpr auto adaptor_ = TileDstr{}.get_ps_ys_to_xs_adaptor();
425  constexpr auto idx_ =
426  container_concat(make_tuple(number<0>{}, number<0>{}), modified_idx_ys);
427 
428  return adaptor_.calculate_bottom_index(idx_);
429  }
430 
431  template <index_t i_access>
433  {
434  constexpr auto linear_coord = get_bottom_linear_coordinate(number<i_access>{});
435  constexpr auto is_pure_linear_tensor =
437  if constexpr(is_pure_linear_tensor)
438  {
439  // this case usually is a LDS window, everything is known at compile tile.
440  // we directly use BottomTensorView transform to compute the offset, in case padding
441  auto bottom_tensor_coord =
442  make_tensor_coordinate(BottomTensorView{}.get_tensor_descriptor(), linear_coord);
443  return bottom_tensor_coord.get_offset();
444  }
445  else
446  {
447  // this case usually is a global window, where last dim can be linear
448  // we hack here, that use the original TileDstr to compute the linear offset
449  // ... hoping that there is no extra padding between other dims, which make sense
450  // since that would introduce runtime length (so can't use linear offset)
451  constexpr index_t linear_offset = [&]() {
452  constexpr auto x_idx_ = linear_coord;
453  constexpr auto x_len_ = TileDstr{}.get_lengths();
454  static_assert(x_idx_.size() == x_len_.size());
455  constexpr index_t x_dims_ = x_idx_.size();
456  index_t cu_stride_ = 1;
457  index_t cu_offset_ = 0;
458  static_for<0, x_dims_, 1>{}([&](auto i_) {
459  auto r_i_ = number<x_dims_ - i_ - 1>{};
460  cu_offset_ += x_idx_[r_i_] * cu_stride_;
461  cu_stride_ *= x_len_[r_i_];
462  });
463  return cu_offset_;
464  }();
465  return linear_offset;
466  }
467  }
468 
469  CK_TILE_DEVICE constexpr auto get_num_of_access() const { return traits::NumAccess; }
470 
471  template <index_t i_access = -1, bool oob_conditional_check = true>
473  {
474  using vector_t = typename traits::vector_t;
475  using SFC_Ys = typename traits::SFC_Ys;
476 
477  constexpr auto tile_dstr = TileDstr{};
478 
479  auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
480 
481  auto issue = [&](auto i_access_) {
482  constexpr auto IAccess = number<i_access_>{};
483 
484  constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
485  auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
486  auto bottom_tensor_flag = cached_flags_[IAccess];
487 
488  constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
489 
490  // read from bottom tensor
491  const vector_t vec_value =
492  get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
493  bottom_tensor_thread_coord,
494  linear_offset,
495  bottom_tensor_flag,
496  bool_constant<oob_conditional_check>{});
497 #if 1
498  // data index [y0, y1, ...]
499  constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
500  // write into distributed tensor
501  static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
502  constexpr auto idx_ys = generate_tuple(
503  [&](auto jj) {
504  return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
505  },
506  number<NDimY>{});
507 
508  constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
509 
510  dst_tensor.get_thread_buffer().template at<d>() =
511  vec_value.template get_as<DataType>()[j];
512  });
513 #else
514  constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
515  static_assert(d % traits::ScalarPerVector == 0);
516 
517  dst_tensor.get_thread_buffer().template get_as<vector_t>()(
518  number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
519 #endif
520  };
521 
523 
524  return dst_tensor;
525  }
526 
527  template <typename DstTile, index_t i_access = -1, bool oob_conditional_check = true>
528  CK_TILE_DEVICE auto load(DstTile& dst_tensor,
529  number<i_access> = {},
531  {
532  using vector_t = typename traits::vector_t;
533  using SFC_Ys = typename traits::SFC_Ys;
534 
535  constexpr auto tile_dstr = TileDstr{};
536 
537  // auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
538 
539  auto issue = [&](auto i_access_) {
540  constexpr auto IAccess = number<i_access_>{};
541 
542  constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
543  auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
544  auto bottom_tensor_flag = cached_flags_[IAccess];
545 
546  constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
547 
548  // read from bottom tensor
549  const vector_t vec_value =
550  get_bottom_tensor_view().template get_vectorized_elements<vector_t>(
551  bottom_tensor_thread_coord,
552  linear_offset,
553  bottom_tensor_flag,
554  bool_constant<oob_conditional_check>{});
555 #if 1
556  // data index [y0, y1, ...]
557  constexpr auto idx_diff_ys = SFC_Ys::get_index(IAccess);
558  // write into distributed tensor
559  static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
560  constexpr auto idx_ys = generate_tuple(
561  [&](auto jj) {
562  return jj == traits::VectorDimY ? (idx_diff_ys[jj] + j) : idx_diff_ys[jj];
563  },
564  number<NDimY>{});
565 
566  constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
567 
568  dst_tensor.get_thread_buffer().template at<d>() =
569  vec_value.template get_as<DataType>()[j];
570  });
571 #else
572  constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
573  static_assert(d % traits::ScalarPerVector == 0);
574 
575  dst_tensor.get_thread_buffer().template get_as<vector_t>()(
576  number<d / traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
577 #endif
578  };
579 
581 
582  return dst_tensor;
583  }
584 
585  template <typename DstTile,
586  index_t i_access = -1,
587  bool oob_conditional_check = true,
588  bool pre_nop = false>
589  CK_TILE_DEVICE void load_raw(DstTile& dst_tensor,
590  number<i_access> = {}, // negative means loop over all num_access
592  bool_constant<pre_nop> = {}) const
593  {
594  using vector_t = typename traits::vector_t;
595  using SFC_Ys = typename traits::SFC_Ys;
596  static constexpr index_t YElementSize =
597  TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
598  static_assert(YElementSize % traits::ScalarPerVector == 0);
599  using vectorized_tbuf = array<vector_t, YElementSize / traits::ScalarPerVector>;
600 
601  constexpr auto tile_dstr = TileDstr{};
602 
603  auto& dst_vec_tbuf = reinterpret_cast<vectorized_tbuf&>(dst_tensor.get_thread_buffer());
604 
605  auto issue = [&](auto i_access_) {
606  constexpr auto IAccess = number<i_access_>{};
607  constexpr auto pre_nop_ = [&]() {
608  if constexpr(pre_nop && i_access_ == 0 &&
609  BottomTensorView::buffer_view::get_address_space() ==
611  return bool_constant<true>{};
612  else
613  return bool_constant<false>{};
614  }();
615 
616  constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
617  auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
618  constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
619  auto bottom_tensor_flag = cached_flags_[IAccess];
620 
621  // data index [y0, y1, ...]
622  constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
623  constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
624  static_assert(d % traits::ScalarPerVector == 0);
625 
626  get_bottom_tensor_view().template get_vectorized_elements_raw<vector_t>(
627  dst_vec_tbuf.template at<d / traits::ScalarPerVector>(),
628  bottom_tensor_thread_coord,
629  linear_offset ,
630  bottom_tensor_flag,
631  bool_constant<oob_conditional_check>{},
632  pre_nop_);
633 #if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
634  CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
635  asm volatile(""); // this is starting from rocm-6.2, but same sympton, reuse this flag
636 #endif
637  };
638 
640  }
641 
642  // TODO: currently async load only implemented in inline asm
643  template <typename LdsTileWindow_,
644  index_t i_access = -1,
645  bool oob_conditional_check = true,
646  bool pre_nop = false>
647  CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_&& lds_tile,
648  number<i_access> = {},
650  bool_constant<pre_nop> = {}) const
651  {
652  using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
653  using LdsDataType = typename LdsTileWindow::DataType;
654 
655  // currently we only support everything is non linear dim
656  // actually it's not performant if we have linear dim(e.g. fast changing)
657  static_assert(NumAccess_NonLinear == NumAccess);
658  static_assert(BottomTensorView::buffer_view::get_address_space() ==
660 
661  // issues * warps * lanes
662  static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
663 
664  const index_t size_per_buf =
665  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
666  make_tuple(number<0>{}, number<0>{}, number<0>{})) *
667  sizeof(LdsDataType);
668 
669  const index_t size_per_wave =
670  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
671  make_tuple(number<0>{}, number<1>{}, number<0>{})) *
672  sizeof(LdsDataType) -
673  size_per_buf;
674 
675  const index_t size_per_issue =
676  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
677  make_tuple(number<1>{}, number<0>{}, number<0>{})) *
678  sizeof(LdsDataType) -
679  size_per_buf;
680 
681  const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
682  m0_set_with_memory(m0_init_value); // This should be wave independent
683 
684  using vector_t = typename traits::vector_t;
685 
686  LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
687 
688  // loop over thread tensor space [y0, y1, ...]
689  auto issue = [&](auto i_access_) {
690  constexpr auto IAccess = number<i_access_>{};
691  constexpr auto pre_nop_ = [&]() {
692  if constexpr(pre_nop && i_access_ == 0)
693  return bool_constant<true>{};
694  else
695  return bool_constant<false>{};
696  }();
697 
698  constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
699  auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
700  auto bottom_tensor_flag = cached_flags_[IAccess]; // get this flag anyway
701 
702  // read from bottom tensor
703  get_bottom_tensor_view().template async_get_vectorized_elements_raw<vector_t>(
704  smem, bottom_tensor_thread_coord, 0, bottom_tensor_flag, pre_nop_);
705 
706  // move thread coordinate
707  if constexpr(i_access_ != (NumAccess - 1))
708  {
709  m0_inc_with_memory(size_per_issue);
710  }
711  };
712 
714  }
715 
716  template <typename LdsTileWindow_, index_t i_access = -1, bool oob_conditional_check = true>
717  CK_TILE_DEVICE auto async_load(LdsTileWindow_&& lds_tile,
718  number<i_access> = {},
720  {
721  using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
722  using LdsDataType = typename LdsTileWindow::DataType;
723 
724  // currently we only support everything is non linear dim
725  // actually it's not performant if we have linear dim(e.g. fast changing)
726  static_assert(NumAccess_NonLinear == NumAccess);
727  static_assert(BottomTensorView::buffer_view::get_address_space() ==
729 
730  // issues * warps * lanes
731  static_assert(LdsTileWindow::get_num_of_dimension() == 3); // TODO: hard coded
732 
733  // TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
734  // dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
735  // check?)
736  constexpr index_t size_per_buf =
737  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
738  make_tuple(number<0>{}, number<0>{}, number<0>{}));
739 
740  constexpr index_t size_per_wave =
741  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
742  make_tuple(number<0>{}, number<1>{}, number<0>{})) -
743  size_per_buf;
744 
745  constexpr index_t size_per_issue =
746  lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
747  make_tuple(number<1>{}, number<0>{}, number<0>{})) -
748  size_per_buf;
749 
750  const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
751 
752  using vector_t = typename traits::vector_t;
753 
754  // TODO: we force CK_TILE_LDS_ADDR
755  CK_TILE_LDS_ADDR LdsDataType* smem =
756  lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_ + m0_init_value;
757 
758  // loop over thread tensor space [y0, y1, ...]
759  auto issue = [&](auto i_access_) {
760  constexpr auto IAccess = number<i_access_>{};
761  constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
762  auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
763  auto bottom_tensor_flag = cached_flags_[IAccess];
764 
765  // read from bottom tensor
766  get_bottom_tensor_view().template async_get_vectorized_elements<vector_t>(
767  smem,
768  bottom_tensor_thread_coord,
769  0,
770  bottom_tensor_flag,
771  bool_constant<oob_conditional_check>{});
772 
773  // move thread coordinate
774  if constexpr(i_access_ != (NumAccess - 1))
775  {
776  smem += size_per_issue; // Note we manually increase the per-issue offset
777  }
778  };
779 
781  }
782 
783  template <index_t i_access = -1, bool oob_conditional_check = true>
785  number<i_access> = {},
787  {
788 
789  using vector_t = typename traits::vector_t;
790  using SFC_Ys = typename traits::SFC_Ys;
791 
792  constexpr auto tile_dstr = TileDstr{};
793 
794  // loop over thread tensor space [y0, y1, ...]
795  auto issue = [&](auto i_access_) {
796  constexpr auto IAccess = number<i_access_>{};
797  constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
798  auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
799  constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
800  auto bottom_tensor_flag = cached_flags_[IAccess];
801  // data index [y0, y1, ...]
802  constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
803 
804  // read from distributed tensor
805  vector_t vec_value;
806 
807  static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
808  constexpr auto idx_ys = generate_tuple(
809  [&](auto jj) {
810  return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
811  },
812  number<NDimY>{});
813 
814  constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
815 
816  vec_value.template get_as<DataType>()(j) =
817  dstr_tensor.get_thread_buffer().template at<d>();
818  });
819 
820  // write into bottom tensor
821  get_bottom_tensor_view().template set_vectorized_elements<vector_t>(
822  bottom_tensor_thread_coord,
823  linear_offset,
824  bottom_tensor_flag,
825  vec_value,
826  bool_constant<oob_conditional_check>{});
827  };
828 
830  }
831 
832  template <index_t i_access = -1>
834  number<i_access> = {}) const
835  {
836  using vector_t = typename traits::vector_t;
837  using SFC_Ys = typename traits::SFC_Ys;
838 
839  constexpr auto tile_dstr = TileDstr{};
840  static constexpr bool oob_conditional_check = true;
841 
842  // loop over thread tensor space [y0, y1, ...]
843  auto issue = [&](auto i_access_) {
844  constexpr auto IAccess = number<i_access_>{};
845  constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
846  auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
847  constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
848  auto bottom_tensor_flag = cached_flags_[IAccess];
849 
850  // data index [y0, y1, ...]
851  constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
852 
853  // read from distributed tensor
854  vector_t vec_value;
855  static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
856  constexpr auto idx_ys = generate_tuple(
857  [&](auto jj) {
858  return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
859  },
860  number<NDimY>{});
861  constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
862  vec_value.template get_as<DataType>()(j) =
863  dstr_tensor.get_thread_buffer().template at<d>();
864  });
865 
866  // write into bottom tensor
868  .template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
869  bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, vec_value);
870  };
871 
873  }
874 
875  template <index_t i_access = -1, bool oob_conditional_check = true>
877  number<i_access> = {},
879  {
880 
881  using vector_t = typename traits::vector_t;
882  using SFC_Ys = typename traits::SFC_Ys;
883 
884  constexpr auto tile_dstr = TileDstr{};
885 
886  // loop over thread tensor space [y0, y1, ...]
887  auto issue = [&](auto i_access_) {
888  constexpr auto IAccess = number<i_access_>{};
889  constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
890  auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
891  constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
892  auto bottom_tensor_flag = cached_flags_[IAccess];
893 
894  // data index [y0, y1, ...]
895  constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
896 
897  // read from distributed tensor
898  vector_t vec_value;
899 
900  static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
901  constexpr auto idx_ys = generate_tuple(
902  [&](auto jj) {
903  return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
904  },
905  number<NDimY>{});
906 
907  constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
908 
909  vec_value.template get_as<DataType>()(j) =
910  dstr_tensor.get_thread_buffer().template at<d>();
911  });
912 
913  // write into bottom tensor
914  get_bottom_tensor_view().template update_vectorized_elements<vector_t>(
915  bottom_tensor_thread_coord,
916  linear_offset,
917  bottom_tensor_flag,
918  vec_value,
919  bool_constant<oob_conditional_check>{});
920  };
921 
923  }
924 
925  template <index_t i_access = -1, bool oob_conditional_check = true, bool pre_nop = false>
927  number<i_access> = {},
929  bool_constant<pre_nop> = {}) const
930  {
931 
932  using vector_t = typename traits::vector_t;
933  using SFC_Ys = typename traits::SFC_Ys;
934 
935  constexpr auto tile_dstr = TileDstr{};
936 
937  // loop over thread tensor space [y0, y1, ...]
938  auto issue = [&](auto i_access_) {
939  constexpr auto IAccess = number<i_access_>{};
940  constexpr auto non_linear_id = number<AccessMap_NonLinear{}[IAccess]>{};
941  auto bottom_tensor_thread_coord = cached_coords_[non_linear_id];
942  constexpr auto linear_offset = get_bottom_linear_offset(IAccess);
943  auto bottom_tensor_flag = cached_flags_[IAccess];
944 
945  // data index [y0, y1, ...]
946  constexpr auto idx_ys_start = SFC_Ys::get_index(IAccess);
947 
948  // read from distributed tensor
949  vector_t vec_value;
950 
951  static_for<0, traits::ScalarPerVector, 1>{}([&](auto j) {
952  constexpr auto idx_ys = generate_tuple(
953  [&](auto jj) {
954  return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
955  },
956  number<NDimY>{});
957 
958  constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
959 
960  vec_value.template get_as<DataType>()(j) =
961  dstr_tensor.get_thread_buffer().template at<d>();
962  });
963 
964  // write into bottom tensor
965  get_bottom_tensor_view().template update_vectorized_elements_raw<vector_t>(
966  bottom_tensor_thread_coord,
967  linear_offset,
968  bottom_tensor_flag,
969  vec_value,
970  bool_constant<oob_conditional_check>{},
971  bool_constant<pre_nop>{});
972  };
973 
975  }
976 
977  // move thread's botom tensor coordiante
978  // [x0', x1', ... ] ==> [offset]
979  // also move window-origin
981  {
982  window_origin_ += step;
983 
984  static_for<0, NumAccess, 1>{}([&](auto i_access) {
985  constexpr auto IAccess = number<i_access>{};
986  constexpr auto non_linear_id = number<AccessMap_NonLinear{}[i_access]>{};
987  constexpr auto need_update_non_linear_coord =
988  bool_constant<AccessPrefixSum_NonLinear{}[non_linear_id] == i_access>{};
989 
990  if constexpr(need_update_non_linear_coord)
991  {
992  move_tensor_coordinate(bottom_tensor_view_.get_tensor_descriptor(),
993  cached_coords_(non_linear_id),
994  step);
995  }
996 
997  // move the current coord with linear_coords
998  auto tmp_coords = cached_coords_[non_linear_id];
999  constexpr auto linear_coord = get_bottom_linear_coordinate(IAccess);
1001  bottom_tensor_view_.get_tensor_descriptor(), tmp_coords, linear_coord);
1002 
1004  bottom_tensor_view_.get_tensor_descriptor(), tmp_coords);
1005  });
1006  }
1007 
1008  CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex& new_window_origin)
1009  {
1010  window_origin_ = new_window_origin;
1011 
1012  auto window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
1013  TileDstr{}.get_ps_ys_to_xs_adaptor(),
1015  generate_tuple([&](auto) { return number<0>{}; }, number<NDimY>{})));
1016 
1017  BottomTensorIndex bottom_tensor_thread_origin_idx_tmp =
1018  window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
1019 
1020  auto bottom_tensor_thread_coord_tmp = make_tensor_coordinate(
1021  bottom_tensor_view_.get_tensor_descriptor(), bottom_tensor_thread_origin_idx_tmp);
1022 
1023  // future load/store() calls (might allocate more registers)
1024  using SFC_Ys = typename traits::SFC_Ys;
1025 
1026  static_for<0, NumAccess, 1>{}([&](auto i_access) {
1027  constexpr auto non_linear_id = number<AccessMap_NonLinear{}[i_access]>{};
1028  constexpr auto need_save_non_linear_coord =
1029  bool_constant<AccessPrefixSum_NonLinear{}[non_linear_id] == i_access>{};
1030 
1031  if constexpr(need_save_non_linear_coord)
1032  {
1033  cached_coords_(non_linear_id) = bottom_tensor_thread_coord_tmp;
1034  }
1035 
1036  if constexpr(i_access != (NumAccess - 1))
1037  {
1038  constexpr auto idx_diff_ys = SFC_Ys::get_forward_step(i_access); // tuple of number
1039  constexpr auto idx_diff_ps_ys = container_concat(
1040  generate_tuple([&](auto) { return number<0>{}; }, number<NDimP>{}),
1041  idx_diff_ys);
1042 
1044  window_adaptor_thread_coord_tmp,
1045  bottom_tensor_thread_coord_tmp,
1046  idx_diff_ps_ys);
1047  }
1048  });
1049  }
1050 
1052 
1053  // this is the bottom tensor view
1054  // [x0', x1', ...] ==> [offset]
1056 
1057  //
1059 
1060  // origin ([x0', x1', ...]) of window on bottom tensor
1062 
1063  // Tile tensor distribution, which contains:
1064  // 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
1065  // 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
1067 
1068  // this contains:
1071 };
1072 
1073 #undef WINDOW_DISPATCH_ISSUE
1074 
1075 namespace impl {
1076 template <address_space_enum, index_t len_>
1078 {
1080 };
1081 
1082 template <index_t len_>
1084 {
1085  // global default to seq<0,0,....1>
1086  using type = typename sequence_merge<typename uniform_sequence_gen<len_ - 1, 0>::type,
1088 };
1089 
1090 template <index_t len_>
1092 {
1093  // lds default to seq<1,1.....1>
1095 };
1096 } // namespace impl
1097 
1098 template <typename TensorView_>
1100  typename impl::default_linear_bottom_dims_impl<TensorView_::buffer_view::get_address_space(),
1101  TensorView_::get_num_of_dimension()>::type;
1102 
1103 // if using this API, will create a tile_window_linear
1104 // this structure can have the chance to use immediate value, save register
1105 // need pass in LinearBottomDims_ properly to control which dim is linear
1106 // so to generate a constexpr offset as linear_offset for this dim
1107 // (and finally pass to the immediate offset of buffer/lds instruction)
1108 //
1109 // Note: there is no internal check for which dim is OK to use linear offset
1110 // user must make sure by themselves
1111 //
1112 // e.g.
1113 // 2d global matrix, set LinearBottomDims_=seq<0, 1>, the last dim will generate
1114 // immediate offset if each thread has multiple issue along last dim
1115 //
1116 // 2d LDS buffer, set LinearBottomDims_=seq<1, 1>, then only one vgpr used as offset
1117 // everything else is just using immediate offset.
1118 //
1119 template <typename TensorView_,
1120  typename WindowLengths_,
1121  typename StaticTileDistribution_,
1122  typename LinearBottomDims_ = default_linear_bottom_dims<TensorView_>>
1123 CK_TILE_DEVICE constexpr auto
1125  const WindowLengths_& window_lengths,
1126  const multi_index<TensorView_::get_num_of_dimension()>& origin,
1127  const StaticTileDistribution_& tile_distribution,
1128  LinearBottomDims_ = {})
1129 {
1130  static_assert(LinearBottomDims_::size() == TensorView_::get_num_of_dimension());
1131  return tile_window_linear<remove_cvref_t<TensorView_>,
1132  remove_cvref_t<WindowLengths_>,
1133  remove_cvref_t<StaticTileDistribution_>,
1134  remove_cvref_t<LinearBottomDims_>>{
1135  tensor_view, window_lengths, origin, tile_distribution};
1136 }
1137 
1138 template <
1139  typename TileWindow_,
1140  typename StaticTileDistribution_,
1141  typename LinearBottomDims_ = default_linear_bottom_dims<typename TileWindow_::BottomTensorView>>
1142 CK_TILE_DEVICE constexpr auto
1143 make_tile_window_linear(const TileWindow_& tile_window,
1144  const StaticTileDistribution_& tile_distribution,
1145  LinearBottomDims_ = {})
1146 {
1147  return make_tile_window_linear(tile_window.get_bottom_tensor_view(),
1148  tile_window.get_window_lengths(),
1149  tile_window.get_window_origin(),
1150  tile_distribution,
1151  LinearBottomDims_{});
1152 }
1153 
1154 // this version must not be called under a constexpr context
1155 template <typename TensorView_,
1156  typename WindowLengths_,
1157  typename StaticTileDistribution_,
1158  typename LinearBottomDims_ = default_linear_bottom_dims<TensorView_>>
1159 CK_TILE_DEVICE auto
1161  const WindowLengths_& window_lengths,
1162  const multi_index<TensorView_::get_num_of_dimension()>& origin,
1163  const StaticTileDistribution_& tile_distribution,
1164  LinearBottomDims_ = {})
1165 {
1166  static_assert(LinearBottomDims_::size() == TensorView_::get_num_of_dimension());
1167  auto w = tile_window_linear<remove_cvref_t<TensorView_>,
1168  remove_cvref_t<WindowLengths_>,
1169  remove_cvref_t<StaticTileDistribution_>,
1170  remove_cvref_t<LinearBottomDims_>>{
1171  tensor_view, window_lengths, origin, tile_distribution};
1172  w.init_raw();
1173  return w;
1174 }
1175 
1176 template <
1177  typename TileWindow_,
1178  typename StaticTileDistribution_,
1179  typename LinearBottomDims_ = default_linear_bottom_dims<typename TileWindow_::BottomTensorView>>
1180 CK_TILE_DEVICE constexpr auto
1181 make_tile_window_linear_raw(const TileWindow_& tile_window,
1182  const StaticTileDistribution_& tile_distribution,
1183  LinearBottomDims_ = {})
1184 {
1185  return make_tile_window_linear_raw(tile_window.get_bottom_tensor_view(),
1186  tile_window.get_window_lengths(),
1187  tile_window.get_window_origin(),
1188  tile_distribution,
1189  LinearBottomDims_{});
1190 }
1191 
1192 template <typename TensorView_,
1193  typename WindowLengths_,
1194  typename StaticTileDistribution_,
1195  typename LinearBottomDims_>
1198  window,
1199  const typename tile_window_linear<TensorView_,
1200  WindowLengths_,
1201  StaticTileDistribution_,
1202  LinearBottomDims_>::BottomTensorIndex& step)
1203 {
1204  window.move(step);
1205 }
1206 
1207 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_LDS_ADDR
Definition: config.hpp:56
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
typename impl::default_linear_bottom_dims_impl< TensorView_::buffer_view::get_address_space(), TensorView_::get_num_of_dimension()>::type default_linear_bottom_dims
Definition: tile_window_linear.hpp:1101
CK_TILE_DEVICE index_t get_lane_id()
Definition: arch.hpp:69
constexpr CK_TILE_HOST_DEVICE void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition: container_helper.hpp:420
constexpr CK_TILE_HOST_DEVICE void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition: tensor_coordinate.hpp:72
tuple_array< T, N > thread_buffer
Definition: thread_buffer.hpp:14
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:55
constant< b > bool_constant
Definition: integral_constant.hpp:39
constexpr CK_TILE_HOST_DEVICE auto generate_array(F &&f, number< N >)
Definition: sequence.hpp:1106
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition: tensor_coordinate.hpp:60
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constant< v > number
Definition: integral_constant.hpp:33
constexpr CK_TILE_HOST_DEVICE index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition: sequence.hpp:973
constexpr CK_TILE_HOST_DEVICE bool coordinate_has_valid_offset_assuming_top_index_is_valid(const TensorDesc &tensor_desc, const TensorCoord &coord)
Definition: tensor_coordinate.hpp:79
CK_TILE_DEVICE auto make_tile_window_linear_raw(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:1160
CK_TILE_DEVICE index_t get_warp_id()
Definition: arch.hpp:71
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1046
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:92
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:1124
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:400
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
Definition: utility.hpp:19
typename std::remove_reference< T >::type remove_reference_t
Definition: type_traits.hpp:14
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
address_space_enum
Definition: arch.hpp:34
constexpr CK_TILE_HOST_DEVICE void move_tensor_adaptor_coordinate(const Adaptor &adaptor, AdaptorCoord &coord, const TopIndex &idx_diff_top, BottomIndex &idx_diff_bottom)
Definition: tensor_adaptor_coordinate.hpp:97
constexpr CK_TILE_HOST_DEVICE auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition: container_helper.hpp:389
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
Definition: utility.hpp:25
constexpr CK_TILE_HOST_DEVICE auto histogram_sorted_sequence(SeqSortedSamples, sequence< r, rs... >)
Definition: sequence.hpp:1093
impl::is_static_impl< remove_cvref_t< T > > is_static
Definition: type_traits.hpp:86
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
constexpr auto prefix_sum_sequence(Seq)
Definition: sequence.hpp:899
Definition: sequence.hpp:278
typename std::conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:293
Definition: array.hpp:24
Definition: integral_constant.hpp:13
typename sequence_merge< typename uniform_sequence_gen< len_ - 1, 0 >::type, sequence< 1 > >::type type
Definition: tile_window_linear.hpp:1087
typename uniform_sequence_gen< len_, 1 >::type type
Definition: tile_window_linear.hpp:1094
Definition: tile_window_linear.hpp:1078
typename uniform_sequence_gen< len_, 0 >::type type
Definition: tile_window_linear.hpp:1079
Definition: type_traits.hpp:75
Definition: math.hpp:98
Definition: sequence.hpp:227
Definition: sequence.hpp:52
Definition: space_filling_curve.hpp:20
Definition: static_distributed_tensor.hpp:21
constexpr CK_TILE_HOST_DEVICE const auto & get_thread_buffer() const
Definition: static_distributed_tensor.hpp:56
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
Definition: tile_distribution.hpp:72
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:126
Definition: tile_window_linear.hpp:87
decltype(get_space_filling_curve()) SFC_Ys
Definition: tile_window_linear.hpp:186
decltype(get_non_linear_access_histogram_prefix_sum()) AccessPrefixSum_NonLinear
Definition: tile_window_linear.hpp:293
thread_buffer< DataType, ScalarPerVector > vector_t
Definition: tile_window_linear.hpp:158
static constexpr index_t NumAccess
Definition: tile_window_linear.hpp:188
static constexpr index_t VectorDimY
Definition: tile_window_linear.hpp:154
static constexpr index_t ScalarPerVector
Definition: tile_window_linear.hpp:155
decltype(get_non_linear_access_map()) AccessMap_NonLinear
Definition: tile_window_linear.hpp:291
static constexpr index_t NumAccess_NonLinear
Definition: tile_window_linear.hpp:290
decltype(get_non_linear_access_histogram()) AccessHistogram_NonLinear
Definition: tile_window_linear.hpp:292
Definition: tile_window_linear.hpp:46
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition: tile_window_linear.hpp:384
static constexpr auto I0
Definition: tile_window_linear.hpp:65
array< bool, traits::NumAccess > cached_flags_
Definition: tile_window_linear.hpp:1070
constexpr CK_TILE_DEVICE auto get_num_of_access() const
Definition: tile_window_linear.hpp:469
CK_TILE_DEVICE auto load(number< i_access >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window_linear.hpp:472
array< BottomTensorCoord, traits::NumAccess_NonLinear > cached_coords_
Definition: tile_window_linear.hpp:1069
constexpr CK_TILE_DEVICE tile_window_linear()=default
static constexpr index_t NDimWindowAdaptorTop
Definition: tile_window_linear.hpp:59
CK_TILE_DEVICE auto async_load(LdsTileWindow_ &&lds_tile, number< i_access >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window_linear.hpp:717
array< index_t, NDimWindowAdaptorTop > AdaptorTopIndex
Definition: tile_window_linear.hpp:77
CK_TILE_DEVICE void load_raw(DstTile &dst_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window_linear.hpp:589
constexpr CK_TILE_DEVICE auto get_tile_distribution() const
Definition: tile_window_linear.hpp:369
static constexpr CK_TILE_DEVICE bool has_static_tile_distribution()
Definition: tile_window_linear.hpp:362
static constexpr index_t NDimY
Definition: tile_window_linear.hpp:63
static constexpr CK_TILE_DEVICE index_t get_bottom_linear_offset(number< i_access >)
Definition: tile_window_linear.hpp:432
static constexpr index_t NDimP
Definition: tile_window_linear.hpp:62
typename traits::AccessHistogram_NonLinear AccessHistogram_NonLinear
Definition: tile_window_linear.hpp:299
typename traits::AccessMap_NonLinear AccessMap_NonLinear
Definition: tile_window_linear.hpp:298
static constexpr index_t NumAccess
Definition: tile_window_linear.hpp:296
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_window_linear.hpp:367
TileDstr tile_dstr_
Definition: tile_window_linear.hpp:1066
static constexpr CK_TILE_DEVICE auto get_bottom_linear_coordinate(number< i_access >)
Definition: tile_window_linear.hpp:402
constexpr CK_TILE_DEVICE void set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType *data)
Definition: tile_window_linear.hpp:376
array< index_t, NDimBottomTensor > BottomTensorIndex
Definition: tile_window_linear.hpp:78
remove_cvref_t< typename BottomTensorView::DataType > DataType
Definition: tile_window_linear.hpp:54
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition: tile_window_linear.hpp:980
remove_cvref_t< StaticTileDistribution_ > TileDstr
Definition: tile_window_linear.hpp:49
CK_TILE_DEVICE void update(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window_linear.hpp:876
static constexpr index_t NDimBottomTensor
Definition: tile_window_linear.hpp:60
typename traits::AccessPrefixSum_NonLinear AccessPrefixSum_NonLinear
Definition: tile_window_linear.hpp:300
BottomTensorIndex window_origin_
Definition: tile_window_linear.hpp:1061
WindowLengths window_lengths_
Definition: tile_window_linear.hpp:1058
static constexpr index_t NumAccess_NonLinear
Definition: tile_window_linear.hpp:297
typename TileDstr::PsYs2XsAdaptor WindowAdaptor
Definition: tile_window_linear.hpp:51
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})) BottomTensorCoord
Definition: tile_window_linear.hpp:84
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_window_linear.hpp:373
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex &new_window_origin)
Definition: tile_window_linear.hpp:1008
typename BottomTensorView::TensorDesc BottomTensorDesc
Definition: tile_window_linear.hpp:52
CK_TILE_DEVICE void update_raw(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window_linear.hpp:926
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})) WindowAdaptorCoord
Definition: tile_window_linear.hpp:81
static constexpr CK_TILE_DEVICE index_t get_num_of_dimension()
Definition: tile_window_linear.hpp:360
CK_TILE_DEVICE void store_raw(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access >={}) const
Definition: tile_window_linear.hpp:833
CK_TILE_DEVICE auto load(DstTile &dst_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window_linear.hpp:528
CK_TILE_HOST_DEVICE void init_raw()
Definition: tile_window_linear.hpp:1051
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_window_linear.hpp:371
remove_reference_t< BottomTensorView_ > BottomTensorView
Definition: tile_window_linear.hpp:47
CK_TILE_DEVICE void store(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window_linear.hpp:784
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_ &&lds_tile, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window_linear.hpp:647
remove_cvref_t< LinearBottomDims_ > LinearBottomDims
Definition: tile_window_linear.hpp:55
remove_cvref_t< WindowLengths_ > WindowLengths
Definition: tile_window_linear.hpp:48
static constexpr auto I1
Definition: tile_window_linear.hpp:66
BottomTensorView bottom_tensor_view_
Definition: tile_window_linear.hpp:1055
constexpr CK_TILE_DEVICE tile_window_linear(const BottomTensorView &bottom_tensor_view, const WindowLengths &window_lengths, const BottomTensorIndex &window_origin, const TileDstr &tile_distribution)
Definition: tile_window_linear.hpp:304
Definition: sequence.hpp:305
typename sequence_gen< NSize, F >::type type
Definition: sequence.hpp:311
#define WINDOW_DISPATCH_ISSUE()
Definition: tile_window_linear.hpp:21
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10