32 template <
typename BottomTensorView_,
 
   33           typename WindowLengths_,
 
   34           typename StaticTileDistribution_,
 
   35           typename StaticPageIndexArray_,
 
   36           typename StaticValidArray_,
 
   55     static constexpr 
index_t NDimP = TileDstr::get_num_of_dimension_p();
 
   56     static constexpr 
index_t NDimY = TileDstr::get_num_of_dimension_y();
 
   60     static_assert(NumCoord == 1);
 
   65                   "wrong! lengths should be static");
 
   68     static_assert(
NDimBottomTensor == WindowAdaptor::get_num_of_bottom_dimension(),
 
   69                   "wrong! inconsistent # of diemsnions");
 
   83         static constexpr 
auto get_vector_dim_y_scalar_per_vector()
 
   85             const auto [ys_vector_lengths, ys_vector_strides] =
 
   93                 if(ys_vector_strides[i] == 1 && ys_vector_lengths[i] > ScalarPerVector_)
 
   95                     ScalarPerVector_ = ys_vector_lengths[i];
 
  100             return make_tuple(VectorDimY_, ScalarPerVector_);
 
  108             get_vector_dim_y_scalar_per_vector().template at<1>();
 
  115         static constexpr 
auto scalars_per_access_ = [] {
 
  120             constexpr 
auto NDimY_ = 
NDimY;
 
  122             return TO_SEQUENCE(scalars_per_access_arr, NDimY_);
 
  125         static constexpr 
auto get_space_filling_curve()
 
  127             constexpr 
auto tile_dstr = 
TileDstr{};
 
  129             constexpr 
auto thread_tensor_lengths_ys =
 
  130                 to_sequence(tile_dstr.get_ys_to_d_descriptor().get_lengths());
 
  137                                        decltype(scalars_per_access_)>{};
 
  141         using SFC_Ys = decltype(get_space_filling_curve());
 
  145         static_assert(0 < 
NumAccess, 
"Wrong! NumAccess should be larger than 0");
 
  146         static_assert(
NumAccess % NumCoord == 0, 
"wrong! # of access is not divisible by NumCoord");
 
  171         static_assert(
NDimP == 1 or 
NDimP == 2, 
"wrong!");
 
  175         if constexpr(
NDimP == 1)
 
  180         else if constexpr(
NDimP == 2)
 
  182             window_adaptor_thread_coord_tmp =
 
  196             window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
 
  197         bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
 
  204         using SFC_Ys = 
typename Traits::SFC_Ys;
 
  207             auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
 
  208             auto bottom_tensor_thread_coord  = bottom_tensor_thread_coord_tmp;
 
  210             constexpr 
auto idx_diff_ys =
 
  217                 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
 
  220                 make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
 
  247     template <
typename ATopIndex>
 
  251         const ATopIndex& idx_diff_adaptor_top)
 const 
  256                                        window_adaptor_thread_coord,
 
  257                                        idx_diff_adaptor_top,
 
  258                                        idx_diff_adaptor_bottom);
 
  261                                bottom_tensor_thread_coord,
 
  262                                idx_diff_adaptor_bottom);
 
  269         const auto [bottom_tensor_top_dim_vector_lengths, bottom_tensor_top_dim_vector_strides] =
 
  270             BottomTensorDesc::get_top_dimension_safe_vector_length_strides();
 
  273         const auto window_adaptor_bottom_dim_vector_lengths = bottom_tensor_top_dim_vector_lengths;
 
  274         const auto window_adaptor_bottom_dim_vector_strides = bottom_tensor_top_dim_vector_strides;
 
  277         array<
