/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp Source File
reference_layernorm2d_fwd.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
8 
9 namespace ck_tile {
10 
11 // Note: for simplicity, each functor only care about single M
13 {
14  template <typename OutDataType, typename AccDataType>
16  {
17  const int N = acc.mDesc.get_lengths()[1];
18  for(int n = 0; n < N; ++n)
19  {
20  o(m, n) = ck_tile::type_convert<OutDataType>(acc(m, n));
21  }
22  }
23 
24  template <typename OutDataType, typename AccDataType>
25  auto operator()(int m, const HostTensor<AccDataType>& acc)
26  {
28  operator()(m, o, acc);
29  return o;
30  }
31 };
32 
33 template <typename XDataType,
34  typename GammaDataType,
35  typename BetaDataType,
36  typename ComputeDataType,
37  typename YDataType,
38  typename MeanDataType,
39  typename InvStdDataType,
40  typename Epilogue = reference_layernorm2d_default_epilogue>
42  const HostTensor<GammaDataType>& gamma_n,
43  const HostTensor<BetaDataType>& beta_n,
44  HostTensor<YDataType>& y_m_n,
47  ComputeDataType epsilon,
48  Epilogue epilogue_functor = {})
49 {
50  auto layernorm2d_fwd_func = [&](auto m) {
51  const int N = x_m_n.mDesc.get_lengths()[1];
52 
53  int count = 0;
54  ComputeDataType mean = 0;
55  ComputeDataType variance = 0;
56  ComputeDataType divisor = 0;
57 
58  for(int n = 0; n < N; ++n)
59  {
60  ++count;
61  ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
62  ComputeDataType delta = x - mean;
63  mean += delta / count;
64  ComputeDataType delta2 = x - mean;
65  variance += delta * delta2;
66  }
67 
68  // actual variance
69  variance = variance / count;
70  divisor = ck_tile::type_convert<ComputeDataType>(1) / ck_tile::sqrt(variance + epsilon);
71 
72  if constexpr(!std::is_same_v<MeanDataType, ck_tile::null_type>)
73  mean_m(m) = ck_tile::type_convert<MeanDataType>(mean);
74 
75  if constexpr(!std::is_same_v<InvStdDataType, ck_tile::null_type>)
76  invStd_m(m) = ck_tile::type_convert<InvStdDataType>(divisor);
77 
78  HostTensor<ComputeDataType> acc(x_m_n.get_lengths(), x_m_n.get_strides());
79  for(int n = 0; n < N; ++n)
80  {
81  ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
82  ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
83  ComputeDataType beta = ck_tile::type_convert<ComputeDataType>(beta_n(n));
84  auto a_ = (x - mean) * divisor;
85  a_ = a_ * gamma + beta;
86 
87  acc(m, n) = a_;
88  }
89 
90  epilogue_functor(m, y_m_n, acc);
91  };
92 
93  make_ParallelTensorFunctor(layernorm2d_fwd_func,
94  mean_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
95 }
96 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition: bfloat16.hpp:417
void reference_layernorm2d_fwd(const HostTensor< XDataType > &x_m_n, const HostTensor< GammaDataType > &gamma_n, const HostTensor< BetaDataType > &beta_n, HostTensor< YDataType > &y_m_n, HostTensor< MeanDataType > &mean_m, HostTensor< InvStdDataType > &invStd_m, ComputeDataType epsilon, Epilogue epilogue_functor={})
Definition: reference_layernorm2d_fwd.hpp:41
const std::vector< std::size_t > & get_lengths() const
Definition: host_tensor.hpp:198
Definition: host_tensor.hpp:336
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
decltype(auto) get_strides() const
Definition: host_tensor.hpp:394
Descriptor mDesc
Definition: host_tensor.hpp:800
Definition: reference_layernorm2d_fwd.hpp:13
auto operator()(int m, const HostTensor< AccDataType > &acc)
Definition: reference_layernorm2d_fwd.hpp:25
void operator()(int m, HostTensor< OutDataType > &o, const HostTensor< AccDataType > &acc)
Definition: reference_layernorm2d_fwd.hpp:15