/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp Source File
default_2d_epilogue.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 // this epilogue just store out a M*N matrix, row major
13 
14 template <typename AccDataType_,
15  typename ODataType_,
16  bool kPadM_,
17  bool kPadN_,
18  bool UseRawStore_ = true,
19  memory_operation_enum MemoryOperation_ = memory_operation_enum::set>
21 {
24  static constexpr bool kPadM = kPadM_;
25  static constexpr bool kPadN = kPadN_;
26  static constexpr bool UseRawStore = UseRawStore_;
27  static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
28  static constexpr index_t NumDTensor = 0;
29 };
30 
31 template <typename AsDataType_,
32  typename BsDataType_,
33  typename DsDataType_,
34  typename AccDataType_,
35  typename ODataType_,
36  typename DsLayout_,
37  typename CLayout_,
38  typename CDElementwise_,
39  index_t kM_,
40  index_t kN_,
41  bool kPadM_,
42  bool kPadN_,
43  index_t kMPerXdl_,
44  index_t kNPerXdl_,
45  index_t kKPerXdl_,
46  bool isCTransposed_,
47  bool UseRawStore_ = true,
48  memory_operation_enum MemoryOperation_ = memory_operation_enum::set>
50  ODataType_,
51  kPadM_,
52  kPadN_,
53  UseRawStore_,
54  MemoryOperation_>
55 {
62  static constexpr index_t kMPerBlock = kM_;
63  static constexpr index_t kNPerBlock = kN_;
64  static constexpr index_t kMPerXdl = kMPerXdl_;
65  static constexpr index_t kNPerXdl = kNPerXdl_;
66  static constexpr index_t kKPerXdl = kKPerXdl_;
67  static constexpr index_t isCTransposed = isCTransposed_;
68 
69  static constexpr index_t NumDTensor = DsDataType::size();
70 
71  static_assert(NumDTensor == DsLayout::size(),
72  "The size of DsDataType and DsLayout should be the same");
73 };
74 
75 template <typename Problem_, typename Policy_ = void>
77 {
81  static constexpr bool kPadM = Problem::kPadM;
82  static constexpr bool kPadN = Problem::kPadN;
83  static constexpr bool UseRawStore = Problem::UseRawStore;
84  static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
85 
86  CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
87 
88  // TODO: this function assume store out vector size is the same as OAccTile last dimension size
89  // how do we fix this ?
90  template <typename ODramWindowTmp, typename OAccTile, typename DsDramWindows>
91  CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
92  const OAccTile& o_acc_tile,
93  const DsDramWindows& ds_dram_windows,
94  void* = nullptr) const
95  {
96  constexpr bool is_partition_index =
97  std::is_convertible_v<decltype(ds_dram_windows),
98  decltype(get_partition_index(
99  o_acc_tile.get_tile_distribution()))>;
100 
101  const auto storeOrUpdateTile = [&](const auto& o_tile) {
102  // TODO: this is ugly
103  if constexpr(UseRawStore && (kPadM || kPadN))
104  {
105  if constexpr(MemoryOperation == memory_operation_enum::set)
106  {
107  if constexpr(is_partition_index)
108  {
109  store_tile_raw(o_dram_window_tmp,
110  cast_tile<ODataType>(o_tile),
111  /*partition_index=*/ds_dram_windows);
112  }
113  else
114  {
115  store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
116  }
117  }
118  else
119  {
120  update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
121  }
123  }
124  else
125  {
126  if constexpr(MemoryOperation == memory_operation_enum::set)
127  {
128  if constexpr(is_partition_index)
129  {
130  store_tile(o_dram_window_tmp,
131  cast_tile<ODataType>(o_tile),
132  /*partition_index=*/ds_dram_windows);
133  }
134  else
135  {
136  store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
137  }
138  }
139  else
140  {
141  if constexpr(is_partition_index)
142  {
143  update_tile(o_dram_window_tmp,
144  cast_tile<ODataType>(o_tile),
145  /*partition_index=*/ds_dram_windows);
146  }
147  else
148  {
149  update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
150  }
151  }
152  }
153  };
154 
155  if constexpr(!std::is_same_v<DsDramWindows, std::nullptr_t> && !is_partition_index &&
156  Problem::NumDTensor >= 1)
157  {
158  using elementwise_result_t = decltype(load_tile(
159  make_tile_window(ds_dram_windows[number<0>{}].get_bottom_tensor_view(),
160  make_tuple(Problem::kMPerBlock, Problem::kNPerBlock),
161  ds_dram_windows[number<0>{}].get_window_origin(),
162  o_acc_tile.get_tile_distribution())));
163 
164  elementwise_result_t elementwise_result;
165 
166  const auto d_tensor_tuple = generate_tuple(
167  [&](auto idx) {
168  const auto d_tile_window =
169  make_tile_window(ds_dram_windows[idx], o_acc_tile.get_tile_distribution());
170  return load_tile(d_tile_window);
171  },
173 
174  const auto c_d_tuple = concat_tuple_of_reference(
175  tie(elementwise_result, o_acc_tile),
176  generate_tie([&](auto idx) -> const auto& { return d_tensor_tuple[idx]; },
178 
179  tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_d_tuple);
180 
181  storeOrUpdateTile(elementwise_result);
182  }
183  else
184  {
185  storeOrUpdateTile(o_acc_tile);
186  }
187  }
188 };
189 
190 template <typename Problem_, typename Policy_ = void>
191 struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
192 {
200 
204 
208 
211  // Used for weight-only quantization kernel, B would be dequantized to the same data type as A
212  using BTypeToUse =
213  std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
214 
219  static constexpr index_t kMPerXdl = Problem::kMPerXdl;
220  static constexpr index_t kNPerXdl = Problem::kNPerXdl;
221  static constexpr index_t kKPerXdl = Problem::kKPerXdl;
222  static constexpr index_t isCTransposed = Problem::isCTransposed;
223 
225  BTypeToUse,
226  AccDataType,
227  kMPerXdl,
228  kNPerXdl,
229  kKPerXdl,
230  isCTransposed>;
231 
232  using CWarpDstr = typename WG::CWarpDstr;
233 
234  CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
235  {
236  // N is contiguous dimension
237  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
238  {
239  if constexpr(isCTransposed)
240  {
241  // In this case each thread has multiple consecutive elements in
242  // N dimension, however consecutive threads' elements have stride.
243  constexpr index_t NDimY = CWarpDstr::NDimY;
244  constexpr auto c_warp_y_lengths =
245  CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
246  static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
247  c_warp_y_lengths.get(number<NDimY - 1>{}));
248  return c_warp_y_lengths.get(number<NDimY - 1>{});
249  }
250  else
251  {
252  // In this case each thread has just a single item in Ndim
253  return (WG::WarpGemmAttribute::Impl::kCNLane *
254  WG::WarpGemmAttribute::Impl::kBNBlock) /
255  WG::kN;
256  }
257  }
258  // M is contiguous dimension
259  else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
260  {
261  if constexpr(isCTransposed)
262  {
263  // In this case each thread has just a single item in Mdim
264  return (WG::WarpGemmAttribute::Impl::kCNLane *
265  WG::WarpGemmAttribute::Impl::kAMBlock) /
266  WG::kN;
267  }
268  else
269  {
270  // In this case each thread has multiple consecutive elements in
271  // M dimension, however consecutive threads' elements have stride.
272  constexpr index_t NDimY = CWarpDstr::NDimY;
273  constexpr auto c_warp_y_lengths =
274  CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
275  static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
276  c_warp_y_lengths.get(number<NDimY - 1>{}));
277  return c_warp_y_lengths.get(number<NDimY - 1>{});
278  }
279  }
280  else
281  {
282  static_assert(false, "Unsupported CLayout!");
283  }
284  }
285 
286  template <index_t I>
287  CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD([[maybe_unused]] number<I> index)
288  {
289  return GetVectorSizeC();
290  }
291 };
292 
293 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
constexpr tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:376
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto generate_tie(F &&f, number< N >)
Definition: tuple.hpp:435
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
typename impl::warp_gemm_dispatcher::Dispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity, AttrNumAccess >::Type WarpGemmDispatcher
Definition: warp_gemm_dispatcher.hpp:177
CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc &in_element_func, const Tuple &t, std::index_sequence< I... >)
Template function that "unpacks" a tuple and applies an element-wise operation.
Definition: tile_elementwise.hpp:71
CK_TILE_DEVICE void buffer_store_fence(index_t cnt=0)
Definition: amd_buffer_addressing.hpp:1063
CK_TILE_DEVICE void store_tile_raw(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:72
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE auto concat_tuple_of_reference(const tuple< X &... > &tx, const tuple< Y &... > &ty)
Definition: tuple.hpp:443
CK_TILE_DEVICE void update_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: update_tile.hpp:22
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:24
CK_TILE_DEVICE auto load_tile(const TileWindow_ &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:36
CK_TILE_DEVICE void update_tile_raw(tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: update_tile.hpp:68
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
Definition: default_2d_epilogue.hpp:77
remove_cvref_t< typename Problem::ODataType > ODataType
Definition: default_2d_epilogue.hpp:80
CK_TILE_DEVICE auto operator()(ODramWindowTmp &o_dram_window_tmp, const OAccTile &o_acc_tile, const DsDramWindows &ds_dram_windows, void *=nullptr) const
Definition: default_2d_epilogue.hpp:91
static constexpr bool kPadN
Definition: default_2d_epilogue.hpp:82
remove_cvref_t< Problem_ > Problem
Definition: default_2d_epilogue.hpp:78
static constexpr bool kPadM
Definition: default_2d_epilogue.hpp:81
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition: default_2d_epilogue.hpp:79
static constexpr bool UseRawStore
Definition: default_2d_epilogue.hpp:83
static constexpr memory_operation_enum MemoryOperation
Definition: default_2d_epilogue.hpp:84
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: default_2d_epilogue.hpp:86
Definition: default_2d_epilogue.hpp:21
static constexpr bool UseRawStore
Definition: default_2d_epilogue.hpp:26
remove_cvref_t< ODataType_ > ODataType
Definition: default_2d_epilogue.hpp:23
remove_cvref_t< AccDataType_ > AccDataType
Definition: default_2d_epilogue.hpp:22
static constexpr bool kPadM
Definition: default_2d_epilogue.hpp:24
static constexpr memory_operation_enum MemoryOperation
Definition: default_2d_epilogue.hpp:27
static constexpr bool kPadN
Definition: default_2d_epilogue.hpp:25
static constexpr index_t NumDTensor
Definition: default_2d_epilogue.hpp:28
Definition: default_2d_epilogue.hpp:192
remove_cvref_t< typename Problem::BsDataType > BsDataType
Definition: default_2d_epilogue.hpp:195
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition: default_2d_epilogue.hpp:196
static constexpr index_t kMPerXdl
Definition: default_2d_epilogue.hpp:219
static constexpr bool ADataTypeIsTuple
Definition: default_2d_epilogue.hpp:198
remove_cvref_t< std::tuple_element_t< number< 0 >{}, BsDataTypeTuple > > BDataType
Definition: default_2d_epilogue.hpp:210
static constexpr index_t kNPerXdl
Definition: default_2d_epilogue.hpp:220
std::conditional_t< BDataTypeIsTuple, remove_cvref_t< BsDataType >, remove_cvref_t< tuple< BsDataType > >> BsDataTypeTuple
Definition: default_2d_epilogue.hpp:207
static constexpr index_t kKPerXdl
Definition: default_2d_epilogue.hpp:221
remove_cvref_t< typename Problem::CDElementwise > CDElementwise
Definition: default_2d_epilogue.hpp:217
static constexpr CK_TILE_HOST_DEVICE auto GetVectorSizeC()
Definition: default_2d_epilogue.hpp:234
static constexpr index_t isCTransposed
Definition: default_2d_epilogue.hpp:222
remove_cvref_t< typename Problem::CLayout > CLayout
Definition: default_2d_epilogue.hpp:218
remove_cvref_t< typename Problem::ODataType > ODataType
Definition: default_2d_epilogue.hpp:197
remove_cvref_t< typename Problem::DsDataType > DsDataType
Definition: default_2d_epilogue.hpp:215
remove_cvref_t< typename Problem::AsDataType > AsDataType
Definition: default_2d_epilogue.hpp:194
std::conditional_t< std::is_same_v< BDataType, pk_int4_t >, ADataType, BDataType > BTypeToUse
Definition: default_2d_epilogue.hpp:213
static constexpr CK_TILE_HOST_DEVICE auto GetVectorSizeD([[maybe_unused]] number< I > index)
Definition: default_2d_epilogue.hpp:287
WarpGemmDispatcher< ADataType, BTypeToUse, AccDataType, kMPerXdl, kNPerXdl, kKPerXdl, isCTransposed > WG
Definition: default_2d_epilogue.hpp:230
std::conditional_t< ADataTypeIsTuple, remove_cvref_t< AsDataType >, remove_cvref_t< tuple< AsDataType > >> AsDataTypeTuple
Definition: default_2d_epilogue.hpp:203
static constexpr bool BDataTypeIsTuple
Definition: default_2d_epilogue.hpp:199
remove_cvref_t< std::tuple_element_t< number< 0 >{}, AsDataTypeTuple > > ADataType
Definition: default_2d_epilogue.hpp:209
typename WG::CWarpDstr CWarpDstr
Definition: default_2d_epilogue.hpp:232
remove_cvref_t< typename Problem::DsLayout > DsLayout
Definition: default_2d_epilogue.hpp:216
remove_cvref_t< Problem_ > Problem
Definition: default_2d_epilogue.hpp:193
Definition: default_2d_epilogue.hpp:55
remove_cvref_t< DsLayout_ > DsLayout
Definition: default_2d_epilogue.hpp:61
static constexpr index_t kKPerXdl
Definition: default_2d_epilogue.hpp:66
remove_cvref_t< DsDataType_ > DsDataType
Definition: default_2d_epilogue.hpp:59
remove_cvref_t< AsDataType_ > AsDataType
Definition: default_2d_epilogue.hpp:56
static constexpr index_t kMPerXdl
Definition: default_2d_epilogue.hpp:64
static constexpr index_t kNPerBlock
Definition: default_2d_epilogue.hpp:63
static constexpr index_t kNPerXdl
Definition: default_2d_epilogue.hpp:65
static constexpr index_t kMPerBlock
Definition: default_2d_epilogue.hpp:62
remove_cvref_t< CLayout_ > CLayout
Definition: default_2d_epilogue.hpp:58
remove_cvref_t< BsDataType_ > BsDataType
Definition: default_2d_epilogue.hpp:57
remove_cvref_t< CDElementwise_ > CDElementwise
Definition: default_2d_epilogue.hpp:60
static constexpr index_t isCTransposed
Definition: default_2d_epilogue.hpp:67
static constexpr index_t NumDTensor
Definition: default_2d_epilogue.hpp:69
Definition: integral_constant.hpp:13