13 namespace tensor_operation {
14 namespace element_wise {
36 static constexpr
const char*
name =
"AddReluAdd";
38 template <
typename Y,
typename X0,
typename X1,
typename X2>
39 __host__ __device__ constexpr
void operator()(Y&,
const X0&,
const X1&,
const X2&)
const;
51 __host__ __device__ constexpr
void operator()<float, float, float,
float>(
float& y,
54 const float& x2)
const
57 float b =
a > 0 ?
a : 0;
63 __host__ __device__ constexpr
void operator()<float, float,
half_t,
half_t>(
64 float& y,
const float& x0,
const half_t& x1,
const half_t& x2)
const
67 float b =
a > 0 ?
a : 0;
77 (*this)(y_float, x0, x1, x2);
86 float b =
a > 0 ?
a : 0;
101 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
116 static constexpr
const char*
name =
"AddHardswishAdd";
118 template <
typename Y,
typename X0,
typename X1,
typename X2>
119 __host__ __device__ constexpr
void operator()(Y&,
const X0&,
const X1&,
const X2&)
const;
122 __host__ __device__ constexpr
void operator()<float, float, float,
float>(
float& y,
125 const float& x2)
const
128 float b =
a +
float{3};
129 float c = (b > 0) * (b >
float{6} ?
float{6} : b) *
a *
float{0.166667};
139 float b =
a +
float{3};
140 float c = (b > 0) * (b >
float{6} ?
float{6} : b) *
a *
float{0.166667};
150 static constexpr
const char*
name =
"AddAdd";
152 template <
typename E,
typename C,
typename D0,
typename D1>
153 __host__ __device__
void operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const
158 "Data type is not supported by this operation!");
162 "Data type is not supported by this operation!");
166 "Data type is not supported by this operation!");
170 "Data type is not supported by this operation!");
172 const C y = c + type_convert<C>(d0) + type_convert<C>(d1);
173 e = type_convert<E>(y);
181 static constexpr
const char*
name =
"AddMultiply";
183 template <
typename E,
typename C,
typename D0,
typename D1>
184 __host__ __device__
void operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const;
192 const half_t y = (c + d0) * d1;
201 const half_t y = (type_convert<half_t>(c) + d0) * d1;
205 __host__ __device__
void operator()<float, float,
half_t,
half_t>(
float& e,
210 const float y = (c + d0) * d1;
219 static constexpr
const char*
name =
"MultiplyAdd";
221 template <
typename E,
typename C,
typename D0,
typename D1>
222 __host__ __device__
void operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const;
230 const half_t y = (c * d0) + d1;
239 const half_t y = type_convert<half_t>(c) * d0 + d1;
248 const bhalf_t y = type_convert<bhalf_t>(c) * d0 + d1;
252 __host__ __device__
void operator()<float, float,
half_t,
half_t>(
float& e,
257 const float y = c * d0 + d1;
261 __host__ __device__
void operator()<
half_t, float, float,
float>(
half_t& e,
264 const float& d1)
const
266 const float y = c * d0 + d1;
273 static constexpr
const char*
name =
"MultiplyMultiply";
275 template <
typename E,
typename C,
typename D0,
typename D1>
276 __host__ __device__ constexpr
void
277 operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const;
280 __host__ __device__ constexpr
void operator()<
ck::half_t, float, float,
float>(
281 ck::half_t& e,
const float& c,
const float& d0,
const float& d1)
const
283 const float x0_f = c * d0 * d1;
285 e = ck::type_convert<ck::half_t>(x0_f);
289 __host__ __device__ constexpr
void operator()<
ck::bhalf_t, float, float,
float>(
290 ck::bhalf_t& e,
const float& c,
const float& d0,
const float& d1)
const
292 const float x0_f = c * d0 * d1;
294 e = ck::type_convert<ck::bhalf_t>(x0_f);
302 ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
304 e = ck::type_convert<ck::half_t>(x0_f);
308 __host__ __device__ constexpr
void operator()<
ck::half_t, int, float,
float>(
309 ck::half_t& e,
const int& c,
const float& d0,
const float& d1)
const
312 ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
314 e = ck::type_convert<ck::half_t>(x0_f);
318 __host__ __device__ constexpr
void operator()<
ck::bhalf_t, int, float,
float>(
319 ck::bhalf_t& e,
const int& c,
const float& d0,
const float& d1)
const
322 ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
324 e = ck::type_convert<ck::bhalf_t>(x0_f);
330 static constexpr
const char*
name =
"MultiplyAddFastGelu";
332 template <
typename E,
typename C,
typename D0,
typename D1>
333 __host__ __device__ constexpr
void
334 operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const;
340 const float x0_f = c * ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
344 FastGelu{}.template operator()<float,
float>(x1_f, x0_f);
346 e = ck::type_convert<ck::bhalf_t>(x1_f);
353 static constexpr
const char*
name =
"AddAddFastGelu";
355 template <
typename E,
typename C,
typename D0,
typename D1>
356 __host__ __device__ constexpr
void
357 operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const;
360 __host__ __device__ constexpr
void operator()<float, float, float,
float>(
float& e,
363 const float& d1)
const
365 const float x = c + d0 + d1;
367 FastGelu{}.template operator()<float,
float>(e, x);
374 const half_t x = c + d0 + d1;
383 const float x0_f = c + d0 + d1;
390 e = type_convert<half_t>(x1_f);
397 const float x0_f = c + type_convert<float>(d0) + type_convert<float>(d1);
404 e = type_convert<bhalf_t>(x1_f);
412 type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1);
419 e = type_convert<int8_t>(x1_f);
426 static constexpr
const char*
name =
"ScaleAddScaleAddRelu";
433 template <
typename E,
typename C,
typename D0,
typename D1>
434 __host__ __device__ constexpr
void
435 operator()(E& e,
const C& c,
const D0& d0,
const D1& d1)
const;
438 __host__ __device__ constexpr
void operator()<float, float, float,
float>(
float& e,
441 const float& d1)
const
451 const float x = type_convert<float>(c) *
alpha1_ +
alpha2_ * type_convert<float>(d0) +
452 type_convert<float>(d1);
455 result = x > 0 ? x : 0;
457 e = type_convert<half_t>(result);
464 const float x = type_convert<float>(c) *
alpha1_ +
alpha2_ * type_convert<float>(d0) +
465 type_convert<float>(d1);
468 result = x > 0 ? x : 0;
470 e = type_convert<bhalf_t>(result);
474 __host__ __device__ constexpr
void operator()<
int8_t,
int8_t, float,
float>(
475 int8_t& e,
const int8_t& c,
const float& d0,
const float& d1)
const
477 const float x = type_convert<float>(c) *
alpha1_ +
alpha2_ * d0 + d1;
480 result = x > 0 ? x : 0;
482 e = type_convert<int8_t>(result);
491 static constexpr
const char*
name =
"Normalize";
496 template <
typename T1,
typename T2,
typename T3>
500 const T2& mean_square,
502 const T3& beta)
const;
508 const float& mean_square,
512 using ck::math::sqrt;
514 float variance = mean_square - (mean * mean);
516 float tmp_x = type_convert<float>(x);
517 float tmp_gamma = type_convert<float>(gamma);
518 float tmp_beta = type_convert<float>(beta);
521 ((tmp_x - mean) / sqrt(variance + type_convert<float>(
epsilon_))) * tmp_gamma +
524 y = type_convert<half_t>(tmp_y);
528 __host__ __device__ constexpr
void operator()<float, float,
float>(
float& y,
531 const float& mean_square,
533 const float& beta)
const
535 using ck::math::sqrt;
537 float variance = mean_square - (mean * mean);
538 y = ((x - mean) / sqrt(variance + type_convert<float>(
epsilon_))) * gamma + beta;
542 __host__ __device__ constexpr
void operator()<double, double,
double>(
double& y,
545 const double& mean_square,
547 const double& beta)
const
549 using ck::math::sqrt;
551 double variance = mean_square - (mean * mean);
552 y = ((x - mean) / sqrt(variance +
epsilon_)) * gamma + beta;
564 static constexpr
const char*
name =
"NormalizeInInfer";
568 template <
typename T1,
typename T2,
typename T3,
typename T4>
574 const T4& beta)
const
577 "Data type is not supported by this operation!");
580 using ck::math::sqrt;
584 tmp_x = type_convert<T2>(x);
586 tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(
epsilon_))) *
587 type_convert<T2>(gamma) +
588 type_convert<T2>(beta);
589 y = type_convert<T1>(tmp_y);
598 static constexpr
const char*
name =
"BiasNormalizeInInferClamp";
602 float epsilon = 1e-4)
607 template <
typename T>
617 using ck::math::sqrt;
619 float tmp_x = type_convert<float>(x) + type_convert<float>(bias);
622 ((tmp_x - type_convert<float>(mean)) / sqrt(type_convert<float>(variance) +
epsilon_)) *
623 type_convert<float>(gamma) +
624 type_convert<float>(beta);
626 y = type_convert<T>(tmp_y);
634 const float& variance,
636 const float& beta)
const
639 using ck::math::sqrt;
641 float tmp_y = (((x + bias) - mean) / sqrt(variance +
epsilon_)) * gamma + beta;
649 template <
typename Y,
typename X>
655 static constexpr
const char* name =
"UnaryTypeConvert";
659 y = ck::type_convert<float, ck::bhalf_t>(x);
666 static constexpr
const char* name =
"UnaryTypeConvert";
670 y = ck::type_convert<ck::bhalf_t, float>(x);
__host__ T ceil(T x)
Definition: math_v2.hpp:331
__host__ T floor(T x)
Definition: math_v2.hpp:367
_Float16 half_t
Definition: data_type.hpp:31
ushort bhalf_t
Definition: data_type.hpp:30
__host__ constexpr __device__ Y type_convert(X x)
Definition: type_convert.hpp:98
_BitInt(4) int4_t
Definition: data_type.hpp:32
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1517
signed int int32_t
Definition: stdint.h:123
signed char int8_t
Definition: stdint.h:121
Definition: numeric_limits.hpp:309
Definition: element_wise_operation.hpp:352
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
static constexpr const char * name
Definition: element_wise_operation.hpp:353
Definition: element_wise_operation.hpp:149
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:153
static constexpr const char * name
Definition: element_wise_operation.hpp:150
Definition: element_wise_operation.hpp:115
static constexpr const char * name
Definition: element_wise_operation.hpp:116
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &, const X2 &) const
Definition: element_wise_operation.hpp:180
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
static constexpr const char * name
Definition: element_wise_operation.hpp:181
Definition: element_wise_operation.hpp:35
static constexpr const char * name
Definition: element_wise_operation.hpp:36
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &, const X2 &) const
Definition: element_wise_operation.hpp:597
BiasNormalizeInInferClamp(float floor=0.f, float ceil=NumericLimits< float >::Max(), float epsilon=1e-4)
Definition: element_wise_operation.hpp:600
__host__ constexpr __device__ void operator()(T &y, const T &x, const T &bias, const T &mean, const T &variance, const T &gamma, const T &beta) const
Definition: element_wise_operation.hpp:608
float epsilon_
Definition: element_wise_operation.hpp:646
Clamp clamp_
Definition: element_wise_operation.hpp:643
__host__ constexpr __device__ void operator()(float &y, const float &x, const float &bias, const float &mean, const float &variance, const float &gamma, const float &beta) const
Definition: element_wise_operation.hpp:630
static constexpr const char * name
Definition: element_wise_operation.hpp:598
Definition: unary_element_wise_operation.hpp:811
Definition: unary_element_wise_operation.hpp:924
Definition: element_wise_operation.hpp:329
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
static constexpr const char * name
Definition: element_wise_operation.hpp:330
Definition: element_wise_operation.hpp:218
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
static constexpr const char * name
Definition: element_wise_operation.hpp:219
Definition: element_wise_operation.hpp:272
static constexpr const char * name
Definition: element_wise_operation.hpp:273
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:490
Normalize(double epsilon=1e-4)
Definition: element_wise_operation.hpp:494
double epsilon_
Definition: element_wise_operation.hpp:553
__host__ constexpr __device__ void operator()(T1 &y, const T1 &x, const T2 &mean, const T2 &mean_square, const T3 &gamma, const T3 &beta) const
static constexpr const char * name
Definition: element_wise_operation.hpp:491
Definition: element_wise_operation.hpp:563
static constexpr const char * name
Definition: element_wise_operation.hpp:564
double epsilon_
Definition: element_wise_operation.hpp:590
__host__ constexpr __device__ void operator()(T1 &y, const T1 &x, const T2 &mean, const T2 &variance, const T3 &gamma, const T4 &beta) const
Definition: element_wise_operation.hpp:569
NormalizeInInfer(double epsilon=1e-4)
Definition: element_wise_operation.hpp:566
Definition: element_wise_operation.hpp:425
ScaleAddScaleAddRelu(const float alpha1=1.f, const float alpha2=1.f)
Definition: element_wise_operation.hpp:428
static constexpr const char * name
Definition: element_wise_operation.hpp:426
const float alpha2_
Definition: element_wise_operation.hpp:486
const float alpha1_
Definition: element_wise_operation.hpp:485
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
__host__ __device__ void operator()(ck::bhalf_t &y, float &x) const
Definition: element_wise_operation.hpp:668
__host__ __device__ void operator()(float &y, ck::bhalf_t &x) const
Definition: element_wise_operation.hpp:657
Definition: element_wise_operation.hpp:650