22 #define WINDOW_DISPATCH_ISSUE()                                     \ 
   23     if constexpr(i_access < 0)                                      \ 
   25         static_for<0, NumAccess, 1>{}([&](auto ia) { issue(ia); }); \ 
   29         static_assert(i_access < NumAccess);                        \ 
   30         issue(number<i_access>{});                                  \ 
   43 template <
typename BottomTensorView_,
 
   44           typename WindowLengths_,
 
   45           typename StaticTileDistribution_,
 
   46           typename LinearBottomDims_>
 
   50                                                                 StaticTileDistribution_,
 
   54                                              StaticTileDistribution_>
 
   58                                                                     StaticTileDistribution_,
 
   62                                                  StaticTileDistribution_>;
 
   66     static_assert(LinearBottomDims::size() == Base::BottomTensorView::get_num_of_dimension());
 
   74         static constexpr 
auto get_num_non_linear_access()
 
   76             constexpr 
auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths;
 
   77             using ys_to_rhs_major =
 
   79                                       .get_static_tile_distribution_encoding())::Ys2RHsMajor;
 
   81             constexpr 
auto non_linear = [&]() {
 
   84                     constexpr 
auto rhs_major    = ys_to_rhs_major{}[i_dim_y];
 
   85                     constexpr 
auto target_h_dim = 
number<rhs_major - 1>{}; 
 
   88                         cnt *= sfc_access_lens[i_dim_y];
 
  110         static constexpr 
auto get_non_linear_access_map()
 
  112             constexpr 
auto sfc_access_lens = Base::Traits::SFC_Ys::access_lengths;
 
  113             using ys_to_rhs_major =
 
  115                                       .get_static_tile_distribution_encoding())::Ys2RHsMajor;
 
  116             constexpr 
auto non_linear_map = [&]() {
 
  119                 index_t cumulative_non_linear_len_ = 1;
 
  122                     constexpr 
auto rhs_major     = ys_to_rhs_major{}[i_dim_y];
 
  123                     constexpr 
auto target_h_dim  = 
number<rhs_major - 1>{}; 
 
  127                     constexpr 
auto current_len_ = sfc_access_lens[i_dim_y];
 
  130                     for(
auto i_ = 0; i_ < cumulative_len_; i_++)
 
  132                         current_m_(i_) = m_[i_];
 
  134                     for(
auto j_ = 0; j_ < current_len_; j_++)
 
  136                         auto j_offset_ = is_linear_dim ? 0 : j_ * cumulative_non_linear_len_;
 
  137                         for(
auto i_ = 0; i_ < cumulative_len_; i_++)
 
  139                             m_(j_ * cumulative_len_ + i_) = current_m_[i_] + j_offset_;
 
  142                     cumulative_len_ *= current_len_;
 
  144                         cumulative_non_linear_len_ *= current_len_;
 
  149             return TO_SEQUENCE(non_linear_map, Base::Traits::NumAccess);
 
  152         static constexpr 
auto get_non_linear_access_histogram()
 
  154             constexpr 
auto m_ = get_non_linear_access_map();
 
  164         static constexpr 
auto get_non_linear_access_histogram_prefix_sum()
 
  166             constexpr 
auto h_            = get_non_linear_access_histogram();
 
  168             return h_prefix_sum_;
 
  204             window_origin + window_adaptor_thread_coord_tmp.get_bottom_index();
 
  210         using SFC_Ys = 
typename Base::Traits::SFC_Ys;
 
  214             constexpr 
auto need_save_non_linear_coord =
 
  217             if constexpr(need_save_non_linear_coord)
 
  229             if constexpr(i_access != (
NumAccess - 1))
 
  231                 constexpr 
auto idx_diff_ys = SFC_Ys::get_forward_step(i_access); 
 
  237                     window_adaptor_thread_coord_tmp,
 
  238                     bottom_tensor_thread_coord_tmp,
 
  244     template <index_t i_access>
 
  247         using SFC_Ys          = 
typename Base::Traits::SFC_Ys;
 
  249         using ys_to_rhs_major =
 
  251                                   .get_static_tile_distribution_encoding())::Ys2RHsMajor;
 
  255                 constexpr 