index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_lengths{
 
  279         array<
index_t, WindowAdaptor::get_num_of_hidden_dimension()> window_adaptor_vector_strides{
 
  282         constexpr 
auto window_adaptor_bottom_dims =
 
  283             WindowAdaptor::get_bottom_dimension_hidden_ids();
 
  286                              window_adaptor_bottom_dims,
 
  287                              window_adaptor_bottom_dim_vector_lengths);
 
  289                              window_adaptor_bottom_dims,
 
  290                              window_adaptor_bottom_dim_vector_strides);
 
  292         const auto [window_adaptor_ps_ys_vector_lengths, window_adaptor_ps_ys_vector_strides] =
 
  293             WindowAdaptor{}.get_top_dimension_safe_vector_length_strides(
 
  294                 window_adaptor_vector_lengths, window_adaptor_vector_strides);
 
  307     template <
index_t i_access_unsupport_ = -1, 
bool oob_conditional_check = 
true>
 
  311         constexpr 
auto tile_dstr = 
TileDstr{};
 
  312         auto dst_tensor          = make_static_distributed_tensor<DataType>(tile_dstr);
 
  313         load(dst_tensor, number<i_access_unsupport_>{}, bool_constant<oob_conditional_check>{});
 
  317     template <
typename DistributedTensor,
 
  318               index_t i_access_unsupport_ = -1,
 
  319               bool oob_conditional_check  = 
true>
 
  324         using Traits   = load_store_traits;
 
  325         using vector_t = 
typename Traits::vector_t;
 
  326         using SFC_Ys   = 
typename Traits::SFC_Ys;
 
  328         constexpr 
auto tile_dstr = 
TileDstr{};
 
  331         static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
 
  336             static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
 
  337                 constexpr 
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
 
  340                 constexpr 
auto idx_ys_start = SFC_Ys::get_index(iAccess);
 
  341                 constexpr 
auto idx_gather   = idx_ys_start[number<YsGatherDim>{}];
 
  342                 const auto page_offset      = 
page_idx_[idx_gather];
 
  345                 const vector_t vec_value = [&]() {
 
  346                     if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
 
  349                             bottom_tensor_thread_coord,
 
  351                             bool_constant<oob_conditional_check>{});
 
  356                             bottom_tensor_thread_coord,
 
  359                             bool_constant<oob_conditional_check>{});
 
  364                 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
 
  367                             return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
 
  373                         tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
 
  376                     dst_tensor.get_thread_buffer().template at<d>() =
 
  377                         vec_value.template get_as<DataType>()[j / Traits::PackedSize];
 
  381                     tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start);
 
  382                 static_assert(d % Traits::ScalarPerVector == 0);
 
  384                 dst_tensor.get_thread_buffer().template get_as<vector_t>()(
 
  385                     number<d / Traits::ScalarPerVector>{}) = bit_cast<vector_t>(vec_value);
 
  390                     constexpr 
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
 
  393                         [&](
auto i) { 
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
 
  397                         generate_tuple([&](
auto) { 
return number<0>{}; }, number<NDimP>{}),
 
  398                         forward_step_scatter);
 
  401                         window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
 
  408     template <
typename LdsTileWindow_,
 
  409               index_t i_access_unsupport_ = -1,
 
  410               bool oob_conditional_check  = 
true,
 
  411               bool pre_nop                = 
false>
 
  415                                        bool_constant<pre_nop>               = {}) 
const 
  417         using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
 
  419         using LdsDataType = 
typename LdsTileWindow::DataType;
 
  423         static_assert(LdsTileWindow::get_num_of_dimension() == 3); 
 
  426             lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
 
  427                 make_tuple(number<0>{}, number<0>{}, number<0>{})) *
 
  431             lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
 
  432                 make_tuple(number<0>{}, number<1>{}, number<0>{})) *
 
  433                 sizeof(LdsDataType) -
 
  437             lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
 
  438                 make_tuple(number<1>{}, number<0>{}, number<0>{})) *
 
  439                 sizeof(LdsDataType) -
 
  442         const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
 
  445         using Traits = load_store_traits;
 
  448         using vector_t = 
typename Traits::vector_t;
 
  449         using SFC_Ys   = 
typename Traits::SFC_Ys;
 
  451         LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
 
  454         static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
 
  459             static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
 
  460                 constexpr 
auto iAccess  = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
 
  461                 constexpr 
