/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp Source File
default_2d_epilogue.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 // this epilogue just store out a M*N matrix, row major
13 
14 template <typename AccDataType_,
15  typename ODataType_,
16  bool kPadM_,
17  bool kPadN_,
18  bool UseRawStore_ = true,
21 {
24  static constexpr bool kPadM = kPadM_;
25  static constexpr bool kPadN = kPadN_;
26  static constexpr bool UseRawStore = UseRawStore_;
27  static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
28 };
29 
30 template <typename ADataType_,
31  typename BDataType_,
32  typename AccDataType_,
33  typename ODataType_,
34  typename CLayout_,
35  bool kPadM_,
36  bool kPadN_,
37  index_t kMPerXdl_,
38  index_t kNPerXdl_,
39  index_t kKPerXdl_,
40  bool isCTransposed_,
41  bool UseRawStore_ = true,
44  ODataType_,
45  kPadM_,
46  kPadN_,
47  UseRawStore_,
48  MemoryOperation_>
49 {
53  static constexpr index_t kMPerXdl = kMPerXdl_;
54  static constexpr index_t kNPerXdl = kNPerXdl_;
55  static constexpr index_t kKPerXdl = kKPerXdl_;
56  static constexpr index_t isCTransposed = isCTransposed_;
57 };
58 
59 template <typename Problem_, typename Policy_ = void>
61 {
65  static constexpr bool kPadM = Problem::kPadM;
66  static constexpr bool kPadN = Problem::kPadN;
67  static constexpr bool UseRawStore = Problem::UseRawStore;
68  static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
69 
70  CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
71 
72  // TODO: this function assume store out vector size is the same as OAccTile last dimension size
73  // how do we fix this ?
74  template <typename ODramWindowTmp, typename OAccTile>
75  CK_TILE_DEVICE auto
76  operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr)
77  {
78  // TODO: this is ugly
79  if constexpr(UseRawStore && (kPadM || kPadN))
80  {
82  {
83  store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
84  }
85  else
86  {
87  update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
88  }
90  }
91  else
92  {
94  {
95  store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
96  }
97  else
98  {
99  update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
100  }
101  }
102  }
103 
104  template <typename ODramWindowTmp, typename OAccTile, typename DsDramWindows>
105  CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
106  const OAccTile& o_acc_tile,
107  const DsDramWindows& /* unused */,
108  void* = nullptr)
109  {
110  return operator()<ODramWindowTmp, OAccTile>(o_dram_window_tmp, o_acc_tile);
111  }
112 };
113 
114 template <typename Problem_, typename Policy_ = void>
115 struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
116 {
122  // Used for weight-only quantization kernel, B would be dequantized to the same data type as A
123  using BTypeToUse =
124  std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
128  static constexpr index_t kMPerXdl = Problem::kMPerXdl;
129  static constexpr index_t kNPerXdl = Problem::kNPerXdl;
130  static constexpr index_t kKPerXdl = Problem::kKPerXdl;
131  static constexpr index_t isCTransposed = Problem::isCTransposed;
132 
134  BTypeToUse,
135  AccDataType,
136  kMPerXdl,
137  kNPerXdl,
138  kKPerXdl,
139  isCTransposed>;
140 
141  using CWarpDstr = typename WG::CWarpDstr;
142 
143  CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
144  {
145  // N is contiguous dimension
146  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
147  {
148  if constexpr(isCTransposed)
149  {
150  // In this case each thread has multiple consecutive elements in
151  // N dimension, however consecutive threads' elements have stride.
152  constexpr index_t NDimY = CWarpDstr::NDimY;
153  constexpr auto c_warp_y_lengths =
154  CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
155  static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
156  c_warp_y_lengths.get(number<NDimY - 1>{}));
157  return c_warp_y_lengths.get(number<NDimY - 1>{});
158  }
159  else
160  {
161  // In this case each thread has just a single item in Ndim
162  return (WG::WarpGemmAttribute::Impl::kCNLane *
163  WG::WarpGemmAttribute::Impl::kBNBlock) /
164  WG::kN;
165  }
166  }
167  // M is contiguous dimension
168  else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
169  {
170  if constexpr(isCTransposed)
171  {
172  // In this case each thread has just a single item in Mdim
173  return (WG::WarpGemmAttribute::Impl::kCNLane *
174  WG::WarpGemmAttribute::Impl::kAMBlock) /
175  WG::kN;
176  }
177  else
178  {
179  // In this case each thread has multiple consecutive elements in
180  // M dimension, however consecutive threads' elements have stride.
181  constexpr index_t NDimY = CWarpDstr::NDimY;
182  constexpr auto c_warp_y_lengths =
183  CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
184  static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
185  c_warp_y_lengths.get(number<NDimY - 1>{}));
186  return c_warp_y_lengths.get(number<NDimY - 1>{});
187  }
188  }
189  else
190  {
191  static_assert(false, "Unsupported CLayout!");
192  }
193  }
194 
195  CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD() { return 1; }
196 };
197 
198 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
typename impl::WarpGemmMfmaDispatcher< AType, BType, AccType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity >::Type WarpGemmMfmaDispatcher
Definition: warp_gemm_dispatcher.hpp:115
memory_operation_enum
Definition: arch.hpp:44
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
CK_TILE_DEVICE void buffer_store_fence(index_t cnt=0)
Definition: amd_buffer_addressing.hpp:1004
CK_TILE_DEVICE void store_tile_raw(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:46
CK_TILE_DEVICE void update_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: update_tile.hpp:22
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE void update_tile_raw(tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: update_tile.hpp:68
Definition: default_2d_epilogue.hpp:61
remove_cvref_t< typename Problem::ODataType > ODataType
Definition: default_2d_epilogue.hpp:64
CK_TILE_DEVICE auto operator()(ODramWindowTmp &o_dram_window_tmp, const OAccTile &o_acc_tile, void *=nullptr)
Definition: default_2d_epilogue.hpp:76
static constexpr bool kPadN
Definition: default_2d_epilogue.hpp:66
remove_cvref_t< Problem_ > Problem
Definition: default_2d_epilogue.hpp:62
static constexpr bool kPadM
Definition: default_2d_epilogue.hpp:65
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition: default_2d_epilogue.hpp:63
static constexpr bool UseRawStore
Definition: default_2d_epilogue.hpp:67
CK_TILE_DEVICE auto operator()(ODramWindowTmp &o_dram_window_tmp, const OAccTile &o_acc_tile, const DsDramWindows &, void *=nullptr)
Definition: default_2d_epilogue.hpp:105
static constexpr memory_operation_enum MemoryOperation
Definition: default_2d_epilogue.hpp:68
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: default_2d_epilogue.hpp:70
Definition: default_2d_epilogue.hpp:21
static constexpr bool UseRawStore
Definition: default_2d_epilogue.hpp:26
remove_cvref_t< ODataType_ > ODataType
Definition: default_2d_epilogue.hpp:23
remove_cvref_t< AccDataType_ > AccDataType
Definition: default_2d_epilogue.hpp:22
static constexpr bool kPadM
Definition: default_2d_epilogue.hpp:24
static constexpr memory_operation_enum MemoryOperation
Definition: default_2d_epilogue.hpp:27
static constexpr bool kPadN
Definition: default_2d_epilogue.hpp:25
Definition: default_2d_epilogue.hpp:116
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition: default_2d_epilogue.hpp:120
static constexpr index_t kMPerXdl
Definition: default_2d_epilogue.hpp:128
static constexpr index_t kNPerXdl
Definition: default_2d_epilogue.hpp:129
static constexpr index_t kKPerXdl
Definition: default_2d_epilogue.hpp:130
static constexpr CK_TILE_HOST_DEVICE auto GetVectorSizeC()
Definition: default_2d_epilogue.hpp:143
static constexpr index_t isCTransposed
Definition: default_2d_epilogue.hpp:131
remove_cvref_t< typename Problem::CLayout > CLayout
Definition: default_2d_epilogue.hpp:127
remove_cvref_t< typename Problem::ODataType > ODataType
Definition: default_2d_epilogue.hpp:121
WarpGemmMfmaDispatcher< ADataType, BTypeToUse, AccDataType, kMPerXdl, kNPerXdl, kKPerXdl, isCTransposed > WG
Definition: default_2d_epilogue.hpp:139
remove_cvref_t< typename Problem::BDataType > BDataType
Definition: default_2d_epilogue.hpp:119
std::conditional_t< std::is_same_v< BDataType, pk_int4_t >, ADataType, BDataType > BTypeToUse
Definition: default_2d_epilogue.hpp:124
remove_cvref_t< typename Problem::ADataType > ADataType
Definition: default_2d_epilogue.hpp:118
typename WG::CWarpDstr CWarpDstr
Definition: default_2d_epilogue.hpp:141
static constexpr CK_TILE_HOST_DEVICE auto GetVectorSizeD()
Definition: default_2d_epilogue.hpp:195
remove_cvref_t< Problem_ > Problem
Definition: default_2d_epilogue.hpp:117
Definition: default_2d_epilogue.hpp:49
static constexpr index_t isCTransposed
Definition: default_2d_epilogue.hpp:56
static constexpr index_t kNPerXdl
Definition: default_2d_epilogue.hpp:54
static constexpr index_t kMPerXdl
Definition: default_2d_epilogue.hpp:53
remove_cvref_t< BDataType_ > BDataType
Definition: default_2d_epilogue.hpp:51
remove_cvref_t< CLayout_ > CLayout
Definition: default_2d_epilogue.hpp:52
remove_cvref_t< ADataType_ > ADataType
Definition: default_2d_epilogue.hpp:50
static constexpr index_t kKPerXdl
Definition: default_2d_epilogue.hpp:55
Definition: integral_constant.hpp:13
Definition: tuple.hpp:192