auto rhs_major    = ys_to_rhs_major{}[i_dim_y];
 
  256                 constexpr 
auto target_h_dim = 
number<rhs_major - 1>{}; 
 
  268         constexpr 
auto adaptor_ = 
typename Base::TileDstr{}.get_ps_ys_to_xs_adaptor();
 
  269         constexpr 
auto idx_ =
 
  272         return adaptor_.calculate_bottom_index(idx_);
 
  275     template <index_t i_access>
 
  279         constexpr 
auto is_pure_linear_tensor =
 
  281         if constexpr(is_pure_linear_tensor)
 
  287             return bottom_tensor_coord.get_offset();
 
  295             constexpr 
index_t linear_offset = [&]() {
 
  296                 constexpr 
auto x_idx_ = linear_coord;
 
  298                 static_assert(x_idx_.size() == x_len_.size());
 
  299                 constexpr 
index_t x_dims_ = x_idx_.size();
 
  303                     auto r_i_ = 
number<x_dims_ - i_ - 1>{};
 
  304                     cu_offset_ += x_idx_[r_i_] * cu_stride_;
 
  305                     cu_stride_ *= x_len_[r_i_];
 
  309             return linear_offset;
 
  313     template <
index_t i_access = -1, 
bool oob_conditional_check = 
true>
 
  316         using vector_t = 
typename Base::Traits::vector_t;
 
  317         using SFC_Ys   = 
typename Base::Traits::SFC_Ys;
 
  321         auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
 
  323         auto issue = [&](
auto i_access_) {
 
  324             constexpr 
auto IAccess = number<i_access_>{};
 
  333             const vector_t vec_value =
 
  335                     bottom_tensor_thread_coord,
 
  338                     bool_constant<oob_conditional_check>{});
 
  341             constexpr 
auto idx_diff_ys = SFC_Ys::get_index(IAccess);
 
  343             static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](
auto j) {
 
  346                         return jj == Base::Traits::VectorDimY ? (idx_diff_ys[jj] + j)
 
  349                     number<Base::NDimY>{});
 
  351                 constexpr 
index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
 
  352                                       Base::Traits::PackedSize;
 
  354                 dst_tensor.get_thread_buffer().template at<d>() =
 
  356                         .template get_as<typename Base::DataType>()[j / Base::Traits::PackedSize];
 
  365     template <
typename DstTile, 
index_t i_access = -1, 
bool oob_conditional_check = 
true>
 
  370         using vector_t = 
typename Base::Traits::vector_t;
 
  371         using SFC_Ys   = 
typename Base::Traits::SFC_Ys;
 
  377         auto issue = [&](
auto i_access_) {
 
  378             constexpr 
auto IAccess = number<i_access_>{};
 
  387             const vector_t vec_value =
 
  389                     bottom_tensor_thread_coord,
 
  392                     bool_constant<oob_conditional_check>{});
 
  394             constexpr 
auto idx_diff_ys = SFC_Ys::get_index(IAccess);
 
  396             static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](
auto j) {
 
  399                         return jj == Base::Traits::VectorDimY ? (idx_diff_ys[jj] + j)
 
  402                     number<Base::NDimY>{});
 
  404                 constexpr 
index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
 
  405                                       Base::Traits::PackedSize;
 
  407                 dst_tensor.get_thread_buffer().template at<d>() =
 
  409                         .template get_as<typename Base::DataType>()[j / Base::Traits::PackedSize];
 
  418     template <
typename DstTile,
 
  420               bool oob_conditional_check = 
true,
 
  421               bool pre_nop               = 
false>
 
  425                                  bool_constant<pre_nop>               = {}) 
const 
  427         using vector_t = 
typename Base::Traits::vector_t;
 
  428         using SFC_Ys   = 
typename Base::Traits::SFC_Ys;
 
  429         static constexpr 
