/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp Source File
reference_rmsnorm2d_fwd.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
9 
10 namespace ck_tile {
11 
12 // Note: for simplicity, each functor only care about single M
14 {
15  template <typename OutDataType, typename AccDataType>
17  {
18  const int N = acc.mDesc.get_lengths()[1];
19  for(int n = 0; n < N; ++n)
20  {
21  o(m, n) = ck_tile::type_convert<OutDataType>(acc(m, n));
22  }
23  }
24 
25  template <typename OutDataType, typename AccDataType>
26  auto operator()(int m, const HostTensor<AccDataType>& acc)
27  {
29  operator()(m, o, acc);
30  return o;
31  }
32 };
33 
34 template <typename XDataType,
35  typename GammaDataType,
36  typename ComputeDataType,
37  typename YDataType,
38  typename InvRmsDataType,
39  typename UnquantYDataType,
40  typename Epilogue = reference_rmsnorm2d_default_epilogue>
42  const HostTensor<GammaDataType>& gamma_n,
43  HostTensor<YDataType>& y_m_n,
45  HostTensor<UnquantYDataType>& unquant_y_m_n,
46  ComputeDataType epsilon,
47  Epilogue epilogue_functor = {},
48  const int use_model_sensitive_rmsnorm =
50 {
51  auto rmsnorm2d_fwd_func = [&](auto m) {
52  const int N = x_m_n.mDesc.get_lengths()[1];
53 
54  ComputeDataType mean_square = 0;
55  ComputeDataType divisor = 0;
56 
57  for(int n = 0; n < N; ++n)
58  {
59  ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
60  mean_square += x * x;
61  }
62 
63  mean_square = mean_square / N;
64  divisor = ck_tile::type_convert<ComputeDataType>(1) / ck_tile::sqrt(mean_square + epsilon);
65 
66  if constexpr(!std::is_same_v<InvRmsDataType, ck_tile::null_type>)
67  invRms_m(m) = ck_tile::type_convert<InvRmsDataType>(divisor);
68 
69  HostTensor<ComputeDataType> acc(x_m_n.get_lengths(), x_m_n.get_strides());
70  for(int n = 0; n < N; ++n)
71  {
72  ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
73  ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
74  if(use_model_sensitive_rmsnorm ==
75  static_cast<int>(
76  Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL)) // 0: for no specific model
77  {
78  acc(m, n) = x * divisor * gamma;
79  }
80  else if(use_model_sensitive_rmsnorm ==
81  static_cast<int>(Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE)) // 1: for T5-like model
82  {
83  if constexpr(std::is_same_v<XDataType, ck_tile::bf16_t>)
84  {
85  const auto tmp0 = float_to_bf16<bf16_rounding_mode::standard>(x * divisor);
86  const auto tmp1 = float_to_bf16<bf16_rounding_mode::standard>(
87  type_convert<ComputeDataType>(tmp0) * gamma);
88  const auto rmsn_ = type_convert<ComputeDataType>(tmp1);
89  acc(m, n) = rmsn_;
90  }
91  else
92  {
93  const auto tmp = type_convert<XDataType>(x * divisor);
94  const auto rmsn_ = type_convert<ComputeDataType>(tmp) * gamma;
95  acc(m, n) = rmsn_;
96  }
97  }
98  }
99 
100  if constexpr(!std::is_same_v<UnquantYDataType, ck_tile::null_type>)
101  {
102  epilogue_functor(m, unquant_y_m_n, y_m_n, acc);
103  }
104  else
105  {
106  epilogue_functor(m, y_m_n, acc);
107  }
108  };
109 
110  make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])(
111  std::thread::hardware_concurrency());
112 }
113 
114 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition: bfloat16.hpp:405
void reference_rmsnorm2d_fwd(const HostTensor< XDataType > &x_m_n, const HostTensor< GammaDataType > &gamma_n, HostTensor< YDataType > &y_m_n, HostTensor< InvRmsDataType > &invRms_m, HostTensor< UnquantYDataType > &unquant_y_m_n, ComputeDataType epsilon, Epilogue epilogue_functor={}, const int use_model_sensitive_rmsnorm=static_cast< int >(Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL))
Definition: reference_rmsnorm2d_fwd.hpp:41
const std::vector< std::size_t > & get_lengths() const
Definition: host_tensor.hpp:198
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
decltype(auto) get_strides() const
Definition: host_tensor.hpp:394
Descriptor mDesc
Definition: host_tensor.hpp:800
Definition: reference_rmsnorm2d_fwd.hpp:14
auto operator()(int m, const HostTensor< AccDataType > &acc)
Definition: reference_rmsnorm2d_fwd.hpp:26
void operator()(int m, HostTensor< OutDataType > &o, const HostTensor< AccDataType > &acc)
Definition: reference_rmsnorm2d_fwd.hpp:16