auto pre_nop_ = [&]() {
 
  462                     if constexpr(pre_nop && iCoord == 0 && iCoordAccess == 0)
 
  465                         return bool_constant<false>{};
 
  468                 constexpr 
auto idx_ys_start = SFC_Ys::get_index(iAccess);
 
  469                 constexpr 
auto idx_gather   = idx_ys_start[number<YsGatherDim>{}];
 
  470                 const auto page_offset      = 
page_idx_[idx_gather];
 
  473                 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
 
  476                         smem, bottom_tensor_thread_coord, page_offset, 0, pre_nop_);
 
  482                         bottom_tensor_thread_coord,
 
  492                     constexpr 
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
 
  495                         [&](
auto i) { 
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
 
  499                         generate_tuple([&](
auto) { 
return number<0>{}; }, number<NDimP>{}),
 
  500                         forward_step_scatter);
 
  503                         window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
 
  511     template <
index_t i_access_unsupport_ = -1, 
bool oob_conditional_check = 
true>
 
  516         using Traits = load_store_traits;
 
  519         using vector_t = 
typename Traits::vector_t;
 
  520         using SFC_Ys   = 
typename Traits::SFC_Ys;
 
  522         constexpr 
auto tile_dstr = 
TileDstr{};
 
  525         static_for<0, NumCoord, 1>{}([&](
auto iCoord) {
 
  529             static_for<0, NumAccessPerCoord, 1>{}([&](
auto iCoordAccess) {
 
  530                 constexpr 
auto iAccess = number<iCoord * NumAccessPerCoord + iCoordAccess>{};
 
  533                 constexpr 
auto idx_ys_start = SFC_Ys::get_index(iAccess);
 
  534                 constexpr 
auto idx_gather   = idx_ys_start[number<0>{}];
 
  535                 const auto page_offset      = 
page_idx_[idx_gather];
 
  544                 static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](
auto j) {
 
  547                             return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j)
 
  553                         tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
 
  556                     vec_value.template get_as<DataType>()(j / Traits::PackedSize) =
 
  563                 if constexpr(std::is_same_v<ValidArray, std::nullptr_t>)
 
  566                         bottom_tensor_thread_coord,
 
  569                         bool_constant<oob_conditional_check>{});
 
  574                         bottom_tensor_thread_coord,
 
  578                         bool_constant<oob_conditional_check>{});
 
  585                     constexpr 
auto idx_diff_ys = SFC_Ys::get_forward_step(iAccess);
 
  588                         [&](
auto i) { 
return i == YsGatherDim ? 0 : idx_diff_ys[i]; },
 
  592                         generate_tuple([&](
auto) { 
return number<0>{}; }, number<NDimP>{}),
 
  593                         forward_step_scatter);
 
  596                         window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
 
  609         step_new(HsGatherDim)      = 0;
 
  621         if constexpr(std::is_same_v<ValidArray, std::nullptr_t> == 
false)
 
  642         static_assert(
NDimP == 1 or 
NDimP == 2, 
"wrong!");
 
  646         if constexpr(
NDimP == 1)
 
  651         else if constexpr(
NDimP == 2)
 
  653             window_adaptor_thread_coord_tmp =
 
  666             window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
 
  668         bottom_tensor_thread_origin_idx_tmp(HsGatherDim) = 0;
 
  675         using SFC_Ys = 
typename Traits::SFC_Ys;
 
  678             auto window_adaptor_thread_coord = window_adaptor_thread_coord_tmp;
 
  679             auto bottom_tensor_thread_coord  = bottom_tensor_thread_coord_tmp;
 
  681             constexpr 
auto idx_diff_ys =
 
  688                 window_adaptor_thread_coord, bottom_tensor_thread_coord, idx_diff_ps_ys);
 
  691                 make_tuple(window_adaptor_thread_coord, bottom_tensor_thread_coord);
 
  722 template <