index_t YElementSize =
 
  430             typename Base::TileDstr{}.get_ys_to_d_descriptor().get_element_space_size();
 
  431         static_assert(YElementSize % (Base::Traits::PackedSize * Base::Traits::ScalarPerVector) ==
 
  433         using vectorized_tbuf =
 
  435                   YElementSize / (Base::Traits::PackedSize * Base::Traits::ScalarPerVector)>;
 
  439         auto& dst_vec_tbuf = 
reinterpret_cast<vectorized_tbuf&
>(dst_tensor.get_thread_buffer());
 
  441         auto issue = [&](
auto i_access_) {
 
  442             constexpr 
auto IAccess  = number<i_access_>{};
 
  443             constexpr 
auto pre_nop_ = [&]() {
 
  444                 if constexpr(pre_nop && i_access_ == 0 &&
 
  445                              Base::BottomTensorView::buffer_view::get_address_space() ==
 
  446                                  address_space_enum::global)
 
  447                     return bool_constant<true>{};
 
  449                     return bool_constant<false>{};
 
  458             constexpr 
auto idx_ys_start = SFC_Ys::get_index(IAccess);
 
  460                 tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys_start) /
 
  461                 Base::Traits::PackedSize;
 
  462             static_assert(d % Base::Traits::ScalarPerVector == 0);
 
  465                 dst_vec_tbuf.template at<d / Base::Traits::ScalarPerVector>(),
 
  466                 bottom_tensor_thread_coord,
 
  469                 bool_constant<oob_conditional_check>{},
 
  471 #if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \ 
  472     CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 
  481     template <
typename LdsTileWindow_,
 
  483               bool oob_conditional_check = 
true,
 
  484               bool pre_nop               = 
false>
 
  488                                        bool_constant<pre_nop>               = {}) 
const 
  490         using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
 
  491         using LdsDataType   = 
typename LdsTileWindow::DataType;
 
  496         static_assert(Base::BottomTensorView::buffer_view::get_address_space() ==
 
  497                       address_space_enum::global);
 
  500         static_assert(LdsTileWindow::get_num_of_dimension() == 3); 
 
  503             lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
 
  504                 make_tuple(number<0>{}, number<0>{}, number<0>{})) *
 
  508             lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
 
  509                 make_tuple(number<0>{}, number<1>{}, number<0>{})) *
 
  510                 sizeof(LdsDataType) -
 
  514             lds_tile.get_bottom_tensor_view().get_tensor_descriptor().calculate_offset(
 
  515                 make_tuple(number<1>{}, number<0>{}, number<0>{})) *
 
  516                 sizeof(LdsDataType) -
 
  519         const index_t m0_init_value = size_per_buf + size_per_wave * get_warp_id();
 
  522         using vector_t = 
