14     template <
typename OutDataType, 
typename AccDataType>
 
   18         for(
int n = 0; n < N; ++n)
 
   20             o(m, n) = ck_tile::type_convert<OutDataType>(acc(m, n));
 
   24     template <
typename OutDataType, 
typename AccDataType>
 
   33 template <
typename XDataType,
 
   34           typename GammaDataType,
 
   35           typename BetaDataType,
 
   36           typename ComputeDataType,
 
   38           typename MeanDataType,
 
   39           typename InvStdDataType,
 
   40           typename Epilogue = reference_layernorm2d_default_epilogue>
 
   47                                ComputeDataType epsilon,
 
   48                                Epilogue epilogue_functor = {})
 
   50     auto layernorm2d_fwd_func = [&](
auto m) {
 
   54         ComputeDataType mean     = 0;
 
   55         ComputeDataType variance = 0;
 
   56         ComputeDataType divisor  = 0;
 
   58         for(
int n = 0; n < N; ++n)
 
   61             ComputeDataType x     = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
 
   62             ComputeDataType delta = x - mean;
 
   63             mean += delta / count;
 
   64             ComputeDataType delta2 = x - mean;
 
   65             variance += delta * delta2;
 
   69         variance = variance / count;
 
   70         divisor  = ck_tile::type_convert<ComputeDataType>(1) / 
ck_tile::sqrt(variance + epsilon);
 
   72         if constexpr(!std::is_same_v<MeanDataType, ck_tile::null_type>)
 
   73             mean_m(m) = ck_tile::type_convert<MeanDataType>(mean);
 
   75         if constexpr(!std::is_same_v<InvStdDataType, ck_tile::null_type>)
 
   76             invStd_m(m) = ck_tile::type_convert<InvStdDataType>(divisor);
 
   79         for(
int n = 0; n < N; ++n)
 
   81             ComputeDataType x     = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
 
   82             ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
 
   83             ComputeDataType beta  = ck_tile::type_convert<ComputeDataType>(beta_n(n));
 
   84             auto a_               = (x - mean) * divisor;
 
   85             a_                    = a_ * gamma + beta;
 
   90         epilogue_functor(m, y_m_n, acc);
 
Definition: cluster_descriptor.hpp:13
 
CK_TILE_HOST auto make_ParallelTensorFunctor(F f, Xs... xs)
Definition: host_tensor.hpp:329
 
CK_TILE_DEVICE bfloat16_t sqrt(bfloat16_t x)
Definition: bfloat16.hpp:406
 
void reference_layernorm2d_fwd(const HostTensor< XDataType > &x_m_n, const HostTensor< GammaDataType > &gamma_n, const HostTensor< BetaDataType > &beta_n, HostTensor< YDataType > &y_m_n, HostTensor< MeanDataType > &mean_m, HostTensor< InvStdDataType > &invStd_m, ComputeDataType epsilon, Epilogue epilogue_functor={})
Definition: reference_layernorm2d_fwd.hpp:41
 
const std::vector< std::size_t > & get_lengths() const
Definition: host_tensor.hpp:198
 
Definition: host_tensor.hpp:336
 
decltype(auto) get_lengths() const
Definition: host_tensor.hpp:390
 
decltype(auto) get_strides() const
Definition: host_tensor.hpp:394
 
Descriptor mDesc
Definition: host_tensor.hpp:742
 
Definition: reference_layernorm2d_fwd.hpp:13
 
auto operator()(int m, const HostTensor< AccDataType > &acc)
Definition: reference_layernorm2d_fwd.hpp:25
 
void operator()(int m, HostTensor< OutDataType > &o, const HostTensor< AccDataType > &acc)
Definition: reference_layernorm2d_fwd.hpp:15