typename TensorView_,
 
  723           typename WindowLengths_,
 
  724           typename StaticTileDistribution_,
 
  725           typename StaticPageIndexArray_,
 
  730                          const WindowLengths_& window_lengths,
 
  731                          const multi_index<TensorView_::get_num_of_dimension()>& origin,
 
  733                          const StaticPageIndexArray_& page_idx,
 
  735                          number<NumCoord>    = {})
 
  737     return tile_scatter_gather<remove_cvref_t<TensorView_>,
 
  738                                remove_cvref_t<WindowLengths_>,
 
  739                                remove_cvref_t<StaticTileDistribution_>,
 
  740                                remove_cvref_t<StaticPageIndexArray_>,
 
  744         tensor_view, window_lengths, origin, tile_distribution, page_idx, 
nullptr};
 
  747 template <
typename TensorView,
 
  748           typename WindowLengths,
 
  749           typename StaticTileDistribution,
 
  750           typename StaticPageIndexArray,
 
  754     const multi_index<TensorView::get_num_of_dimension()>& origin,
 
  756     const StaticPageIndexArray& page_idx,
 
  764                                     number<HsGatherDim>{});
 
  767 template <
typename TensorView,
 
  768           typename WindowLengths,
 
  769           typename StaticTileDistribution,
 
  770           typename StaticPageIndexArray,
 
  775     const StaticPageIndexArray& page_idx,
 
  783                                     number<HsGatherDim>{});
 
  786 template <
typename TensorView_,
 
  787           typename WindowLengths_,
 
  788           typename StaticTileDistribution_,
 
  789           typename StaticPageIndexArray_,
 
  790           typename StaticValidArray_,
 
  795                          const WindowLengths_& window_lengths,
 
  796                          const multi_index<TensorView_::get_num_of_dimension()>& origin,
 
  798                          const StaticPageIndexArray_& page_idx,
 
  799                          const StaticValidArray_& valids,
 
  801                          number<NumCoord>    = {})
 
  803     return tile_scatter_gather<remove_cvref_t<TensorView_>,
 
  804                                remove_cvref_t<WindowLengths_>,
 
  805                                remove_cvref_t<StaticTileDistribution_>,
 
  806                                remove_cvref_t<StaticPageIndexArray_>,
 
  807                                remove_cvref_t<StaticValidArray_>,
 
  810         tensor_view, window_lengths, origin, tile_distribution, page_idx, valids};
 
  813 template <
typename TensorView,
 
  814           typename WindowLengths,
 
  815           typename StaticTileDistribution,
 
  816           typename StaticPageIndexArray,
 
  817           typename StaticValidArray,
 
  821     const multi_index<TensorView::get_num_of_dimension()>& origin,
 
  823     const StaticPageIndexArray& page_idx,
 
  824     const StaticValidArray& valids,
 
  833                                     number<HsGatherDim>{});
 
  836 template <
typename TensorView,
 
  837           typename WindowLengths,
 
  838           typename StaticTileDistribution,
 
  839           typename StaticPageIndexArray,
 
  840           typename StaticValidArray,
 
  845     const StaticPageIndexArray& page_idx,
 
  846     const StaticValidArray& valids,
 
  855                                     number<HsGatherDim>{});
 
#define CK_TILE_DEVICE
Definition: config.hpp:40
 
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
 
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:22
 
Definition: cluster_descriptor.hpp:13
 
constexpr CK_TILE_HOST_DEVICE void set_container_subset(array< T, N > &y, sequence< Is... > picks, const array< T, sizeof...(Is)> &x)
Definition: container_helper.hpp:420
 
constexpr CK_TILE_HOST_DEVICE void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition: tensor_coordinate.hpp:72
 
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:55
 
constant< b > bool_constant
Definition: integral_constant.hpp:39
 
constexpr CK_TILE_HOST_DEVICE auto generate_array(F &&f, number< N >)
Definition: sequence.hpp:1106
 
int32_t index_t
Definition: integer.hpp:9
 
constexpr CK_TILE_HOST_DEVICE auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition: tensor_coordinate.hpp:60
 
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
 
constant< v > number
Definition: integral_constant.hpp:33
 
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1046
 
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:412
 
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
Definition: utility.hpp:19
 