typename Base::Traits::vector_t;
 
  524         LdsDataType* smem = lds_tile.get_bottom_tensor_view().get_buffer_view().p_data_;
 
  527         auto issue = [&](
auto i_access_) {
 
  528             constexpr 
auto IAccess  = number<i_access_>{};
 
  529             constexpr 
auto pre_nop_ = [&]() {
 
  530                 if constexpr(pre_nop && i_access_ == 0)
 
  533                     return bool_constant<false>{};
 
  542                 smem, bottom_tensor_thread_coord, 0, bottom_tensor_flag, pre_nop_);
 
  545             if constexpr(i_access_ != (
NumAccess - 1))
 
  554     template <
typename LdsTileWindow_, 
index_t i_access = -1, 
bool oob_conditional_check = 
true>
 
  559         using LdsTileWindow = remove_cvref_t<LdsTileWindow_>;
 
  560         using LdsDataType   = 
typename LdsTileWindow::DataType;
 
  561         using vector_t      = 
typename traits::vector_t;
 
  564         static_assert(Base::BottomTensorView::buffer_view::get_address_space() ==
 
  565                           address_space_enum::global,
 
  566                       "Requires global memory");
 
  569         const auto window_origin       = lds_tile.get_window_origin();
 
  570         const auto& bottom_tensor_view = lds_tile.get_bottom_tensor_view();
 
  571         const auto& tensor_descriptor  = bottom_tensor_view.get_tensor_descriptor();
 
  572         auto smem_base_ptr             = bottom_tensor_view.get_buffer_view().p_data_;
 
  574         auto issue = [&](
auto i_access_) {
 
  575             constexpr 
auto IAccess       = number<i_access_>{};
 
  583             auto lds_bottom_tensor_thread_idx =
 
  584                 window_origin + window_adaptor_coord.get_bottom_index();
 
  585             const auto lds_coord =
 
  588             CK_TILE_LDS_ADDR LdsDataType* smem = smem_base_ptr + lds_coord.get_offset();
 
  593                 bottom_tensor_thread_coord,
 
  596                 bool_constant<oob_conditional_check>{});
 
  602     template <
typename Policy, 
index_t i_access_unsupport_ = -1, 
bool oob_conditional_check = 
true>
 
  606         auto dst_tensor = make_static_distributed_tensor<typename Base::DataType>(tile_dstr);
 
  607         this->
template load_transpose_linear<Policy>(
 
  612     template <
typename Policy,
 
  613               typename DistributedTensor,
 
  615               bool oob_conditional_check = 
true>
 
  620         using vector_t = 
typename traits::vector_t;
 
  621         using SFC_Ys   = 
typename traits::SFC_Ys;
 
  625         constexpr 
auto group_func = Policy::group_func;
 
  627         auto issue = [&](
auto i_access_) {
 
  628             constexpr 
auto IAccess          = number<i_access_>{};
 
  633             constexpr 
auto idx_ys_start = SFC_Ys::get_index(IAccess);
 
  636             const vector_t vec_value =
 
  638                     bottom_tensor_thread_coord, 0);
 
  640             static_for<0, traits::ScalarPerVector, 1>{}([&](
auto j) {
 
  643                         return jj == traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj];
 
  645                     number<Base::NDimY>{});
 
  647                 constexpr 
index_t linear_distributed_index =
 
  648                     tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys);
 
  649                 dst_tensor.get_thread_buffer().template at<linear_distributed_index>() =
 
  650                     vec_value.template get_as<typename Base::DataType>()[j];
 
  656     template <
index_t i_access = -1, 
bool oob_conditional_check = 
true>
 
  663         using vector_t = 
typename Base::Traits::vector_t;
 
  664         using SFC_Ys   = 
typename Base::Traits::SFC_Ys;
 
  669         auto issue = [&](
auto i_access_) {
 
  670             constexpr 
auto IAccess          = number<i_access_>{};
 
  676             constexpr 
auto idx_ys_start = SFC_Ys::get_index(IAccess);
 
  681             static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](
auto j) {
 
  684                         return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j)
 
  687                     number<Base::NDimY>{});
 
  689                 constexpr 
index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
 
  690                                       Base::Traits::PackedSize;
 
  692                 vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
 
  693                     dstr_tensor.get_thread_buffer().template at<d>();
 
  698                 bottom_tensor_thread_coord,
 
  702                 bool_constant<oob_conditional_check>{});
 
  708     template <
index_t i_access = -1>
 
  714         using vector_t = 
typename Base::Traits::vector_t;
 
  715         using SFC_Ys   = 
typename Base::Traits::SFC_Ys;
 
  718         static constexpr 
bool oob_conditional_check = 
true;
 
  721         auto issue = [&](
auto i_access_) {
 
  722             constexpr 
auto IAccess          = number<i_access_>{};
 
  729             constexpr 
auto idx_ys_start = SFC_Ys::get_index(IAccess);
 
  733             static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](
auto j) {
 
  736                         return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j)
 
  739                     number<Base::NDimY>{});
 
  740                 constexpr 
