10 namespace tensor_operation {
 
   11 namespace element_wise {
 
   15     template <
typename Y, 
typename X0, 
typename X1>
 
   16     __host__ __device__ constexpr 
void operator()(Y& y, 
const X0& x0, 
const X1& x1) 
const;
 
   19     __host__ __device__ constexpr 
void 
   20     operator()<
float>(
float& y, 
const float& x0, 
const float& x1) 
const 
   26     __host__ __device__ constexpr 
void 
   27     operator()<
double>(
double& y, 
const double& x0, 
const double& x1) 
const 
   33     __host__ __device__ constexpr 
void 
   34     operator()<
float>(
float& y, 
const float& x0, 
const half_t& x1) 
const 
   36         y = x0 + type_convert<half_t>(x1);
 
   40     __host__ __device__ constexpr 
void 
   41     operator()<
half_t>(
half_t& y, 
const float& x0, 
const float& x1) 
const 
   43         y = type_convert<half_t>(x0 + x1);
 
   47     __host__ __device__ constexpr 
void 
   50         y = type_convert<half_t>(x0) + x1;
 
   54     __host__ __device__ constexpr 
void 
   61     __host__ __device__ constexpr 
void 
   62     operator()<
float>(
float& y, 
const float& x0, 
const bhalf_t& x1) 
const 
   64         const float x1_tmp = ck::type_convert<float>(x1);
 
   69     __host__ __device__ constexpr 
void 
   72         const float x1_tmp = ck::type_convert<float>(x0);
 
   73         const float x2_tmp = ck::type_convert<float>(x1);
 
   74         const float y_tmp  = x1_tmp + x2_tmp;
 
   75         y                  = ck::type_convert<bhalf_t>(y_tmp);
 
   79     __host__ __device__ constexpr 
void 
   82         const float x2_tmp = ck::type_convert<float>(x1);
 
   83         const float y_tmp  = x0 + x2_tmp;
 
   84         y                  = ck::type_convert<bhalf_t>(y_tmp);
 
   88     __host__ __device__ constexpr 
void 
   97     template <
typename Y, 
typename X0, 
typename X1>
 
   98     __host__ __device__ 
void operator()(Y& y, 
const X0& x0, 
const X1& x1)
 const 
  100         const Y x0_converted = type_convert<Y>(x0);
 
  101         const Y x1_converted = type_convert<Y>(x1);
 
  108     template <
typename Y, 
typename X0, 
typename X1>
 
  109     __host__ __device__ 
void operator()(Y& y, 
const X0& x0, 
const X1& x1)
 const 
  111         const Y x0_converted = type_convert<Y>(x0);
 
  112         const Y x1_converted = type_convert<Y>(x1);
 
  119     template <
typename Y, 
typename X0, 
typename X1>
 
  120     __host__ __device__ constexpr 
void operator()(Y& y, 
const X0& x0, 
const X1& x1) 
const;
 
  123     __host__ __device__ constexpr 
void 
  124     operator()<
float>(
float& y, 
const float& x0, 
const float& x1) 
const 
  130     __host__ __device__ constexpr 
void 
  131     operator()<
double>(
double& y, 
const double& x0, 
const double& x1) 
const 
  137     __host__ __device__ constexpr 
void 
  138     operator()<
float>(
float& y, 
const float& x0, 
const half_t& x1) 
const 
  140         y = x0 * type_convert<half_t>(x1);
 
  144     __host__ __device__ constexpr 
void 
  147         y = type_convert<half_t>(x0 * x1);
 
  151     __host__ __device__ constexpr 
void 
  154         y = type_convert<half_t>(x0) * x1;
 
  158     __host__ __device__ constexpr 
void 
  165     __host__ __device__ constexpr 
void 
  166     operator()<
float>(
float& y, 
const float& x0, 
const bhalf_t& x1) 
const 
  168         const float x1_tmp = ck::type_convert<float>(x1);
 
  173     __host__ __device__ constexpr 
void 
  176         const float x1_tmp = ck::type_convert<float>(x0);
 
  177         const float x2_tmp = ck::type_convert<float>(x1);
 
  178         const float y_tmp  = x1_tmp * x2_tmp;
 
  179         y                  = ck::type_convert<bhalf_t>(y_tmp);
 
  183     __host__ __device__ constexpr 
void 
  186         const float x1_tmp = ck::type_convert<float>(x0);
 
  187         const float x2_tmp = ck::type_convert<float>(x1);
 
  188         const float y_tmp  = x1_tmp * x2_tmp;
 
  189         y                  = ck::type_convert<bhalf_t>(y_tmp);
 
  193     __host__ __device__ constexpr 
void 
  196         const float x2_tmp = ck::type_convert<float>(x1);
 
  197         const float y_tmp  = x0 * x2_tmp;
 
  198         y                  = ck::type_convert<bhalf_t>(y_tmp);
 
  202     __host__ __device__ constexpr 
void 
  213     template <
typename Y, 
typename X0, 
typename X1>
 
  214     __host__ __device__ constexpr 
void operator()(Y& y, 
const X0& x0, 
const X1& x1)
 const 
  216         y = ck::type_convert<Y>(
scale_ * ck::type_convert<float>(x0) + ck::type_convert<float>(x1));
 
  220     __host__ __device__ 
void 
  221     operator()<float, float, 
half_t>(
float& y, 
const float& x0, 
const half_t& x1) 
const 
  223         y = 
scale_ * x0 + ck::type_convert<float>(x1);
 
  227     __host__ __device__ 
void 
  228     operator()<float, float, 
bhalf_t>(
float& y, 
const float& x0, 
const bhalf_t& x1) 
const 
  230         y = 
scale_ * x0 + ck::type_convert<float>(x1);
 
  238     template <
typename T>
 
  239     __host__ __device__ constexpr 
void operator()(T& y, 
const T& x0, 
const T& x1) 
const;
 
  242     __host__ __device__ constexpr 
void 
  243     operator()<
float>(
float& y, 
const float& x0, 
const float& x1) 
const 
  249     __host__ __device__ constexpr 
void 
  250     operator()<
double>(
double& y, 
const double& x0, 
const double& x1) 
const 
  256     __host__ __device__ constexpr 
void 
  263     __host__ __device__ constexpr 
void 
  266         const float x1_tmp = ck::type_convert<float>(x0);
 
  267         const float x2_tmp = ck::type_convert<float>(x1);
 
  268         const float y_tmp  = x1_tmp - x2_tmp;
 
  269         y                  = ck::type_convert<bhalf_t>(y_tmp);
 
  273     __host__ __device__ constexpr 
void 
  284     template <
typename Y, 
typename X0, 
typename X1>
 
  285     __host__ __device__ constexpr 
void operator()(Y&, 
const X0&, 
const X1&) 
const;
 
  288     __host__ __device__ constexpr 
void 
  289     operator()<double, double, 
double>(
double& y, 
const double& x0, 
const double& x1) 
const 
  295     __host__ __device__ constexpr 
void 
  296     operator()<float, float, 
float>(
float& y, 
const float& x0, 
const float& x1) 
const 
  302     __host__ __device__ constexpr 
void 
  305         y = type_convert<int8_t>(
alpha_ * type_convert<float>(x0) +
 
  306                                  beta_ * type_convert<float>(x1));
 
  310     __host__ __device__ constexpr 
void 
  313         y = type_convert<half_t>(
alpha_) * x0 + type_convert<half_t>(
beta_) * x1;
 
  317     __host__ __device__ constexpr 
void 
  320         y = type_convert<half_t>(
alpha_ * x0 + 
beta_ * ck::type_convert<float>(x1));
 
  324     __host__ __device__ constexpr 
void 
  327         const float x0_tmp = type_convert<float>(x0);
 
  328         const float x1_tmp = type_convert<float>(x1);
 
  329         const float y_tmp  = 
alpha_ * x0_tmp + 
beta_ * x1_tmp;
 
  330         y                  = type_convert<bhalf_t>(y_tmp);
 
  334     __host__ __device__ constexpr 
void 
  337         const float x1_tmp = ck::type_convert<float>(x1);
 
  343     __host__ __device__ constexpr 
void 
  346         y = type_convert<int8_t>(
alpha_ * type_convert<float>(x0) +
 
  347                                  beta_ * type_convert<float>(x1));
 
  359     template <
typename Y, 
typename X0, 
typename X1>
 
  360     __host__ __device__ constexpr 
void operator()(Y& y, 
const X0& x0, 
const X1& x1) 
const;
 
  363     __host__ __device__ constexpr 
void 
  364     operator()<float, float, 
float>(
float& y, 
const float& x0, 
const float& x1) 
const 
  366         const float a = x0 + x1;
 
  371     __host__ __device__ constexpr 
void 
  372     operator()<double, double, 
double>(
double& y, 
const double& x0, 
const double& x1) 
const 
  374         const double a = x0 + x1;
 
  379     __host__ __device__ constexpr 
void 
  389     __host__ __device__ constexpr 
void 
  392         const float a = x0 + type_convert<float>(x1);
 
  394         y             = type_convert<half_t>(b);
 
  398     __host__ __device__ constexpr 
void 
  399     operator()<float, float, 
half_t>(
float& y, 
const float& x0, 
const half_t& x1) 
const 
  401         const float a = x0 + type_convert<float>(x1);
 
  406     __host__ __device__ constexpr 
void 
  409         const float a = x0 + type_convert<float>(x1);
 
  411         y             = type_convert<bhalf_t>(b);
 
  415     __host__ __device__ constexpr 
void 
  418         const float a = type_convert<float>(x0) + type_convert<float>(x1);
 
  420         y             = type_convert<bhalf_t>(b);
 
  424     __host__ __device__ constexpr 
void 
  425     operator()<int, int, 
int8_t>(
int& y, 
const int& x0, 
const int8_t& x1) 
const 
  432     __host__ __device__ constexpr 
void 
  445     template <
typename Y, 
typename X0, 
typename X1>
 
  446     __host__ __device__ constexpr 
void operator()(Y& y, 
const X0& x0, 
const X1& x1) 
const;
 
  449     __host__ __device__ constexpr 
void 
  450     operator()<float, float, 
float>(
float& y, 
const float& x0, 
const float& x1) 
const 
  452         const float a = x0 + x1;
 
  453         y             = a > 0.0f ? a : 0.0f;
 
  457     __host__ __device__ constexpr 
void 
  458     operator()<double, double, 
double>(
double& y, 
const double& x0, 
const double& x1) 
const 
  460         const double a = x0 + x1;
 
  461         y              = a > 0.0 ? a : 0.0;
 
  465     __host__ __device__ constexpr 
void 
  469         y              = a > type_convert<half_t>(0.0f) ? a : type_convert<half_t>(0.0f);
 
  473     __host__ __device__ constexpr 
void 
  476         const float a = x0 + type_convert<float>(x1);
 
  477         const float b = a > 0.0f ? a : 0.0f;
 
  478         y             = type_convert<half_t>(b);
 
  482     __host__ __device__ constexpr 
void 
  483     operator()<float, float, 
half_t>(
float& y, 
const float& x0, 
const half_t& x1) 
const 
  485         const float a = x0 + type_convert<float>(x1);
 
  486         y             = a > 0.0f ? a : 0.0f;
 
  490     __host__ __device__ constexpr 
void 
  493         const float a = x0 + type_convert<float>(x1);
 
  494         const float b = a > 0.0f ? a : 0.0f;
 
  495         y             = type_convert<bhalf_t>(b);
 
  499     __host__ __device__ constexpr 
void 
  502         const float a = type_convert<float>(x0) + type_convert<float>(x1);
 
  503         const float b = a > 0.0f ? a : 0.0f;
 
  504         y             = type_convert<bhalf_t>(b);
 
  508     __host__ __device__ constexpr 
void 
  509     operator()<int, int, 
int8_t>(
int& y, 
const int& x0, 
const int8_t& x1) 
const 
  516     __host__ __device__ constexpr 
void 
  526     template <
typename T>
 
  527     __host__ __device__ constexpr 
void operator()(T& y, 
const T& x0, 
const T& x1) 
const;
 
  530     __host__ __device__ constexpr 
void 
  531     operator()<
float>(
float& y, 
const float& x0, 
const float& x1) 
const 
  534         float b = a + 
float{3};
 
  535         float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f;
 
  540     __host__ __device__ constexpr 
void 
  541     operator()<
double>(
double& y, 
const double& x0, 
const double& x1) 
const 
  545         double c = (b > 0) * (b > 6.0 ? 6.0 : b) * a * 0.166667;
 
  550     __host__ __device__ constexpr 
void 
  555         float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f;
 
  563     template <
typename E, 
typename C, 
typename D>
 
  564     __host__ __device__ constexpr 
void operator()(E& e, 
const C& c, 
const D& d) 
const;
 
  567     __host__ __device__ constexpr 
void 
  568     operator()<float, float, 
float>(
float& e, 
const float& c, 
const float& d) 
const 
  570         const float x = c + d;
 
  572         FastGelu{}.template operator()<float, 
float>(e, x);
 
  576     __host__ __device__ constexpr 
void 
  585     __host__ __device__ constexpr 
void 
  588         const float x0_f = c + d;
 
  595         e = type_convert<half_t>(x1_f);
 
  599     __host__ __device__ constexpr 
void 
  602         const float x0_f = type_convert<float>(c) + type_convert<float>(d);
 
  606         FastGelu{}.template operator()<float, 
float>(x1_f, x0_f);
 
  608         e = type_convert<bhalf_t>(x1_f);
 
  612     __host__ __device__ constexpr 
void 
  615         const float x0_f = c + type_convert<float>(d);
 
  619         FastGelu{}.template operator()<float, 
float>(x1_f, x0_f);
 
  621         e = type_convert<bhalf_t>(x1_f);
 
  628     template <
typename E, 
typename C, 
typename D>
 
  629     __host__ __device__ constexpr 
void operator()(E& e, 
const C& c, 
const D& d) 
const;
 
  632     __host__ __device__ constexpr 
void 
  633     operator()<float, float, 
float>(
float& e, 
const float& c, 
const float& d) 
const 
  635         const float x = c * d;
 
  637         FastGelu{}.template operator()<float, 
float>(e, x);
 
  641     __host__ __device__ constexpr 
void 
  650     __host__ __device__ constexpr 
void 
  653         const float x0_f = c * d;
 
  660         e = type_convert<half_t>(x1_f);
 
  664     __host__ __device__ constexpr 
void 
  667         const float x0_f = type_convert<float>(c) * type_convert<float>(d);
 
  671         FastGelu{}.template operator()<float, 
float>(x1_f, x0_f);
 
  673         e = type_convert<bhalf_t>(x1_f);
 
  677     __host__ __device__ constexpr 
void 
  680         const float x0_f = c * type_convert<float>(d);
 
  684         FastGelu{}.template operator()<float, 
float>(x1_f, x0_f);
 
  686         e = type_convert<bhalf_t>(x1_f);
 
  693     template <
typename E, 
typename C, 
typename D>
 
  694     __host__ __device__ constexpr 
void operator()(E& e, 
const C& c, 
const D& d) 
const;
 
  697     __host__ __device__ constexpr 
void 
  698     operator()<float, float, 
float>(
float& e, 
const float& c, 
const float& d) 
const 
  700         const float x = c + d;
 
  702         Silu{}.template operator()<
float>(e, x);
 
  706     __host__ __device__ constexpr 
void 
  715     __host__ __device__ constexpr 
void 
  718         const float x0_f = c + d;
 
  722         Silu{}.template operator()<
float>(x1_f, x0_f);
 
  724         e = type_convert<half_t>(x1_f);
 
  728     __host__ __device__ constexpr 
void 
  731         const float x0_f = c + type_convert<float>(d);
 
  735         Silu{}.template operator()<
float>(x1_f, x0_f);
 
  737         e = type_convert<bhalf_t>(x1_f);
 
  744                                      float scale_wei = 1.f,
 
  745                                      float scale_out = 1.f)
 
  750     template <
typename E, 
typename C, 
typename D>
 
  751     __host__ __device__ 
void operator()(E& e, 
const C& c, 
const D& d) 
const;
 
  754     __host__ __device__ 
void 
  755     operator()<
f8_t, float, 
float>(
f8_t& e, 
const float& c, 
const float& d) 
const 
__host__ T ceil(T x)
Definition: math_v2.hpp:331
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
__host__ T floor(T x)
Definition: math_v2.hpp:367
 
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
 
int8_t int8_t
Definition: int8.hpp:20
 
int32_t int32_t
Definition: integer.hpp:10
 
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
 
_Float16 half_t
Definition: data_type.hpp:30
 
ushort bhalf_t
Definition: data_type.hpp:29
 
Definition: numeric_limits.hpp:309
 
Definition: binary_element_wise_operation.hpp:355
 
AddClamp(float floor=0.f, float ceil=NumericLimits< float >::Max())
Definition: binary_element_wise_operation.hpp:356
 
const float ceil_
Definition: binary_element_wise_operation.hpp:440
 
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
 
const float floor_
Definition: binary_element_wise_operation.hpp:437
 
Definition: binary_element_wise_operation.hpp:562
 
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
 
Definition: binary_element_wise_operation.hpp:525
 
__host__ constexpr __device__ void operator()(T &y, const T &x0, const T &x1) const
 
Definition: binary_element_wise_operation.hpp:14
 
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
 
Definition: binary_element_wise_operation.hpp:444
 
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
 
Definition: binary_element_wise_operation.hpp:692
 
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
 
Definition: binary_element_wise_operation.hpp:281
 
Bilinear(float alpha=1.f, float beta=1.f)
Definition: binary_element_wise_operation.hpp:282
 
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &) const
 
