/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp Source File
layernorm2d_fwd_pipeline_two_pass.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 #include <string>
9 #include <type_traits>
10 
11 namespace ck_tile {
12 
13 template <typename Problem_, typename Policy_ = Layernorm2dFwdPipelineDefaultPolicy>
15 {
18 
27 
30 
31  static constexpr bool kHasGamma = !std::is_same_v<GammaDataType, ck_tile::null_type>;
32  static constexpr bool kHasBeta = !std::is_same_v<BetaDataType, ck_tile::null_type>;
33  static constexpr bool kSaveMean = Problem::Traits::kSaveMeanInvStd;
34  static constexpr bool kSaveInvStd = Problem::Traits::kSaveMeanInvStd;
35 
36  static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
37  static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
38  static constexpr bool kPadN = Problem::Traits::kPadN;
39  static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
40  static constexpr bool kWelford = Problem::Traits::kWelford;
41  static constexpr auto kXbias = Problem::Traits::kXbias;
42  static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
43  static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
44 
45  static constexpr const char* name = []() {
46  if constexpr(kNeedCrossWarpSync)
47  return "bpr_2p"; // block per row
48  else
49  return "wpr_2p"; // warp per row
50  }();
51 
53  {
54  return Policy::template GetSmemSize<Problem>();
55  }
56 
57  template <typename XWindow,
58  typename XResidualWindow,
59  typename XBiasWindow,
60  typename GammaWindow,
61  typename BetaWindow,
62  typename YWindow,
63  typename YResidualWindow,
64  typename MeanWindow,
65  typename InvStdWindow,
66  typename SmoothScaleWindow,
67  typename YScaleWindow,
68  typename Epilogue>
69  CK_TILE_DEVICE auto operator()(const XWindow& x_window_,
70  const XResidualWindow& x_residual_window_,
71  const XBiasWindow& x_bias_window_,
72  const GammaWindow& gamma_window_,
73  const BetaWindow& beta_window_,
74  YWindow& y_window,
75  const YResidualWindow& y_residual_window_,
76  MeanWindow& mean_window,
77  InvStdWindow& inv_std_window,
78  const SmoothScaleWindow& /*sm_scale_window*/,
79  YScaleWindow& /*y_scale_window*/,
80  ComputeDataType epsilon,
81  ck_tile::index_t row_size,
82  void* smem,
83  Epilogue) const
84  {
85  static_assert(kWelford == true, "2 pass only supports welford merge");
86  auto x_window =
87  make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution<Problem>());
88  auto x_bias_window = make_tile_window(
89  x_bias_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
90  auto gamma_window = make_tile_window(
91  gamma_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
92  auto beta_window = make_tile_window(
93  beta_window_, Policy::template MakeGammaBetaBlockTileDistribution<Problem>());
94  auto x_residual_window = make_tile_window(
95  x_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
96  auto y_residual_window = make_tile_window(
97  y_residual_window_, Policy::template MakeXBlockTileDistribution<Problem>());
98 
99  // Problem::BlockShape
100  static constexpr index_t Block_N = Problem::BlockShape::Block_N;
101  index_t num_n_tile_iteration =
102  __builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N));
103 
104  // total number of count assume current iter have no pad(only last iter has pad)
105  constexpr index_t count_per_iter =
106  Problem::BlockShape::Repeat_N * Problem::BlockShape::Vector_N;
107  const index_t last_iter_n = row_size - (num_n_tile_iteration - 1) * Block_N;
108 
109  int cur_count = 0;
110  int max_count =
111  (num_n_tile_iteration - 1) * count_per_iter +
112  block_tile_welford_calculate_max_count<typename Problem::BlockShape>(last_iter_n);
113  auto block_norm_reduce = Policy::template GetBlockNormReduce<Problem>();
114  auto block_norm_reduce_sync = Policy::template GetBlockNormReduceSync<Problem>();
115  auto block_norm_reduce_cross_warp_sync =
116  Policy::template GetBlockNormReduceCrossWarpSync<Problem>();
117 
118  using XTensorType = decltype(cast_tile<ComputeDataType>(load_tile(x_window)));
119  auto mean = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
120  auto var = block_norm_reduce.template MakeMeanVarBlockTile<XTensorType>();
121 
122  for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
123  {
124  auto x = load_tile(x_window);
125  auto x_resi = load_tile(x_residual_window);
126  const auto x_bias = load_tile(x_bias_window);
127 
128  move_tile_window(x_window, {0, Block_N});
129  move_tile_window(x_residual_window, {0, Block_N});
130  move_tile_window(x_bias_window, {Block_N});
131  auto acc = cast_tile<ComputeDataType>(x);
132 
133  if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
134  {
135  sweep_tile(x, [&](auto idx) {
136  // compute x = bias + x
137  constexpr auto j_idx = make_tuple(idx[number<1>{}]);
138  acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
139  });
140  }
141 
144  {
145  sweep_tile(x_resi, [&](auto idx) {
146  // compute x = x_resi + x
147  acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
148  });
150  {
151  store_tile(y_residual_window, cast_tile<YResidualDataType>(acc));
152  move_tile_window(y_residual_window, {0, Block_N});
153  }
154  }
155  block_norm_reduce(acc, mean, var, cur_count, max_count);
156  }
157 
158  block_norm_reduce_sync(mean, var, cur_count);
159  block_norm_reduce_cross_warp_sync(mean, var, cur_count, smem);
161 
162  // compute inv-std
163  auto inv_std = tile_elementwise_in(
164  [&](const auto& v_) {
165  if(kFastFDiv && std::is_same_v<ComputeDataType, float>)
166  {
167  return type_convert<ComputeDataType>(1.0f) *
168  __builtin_amdgcn_rcpf(sqrt(v_ + epsilon));
169  }
170  else
171  {
172  return type_convert<ComputeDataType>(1.0f) / sqrt(v_ + epsilon);
173  }
174  },
175  var);
176  if constexpr(kSaveMean)
177  store_tile(mean_window, cast_tile<MeanDataType>(mean));
178  if constexpr(kSaveInvStd)
179  store_tile(inv_std_window, cast_tile<InvStdDataType>(inv_std));
180 
181  // reverse read x to reuse cache
182  ck_tile::index_t stride_to_right_most_window =
183  row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N;
184 
185  move_tile_window(x_window, {0, -Block_N});
186  move_tile_window(x_residual_window, {0, -Block_N});
187  move_tile_window(x_bias_window, {-Block_N});
188  move_tile_window(gamma_window, {stride_to_right_most_window});
189  move_tile_window(beta_window, {stride_to_right_most_window});
190  move_tile_window(y_window, {0, stride_to_right_most_window});
191 
192  // layernorm computation
193  for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN)
194  {
195  auto x = load_tile(x_window);
196  auto x_resi = load_tile(x_residual_window);
197  const auto x_bias = load_tile(x_bias_window);
198  auto acc = cast_tile<ComputeDataType>(x);
199 
200  if constexpr(kXbias == Layernorm2dXBiasEnum::ADD_BIAS)
201  {
202  sweep_tile(x, [&](auto idx) {
203  // compute x = bias + x
204  constexpr auto j_idx = make_tuple(idx[number<1>{}]);
205  acc(idx) = type_convert<ComputeDataType>(x_bias[j_idx]) + acc(idx);
206  });
207  }
208 
211  {
212  sweep_tile(x_resi, [&](auto idx) {
213  // compute x = x_resi + x
214  acc(idx) = type_convert<ComputeDataType>(x_resi(idx)) + acc(idx);
215  });
216  }
217  // load gamma/beta (TODO: support no gamma/beta?)
218  const auto gamma = load_tile(gamma_window);
219  const auto beta = load_tile(beta_window);
220 
221  auto ln = make_static_distributed_tensor<ComputeDataType>(acc.get_tile_distribution());
222 
223  sweep_tile(ln, [&, mean_ = mean](auto idx) {
224  constexpr auto i_idx = make_tuple(idx[number<0>{}]);
225  constexpr auto j_idx = make_tuple(idx[number<1>{}]);
226 
227  const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
228  const auto beta_ = type_convert<ComputeDataType>(beta[j_idx]);
229 
230  auto ln_ = (acc(idx) - mean_[i_idx]) * inv_std[i_idx] * gamma_ + beta_;
231 
232  ln(idx) = ln_;
233  });
234 
236  Epilogue{}(y_window, ln);
237 
238  move_tile_window(x_window, {0, -Block_N});
239  move_tile_window(x_residual_window, {0, -Block_N});
240  move_tile_window(x_bias_window, {-Block_N});
241  move_tile_window(gamma_window, {-Block_N});
242  move_tile_window(beta_window, {-Block_N});
243  move_tile_window(y_window, {0, -Block_N});
244  }
245  }
246 };
247 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition: bfloat16.hpp:408
constexpr CK_TILE_DEVICE void block_tile_welford_post_scale_var(VarDistributedTensor_ &var_tensor, int count, bool_constant< FastFdiv_ >={})
Definition: block_norm_reduce.hpp:393
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:27
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:92
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
constexpr CK_TILE_HOST_DEVICE void sweep_tile(const F &f, UnpacksPerXDim={})
Definition: sweep_tile.hpp:231
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:15
static constexpr auto kXbias
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:41
static constexpr bool kHasBeta
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:32
static constexpr bool kWelford
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:40
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:52
ck_tile::remove_cvref_t< typename Problem::XBiasDataType > XBiasDataType
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:20
XDataType XResidualDataType
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:28
static constexpr bool kSaveInvStd
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:34
CK_TILE_DEVICE auto operator()(const XWindow &x_window_, const XResidualWindow &x_residual_window_, const XBiasWindow &x_bias_window_, const GammaWindow &gamma_window_, const BetaWindow &beta_window_, YWindow &y_window, const YResidualWindow &y_residual_window_, MeanWindow &mean_window, InvStdWindow &inv_std_window, const SmoothScaleWindow &, YScaleWindow &, ComputeDataType epsilon, ck_tile::index_t row_size, void *smem, Epilogue) const
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:69
ck_tile::remove_cvref_t< typename Problem::BetaDataType > BetaDataType
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:22
static constexpr auto kFusedQuant
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:43
static constexpr bool kHasGamma
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:31
ck_tile::remove_cvref_t< typename Problem::GammaDataType > GammaDataType
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:21
ck_tile::remove_cvref_t< Problem_ > Problem
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:16
ck_tile::remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:23
static constexpr bool kFastFDiv
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:39
XDataType YResidualDataType
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:29
ck_tile::remove_cvref_t< typename Problem::InvStdDataType > InvStdDataType
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:26
static constexpr bool kSaveMean
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:33
static constexpr bool kPadM
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:37
ck_tile::remove_cvref_t< typename Problem::YDataType > YDataType
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:24
static constexpr bool kPadN
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:38
static constexpr auto kFusedAdd
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:42
ck_tile::remove_cvref_t< Policy_ > Policy
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:17
ck_tile::remove_cvref_t< typename Problem::XDataType > XDataType
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:19
static constexpr bool kNeedCrossWarpSync
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:36
static constexpr const char * name
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:45
ck_tile::remove_cvref_t< typename Problem::MeanDataType > MeanDataType
Definition: layernorm2d_fwd_pipeline_two_pass.hpp:25
Definition: integral_constant.hpp:13