index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
 
  741                                       Base::Traits::PackedSize;
 
  742                 vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
 
  748                 .template set_vectorized_elements_raw<vector_t, oob_conditional_check>(
 
  749                     bottom_tensor_thread_coord, linear_offset, bottom_tensor_flag, vec_value);
 
  755     template <
index_t i_access = -1, 
bool oob_conditional_check = 
true>
 
  763         using vector_t = 
typename Base::Traits::vector_t;
 
  764         using SFC_Ys   = 
typename Base::Traits::SFC_Ys;
 
  769         auto issue = [&](
auto i_access_) {
 
  770             constexpr 
auto IAccess          = number<i_access_>{};
 
  777             constexpr 
auto idx_ys_start = SFC_Ys::get_index(IAccess);
 
  782             static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](
auto j) {
 
  785                         return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j)
 
  788                     number<Base::NDimY>{});
 
  790                 constexpr 
index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
 
  791                                       Base::Traits::PackedSize;
 
  793                 vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
 
  799                 bottom_tensor_thread_coord,
 
  803                 bool_constant<oob_conditional_check>{});
 
  809     template <
index_t i_access = -1, 
bool oob_conditional_check = 
true, 
bool pre_nop = 
false>
 
  815                bool_constant<pre_nop>               = {}) 
const 
  818         using vector_t = 
typename Base::Traits::vector_t;
 
  819         using SFC_Ys   = 
typename Base::Traits::SFC_Ys;
 
  824         auto issue = [&](
auto i_access_) {
 
  825             constexpr 
auto IAccess          = number<i_access_>{};
 
  832             constexpr 
auto idx_ys_start = SFC_Ys::get_index(IAccess);
 
  837             static_for<0, Base::Traits::ScalarPerVector, Base::Traits::PackedSize>{}([&](
auto j) {
 
  840                         return jj == Base::Traits::VectorDimY ? (idx_ys_start[jj] + j)
 
  843                     number<Base::NDimY>{});
 
  845                 constexpr 
index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) /
 
  846                                       Base::Traits::PackedSize;
 
  848                 vec_value.template get_as<typename Base::DataType>()(j / Base::Traits::PackedSize) =
 
  854                 bottom_tensor_thread_coord,
 
  858                 bool_constant<oob_conditional_check>{},
 
  859                 bool_constant<pre_nop>{});
 
  871             constexpr 
auto need_update_non_linear_coord =
 
  874             if constexpr(need_update_non_linear_coord)
 
  901             this->
window_origin_ + window_adaptor_thread_coord_tmp.get_bottom_index();
 
  907         using SFC_Ys = 
typename Base::Traits::SFC_Ys;
 
  911             constexpr 
auto need_save_non_linear_coord =
 
  914             if constexpr(need_save_non_linear_coord)
 
  920             if constexpr(i_access != (
NumAccess - 1))
 
  922                 constexpr 
auto idx_diff_ys = SFC_Ys::get_forward_step(i_access); 
 
  928                     window_adaptor_thread_coord_tmp,
 
  929                     bottom_tensor_thread_coord_tmp,
 
  942 #undef WINDOW_DISPATCH_ISSUE 
  945 template <address_space_enum, index_t len_>
 
  951 template <index_t len_>
 
  959 template <index_t len_>
 
  967 template <
typename TensorView_>
 
  970                                                    TensorView_::get_num_of_dimension()>::type;
 
  988 template <
typename TensorView_,
 
  989           typename WindowLengths_,
 
  990           typename StaticTileDistribution_,
 
  994                         const WindowLengths_& window_lengths,
 
  995                         const multi_index<TensorView_::get_num_of_dimension()>& origin,
 
  997                         LinearBottomDims_ = {})
 
  999     static_assert(LinearBottomDims_::size() == TensorView_::get_num_of_dimension());
 
 1000     return tile_window_linear<remove_cvref_t<TensorView_>,
 
 1001                               remove_cvref_t<WindowLengths_>,
 
 1002                               remove_cvref_t<StaticTileDistribution_>,
 
 1003                               remove_cvref_t<LinearBottomDims_>>{
 
 1004         tensor_view, window_lengths, origin, tile_distribution};
 
 1008     typename TileWindow_,
 
 1009     typename StaticTileDistribution_,
 
 1010     typename LinearBottomDims_ = default_linear_bottom_dims<typename TileWindow_::BottomTensorView>>
 
 1014                         LinearBottomDims_ = {})
 
 1017                                    tile_window.get_window_lengths(),
 
 1018                                    tile_window.get_window_origin(),
 
 1020                                    LinearBottomDims_{});
 
 1024 template <
