/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp Source File
cshuffle_epilogue.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 template <typename ADataType_,
13  typename BDataType_,
14  typename DsDataType_,
15  typename AccDataType_,
16  typename ODataType_,
17  typename DsLayout_,
18  typename ELayout_,
19  typename CDElementwise_,
20  index_t kBlockSize_,
21  index_t kM_,
22  index_t kN_,
23  index_t MWave_,
24  index_t NWave_,
25  index_t MPerXdl_,
26  index_t NPerXdl_,
27  index_t KPerXdl_,
28  bool isCTransposed_,
29  memory_operation_enum MemoryOperation_,
30  index_t kNumWaveGroups_ = 1,
31  bool FixedVectorSize_ = false,
32  index_t VectorSizeC_ = 1>
34 {
43  static constexpr index_t kBlockSize = kBlockSize_;
44  static constexpr index_t kMPerBlock = kM_;
45  static constexpr index_t kNPerBlock = kN_;
46  static constexpr index_t MWave = MWave_;
47  static constexpr index_t NWave = NWave_;
48  static constexpr index_t MPerXdl = MPerXdl_;
49  static constexpr index_t NPerXdl = NPerXdl_;
50  static constexpr index_t KPerXdl = KPerXdl_;
51  static constexpr index_t isCTransposed = isCTransposed_;
52  static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
53  static constexpr bool FixedVectorSize = FixedVectorSize_;
54  static constexpr index_t VectorSizeC = VectorSizeC_;
55  static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
56  static constexpr index_t NumDTensor = DsDataType::size();
57 
58  static_assert(NumDTensor == DsLayout::size(),
59  "The size of DsDataType and DsLayout should be the same");
60 };
61 
62 template <typename Problem_, typename Policy_ = void>
64 {
72  // Used for weight-only quantization kernel, B would be dequantized to the same data type as A
73  using BTypeToUse =
74  std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
77  static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
78  static constexpr index_t kBlockSize = Problem::kBlockSize;
79  static constexpr index_t kMPerBlock = Problem::kMPerBlock;
80  static constexpr index_t kNPerBlock = Problem::kNPerBlock;
81  static constexpr index_t MWave = Problem::MWave;
82  static constexpr index_t NWave = Problem::NWave;
83  static constexpr index_t MPerXdl = Problem::MPerXdl;
84  static constexpr index_t NPerXdl = Problem::NPerXdl;
85  static constexpr index_t KPerXdl = Problem::KPerXdl;
86  static constexpr index_t isCTransposed = Problem::isCTransposed;
87  static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
88  static constexpr index_t VectorSizeC = Problem::VectorSizeC;
89  static constexpr index_t MPerIteration = MPerXdl * MWave;
90  static constexpr index_t NPerIteration = NPerXdl * NWave;
91  static constexpr index_t NumDTensor = Problem::NumDTensor;
92 
93  static_assert(NumDTensor == DsLayout::size(),
94  "The size of DsDataType and DsLayout should be the same");
106  {
107  if constexpr(FixedVectorSize)
108  {
109  return VectorSizeC;
110  }
111  constexpr index_t max_vector_size = 16;
112  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
113  {
114  return std::min(static_cast<int>(NPerIteration),
115  static_cast<int>(max_vector_size / sizeof(ODataType)));
116  }
117  else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
118  {
119  return std::min(static_cast<int>(MPerIteration),
120  static_cast<int>(max_vector_size / sizeof(ODataType)));
121  }
122  else
123  {
124  static_assert(false, "Unsupported ELayout!");
125  }
126  }
127 
133  template <index_t I>
135  {
136  constexpr index_t max_vector_size = 16;
137  using DiDataType = remove_cvref_t<std::tuple_element_t<index.value, DsDataType>>;
138  using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
139  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
140  {
141  return std::min(static_cast<int>(NPerIteration),
142  static_cast<int>(max_vector_size / sizeof(DiDataType)));
143  }
144  else if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::ColumnMajor>)
145  {
146  return std::min(static_cast<int>(MPerIteration),
147  static_cast<int>(max_vector_size / sizeof(DiDataType)));
148  }
149  else
150  {
151  static_assert(false, "Unsupported DLayout!");
152  }
153  return max_vector_size / sizeof(DiDataType);
154  }
163  static constexpr auto shuffle_tile_tuple = [] {
164  constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
165  if constexpr(elem_per_thread >= GetVectorSizeC())
166  {
167  return std::make_tuple(1, 1);
168  }
169  else
170  {
171  constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread;
172  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
173  {
174  static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
175  (kMPerBlock % num_xdl_shuffles == 0),
176  "kMPerBlock must be divisible by MPerXdl*MWave and "
177  "num_xdl_shuffles for CShuffleEpilogue");
178  return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1);
179  }
180  else
181  {
182  static_assert((kNPerBlock % (NPerXdl * NWave) == 0) &&
183  (kNPerBlock % num_xdl_shuffles == 0),
184  "kNPerBlock must be divisible by NPerXdl*NWave and "
185  "num_xdl_shuffles for CShuffleEpilogue");
186  return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave)));
187  }
188  }
189  }();
190  static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
191  static constexpr index_t NumNXdlPerWavePerShuffle = std::get<1>(shuffle_tile_tuple);
192 
193  static constexpr auto MNPerIterationShuffle = [] {
194  constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
195  constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle;
196  if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0)
198  else
199  return std::make_tuple(m_val, n_val);
200  }();
201  static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
202  static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);
203 
205  BTypeToUse,
206  AccDataType,
207  MPerXdl,
208  NPerXdl,
209  KPerXdl,
210  isCTransposed>;
211 
212  using CWarpDstr = typename WG::CWarpDstr;
213  using CWarpTensor = typename WG::CWarpTensor;
214 
215  template <typename Problem>
217  {
218  // N is contiguous dimension
219  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
220  {
224  }
225  // M is contiguous dimension
226  else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
227  {
231  }
232  else
233  {
234  static_assert(false, "Unsupported ELayout!");
235  }
236  }
237 
239  {
240  constexpr auto block_outer_dstr_encoding =
247  sequence<0, 0>>{};
248  constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
249  block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});
250 
251  return block_dstr_encoding;
252  }
253 
255  {
257  }
258 
259  template <typename ODramWindow, typename OAccTile, typename DsDramWindows>
260  CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
261  const OAccTile& o_acc_tile,
262  const DsDramWindows& ds_dram_windows,
263  void* p_smem)
264  {
265  constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
266 
267  auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
268 
269  constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
270  auto o_lds_block = make_tensor_view<address_space_enum::lds>(
271  static_cast<ODataType*>(p_smem), lds_block_desc);
272 
273  auto in_lds_window = make_tile_window(
274  o_lds_block,
276  {0, 0},
277  LdsTileDistr);
278 
279  auto out_lds_window = make_tile_window(
280  o_lds_block,
282  {0, 0});
283 
287  constexpr index_t num_access = SFC::get_num_of_access();
288 
289  static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
290  "Currently, the CShuffle Epilogue only supports the Row Major Output layout");
291 
292  using TileEncodingPattern =
296  GetVectorSizeC(),
298  Problem::kNumWaveGroups>;
299  constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
300 
301  auto d_dram_windows = generate_tuple(
302  [&](auto idx) {
303  return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
304  },
306 
307  constexpr auto c_warp_y_lengths =
308  to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
309  constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
310 
311  static_for<0, num_access, 1>{}([&](auto iAccess) {
312  block_sync_lds();
313  constexpr auto idx_y_start = SFC::get_index(iAccess);
314 
315  constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
316  constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
317 
318  lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
321  c_warp_y_index_zeros),
323  c_warp_y_lengths));
324 
325  const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
326 
327  store_tile(in_lds_window, c_warptile_in_tensor_casted);
328  block_sync_lds();
329 
330  auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
331 
332  const auto ds_tensor = generate_tuple(
333  [&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
334 
335  const auto c_ds_tiles = concat_tuple_of_reference(
336  tie(c_out_tensor, c_out_tensor),
337  generate_tie(
338  [&](auto idx) -> const auto& { return ds_tensor[idx]; }, number<NumDTensor>{}));
339 
340  tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
341 
343  {
344  store_tile(out_dram_window, c_out_tensor);
345  }
346  else
347  {
348  update_tile(out_dram_window, c_out_tensor);
349  }
350  if constexpr(iAccess != num_access - 1)
351  {
352  constexpr auto step = SFC::get_forward_step(iAccess);
353 
354  move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
355 
356  static_for<0, NumDTensor, 1>{}([&](auto idx) {
357  move_tile_window(d_dram_windows[idx],
358  {step.at(number<0>{}), step.at(number<1>{})});
359  });
360  }
361  });
362  }
363 };
364 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:539
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE index_t get_warp_size()
Definition: arch.hpp:51
typename impl::WarpGemmMfmaDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity >::Type WarpGemmMfmaDispatcher
Definition: warp_gemm_dispatcher.hpp:115
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:255
memory_operation_enum
Definition: arch.hpp:44
constexpr tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:353
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:83
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto generate_tie(F &&f, number< N >)
Definition: tuple.hpp:412
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc &in_element_func, const Tuple &t, std::index_sequence< I... >)
Template function that "unpacks" a tuple and applies an element-wise operation.
Definition: tile_elementwise.hpp:71
@ thread_raked
Thread raked pattern.
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1046
constexpr CK_TILE_HOST_DEVICE auto merge_sequences(Seqs...)
Definition: sequence.hpp:817
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:92
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:406
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
constexpr CK_TILE_HOST_DEVICE auto concat_tuple_of_reference(const tuple< X &... > &tx, const tuple< Y &... > &ty)
Definition: tuple.hpp:420
CK_TILE_DEVICE void update_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: update_tile.hpp:22
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:498
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1017
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
Definition: cshuffle_epilogue.hpp:64
static constexpr index_t kBlockSize
Definition: cshuffle_epilogue.hpp:78
static constexpr CK_TILE_HOST_DEVICE auto MakeLdsBlockDescriptor()
Definition: cshuffle_epilogue.hpp:216
typename WG::CWarpTensor CWarpTensor
Definition: cshuffle_epilogue.hpp:213
remove_cvref_t< Problem_ > Problem
Definition: cshuffle_epilogue.hpp:65
static constexpr index_t MPerXdl
Definition: cshuffle_epilogue.hpp:83
static constexpr bool FixedVectorSize
Definition: cshuffle_epilogue.hpp:87
static constexpr CK_TILE_HOST_DEVICE index_t GetVectorSizeD(number< I > index)
Get the vector store size for Di tensor.
Definition: cshuffle_epilogue.hpp:134
remove_cvref_t< typename Problem::ODataType > ODataType
Definition: cshuffle_epilogue.hpp:69
static constexpr index_t kNPerBlock
Definition: cshuffle_epilogue.hpp:80
CK_TILE_DEVICE auto operator()(ODramWindow &out_dram_window, const OAccTile &o_acc_tile, const DsDramWindows &ds_dram_windows, void *p_smem)
Definition: cshuffle_epilogue.hpp:260
remove_cvref_t< typename Problem::ELayout > ELayout
Definition: cshuffle_epilogue.hpp:75
static constexpr memory_operation_enum MemoryOperation
Definition: cshuffle_epilogue.hpp:77
remove_cvref_t< typename Problem::DsLayout > DsLayout
Definition: cshuffle_epilogue.hpp:71
static constexpr CK_TILE_DEVICE auto MakeLdsDistributionEncode()
Definition: cshuffle_epilogue.hpp:238
static constexpr index_t MPerIteration
Definition: cshuffle_epilogue.hpp:89
static constexpr auto MNPerIterationShuffle
Definition: cshuffle_epilogue.hpp:193
WarpGemmMfmaDispatcher< ADataType, BTypeToUse, AccDataType, MPerXdl, NPerXdl, KPerXdl, isCTransposed > WG
Definition: cshuffle_epilogue.hpp:210
static constexpr index_t isCTransposed
Definition: cshuffle_epilogue.hpp:86
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: cshuffle_epilogue.hpp:254
static constexpr index_t MWave
Definition: cshuffle_epilogue.hpp:81
static constexpr CK_TILE_HOST_DEVICE index_t GetVectorSizeC()
Get the vector store size for C tensor.
Definition: cshuffle_epilogue.hpp:105
static constexpr index_t VectorSizeC
Definition: cshuffle_epilogue.hpp:88
remove_cvref_t< typename Problem::DsDataType > DsDataType
Definition: cshuffle_epilogue.hpp:70
std::conditional_t< std::is_same_v< BDataType, pk_int4_t >, ADataType, BDataType > BTypeToUse
Definition: cshuffle_epilogue.hpp:74
remove_cvref_t< typename Problem::CDElementwise > CDElementwise
Definition: cshuffle_epilogue.hpp:76
static constexpr index_t NPerIterationShuffle
Definition: cshuffle_epilogue.hpp:202
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition: cshuffle_epilogue.hpp:68
static constexpr index_t NumDTensor
Definition: cshuffle_epilogue.hpp:91
static constexpr index_t KPerXdl
Definition: cshuffle_epilogue.hpp:85
static constexpr index_t NumMXdlPerWavePerShuffle
Definition: cshuffle_epilogue.hpp:190
static constexpr index_t NumNXdlPerWavePerShuffle
Definition: cshuffle_epilogue.hpp:191
static constexpr index_t MPerIterationShuffle
Definition: cshuffle_epilogue.hpp:201
remove_cvref_t< typename Problem::BDataType > BDataType
Definition: cshuffle_epilogue.hpp:67
static constexpr auto shuffle_tile_tuple
Shuffle tile configuration parameters.
Definition: cshuffle_epilogue.hpp:163
static constexpr index_t NWave
Definition: cshuffle_epilogue.hpp:82
remove_cvref_t< typename Problem::ADataType > ADataType
Definition: cshuffle_epilogue.hpp:66
static constexpr index_t kMPerBlock
Definition: cshuffle_epilogue.hpp:79
static constexpr index_t NPerIteration
Definition: cshuffle_epilogue.hpp:90
static constexpr index_t NPerXdl
Definition: cshuffle_epilogue.hpp:84
typename WG::CWarpDstr CWarpDstr
Definition: cshuffle_epilogue.hpp:212
Definition: cshuffle_epilogue.hpp:34
remove_cvref_t< AccDataType_ > AccDataType
Definition: cshuffle_epilogue.hpp:37
static constexpr index_t kNumWaveGroups
Definition: cshuffle_epilogue.hpp:55
remove_cvref_t< ODataType_ > ODataType
Definition: cshuffle_epilogue.hpp:38
remove_cvref_t< ELayout_ > ELayout
Definition: cshuffle_epilogue.hpp:41
remove_cvref_t< CDElementwise_ > CDElementwise
Definition: cshuffle_epilogue.hpp:42
static constexpr index_t NWave
Definition: cshuffle_epilogue.hpp:47
static constexpr index_t kBlockSize
Definition: cshuffle_epilogue.hpp:43
remove_cvref_t< ADataType_ > ADataType
Definition: cshuffle_epilogue.hpp:35
static constexpr index_t MWave
Definition: cshuffle_epilogue.hpp:46
static constexpr index_t kNPerBlock
Definition: cshuffle_epilogue.hpp:45
remove_cvref_t< DsDataType_ > DsDataType
Definition: cshuffle_epilogue.hpp:39
remove_cvref_t< DsLayout_ > DsLayout
Definition: cshuffle_epilogue.hpp:40
static constexpr index_t KPerXdl
Definition: cshuffle_epilogue.hpp:50
static constexpr bool FixedVectorSize
Definition: cshuffle_epilogue.hpp:53
static constexpr index_t kMPerBlock
Definition: cshuffle_epilogue.hpp:44
static constexpr index_t VectorSizeC
Definition: cshuffle_epilogue.hpp:54
static constexpr index_t NumDTensor
Definition: cshuffle_epilogue.hpp:56
static constexpr memory_operation_enum MemoryOperation
Definition: cshuffle_epilogue.hpp:52
remove_cvref_t< BDataType_ > BDataType
Definition: cshuffle_epilogue.hpp:36
static constexpr index_t isCTransposed
Definition: cshuffle_epilogue.hpp:51
static constexpr index_t NPerXdl
Definition: cshuffle_epilogue.hpp:49
static constexpr index_t MPerXdl
Definition: cshuffle_epilogue.hpp:48
Class creating 2D static tile distribution with different load/store patterns.
Definition: static_encoding_pattern.hpp:129
Definition: integral_constant.hpp:13
static constexpr value_type value
Definition: integral_constant.hpp:16
Definition: sequence.hpp:52
Definition: space_filling_curve.hpp:20
Definition: functional.hpp:43
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192