/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp Source File
rmsnorm2d_fwd_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
9 
10 namespace ck_tile {
11 
12 // host side args
14 {
15  const void* p_x; // [m ,n], input, fp16/bf16
16  const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
17  const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
18  const void* p_gamma; // [1, n], gamma, prec same as input
19 
20  void* p_y; // [m, n], output, fp16/bf16
21  void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
22  void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
23  void* p_invRms; // [m, 1], output inv-rms, prec same as input, nullptr if not used
24  void* p_y_unquant; // [m, n], output result before quant, nullptr if not used
25 
26  float epsilon;
27 
30  index_t x_stride; // x row_stride
31  index_t xr_stride; // x residule row stride
32  index_t y_stride; // y row stride
33  index_t yr_stride; // y residule row stride
34 };
35 
36 // TODO: Extract some type to wrapper class
37 template <typename Pipeline_, typename Epilogue_>
39 {
42  using Problem = typename Pipeline::Problem;
43 
52 
53  // for simplicity, shortcut input/output type is same as X
56 
57  static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
58  static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
59  static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant;
60 
61  static constexpr index_t Block_M = Problem::BlockShape::Block_M;
62  static constexpr index_t Block_N = Problem::BlockShape::Block_N;
63  static constexpr bool kPadM = false; // always no need to pad along M
64  static constexpr bool kPadN = Problem::Traits::kPadN;
65  static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
66  static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
67  static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
68  static constexpr auto kUseModelSensitiveRMSNorm = Problem::Traits::kUseModelSensitiveRMSNorm;
69 
70  static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
71  static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
72  static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
73 
74  static constexpr auto I0 = number<0>{};
75  static constexpr auto I1 = number<1>{};
76 
77  struct Kargs
78  {
79  const void* p_x;
80  const void* p_x_residual;
81  const void* p_sm_scale;
82  const void* p_gamma;
83 
84  void* p_y;
85  void* p_y_residual;
86  void* p_y_scale;
87  void* p_invRms;
88  void* p_y_unquant;
89 
90  float epsilon;
91 
94  index_t x_stride; // x row_stride
95  index_t xr_stride; // x residule row stride
96  index_t y_stride; // y row stride
97  index_t yr_stride; // y residule row stride
98  };
100 
101  CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
102  {
103  return Kargs{hargs.p_x,
104  hargs.p_x_residual,
105  hargs.p_sm_scale,
106  hargs.p_gamma,
107  hargs.p_y,
108  hargs.p_y_residual,
109  hargs.p_y_scale,
110  hargs.p_invRms,
111  hargs.p_y_unquant,
112  hargs.epsilon,
113  hargs.m,
114  hargs.n,
115  hargs.x_stride,
116  hargs.xr_stride,
117  hargs.y_stride,
118  hargs.yr_stride};
119  }
120 
121  CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
122  {
123  return dim3(integer_divide_ceil(hargs.m, Block_M));
124  }
125 
126  CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
127 
128  // clang-format off
129  template <typename T> struct t2s;
130  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
131  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
132  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
133  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
134  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
135  template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "int8"; };
136  // clang-format on
137 
138  // in byte
139  CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
140 
141  CK_TILE_HOST static std::string GetName()
142  {
143 #define _SS_ std::string
144 #define _TS_ std::to_string
145  // clang-format off
146  using S_ = typename Problem::BlockShape;
147  auto surfix = [&] () {
148  std::string n;
151  if (kPadN) n += "_pn";
152  if (kSaveInvRms) n += "_rms";
153  if (kTwoPass) n += "_2p";
156  return n; }();
157 
158  auto prec_str = [&] () {
159  std::string base_str = _SS_(t2s<XDataType>::name);
160  if (!std::is_same_v<XDataType, YDataType>) {
161  base_str += _SS_("_") + _SS_(t2s<YDataType>::name);
162  }
164  base_str += _SS_("_sx") + _SS_(t2s<SmoothScaleDataType>::name);
165  base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
166  }
168  base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
169  }
170  return base_str;
171  }();
172 
173  return _SS_("rmsnorm2d_fwd_") + _SS_(prec_str) + "_" +
174  _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
175  _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
176  _SS_(Pipeline::name) + surfix;
177  // clang-format on
178 #undef _SS_
179 #undef _TS_
180  }
181 
182  CK_TILE_DEVICE void operator()(Kargs kargs) const
183  {
184  const auto iM = get_block_id() * Block_M;
185 
186  const auto x_window = [&]() {
187  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
188  static_cast<const XDataType*>(kargs.p_x),
189  make_tuple(kargs.m, kargs.n),
190  make_tuple(kargs.x_stride, 1),
192  number<1>{});
193 
194  const auto tmp2_ = pad_tensor_view(
196  return make_tile_window(
197  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
198  }();
199 
200  const auto x_residual_window = [&]() {
201  if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
203  {
204  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
205  static_cast<const XResidualDataType*>(kargs.p_x_residual),
206  make_tuple(kargs.m, kargs.n),
207  make_tuple(kargs.xr_stride, 1),
209  number<1>{});
210 
211  const auto tmp2_ = pad_tensor_view(tmp_,
214  return make_tile_window(
215  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
216  }
217  else
218  {
220  }
221  }();
222 
223  const auto gamma_window = [&]() {
224  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
225  static_cast<const GammaDataType*>(kargs.p_gamma),
226  make_tuple(kargs.n),
227  make_tuple(1),
229  number<1>{});
230 
231  const auto tmp2_ =
233 
234  return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
235  }();
236 
237  auto y_window = [&]() {
238  auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
239  static_cast<YDataType*>(kargs.p_y),
240  make_tuple(kargs.m, kargs.n),
241  make_tuple(kargs.y_stride, 1),
243  number<1>{});
244 
245  auto tmp2_ = pad_tensor_view(
247  return make_tile_window(
248  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
249  }();
250 
251  auto y_residual_window = [&]() {
253  {
254  auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
255  static_cast<YResidualDataType*>(kargs.p_y_residual),
256  make_tuple(kargs.m, kargs.n),
257  make_tuple(kargs.yr_stride, 1),
259  number<1>{});
260 
261  auto tmp2_ = pad_tensor_view(tmp_,
264  return make_tile_window(
265  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
266  }
267  else
268  {
270  }
271  }();
272 
273  auto inv_rms_window = [&]() {
274  if constexpr(kSaveInvRms)
275  {
276  const auto inv_rms_m = [&]() {
277  const auto inv_rms_dram_naive =
278  make_naive_tensor_view_packed<address_space_enum::global>(
279  static_cast<InvRmsDataType*>(kargs.p_invRms),
280  make_tuple(kargs.m),
281  number<1>{});
282 
283  return pad_tensor_view(
284  inv_rms_dram_naive, make_tuple(number<Block_M>{}), sequence<kPadM>{});
285  }();
286  return make_tile_window(inv_rms_m, make_tuple(number<Block_M>{}), {iM});
287  }
288  else
290  }();
291 
292  auto sm_scale_window = [&]() {
294  {
295  const auto win_ = [&]() {
296  const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
297  static_cast<const SmoothScaleDataType*>(kargs.p_sm_scale),
298  make_tuple(kargs.n),
299  number<Vector_N>{});
300 
301  return pad_tensor_view(tmp_0_,
303  sequence<false>{}); // sm_scale no need pad
304  }();
305  return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
306  }
307  else
308  {
310  }
311  }();
312 
313  auto y_scale_window = [&]() {
316  {
317  const auto win_ = [&]() {
318  const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
319  static_cast<YScaleDataType*>(kargs.p_y_scale),
320  make_tuple(kargs.m),
321  number<1>{});
322 
323  return pad_tensor_view(
325  }();
326  return make_tile_window(win_, make_tuple(number<Block_M>{}), {iM});
327  }
328  else
329  {
331  }
332  }();
333 
334  auto unquant_y_window = [&]() {
337  kSaveUnquant)
338  {
339  auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
340  static_cast<UnquantYDataType*>(kargs.p_y_unquant),
341  make_tuple(kargs.m, kargs.n),
342  make_tuple(kargs.y_stride, 1),
344  number<1>{});
345 
346  auto tmp2_ = pad_tensor_view(tmp_,
349  return make_tile_window(
350  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
351  }
352  else
353  {
355  }
356  }();
357 
358  __shared__ char smem[GetSmemSize()];
359 
360  Pipeline{}(x_window,
361  x_residual_window,
362  gamma_window,
363  y_window,
364  y_residual_window,
365  inv_rms_window,
366  sm_scale_window,
367  y_scale_window,
368  unquant_y_window,
369  static_cast<const ComputeDataType>(kargs.epsilon),
370  kargs.n,
371  smem,
372  Epilogue{});
373  }
374 };
375 
376 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:63
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
_Float16 fp16_t
Definition: half.hpp:110
int8_t int8_t
Definition: int8.hpp:20
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:529
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:343
#define _TS_
#define _SS_
Definition: rmsnorm2d_fwd_traits.hpp:20
Definition: rmsnorm2d_fwd_traits.hpp:34
Definition: rmsnorm2d_fwd_kernel.hpp:78
void * p_invRms
Definition: rmsnorm2d_fwd_kernel.hpp:87
void * p_y_scale
Definition: rmsnorm2d_fwd_kernel.hpp:86
index_t n
Definition: rmsnorm2d_fwd_kernel.hpp:93
const void * p_x
Definition: rmsnorm2d_fwd_kernel.hpp:79
index_t yr_stride
Definition: rmsnorm2d_fwd_kernel.hpp:97
index_t y_stride
Definition: rmsnorm2d_fwd_kernel.hpp:96
void * p_y
Definition: rmsnorm2d_fwd_kernel.hpp:84
index_t xr_stride
Definition: rmsnorm2d_fwd_kernel.hpp:95
const void * p_sm_scale
Definition: rmsnorm2d_fwd_kernel.hpp:81
void * p_y_residual
Definition: rmsnorm2d_fwd_kernel.hpp:85
void * p_y_unquant
Definition: rmsnorm2d_fwd_kernel.hpp:88
index_t m
Definition: rmsnorm2d_fwd_kernel.hpp:92
const void * p_gamma
Definition: rmsnorm2d_fwd_kernel.hpp:82
float epsilon
Definition: rmsnorm2d_fwd_kernel.hpp:90
const void * p_x_residual
Definition: rmsnorm2d_fwd_kernel.hpp:80
index_t x_stride
Definition: rmsnorm2d_fwd_kernel.hpp:94
Definition: rmsnorm2d_fwd_kernel.hpp:129
Definition: rmsnorm2d_fwd_kernel.hpp:14
void * p_invRms
Definition: rmsnorm2d_fwd_kernel.hpp:23
index_t xr_stride
Definition: rmsnorm2d_fwd_kernel.hpp:31
void * p_y_residual
Definition: rmsnorm2d_fwd_kernel.hpp:21
const void * p_x_residual
Definition: rmsnorm2d_fwd_kernel.hpp:16
void * p_y_scale
Definition: rmsnorm2d_fwd_kernel.hpp:22
float epsilon
Definition: rmsnorm2d_fwd_kernel.hpp:26
void * p_y
Definition: rmsnorm2d_fwd_kernel.hpp:20
index_t yr_stride
Definition: rmsnorm2d_fwd_kernel.hpp:33
void * p_y_unquant
Definition: rmsnorm2d_fwd_kernel.hpp:24
index_t x_stride
Definition: rmsnorm2d_fwd_kernel.hpp:30
index_t y_stride
Definition: rmsnorm2d_fwd_kernel.hpp:32
index_t m
Definition: rmsnorm2d_fwd_kernel.hpp:28
index_t n
Definition: rmsnorm2d_fwd_kernel.hpp:29
const void * p_sm_scale
Definition: rmsnorm2d_fwd_kernel.hpp:17
const void * p_x
Definition: rmsnorm2d_fwd_kernel.hpp:15
const void * p_gamma
Definition: rmsnorm2d_fwd_kernel.hpp:18
Definition: rmsnorm2d_fwd_kernel.hpp:39
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: rmsnorm2d_fwd_kernel.hpp:182
XDataType XResidualDataType
Definition: rmsnorm2d_fwd_kernel.hpp:54
remove_cvref_t< typename Problem::UnquantYDataType > UnquantYDataType
Definition: rmsnorm2d_fwd_kernel.hpp:51
remove_cvref_t< Epilogue_ > Epilogue
Definition: rmsnorm2d_fwd_kernel.hpp:41
remove_cvref_t< typename Problem::YScaleDataType > YScaleDataType
Definition: rmsnorm2d_fwd_kernel.hpp:50
static constexpr bool kTwoPass
Definition: rmsnorm2d_fwd_kernel.hpp:65
static constexpr bool kSaveInvRms
Definition: rmsnorm2d_fwd_kernel.hpp:58
static constexpr CK_TILE_HOST auto GridSize(const Hargs &hargs)
Definition: rmsnorm2d_fwd_kernel.hpp:121
static constexpr auto I0
Definition: rmsnorm2d_fwd_kernel.hpp:74
typename Pipeline::Problem Problem
Definition: rmsnorm2d_fwd_kernel.hpp:42
remove_cvref_t< typename Problem::InvRmsDataType > InvRmsDataType
Definition: rmsnorm2d_fwd_kernel.hpp:48
static constexpr bool kPadN
Definition: rmsnorm2d_fwd_kernel.hpp:64
static CK_TILE_HOST std::string GetName()
Definition: rmsnorm2d_fwd_kernel.hpp:141
remove_cvref_t< typename Problem::YDataType > YDataType
Definition: rmsnorm2d_fwd_kernel.hpp:47
static constexpr auto kFusedQuant
Definition: rmsnorm2d_fwd_kernel.hpp:67
remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition: rmsnorm2d_fwd_kernel.hpp:46
remove_cvref_t< Pipeline_ > Pipeline
Definition: rmsnorm2d_fwd_kernel.hpp:40
static constexpr auto I1
Definition: rmsnorm2d_fwd_kernel.hpp:75
static constexpr bool kPadM
Definition: rmsnorm2d_fwd_kernel.hpp:63
remove_cvref_t< typename Problem::SmoothScaleDataType > SmoothScaleDataType
Definition: rmsnorm2d_fwd_kernel.hpp:49
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: rmsnorm2d_fwd_kernel.hpp:139
static constexpr index_t Block_M
Definition: rmsnorm2d_fwd_kernel.hpp:61
static constexpr auto kFusedAdd
Definition: rmsnorm2d_fwd_kernel.hpp:66
XDataType YResidualDataType
Definition: rmsnorm2d_fwd_kernel.hpp:55
static constexpr bool kHasGamma
Definition: rmsnorm2d_fwd_kernel.hpp:57
remove_cvref_t< typename Problem::XDataType > XDataType
Definition: rmsnorm2d_fwd_kernel.hpp:44
static constexpr index_t ThreadPerWarp_N
Definition: rmsnorm2d_fwd_kernel.hpp:70
static constexpr CK_TILE_HOST auto BlockSize()
Definition: rmsnorm2d_fwd_kernel.hpp:126
static constexpr CK_TILE_HOST Kargs MakeKargs(const Hargs &hargs)
Definition: rmsnorm2d_fwd_kernel.hpp:101
remove_cvref_t< typename Problem::GammaDataType > GammaDataType
Definition: rmsnorm2d_fwd_kernel.hpp:45
static constexpr index_t Block_N
Definition: rmsnorm2d_fwd_kernel.hpp:62
static constexpr bool kSaveUnquant
Definition: rmsnorm2d_fwd_kernel.hpp:59
static constexpr index_t Repeat_N
Definition: rmsnorm2d_fwd_kernel.hpp:72
static constexpr auto kUseModelSensitiveRMSNorm
Definition: rmsnorm2d_fwd_kernel.hpp:68
static constexpr index_t Vector_N
Definition: rmsnorm2d_fwd_kernel.hpp:71
Definition: integral_constant.hpp:13
Definition: sequence.hpp:52