typename TensorView_,
 
 1025           typename WindowLengths_,
 
 1026           typename StaticTileDistribution_,
 
 1027           typename LinearBottomDims_ = default_linear_bottom_dims<TensorView_>>
 
 1030                             const WindowLengths_& window_lengths,
 
 1031                             const multi_index<TensorView_::get_num_of_dimension()>& origin,
 
 1033                             LinearBottomDims_ = {})
 
 1035     static_assert(LinearBottomDims_::size() == TensorView_::get_num_of_dimension());
 
 1036     auto w = tile_window_linear<remove_cvref_t<TensorView_>,
 
 1037                                 remove_cvref_t<WindowLengths_>,
 
 1038                                 remove_cvref_t<StaticTileDistribution_>,
 
 1039                                 remove_cvref_t<LinearBottomDims_>>{
 
 1040         tensor_view, window_lengths, origin, tile_distribution};
 
 1046     typename TileWindow_,
 
 1047     typename StaticTileDistribution_,
 
 1048     typename LinearBottomDims_ = default_linear_bottom_dims<typename TileWindow_::BottomTensorView>>
 
 1052                             LinearBottomDims_ = {})
 
 1055                                        tile_window.get_window_lengths(),
 
 1056                                        tile_window.get_window_origin(),
 
 1058                                        LinearBottomDims_{});
 
 1061 template <
typename TensorView_,
 
 1062           typename WindowLengths_,
 
 1063           typename StaticTileDistribution_,
 
 1064           typename LinearBottomDims_>
 
 1070                                       StaticTileDistribution_,
 
 1071                                       LinearBottomDims_>::BottomTensorIndex& step)
 
 1084 template <
typename T>
 
 1100 template <
typename BottomTensorView_,
 
 1101           typename WindowLengths_,
 
 1102           typename StaticTileDistribution_,
 
 1103           typename LinearBottomDims_>
 
 1106                                                 StaticTileDistribution_,
 
 1118 template <
typename T>
 
#define CK_TILE_DEVICE
Definition: config.hpp:40
 
#define CK_TILE_LDS_ADDR
Definition: config.hpp:57
 
Definition: cluster_descriptor.hpp:13
 
typename impl::default_linear_bottom_dims_impl< TensorView_::buffer_view::get_address_space(), TensorView_::get_num_of_dimension()>::type default_linear_bottom_dims
Definition: tile_window_linear.hpp:970
 
constexpr CK_TILE_HOST_DEVICE void move_tensor_coordinate(const TensorDesc &tensor_desc, TensorCoord &coord, const Index &coord_step)
Definition: tensor_coordinate.hpp:72
 
constexpr CK_TILE_HOST_DEVICE auto make_tensor_adaptor_coordinate(const Adaptor &adaptor, const TopIndex &idx_top)
Definition: tensor_adaptor_coordinate.hpp:55
 
constant< b > bool_constant
Definition: integral_constant.hpp:39
 
int32_t index_t
Definition: integer.hpp:9
 
constexpr CK_TILE_HOST_DEVICE auto make_tensor_coordinate(const TensorDesc &tensor_desc, const TopIndex &idx_top)
Definition: tensor_coordinate.hpp:60
 
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
 
constant< v > number
Definition: integral_constant.hpp:33
 
constexpr CK_TILE_HOST_DEVICE index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition: sequence.hpp:973
 
constexpr CK_TILE_HOST_DEVICE bool coordinate_has_valid_offset_assuming_top_index_is_valid(const TensorDesc &tensor_desc, const TensorCoord &coord)
Definition: tensor_coordinate.hpp:79
 