typename std::remove_reference< T >::type remove_reference_t
Definition: type_traits.hpp:15
 
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:343
 
constexpr CK_TILE_DEVICE auto make_tile_scatter_gather(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, const StaticPageIndexArray_ &page_idx, number< HsGatherDim >={}, number< NumCoord >={})
Definition: tile_scatter_gather.hpp:729
 
constexpr CK_TILE_HOST_DEVICE void move_tensor_adaptor_coordinate(const Adaptor &adaptor, AdaptorCoord &coord, const TopIndex &idx_diff_top, BottomIndex &idx_diff_bottom)
Definition: tensor_adaptor_coordinate.hpp:97
 
constexpr CK_TILE_HOST_DEVICE auto get_container_subset(const array< T, N > &arr, sequence< Is... >)
Definition: container_helper.hpp:389
 
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
Definition: utility.hpp:25
 
impl::is_static_impl< remove_cvref_t< T > > is_static
Definition: type_traits.hpp:87
 
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
 
Definition: sequence.hpp:278
 
typename std::conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:293
 
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
 
Definition: integral_constant.hpp:13
 
Definition: type_traits.hpp:76
 
Definition: numeric.hpp:81
 
Definition: space_filling_curve.hpp:20
 
Definition: static_distributed_tensor.hpp:21
 
constexpr CK_TILE_HOST_DEVICE const auto & get_thread_buffer() const
Definition: static_distributed_tensor.hpp:58
 
Definition: functional.hpp:43
 
Definition: tensor_view.hpp:41
 
Definition: tile_distribution.hpp:72
 
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:126
 
Definition: tile_scatter_gather.hpp:81
 
static constexpr index_t PackedSize
Definition: tile_scatter_gather.hpp:104
 
static constexpr index_t NumAccess
Definition: tile_scatter_gather.hpp:143
 
decltype(get_space_filling_curve()) SFC_Ys
Definition: tile_scatter_gather.hpp:141
 
static constexpr index_t VectorDimY
Definition: tile_scatter_gather.hpp:106
 
static constexpr index_t ScalarPerVector
Definition: tile_scatter_gather.hpp:107
 
This class provides tile (windowed) view and access to the device memory.
Definition: tile_scatter_gather.hpp:41
 
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition: tile_scatter_gather.hpp:605
 
static constexpr index_t NumAccessPerCoord
Definition: tile_scatter_gather.hpp:149
 
static constexpr auto I1
Definition: tile_scatter_gather.hpp:59
 
constexpr CK_TILE_DEVICE tile_scatter_gather(const BottomTensorView &bottom_tensor_view, const WindowLengths &window_lengths, const BottomTensorIndex &window_origin, const TileDstr &tile_distribution, const PageIdxArray &page_idx, const ValidArray &valids)
Definition: tile_scatter_gather.hpp:153
 
BottomTensorIndex window_origin_
Definition: tile_scatter_gather.hpp:705
 
