10 namespace tensor_operation {
11 namespace element_wise {
15 static constexpr
const char*
name =
"Add";
17 template <
typename Y,
typename X0,
typename X1>
18 __host__ __device__ constexpr
void operator()(Y& y,
const X0& x0,
const X1& x1)
const;
21 __host__ __device__ constexpr
void
22 operator()<
float>(
float& y,
const float& x0,
const float& x1)
const
28 __host__ __device__ constexpr
void
29 operator()<
double>(
double& y,
const double& x0,
const double& x1)
const
35 __host__ __device__ constexpr
void
36 operator()<
float>(
float& y,
const float& x0,
const half_t& x1)
const
38 y = x0 + type_convert<half_t>(x1);
42 __host__ __device__ constexpr
void
43 operator()<
half_t>(
half_t& y,
const float& x0,
const float& x1)
const
45 y = type_convert<half_t>(x0 + x1);
49 __host__ __device__ constexpr
void
52 y = x0 + type_convert<float>(x1);
56 __host__ __device__ constexpr
void
63 __host__ __device__ constexpr
void
64 operator()<
float>(
float& y,
const float& x0,
const bhalf_t& x1)
const
66 const float x1_tmp = ck::type_convert<float>(x1);
71 __host__ __device__ constexpr
void
74 const float x1_tmp = ck::type_convert<float>(x0);
75 const float x2_tmp = ck::type_convert<float>(x1);
76 const float y_tmp = x1_tmp + x2_tmp;
77 y = ck::type_convert<bhalf_t>(y_tmp);
81 __host__ __device__ constexpr
void
84 const float x2_tmp = ck::type_convert<float>(x1);
85 const float y_tmp = x0 + x2_tmp;
86 y = ck::type_convert<bhalf_t>(y_tmp);
90 __host__ __device__ constexpr
void
99 static constexpr
const char*
name =
"Max";
101 template <
typename Y,
typename X0,
typename X1>
102 __host__ __device__
void operator()(Y& y,
const X0& x0,
const X1& x1)
const
104 const Y x0_converted = type_convert<Y>(x0);
105 const Y x1_converted = type_convert<Y>(x1);
112 static constexpr
const char*
name =
"Min";
114 template <
typename Y,
typename X0,
typename X1>
115 __host__ __device__
void operator()(Y& y,
const X0& x0,
const X1& x1)
const
117 const Y x0_converted = type_convert<Y>(x0);
118 const Y x1_converted = type_convert<Y>(x1);
125 static constexpr
const char*
name =
"Multiply";
127 template <
typename Y,
typename X0,
typename X1>
128 __host__ __device__ constexpr
void operator()(Y& y,
const X0& x0,
const X1& x1)
const;
131 __host__ __device__ constexpr
void
132 operator()<
float>(
float& y,
const float& x0,
const float& x1)
const
138 __host__ __device__ constexpr
void
139 operator()<
double>(
double& y,
const double& x0,
const double& x1)
const
145 __host__ __device__ constexpr
void
146 operator()<
float>(
float& y,
const float& x0,
const half_t& x1)
const
148 y = x0 * type_convert<half_t>(x1);
152 __host__ __device__ constexpr
void
155 y = type_convert<half_t>(x0 * x1);
159 __host__ __device__ constexpr
void
162 y = type_convert<half_t>(x0) * x1;
166 __host__ __device__ constexpr
void
173 __host__ __device__ constexpr
void
174 operator()<
float>(
float& y,
const float& x0,
const bhalf_t& x1)
const
176 const float x1_tmp = ck::type_convert<float>(x1);
181 __host__ __device__ constexpr
void
184 const float x1_tmp = ck::type_convert<float>(x0);
185 const float x2_tmp = ck::type_convert<float>(x1);
186 const float y_tmp = x1_tmp * x2_tmp;
187 y = ck::type_convert<bhalf_t>(y_tmp);
191 __host__ __device__ constexpr
void
194 const float x1_tmp = ck::type_convert<float>(x0);
195 const float x2_tmp = ck::type_convert<float>(x1);
196 const float y_tmp = x1_tmp * x2_tmp;
197 y = ck::type_convert<bhalf_t>(y_tmp);
201 __host__ __device__ constexpr
void
204 const float x2_tmp = ck::type_convert<float>(x1);
205 const float y_tmp = x0 * x2_tmp;
206 y = ck::type_convert<bhalf_t>(y_tmp);
210 __host__ __device__ constexpr
void
219 static constexpr
const char*
name =
"ScaleAdd";
223 template <
typename Y,
typename X0,
typename X1>
224 __host__ __device__ constexpr
void operator()(Y& y,
const X0& x0,
const X1& x1)
const
226 y = ck::type_convert<Y>(
scale_ * ck::type_convert<float>(x0) + ck::type_convert<float>(x1));
230 __host__ __device__
void
231 operator()<float, float,
half_t>(
float& y,
const float& x0,
const half_t& x1)
const
233 y =
scale_ * x0 + ck::type_convert<float>(x1);
237 __host__ __device__
void
238 operator()<float, float,
bhalf_t>(
float& y,
const float& x0,
const bhalf_t& x1)
const
240 y =
scale_ * x0 + ck::type_convert<float>(x1);
248 static constexpr
const char*
name =
"Subtract";
250 template <
typename T>
251 __host__ __device__ constexpr
void operator()(T& y,
const T& x0,
const T& x1)
const;
254 __host__ __device__ constexpr
void
255 operator()<
float>(
float& y,
const float& x0,
const float& x1)
const
261 __host__ __device__ constexpr
void
262 operator()<
double>(
double& y,
const double& x0,
const double& x1)
const
268 __host__ __device__ constexpr
void
275 __host__ __device__ constexpr
void
278 const float x1_tmp = ck::type_convert<float>(x0);
279 const float x2_tmp = ck::type_convert<float>(x1);
280 const float y_tmp = x1_tmp - x2_tmp;
281 y = ck::type_convert<bhalf_t>(y_tmp);
285 __host__ __device__ constexpr
void
294 static constexpr
const char*
name =
"Bilinear";
298 template <
typename Y,
typename X0,
typename X1>
299 __host__ __device__ constexpr
void operator()(Y&,
const X0&,
const X1&)
const;
302 __host__ __device__ constexpr
void
303 operator()<double, double,
double>(
double& y,
const double& x0,
const double& x1)
const
309 __host__ __device__ constexpr
void
310 operator()<float, float,
float>(
float& y,
const float& x0,
const float& x1)
const
316 __host__ __device__ constexpr
void
319 y = type_convert<int8_t>(
alpha_ * type_convert<float>(x0) +
320 beta_ * type_convert<float>(x1));
324 __host__ __device__ constexpr
void
327 y = type_convert<half_t>(
alpha_) * x0 + type_convert<half_t>(
beta_) * x1;
331 __host__ __device__ constexpr
void
334 y = type_convert<half_t>(
alpha_ * x0 +
beta_ * ck::type_convert<float>(x1));
338 __host__ __device__ constexpr
void
341 const float x0_tmp = type_convert<float>(x0);
342 const float x1_tmp = type_convert<float>(x1);
343 const float y_tmp =
alpha_ * x0_tmp +
beta_ * x1_tmp;
344 y = type_convert<bhalf_t>(y_tmp);
348 __host__ __device__ constexpr
void
351 y = type_convert<bhalf_t>(
alpha_ * x0 +
beta_ * ck::type_convert<float>(x1));
355 __host__ __device__ constexpr
void
358 y = type_convert<int8_t>(
alpha_ * type_convert<float>(x0) +
359 beta_ * type_convert<float>(x1));
368 static constexpr
const char*
name =
"AddClamp";
373 template <
typename Y,
typename X0,
typename X1>
374 __host__ __device__ constexpr
void operator()(Y& y,
const X0& x0,
const X1& x1)
const;
377 __host__ __device__ constexpr
void
378 operator()<float, float,
float>(
float& y,
const float& x0,
const float& x1)
const
380 const float a = x0 + x1;
385 __host__ __device__ constexpr
void
386 operator()<double, double,
double>(
double& y,
const double& x0,
const double& x1)
const
388 const double a = x0 + x1;
393 __host__ __device__ constexpr
void
403 __host__ __device__ constexpr
void
406 const float a = x0 + type_convert<float>(x1);
408 y = type_convert<half_t>(b);
412 __host__ __device__ constexpr
void
413 operator()<float, float,
half_t>(
float& y,
const float& x0,
const half_t& x1)
const
415 const float a = x0 + type_convert<float>(x1);
420 __host__ __device__ constexpr
void
423 const float a = x0 + type_convert<float>(x1);
425 y = type_convert<bhalf_t>(b);
429 __host__ __device__ constexpr
void
432 const float a = type_convert<float>(x0) + type_convert<float>(x1);
434 y = type_convert<bhalf_t>(b);
438 __host__ __device__ constexpr
void
439 operator()<int, int,
int8_t>(
int& y,
const int& x0,
const int8_t& x1)
const
446 __host__ __device__ constexpr
void
459 static constexpr
const char*
name =
"AddRelu";
461 template <
typename Y,
typename X0,
typename X1>
462 __host__ __device__ constexpr
void operator()(Y& y,
const X0& x0,
const X1& x1)
const;
465 __host__ __device__ constexpr
void
466 operator()<float, float,
float>(
float& y,
const float& x0,
const float& x1)
const
468 const float a = x0 + x1;
469 y =
a > 0.0f ?
a : 0.0f;
473 __host__ __device__ constexpr
void
474 operator()<double, double,
double>(
double& y,
const double& x0,
const double& x1)
const
476 const double a = x0 + x1;
477 y =
a > 0.0 ?
a : 0.0;
481 __host__ __device__ constexpr
void
485 y =
a > type_convert<half_t>(0.0f) ?
a : type_convert<half_t>(0.0f);
489 __host__ __device__ constexpr
void
492 const float a = x0 + type_convert<float>(x1);
493 const float b =
a > 0.0f ?
a : 0.0f;
494 y = type_convert<half_t>(b);
498 __host__ __device__ constexpr
void
499 operator()<float, float,
half_t>(
float& y,
const float& x0,
const half_t& x1)
const
501 const float a = x0 + type_convert<float>(x1);
502 y =
a > 0.0f ?
a : 0.0f;
506 __host__ __device__ constexpr
void
509 const float a = x0 + type_convert<float>(x1);
510 const float b =
a > 0.0f ?
a : 0.0f;
511 y = type_convert<bhalf_t>(b);
515 __host__ __device__ constexpr
void
518 const float a = type_convert<float>(x0) + type_convert<float>(x1);
519 const float b =
a > 0.0f ?
a : 0.0f;
520 y = type_convert<bhalf_t>(b);
524 __host__ __device__ constexpr
void
525 operator()<int, int,
int8_t>(
int& y,
const int& x0,
const int8_t& x1)
const
532 __host__ __device__ constexpr
void
542 static constexpr
const char*
name =
"AddHardswish";
544 template <
typename T>
545 __host__ __device__ constexpr
void operator()(T& y,
const T& x0,
const T& x1)
const;
548 __host__ __device__ constexpr
void
549 operator()<
float>(
float& y,
const float& x0,
const float& x1)
const
552 float b =
a +
float{3};
553 float c = (b > 0) * (b > 6.0f ? 6.0f : b) *
a * 0.166667f;
558 __host__ __device__ constexpr
void
559 operator()<
double>(
double& y,
const double& x0,
const double& x1)
const
563 double c = (b > 0) * (b > 6.0 ? 6.0 : b) *
a * 0.166667;
568 __host__ __device__ constexpr
void
573 float c = (b > 0) * (b > 6.0f ? 6.0f : b) *
a * 0.166667f;
581 static constexpr
const char*
name =
"AddFastGelu";
583 template <
typename E,
typename C,
typename D>
584 __host__ __device__ constexpr
void operator()(E& e,
const C& c,
const D& d)
const;
587 __host__ __device__ constexpr
void
588 operator()<float, float,
float>(
float& e,
const float& c,
const float& d)
const
590 const float x = c + d;
592 FastGelu{}.template operator()<float,
float>(e, x);
596 __host__ __device__ constexpr
void
605 __host__ __device__ constexpr
void
608 const float x0_f = c + d;
615 e = type_convert<half_t>(x1_f);
619 __host__ __device__ constexpr
void
622 const float x0_f = type_convert<float>(c) + type_convert<float>(d);
626 FastGelu{}.template operator()<float,
float>(x1_f, x0_f);
628 e = type_convert<bhalf_t>(x1_f);
632 __host__ __device__ constexpr
void
635 const float x0_f = c + type_convert<float>(d);
639 FastGelu{}.template operator()<float,
float>(x1_f, x0_f);
641 e = type_convert<bhalf_t>(x1_f);
648 static constexpr
const char*
name =
"MultiplyFastGelu";
650 template <
typename E,
typename C,
typename D>
651 __host__ __device__ constexpr
void operator()(E& e,
const C& c,
const D& d)
const;
654 __host__ __device__ constexpr
void
655 operator()<float, float,
float>(
float& e,
const float& c,
const float& d)
const
657 const float x = c * d;
659 FastGelu{}.template operator()<float,
float>(e, x);
663 __host__ __device__ constexpr
void
672 __host__ __device__ constexpr
void
675 const float x0_f = c * d;
682 e = type_convert<half_t>(x1_f);
686 __host__ __device__ constexpr
void
689 const float x0_f = type_convert<float>(c) * type_convert<float>(d);
693 FastGelu{}.template operator()<float,
float>(x1_f, x0_f);
695 e = type_convert<bhalf_t>(x1_f);
699 __host__ __device__ constexpr
void
702 const float x0_f = c * type_convert<float>(d);
706 FastGelu{}.template operator()<float,
float>(x1_f, x0_f);
708 e = type_convert<bhalf_t>(x1_f);
715 static constexpr
const char*
name =
"AddSilu";
717 template <
typename E,
typename C,
typename D>
718 __host__ __device__ constexpr
void operator()(E& e,
const C& c,
const D& d)
const;
721 __host__ __device__ constexpr
void
722 operator()<float, float,
float>(
float& e,
const float& c,
const float& d)
const
724 const float x = c + d;
726 Silu{}.template operator()<
float>(e, x);
730 __host__ __device__ constexpr
void
739 __host__ __device__ constexpr
void
742 const float x0_f = c + d;
746 Silu{}.template operator()<
float>(x1_f, x0_f);
748 e = type_convert<half_t>(x1_f);
752 __host__ __device__ constexpr
void
755 const float x0_f = c + type_convert<float>(d);
759 Silu{}.template operator()<
float>(x1_f, x0_f);
761 e = type_convert<bhalf_t>(x1_f);
767 static constexpr
const char*
name =
"ConvScaleAdd";
770 float scale_wei = 1.f,
771 float scale_out = 1.f)
776 template <
typename E,
typename C,
typename D>
777 __host__ __device__
void operator()(E& e,
const C& c,
const D& d)
const;
780 __host__ __device__
void
781 operator()<
f8_t, float,
float>(
f8_t& e,
const float& c,
const float& d)
const
__host__ T ceil(T x)
Definition: math_v2.hpp:331
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ T floor(T x)
Definition: math_v2.hpp:367
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1762
_Float16 half_t
Definition: data_type.hpp:31
ushort bhalf_t
Definition: data_type.hpp:30
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1517
signed int int32_t
Definition: stdint.h:123
signed char int8_t
Definition: stdint.h:121
Definition: numeric_limits.hpp:310
Definition: amd_ck_fp8.hpp:36
Definition: binary_element_wise_operation.hpp:367
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:368
AddClamp(float floor=0.f, float ceil=NumericLimits< float >::Max())
Definition: binary_element_wise_operation.hpp:370
const float ceil_
Definition: binary_element_wise_operation.hpp:454
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
const float floor_
Definition: binary_element_wise_operation.hpp:451
Definition: binary_element_wise_operation.hpp:580
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:581
Definition: binary_element_wise_operation.hpp:541
__host__ constexpr __device__ void operator()(T &y, const T &x0, const T &x1) const
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:542
Definition: binary_element_wise_operation.hpp:14
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:15
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:458
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:459
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:714
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:715
Definition: binary_element_wise_operation.hpp:293
Bilinear(float alpha=1.f, float beta=1.f)
Definition: binary_element_wise_operation.hpp:296
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:294
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &) const
float beta_
Definition: binary_element_wise_operation.hpp:363
float alpha_
Definition: binary_element_wise_operation.hpp:360
Definition: binary_element_wise_operation.hpp:766
float scale_in_
Definition: binary_element_wise_operation.hpp:786
float scale_wei_
Definition: binary_element_wise_operation.hpp:789
__host__ __device__ ConvScaleAdd(float scale_in=1.f, float scale_wei=1.f, float scale_out=1.f)
Definition: binary_element_wise_operation.hpp:769
float scale_out_
Definition: binary_element_wise_operation.hpp:790
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:767
__host__ __device__ void operator()(E &e, const C &c, const D &d) const
Definition: unary_element_wise_operation.hpp:924
Definition: binary_element_wise_operation.hpp:98
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:99
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:102
Definition: binary_element_wise_operation.hpp:111
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:112
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:115
Definition: binary_element_wise_operation.hpp:647
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:648
Definition: binary_element_wise_operation.hpp:124
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:125
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:218
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:224
float scale_
Definition: binary_element_wise_operation.hpp:241
__host__ __device__ ScaleAdd(float scale=1.f)
Definition: binary_element_wise_operation.hpp:221
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:219
Definition: unary_element_wise_operation.hpp:1087
Definition: binary_element_wise_operation.hpp:247
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:248
__host__ constexpr __device__ void operator()(T &y, const T &x0, const T &x1) const