CK_TILE_DEVICE auto make_tile_window_linear_raw(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:1029
 
constexpr bool is_tile_window_linear_v
Helper variable template to check if a type is a linear tile window.
Definition: tile_window_linear.hpp:1119
 
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:92
 
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:993
 
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:412
 
CK_TILE_DEVICE void m0_set_with_memory(index_t v)
Definition: utility.hpp:19
 
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:343
 
CK_TILE_DEVICE void m0_inc_with_memory(index_t v)
Definition: utility.hpp:25
 
constexpr CK_TILE_HOST_DEVICE auto histogram_sorted_sequence(SeqSortedSamples, sequence< r, rs... >)
Definition: sequence.hpp:1093
 
constexpr CK_TILE_HOST_DEVICE auto container_concat(const X &x, const Ys &... ys)
Definition: container_helper.hpp:363
 
constexpr auto prefix_sum_sequence(Seq)
Definition: sequence.hpp:899
 
bool_constant< false > false_type
Definition: integral_constant.hpp:63
 
bool_constant< true > true_type
Definition: integral_constant.hpp:62
 
Definition: sequence.hpp:278
 
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
 
Definition: integral_constant.hpp:13
 
typename sequence_merge< typename uniform_sequence_gen< len_ - 1, 0 >::type, sequence< 1 > >::type type
Definition: tile_window_linear.hpp:956
 
typename uniform_sequence_gen< len_, 1 >::type type
Definition: tile_window_linear.hpp:963
 
Definition: tile_window_linear.hpp:947
 
typename uniform_sequence_gen< len_, 0 >::type type
Definition: tile_window_linear.hpp:948
 
Type trait to determine if a type is a linear tile window.
Definition: tile_window_linear.hpp:1086
 
Definition: sequence.hpp:227
 
Definition: sequence.hpp:52
 
Definition: static_distributed_tensor.hpp:21
 
constexpr CK_TILE_HOST_DEVICE const auto & get_thread_buffer() const
Definition: static_distributed_tensor.hpp:58
 
Definition: functional.hpp:43
 
Definition: tensor_view.hpp:41
 
Definition: tile_distribution.hpp:72
 
constexpr CK_TILE_HOST_DEVICE const auto & get_ps_ys_to_xs_adaptor() const
Definition: tile_distribution.hpp:126
 
BottomTensorView bottom_tensor_view_
Definition: tile_window_base.hpp:85
 
remove_cvref_t< typename BottomTensorView::DataType > DataType
Definition: tile_window_base.hpp:36
 
BottomTensorIndex window_origin_
Definition: tile_window_base.hpp:79
 
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_window_base.hpp:47
 
CK_TILE_DEVICE void move(const BottomTensorIndex &step)
Definition: tile_window_base.hpp:67
 
remove_reference_t< BottomTensorView_ > BottomTensorView
Definition: tile_window_base.hpp:33
 
remove_cvref_t< WindowLengths_ > WindowLengths
Definition: tile_window_base.hpp:34
 
WindowLengths window_lengths_
Definition: tile_window_base.hpp:81
 
Definition: tile_window_linear.hpp:72
 
decltype(get_non_linear_access_histogram_prefix_sum()) AccessPrefixSum_NonLinear
Definition: tile_window_linear.hpp:175
 
decltype(get_non_linear_access_map()) AccessMap_NonLinear
Definition: tile_window_linear.hpp:173
 
static constexpr index_t NumAccess_NonLinear
Definition: tile_window_linear.hpp:172
 
decltype(get_non_linear_access_histogram()) AccessHistogram_NonLinear
Definition: tile_window_linear.hpp:174
 
Definition: tile_window_linear.hpp:55
 
static constexpr auto I0
Definition: tile_window_linear.hpp:68
 
CK_TILE_DEVICE void set_window_origin_extended(const typename Base::BottomTensorIndex &)
Definition: tile_window_linear.hpp:892
 
CK_TILE_DEVICE auto load(number< i_access >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window_linear.hpp:314
 
constexpr CK_TILE_DEVICE tile_window_linear()=default
 
array< typename Base::WindowAdaptorCoord, traits::NumAccess_NonLinear > cached_window_adaptor_coords_
Definition: tile_window_linear.hpp:938
 
CK_TILE_DEVICE auto async_load(LdsTileWindow_ &&lds_tile, number< i_access >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window_linear.hpp:555
 
CK_TILE_DEVICE void load_raw(DstTile &dst_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window_linear.hpp:422
 
static constexpr CK_TILE_DEVICE index_t get_bottom_linear_offset(number< i_access >)
Definition: tile_window_linear.hpp:276
 
CK_TILE_DEVICE auto load_transpose() const
Definition: tile_window_linear.hpp:603
 
typename traits::AccessHistogram_NonLinear AccessHistogram_NonLinear
Definition: tile_window_linear.hpp:181
 
typename traits::AccessMap_NonLinear AccessMap_NonLinear
Definition: tile_window_linear.hpp:180
 
constexpr CK_TILE_DEVICE tile_window_linear(const typename Base::BottomTensorView &bottom_tensor_view, const typename Base::WindowLengths &window_lengths, const typename Base::BottomTensorIndex &window_origin, const typename Base::TileDstr &tile_distribution)
Definition: tile_window_linear.hpp:186
 
static constexpr index_t NumAccess
Definition: tile_window_linear.hpp:178
 
CK_TILE_DEVICE void store_raw(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access >={}) const
Definition: tile_window_linear.hpp:710
 
array< bool, Base::Traits::NumAccess > cached_flags_
Definition: tile_window_linear.hpp:939
 
static constexpr CK_TILE_DEVICE auto get_bottom_linear_coordinate(number< i_access >)
Definition: tile_window_linear.hpp:245
 
CK_TILE_DEVICE void update(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window_linear.hpp:757
 
CK_TILE_DEVICE void store(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window_linear.hpp:657
 
CK_TILE_DEVICE void update_raw(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window_linear.hpp:811
 
typename traits::AccessPrefixSum_NonLinear AccessPrefixSum_NonLinear
Definition: tile_window_linear.hpp:182
 
CK_TILE_DEVICE auto load_transpose_linear(DistributedTensor &dst_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window_linear.hpp:616
 
static constexpr index_t NumAccess_NonLinear
Definition: tile_window_linear.hpp:179
 
CK_TILE_DEVICE auto load(DstTile &dst_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window_linear.hpp:366
 
CK_TILE_DEVICE void move_extended(const typename Base::BottomTensorIndex &step)
Definition: tile_window_linear.hpp:866
 
array< typename Base::BottomTensorCoord, traits::NumAccess_NonLinear > cached_coords_
Definition: tile_window_linear.hpp:936
 
CK_TILE_DEVICE auto async_load_raw(LdsTileWindow_ &&lds_tile, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={}) const
Definition: tile_window_linear.hpp:485
 
remove_cvref_t< LinearBottomDims_ > LinearBottomDims
Definition: tile_window_linear.hpp:64
 
static constexpr auto I1
Definition: tile_window_linear.hpp:69
 
Definition: tile_window_base.hpp:94
 
static constexpr index_t NDimY
Definition: tile_window_base.hpp:103
 
remove_cvref_t< StaticTileDistribution_ > TileDstr
Definition: tile_window_base.hpp:95
 
CK_TILE_DEVICE void move_window_adaptor_and_bottom_tensor_thread_coordinate(WindowAdaptorCoord &window_adaptor_thread_coord, BottomTensorCoord &bottom_tensor_thread_coord, const ATopIndex &idx_diff_adaptor_top) const
Definition: tile_window_base.hpp:129
 
TileDstr tile_dstr_
Definition: tile_window_base.hpp:253
 
#define WINDOW_DISPATCH_ISSUE()
Definition: tile_window_linear.hpp:22
 
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10