32 template <
typename BottomTensorView_,
33 typename WindowLengths_,
34 typename StaticTileDistribution_,
35 typename StaticPageIndexArray_,
36 typename StaticValidArray_,
55 static constexpr
index_t NDimP = TileDstr::get_num_of_dimension_p();
56 static constexpr
index_t NDimY = TileDstr::get_num_of_dimension_y();
60 static_assert(NumCoord == 1);
65 "wrong! lengths should be static");
68 static_assert(
NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
69 "wrong! inconsistent # of diemsnions");
83 static constexpr
auto get_vector_dim_y_scalar_per_vector()
85 const auto [ys_vector_lengths, ys_vector_strides] =
93 if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
95 ScalarPerVector_ = ys_vector_lengths[i];
100 return make_tuple(VectorDimY_, ScalarPerVector_);
108 get_vector_dim_y_scalar_per_vector().template at<1>();
115 static constexpr
auto scalars_per_access_ = [] {
120 constexpr
auto NDimY_ =
NDimY;
122 return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
125 static constexpr
auto get_space_filling_curve()
127 constexpr
auto tile_dstr =
TileDstr{};
129 constexpr
auto thread_tensor_lengths_ys =
130 to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
137 decltype(scalars_per_access_)>{};
141 using SFC_Ys = decltype(get_space_filling_curve());
145 static_assert(0 <
NumAccess,
"Wrong! NumAccess should be larger than 0");
146 static_assert(
NumAccess % NumCoord == 0,
"wrong! # of access is not divisible by NumCoord");
171 static_assert(
NDimP == 1 or
NDimP == 2,
"wrong!");
175 if constexpr(
NDimP == 1)
180 else if constexpr(
NDimP == 2)
182 window_adaptor_thread_coord_tmp =
195 window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
196 bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
203 using SFC_Ys =
typename Traits::SFC_Ys;
206 auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
207 auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
209 constexpr
auto idx_diff_ys =
216 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
219 make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
246 template <
typename ATopIndex>
250 const ATopIndex& idx_diff_adaptor_top)
const
255 window_adaptor_thread_coord,
256 idx_diff_adaptor_top,
257 idx_diff_adaptor_bottom);
260 bottom_tensor_thread_coord,
261 idx_diff_adaptor_bottom);
268 const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
269 BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
272 const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
273 const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
276 array<
index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
278 array<
index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
281 constexpr
auto window_adaptor_bottom_dims =
282 WindowAdaptor::get_bottom_dimension_hidden_ids();
285 window_adaptor_bottom_dims,
286 window_adaptor_bottom_dim_vector_lengths);
288 window_adaptor_bottom_dims,
289 window_adaptor_bottom_dim_vector_strides);
291 const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
292 WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
293 window_adaptor_vector_lengths, window_adaptor_vector_strides);
306 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true>
310 constexpr
auto tile_dstr =
TileDstr{};
311 auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
312 load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
316 template <
typename DistributedTensor,
317 index_t i_access_unsupport_ = -1,
318 bool oob_conditional_check =
true>
323 using Traits = load_store_traits;
324 using vector_t =
typename Traits::vector_t;
325 using SFC_Ys =
typename Traits::SFC_Ys;
327 constexpr
auto tile_dstr =
TileDstr{};
330 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
335 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
336 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
339 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
340 constexpr
auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
341 const auto page_offset =
page_idx_[idx_gather];
344 const vector_t vec_value = [&]() {
345 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
348 bottom_tensor_thread_coord,
350 bool_constant<oob_conditional_check>{});
355 bottom_tensor_thread_coord,
358 bool_constant<oob_conditional_check>{});
363 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
366 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
372 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
375 dst_tensor.get_thread_buffer().template at<d>() =
376 vec_value.template get_as<DataType>()[j / Traits::PackedSize];
380 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
381 static_assert(d % Traits::ScalarPerVector == 0);
383 dst_tensor.get_thread_buffer().template get_as<vector_t>()(
384 number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
389 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
392 [&](
auto i) {
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
396 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
397 forward_step_scatter);
400 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
406 template <
typename LdsTileWindow_,
407 index_t i_access_unsupport_ = -1,
408 bool oob_conditional_check =
true>
413 using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
414 using LdsDataType =
typename LdsTileWindow::DataType;
415 using Traits = load_store_traits;
416 using vector_t =
typename Traits::vector_t;
417 using SFC_Ys =
typename Traits::SFC_Ys;
419 constexpr
auto tile_dstr =
TileDstr{};
422 const auto window_origin = lds_tile.get_window_origin();
423 const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
424 const auto& tensor_descriptor = bottom_tensor_view.get_tensor_descriptor();
425 auto smem_base_ptr = bottom_tensor_view.get_buffer_view().p_data_;
428 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
436 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
437 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
440 auto lds_bottom_tensor_thread_idx =
441 window_origin + lds_window_adaptor_thread_coord.get_bottom_index();
443 const auto lds_coord =
446 CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
449 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
450 constexpr
auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
451 const auto page_offset =
page_idx_[idx_gather];
454 auto mixed_bottom_thread_coord = bottom_tensor_thread_coord;
455 mixed_bottom_thread_coord.get_hidden_index()[number<0>{}] += page_offset;
458 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
461 mixed_bottom_thread_coord,
463 bool_constant<oob_conditional_check>{});
467 mixed_bottom_thread_coord,
470 bool_constant<oob_conditional_check>{});
475 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
478 [&](
auto i) {
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
482 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
483 forward_step_scatter);
486 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
490 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
492 lds_window_adaptor_thread_coord,
493 lds_bottom_tensor_thread_coord,
501 template <
typename LdsTileWindow_,
502 index_t i_access_unsupport_ = -1,
503 bool oob_conditional_check =
true,
504 bool pre_nop =
false>
508 bool_constant<pre_nop> = {})
const
510 using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
512 using LdsDataType =
typename LdsTileWindow::DataType;
516 static_assert(LdsTileWindow::get_num_of_dimension() == 3);
519 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
520 make_tuple(number<0>{}, number<0>{}, number<0>{})) *
524 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
525 make_tuple(number<0>{}, number<1>{}, number<0>{})) *
526 sizeof(LdsDataType) -
530 lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
531 make_tuple(number<1>{}, number<0>{}, number<0>{})) *
532 sizeof(LdsDataType) -
535 const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
538 using Traits = load_store_traits;
541 using vector_t =
typename Traits::vector_t;
542 using SFC_Ys =
typename Traits::SFC_Ys;
544 LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
547 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
552 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
553 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
554 constexpr
auto pre_nop_ = [&]() {
555 if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
558 return bool_constant<false>{};
561 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
562 constexpr
auto idx_gather = idx_ys_start[number<YsGatherDim>{}];
563 const auto page_offset =
page_idx_[idx_gather];
566 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
569 smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
575 bottom_tensor_thread_coord,
585 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
588 [&](
auto i) {
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
592 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
593 forward_step_scatter);
596 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
604 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true>
609 using Traits = load_store_traits;
612 using vector_t =
typename Traits::vector_t;
613 using SFC_Ys =
typename Traits::SFC_Ys;
615 constexpr
auto tile_dstr =
TileDstr{};
617 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
621 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
622 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
625 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
626 constexpr
auto idx_gather = idx_ys_start[number<0>{}];
627 const auto page_offset =
page_idx_[idx_gather];
632 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
635 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
641 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
644 vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
649 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
652 bottom_tensor_thread_coord,
655 bool_constant<oob_conditional_check>{});
660 bottom_tensor_thread_coord,
664 bool_constant<oob_conditional_check>{});
669 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
672 [&](
auto i) {
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
676 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
677 forward_step_scatter);
680 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
686 template <
index_t i_access_unsupport_ = -1,
bool oob_conditional_check =
true>
691 using Traits = load_store_traits;
694 using vector_t =
typename Traits::vector_t;
695 using SFC_Ys =
typename Traits::SFC_Ys;
697 constexpr
auto tile_dstr =
TileDstr{};
700 static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
704 static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
705 constexpr
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
708 constexpr
auto idx_ys_start = SFC_Ys::get_index(iAccess);
709 constexpr
auto idx_gather = idx_ys_start[number<0>{}];
710 const auto page_offset =
page_idx_[idx_gather];
719 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
722 return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
728 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
731 vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
738 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
741 bottom_tensor_thread_coord,
744 bool_constant<oob_conditional_check>{});
749 bottom_tensor_thread_coord,
753 bool_constant<oob_conditional_check>{});
760 constexpr
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
763 [&](
auto i) {
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
767 generate_tuple([&](
auto) {
return number<0>{}; }, number<NDimP>{}),
768 forward_step_scatter);
771 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
784 step_new(HsGatherDim) = 0;
796 if constexpr(std::is_same_v<ValidArray, std::nullptr_t> ==
false)
817 static_assert(
NDimP == 1 or
NDimP == 2,
"wrong!");
821 if constexpr(
NDimP == 1)
826 else if constexpr(
NDimP == 2)
828 window_adaptor_thread_coord_tmp =
841 window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
843 bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
850 using SFC_Ys =
typename Traits::SFC_Ys;
853 auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
854 auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp;
856 constexpr
auto idx_diff_ys =
863 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
866 make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
897 template <
typename TensorView_,
898 typename WindowLengths_,
899 typename StaticTileDistribution_,
900 typename StaticPageIndexArray_,
905 const WindowLengths_& window_lengths,
906 const multi_index<TensorView_::get_num_of_dimension()>& origin,
908 const StaticPageIndexArray_& page_idx,
910 number<NumCoord> = {})
912 return tile_scatter_gather<remove_cvref_t<TensorView_>,
913 remove_cvref_t<WindowLengths_>,
914 remove_cvref_t<StaticTileDistribution_>,
915 remove_cvref_t<StaticPageIndexArray_>,
919 tensor_view, window_lengths, origin, tile_distribution, page_idx,
nullptr};
922 template <
typename TensorView,
923 typename WindowLengths,
924 typename StaticTileDistribution,
925 typename StaticPageIndexArray,
929 const multi_index<TensorView::get_num_of_dimension()>& origin,
931 const StaticPageIndexArray& page_idx,
939 number<HsGatherDim>{});
942 template <
typename TensorView,
943 typename WindowLengths,
944 typename StaticTileDistribution,
945 typename StaticPageIndexArray,
950 const StaticPageIndexArray& page_idx,
958 number<HsGatherDim>{});
961 template <
typename TensorView_,
962 typename WindowLengths_,
963 typename StaticTileDistribution_,
964 typename StaticPageIndexArray_,
965 typename StaticValidArray_,
970 const WindowLengths_& window_lengths,
971 const multi_index<TensorView_::get_num_of_dimension()>& origin,
973 const StaticPageIndexArray_& page_idx,
974 const StaticValidArray_& valids,
976 number<NumCoord> = {})
978 return tile_scatter_gather<remove_cvref_t<TensorView_>,
979 remove_cvref_t<WindowLengths_>,
980 remove_cvref_t<StaticTileDistribution_>,
981 remove_cvref_t<StaticPageIndexArray_>,
982 remove_cvref_t<StaticValidArray_>,
985 tensor_view, window_lengths, origin, tile_distribution, page_idx, valids};
988 template <
typename TensorView,
989 typename WindowLengths,
990 typename StaticTileDistribution,
991 typename StaticPageIndexArray,
992 typename StaticValidArray,
996 const multi_index<TensorView::get_num_of_dimension()>& origin,
998 const StaticPageIndexArray& page_idx,
999 const StaticValidArray& valids,
1008 number<HsGatherDim>{});
1011 template <
typename TensorView,
1012 typename WindowLengths,
1013 typename StaticTileDistribution,
1014 typename StaticPageIndexArray,
1015 typename StaticValidArray,
1020 const StaticPageIndexArray& page_idx,
1021 const StaticValidArray& valids,
1030 number<HsGatherDim>{});
1033 template <
typename NewTensorView_,
1034 typename OldTensorView_,
1035 typename WindowLengths_,
1036 typename StaticTileDistribution_,
1037 typename StaticPageIndexArray_,
1038 typename StaticValidArray_,
1044 StaticTileDistribution_,
1045 StaticPageIndexArray_,
1048 NumCoord>& tile_window)
1053 tile_window.tile_dstr_,
1054 tile_window.page_idx_,
1055 tile_window.valids_);
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_LDS_ADDR
Definition: config.hpp:62
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition: container_helper.hpp:420
constexpr CK_TILE_HOST_DEVICE void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition: tensor_coordinate.hpp:72
CK_TILE_DEVICE auto replace_bottom_tensor_view(const NewTensorView_ &new_tensor_view, const tile_scatter_gather< OldTensorView_, WindowLengths_, StaticTileDistribution_, StaticPageIndexArray_, StaticValidArray_, HsGatherDim, NumCoord > &tile_window)
Definition: tile_scatter_gather.hpp:1041
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:56
constant< b > bool_constant
Definition: integral_constant.hpp:43
constexpr CK_TILE_HOST_DEVICE auto generate_array(F &&f, number< N >)
Definition: sequence.hpp:1126
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition: tensor_coordinate.hpp:60
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constant< v > number
Definition: integral_constant.hpp:37
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1066
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:21
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
Definition: utility.hpp:19
typename std::remove_reference< T >::type remove_reference_t
Definition: type_traits.hpp:15
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_DEVICE auto make_tile_scatter_gather(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, const StaticPageIndexArray_ &page_idx, number< HsGatherDim >={}, number< NumCoord >={})
Definition: tile_scatter_gather.hpp:904
constexpr CK_TILE_HOST_DEVICE void move_tensor_adaptor_coordinate(const Adaptor &adaptor, AdaptorCoord &coord, const TopIndex &idx_diff_top, BottomIndex &idx_diff_bottom)
Definition: tensor_adaptor_coordinate.hpp:98
constexpr CK_TILE_HOST_DEVICE auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition: container_helper.hpp:389
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
Definition: utility.hpp:25
impl::is_static_impl< remove_cvref_t< T > > is_static
Definition: type_traits.hpp:87
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
Definition: sequence.hpp:298
typename std::conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:313
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:76
Definition: numeric.hpp:81
Definition: space_filling_curve.hpp:20
Definition: static_distributed_tensor.hpp:21
constexpr CK_TILE_HOST_DEVICE const auto & get_thread_buffer() const
Definition: static_distributed_tensor.hpp:58
Definition: functional.hpp:43
Definition: tensor_view.hpp:41
Definition: tile_distribution.hpp:70
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:124
Definition: tile_scatter_gather.hpp:81
static constexpr index_t PackedSize
Definition: tile_scatter_gather.hpp:104
static constexpr index_t NumAccess
Definition: tile_scatter_gather.hpp:143
decltype(get_space_filling_curve()) SFC_Ys
Definition: tile_scatter_gather.hpp:141
static constexpr index_t VectorDimY
Definition: tile_scatter_gather.hpp:106
static constexpr index_t ScalarPerVector
Definition: tile_scatter_gather.hpp:107
This class provides tile (windowed) view and access to the device memory.
Definition: tile_scatter_gather.hpp:41
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition: tile_scatter_gather.hpp:780
static constexpr index_t NumAccessPerCoord
Definition: tile_scatter_gather.hpp:149
static constexpr auto I1
Definition: tile_scatter_gather.hpp:59
constexpr CK_TILE_DEVICE tile_scatter_gather(const BottomTensorView &bottom_tensor_view, const WindowLengths &window_lengths, const BottomTensorIndex &window_origin, const TileDstr &tile_distribution, const PageIdxArray &page_idx, const ValidArray &valids)
Definition: tile_scatter_gather.hpp:153
BottomTensorIndex window_origin_
Definition: tile_scatter_gather.hpp:880
CK_TILE_DEVICE auto load(number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:307
WindowLengths window_lengths_
Definition: tile_scatter_gather.hpp:877
constexpr CK_TILE_DEVICE auto get_tile_distribution() const
Definition: tile_scatter_gather.hpp:232
constexpr CK_TILE_DEVICE auto get_num_of_access() const
Definition: tile_scatter_gather.hpp:304
static constexpr index_t NDimBottomTensor
Definition: tile_scatter_gather.hpp:53
static constexpr CK_TILE_DEVICE auto get_window_adaptor_ys_safe_vector_length_strides()
Definition: tile_scatter_gather.hpp:265
array< index_t, NDimBottomTensor > BottomTensorIndex
Definition: tile_scatter_gather.hpp:72
PageIdxArray page_idx_
Definition: tile_scatter_gather.hpp:887
CK_TILE_DEVICE auto async_load(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:409
remove_cvref_t< WindowLengths_ > WindowLengths
Definition: tile_scatter_gather.hpp:43
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex &new_window_origin)
Definition: tile_scatter_gather.hpp:809
array< tuple< WindowAdaptorCoord, BottomTensorCoord >, NumCoord > pre_computed_coords_
Definition: tile_scatter_gather.hpp:893
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_scatter_gather.hpp:236
remove_cvref_t< StaticTileDistribution_ > TileDstr
Definition: tile_scatter_gather.hpp:44
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition: tile_scatter_gather.hpp:247
CK_TILE_DEVICE auto load(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:319
CK_TILE_DEVICE void store(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:687
CK_TILE_DEVICE void update_page_idx_and_valids(const PageIdxArray &new_idx, const ValidArray &new_valids)
Definition: tile_scatter_gather.hpp:802
typename BottomTensorView::TensorDesc BottomTensorDesc
Definition: tile_scatter_gather.hpp:48
TileDstr tile_dstr_
Definition: tile_scatter_gather.hpp:885
ValidArray valids_
Definition: tile_scatter_gather.hpp:888
static constexpr index_t NDimY
Definition: tile_scatter_gather.hpp:56
remove_cvref_t< typename BottomTensorView::DataType > DataType
Definition: tile_scatter_gather.hpp:50
static constexpr index_t NDimWindowAdaptorTop
Definition: tile_scatter_gather.hpp:52
static constexpr CK_TILE_DEVICE bool has_static_tile_distribution()
Definition: tile_scatter_gather.hpp:225
remove_cvref_t< StaticValidArray_ > ValidArray
Definition: tile_scatter_gather.hpp:46
static constexpr index_t NDimP
Definition: tile_scatter_gather.hpp:55
remove_reference_t< BottomTensorView_ > BottomTensorView
Definition: tile_scatter_gather.hpp:42
constexpr CK_TILE_DEVICE tile_scatter_gather()=default
CK_TILE_DEVICE void update(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:605
remove_cvref_t< StaticPageIndexArray_ > PageIdxArray
Definition: tile_scatter_gather.hpp:45
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_scatter_gather.hpp:230
CK_TILE_HOST_DEVICE void init_raw()
Definition: tile_scatter_gather.hpp:870
static constexpr auto I0
Definition: tile_scatter_gather.hpp:58
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})) BottomTensorCoord
Definition: tile_scatter_gather.hpp:78
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_scatter_gather.hpp:234
typename TileDstr::PsYs2XsAdaptor WindowAdaptor
Definition: tile_scatter_gather.hpp:47
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})) WindowAdaptorCoord
Definition: tile_scatter_gather.hpp:75
constexpr CK_TILE_DEVICE void set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType *data)
Definition: tile_scatter_gather.hpp:239
BottomTensorView bottom_tensor_view_
Definition: tile_scatter_gather.hpp:874
CK_TILE_DEVICE void update_valids(const ValidArray &new_valids)
Definition: tile_scatter_gather.hpp:794
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_scatter_gather.hpp:505
array< index_t, NDimWindowAdaptorTop > AdaptorTopIndex
Definition: tile_scatter_gather.hpp:71
CK_TILE_DEVICE void update_page_idx(const PageIdxArray &new_idx)
Definition: tile_scatter_gather.hpp:792
static constexpr CK_TILE_DEVICE index_t get_num_of_dimension()
Definition: tile_scatter_gather.hpp:223
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_window_base.hpp:45
BottomTensorIndex window_origin_
Definition: tile_window_base.hpp:79
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_window_base.hpp:47
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_window_base.hpp:46
WindowLengths window_lengths_
Definition: tile_window_base.hpp:81
This class provides description of tile windowed view on the device memory.
Definition: tile_window.hpp:1195
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10