CK_TILE_DEVICE auto load(number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:308
 
WindowLengths window_lengths_
Definition: tile_scatter_gather.hpp:702
 
constexpr CK_TILE_DEVICE auto get_tile_distribution() const
Definition: tile_scatter_gather.hpp:233
 
constexpr CK_TILE_DEVICE auto get_num_of_access() const
Definition: tile_scatter_gather.hpp:305
 
static constexpr index_t NDimBottomTensor
Definition: tile_scatter_gather.hpp:53
 
static constexpr CK_TILE_DEVICE auto get_window_adaptor_ys_safe_vector_length_strides()
Definition: tile_scatter_gather.hpp:266
 
array< index_t, NDimBottomTensor > BottomTensorIndex
Definition: tile_scatter_gather.hpp:72
 
PageIdxArray page_idx_
Definition: tile_scatter_gather.hpp:712
 
remove_cvref_t< WindowLengths_ > WindowLengths
Definition: tile_scatter_gather.hpp:43
 
CK_TILE_DEVICE void set_window_origin(const BottomTensorIndex &new_window_origin)
Definition: tile_scatter_gather.hpp:634
 
array< tuple< WindowAdaptorCoord, BottomTensorCoord >, NumCoord > pre_computed_coords_
Definition: tile_scatter_gather.hpp:718
 
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_scatter_gather.hpp:237
 
remove_cvref_t< StaticTileDistribution_ > TileDstr
Definition: tile_scatter_gather.hpp:44
 
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition: tile_scatter_gather.hpp:248
 
CK_TILE_DEVICE auto load(DistributedTensor &dst_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:320
 
CK_TILE_DEVICE void store(const static_distributed_tensor< DataType, TileDstr > &dstr_tensor, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_scatter_gather.hpp:512
 
CK_TILE_DEVICE void update_page_idx_and_valids(const PageIdxArray &new_idx, const ValidArray &new_valids)
Definition: tile_scatter_gather.hpp:627
 
typename BottomTensorView::TensorDesc BottomTensorDesc
Definition: tile_scatter_gather.hpp:48
 
TileDstr tile_dstr_
Definition: tile_scatter_gather.hpp:710
 
ValidArray valids_
Definition: tile_scatter_gather.hpp:713
 
static constexpr index_t NDimY
Definition: tile_scatter_gather.hpp:56
 
remove_cvref_t< typename BottomTensorView::DataType > DataType
Definition: tile_scatter_gather.hpp:50
 
static constexpr index_t NDimWindowAdaptorTop
Definition: tile_scatter_gather.hpp:52
 
static constexpr CK_TILE_DEVICE bool has_static_tile_distribution()
Definition: tile_scatter_gather.hpp:226
 
remove_cvref_t< StaticValidArray_ > ValidArray
Definition: tile_scatter_gather.hpp:46
 
static constexpr index_t NDimP
Definition: tile_scatter_gather.hpp:55
 
remove_reference_t< BottomTensorView_ > BottomTensorView
Definition: tile_scatter_gather.hpp:42
 
constexpr CK_TILE_DEVICE tile_scatter_gather()=default
 
remove_cvref_t< StaticPageIndexArray_ > PageIdxArray
Definition: tile_scatter_gather.hpp:45
 
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_scatter_gather.hpp:231
 
CK_TILE_HOST_DEVICE void init_raw()
Definition: tile_scatter_gather.hpp:695
 
static constexpr auto I0
Definition: tile_scatter_gather.hpp:58
 
decltype(make_tensor_coordinate(BottomTensorDesc{}, BottomTensorIndex{})) BottomTensorCoord
Definition: tile_scatter_gather.hpp:78
 
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_scatter_gather.hpp:235
 
typename TileDstr::PsYs2XsAdaptor WindowAdaptor
Definition: tile_scatter_gather.hpp:47
 
decltype(make_tensor_adaptor_coordinate(WindowAdaptor{}, AdaptorTopIndex{})) WindowAdaptorCoord
Definition: tile_scatter_gather.hpp:75
 
constexpr CK_TILE_DEVICE void set_bottom_tensor_view_data_ptr(typename BottomTensorView::DataType *data)
Definition: tile_scatter_gather.hpp:240
 
BottomTensorView bottom_tensor_view_
Definition: tile_scatter_gather.hpp:699
 
CK_TILE_DEVICE void update_valids(const ValidArray &new_valids)
Definition: tile_scatter_gather.hpp:619
 
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_ &&lds_tile, number< i_access_unsupport_ >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_scatter_gather.hpp:412
 
array< index_t, NDimWindowAdaptorTop > AdaptorTopIndex
Definition: tile_scatter_gather.hpp:71
 
CK_TILE_DEVICE void update_page_idx(const PageIdxArray &new_idx)
Definition: tile_scatter_gather.hpp:617
 
static constexpr CK_TILE_DEVICE index_t get_num_of_dimension()
Definition: tile_scatter_gather.hpp:224
 
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_window_base.hpp:45
 
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_window_base.hpp:47
 
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_window_base.hpp:46
 
This class provides description of tile windowed view on the device memory.
Definition: tile_window.hpp:870
 
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10