13 namespace tensor_operation {
 
   14 namespace element_wise {
 
   36     template <
typename Y, 
typename X0, 
typename X1, 
typename X2>
 
   37     __host__ __device__ constexpr 
void operator()(Y&, 
const X0&, 
const X1&, 
const X2&) 
const;
 
   49     __host__ __device__ constexpr 
void operator()<float, float, float, 
float>(
float& y,
 
   52                                                                               const float& x2) 
const 
   55         float b = a > 0 ? a : 0;
 
   65         float b = a > 0 ? a : 0;
 
   75         float b = a > 0 ? a : 0;
 
   90 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4 
  105     template <
typename Y, 
typename X0, 
typename X1, 
typename X2>
 
  106     __host__ __device__ constexpr 
void operator()(Y&, 
const X0&, 
const X1&, 
const X2&) 
const;
 
  109     __host__ __device__ constexpr 
void operator()<float, float, float, 
float>(
float& y,
 
  112                                                                               const float& x2) 
const 
  115         float b = a + 
float{3};
 
  116         float c = (b > 0) * (b > 
float{6} ? 
float{6} : b) * a * 
float{0.166667};
 
  126         float b = a + 
float{3};
 
  127         float c = (b > 0) * (b > 
float{6} ? 
float{6} : b) * a * 
float{0.166667};
 
  137     template <
typename E, 
typename C, 
typename D0, 
typename D1>
 
  138     __host__ __device__ 
void operator()(E& e, 
const C& c, 
const D0& d0, 
const D1& d1)
 const 
  143                       "Data type is not supported by this operation!");
 
  147                       "Data type is not supported by this operation!");
 
  151                       "Data type is not supported by this operation!");
 
  155                       "Data type is not supported by this operation!");
 
  157         const C y = c + type_convert<C>(d0) + type_convert<C>(d1);
 
  158         e         = type_convert<E>(y);
 
  166     template <
typename E, 
typename C, 
typename D0, 
typename D1>
 
  167     __host__ __device__ 
void operator()(E& e, 
const C& c, 
const D0& d0, 
const D1& d1) 
const;
 
  175         const half_t y = (c + d0) * d1;
 
  184         const half_t y = (type_convert<half_t>(c) + d0) * d1;
 
  188     __host__ __device__ 
