/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp Source File
default_2d_epilogue.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 // this epilogue just store out a M*N matrix, row major
13 
14 template <typename AccDataType_,
15  typename ODataType_,
16  bool kPadM_,
17  bool kPadN_,
18  bool UseRawStore_ = true,
19  memory_operation_enum MemoryOperation_ = memory_operation_enum::set>
21 {
24  static constexpr bool kPadM = kPadM_;
25  static constexpr bool kPadN = kPadN_;
26  static constexpr bool UseRawStore = UseRawStore_;
27  static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
28  static constexpr index_t NumDTensor = 0;
29 };
30 
31 template <typename ADataType_,
32  typename BDataType_,
33  typename DsDataType_,
34  typename AccDataType_,
35  typename ODataType_,
36  typename DsLayout_,
37  typename CLayout_,
38  typename CDElementwise_,
39  index_t kM_,
40  index_t kN_,
41  bool kPadM_,
42  bool kPadN_,
43  index_t kMPerXdl_,
44  index_t kNPerXdl_,
45  index_t kKPerXdl_,
46  bool isCTransposed_,
47  bool UseRawStore_ = true,
48  memory_operation_enum MemoryOperation_ = memory_operation_enum::set>
50  ODataType_,
51  kPadM_,
52  kPadN_,
53  UseRawStore_,
54  MemoryOperation_>
55 {
62  static constexpr index_t kMPerBlock = kM_;
63  static constexpr index_t kNPerBlock = kN_;
64  static constexpr index_t kMPerXdl = kMPerXdl_;
65  static constexpr index_t kNPerXdl = kNPerXdl_;
66  static constexpr index_t kKPerXdl = kKPerXdl_;
67  static constexpr index_t isCTransposed = isCTransposed_;
68 
69  static constexpr index_t NumDTensor = DsDataType::size();
70 
71  static_assert(NumDTensor == DsLayout::size(),
72  "The size of DsDataType and DsLayout should be the same");
73 };
74 
75 template <typename Problem_, typename Policy_ = void>
77 {
81  static constexpr bool kPadM = Problem::kPadM;
82  static constexpr bool kPadN = Problem::kPadN;
83  static constexpr bool UseRawStore = Problem::UseRawStore;
84  static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
85 
86  CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
87 
88  // TODO: this function assume store out vector size is the same as OAccTile last dimension size
89  // how do we fix this ?
90  template <typename ODramWindowTmp, typename OAccTile, typename DsDramWindows>
91  CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
92  const OAccTile& o_acc_tile,
93  const DsDramWindows& ds_dram_windows,
94  void* = nullptr)
95  {
96  const auto storeOrUpdateTile = [&](const auto& o_tile) {
97  // TODO: this is ugly
98  if constexpr(UseRawStore && (kPadM || kPadN))
99  {
100  if constexpr(MemoryOperation == memory_operation_enum::set)
101  {
102  store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
103  }
104  else
105  {
106  update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
107  }
109  }
110  else
111  {
112  if constexpr(MemoryOperation == memory_operation_enum::set)
113  {
114  store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
115  }
116  else
117  {
118  update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
119  }
120  }
121  };
122 
123  if constexpr(!std::is_same_v<DsDramWindows, std::nullptr_t> && Problem::NumDTensor >= 1)
124  {
125  using elementwise_result_t = decltype(load_tile(
126  make_tile_window(ds_dram_windows[number<0>{}].get_bottom_tensor_view(),
127  make_tuple(Problem::kMPerBlock, Problem::kNPerBlock),
128  ds_dram_windows[number<0>{}].get_window_origin(),
129  o_acc_tile.get_tile_distribution())));
130 
131  elementwise_result_t elementwise_result;
132 
133  const auto d_tensor_tuple = generate_tuple(
134  [&](auto idx) {
135  const auto d_tile_window =
136  make_tile_window(ds_dram_windows[idx], o_acc_tile.get_tile_distribution());
137  return load_tile(d_tile_window);
138  },
140 
141  const auto c_d_tuple = concat_tuple_of_reference(
142  tie(elementwise_result, o_acc_tile),
143  generate_tie([&](auto idx) -> const auto& { return d_tensor_tuple[idx]; },
145 
146  tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_d_tuple);
147 
148  storeOrUpdateTile(elementwise_result);
149  }
150  else
151  {
152  storeOrUpdateTile(o_acc_tile);
153  }
154  }
155 };
156 
157 template <typename Problem_, typename Policy_ = void>
158 struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
159 {
165  // Used for weight-only quantization kernel, B would be dequantized to the same data type as A
166  using BTypeToUse =
167  std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
172  static constexpr index_t kMPerXdl = Problem::kMPerXdl;
173  static constexpr index_t kNPerXdl = Problem::kNPerXdl;
174  static constexpr index_t kKPerXdl = Problem::kKPerXdl;
175  static constexpr index_t isCTransposed = Problem::isCTransposed;
176 
178  BTypeToUse,
179  AccDataType,
180  kMPerXdl,
181  kNPerXdl,
182  kKPerXdl,
183  isCTransposed>;
184 
185  using CWarpDstr = typename WG::CWarpDstr;
186 
187  CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
188  {
189  // N is contiguous dimension
190  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
191  {
192  if constexpr(isCTransposed)
193  {
194  // In this case each thread has multiple consecutive elements in
195  // N dimension, however consecutive threads' elements have stride.
196  constexpr index_t NDimY = CWarpDstr::NDimY;
197  constexpr auto c_warp_y_lengths =
198  CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
199  static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
200  c_warp_y_lengths.get(number<NDimY - 1>{}));
201  return c_warp_y_lengths.get(number<NDimY - 1>{});
202  }
203  else
204  {
205  // In this case each thread has just a single item in Ndim
206  return (WG::WarpGemmAttribute::Impl::kCNLane *
207  WG::WarpGemmAttribute::Impl::kBNBlock) /
208  WG::kN;
209  }
210  }
211  // M is contiguous dimension
212  else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
213  {
214  if constexpr(isCTransposed)
215  {
216  // In this case each thread has just a single item in Mdim
217  return (WG::WarpGemmAttribute::Impl::kCNLane *
218  WG::WarpGemmAttribute::Impl::kAMBlock) /
219  WG::kN;
220  }
221  else
222  {
223  // In this case each thread has multiple consecutive elements in
224  // M dimension, however consecutive threads' elements have stride.
225  constexpr index_t NDimY = CWarpDstr::NDimY;
226  constexpr auto c_warp_y_lengths =
227  CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
228  static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
229  c_warp_y_lengths.get(number<NDimY - 1>{}));
230  return c_warp_y_lengths.get(number<NDimY - 1>{});
231  }
232  }
233  else
234  {
235  static_assert(false, "Unsupported CLayout!");
236  }
237  }
238 
239  template <index_t I>
240  CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD([[maybe_unused]] number<I> index)
241  {
242  return GetVectorSizeC();
243  }
244 };
245 
246 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
constexpr tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:376
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto generate_tie(F &&f, number< N >)
Definition: tuple.hpp:435
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc &in_element_func, const Tuple &t, std::index_sequence< I... >)
Template function that "unpacks" a tuple and applies an element-wise operation.
Definition: tile_elementwise.hpp:71
CK_TILE_DEVICE void buffer_store_fence(index_t cnt=0)
Definition: amd_buffer_addressing.hpp:1000
CK_TILE_DEVICE void store_tile_raw(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:46
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto concat_tuple_of_reference(const tuple< X &... > &tx, const tuple< Y &... > &ty)
Definition: tuple.hpp:443
CK_TILE_DEVICE void update_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: update_tile.hpp:22
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:22
typename impl::WarpGemmDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition: warp_gemm_dispatcher.hpp:178
CK_TILE_DEVICE void update_tile_raw(tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: update_tile.hpp:68
Definition: default_2d_epilogue.hpp:77
remove_cvref_t< typename Problem::ODataType > ODataType
Definition: default_2d_epilogue.hpp:80
CK_TILE_DEVICE auto operator()(ODramWindowTmp &o_dram_window_tmp, const OAccTile &o_acc_tile, const DsDramWindows &ds_dram_windows, void *=nullptr)
Definition: default_2d_epilogue.hpp:91
static constexpr bool kPadN
Definition: default_2d_epilogue.hpp:82
remove_cvref_t< Problem_ > Problem
Definition: default_2d_epilogue.hpp:78
static constexpr bool kPadM
Definition: default_2d_epilogue.hpp:81
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition: default_2d_epilogue.hpp:79
static constexpr bool UseRawStore
Definition: default_2d_epilogue.hpp:83
static constexpr memory_operation_enum MemoryOperation
Definition: default_2d_epilogue.hpp:84
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: default_2d_epilogue.hpp:86
Definition: default_2d_epilogue.hpp:21
static constexpr bool UseRawStore
Definition: default_2d_epilogue.hpp:26
remove_cvref_t< ODataType_ > ODataType
Definition: default_2d_epilogue.hpp:23
remove_cvref_t< AccDataType_ > AccDataType
Definition: default_2d_epilogue.hpp:22
static constexpr bool kPadM
Definition: default_2d_epilogue.hpp:24
static constexpr memory_operation_enum MemoryOperation
Definition: default_2d_epilogue.hpp:27
static constexpr bool kPadN
Definition: default_2d_epilogue.hpp:25
static constexpr index_t NumDTensor
Definition: default_2d_epilogue.hpp:28
Definition: default_2d_epilogue.hpp:159
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition: default_2d_epilogue.hpp:163
static constexpr index_t kMPerXdl
Definition: default_2d_epilogue.hpp:172
static constexpr index_t kNPerXdl
Definition: default_2d_epilogue.hpp:173
static constexpr index_t kKPerXdl
Definition: default_2d_epilogue.hpp:174
remove_cvref_t< typename Problem::CDElementwise > CDElementwise
Definition: default_2d_epilogue.hpp:170
static constexpr CK_TILE_HOST_DEVICE auto GetVectorSizeC()
Definition: default_2d_epilogue.hpp:187
static constexpr index_t isCTransposed
Definition: default_2d_epilogue.hpp:175
remove_cvref_t< typename Problem::CLayout > CLayout
Definition: default_2d_epilogue.hpp:171
remove_cvref_t< typename Problem::ODataType > ODataType
Definition: default_2d_epilogue.hpp:164
remove_cvref_t< typename Problem::DsDataType > DsDataType
Definition: default_2d_epilogue.hpp:168
remove_cvref_t< typename Problem::BDataType > BDataType
Definition: default_2d_epilogue.hpp:162
std::conditional_t< std::is_same_v< BDataType, pk_int4_t >, ADataType, BDataType > BTypeToUse
Definition: default_2d_epilogue.hpp:167
static constexpr CK_TILE_HOST_DEVICE auto GetVectorSizeD([[maybe_unused]] number< I > index)
Definition: default_2d_epilogue.hpp:240
WarpGemmDispatcher< ADataType, BTypeToUse, AccDataType, kMPerXdl, kNPerXdl, kKPerXdl, isCTransposed > WG
Definition: default_2d_epilogue.hpp:183
remove_cvref_t< typename Problem::ADataType > ADataType
Definition: default_2d_epilogue.hpp:161
typename WG::CWarpDstr CWarpDstr
Definition: default_2d_epilogue.hpp:185
remove_cvref_t< typename Problem::DsLayout > DsLayout
Definition: default_2d_epilogue.hpp:169
remove_cvref_t< Problem_ > Problem
Definition: default_2d_epilogue.hpp:160
Definition: default_2d_epilogue.hpp:55
remove_cvref_t< DsLayout_ > DsLayout
Definition: default_2d_epilogue.hpp:61
static constexpr index_t kNPerXdl
Definition: default_2d_epilogue.hpp:65
remove_cvref_t< CLayout_ > CLayout
Definition: default_2d_epilogue.hpp:58
static constexpr index_t kMPerBlock
Definition: default_2d_epilogue.hpp:62
static constexpr index_t kNPerBlock
Definition: default_2d_epilogue.hpp:63
static constexpr index_t isCTransposed
Definition: default_2d_epilogue.hpp:67
remove_cvref_t< CDElementwise_ > CDElementwise
Definition: default_2d_epilogue.hpp:60
static constexpr index_t kKPerXdl
Definition: default_2d_epilogue.hpp:66
remove_cvref_t< ADataType_ > ADataType
Definition: default_2d_epilogue.hpp:56
static constexpr index_t kMPerXdl
Definition: default_2d_epilogue.hpp:64
static constexpr index_t NumDTensor
Definition: default_2d_epilogue.hpp:69
remove_cvref_t< BDataType_ > BDataType
Definition: default_2d_epilogue.hpp:57
remove_cvref_t< DsDataType_ > DsDataType
Definition: default_2d_epilogue.hpp:59
Definition: integral_constant.hpp:13