/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp Source File
rmsnorm2d_fwd_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
9 
10 namespace ck_tile {
11 
12 // host side args
14 {
15  const void* p_x; // [m ,n], input, fp16/bf16
16  const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
17  const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
18  const void* p_gamma; // [1, n], gamma, prec same as input
19 
20  void* p_y; // [m, n], output, fp16/bf16
21  void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
22  void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
23  void* p_invRms; // [m, 1], output inv-rms, prec same as input, nullptr if not used
24 
25  float epsilon;
26 
29  index_t x_stride; // x row_stride
30  index_t xr_stride; // x residule row stride
31  index_t y_stride; // y row stride
32  index_t yr_stride; // y residule row stride
33 };
34 
35 // TODO: Extract some type to wrapper class
36 template <typename Pipeline_, typename Epilogue_>
38 {
41  using Problem = typename Pipeline::Problem;
42 
50 
51  // for simplicity, shortcut input/output type is same as X
54 
55  static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
56  static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
57 
58  static constexpr index_t Block_M = Problem::BlockShape::Block_M;
59  static constexpr index_t Block_N = Problem::BlockShape::Block_N;
60  static constexpr bool kPadM = false; // always no need to pad along M
61  static constexpr bool kPadN = Problem::Traits::kPadN;
62  static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
63  static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
64  static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
65 
66  static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
67  static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
68  static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
69 
70  static constexpr auto I0 = number<0>{};
71  static constexpr auto I1 = number<1>{};
72 
73  struct Kargs
74  {
75  const void* p_x;
76  const void* p_x_residual;
77  const void* p_sm_scale;
78  const void* p_gamma;
79 
80  void* p_y;
81  void* p_y_residual;
82  void* p_y_scale;
83  void* p_invRms;
84 
85  float epsilon;
86 
89  index_t x_stride; // x row_stride
90  index_t xr_stride; // x residule row stride
91  index_t y_stride; // y row stride
92  index_t yr_stride; // y residule row stride
93  };
95 
96  CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
97  {
98  return Kargs{hargs.p_x,
99  hargs.p_x_residual,
100  hargs.p_sm_scale,
101  hargs.p_gamma,
102  hargs.p_y,
103  hargs.p_y_residual,
104  hargs.p_y_scale,
105  hargs.p_invRms,
106  hargs.epsilon,
107  hargs.m,
108  hargs.n,
109  hargs.x_stride,
110  hargs.xr_stride,
111  hargs.y_stride,
112  hargs.yr_stride};
113  }
114 
115  CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
116  {
117  return dim3(integer_divide_ceil(hargs.m, Block_M));
118  }
119 
120  CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
121 
122  // clang-format off
123  template <typename T> struct t2s;
124  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
125  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
126  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
127  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
128  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
129  template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "int8"; };
130  // clang-format on
131 
132  // in byte
133  CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
134 
135  CK_TILE_HOST static std::string GetName()
136  {
137 #define _SS_ std::string
138 #define _TS_ std::to_string
139  // clang-format off
140  using S_ = typename Problem::BlockShape;
141  auto surfix = [&] () {
142  std::string n;
145  if (kPadN) n += "_pn";
146  if (kSaveInvRms) n += "_rms";
147  if (kTwoPass) n += "_2p";
148  return n; }();
149 
150  auto prec_str = [&] () {
151  std::string base_str = _SS_(t2s<XDataType>::name);
152  if (!std::is_same_v<XDataType, YDataType>) {
153  base_str += _SS_("_") + _SS_(t2s<YDataType>::name);
154  }
156  base_str += _SS_("_sx") + _SS_(t2s<SmoothScaleDataType>::name);
157  base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
158  }
160  base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
161  }
162  return base_str;
163  }();
164 
165  return _SS_("rmsnorm2d_fwd_") + _SS_(prec_str) + "_" +
166  _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
167  _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
168  _SS_(Pipeline::name) + surfix;
169  // clang-format on
170 #undef _SS_
171 #undef _TS_
172  }
173 
174  CK_TILE_DEVICE void operator()(Kargs kargs) const
175  {
176  const auto iM = get_block_id() * Block_M;
177 
178  const auto x_window = [&]() {
179  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
180  static_cast<const XDataType*>(kargs.p_x),
181  make_tuple(kargs.m, kargs.n),
182  make_tuple(kargs.x_stride, 1),
184  number<1>{});
185 
186  const auto tmp2_ = pad_tensor_view(
188  return make_tile_window(
189  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
190  }();
191 
192  const auto x_residual_window = [&]() {
193  if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
195  {
196  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
197  static_cast<const XResidualDataType*>(kargs.p_x_residual),
198  make_tuple(kargs.m, kargs.n),
199  make_tuple(kargs.xr_stride, 1),
201  number<1>{});
202 
203  const auto tmp2_ = pad_tensor_view(tmp_,
206  return make_tile_window(
207  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
208  }
209  else
210  {
212  }
213  }();
214 
215  const auto gamma_window = [&]() {
216  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
217  static_cast<const GammaDataType*>(kargs.p_gamma),
218  make_tuple(kargs.n),
219  make_tuple(1),
221  number<1>{});
222 
223  const auto tmp2_ =
225 
226  return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
227  }();
228 
229  auto y_window = [&]() {
230  auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
231  static_cast<YDataType*>(kargs.p_y),
232  make_tuple(kargs.m, kargs.n),
233  make_tuple(kargs.y_stride, 1),
235  number<1>{});
236 
237  auto tmp2_ = pad_tensor_view(
239  return make_tile_window(
240  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
241  }();
242 
243  auto y_residual_window = [&]() {
245  {
246  auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
247  static_cast<YResidualDataType*>(kargs.p_y_residual),
248  make_tuple(kargs.m, kargs.n),
249  make_tuple(kargs.yr_stride, 1),
251  number<1>{});
252 
253  auto tmp2_ = pad_tensor_view(tmp_,
256  return make_tile_window(
257  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
258  }
259  else
260  {
262  }
263  }();
264 
265  auto inv_rms_window = [&]() {
266  if constexpr(kSaveInvRms)
267  {
268  const auto inv_rms_m = [&]() {
269  const auto inv_rms_dram_naive =
270  make_naive_tensor_view_packed<address_space_enum::global>(
271  static_cast<InvRmsDataType*>(kargs.p_invRms),
272  make_tuple(kargs.m),
273  number<1>{});
274 
275  return pad_tensor_view(
276  inv_rms_dram_naive, make_tuple(number<Block_M>{}), sequence<kPadM>{});
277  }();
278  return make_tile_window(inv_rms_m, make_tuple(number<Block_M>{}), {iM});
279  }
280  else
282  }();
283 
284  auto sm_scale_window = [&]() {
286  {
287  const auto win_ = [&]() {
288  const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
289  static_cast<const SmoothScaleDataType*>(kargs.p_sm_scale),
290  make_tuple(kargs.n),
291  number<Vector_N>{});
292 
293  return pad_tensor_view(tmp_0_,
295  sequence<false>{}); // sm_scale no need pad
296  }();
297  return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
298  }
299  else
300  {
302  }
303  }();
304 
305  auto y_scale_window = [&]() {
308  {
309  const auto win_ = [&]() {
310  const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
311  static_cast<YScaleDataType*>(kargs.p_y_scale),
312  make_tuple(kargs.m),
313  number<1>{});
314 
315  return pad_tensor_view(
317  }();
318  return make_tile_window(win_, make_tuple(number<Block_M>{}), {iM});
319  }
320  else
321  {
323  }
324  }();
325 
326  __shared__ char smem[GetSmemSize()];
327 
328  Pipeline{}(x_window,
329  x_residual_window,
330  gamma_window,
331  y_window,
332  y_residual_window,
333  inv_rms_window,
334  sm_scale_window,
335  y_scale_window,
336  static_cast<const ComputeDataType>(kargs.epsilon),
337  kargs.n,
338  smem,
339  Epilogue{});
340  }
341 };
342 
343 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:63
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
_Float16 fp16_t
Definition: half.hpp:110
int8_t int8_t
Definition: int8.hpp:20
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:480
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
CK_TILE_DEVICE index_t get_block_id()
Definition: arch.hpp:78
#define _TS_
#define _SS_
Definition: rmsnorm2d_fwd_traits.hpp:20
Definition: rmsnorm2d_fwd_traits.hpp:34
Definition: rmsnorm2d_fwd_kernel.hpp:74
void * p_invRms
Definition: rmsnorm2d_fwd_kernel.hpp:83
void * p_y_scale
Definition: rmsnorm2d_fwd_kernel.hpp:82
index_t n
Definition: rmsnorm2d_fwd_kernel.hpp:88
const void * p_x
Definition: rmsnorm2d_fwd_kernel.hpp:75
index_t yr_stride
Definition: rmsnorm2d_fwd_kernel.hpp:92
index_t y_stride
Definition: rmsnorm2d_fwd_kernel.hpp:91
void * p_y
Definition: rmsnorm2d_fwd_kernel.hpp:80
index_t xr_stride
Definition: rmsnorm2d_fwd_kernel.hpp:90
const void * p_sm_scale
Definition: rmsnorm2d_fwd_kernel.hpp:77
void * p_y_residual
Definition: rmsnorm2d_fwd_kernel.hpp:81
index_t m
Definition: rmsnorm2d_fwd_kernel.hpp:87
const void * p_gamma
Definition: rmsnorm2d_fwd_kernel.hpp:78
float epsilon
Definition: rmsnorm2d_fwd_kernel.hpp:85
const void * p_x_residual
Definition: rmsnorm2d_fwd_kernel.hpp:76
index_t x_stride
Definition: rmsnorm2d_fwd_kernel.hpp:89
Definition: rmsnorm2d_fwd_kernel.hpp:123
Definition: rmsnorm2d_fwd_kernel.hpp:14
void * p_invRms
Definition: rmsnorm2d_fwd_kernel.hpp:23
index_t xr_stride
Definition: rmsnorm2d_fwd_kernel.hpp:30
void * p_y_residual
Definition: rmsnorm2d_fwd_kernel.hpp:21
const void * p_x_residual
Definition: rmsnorm2d_fwd_kernel.hpp:16
void * p_y_scale
Definition: rmsnorm2d_fwd_kernel.hpp:22
float epsilon
Definition: rmsnorm2d_fwd_kernel.hpp:25
void * p_y
Definition: rmsnorm2d_fwd_kernel.hpp:20
index_t yr_stride
Definition: rmsnorm2d_fwd_kernel.hpp:32
index_t x_stride
Definition: rmsnorm2d_fwd_kernel.hpp:29
index_t y_stride
Definition: rmsnorm2d_fwd_kernel.hpp:31
index_t m
Definition: rmsnorm2d_fwd_kernel.hpp:27
index_t n
Definition: rmsnorm2d_fwd_kernel.hpp:28
const void * p_sm_scale
Definition: rmsnorm2d_fwd_kernel.hpp:17
const void * p_x
Definition: rmsnorm2d_fwd_kernel.hpp:15
const void * p_gamma
Definition: rmsnorm2d_fwd_kernel.hpp:18
Definition: rmsnorm2d_fwd_kernel.hpp:38
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: rmsnorm2d_fwd_kernel.hpp:174
XDataType XResidualDataType
Definition: rmsnorm2d_fwd_kernel.hpp:52
remove_cvref_t< Epilogue_ > Epilogue
Definition: rmsnorm2d_fwd_kernel.hpp:40
remove_cvref_t< typename Problem::YScaleDataType > YScaleDataType
Definition: rmsnorm2d_fwd_kernel.hpp:49
static constexpr bool kTwoPass
Definition: rmsnorm2d_fwd_kernel.hpp:62
static constexpr bool kSaveInvRms
Definition: rmsnorm2d_fwd_kernel.hpp:56
static constexpr CK_TILE_HOST auto GridSize(const Hargs &hargs)
Definition: rmsnorm2d_fwd_kernel.hpp:115
static constexpr auto I0
Definition: rmsnorm2d_fwd_kernel.hpp:70
typename Pipeline::Problem Problem
Definition: rmsnorm2d_fwd_kernel.hpp:41
remove_cvref_t< typename Problem::InvRmsDataType > InvRmsDataType
Definition: rmsnorm2d_fwd_kernel.hpp:47
static constexpr bool kPadN
Definition: rmsnorm2d_fwd_kernel.hpp:61
static CK_TILE_HOST std::string GetName()
Definition: rmsnorm2d_fwd_kernel.hpp:135
remove_cvref_t< typename Problem::YDataType > YDataType
Definition: rmsnorm2d_fwd_kernel.hpp:46
static constexpr auto kFusedQuant
Definition: rmsnorm2d_fwd_kernel.hpp:64
remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition: rmsnorm2d_fwd_kernel.hpp:45
remove_cvref_t< Pipeline_ > Pipeline
Definition: rmsnorm2d_fwd_kernel.hpp:39
static constexpr auto I1
Definition: rmsnorm2d_fwd_kernel.hpp:71
static constexpr bool kPadM
Definition: rmsnorm2d_fwd_kernel.hpp:60
remove_cvref_t< typename Problem::SmoothScaleDataType > SmoothScaleDataType
Definition: rmsnorm2d_fwd_kernel.hpp:48
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: rmsnorm2d_fwd_kernel.hpp:133
static constexpr index_t Block_M
Definition: rmsnorm2d_fwd_kernel.hpp:58
static constexpr auto kFusedAdd
Definition: rmsnorm2d_fwd_kernel.hpp:63
XDataType YResidualDataType
Definition: rmsnorm2d_fwd_kernel.hpp:53
static constexpr bool kHasGamma
Definition: rmsnorm2d_fwd_kernel.hpp:55
remove_cvref_t< typename Problem::XDataType > XDataType
Definition: rmsnorm2d_fwd_kernel.hpp:43
static constexpr index_t ThreadPerWarp_N
Definition: rmsnorm2d_fwd_kernel.hpp:66
static constexpr CK_TILE_HOST auto BlockSize()
Definition: rmsnorm2d_fwd_kernel.hpp:120
static constexpr CK_TILE_HOST Kargs MakeKargs(const Hargs &hargs)
Definition: rmsnorm2d_fwd_kernel.hpp:96
remove_cvref_t< typename Problem::GammaDataType > GammaDataType
Definition: rmsnorm2d_fwd_kernel.hpp:44
static constexpr index_t Block_N
Definition: rmsnorm2d_fwd_kernel.hpp:59
static constexpr index_t Repeat_N
Definition: rmsnorm2d_fwd_kernel.hpp:68
static constexpr index_t Vector_N
Definition: rmsnorm2d_fwd_kernel.hpp:67
Definition: integral_constant.hpp:13
Definition: sequence.hpp:52