float beta_
Definition: binary_element_wise_operation.hpp:351
 
float alpha_
Definition: binary_element_wise_operation.hpp:348
 
Definition: binary_element_wise_operation.hpp:742
 
float scale_in_
Definition: binary_element_wise_operation.hpp:760
 
float scale_wei_
Definition: binary_element_wise_operation.hpp:763
 
__host__ __device__ ConvScaleAdd(float scale_in=1.f, float scale_wei=1.f, float scale_out=1.f)
Definition: binary_element_wise_operation.hpp:743
 
float scale_out_
Definition: binary_element_wise_operation.hpp:764
 
__host__ __device__ void operator()(E &e, const C &c, const D &d) const
 
Definition: unary_element_wise_operation.hpp:866
 
Definition: binary_element_wise_operation.hpp:96
 
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:98
 
Definition: binary_element_wise_operation.hpp:107
 
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:109
 
Definition: binary_element_wise_operation.hpp:627
 
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
 
Definition: binary_element_wise_operation.hpp:118
 
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
 
Definition: binary_element_wise_operation.hpp:210
 
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:214
 
float scale_
Definition: binary_element_wise_operation.hpp:231
 
__host__ __device__ ScaleAdd(float scale=1.f)
Definition: binary_element_wise_operation.hpp:211
 
Definition: unary_element_wise_operation.hpp:1023
 
Definition: binary_element_wise_operation.hpp:237
 
__host__ constexpr __device__ void operator()(T &y, const T &x0, const T &x1) const