/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp Source File
cshuffle_epilogue.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
10 
11 #include <optional>
12 
13 namespace ck_tile {
14 
15 template <typename ADataType_,
16  typename BDataType_,
17  typename DsDataType_,
18  typename AccDataType_,
19  typename ODataType_,
20  typename DsLayout_,
21  typename ELayout_,
22  typename CDElementwise_,
23  index_t kM_,
24  index_t kN_,
25  index_t MWave_,
26  index_t NWave_,
27  index_t MPerXdl_,
28  index_t NPerXdl_,
29  index_t KPerXdl_,
30  bool isCTransposed_,
31  memory_operation_enum MemoryOperation_,
32  index_t kNumWaveGroups_ = 1,
33  bool FixedVectorSize_ = false,
34  index_t VectorSizeC_ = 1>
36 {
45  static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
46  static constexpr index_t kMPerBlock = kM_;
47  static constexpr index_t kNPerBlock = kN_;
48  static constexpr index_t MWave = MWave_;
49  static constexpr index_t NWave = NWave_;
50  static constexpr index_t MPerXdl = MPerXdl_;
51  static constexpr index_t NPerXdl = NPerXdl_;
52  static constexpr index_t KPerXdl = KPerXdl_;
53  static constexpr index_t isCTransposed = isCTransposed_;
54  static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
55  static constexpr bool FixedVectorSize = FixedVectorSize_;
56  static constexpr index_t VectorSizeC = VectorSizeC_;
57  static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
58  static constexpr index_t NumDTensor = DsDataType::size();
59 
60  static_assert(NumDTensor == DsLayout::size(),
61  "The size of DsDataType and DsLayout should be the same");
62 };
63 
64 template <typename Problem_, typename Policy_ = void>
66 {
74  using ATypeToUse =
75  std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
76  // Used for weight-only quantization kernel, B would be dequantized to the same data type as A
77  using BTypeToUse =
78  std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
81  static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
82  static constexpr index_t kBlockSize = Problem::kBlockSize;
83  static constexpr index_t kMPerBlock = Problem::kMPerBlock;
84  static constexpr index_t kNPerBlock = Problem::kNPerBlock;
85  static constexpr index_t MWave = Problem::MWave;
86  static constexpr index_t NWave = Problem::NWave;
87  static constexpr index_t MPerXdl = Problem::MPerXdl;
88  static constexpr index_t NPerXdl = Problem::NPerXdl;
89  static constexpr index_t KPerXdl = Problem::KPerXdl;
90  static constexpr index_t isCTransposed = Problem::isCTransposed;
91  static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
92  static constexpr index_t VectorSizeC = Problem::VectorSizeC;
93  static constexpr index_t MPerIteration = MPerXdl * MWave;
94  static constexpr index_t NPerIteration = NPerXdl * NWave;
95  static constexpr index_t NumDTensor = Problem::NumDTensor;
96 
97  static_assert(NumDTensor == DsLayout::size(),
98  "The size of DsDataType and DsLayout should be the same");
110  {
111  if constexpr(FixedVectorSize)
112  {
113  return VectorSizeC;
114  }
115  constexpr index_t max_vector_size = 16;
116  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
117  {
118  return std::min(static_cast<int>(NPerIteration),
119  static_cast<int>(max_vector_size / sizeof(ODataType)));
120  }
121  else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
122  {
123  return std::min(static_cast<int>(MPerIteration),
124  static_cast<int>(max_vector_size / sizeof(ODataType)));
125  }
126  else
127  {
128  static_assert(false, "Unsupported ELayout!");
129  }
130  }
131 
137  template <index_t I>
139  {
140  constexpr index_t max_vector_size = 16;
141  using DiDataType = remove_cvref_t<std::tuple_element_t<index.value, DsDataType>>;
142  using DiLayout = remove_cvref_t<std::tuple_element_t<index.value, DsLayout>>;
143  if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
144  {
145  return std::min(static_cast<int>(NPerIteration),
146  static_cast<int>(max_vector_size / sizeof(DiDataType)));
147  }
148  else if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::ColumnMajor>)
149  {
150  return std::min(static_cast<int>(MPerIteration),
151  static_cast<int>(max_vector_size / sizeof(DiDataType)));
152  }
153  else
154  {
155  static_assert(false, "Unsupported DLayout!");
156  }
157  return max_vector_size / sizeof(DiDataType);
158  }
167  static constexpr auto shuffle_tile_tuple = [] {
168  constexpr index_t elem_per_thread = MPerXdl * NPerXdl / get_warp_size();
169  if constexpr(elem_per_thread >= GetVectorSizeC())
170  {
171  return std::make_tuple(1, 1);
172  }
173  else
174  {
175  constexpr index_t num_xdl_shuffles = GetVectorSizeC() / elem_per_thread;
176  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
177  {
178  static_assert((kMPerBlock % (MPerXdl * MWave) == 0) &&
179  (kMPerBlock % num_xdl_shuffles == 0),
180  "kMPerBlock must be divisible by MPerXdl*MWave and "
181  "num_xdl_shuffles for CShuffleEpilogue");
182  return std::make_tuple(min(num_xdl_shuffles, kMPerBlock / (MPerXdl * MWave)), 1);
183  }
184  else
185  {
186  static_assert((kNPerBlock % (NPerXdl * NWave) == 0) &&
187  (kNPerBlock % num_xdl_shuffles == 0),
188  "kNPerBlock must be divisible by NPerXdl*NWave and "
189  "num_xdl_shuffles for CShuffleEpilogue");
190  return std::make_tuple(1, min(num_xdl_shuffles, kNPerBlock / (NPerXdl * NWave)));
191  }
192  }
193  }();
194  static constexpr index_t NumMXdlPerWavePerShuffle = std::get<0>(shuffle_tile_tuple);
195  static constexpr index_t NumNXdlPerWavePerShuffle = std::get<1>(shuffle_tile_tuple);
196 
197  static constexpr auto MNPerIterationShuffle = [] {
198  constexpr index_t m_val = MPerXdl * MWave * NumMXdlPerWavePerShuffle;
199  constexpr index_t n_val = NPerXdl * NWave * NumNXdlPerWavePerShuffle;
200  if constexpr(kMPerBlock % m_val != 0 || kNPerBlock % n_val != 0)
202  else
203  return std::make_tuple(m_val, n_val);
204  }();
205  static constexpr index_t MPerIterationShuffle = std::get<0>(MNPerIterationShuffle);
206  static constexpr index_t NPerIterationShuffle = std::get<1>(MNPerIterationShuffle);
207 
209  BTypeToUse,
210  AccDataType,
211  MPerXdl,
212  NPerXdl,
213  KPerXdl,
214  isCTransposed>;
215 
216  using CWarpDstr = typename WG::CWarpDstr;
217  using CWarpTensor = typename WG::CWarpTensor;
218  using CWarpDstrEncoding = typename WG::CWarpDstrEncoding;
222 
223  template <typename Problem>
225  {
226  // N is contiguous dimension
227  if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
228  {
232  }
233  // M is contiguous dimension
234  else if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::ColumnMajor>)
235  {
239  }
240  else
241  {
242  static_assert(false, "Unsupported ELayout!");
243  }
244  }
245 
247  {
248  constexpr auto block_outer_dstr_encoding =
255  sequence<0, 0>>{};
256  constexpr auto block_dstr_encoding = detail::make_embed_tile_distribution_encoding(
257  block_outer_dstr_encoding, typename CWarpDstr::DstrEncode{});
258 
259  return block_dstr_encoding;
260  }
261 
263  {
265  }
266 
267  template <auto iAccess, typename LdsTile, typename ScaleM, typename ScaleN>
268  CK_TILE_DEVICE void
269  scale_tile(LdsTile& lds_tile, ScaleM& scale_m_window, ScaleN& scale_n_window)
270  {
271  // Load tiles
272  const auto scale_m_tile = load_tile(scale_m_window);
273  const auto scale_n_tile = load_tile(scale_n_window);
274 
275  // Compute element-wise product in-place i.e. lds_tile = lds_tile * scale_m * scale_n
277  element_wise::MultiDMultiply{}, lds_tile, lds_tile, scale_m_tile, scale_n_tile);
278 
279  // Move scale windows
280  constexpr index_t num_access = SFC::get_num_of_access();
281  if constexpr(iAccess != num_access - 1)
282  {
283  constexpr auto step = SFC::get_forward_step(iAccess);
284 
285  move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})});
286  move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})});
287  }
288  }
289 
290  template <auto iAccess, typename OAccTile, typename LdsTile>
291  CK_TILE_DEVICE void slice_acc_tile(const OAccTile& o_acc_tile, LdsTile& lds_tile)
292  {
293  constexpr auto idx_y_start = SFC::get_index(iAccess);
294 
295  constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (MPerIterationShuffle)>{};
296  constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (NPerIterationShuffle)>{};
297  constexpr auto c_warp_y_lengths =
298  to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
299  constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
300 
301  lds_tile.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
304  c_warp_y_index_zeros),
306  c_warp_y_lengths));
307  }
308 
309  template <typename LdsTile, typename InLdsWindow>
310  CK_TILE_DEVICE void cast_lds_tile(LdsTile& lds_tile, InLdsWindow& in_lds_window)
311  {
312  const auto c_warptile_in_tensor_casted = cast_tile<ODataType>(lds_tile);
313 
314  store_tile(in_lds_window, c_warptile_in_tensor_casted);
315  }
316 
317  template <typename DramWindows, typename COutTensor>
318  CK_TILE_DEVICE void apply_d_tensors(DramWindows& d_dram_windows, COutTensor& c_out_tensor)
319  {
320  const auto ds_tensor = generate_tuple(
321  [&](auto idx) { return load_tile(d_dram_windows[idx]); }, number<NumDTensor>{});
322 
323  const auto c_ds_tiles = concat_tuple_of_reference(
324  tie(c_out_tensor, c_out_tensor),
325  generate_tie([&](auto idx) -> const auto& { return ds_tensor[idx]; },
326  number<NumDTensor>{}));
327 
328  tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_ds_tiles);
329  }
330 
331  template <typename OutDramWindow, typename COutTensor>
332  CK_TILE_DEVICE void store_to_dram(OutDramWindow& out_dram_window,
333  const COutTensor& c_out_tensor)
334  {
335  if constexpr(MemoryOperation == memory_operation_enum::set)
336  {
337  store_tile(out_dram_window, c_out_tensor);
338  }
339  else
340  {
341  update_tile(out_dram_window, c_out_tensor);
342  }
343  }
344 
348  template <auto iAccess, typename OutDramWindow, typename DDramWindows>
349  CK_TILE_DEVICE void move_windows(OutDramWindow& out_dram_window, DDramWindows& d_dram_windows)
350  {
351  constexpr index_t num_access = SFC::get_num_of_access();
352  if constexpr(iAccess != num_access - 1)
353  {
354  constexpr auto step = SFC::get_forward_step(iAccess);
355 
356  // move the output dram window
357  move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
358 
359  // move windows for each of the D matrices (inputs for element-wise)
360  static_for<0, NumDTensor, 1>{}([&](auto idx) {
361  move_tile_window(d_dram_windows[idx], {step.at(number<0>{}), step.at(number<1>{})});
362  });
363  }
364  }
365 
366  // TODO: Check if there would be nicer ways to overload rather than with EmptyScale or nullptr_t
367  struct EmptyScale
368  {
369  };
370  template <typename ODramWindow,
371  typename OAccTile,
372  typename DsDramWindows,
373  typename ScaleM = EmptyScale,
374  typename ScaleN = EmptyScale>
375  CK_TILE_DEVICE auto operator()(ODramWindow& out_dram_window,
376  const OAccTile& o_acc_tile,
377  const DsDramWindows& ds_dram_windows,
378  void* p_smem,
379  const ScaleM& scale_m = {},
380  const ScaleN& scale_n = {})
381  {
382  constexpr auto LdsTileDistr = make_static_tile_distribution(MakeLdsDistributionEncode());
383 
384  auto lds_tile = make_static_distributed_tensor<AccDataType>(LdsTileDistr);
385 
386  constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
387  auto o_lds_block = make_tensor_view<address_space_enum::lds>(
388  static_cast<ODataType*>(p_smem), lds_block_desc);
389 
390  auto in_lds_window = make_tile_window(
391  o_lds_block,
392  make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
393  {0, 0},
394  LdsTileDistr);
395 
396  auto out_lds_window = make_tile_window(
397  o_lds_block,
398  make_tuple(number<MPerIterationShuffle>{}, number<NPerIterationShuffle>{}),
399  {0, 0});
400 
401  constexpr index_t num_access = SFC::get_num_of_access();
402 
403  static_assert(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>,
404  "Currently, the CShuffle Epilogue only supports the Row Major Output layout");
405 
406  using TileEncodingPattern =
407  tile_distribution_encoding_pattern_2d<kBlockSize,
410  GetVectorSizeC(),
412  Problem::kNumWaveGroups>;
413  constexpr auto dram_tile_distribution =
414  TileEncodingPattern::make_2d_static_tile_distribution();
415 
416  auto d_dram_windows = generate_tuple(
417  [&](auto idx) {
418  return make_tile_window(ds_dram_windows[idx], dram_tile_distribution);
419  },
420  number<NumDTensor>{});
421 
422  constexpr bool has_scales =
424  auto scale_m_window = [&]() {
425  if constexpr(has_scales)
426  {
427  return make_tile_window(scale_m, lds_tile.get_tile_distribution());
428  }
429  else
430  {
431  return EmptyScale{};
432  }
433  }();
434  auto scale_n_window = [&]() {
435  if constexpr(has_scales)
436  {
437  return make_tile_window(scale_n, lds_tile.get_tile_distribution());
438  }
439  else
440  {
441  return EmptyScale{};
442  }
443  }();
444 
445  static_for<0, num_access, 1>{}([&](auto iAccess) {
446  block_sync_lds();
447  slice_acc_tile<iAccess>(o_acc_tile, lds_tile);
448 
449  if constexpr(has_scales)
450  {
451  scale_tile<iAccess>(lds_tile, scale_m_window, scale_n_window);
452  }
453 
454  cast_lds_tile(lds_tile, in_lds_window);
455  block_sync_lds();
456 
457  auto c_out_tensor = load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
458 
459  apply_d_tensors(d_dram_windows, c_out_tensor);
460  store_to_dram(out_dram_window, c_out_tensor);
461  move_windows<iAccess>(out_dram_window, d_dram_windows);
462  });
463  }
464 };
465 } // namespace ck_tile
CK_TILE_DEVICE void block_sync_lds()
Definition: arch.hpp:190
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:457
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:268
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition: tile_elementwise.hpp:23
constexpr tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:376
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto generate_tie(F &&f, number< N >)
Definition: tuple.hpp:435
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc &in_element_func, const Tuple &t, std::index_sequence< I... >)
Template function that "unpacks" a tuple and applies an element-wise operation.
Definition: tile_elementwise.hpp:71
@ thread_raked
Thread raked pattern.
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1052
constexpr CK_TILE_HOST_DEVICE auto merge_sequences(Seqs...)
Definition: sequence.hpp:823
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:95
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto concat_tuple_of_reference(const tuple< X &... > &tx, const tuple< Y &... > &ty)
Definition: tuple.hpp:443
CK_TILE_DEVICE void update_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: update_tile.hpp:22
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:480
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition: warp_gemm_dispatcher.hpp:178
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1023
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:42
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1350
Definition: cshuffle_epilogue.hpp:368
Definition: cshuffle_epilogue.hpp:66
static constexpr index_t kBlockSize
Definition: cshuffle_epilogue.hpp:82
CK_TILE_DEVICE void scale_tile(LdsTile &lds_tile, ScaleM &scale_m_window, ScaleN &scale_n_window)
Definition: cshuffle_epilogue.hpp:269
CK_TILE_DEVICE void slice_acc_tile(const OAccTile &o_acc_tile, LdsTile &lds_tile)
Definition: cshuffle_epilogue.hpp:291
static constexpr CK_TILE_HOST_DEVICE auto MakeLdsBlockDescriptor()
Definition: cshuffle_epilogue.hpp:224
typename WG::CWarpTensor CWarpTensor
Definition: cshuffle_epilogue.hpp:217
typename WG::CWarpDstrEncoding CWarpDstrEncoding
Definition: cshuffle_epilogue.hpp:218
remove_cvref_t< Problem_ > Problem
Definition: cshuffle_epilogue.hpp:67
static constexpr index_t MPerXdl
Definition: cshuffle_epilogue.hpp:87
static constexpr bool FixedVectorSize
Definition: cshuffle_epilogue.hpp:91
static constexpr CK_TILE_HOST_DEVICE index_t GetVectorSizeD(number< I > index)
Get the vector store size for Di tensor.
Definition: cshuffle_epilogue.hpp:138
remove_cvref_t< typename Problem::ODataType > ODataType
Definition: cshuffle_epilogue.hpp:71
CK_TILE_DEVICE void store_to_dram(OutDramWindow &out_dram_window, const COutTensor &c_out_tensor)
Definition: cshuffle_epilogue.hpp:332
static constexpr index_t kNPerBlock
Definition: cshuffle_epilogue.hpp:84
remove_cvref_t< typename Problem::ELayout > ELayout
Definition: cshuffle_epilogue.hpp:79
static constexpr memory_operation_enum MemoryOperation
Definition: cshuffle_epilogue.hpp:81
remove_cvref_t< typename Problem::DsLayout > DsLayout
Definition: cshuffle_epilogue.hpp:73
static constexpr CK_TILE_DEVICE auto MakeLdsDistributionEncode()
Definition: cshuffle_epilogue.hpp:246
static constexpr index_t MPerIteration
Definition: cshuffle_epilogue.hpp:93
static constexpr auto MNPerIterationShuffle
Definition: cshuffle_epilogue.hpp:197
static constexpr index_t isCTransposed
Definition: cshuffle_epilogue.hpp:90
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: cshuffle_epilogue.hpp:262
CK_TILE_DEVICE void apply_d_tensors(DramWindows &d_dram_windows, COutTensor &c_out_tensor)
Definition: cshuffle_epilogue.hpp:318
static constexpr index_t MWave
Definition: cshuffle_epilogue.hpp:85
static constexpr CK_TILE_HOST_DEVICE index_t GetVectorSizeC()
Get the vector store size for C tensor.
Definition: cshuffle_epilogue.hpp:109
static constexpr index_t VectorSizeC
Definition: cshuffle_epilogue.hpp:92
remove_cvref_t< typename Problem::DsDataType > DsDataType
Definition: cshuffle_epilogue.hpp:72
std::conditional_t< std::is_same_v< BDataType, pk_int4_t >, ADataType, BDataType > BTypeToUse
Definition: cshuffle_epilogue.hpp:78
CK_TILE_DEVICE void move_windows(OutDramWindow &out_dram_window, DDramWindows &d_dram_windows)
Move both the output and D tensors windows for the next access.
Definition: cshuffle_epilogue.hpp:349
remove_cvref_t< typename Problem::CDElementwise > CDElementwise
Definition: cshuffle_epilogue.hpp:80
static constexpr index_t NPerIterationShuffle
Definition: cshuffle_epilogue.hpp:206
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition: cshuffle_epilogue.hpp:70
static constexpr index_t NumDTensor
Definition: cshuffle_epilogue.hpp:95
static constexpr index_t KPerXdl
Definition: cshuffle_epilogue.hpp:89
static constexpr index_t NumMXdlPerWavePerShuffle
Definition: cshuffle_epilogue.hpp:194
static constexpr index_t NumNXdlPerWavePerShuffle
Definition: cshuffle_epilogue.hpp:195
CK_TILE_DEVICE auto operator()(ODramWindow &out_dram_window, const OAccTile &o_acc_tile, const DsDramWindows &ds_dram_windows, void *p_smem, const ScaleM &scale_m={}, const ScaleN &scale_n={})
Definition: cshuffle_epilogue.hpp:375
WarpGemmDispatcher< ATypeToUse, BTypeToUse, AccDataType, MPerXdl, NPerXdl, KPerXdl, isCTransposed > WG
Definition: cshuffle_epilogue.hpp:214
static constexpr index_t MPerIterationShuffle
Definition: cshuffle_epilogue.hpp:205
CK_TILE_DEVICE void cast_lds_tile(LdsTile &lds_tile, InLdsWindow &in_lds_window)
Definition: cshuffle_epilogue.hpp:310
remove_cvref_t< typename Problem::BDataType > BDataType
Definition: cshuffle_epilogue.hpp:69
static constexpr auto shuffle_tile_tuple
Shuffle tile configuration parameters.
Definition: cshuffle_epilogue.hpp:167
static constexpr index_t NWave
Definition: cshuffle_epilogue.hpp:86
remove_cvref_t< typename Problem::ADataType > ADataType
Definition: cshuffle_epilogue.hpp:68
static constexpr index_t kMPerBlock
Definition: cshuffle_epilogue.hpp:83
static constexpr index_t NPerIteration
Definition: cshuffle_epilogue.hpp:94
static constexpr index_t NPerXdl
Definition: cshuffle_epilogue.hpp:88
typename WG::CWarpDstr CWarpDstr
Definition: cshuffle_epilogue.hpp:216
std::conditional_t< std::is_same_v< ADataType, pk_int4_t >, BDataType, ADataType > ATypeToUse
Definition: cshuffle_epilogue.hpp:75
Definition: cshuffle_epilogue.hpp:36
remove_cvref_t< DsDataType_ > DsDataType
Definition: cshuffle_epilogue.hpp:41
remove_cvref_t< BDataType_ > BDataType
Definition: cshuffle_epilogue.hpp:38
remove_cvref_t< AccDataType_ > AccDataType
Definition: cshuffle_epilogue.hpp:39
static constexpr index_t NPerXdl
Definition: cshuffle_epilogue.hpp:51
static constexpr index_t kBlockSize
Definition: cshuffle_epilogue.hpp:45
static constexpr index_t kMPerBlock
Definition: cshuffle_epilogue.hpp:46
static constexpr index_t KPerXdl
Definition: cshuffle_epilogue.hpp:52
remove_cvref_t< ELayout_ > ELayout
Definition: cshuffle_epilogue.hpp:43
static constexpr index_t MPerXdl
Definition: cshuffle_epilogue.hpp:50
remove_cvref_t< ADataType_ > ADataType
Definition: cshuffle_epilogue.hpp:37
static constexpr index_t MWave
Definition: cshuffle_epilogue.hpp:48
static constexpr index_t kNPerBlock
Definition: cshuffle_epilogue.hpp:47
remove_cvref_t< CDElementwise_ > CDElementwise
Definition: cshuffle_epilogue.hpp:44
static constexpr bool FixedVectorSize
Definition: cshuffle_epilogue.hpp:55
static constexpr memory_operation_enum MemoryOperation
Definition: cshuffle_epilogue.hpp:54
static constexpr index_t isCTransposed
Definition: cshuffle_epilogue.hpp:53
static constexpr index_t NumDTensor
Definition: cshuffle_epilogue.hpp:58
remove_cvref_t< DsLayout_ > DsLayout
Definition: cshuffle_epilogue.hpp:42
remove_cvref_t< ODataType_ > ODataType
Definition: cshuffle_epilogue.hpp:40
static constexpr index_t kNumWaveGroups
Definition: cshuffle_epilogue.hpp:57
static constexpr index_t NWave
Definition: cshuffle_epilogue.hpp:49
static constexpr index_t VectorSizeC
Definition: cshuffle_epilogue.hpp:56
Definition: integral_constant.hpp:13
static constexpr value_type value
Definition: integral_constant.hpp:16
Definition: unary_element_wise_operation.hpp:300
Definition: sequence.hpp:49
Definition: space_filling_curve.hpp:20
static constexpr CK_TILE_HOST_DEVICE auto get_forward_step(number< AccessIdx1d >)
Definition: space_filling_curve.hpp:70
static constexpr CK_TILE_HOST_DEVICE auto get_index(number< AccessIdx1d >)
Definition: space_filling_curve.hpp:158
static constexpr CK_TILE_HOST_DEVICE index_t get_num_of_access()
Definition: space_filling_curve.hpp:46
Definition: functional.hpp:43
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192