void operator()<float, float, 
half_t, 
half_t>(
float& e,
 
  193         const float y = (c + d0) * d1;
 
  202     template <
typename E, 
typename C, 
typename D0, 
typename D1>
 
  203     __host__ __device__ 
void operator()(E& e, 
const C& c, 
const D0& d0, 
const D1& d1) 
const;
 
  211         const half_t y = (c * d0) + d1;
 
  220         const half_t y = type_convert<half_t>(c) * d0 + d1;
 
  229         const bhalf_t y = type_convert<bhalf_t>(c) * d0 + d1;
 
  233     __host__ __device__ 
void operator()<float, float, 
half_t, 
half_t>(
float& e,
 
  238         const float y = c * d0 + d1;
 
  242     __host__ __device__ 
void operator()<
half_t, float, float, 
float>(
half_t& e,
 
  245                                                                      const float& d1) 
const 
  247         const float y = c * d0 + d1;
 
  254     template <
typename E, 
typename C, 
typename D0, 
typename D1>
 
  255     __host__ __device__ constexpr 
void 
  256     operator()(E& e, 
const C& c, 
const D0& d0, 
const D1& d1) 
const;
 
  259     __host__ __device__ constexpr 
void operator()<
ck::half_t, float, float, 
float>(
 
  260         ck::half_t& e, 
const float& c, 
const float& d0, 
const float& d1) 
const 
  262         const float x0_f = c * d0 * d1;
 
  264         e = ck::type_convert<ck::half_t>(x0_f);
 
  268     __host__ __device__ constexpr 
void operator()<
ck::bhalf_t, float, float, 
float>(
 
  269         ck::bhalf_t& e, 
const float& c, 
const float& d0, 
const float& d1) 
const 
  271         const float x0_f = c * d0 * d1;
 
  273         e = ck::type_convert<ck::bhalf_t>(x0_f);
 
  281             ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
 
  283         e = ck::type_convert<ck::half_t>(x0_f);
 
  287     __host__ __device__ constexpr 
void operator()<
ck::half_t, int, float, 
float>(
 
  288         ck::half_t& e, 
const int& c, 
const float& d0, 
const float& d1) 
const 
  291             ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
 
  293         e = ck::type_convert<ck::half_t>(x0_f);
 
  297     __host__ __device__ constexpr 
void operator()<
ck::bhalf_t, int, float, 
float>(
 
  298         ck::bhalf_t& e, 
const int& c, 
const float& d0, 
const float& d1) 
const 
  301             ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
 
  303         e = ck::type_convert<ck::bhalf_t>(x0_f);
 
  309     template <
typename E, 
typename C, 
typename D0, 
typename D1>
 
  310     __host__ __device__ constexpr 
void 
  311     operator()(E& e, 
const C& c, 
const D0& d0, 
const D1& d1) 
const;
 
  317         const float x0_f = c * ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
 
  321         FastGelu{}.template operator()<float, 
float>(x1_f, x0_f);
 
  323         e = ck::type_convert<ck::bhalf_t>(x1_f);
 
  330     template <
typename E, 
typename C, 
typename D0, 
typename D1>
 
  331     __host__ __device__ constexpr 
void 
  332     operator()(E& e, 
const C& c, 
const D0& d0, 
const D1& d1) 
const;
 
  335     __host__ __device__ constexpr 
void operator()<float, float, float, 
float>(
float& e,
 
  338                                                                               const float& d1) 
const 
  340         const float x = c + d0 + d1;
 
  342         FastGelu{}.template operator()<float, 
float>(e, x);
 
  349         const half_t x = c + d0 + d1;
 
  358         const float x0_f = c + d0 + d1;
 
  365         e = type_convert<half_t>(x1_f);
 
  372         const float x0_f = c + type_convert<float>(d0) + type_convert<float>(d1);
 
  379         e = type_convert<bhalf_t>(x1_f);
 
  387             type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1);
 
  394         e = type_convert<int8_t>(x1_f);
 
  407     template <
typename E, 
typename C, 
typename D0, 
typename D1>
 
  408     __host__ __device__ constexpr 
void 
  409     operator()(E& e, 
const C& c, 
const D0& d0, 
const D1& d1) 
const;
 
  412     __host__ __device__ constexpr 
void operator()<float, float, float, 
float>(
float& e,
 
  415                                                                               const float& d1) 
const 
  425         const float x = type_convert<float>(c) * 
alpha1_ + 
alpha2_ * type_convert<float>(d0) +
 
  426                         type_convert<float>(d1);
 
  429         result       = x > 0 ? x : 0;
 
  431         e = type_convert<half_t>(result);
 
  438         const float x = type_convert<float>(c) * 
alpha1_ + 
alpha2_ * type_convert<float>(d0) +
 
  439                         type_convert<float>(d1);
 
  442         result       = x > 0 ? x : 0;
 
  444         e = type_convert<bhalf_t>(result);
 
  448     __host__ __device__ constexpr 
void operator()<
int8_t, 
int8_t, float, 
float>(
 
  449         int8_t& e, 
const int8_t& c, 
const float& d0, 
const float& d1) 
const 
  451         const float x = type_convert<float>(c) * 
alpha1_ + 
alpha2_ * d0 + d1;
 
  454         result       = x > 0 ? x : 0;
 
  456         e = type_convert<int8_t>(result);
 
  468     template <
typename T1, 
typename T2, 
typename T3>
 
  472                                                   const T2& mean_square,
 
  474                                                   const T3& beta) 
const;
 
  480                                                                          const float& mean_square,
 
  484         using ck::math::sqrt;
 
  486         float variance = mean_square - (mean * mean);
 
  488         float tmp_x     = type_convert<float>(x);
 
  489         float tmp_gamma = type_convert<float>(gamma);
 
  490         float tmp_beta  = type_convert<float>(beta);
 
  493             ((tmp_x - mean) / sqrt(variance + type_convert<float>(
epsilon_))) * tmp_gamma +
 
  496         y = type_convert<half_t>(tmp_y);
 
  500     __host__ __device__ constexpr 
