10 namespace tensor_operation {
11 namespace element_wise {
15 static constexpr
const char*
name =
"Add";
17 template <
typename Y,
typename X0,
typename X1>
18 __host__ __device__ constexpr
void operator()(Y& y,
const X0& x0,
const X1& x1)
const;
21 __host__ __device__ constexpr
void
22 operator()<
float>(
float& y,
const float& x0,
const float& x1)
const
28 __host__ __device__ constexpr
void
29 operator()<
double>(
double& y,
const double& x0,
const double& x1)
const
35 __host__ __device__ constexpr
void
36 operator()<
float>(
float& y,
const float& x0,
const half_t& x1)
const
38 y = x0 + type_convert<half_t>(x1);
42 __host__ __device__ constexpr
void
43 operator()<
half_t>(
half_t& y,
const float& x0,
const float& x1)
const
45 y = type_convert<half_t>(x0 + x1);
49 __host__ __device__ constexpr
void
52 y = x0 + type_convert<float>(x1);
56 __host__ __device__ constexpr
void
63 __host__ __device__ constexpr
void
64 operator()<
float>(
float& y,
const float& x0,
const bhalf_t& x1)
const
66 const float x1_tmp = ck::type_convert<float>(x1);
71 __host__ __device__ constexpr
void
74 const float x1_tmp = ck::type_convert<float>(x0);
75 const float x2_tmp = ck::type_convert<float>(x1);
76 const float y_tmp = x1_tmp + x2_tmp;
77 y = ck::type_convert<bhalf_t>(y_tmp);
81 __host__ __device__ constexpr
void
84 const float x2_tmp = ck::type_convert<float>(x1);
85 const float y_tmp = x0 + x2_tmp;
86 y = ck::type_convert<bhalf_t>(y_tmp);
90 __host__ __device__ constexpr
void
99 static constexpr
const char*
name =
"Max";
101 template <
typename Y,
typename X0,
typename X1>
102 __host__ __device__
void operator()(Y& y,
const X0& x0,
const X1& x1)
const
104 const Y x0_converted = type_convert<Y>(x0);
105 const Y x1_converted = type_convert<Y>(x1);
112 static constexpr
const char*
name =
"Min";
114 template <
typename Y,
typename X0,
typename X1>
115 __host__ __device__
void operator()(Y& y,
const X0& x0,
const X1& x1)
const
117 const Y x0_converted = type_convert<Y>(x0);
118 const Y x1_converted = type_convert<Y>(x1);
125 static constexpr
const char*
name =
"Multiply";
127 template <
typename Y,
typename X0,
typename X1>
128 __host__ __device__ constexpr
void operator()(Y& y,
const X0& x0,
const X1& x1)
const;
131 __host__ __device__ constexpr
void
132 operator()<
float>(
float& y,
const float& x0,
const float& x1)
const
138 __host__ __device__ constexpr
void
139 operator()<
double>(
double& y,
const double& x0,
const double& x1)
const
145 __host__ __device__ constexpr
void
146 operator()<
float>(
float& y,
const float& x0,
const half_t& x1)
const
148 y = x0 * type_convert<half_t>(x1);
152 __host__ __device__ constexpr
void
155 y = type_convert<half_t>(x0 * x1);
159 __host__ __device__ constexpr
void
162 y = type_convert<half_t>(x0) * x1;
166 __host__ __device__ constexpr
void
173 __host__ __device__ constexpr
void
174 operator()<
float>(
float& y,
const float& x0,
const bhalf_t& x1)
const
176 const float x1_tmp = ck::type_convert<float>(x1);
181 __host__ __device__ constexpr
void
184 const float x1_tmp = ck::type_convert<float>(x0);
185 const float x2_tmp = ck::type_convert<float>(x1);
186 const float y_tmp = x1_tmp * x2_tmp;
187 y = ck::type_convert<bhalf_t>(y_tmp);
191 __host__ __device__ constexpr
void
194 const float x1_tmp = ck::type_convert<float>(x0);
195 const float x2_tmp = ck::type_convert<float>(x1);
196 const float y_tmp = x1_tmp * x2_tmp;
197 y = ck::type_convert<bhalf_t>(y_tmp);
201 __host__ __device__ constexpr
void
204 const float x2_tmp = ck::type_convert<float>(x1);
205 const float y_tmp = x0 * x2_tmp;
206 y = ck::type_convert<bhalf_t>(y_tmp);
210 __host__ __device__ constexpr
void
219 static constexpr
const char*
name =
"ScaleAdd";
223 template <
typename Y,
typename X0,
typename X1>
224 __host__ __device__ constexpr
void operator()(Y& y,
const X0& x0,
const X1& x1)
const
226 y = ck::type_convert<Y>(
scale_ * ck::type_convert<float>(x0) + ck::type_convert<float>(x1));
230 __host__ __device__
void
231 operator()<float, float,
half_t>(
float& y,
const float& x0,
const half_t& x1)
const
233 y =
scale_ * x0 + ck::type_convert<float>(x1);
237 __host__ __device__
void
238 operator()<float, float,
bhalf_t>(
float& y,
const float& x0,
const bhalf_t& x1)
const
240 y =
scale_ * x0 + ck::type_convert<float>(x1);
248 static constexpr
const char*
name =
"Subtract";
250 template <
typename T>
251 __host__ __device__ constexpr
void operator()(T& y,
const T& x0,
const T& x1)
const;
254 __host__ __device__ constexpr
void
255 operator()<
float>(
float& y,
const float& x0,
const float& x1)
const
261 __host__ __device__ constexpr
void
262 operator()<
double>(
double& y,
const double& x0,
const double& x1)
const
268 __host__ __device__ constexpr
void
275 __host__ __device__ constexpr
void
278 const float x1_tmp = ck::type_convert<float>(x0);
279 const float x2_tmp = ck::type_convert<float>(x1);
280 const float y_tmp = x1_tmp - x2_tmp;
281 y = ck::type_convert<bhalf_t>(y_tmp);
285 __host__ __device__ constexpr
void
294 static constexpr
const char*
name =
"Bilinear";
298 template <
typename Y,
typename X0,
typename X1>
299 __host__ __device__ constexpr
void operator()(Y&,
const X0&,
const X1&)
const;
302 __host__ __device__ constexpr
void
303 operator()<double, double,
double>(
double& y,
const double& x0,
const double& x1)
const
309 __host__ __device__ constexpr
void
310 operator()<float, float,
float>(
float& y,
const float& x0,
const float& x1)
const
316 __host__ __device__ constexpr
void
319 y = type_convert<int8_t>(
alpha_ * type_convert<float>(x0) +
320 beta_ * type_convert<float>(x1));
324 __host__ __device__ constexpr
void
327 y = type_convert<half_t>(
alpha_) * x0 + type_convert<half_t>(
beta_) * x1;
331 __host__ __device__ constexpr
void
334 y = type_convert<half_t>(
alpha_ * x0 +
beta_ * ck::type_convert<float>(x1));
338 __host__ __device__ constexpr
void
341 const float x0_tmp = type_convert<float>(x0);
342 const float x1_tmp = type_convert<float>(x1);
343 const float y_tmp =
alpha_ * x0_tmp +
beta_ * x1_tmp;
344 y = type_convert<bhalf_t>(y_tmp);
348 __host__ __device__ constexpr
void
351 const float x1_tmp = ck::type_convert<float>(x1);
357 __host__ __device__ constexpr
void
360 y = type_convert<int8_t>(
alpha_ * type_convert<float>(x0) +
361 beta_ * type_convert<float>(x1));
370 static constexpr
const char*
name =
"AddClamp";
375 template <
typename Y,
typename X0,
typename X1>
376 __host__ __device__ constexpr
void operator()(Y& y,
const X0& x0,
const X1& x1)
const;
379 __host__ __device__ constexpr
void
380 operator()<float, float,
float>(
float& y,
const float& x0,
const float& x1)
const
382 const float a = x0 + x1;
387 __host__ __device__ constexpr
void
388 operator()<double, double,
double>(
double& y,
const double& x0,
const double& x1)
const
390 const double a = x0 + x1;
395 __host__ __device__ constexpr
void
405 __host__ __device__ constexpr
void
408 const float a = x0 + type_convert<float>(x1);
410 y = type_convert<half_t>(b);
414 __host__ __device__ constexpr
void
415 operator()<float, float,
half_t>(
float& y,
const float& x0,
const half_t& x1)
const
417 const float a = x0 + type_convert<float>(x1);
422 __host__ __device__ constexpr
void
425 const float a = x0 + type_convert<float>(x1);
427 y = type_convert<bhalf_t>(b);
431 __host__ __device__ constexpr
void
434 const float a = type_convert<float>(x0) + type_convert<float>(x1);
436 y = type_convert<bhalf_t>(b);
440 __host__ __device__ constexpr
void
441 operator()<int, int,
int8_t>(
int& y,
const int& x0,
const int8_t& x1)
const
448 __host__ __device__ constexpr
void
461 static constexpr
const char*
name =
"AddRelu";
463 template <
typename Y,
typename X0,
typename X1>
464 __host__ __device__ constexpr
void operator()(Y& y,
const X0& x0,
const X1& x1)
const;
467 __host__ __device__ constexpr
void
468 operator()<float, float,
float>(
float& y,
const float& x0,
const float& x1)
const
470 const float a = x0 + x1;
471 y =
a > 0.0f ?
a : 0.0f;
475 __host__ __device__ constexpr
void
476 operator()<double, double,
double>(
double& y,
const double& x0,
const double& x1)
const
478 const double a = x0 + x1;
479 y =
a > 0.0 ?
a : 0.0;
483 __host__ __device__ constexpr
void
487 y =
a > type_convert<half_t>(0.0f) ?
a : type_convert<half_t>(0.0f);
491 __host__ __device__ constexpr
void
494 const float a = x0 + type_convert<float>(x1);
495 const float b =
a > 0.0f ?
a : 0.0f;
496 y = type_convert<half_t>(b);
500 __host__ __device__ constexpr
void
501 operator()<float, float,
half_t>(
float& y,
const float& x0,
const half_t& x1)
const
503 const float a = x0 + type_convert<float>(x1);
504 y =
a > 0.0f ?
a : 0.0f;
508 __host__ __device__ constexpr
void
511 const float a = x0 + type_convert<float>(x1);
512 const float b =
a > 0.0f ?
a : 0.0f;
513 y = type_convert<bhalf_t>(b);
517 __host__ __device__ constexpr
void
520 const float a = type_convert<float>(x0) + type_convert<float>(x1);
521 const float b =
a > 0.0f ?
a : 0.0f;
522 y = type_convert<bhalf_t>(b);
526 __host__ __device__ constexpr
void
527 operator()<int, int,
int8_t>(
int& y,
const int& x0,
const int8_t& x1)
const
534 __host__ __device__ constexpr
void
544 static constexpr
const char*
name =
"AddHardswish";
546 template <
typename T>
547 __host__ __device__ constexpr
void operator()(T& y,
const T& x0,
const T& x1)
const;
550 __host__ __device__ constexpr
void
551 operator()<
float>(
float& y,
const float& x0,
const float& x1)
const
554 float b =
a +
float{3};
555 float c = (b > 0) * (b > 6.0f ? 6.0f : b) *
a * 0.166667f;
560 __host__ __device__ constexpr
void
561 operator()<
double>(
double& y,
const double& x0,
const double& x1)
const
565 double c = (b > 0) * (b > 6.0 ? 6.0 : b) *
a * 0.166667;
570 __host__ __device__ constexpr
void
575 float c = (b > 0) * (b > 6.0f ? 6.0f : b) *
a * 0.166667f;
583 static constexpr
const char*
name =
"AddFastGelu";
585 template <
typename E,
typename C,
typename D>
586 __host__ __device__ constexpr
void operator()(E& e,
const C& c,
const D& d)
const;
589 __host__ __device__ constexpr
void
590 operator()<float, float,
float>(
float& e,
const float& c,
const float& d)
const
592 const float x = c + d;
594 FastGelu{}.template operator()<float,
float>(e, x);
598 __host__ __device__ constexpr
void
607 __host__ __device__ constexpr
void
610 const float x0_f = c + d;
617 e = type_convert<half_t>(x1_f);
621 __host__ __device__ constexpr
void
624 const float x0_f = type_convert<float>(c) + type_convert<float>(d);
628 FastGelu{}.template operator()<float,
float>(x1_f, x0_f);
630 e = type_convert<bhalf_t>(x1_f);
634 __host__ __device__ constexpr
void
637 const float x0_f = c + type_convert<float>(d);
641 FastGelu{}.template operator()<float,
float>(x1_f, x0_f);
643 e = type_convert<bhalf_t>(x1_f);
650 static constexpr
const char*
name =
"MultiplyFastGelu";
652 template <
typename E,
typename C,
typename D>
653 __host__ __device__ constexpr
void operator()(E& e,
const C& c,
const D& d)
const;
656 __host__ __device__ constexpr
void
657 operator()<float, float,
float>(
float& e,
const float& c,
const float& d)
const
659 const float x = c * d;
661 FastGelu{}.template operator()<float,
float>(e, x);
665 __host__ __device__ constexpr
void
674 __host__ __device__ constexpr
void
677 const float x0_f = c * d;
684 e = type_convert<half_t>(x1_f);
688 __host__ __device__ constexpr
void
691 const float x0_f = type_convert<float>(c) * type_convert<float>(d);
695 FastGelu{}.template operator()<float,
float>(x1_f, x0_f);
697 e = type_convert<bhalf_t>(x1_f);
701 __host__ __device__ constexpr
void
704 const float x0_f = c * type_convert<float>(d);
708 FastGelu{}.template operator()<float,
float>(x1_f, x0_f);
710 e = type_convert<bhalf_t>(x1_f);
717 static constexpr
const char*
name =
"AddSilu";
719 template <
typename E,
typename C,
typename D>
720 __host__ __device__ constexpr
void operator()(E& e,
const C& c,
const D& d)
const;
723 __host__ __device__ constexpr
void
724 operator()<float, float,
float>(
float& e,
const float& c,
const float& d)
const
726 const float x = c + d;
728 Silu{}.template operator()<
float>(e, x);
732 __host__ __device__ constexpr
void
741 __host__ __device__ constexpr
void
744 const float x0_f = c + d;
748 Silu{}.template operator()<
float>(x1_f, x0_f);
750 e = type_convert<half_t>(x1_f);
754 __host__ __device__ constexpr
void
757 const float x0_f = c + type_convert<float>(d);
761 Silu{}.template operator()<
float>(x1_f, x0_f);
763 e = type_convert<bhalf_t>(x1_f);
769 static constexpr
const char*
name =
"ConvScaleAdd";
772 float scale_wei = 1.f,
773 float scale_out = 1.f)
778 template <
typename E,
typename C,
typename D>
779 __host__ __device__
void operator()(E& e,
const C& c,
const D& d)
const;
782 __host__ __device__
void
783 operator()<
f8_t, float,
float>(
f8_t& e,
const float& c,
const float& d)
const
__host__ T ceil(T x)
Definition: math_v2.hpp:331
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ T floor(T x)
Definition: math_v2.hpp:367
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1762
_Float16 half_t
Definition: data_type.hpp:31
ushort bhalf_t
Definition: data_type.hpp:30
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1517
signed int int32_t
Definition: stdint.h:123
signed char int8_t
Definition: stdint.h:121
Definition: numeric_limits.hpp:309
Definition: amd_ck_fp8.hpp:36
Definition: binary_element_wise_operation.hpp:369
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:370
AddClamp(float floor=0.f, float ceil=NumericLimits< float >::Max())
Definition: binary_element_wise_operation.hpp:372
const float ceil_
Definition: binary_element_wise_operation.hpp:456
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
const float floor_
Definition: binary_element_wise_operation.hpp:453
Definition: binary_element_wise_operation.hpp:582
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:583
Definition: binary_element_wise_operation.hpp:543
__host__ constexpr __device__ void operator()(T &y, const T &x0, const T &x1) const
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:544
Definition: binary_element_wise_operation.hpp:14
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:15
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:460
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:461
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:716
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:717
Definition: binary_element_wise_operation.hpp:293
Bilinear(float alpha=1.f, float beta=1.f)
Definition: binary_element_wise_operation.hpp:296
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:294
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &) const
float beta_
Definition: binary_element_wise_operation.hpp:365
float alpha_
Definition: binary_element_wise_operation.hpp:362
Definition: binary_element_wise_operation.hpp:768
float scale_in_
Definition: binary_element_wise_operation.hpp:788
float scale_wei_
Definition: binary_element_wise_operation.hpp:791
__host__ __device__ ConvScaleAdd(float scale_in=1.f, float scale_wei=1.f, float scale_out=1.f)
Definition: binary_element_wise_operation.hpp:771
float scale_out_
Definition: binary_element_wise_operation.hpp:792
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:769
__host__ __device__ void operator()(E &e, const C &c, const D &d) const
Definition: unary_element_wise_operation.hpp:924
Definition: binary_element_wise_operation.hpp:98
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:99
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:102
Definition: binary_element_wise_operation.hpp:111
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:112
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:115
Definition: binary_element_wise_operation.hpp:649
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:650
Definition: binary_element_wise_operation.hpp:124
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:125
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:218
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:224
float scale_
Definition: binary_element_wise_operation.hpp:241
__host__ __device__ ScaleAdd(float scale=1.f)
Definition: binary_element_wise_operation.hpp:221
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:219
Definition: unary_element_wise_operation.hpp:1087
Definition: binary_element_wise_operation.hpp:247
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:248
__host__ constexpr __device__ void operator()(T &y, const T &x0, const T &x1) const