13 namespace tensor_operation {
14 namespace element_wise {
36 template <
typename Y,
typename X0,
typename X1,
typename X2>
37 __host__ __device__ constexpr
void operator()(Y&,
const X0&,
const X1&,
const X2&)
const;
49 __host__ __device__ constexpr
void operator()<float, float, float,
float>(
float& y,
52 const float& x2)
const
55 float b = a > 0 ? a : 0;
65 float b = a > 0 ? a : 0;
75 float b = a > 0 ? a : 0;
85 int32_t b = a > 0 ? a : 0;
90 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
96 int32_t b = a > 0 ? a : 0;
105 template <
typename Y,
typename X0,
typename X1,
typename X2>
106 __host__ __device__ constexpr
void operator()(Y&,
const X0&,
const X1&,
const X2&)
const;
109 __host__ __device__ constexpr
void operator()<float, float, float,
float>(
float& y,
112 const float& x2)
const
115 float b = a +
float{3};
116 float c = (b > 0) * (b >
float{6} ?
float{6} : b) * a *
float{0.166667};
126 float b = a +
float{3};
127 float c = (b > 0) * (b >
float{6} ?
float{6} : b) * a *
float{0.166667};
137 template <
typename E,
typename C,
typename D0,
typename D1>
138 __host__ __device__
void operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const
143 "Data type is not supported by this operation!");
147 "Data type is not supported by this operation!");
151 "Data type is not supported by this operation!");
155 "Data type is not supported by this operation!");
157 const C y = c + type_convert<C>(d0) + type_convert<C>(d1);
158 e = type_convert<E>(y);
166 template <
typename E,
typename C,
typename D0,
typename D1>
167 __host__ __device__
void operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const;
175 const half_t y = (c + d0) * d1;
184 const half_t y = (type_convert<half_t>(c) + d0) * d1;
188 __host__ __device__
void operator()<float, float,
half_t,
half_t>(
float& e,
193 const float y = (c + d0) * d1;
202 template <
typename E,
typename C,
typename D0,
typename D1>
203 __host__ __device__
void operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const;
211 const half_t y = (c * d0) + d1;
220 const half_t y = type_convert<half_t>(c) * d0 + d1;
229 const bhalf_t y = type_convert<bhalf_t>(c) * d0 + d1;
233 __host__ __device__
void operator()<float, float,
half_t,
half_t>(
float& e,
238 const float y = c * d0 + d1;
242 __host__ __device__
void operator()<
half_t, float, float,
float>(
half_t& e,
245 const float& d1)
const
247 const float y = c * d0 + d1;
254 template <
typename E,
typename C,
typename D0,
typename D1>
255 __host__ __device__ constexpr
void
256 operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const;
259 __host__ __device__ constexpr
void operator()<
ck::half_t, float, float,
float>(
260 ck::half_t& e,
const float& c,
const float& d0,
const float& d1)
const
262 const float x0_f = c * d0 * d1;
264 e = ck::type_convert<ck::half_t>(x0_f);
268 __host__ __device__ constexpr
void operator()<
ck::bhalf_t, float, float,
float>(
269 ck::bhalf_t& e,
const float& c,
const float& d0,
const float& d1)
const
271 const float x0_f = c * d0 * d1;
273 e = ck::type_convert<ck::bhalf_t>(x0_f);
281 ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
283 e = ck::type_convert<ck::half_t>(x0_f);
287 __host__ __device__ constexpr
void operator()<
ck::bhalf_t, int, float,
float>(
288 ck::bhalf_t& e,
const int& c,
const float& d0,
const float& d1)
const
291 ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
293 e = ck::type_convert<ck::bhalf_t>(x0_f);
299 template <
typename E,
typename C,
typename D0,
typename D1>
300 __host__ __device__ constexpr
void
301 operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const;
307 const float x0_f = c * ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
311 FastGelu{}.template operator()<float,
float>(x1_f, x0_f);
313 e = ck::type_convert<ck::bhalf_t>(x1_f);
320 template <
typename E,
typename C,
typename D0,
typename D1>
321 __host__ __device__ constexpr
void
322 operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const;
325 __host__ __device__ constexpr
void operator()<float, float, float,
float>(
float& e,
328 const float& d1)
const
330 const float x = c + d0 + d1;
332 FastGelu{}.template operator()<float,
float>(e, x);
339 const half_t x = c + d0 + d1;
348 const float x0_f = c + d0 + d1;
355 e = type_convert<half_t>(x1_f);
362 const float x0_f = c + type_convert<float>(d0) + type_convert<float>(d1);
369 e = type_convert<bhalf_t>(x1_f);
377 type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1);
384 e = type_convert<int8_t>(x1_f);
397 template <
typename E,
typename C,
typename D0,
typename D1>
398 __host__ __device__ constexpr
void
399 operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const;
402 __host__ __device__ constexpr
void operator()<float, float, float,
float>(
float& e,
405 const float& d1)
const
415 const float x = type_convert<float>(c) *
alpha1_ +
alpha2_ * type_convert<float>(d0) +
416 type_convert<float>(d1);
419 result = x > 0 ? x : 0;
421 e = type_convert<half_t>(result);
428 const float x = type_convert<float>(c) *
alpha1_ +
alpha2_ * type_convert<float>(d0) +
429 type_convert<float>(d1);
432 result = x > 0 ? x : 0;
434 e = type_convert<bhalf_t>(result);
438 __host__ __device__ constexpr
void operator()<
int8_t,
int8_t, float,
float>(
439 int8_t& e,
const int8_t& c,
const float& d0,
const float& d1)
const
441 const float x = type_convert<float>(c) *
alpha1_ +
alpha2_ * d0 + d1;
444 result = x > 0 ? x : 0;
446 e = type_convert<int8_t>(result);
458 template <
typename T1,
typename T2,
typename T3>
462 const T2& mean_square,
464 const T3& beta)
const;
470 const float& mean_square,
474 using ck::math::sqrt;
476 float variance = mean_square - (mean * mean);
478 float tmp_x = type_convert<float>(x);
479 float tmp_gamma = type_convert<float>(gamma);
480 float tmp_beta = type_convert<float>(beta);
483 ((tmp_x - mean) / sqrt(variance + type_convert<float>(
epsilon_))) * tmp_gamma +
486 y = type_convert<half_t>(tmp_y);
490 __host__ __device__ constexpr
void operator()<float, float,
float>(
float& y,
493 const float& mean_square,
495 const float& beta)
const
497 using ck::math::sqrt;
499 float variance = mean_square - (mean * mean);
500 y = ((x - mean) / sqrt(variance + type_convert<float>(
epsilon_))) * gamma + beta;
504 __host__ __device__ constexpr
void operator()<double, double,
double>(
double& y,
507 const double& mean_square,
509 const double& beta)
const
511 using ck::math::sqrt;
513 double variance = mean_square - (mean * mean);
514 y = ((x - mean) / sqrt(variance +
epsilon_)) * gamma + beta;
528 template <
typename T1,
typename T2,
typename T3,
typename T4>
534 const T4& beta)
const
537 "Data type is not supported by this operation!");
540 using ck::math::sqrt;
544 tmp_x = type_convert<T2>(x);
546 tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(
epsilon_))) *
547 type_convert<T2>(gamma) +
548 type_convert<T2>(beta);
549 y = type_convert<T1>(tmp_y);
555 template <
typename Y,
typename X>
563 y = ck::type_convert<float, ck::bhalf_t>(x);
572 y = ck::type_convert<ck::bhalf_t, float>(x);
int8_t int8_t
Definition: int8.hpp:20
_Float16 half_t
Definition: data_type.hpp:25
ushort bhalf_t
Definition: data_type.hpp:24
__host__ constexpr __device__ Y type_convert(X x)
Definition: type_convert.hpp:80
_BitInt(4) int4_t
Definition: data_type.hpp:26
Definition: element_wise_operation.hpp:319
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:136
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:138
Definition: element_wise_operation.hpp:104
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &, const X2 &) const
Definition: element_wise_operation.hpp:165
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:35
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &, const X2 &) const
Definition: unary_element_wise_operation.hpp:688
Definition: element_wise_operation.hpp:298
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:201
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:253
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:454
Normalize(double epsilon=1e-4)
Definition: element_wise_operation.hpp:456
double epsilon_
Definition: element_wise_operation.hpp:515
__host__ constexpr __device__ void operator()(T1 &y, const T1 &x, const T2 &mean, const T2 &mean_square, const T3 &gamma, const T3 &beta) const
Definition: element_wise_operation.hpp:525
double epsilon_
Definition: element_wise_operation.hpp:550
__host__ constexpr __device__ void operator()(T1 &y, const T1 &x, const T2 &mean, const T2 &variance, const T3 &gamma, const T4 &beta) const
Definition: element_wise_operation.hpp:529
NormalizeInInfer(double epsilon=1e-4)
Definition: element_wise_operation.hpp:526
Definition: element_wise_operation.hpp:390
ScaleAddScaleAddRelu(const float alpha1=1.f, const float alpha2=1.f)
Definition: element_wise_operation.hpp:392
const float alpha2_
Definition: element_wise_operation.hpp:450
const float alpha1_
Definition: element_wise_operation.hpp:449
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
__host__ __device__ void operator()(ck::bhalf_t &y, float &x) const
Definition: element_wise_operation.hpp:570
__host__ __device__ void operator()(float &y, ck::bhalf_t &x) const
Definition: element_wise_operation.hpp:561
Definition: element_wise_operation.hpp:556