/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp Source File
default_2d_epilogue.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 // this epilogue just store out a M*N matrix, row major
13 
14 template <typename AccDataType_,
15  typename ODataType_,
16  bool kPadM_,
17  bool kPadN_,
18  bool UseRawStore_ = true>
20 {
23  static constexpr bool kPadM = kPadM_;
24  static constexpr bool kPadN = kPadN_;
25  static constexpr bool UseRawStore = UseRawStore_;
26 };
27 
28 template <typename AccDataType_,
29  typename ODataType_,
30  typename CLayout_,
31  bool kPadM_,
32  bool kPadN_,
33  index_t kMPerXdl_,
34  index_t kNPerXdl_,
35  index_t kKPerXdl_,
36  bool isCTransposed_,
37  bool UseRawStore_ = true>
39  : public Default2DEpilogueProblem<AccDataType_, ODataType_, kPadM_, kPadN_, UseRawStore_>
40 {
42  static constexpr index_t kMPerXdl = kMPerXdl_;
43  static constexpr index_t kNPerXdl = kNPerXdl_;
44  static constexpr index_t kKPerXdl = kKPerXdl_;
45  static constexpr index_t isCTransposed = isCTransposed_;
46 };
47 
48 template <typename Problem_, typename Policy_ = void>
50 {
54  static constexpr bool kPadM = Problem::kPadM;
55  static constexpr bool kPadN = Problem::kPadN;
56  static constexpr bool UseRawStore = Problem::UseRawStore;
57 
58  CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
59 
60  // TODO: this function assume store out vector size is the same as OAccTile last dimension size
61  // how do we fix this ?
62  template <typename ODramWindowTmp,
63  typename OAccTile,
65  CK_TILE_DEVICE auto
66  operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr)
67  {
68 
69  // TODO: this is ugly
70  if constexpr(UseRawStore && (kPadM || kPadN))
71  {
72  if constexpr(out_memory_data_op == memory_operation_enum::set)
73  {
74  store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
75  }
76  else
77  {
78  update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
79  }
81  }
82  else
83  {
84  if constexpr(out_memory_data_op == memory_operation_enum::set)
85  {
86  store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
87  }
88  else
89  {
90  update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
91  }
92  }
93  }
94 };
95 
96 template <typename Problem_, typename Policy_ = void>
97 struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
98 {
103  static constexpr index_t kMPerXdl = Problem::kMPerXdl;
104  static constexpr index_t kNPerXdl = Problem::kNPerXdl;
105  static constexpr index_t kKPerXdl = Problem::kKPerXdl;
106  static constexpr index_t isCTransposed = Problem::isCTransposed;
107 
109  ODataType,
110  AccDataType,
111  kMPerXdl,
112  kNPerXdl,
113  kKPerXdl,
114  isCTransposed>;
115 
116  using CWarpDstr = typename WG::CWarpDstr;
117 
118  CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
119  {
120  // N is contiguous dimension
121  if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
122  {
123  if constexpr(isCTransposed)
124  {
125  // In this case each thread has multiple consecutive elements in
126  // N dimension, however consecutive threads' elements have stride.
127  constexpr index_t NDimY = CWarpDstr::NDimY;
128  constexpr auto c_warp_y_lengths =
129  CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
130  static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
131  c_warp_y_lengths.get(number<NDimY - 1>{}));
132  return c_warp_y_lengths.get(number<NDimY - 1>{});
133  }
134  else
135  {
136  // In this case each thread has just a single item in Ndim
137  return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
138  }
139  }
140  // M is contiguous dimension
141  else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
142  {
143  if constexpr(isCTransposed)
144  {
145  // In this case each thread has just a single item in Mdim
146  return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
147  }
148  else
149  {
150  // In this case each thread has multiple consecutive elements in
151  // M dimension, however consecutive threads' elements have stride.
152  constexpr index_t NDimY = CWarpDstr::NDimY;
153  constexpr auto c_warp_y_lengths =
154  CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
155  static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
156  c_warp_y_lengths.get(number<NDimY - 1>{}));
157  return c_warp_y_lengths.get(number<NDimY - 1>{});
158  }
159  }
160  else
161  {
162  static_assert(false, "Unsupported CLayout!");
163  }
164  }
165 };
166 
167 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
memory_operation_enum
Definition: arch.hpp:44
int32_t index_t
Definition: integer.hpp:9
typename impl::WarpGemmMfmaDispatcher< AType, BType, CType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA >::Type WarpGemmMfmaDispatcher
Definition: warp_gemm_dispatcher.hpp:81
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
CK_TILE_DEVICE void buffer_store_fence(index_t cnt=0)
Definition: amd_buffer_addressing.hpp:867
CK_TILE_DEVICE void store_tile_raw(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:46
CK_TILE_DEVICE void update_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: update_tile.hpp:22
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE void update_tile_raw(tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: update_tile.hpp:68
Definition: default_2d_epilogue.hpp:50
remove_cvref_t< typename Problem::ODataType > ODataType
Definition: default_2d_epilogue.hpp:53
CK_TILE_DEVICE auto operator()(ODramWindowTmp &o_dram_window_tmp, const OAccTile &o_acc_tile, void *=nullptr)
Definition: default_2d_epilogue.hpp:66
static constexpr bool kPadN
Definition: default_2d_epilogue.hpp:55
remove_cvref_t< Problem_ > Problem
Definition: default_2d_epilogue.hpp:51
static constexpr bool kPadM
Definition: default_2d_epilogue.hpp:54
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition: default_2d_epilogue.hpp:52
static constexpr bool UseRawStore
Definition: default_2d_epilogue.hpp:56
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: default_2d_epilogue.hpp:58
Definition: default_2d_epilogue.hpp:20
remove_cvref_t< ODataType_ > ODataType
Definition: default_2d_epilogue.hpp:22
static constexpr bool kPadM
Definition: default_2d_epilogue.hpp:23
remove_cvref_t< AccDataType_ > AccDataType
Definition: default_2d_epilogue.hpp:21
static constexpr bool kPadN
Definition: default_2d_epilogue.hpp:24
static constexpr bool UseRawStore
Definition: default_2d_epilogue.hpp:25
Definition: default_2d_epilogue.hpp:98
remove_cvref_t< typename Problem::AccDataType > AccDataType
Definition: default_2d_epilogue.hpp:100
static constexpr index_t kMPerXdl
Definition: default_2d_epilogue.hpp:103
WarpGemmMfmaDispatcher< ODataType, ODataType, AccDataType, kMPerXdl, kNPerXdl, kKPerXdl, isCTransposed > WG
Definition: default_2d_epilogue.hpp:114
static constexpr index_t kNPerXdl
Definition: default_2d_epilogue.hpp:104
static constexpr index_t kKPerXdl
Definition: default_2d_epilogue.hpp:105
static constexpr CK_TILE_HOST_DEVICE auto GetVectorSizeC()
Definition: default_2d_epilogue.hpp:118
static constexpr index_t isCTransposed
Definition: default_2d_epilogue.hpp:106
remove_cvref_t< typename Problem::CLayout > CLayout
Definition: default_2d_epilogue.hpp:102
remove_cvref_t< typename Problem::ODataType > ODataType
Definition: default_2d_epilogue.hpp:101
typename WG::CWarpDstr CWarpDstr
Definition: default_2d_epilogue.hpp:116
remove_cvref_t< Problem_ > Problem
Definition: default_2d_epilogue.hpp:99
Definition: default_2d_epilogue.hpp:40
static constexpr index_t isCTransposed
Definition: default_2d_epilogue.hpp:45
static constexpr index_t kKPerXdl
Definition: default_2d_epilogue.hpp:44
remove_cvref_t< CLayout_ > CLayout
Definition: default_2d_epilogue.hpp:41
static constexpr index_t kMPerXdl
Definition: default_2d_epilogue.hpp:42
static constexpr index_t kNPerXdl
Definition: default_2d_epilogue.hpp:43
Definition: integral_constant.hpp:13