/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/store_tile.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/store_tile.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/store_tile.hpp Source File
store_tile.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
16 
17 namespace ck_tile {
18 
19 template <typename BottomTensorView_,
20  typename WindowLengths_,
21  typename TileDistribution_,
22  typename DataType_>
23 CK_TILE_DEVICE void
26 {
28  using TileDstr = remove_cvref_t<TileDistribution_>;
29 
30  static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
31 
32  constexpr auto tile_dstr = TileDstr{};
33 
34  auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
35  tile_window_tmp.get_window_lengths(),
36  tile_window_tmp.get_window_origin(),
37  tile_dstr);
38 
39  tile_window.store(dstr_tensor);
40 }
41 
42 template <typename BottomTensorView_,
43  typename WindowLengths_,
44  typename TileDistribution_,
45  typename DataType_>
46 CK_TILE_DEVICE void
49  decltype(get_partition_index(dstr_tensor.get_tile_distribution())) partition_index)
50 {
52  using TileDstr = remove_cvref_t<TileDistribution_>;
53 
54  static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
55 
56  constexpr auto tile_dstr = TileDstr{};
57 
58  auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
59  tile_window_tmp.get_window_lengths(),
60  tile_window_tmp.get_window_origin(),
61  tile_dstr,
62  partition_index);
63 
64  tile_window.store(dstr_tensor);
65 }
66 
67 template <typename BottomTensorView_,
68  typename WindowLengths_,
69  typename TileDistribution_,
70  typename DataType_>
71 CK_TILE_DEVICE void
74 {
76  using TileDstr = remove_cvref_t<TileDistribution_>;
77 
78  static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
79 
80  constexpr auto tile_dstr = TileDstr{};
81 
82  auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
83  tile_window_tmp.get_window_lengths(),
84  tile_window_tmp.get_window_origin(),
85  tile_dstr);
86 
87  tile_window.store_raw(dstr_tensor);
88 }
89 
90 template <typename BottomTensorView_,
91  typename WindowLengths_,
92  typename TileDistribution_,
93  typename DataType_>
94 CK_TILE_DEVICE void
97  decltype(get_partition_index(dstr_tensor.get_tile_distribution())) partition_index)
98 {
100  using TileDstr = remove_cvref_t<TileDistribution_>;
101 
102  static_assert(std::is_same_v<remove_cvref_t<DataType_>, DataType>, "wrong!");
103 
104  constexpr auto tile_dstr = TileDstr{};
105 
106  auto tile_window = make_tile_window(tile_window_tmp.get_bottom_tensor_view(),
107  tile_window_tmp.get_window_lengths(),
108  tile_window_tmp.get_window_origin(),
109  tile_dstr,
110  partition_index);
111 
112  tile_window.store_raw(dstr_tensor);
113 }
114 
115 template <typename BottomTensorView_,
116  typename WindowLengths_,
117  typename TileDistribution_,
118  index_t NumCoord,
119  typename DataType_>
120 CK_TILE_DEVICE void
122  WindowLengths_,
123  TileDistribution_,
124  NumCoord>& tile_window,
126 {
127  tile_window.store(dstr_tensor, number<-1>{});
128 }
129 
130 template <typename BottomTensorView_,
131  typename WindowLengths_,
132  typename TileDistribution_,
133  index_t NumCoord,
134  typename DataType_>
135 CK_TILE_DEVICE void
137  WindowLengths_,
138  TileDistribution_,
139  NumCoord>& tile_window,
141 {
142  tile_window.store_raw(dstr_tensor, number<-1>{});
143 }
144 
145 template <typename BottomTensorView_,
146  typename WindowLengths_,
147  typename TileDistribution_,
148  typename LinearBottomDims_,
149  typename DataType_>
152  tile_window,
154 {
155  tile_window.store(dstr_tensor, number<-1>{});
156 }
157 
158 template <typename BottomTensorView_,
159  typename WindowLengths_,
160  typename TileDistribution_,
161  typename LinearBottomDims_,
162  typename DataType_>
165  tile_window,
167 {
168  tile_window.store_raw(dstr_tensor, number<-1>{});
169 }
170 
171 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
CK_TILE_DEVICE void store_tile_raw(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:72
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:24
constexpr bool is_same_v
Definition: type.hpp:283
Definition: integral_constant.hpp:13
Definition: static_distributed_tensor.hpp:21
static constexpr CK_TILE_HOST_DEVICE auto get_tile_distribution()
Definition: static_distributed_tensor.hpp:46
constexpr CK_TILE_DEVICE auto get_window_origin() const
Definition: tile_window_base.hpp:45
constexpr CK_TILE_DEVICE auto get_bottom_tensor_view() const
Definition: tile_window_base.hpp:47
constexpr CK_TILE_DEVICE auto get_window_lengths() const
Definition: tile_window_base.hpp:46
Definition: tile_window_linear.hpp:55
CK_TILE_DEVICE void store_raw(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access >={}) const
Definition: tile_window_linear.hpp:710
CK_TILE_DEVICE void store(const static_distributed_tensor< typename Base::DataType, typename Base::TileDstr > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}) const
Definition: tile_window_linear.hpp:657
This class provides tile (windowed) view and access to the device memory.
Definition: tile_window.hpp:47
This class provides description of tile windowed view on the device memory.
Definition: tile_window.hpp:1195