10 namespace tensor_operation {
11 namespace element_wise {
15 template <
typename Y,
typename X0,
typename X1>
16 __host__ __device__ constexpr
void operator()(Y& y,
const X0& x0,
const X1& x1)
const;
19 __host__ __device__ constexpr
void
20 operator()<
float>(
float& y,
const float& x0,
const float& x1)
const
26 __host__ __device__ constexpr
void
27 operator()<
double>(
double& y,
const double& x0,
const double& x1)
const
33 __host__ __device__ constexpr
void
34 operator()<
float>(
float& y,
const float& x0,
const half_t& x1)
const
36 y = x0 + type_convert<half_t>(x1);
40 __host__ __device__ constexpr
void
41 operator()<
half_t>(
half_t& y,
const float& x0,
const float& x1)
const
43 y = type_convert<half_t>(x0 + x1);
47 __host__ __device__ constexpr
void
50 y = type_convert<half_t>(x0) + x1;
54 __host__ __device__ constexpr
void
61 __host__ __device__ constexpr
void
62 operator()<
float>(
float& y,
const float& x0,
const bhalf_t& x1)
const
64 const float x1_tmp = ck::type_convert<float>(x1);
69 __host__ __device__ constexpr
void
72 const float x1_tmp = ck::type_convert<float>(x0);
73 const float x2_tmp = ck::type_convert<float>(x1);
74 const float y_tmp = x1_tmp + x2_tmp;
75 y = ck::type_convert<bhalf_t>(y_tmp);
79 __host__ __device__ constexpr
void
82 const float x2_tmp = ck::type_convert<float>(x1);
83 const float y_tmp = x0 + x2_tmp;
84 y = ck::type_convert<bhalf_t>(y_tmp);
88 __host__ __device__ constexpr
void
97 template <
typename Y,
typename X0,
typename X1>
98 __host__ __device__
void operator()(Y& y,
const X0& x0,
const X1& x1)
const
100 const Y x0_converted = type_convert<Y>(x0);
101 const Y x1_converted = type_convert<Y>(x1);
108 template <
typename Y,
typename X0,
typename X1>
109 __host__ __device__
void operator()(Y& y,
const X0& x0,
const X1& x1)
const
111 const Y x0_converted = type_convert<Y>(x0);
112 const Y x1_converted = type_convert<Y>(x1);
119 template <
typename Y,
typename X0,
typename X1>
120 __host__ __device__ constexpr
void operator()(Y& y,
const X0& x0,
const X1& x1)
const;
123 __host__ __device__ constexpr
void
124 operator()<
float>(
float& y,
const float& x0,
const float& x1)
const
130 __host__ __device__ constexpr
void
131 operator()<
double>(
double& y,
const double& x0,
const double& x1)
const
137 __host__ __device__ constexpr
void
138 operator()<
float>(
float& y,
const float& x0,
const half_t& x1)
const
140 y = x0 * type_convert<half_t>(x1);
144 __host__ __device__ constexpr
void
147 y = type_convert<half_t>(x0 * x1);
151 __host__ __device__ constexpr
void
154 y = type_convert<half_t>(x0) * x1;
158 __host__ __device__ constexpr
void
165 __host__ __device__ constexpr
void
166 operator()<
float>(
float& y,
const float& x0,
const bhalf_t& x1)
const
168 const float x1_tmp = ck::type_convert<float>(x1);
173 __host__ __device__ constexpr
void
176 const float x1_tmp = ck::type_convert<float>(x0);
177 const float x2_tmp = ck::type_convert<float>(x1);
178 const float y_tmp = x1_tmp * x2_tmp;
179 y = ck::type_convert<bhalf_t>(y_tmp);
183 __host__ __device__ constexpr
void
186 const float x1_tmp = ck::type_convert<float>(x0);
187 const float x2_tmp = ck::type_convert<float>(x1);
188 const float y_tmp = x1_tmp * x2_tmp;
189 y = ck::type_convert<bhalf_t>(y_tmp);
193 __host__ __device__ constexpr
void
196 const float x2_tmp = ck::type_convert<float>(x1);
197 const float y_tmp = x0 * x2_tmp;
198 y = ck::type_convert<bhalf_t>(y_tmp);
202 __host__ __device__ constexpr
void
213 template <
typename Y,
typename X0,
typename X1>
214 __host__ __device__ constexpr
void operator()(Y& y,
const X0& x0,
const X1& x1)
const
216 y = ck::type_convert<Y>(
scale_ * ck::type_convert<float>(x0) + ck::type_convert<float>(x1));
220 __host__ __device__
void
221 operator()<float, float,
half_t>(
float& y,
const float& x0,
const half_t& x1)
const
223 y =
scale_ * x0 + ck::type_convert<float>(x1);
227 __host__ __device__
void
228 operator()<float, float,
bhalf_t>(
float& y,
const float& x0,
const bhalf_t& x1)
const
230 y =
scale_ * x0 + ck::type_convert<float>(x1);
238 template <
typename T>
239 __host__ __device__ constexpr
void operator()(T& y,
const T& x0,
const T& x1)
const;
242 __host__ __device__ constexpr
void
243 operator()<
float>(
float& y,
const float& x0,
const float& x1)
const
249 __host__ __device__ constexpr
void
250 operator()<
double>(
double& y,
const double& x0,
const double& x1)
const
256 __host__ __device__ constexpr
void
263 __host__ __device__ constexpr
void
266 const float x1_tmp = ck::type_convert<float>(x0);
267 const float x2_tmp = ck::type_convert<float>(x1);
268 const float y_tmp = x1_tmp - x2_tmp;
269 y = ck::type_convert<bhalf_t>(y_tmp);
273 __host__ __device__ constexpr
void
284 template <
typename Y,
typename X0,
typename X1>
285 __host__ __device__ constexpr
void operator()(Y&,
const X0&,
const X1&)
const;
288 __host__ __device__ constexpr
void
289 operator()<double, double,
double>(
double& y,
const double& x0,
const double& x1)
const
295 __host__ __device__ constexpr
void
296 operator()<float, float,
float>(
float& y,
const float& x0,
const float& x1)
const
302 __host__ __device__ constexpr
void
305 y = type_convert<int8_t>(
alpha_ * type_convert<float>(x0) +
306 beta_ * type_convert<float>(x1));
310 __host__ __device__ constexpr
void
313 y = type_convert<half_t>(
alpha_) * x0 + type_convert<half_t>(
beta_) * x1;
317 __host__ __device__ constexpr
void
320 y = type_convert<half_t>(
alpha_ * x0 +
beta_ * ck::type_convert<float>(x1));
324 __host__ __device__ constexpr
void
327 const float x0_tmp = type_convert<float>(x0);
328 const float x1_tmp = type_convert<float>(x1);
329 const float y_tmp =
alpha_ * x0_tmp +
beta_ * x1_tmp;
330 y = type_convert<bhalf_t>(y_tmp);
334 __host__ __device__ constexpr
void
337 const float x1_tmp = ck::type_convert<float>(x1);
343 __host__ __device__ constexpr
void
346 y = type_convert<int8_t>(
alpha_ * type_convert<float>(x0) +
347 beta_ * type_convert<float>(x1));
356 template <
typename Y,
typename X0,
typename X1>
357 __host__ __device__ constexpr
void operator()(Y& y,
const X0& x0,
const X1& x1)
const;
360 __host__ __device__ constexpr
void
361 operator()<float, float,
float>(
float& y,
const float& x0,
const float& x1)
const
363 const float a = x0 + x1;
364 y = a > 0.0f ? a : 0.0f;
368 __host__ __device__ constexpr
void
369 operator()<double, double,
double>(
double& y,
const double& x0,
const double& x1)
const
371 const double a = x0 + x1;
372 y = a > 0.0 ? a : 0.0;
376 __host__ __device__ constexpr
void
380 y = a > type_convert<half_t>(0.0f) ? a : type_convert<half_t>(0.0f);
384 __host__ __device__ constexpr
void
387 const float a = x0 + x1;
388 y = a > type_convert<half_t>(0.0f) ? a : type_convert<half_t>(0.0f);
392 __host__ __device__ constexpr
void
393 operator()<float, float,
half_t>(
float& y,
const float& x0,
const half_t& x1)
const
395 const float a = x0 + type_convert<float>(x1);
396 y = a > 0.0f ? a : 0.0f;
400 __host__ __device__ constexpr
void
403 const float a = x0 + type_convert<float>(x1);
404 y = a > type_convert<bhalf_t>(0.0f) ? a : type_convert<bhalf_t>(0.0f);
408 __host__ __device__ constexpr
void
409 operator()<int, int,
int8_t>(
int& y,
const int& x0,
const int8_t& x1)
const
416 __host__ __device__ constexpr
void
426 template <
typename T>
427 __host__ __device__ constexpr
void operator()(T& y,
const T& x0,
const T& x1)
const;
430 __host__ __device__ constexpr
void
431 operator()<
float>(
float& y,
const float& x0,
const float& x1)
const
434 float b = a +
float{3};
435 float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f;
440 __host__ __device__ constexpr
void
441 operator()<
double>(
double& y,
const double& x0,
const double& x1)
const
445 double c = (b > 0) * (b > 6.0 ? 6.0 : b) * a * 0.166667;
450 __host__ __device__ constexpr
void
455 float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f;
463 template <
typename E,
typename C,
typename D>
464 __host__ __device__ constexpr
void operator()(E& e,
const C& c,
const D& d)
const;
467 __host__ __device__ constexpr
void
468 operator()<float, float,
float>(
float& e,
const float& c,
const float& d)
const
470 const float x = c + d;
472 FastGelu{}.template operator()<float,
float>(e, x);
476 __host__ __device__ constexpr
void
485 __host__ __device__ constexpr
void
488 const float x0_f = c + d;
495 e = type_convert<half_t>(x1_f);
499 __host__ __device__ constexpr
void
502 const float x0_f = type_convert<float>(c) + type_convert<float>(d);
506 FastGelu{}.template operator()<float,
float>(x1_f, x0_f);
508 e = type_convert<bhalf_t>(x1_f);
512 __host__ __device__ constexpr
void
515 const float x0_f = c + type_convert<float>(d);
519 FastGelu{}.template operator()<float,
float>(x1_f, x0_f);
521 e = type_convert<bhalf_t>(x1_f);
528 template <
typename E,
typename C,
typename D>
529 __host__ __device__ constexpr
void operator()(E& e,
const C& c,
const D& d)
const;
532 __host__ __device__ constexpr
void
533 operator()<float, float,
float>(
float& e,
const float& c,
const float& d)
const
535 const float x = c * d;
537 FastGelu{}.template operator()<float,
float>(e, x);
541 __host__ __device__ constexpr
void
550 __host__ __device__ constexpr
void
553 const float x0_f = c * d;
560 e = type_convert<half_t>(x1_f);
564 __host__ __device__ constexpr
void
567 const float x0_f = type_convert<float>(c) * type_convert<float>(d);
571 FastGelu{}.template operator()<float,
float>(x1_f, x0_f);
573 e = type_convert<bhalf_t>(x1_f);
577 __host__ __device__ constexpr
void
580 const float x0_f = c * type_convert<float>(d);
584 FastGelu{}.template operator()<float,
float>(x1_f, x0_f);
586 e = type_convert<bhalf_t>(x1_f);
593 template <
typename E,
typename C,
typename D>
594 __host__ __device__ constexpr
void operator()(E& e,
const C& c,
const D& d)
const;
597 __host__ __device__ constexpr
void
598 operator()<float, float,
float>(
float& e,
const float& c,
const float& d)
const
600 const float x = c + d;
602 Silu{}.template operator()<
float>(e, x);
606 __host__ __device__ constexpr
void
615 __host__ __device__ constexpr
void
618 const float x0_f = c + d;
622 Silu{}.template operator()<
float>(x1_f, x0_f);
624 e = type_convert<half_t>(x1_f);
628 __host__ __device__ constexpr
void
631 const float x0_f = c + type_convert<float>(d);
635 Silu{}.template operator()<
float>(x1_f, x0_f);
637 e = type_convert<bhalf_t>(x1_f);
644 float scale_wei = 1.f,
645 float scale_out = 1.f)
650 template <
typename E,
typename C,
typename D>
651 __host__ __device__
void operator()(E& e,
const C& c,
const D& d)
const;
654 __host__ __device__
void
655 operator()<
f8_t, float,
float>(
f8_t& e,
const float& c,
const float& d)
const
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
int8_t int8_t
Definition: int8.hpp:20
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:990
_Float16 half_t
Definition: data_type.hpp:25
ushort bhalf_t
Definition: data_type.hpp:24
Definition: binary_element_wise_operation.hpp:462
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
Definition: binary_element_wise_operation.hpp:425
__host__ constexpr __device__ void operator()(T &y, const T &x0, const T &x1) const
Definition: binary_element_wise_operation.hpp:14
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:355
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:592
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
Definition: binary_element_wise_operation.hpp:281
Bilinear(float alpha=1.f, float beta=1.f)
Definition: binary_element_wise_operation.hpp:282
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &) const
float beta_
Definition: binary_element_wise_operation.hpp:351
float alpha_
Definition: binary_element_wise_operation.hpp:348
Definition: binary_element_wise_operation.hpp:642
float scale_in_
Definition: binary_element_wise_operation.hpp:660
float scale_wei_
Definition: binary_element_wise_operation.hpp:663
__host__ __device__ ConvScaleAdd(float scale_in=1.f, float scale_wei=1.f, float scale_out=1.f)
Definition: binary_element_wise_operation.hpp:643
float scale_out_
Definition: binary_element_wise_operation.hpp:664
__host__ __device__ void operator()(E &e, const C &c, const D &d) const
Definition: unary_element_wise_operation.hpp:688
Definition: binary_element_wise_operation.hpp:96
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:98
Definition: binary_element_wise_operation.hpp:107
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:109
Definition: binary_element_wise_operation.hpp:527
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
Definition: binary_element_wise_operation.hpp:118
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:210
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:214
float scale_
Definition: binary_element_wise_operation.hpp:231
__host__ __device__ ScaleAdd(float scale=1.f)
Definition: binary_element_wise_operation.hpp:211
Definition: unary_element_wise_operation.hpp:836
Definition: binary_element_wise_operation.hpp:237
__host__ constexpr __device__ void operator()(T &y, const T &x0, const T &x1) const