/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp Source File
rmsnorm2d_fwd_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
9 
10 namespace ck_tile {
11 
12 // host side args
14 {
15  const void* p_x; // [m ,n], input, fp16/bf16
16  const void* p_x_residual; // [m ,n], shortcut input, prec same as input, nullptr if not used
17  const void* p_sm_scale; // [1 ,n], smooth scale input, fp32, nullptr if not used
18  const void* p_gamma; // [1, n], gamma, prec same as input
19 
20  void* p_y; // [m, n], output, fp16/bf16
21  void* p_y_residual; // [m, n], shortcut output, prec same as input, nullptr if not used
22  void* p_y_scale; // [m, 1], output a dynamic quant per row, nullptr if not used
23  void* p_invRms; // [m, 1], output inv-rms, prec same as input, nullptr if not used
24  void* p_y_unquant; // [m, n], output result before quant, nullptr if not used
25 
26  float epsilon;
27 
30  index_t x_stride; // x row_stride
31  index_t xr_stride; // x residule row stride
32  index_t y_stride; // y row stride
33  index_t yr_stride; // y residule row stride
34 };
35 
36 // TODO: Extract some type to wrapper class
37 template <typename Pipeline_, typename Epilogue_>
39 {
42  using Problem = typename Pipeline::Problem;
43 
52 
53  // for simplicity, shortcut input/output type is same as X
56 
57  static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, null_type>;
58  static constexpr bool kSaveInvRms = Problem::Traits::kSaveInvRms;
59  static constexpr bool kSaveUnquant = Problem::Traits::kSaveUnquant;
60 
61  static constexpr index_t Block_M = Problem::BlockShape::Block_M;
62  static constexpr index_t Block_N = Problem::BlockShape::Block_N;
63  static constexpr bool kPadM = false; // always no need to pad along M
64  static constexpr bool kPadN = Problem::Traits::kPadN;
65  static constexpr bool kTwoPass = Problem::Traits::kTwoPass;
66  static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
67  static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
68 
69  static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N;
70  static constexpr index_t Vector_N = Problem::BlockShape::Vector_N;
71  static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N;
72 
73  static constexpr auto I0 = number<0>{};
74  static constexpr auto I1 = number<1>{};
75 
76  struct Kargs
77  {
78  const void* p_x;
79  const void* p_x_residual;
80  const void* p_sm_scale;
81  const void* p_gamma;
82 
83  void* p_y;
84  void* p_y_residual;
85  void* p_y_scale;
86  void* p_invRms;
87  void* p_y_unquant;
88 
89  float epsilon;
90 
93  index_t x_stride; // x row_stride
94  index_t xr_stride; // x residule row stride
95  index_t y_stride; // y row stride
96  index_t yr_stride; // y residule row stride
97  };
99 
100  CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs)
101  {
102  return Kargs{hargs.p_x,
103  hargs.p_x_residual,
104  hargs.p_sm_scale,
105  hargs.p_gamma,
106  hargs.p_y,
107  hargs.p_y_residual,
108  hargs.p_y_scale,
109  hargs.p_invRms,
110  hargs.p_y_unquant,
111  hargs.epsilon,
112  hargs.m,
113  hargs.n,
114  hargs.x_stride,
115  hargs.xr_stride,
116  hargs.y_stride,
117  hargs.yr_stride};
118  }
119 
120  CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs)
121  {
122  return dim3(integer_divide_ceil(hargs.m, Block_M));
123  }
124 
125  CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; }
126 
127  // clang-format off
128  template <typename T> struct t2s;
129  template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
130  template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
131  template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
132  template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
133  template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
134  template <> struct t2s<ck_tile::int8_t> { static constexpr const char * name = "int8"; };
135  // clang-format on
136 
137  // in byte
138  CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); }
139 
140  CK_TILE_HOST static std::string GetName()
141  {
142 #define _SS_ std::string
143 #define _TS_ std::to_string
144  // clang-format off
145  using S_ = typename Problem::BlockShape;
146  auto surfix = [&] () {
147  std::string n;
150  if (kPadN) n += "_pn";
151  if (kSaveInvRms) n += "_rms";
152  if (kTwoPass) n += "_2p";
153  return n; }();
154 
155  auto prec_str = [&] () {
156  std::string base_str = _SS_(t2s<XDataType>::name);
157  if (!std::is_same_v<XDataType, YDataType>) {
158  base_str += _SS_("_") + _SS_(t2s<YDataType>::name);
159  }
161  base_str += _SS_("_sx") + _SS_(t2s<SmoothScaleDataType>::name);
162  base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
163  }
165  base_str += _SS_("_sy") + _SS_(t2s<YScaleDataType>::name);
166  }
167  return base_str;
168  }();
169 
170  return _SS_("rmsnorm2d_fwd_") + _SS_(prec_str) + "_" +
171  _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
172  _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
173  _SS_(Pipeline::name) + surfix;
174  // clang-format on
175 #undef _SS_
176 #undef _TS_
177  }
178 
179  CK_TILE_DEVICE void operator()(Kargs kargs) const
180  {
181  const auto iM = get_block_id() * Block_M;
182 
183  const auto x_window = [&]() {
184  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
185  static_cast<const XDataType*>(kargs.p_x),
186  make_tuple(kargs.m, kargs.n),
187  make_tuple(kargs.x_stride, 1),
189  number<1>{});
190 
191  const auto tmp2_ = pad_tensor_view(
193  return make_tile_window(
194  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
195  }();
196 
197  const auto x_residual_window = [&]() {
198  if constexpr(kFusedAdd == Rmsnorm2dFusedAddEnum::PRE_ADD ||
200  {
201  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
202  static_cast<const XResidualDataType*>(kargs.p_x_residual),
203  make_tuple(kargs.m, kargs.n),
204  make_tuple(kargs.xr_stride, 1),
206  number<1>{});
207 
208  const auto tmp2_ = pad_tensor_view(tmp_,
211  return make_tile_window(
212  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
213  }
214  else
215  {
217  }
218  }();
219 
220  const auto gamma_window = [&]() {
221  const auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
222  static_cast<const GammaDataType*>(kargs.p_gamma),
223  make_tuple(kargs.n),
224  make_tuple(1),
226  number<1>{});
227 
228  const auto tmp2_ =
230 
231  return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
232  }();
233 
234  auto y_window = [&]() {
235  auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
236  static_cast<YDataType*>(kargs.p_y),
237  make_tuple(kargs.m, kargs.n),
238  make_tuple(kargs.y_stride, 1),
240  number<1>{});
241 
242  auto tmp2_ = pad_tensor_view(
244  return make_tile_window(
245  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
246  }();
247 
248  auto y_residual_window = [&]() {
250  {
251  auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
252  static_cast<YResidualDataType*>(kargs.p_y_residual),
253  make_tuple(kargs.m, kargs.n),
254  make_tuple(kargs.yr_stride, 1),
256  number<1>{});
257 
258  auto tmp2_ = pad_tensor_view(tmp_,
261  return make_tile_window(
262  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
263  }
264  else
265  {
267  }
268  }();
269 
270  auto inv_rms_window = [&]() {
271  if constexpr(kSaveInvRms)
272  {
273  const auto inv_rms_m = [&]() {
274  const auto inv_rms_dram_naive =
275  make_naive_tensor_view_packed<address_space_enum::global>(
276  static_cast<InvRmsDataType*>(kargs.p_invRms),
277  make_tuple(kargs.m),
278  number<1>{});
279 
280  return pad_tensor_view(
281  inv_rms_dram_naive, make_tuple(number<Block_M>{}), sequence<kPadM>{});
282  }();
283  return make_tile_window(inv_rms_m, make_tuple(number<Block_M>{}), {iM});
284  }
285  else
287  }();
288 
289  auto sm_scale_window = [&]() {
291  {
292  const auto win_ = [&]() {
293  const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
294  static_cast<const SmoothScaleDataType*>(kargs.p_sm_scale),
295  make_tuple(kargs.n),
296  number<Vector_N>{});
297 
298  return pad_tensor_view(tmp_0_,
300  sequence<false>{}); // sm_scale no need pad
301  }();
302  return make_tile_window(win_, make_tuple(number<Block_N>{}), {0});
303  }
304  else
305  {
307  }
308  }();
309 
310  auto y_scale_window = [&]() {
313  {
314  const auto win_ = [&]() {
315  const auto tmp_0_ = make_naive_tensor_view_packed<address_space_enum::global>(
316  static_cast<YScaleDataType*>(kargs.p_y_scale),
317  make_tuple(kargs.m),
318  number<1>{});
319 
320  return pad_tensor_view(
322  }();
323  return make_tile_window(win_, make_tuple(number<Block_M>{}), {iM});
324  }
325  else
326  {
328  }
329  }();
330 
331  auto unquant_y_window = [&]() {
334  kSaveUnquant)
335  {
336  auto tmp_ = make_naive_tensor_view<address_space_enum::global>(
337  static_cast<UnquantYDataType*>(kargs.p_y_unquant),
338  make_tuple(kargs.m, kargs.n),
339  make_tuple(kargs.y_stride, 1),
341  number<1>{});
342 
343  auto tmp2_ = pad_tensor_view(tmp_,
346  return make_tile_window(
347  tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
348  }
349  else
350  {
352  }
353  }();
354 
355  __shared__ char smem[GetSmemSize()];
356 
357  Pipeline{}(x_window,
358  x_residual_window,
359  gamma_window,
360  y_window,
361  y_residual_window,
362  inv_rms_window,
363  sm_scale_window,
364  y_scale_window,
365  unquant_y_window,
366  static_cast<const ComputeDataType>(kargs.epsilon),
367  kargs.n,
368  smem,
369  Epilogue{});
370  }
371 };
372 
373 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:63
_BitInt(8) fp8_t
Definition: float8.hpp:204
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
_Float16 fp16_t
Definition: half.hpp:110
int8_t int8_t
Definition: int8.hpp:20
bfloat16_t bf16_t
Definition: bfloat16.hpp:106
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:529
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
unsigned _BitInt(8) bf8_t
Definition: float8.hpp:206
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
CK_TILE_DEVICE index_t get_block_id()
Definition: arch.hpp:81
#define _TS_
#define _SS_
Definition: rmsnorm2d_fwd_traits.hpp:20
Definition: rmsnorm2d_fwd_traits.hpp:34
Definition: rmsnorm2d_fwd_kernel.hpp:77
void * p_invRms
Definition: rmsnorm2d_fwd_kernel.hpp:86
void * p_y_scale
Definition: rmsnorm2d_fwd_kernel.hpp:85
index_t n
Definition: rmsnorm2d_fwd_kernel.hpp:92
const void * p_x
Definition: rmsnorm2d_fwd_kernel.hpp:78
index_t yr_stride
Definition: rmsnorm2d_fwd_kernel.hpp:96
index_t y_stride
Definition: rmsnorm2d_fwd_kernel.hpp:95
void * p_y
Definition: rmsnorm2d_fwd_kernel.hpp:83
index_t xr_stride
Definition: rmsnorm2d_fwd_kernel.hpp:94
const void * p_sm_scale
Definition: rmsnorm2d_fwd_kernel.hpp:80
void * p_y_residual
Definition: rmsnorm2d_fwd_kernel.hpp:84
void * p_y_unquant
Definition: rmsnorm2d_fwd_kernel.hpp:87
index_t m
Definition: rmsnorm2d_fwd_kernel.hpp:91
const void * p_gamma
Definition: rmsnorm2d_fwd_kernel.hpp:81
float epsilon
Definition: rmsnorm2d_fwd_kernel.hpp:89
const void * p_x_residual
Definition: rmsnorm2d_fwd_kernel.hpp:79
index_t x_stride
Definition: rmsnorm2d_fwd_kernel.hpp:93
Definition: rmsnorm2d_fwd_kernel.hpp:128
Definition: rmsnorm2d_fwd_kernel.hpp:14
void * p_invRms
Definition: rmsnorm2d_fwd_kernel.hpp:23
index_t xr_stride
Definition: rmsnorm2d_fwd_kernel.hpp:31
void * p_y_residual
Definition: rmsnorm2d_fwd_kernel.hpp:21
const void * p_x_residual
Definition: rmsnorm2d_fwd_kernel.hpp:16
void * p_y_scale
Definition: rmsnorm2d_fwd_kernel.hpp:22
float epsilon
Definition: rmsnorm2d_fwd_kernel.hpp:26
void * p_y
Definition: rmsnorm2d_fwd_kernel.hpp:20
index_t yr_stride
Definition: rmsnorm2d_fwd_kernel.hpp:33
void * p_y_unquant
Definition: rmsnorm2d_fwd_kernel.hpp:24
index_t x_stride
Definition: rmsnorm2d_fwd_kernel.hpp:30
index_t y_stride
Definition: rmsnorm2d_fwd_kernel.hpp:32
index_t m
Definition: rmsnorm2d_fwd_kernel.hpp:28
index_t n
Definition: rmsnorm2d_fwd_kernel.hpp:29
const void * p_sm_scale
Definition: rmsnorm2d_fwd_kernel.hpp:17
const void * p_x
Definition: rmsnorm2d_fwd_kernel.hpp:15
const void * p_gamma
Definition: rmsnorm2d_fwd_kernel.hpp:18
Definition: rmsnorm2d_fwd_kernel.hpp:39
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: rmsnorm2d_fwd_kernel.hpp:179
XDataType XResidualDataType
Definition: rmsnorm2d_fwd_kernel.hpp:54
remove_cvref_t< typename Problem::UnquantYDataType > UnquantYDataType
Definition: rmsnorm2d_fwd_kernel.hpp:51
remove_cvref_t< Epilogue_ > Epilogue
Definition: rmsnorm2d_fwd_kernel.hpp:41
remove_cvref_t< typename Problem::YScaleDataType > YScaleDataType
Definition: rmsnorm2d_fwd_kernel.hpp:50
static constexpr bool kTwoPass
Definition: rmsnorm2d_fwd_kernel.hpp:65
static constexpr bool kSaveInvRms
Definition: rmsnorm2d_fwd_kernel.hpp:58
static constexpr CK_TILE_HOST auto GridSize(const Hargs &hargs)
Definition: rmsnorm2d_fwd_kernel.hpp:120
static constexpr auto I0
Definition: rmsnorm2d_fwd_kernel.hpp:73
typename Pipeline::Problem Problem
Definition: rmsnorm2d_fwd_kernel.hpp:42
remove_cvref_t< typename Problem::InvRmsDataType > InvRmsDataType
Definition: rmsnorm2d_fwd_kernel.hpp:48
static constexpr bool kPadN
Definition: rmsnorm2d_fwd_kernel.hpp:64
static CK_TILE_HOST std::string GetName()
Definition: rmsnorm2d_fwd_kernel.hpp:140
remove_cvref_t< typename Problem::YDataType > YDataType
Definition: rmsnorm2d_fwd_kernel.hpp:47
static constexpr auto kFusedQuant
Definition: rmsnorm2d_fwd_kernel.hpp:67
remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition: rmsnorm2d_fwd_kernel.hpp:46
remove_cvref_t< Pipeline_ > Pipeline
Definition: rmsnorm2d_fwd_kernel.hpp:40
static constexpr auto I1
Definition: rmsnorm2d_fwd_kernel.hpp:74
static constexpr bool kPadM
Definition: rmsnorm2d_fwd_kernel.hpp:63
remove_cvref_t< typename Problem::SmoothScaleDataType > SmoothScaleDataType
Definition: rmsnorm2d_fwd_kernel.hpp:49
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: rmsnorm2d_fwd_kernel.hpp:138
static constexpr index_t Block_M
Definition: rmsnorm2d_fwd_kernel.hpp:61
static constexpr auto kFusedAdd
Definition: rmsnorm2d_fwd_kernel.hpp:66
XDataType YResidualDataType
Definition: rmsnorm2d_fwd_kernel.hpp:55
static constexpr bool kHasGamma
Definition: rmsnorm2d_fwd_kernel.hpp:57
remove_cvref_t< typename Problem::XDataType > XDataType
Definition: rmsnorm2d_fwd_kernel.hpp:44
static constexpr index_t ThreadPerWarp_N
Definition: rmsnorm2d_fwd_kernel.hpp:69
static constexpr CK_TILE_HOST auto BlockSize()
Definition: rmsnorm2d_fwd_kernel.hpp:125
static constexpr CK_TILE_HOST Kargs MakeKargs(const Hargs &hargs)
Definition: rmsnorm2d_fwd_kernel.hpp:100
remove_cvref_t< typename Problem::GammaDataType > GammaDataType
Definition: rmsnorm2d_fwd_kernel.hpp:45
static constexpr index_t Block_N
Definition: rmsnorm2d_fwd_kernel.hpp:62
static constexpr bool kSaveUnquant
Definition: rmsnorm2d_fwd_kernel.hpp:59
static constexpr index_t Repeat_N
Definition: rmsnorm2d_fwd_kernel.hpp:71
static constexpr index_t Vector_N
Definition: rmsnorm2d_fwd_kernel.hpp:70
Definition: integral_constant.hpp:13
Definition: sequence.hpp:52