/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/core/tensor/load_tile_transpose.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/core/tensor/load_tile_transpose.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/core/tensor/load_tile_transpose.hpp Source File
load_tile_transpose.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
17 
18 namespace ck_tile {
19 
20 namespace util {
21 template <typename Suffix, typename Sequence>
23 {
24  static constexpr bool size_check = (Suffix::size() <= Sequence::size());
25 
26  static constexpr index_t start_pos = Sequence::size() - Suffix::size();
27  using extract_indices = typename arithmetic_sequence_gen<start_pos, Sequence::size(), 1>::type;
28 
29  static constexpr bool value =
30  size_check && (Suffix{} == decltype(Sequence::extract(extract_indices{})){});
31 };
32 
33 template <index_t... Xs>
35 {
36  static constexpr bool value = true;
37 };
38 
39 template <typename Suffix, typename Sequence>
41 
42 } // namespace util
43 
44 // Default policy: Retains original 2D transpose behavior
45 template <typename DataType>
47 {
48  struct Quad16
49  {
55  sequence<1>>;
56 
62  sequence<0>>;
63  };
64 
65  struct Quad8
66  {
72  sequence<1>>;
73 
79  sequence<0>>;
80  };
81 
82  // Select based on data size
83  using QuadInputEncoding = std::conditional_t<sizeof(DataType) == 2,
84  typename Quad16::InputEncoding,
85  typename Quad8::InputEncoding>;
86 
87  using QuadOutputEncoding = std::conditional_t<sizeof(DataType) == 2,
88  typename Quad16::OutputEncoding,
89  typename Quad8::OutputEncoding>;
90 
91  // Always swap last two dimensions
92  static constexpr auto transpose_dims = sequence<1, 0>{};
93 
94  // Programmable: Element grouping function
95  static constexpr auto group_func = [](auto idx) {
96  return idx; // Identity mapping
97  };
98 
99  template <typename InDstrEncode>
101  {
102  static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
103  static constexpr auto quad_hs_lengthss = QuadInputEncoding::hs_lengthss_;
104  // 1. Must be 2D tensor
105  static constexpr bool dims_valid = (InDstrEncode::NDimX == 2);
106  // 2. Quad pattern must be suffix of input pattern
107  static constexpr bool suffix_valid_dim0 =
108  util::is_sequence_suffix_v<decltype(quad_hs_lengthss.template get<0>()),
109  decltype(input_hs_lengthss.template get<0>())>;
110  static constexpr bool suffix_valid_dim1 =
111  util::is_sequence_suffix_v<decltype(quad_hs_lengthss.template get<1>()),
112  decltype(input_hs_lengthss.template get<1>())>;
113 
114  // 3. PS→RHS mapping constraints
115  static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_;
116  static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_;
117 
118  static constexpr index_t ndimp_outer = input_ps_to_rhss_major.size() - 1;
119  static constexpr index_t ndimp_inner =
121 
122  static constexpr bool ps_mapping_valid =
124  (input_ps_to_rhss_minor[number<ndimp_outer>{}][number<ndimp_inner>{}] ==
125  input_hs_lengthss[number<1>{}].size() - 2) &&
126  (input_ps_to_rhss_major[number<ndimp_outer>{}][number<ndimp_inner - 1>{}] == 1) &&
127  (input_ps_to_rhss_minor[number<ndimp_outer>{}][number<ndimp_inner - 1>{}] ==
128  input_hs_lengthss[number<0>{}].size() - 1);
129 
130  // 4. YS→RHS mapping constraints
131  static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_;
132  static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_;
133 
134  static constexpr bool ys_mapping_valid =
135  (input_ys_to_rhs_major.back() == 2) &&
136  (input_ys_to_rhs_minor.back() == input_hs_lengthss[number<1>{}].size() - 1) &&
137  (input_ys_to_rhs_major[input_ys_to_rhs_major.size() - 2] == 1) &&
139  input_hs_lengthss[number<0>{}].size() - 2);
140 
141  static constexpr bool value = dims_valid && suffix_valid_dim0 && suffix_valid_dim1 &&
143  };
144 };
145 template <typename TileDistribution_, typename DataType_, typename Policy>
147 {
149 
150  using Validator = typename Policy::template ValidationTraits<InDstrEncode>;
151 
152  static constexpr bool distr_encoding_valid = Validator::value;
153 };
154 
155 // this is used to generate the transposed output tile distribution encoding
156 // based on the input tile distribution encoding
157 template <typename TileDistribution_,
158  typename DataType_,
159  typename Policy = DefaultTranspose<DataType_>>
161 {
163  static constexpr auto input_hs_lengthss = InDstrEncode::hs_lengthss_;
164  static constexpr auto quad_input_hs_lengthss = Policy::QuadInputEncoding::hs_lengthss_;
165  static constexpr auto quad_output_hs_lengthss = Policy::QuadOutputEncoding::hs_lengthss_;
166 
167  static constexpr auto input_ps_to_rhss_major = InDstrEncode::ps_to_rhss_major_;
168  static constexpr auto input_ps_to_rhss_minor = InDstrEncode::ps_to_rhss_minor_;
169  static constexpr auto input_ys_to_rhs_major = InDstrEncode::ys_to_rhs_major_;
170  static constexpr auto input_ys_to_rhs_minor = InDstrEncode::ys_to_rhs_minor_;
171 
172  static constexpr auto quad_ps_to_rhss_major = Policy::QuadInputEncoding::ps_to_rhss_major_;
173  static constexpr auto quad_ps_to_rhss_minor = Policy::QuadInputEncoding::ps_to_rhss_minor_;
174 
175  // for transpose load
176  // append the reversed quad output hs lengths to the input hs lengthss after removing
177  // the quad_input_hs_lengthss
178  // then reverse the whole sequence to get the dst_out_hs_lengthss
180 
181  static constexpr auto full_out_hs_lengthss = generate_tuple(
182  [](auto i) {
183  return input_hs_lengthss[i]
184  .extract(typename arithmetic_sequence_gen<0,
185  input_hs_lengthss[i].size() -
186  quad_input_hs_lengthss[i].size(),
187  1>::type{})
188  .push_back(reversed_quad_output_hs_lengthss[i]);
189  },
191 
193 
194  // for PS→RHS mapping(both major and minor), we need to modify the last element of the major
195  // sequence
196  static constexpr auto modified_ps_to_rhss_major = generate_tuple(
197  [](auto i) {
198  if constexpr(i == input_ps_to_rhss_major.size() - 1)
199  {
200  constexpr auto current_size = input_ps_to_rhss_major[i].size();
201  constexpr auto reduce_size = quad_ps_to_rhss_major[number<0>{}].size();
202  constexpr auto reduced_ps_to_rhss_major = input_ps_to_rhss_major[i].extract(
204  return reduced_ps_to_rhss_major.push_back(number<2>{});
205  }
206  else
207  {
208  // For all other sequences, keep them unchanged
209  return input_ps_to_rhss_major[i];
210  }
211  },
212  number<input_ps_to_rhss_major.size()>{});
213 
214  static constexpr auto minor_last_index =
215  full_out_hs_lengthss[number<InDstrEncode::NDimX - 1>{}].size() - 1;
216  static constexpr auto major_last_index = full_out_hs_lengthss[number<0>{}].size() - 1;
217 
218  static constexpr auto dst_ps_to_rhss_minor = generate_tuple(
219  [](auto i) {
220  if constexpr(i == input_ps_to_rhss_minor.size() - 1)
221  {
222  constexpr auto current_size = input_ps_to_rhss_minor[i].size();
223  constexpr auto reduce_size = quad_ps_to_rhss_minor[number<0>{}].size();
224  constexpr auto reduced_ps_to_rhss_minor = input_ps_to_rhss_minor[i].extract(
226  return reduced_ps_to_rhss_minor.push_back(number<minor_last_index>{});
227  }
228  else
229  {
230  // For all other sequences, keep them unchanged
231  return input_ps_to_rhss_minor[i];
232  }
233  },
234  number<input_ps_to_rhss_minor.size()>{});
235 
236  // for major because of dst_out_hs_lengthss is reversed, this index also need to be reversed
237  static constexpr auto swap_one_and_two = [](const index_t idx) {
238  return (idx == 1) ? 2 : (idx == 2) ? 1 : idx;
239  };
240  static constexpr auto dst_ps_to_rhss_major = generate_tuple(
241  [](auto i) { return modified_ps_to_rhss_major[i].transform(swap_one_and_two); },
243 
244  static constexpr auto modified_input_ys_to_rhs_major =
245  input_ys_to_rhs_major.pop_back().push_back(number<1>{});
246 
247  static constexpr auto dst_ys_to_rhs_major = generate_sequence_v2(
248  [](auto i) { return number<swap_one_and_two(modified_input_ys_to_rhs_major[i])>{}; },
250 
251  static constexpr auto dst_ys_to_rhs_minor =
252  input_ys_to_rhs_minor.pop_back().push_back(number<major_last_index>{});
253 
254  using OutDstrEncode = tile_distribution_encoding<typename InDstrEncode::RsLengths,
260 };
261 
262 template <typename InnerEncode,
263  index_t kLeadIterPerWarp,
264  index_t kSecondIterPerWarp,
265  index_t kLeadNumWarps,
266  index_t kSecondNumWarps>
268 {
269  constexpr auto block_outer_dst_encoding =
276  sequence<0, 0>>{};
277  constexpr auto blk_distr_encode =
278  detail::make_embed_tile_distribution_encoding(block_outer_dst_encoding, InnerEncode{});
279 
280  return blk_distr_encode;
281 }
282 
308 template <
309  typename BottomTensorView_,
310  typename WindowLengths_,
311  typename TileDistribution_,
312  index_t NumCoord,
313  typename Policy = DefaultTranspose<typename BottomTensorView_::DataType>,
314  typename = std::enable_if_t<TransposeTileDistrChecker<TileDistribution_,
315  typename BottomTensorView_::DataType,
316  Policy>::distr_encoding_valid,
317  Policy>>
318 CK_TILE_DEVICE auto
320  WindowLengths_,
321  TileDistribution_,
322  NumCoord>& tile_window)
323 {
324  using OutTileDstrEncode =
325  typename OutputTileDistributionTraits<TileDistribution_,
326  typename BottomTensorView_::DataType>::OutDstrEncode;
327  auto out_tensor = make_static_distributed_tensor<typename BottomTensorView_::DataType>(
328  make_static_tile_distribution(OutTileDstrEncode{}));
329  auto trans_tensor = tile_window.template load_transpose<Policy>();
330  constexpr auto input_distr = TileDistribution_{};
331  constexpr auto output_distr = make_static_tile_distribution(OutTileDstrEncode{});
332 
333  constexpr auto y_in_desc = input_distr.get_ys_to_d_descriptor();
334  constexpr auto y_out_desc = output_distr.get_ys_to_d_descriptor();
335 
336  constexpr index_t NDimYIn = input_distr.get_num_of_dimension_y();
337  constexpr index_t NDimYOut = output_distr.get_num_of_dimension_y();
338 
339  constexpr auto y_in_lengths = to_sequence(y_in_desc.get_lengths());
340  constexpr auto y_out_lengths = to_sequence(y_out_desc.get_lengths());
341 
342  constexpr auto y_in_element_space_size = y_in_desc.get_element_space_size();
343  constexpr auto y_out_element_space_size = y_out_desc.get_element_space_size();
344  static_assert(y_in_element_space_size == y_out_element_space_size,
345  "the element space size is not the same!");
346  static_assert(y_in_lengths[NDimYIn - 1] == y_out_lengths[NDimYOut - 1],
347  "the vector length is not the same!");
348  constexpr index_t vecLoadSize = y_in_lengths[NDimYIn - 1];
349  constexpr index_t num_of_access =
350  reduce_on_sequence(y_in_lengths, multiplies{}, number<1>{}) / vecLoadSize;
351 
353  static_for<0, num_of_access, 1>{}([&](auto iAccess) {
354  out_tensor.get_thread_buffer().template set_as<DataVec>(
355  number<iAccess>{},
356  trans_tensor.get_thread_buffer().template get_as<DataVec>(number<iAccess>{}));
357  });
358 
359  return out_tensor;
360 }
361 
362 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:539
constexpr bool is_sequence_suffix_v
Definition: load_tile_transpose.hpp:40
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto generate_sequence_v2(F &&f, number< N >)
Definition: sequence.hpp:1036
constant< v > number
Definition: integral_constant.hpp:33
constexpr CK_TILE_HOST_DEVICE index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition: sequence.hpp:973
CK_TILE_DEVICE auto load_tile_transpose(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window)
transpose loads tile from a tensor and returns the resulting tensor with a new (transposed) tile dist...
Definition: load_tile_transpose.hpp:319
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1046
constexpr CK_TILE_HOST_DEVICE auto tuple_reverse(const tuple< Ts... > &t)
Definition: tuple.hpp:547
constexpr CK_TILE_HOST_DEVICE auto InputTileDistributionEncoding()
Definition: load_tile_transpose.hpp:267
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:406
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:498
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
Definition: load_tile_transpose.hpp:49
Definition: load_tile_transpose.hpp:66
Definition: load_tile_transpose.hpp:101
static constexpr bool dims_valid
Definition: load_tile_transpose.hpp:105
static constexpr index_t ndimp_outer
Definition: load_tile_transpose.hpp:118
static constexpr bool suffix_valid_dim1
Definition: load_tile_transpose.hpp:110
static constexpr auto input_ys_to_rhs_major
Definition: load_tile_transpose.hpp:131
static constexpr auto input_hs_lengthss
Definition: load_tile_transpose.hpp:102
static constexpr auto input_ps_to_rhss_major
Definition: load_tile_transpose.hpp:115
static constexpr bool suffix_valid_dim0
Definition: load_tile_transpose.hpp:107
static constexpr bool ys_mapping_valid
Definition: load_tile_transpose.hpp:134
static constexpr auto input_ys_to_rhs_minor
Definition: load_tile_transpose.hpp:132
static constexpr index_t ndimp_inner
Definition: load_tile_transpose.hpp:119
static constexpr bool ps_mapping_valid
Definition: load_tile_transpose.hpp:122
static constexpr auto input_ps_to_rhss_minor
Definition: load_tile_transpose.hpp:116
static constexpr bool value
Definition: load_tile_transpose.hpp:141
static constexpr auto quad_hs_lengthss
Definition: load_tile_transpose.hpp:103
Definition: load_tile_transpose.hpp:47
static constexpr auto group_func
Definition: load_tile_transpose.hpp:95
std::conditional_t< sizeof(DataType)==2, typename Quad16::OutputEncoding, typename Quad8::OutputEncoding > QuadOutputEncoding
Definition: load_tile_transpose.hpp:89
static constexpr auto transpose_dims
Definition: load_tile_transpose.hpp:92
std::conditional_t< sizeof(DataType)==2, typename Quad16::InputEncoding, typename Quad8::InputEncoding > QuadInputEncoding
Definition: load_tile_transpose.hpp:85
Definition: load_tile_transpose.hpp:161
static constexpr auto quad_ps_to_rhss_minor
Definition: load_tile_transpose.hpp:173
static constexpr auto dst_ys_to_rhs_major
Definition: load_tile_transpose.hpp:247
static constexpr auto input_ps_to_rhss_major
Definition: load_tile_transpose.hpp:167
static constexpr auto swap_one_and_two
Definition: load_tile_transpose.hpp:237
static constexpr auto quad_output_hs_lengthss
Definition: load_tile_transpose.hpp:165
static constexpr auto input_hs_lengthss
Definition: load_tile_transpose.hpp:163
static constexpr auto reversed_quad_output_hs_lengthss
Definition: load_tile_transpose.hpp:179
static constexpr auto input_ps_to_rhss_minor
Definition: load_tile_transpose.hpp:168
typename remove_cvref_t< TileDistribution_ >::DstrEncode InDstrEncode
Definition: load_tile_transpose.hpp:162
static constexpr auto modified_input_ys_to_rhs_major
Definition: load_tile_transpose.hpp:244
static constexpr auto quad_input_hs_lengthss
Definition: load_tile_transpose.hpp:164
static constexpr auto input_ys_to_rhs_minor
Definition: load_tile_transpose.hpp:170
static constexpr auto quad_ps_to_rhss_major
Definition: load_tile_transpose.hpp:172
static constexpr auto input_ys_to_rhs_major
Definition: load_tile_transpose.hpp:169
static constexpr auto dst_out_hs_lengthss
Definition: load_tile_transpose.hpp:192
static constexpr auto dst_ys_to_rhs_minor
Definition: load_tile_transpose.hpp:251
static constexpr auto major_last_index
Definition: load_tile_transpose.hpp:216
static constexpr auto dst_ps_to_rhss_minor
Definition: load_tile_transpose.hpp:218
static constexpr auto minor_last_index
Definition: load_tile_transpose.hpp:214
static constexpr auto dst_ps_to_rhss_major
Definition: load_tile_transpose.hpp:240
static constexpr auto full_out_hs_lengthss
Definition: load_tile_transpose.hpp:181
static constexpr auto modified_ps_to_rhss_major
Definition: load_tile_transpose.hpp:196
Definition: load_tile_transpose.hpp:147
static constexpr bool distr_encoding_valid
Definition: load_tile_transpose.hpp:152
typename Policy::template ValidationTraits< InDstrEncode > Validator
Definition: load_tile_transpose.hpp:150
typename remove_cvref_t< TileDistribution_ >::DstrEncode InDstrEncode
Definition: load_tile_transpose.hpp:148
Definition: sequence.hpp:278
typename std::conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:293
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
Definition: integral_constant.hpp:13
Definition: math.hpp:98
Definition: sequence.hpp:52
Definition: functional.hpp:43
Definition: tile_distribution_encoding.hpp:26
This class provides tile (windowed) view and access to the device memory.
Definition: tile_window.hpp:46
Definition: tuple.hpp:192
Definition: load_tile_transpose.hpp:23
typename arithmetic_sequence_gen< start_pos, Sequence::size(), 1 >::type extract_indices
Definition: load_tile_transpose.hpp:27
static constexpr bool value
Definition: load_tile_transpose.hpp:29
static constexpr bool size_check
Definition: load_tile_transpose.hpp:24
static constexpr index_t start_pos
Definition: load_tile_transpose.hpp:26