void operator()<float, float, 
float>(
float& y,
 
  503                                                                        const float& mean_square,
 
  505                                                                        const float& beta) 
const 
  507         using ck::math::sqrt;
 
  509         float variance = mean_square - (mean * mean);
 
  510         y = ((x - mean) / sqrt(variance + type_convert<float>(
epsilon_))) * gamma + beta;
 
  514     __host__ __device__ constexpr 
void operator()<double, double, 
double>(
double& y,
 
  517                                                                           const double& mean_square,
 
  519                                                                           const double& beta) 
const 
  521         using ck::math::sqrt;
 
  523         double variance = mean_square - (mean * mean);
 
  524         y               = ((x - mean) / sqrt(variance + 
epsilon_)) * gamma + beta;
 
  538     template <
typename T1, 
typename T2, 
typename T3, 
typename T4>
 
  544                                                   const T4& beta)
 const 
  547                       "Data type is not supported by this operation!");
 
  550         using ck::math::sqrt;
 
  554         tmp_x = type_convert<T2>(x);
 
  556         tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(
epsilon_))) *
 
  557                     type_convert<T2>(gamma) +
 
  558                 type_convert<T2>(beta);
 
  559         y = type_convert<T1>(tmp_y);
 
  565 template <
typename Y, 
typename X>
 
  573         y = ck::type_convert<float, ck::bhalf_t>(x);
 
  582         y = ck::type_convert<ck::bhalf_t, float>(x);
 
int8_t int8_t
Definition: int8.hpp:20
 
int32_t int32_t
Definition: integer.hpp:10
 
_Float16 half_t
Definition: data_type.hpp:30
 
ushort bhalf_t
Definition: data_type.hpp:29
 
__host__ constexpr __device__ Y type_convert(X x)
Definition: type_convert.hpp:98
 
_BitInt(4) int4_t
Definition: data_type.hpp:31
 
Definition: element_wise_operation.hpp:329
 
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
 
Definition: element_wise_operation.hpp:136
 
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:138
 
Definition: element_wise_operation.hpp:104
 
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &, const X2 &) const
 
Definition: element_wise_operation.hpp:165
 
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
 
Definition: element_wise_operation.hpp:35
 
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &, const X2 &) const
 
Definition: unary_element_wise_operation.hpp:866
 
Definition: element_wise_operation.hpp:308
 
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
 
Definition: element_wise_operation.hpp:201
 
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
 
Definition: element_wise_operation.hpp:253
 
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
 
Definition: element_wise_operation.hpp:464
 
Normalize(double epsilon=1e-4)
Definition: element_wise_operation.hpp:466
 
double epsilon_
Definition: element_wise_operation.hpp:525
 
__host__ constexpr __device__ void operator()(T1 &y, const T1 &x, const T2 &mean, const T2 &mean_square, const T3 &gamma, const T3 &beta) const
 
Definition: element_wise_operation.hpp:535
 
double epsilon_
Definition: element_wise_operation.hpp:560
 
__host__ constexpr __device__ void operator()(T1 &y, const T1 &x, const T2 &mean, const T2 &variance, const T3 &gamma, const T4 &beta) const
Definition: element_wise_operation.hpp:539
 
NormalizeInInfer(double epsilon=1e-4)
Definition: element_wise_operation.hpp:536
 
Definition: element_wise_operation.hpp:400
 
ScaleAddScaleAddRelu(const float alpha1=1.f, const float alpha2=1.f)
Definition: element_wise_operation.hpp:402
 
const float alpha2_
Definition: element_wise_operation.hpp:460
 
const float alpha1_
Definition: element_wise_operation.hpp:459
 
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
 
__host__ __device__ void operator()(ck::bhalf_t &y, float &x) const
Definition: element_wise_operation.hpp:580
 
__host__ __device__ void operator()(float &y, ck::bhalf_t &x) const
Definition: element_wise_operation.hpp:571
 
Definition: element_wise_operation.hpp:566