/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp Source File
binary_element_wise_operation.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
8 
9 namespace ck {
10 namespace tensor_operation {
11 namespace element_wise {
12 
13 struct Add
14 {
15  static constexpr const char* name = "Add";
16 
17  template <typename Y, typename X0, typename X1>
18  __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
19 
20  template <>
21  __host__ __device__ constexpr void
22  operator()<float>(float& y, const float& x0, const float& x1) const
23  {
24  y = x0 + x1;
25  };
26 
27  template <>
28  __host__ __device__ constexpr void
29  operator()<double>(double& y, const double& x0, const double& x1) const
30  {
31  y = x0 + x1;
32  };
33 
34  template <>
35  __host__ __device__ constexpr void
36  operator()<float>(float& y, const float& x0, const half_t& x1) const
37  {
38  y = x0 + type_convert<half_t>(x1);
39  };
40 
41  template <>
42  __host__ __device__ constexpr void
43  operator()<half_t>(half_t& y, const float& x0, const float& x1) const
44  {
45  y = type_convert<half_t>(x0 + x1);
46  };
47 
48  template <>
49  __host__ __device__ constexpr void
50  operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
51  {
52  y = x0 + type_convert<float>(x1);
53  };
54 
55  template <>
56  __host__ __device__ constexpr void
57  operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
58  {
59  y = x0 + x1;
60  };
61 
62  template <>
63  __host__ __device__ constexpr void
64  operator()<float>(float& y, const float& x0, const bhalf_t& x1) const
65  {
66  const float x1_tmp = ck::type_convert<float>(x1);
67  y = x0 + x1_tmp;
68  }
69 
70  template <>
71  __host__ __device__ constexpr void
72  operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
73  {
74  const float x1_tmp = ck::type_convert<float>(x0);
75  const float x2_tmp = ck::type_convert<float>(x1);
76  const float y_tmp = x1_tmp + x2_tmp;
77  y = ck::type_convert<bhalf_t>(y_tmp);
78  }
79 
80  template <>
81  __host__ __device__ constexpr void
82  operator()<bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
83  {
84  const float x2_tmp = ck::type_convert<float>(x1);
85  const float y_tmp = x0 + x2_tmp;
86  y = ck::type_convert<bhalf_t>(y_tmp);
87  }
88 
89  template <>
90  __host__ __device__ constexpr void
91  operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
92  {
93  y = x0 + x1;
94  };
95 };
96 
97 struct Max
98 {
99  static constexpr const char* name = "Max";
100 
101  template <typename Y, typename X0, typename X1>
102  __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
103  {
104  const Y x0_converted = type_convert<Y>(x0);
105  const Y x1_converted = type_convert<Y>(x1);
106  y = ck::math::max(x0_converted, x1_converted);
107  }
108 };
109 
110 struct Min
111 {
112  static constexpr const char* name = "Min";
113 
114  template <typename Y, typename X0, typename X1>
115  __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
116  {
117  const Y x0_converted = type_convert<Y>(x0);
118  const Y x1_converted = type_convert<Y>(x1);
119  y = ck::math::min(x0_converted, x1_converted);
120  }
121 };
122 
123 struct Multiply
124 {
125  static constexpr const char* name = "Multiply";
126 
127  template <typename Y, typename X0, typename X1>
128  __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
129 
130  template <>
131  __host__ __device__ constexpr void
132  operator()<float>(float& y, const float& x0, const float& x1) const
133  {
134  y = x0 * x1;
135  };
136 
137  template <>
138  __host__ __device__ constexpr void
139  operator()<double>(double& y, const double& x0, const double& x1) const
140  {
141  y = x0 * x1;
142  };
143 
144  template <>
145  __host__ __device__ constexpr void
146  operator()<float>(float& y, const float& x0, const half_t& x1) const
147  {
148  y = x0 * type_convert<half_t>(x1);
149  };
150 
151  template <>
152  __host__ __device__ constexpr void
153  operator()<half_t>(half_t& y, const float& x0, const float& x1) const
154  {
155  y = type_convert<half_t>(x0 * x1);
156  };
157 
158  template <>
159  __host__ __device__ constexpr void
160  operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
161  {
162  y = type_convert<half_t>(x0) * x1;
163  };
164 
165  template <>
166  __host__ __device__ constexpr void
167  operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
168  {
169  y = x0 * x1;
170  };
171 
172  template <>
173  __host__ __device__ constexpr void
174  operator()<float>(float& y, const float& x0, const bhalf_t& x1) const
175  {
176  const float x1_tmp = ck::type_convert<float>(x1);
177  y = x0 * x1_tmp;
178  }
179 
180  template <>
181  __host__ __device__ constexpr void
182  operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
183  {
184  const float x1_tmp = ck::type_convert<float>(x0);
185  const float x2_tmp = ck::type_convert<float>(x1);
186  const float y_tmp = x1_tmp * x2_tmp;
187  y = ck::type_convert<bhalf_t>(y_tmp);
188  }
189 
190  template <>
191  __host__ __device__ constexpr void
192  operator()<bhalf_t>(bhalf_t& y, const int8_t& x0, const bhalf_t& x1) const
193  {
194  const float x1_tmp = ck::type_convert<float>(x0);
195  const float x2_tmp = ck::type_convert<float>(x1);
196  const float y_tmp = x1_tmp * x2_tmp;
197  y = ck::type_convert<bhalf_t>(y_tmp);
198  }
199 
200  template <>
201  __host__ __device__ constexpr void
202  operator()<bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
203  {
204  const float x2_tmp = ck::type_convert<float>(x1);
205  const float y_tmp = x0 * x2_tmp;
206  y = ck::type_convert<bhalf_t>(y_tmp);
207  }
208 
209  template <>
210  __host__ __device__ constexpr void
211  operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
212  {
213  y = x0 * x1;
214  };
215 };
216 
217 struct ScaleAdd
218 {
219  static constexpr const char* name = "ScaleAdd";
220 
221  __host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {}
222 
223  template <typename Y, typename X0, typename X1>
224  __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const
225  {
226  y = ck::type_convert<Y>(scale_ * ck::type_convert<float>(x0) + ck::type_convert<float>(x1));
227  }
228 
229  template <>
230  __host__ __device__ void
231  operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
232  {
233  y = scale_ * x0 + ck::type_convert<float>(x1);
234  };
235 
236  template <>
237  __host__ __device__ void
238  operator()<float, float, bhalf_t>(float& y, const float& x0, const bhalf_t& x1) const
239  {
240  y = scale_ * x0 + ck::type_convert<float>(x1);
241  };
242 
243  float scale_;
244 };
245 
246 struct Subtract
247 {
248  static constexpr const char* name = "Subtract";
249 
250  template <typename T>
251  __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
252 
253  template <>
254  __host__ __device__ constexpr void
255  operator()<float>(float& y, const float& x0, const float& x1) const
256  {
257  y = x0 - x1;
258  };
259 
260  template <>
261  __host__ __device__ constexpr void
262  operator()<double>(double& y, const double& x0, const double& x1) const
263  {
264  y = x0 - x1;
265  };
266 
267  template <>
268  __host__ __device__ constexpr void
269  operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
270  {
271  y = x0 - x1;
272  };
273 
274  template <>
275  __host__ __device__ constexpr void
276  operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
277  {
278  const float x1_tmp = ck::type_convert<float>(x0);
279  const float x2_tmp = ck::type_convert<float>(x1);
280  const float y_tmp = x1_tmp - x2_tmp;
281  y = ck::type_convert<bhalf_t>(y_tmp);
282  }
283 
284  template <>
285  __host__ __device__ constexpr void
286  operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
287  {
288  y = x0 - x1;
289  };
290 };
291 
292 struct Bilinear
293 {
294  static constexpr const char* name = "Bilinear";
295 
296  Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
297 
298  template <typename Y, typename X0, typename X1>
299  __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const;
300 
301  template <>
302  __host__ __device__ constexpr void
303  operator()<double, double, double>(double& y, const double& x0, const double& x1) const
304  {
305  y = alpha_ * x0 + beta_ * x1;
306  };
307 
308  template <>
309  __host__ __device__ constexpr void
310  operator()<float, float, float>(float& y, const float& x0, const float& x1) const
311  {
312  y = alpha_ * x0 + beta_ * x1;
313  };
314 
315  template <>
316  __host__ __device__ constexpr void
317  operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
318  {
319  y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) +
320  beta_ * type_convert<float>(x1));
321  };
322 
323  template <>
324  __host__ __device__ constexpr void
325  operator()<half_t, half_t, half_t>(half_t& y, const half_t& x0, const half_t& x1) const
326  {
327  y = type_convert<half_t>(alpha_) * x0 + type_convert<half_t>(beta_) * x1;
328  };
329 
330  template <>
331  __host__ __device__ constexpr void
332  operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
333  {
334  y = type_convert<half_t>(alpha_ * x0 + beta_ * ck::type_convert<float>(x1));
335  };
336 
337  template <>
338  __host__ __device__ constexpr void
339  operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
340  {
341  const float x0_tmp = type_convert<float>(x0);
342  const float x1_tmp = type_convert<float>(x1);
343  const float y_tmp = alpha_ * x0_tmp + beta_ * x1_tmp;
344  y = type_convert<bhalf_t>(y_tmp);
345  };
346 
347  template <>
348  __host__ __device__ constexpr void
349  operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
350  {
351  y = type_convert<bhalf_t>(alpha_ * x0 + beta_ * ck::type_convert<float>(x1));
352  };
353 
354  template <>
355  __host__ __device__ constexpr void
356  operator()<int8_t, int32_t, int8_t>(int8_t& y, const int32_t& x0, const int8_t& x1) const
357  {
358  y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) +
359  beta_ * type_convert<float>(x1));
360  };
361 
362  float alpha_;
363  float beta_;
364 };
365 
366 struct AddClamp
367 {
368  static constexpr const char* name = "AddClamp";
369 
371  : floor_(floor), ceil_(ceil){};
372 
373  template <typename Y, typename X0, typename X1>
374  __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
375 
376  template <>
377  __host__ __device__ constexpr void
378  operator()<float, float, float>(float& y, const float& x0, const float& x1) const
379  {
380  const float a = x0 + x1;
381  y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
382  };
383 
384  template <>
385  __host__ __device__ constexpr void
386  operator()<double, double, double>(double& y, const double& x0, const double& x1) const
387  {
388  const double a = x0 + x1;
389  y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
390  };
391 
392  template <>
393  __host__ __device__ constexpr void
394  operator()<half_t, half_t, half_t>(half_t& y, const half_t& x0, const half_t& x1) const
395  {
396  const half_t floor = type_convert<half_t>(floor_);
397  const half_t ceil = type_convert<half_t>(ceil_);
398  const half_t a = x0 + x1;
399  y = a > floor ? (a < ceil ? a : ceil) : floor;
400  };
401 
402  template <>
403  __host__ __device__ constexpr void
404  operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
405  {
406  const float a = x0 + type_convert<float>(x1);
407  const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
408  y = type_convert<half_t>(b);
409  };
410 
411  template <>
412  __host__ __device__ constexpr void
413  operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
414  {
415  const float a = x0 + type_convert<float>(x1);
416  y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
417  };
418 
419  template <>
420  __host__ __device__ constexpr void
421  operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
422  {
423  const float a = x0 + type_convert<float>(x1);
424  const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
425  y = type_convert<bhalf_t>(b);
426  };
427 
428  template <>
429  __host__ __device__ constexpr void
430  operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
431  {
432  const float a = type_convert<float>(x0) + type_convert<float>(x1);
433  const float b = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
434  y = type_convert<bhalf_t>(b);
435  };
436 
437  template <>
438  __host__ __device__ constexpr void
439  operator()<int, int, int8_t>(int& y, const int& x0, const int8_t& x1) const
440  {
441  const int8_t a = x0 + x1;
442  y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
443  };
444 
445  template <>
446  __host__ __device__ constexpr void
447  operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
448  {
449  const int8_t a = x0 + x1;
450  y = a > floor_ ? (a < ceil_ ? a : ceil_) : floor_;
451  };
452 
453  const float floor_;
454  const float ceil_;
455 };
456 
457 struct AddRelu
458 {
459  static constexpr const char* name = "AddRelu";
460 
461  template <typename Y, typename X0, typename X1>
462  __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
463 
464  template <>
465  __host__ __device__ constexpr void
466  operator()<float, float, float>(float& y, const float& x0, const float& x1) const
467  {
468  const float a = x0 + x1;
469  y = a > 0.0f ? a : 0.0f;
470  };
471 
472  template <>
473  __host__ __device__ constexpr void
474  operator()<double, double, double>(double& y, const double& x0, const double& x1) const
475  {
476  const double a = x0 + x1;
477  y = a > 0.0 ? a : 0.0;
478  };
479 
480  template <>
481  __host__ __device__ constexpr void
482  operator()<half_t, half_t, half_t>(half_t& y, const half_t& x0, const half_t& x1) const
483  {
484  const half_t a = x0 + x1;
485  y = a > type_convert<half_t>(0.0f) ? a : type_convert<half_t>(0.0f);
486  };
487 
488  template <>
489  __host__ __device__ constexpr void
490  operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
491  {
492  const float a = x0 + type_convert<float>(x1);
493  const float b = a > 0.0f ? a : 0.0f;
494  y = type_convert<half_t>(b);
495  };
496 
497  template <>
498  __host__ __device__ constexpr void
499  operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
500  {
501  const float a = x0 + type_convert<float>(x1);
502  y = a > 0.0f ? a : 0.0f;
503  };
504 
505  template <>
506  __host__ __device__ constexpr void
507  operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
508  {
509  const float a = x0 + type_convert<float>(x1);
510  const float b = a > 0.0f ? a : 0.0f;
511  y = type_convert<bhalf_t>(b);
512  };
513 
514  template <>
515  __host__ __device__ constexpr void
516  operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
517  {
518  const float a = type_convert<float>(x0) + type_convert<float>(x1);
519  const float b = a > 0.0f ? a : 0.0f;
520  y = type_convert<bhalf_t>(b);
521  };
522 
523  template <>
524  __host__ __device__ constexpr void
525  operator()<int, int, int8_t>(int& y, const int& x0, const int8_t& x1) const
526  {
527  const int8_t a = x0 + x1;
528  y = a > 0 ? a : 0;
529  };
530 
531  template <>
532  __host__ __device__ constexpr void
533  operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
534  {
535  const int8_t a = x0 + x1;
536  y = a > 0 ? a : 0;
537  };
538 };
539 
541 {
542  static constexpr const char* name = "AddHardswish";
543 
544  template <typename T>
545  __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
546 
547  template <>
548  __host__ __device__ constexpr void
549  operator()<float>(float& y, const float& x0, const float& x1) const
550  {
551  float a = x0 + x1;
552  float b = a + float{3};
553  float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f;
554  y = c;
555  };
556 
557  template <>
558  __host__ __device__ constexpr void
559  operator()<double>(double& y, const double& x0, const double& x1) const
560  {
561  double a = x0 + x1;
562  double b = a + 3.0;
563  double c = (b > 0) * (b > 6.0 ? 6.0 : b) * a * 0.166667;
564  y = c;
565  };
566 
567  template <>
568  __host__ __device__ constexpr void
569  operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
570  {
571  float a = x0 + x1;
572  float b = a + 3.0f;
573  float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f;
574  y = c;
575  };
576 };
577 
578 // E = FastGelu(C + D)
580 {
581  static constexpr const char* name = "AddFastGelu";
582 
583  template <typename E, typename C, typename D>
584  __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
585 
586  template <>
587  __host__ __device__ constexpr void
588  operator()<float, float, float>(float& e, const float& c, const float& d) const
589  {
590  const float x = c + d;
591 
592  FastGelu{}.template operator()<float, float>(e, x);
593  }
594 
595  template <>
596  __host__ __device__ constexpr void
597  operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
598  {
599  const half_t x = c + d;
600 
601  ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
602  }
603 
604  template <>
605  __host__ __device__ constexpr void
606  operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
607  {
608  const float x0_f = c + d;
609 
610  float x1_f = 0;
611 
612  ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
613  x0_f);
614 
615  e = type_convert<half_t>(x1_f);
616  }
617 
618  template <>
619  __host__ __device__ constexpr void
620  operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& e, const bhalf_t& c, const bhalf_t& d) const
621  {
622  const float x0_f = type_convert<float>(c) + type_convert<float>(d);
623 
624  float x1_f = 0;
625 
626  FastGelu{}.template operator()<float, float>(x1_f, x0_f);
627 
628  e = type_convert<bhalf_t>(x1_f);
629  }
630 
631  template <>
632  __host__ __device__ constexpr void
633  operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
634  {
635  const float x0_f = c + type_convert<float>(d);
636 
637  float x1_f = 0;
638 
639  FastGelu{}.template operator()<float, float>(x1_f, x0_f);
640 
641  e = type_convert<bhalf_t>(x1_f);
642  }
643 };
644 
645 // E = MultiplyFastGelu(C + D)
647 {
648  static constexpr const char* name = "MultiplyFastGelu";
649 
650  template <typename E, typename C, typename D>
651  __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
652 
653  template <>
654  __host__ __device__ constexpr void
655  operator()<float, float, float>(float& e, const float& c, const float& d) const
656  {
657  const float x = c * d;
658 
659  FastGelu{}.template operator()<float, float>(e, x);
660  }
661 
662  template <>
663  __host__ __device__ constexpr void
664  operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
665  {
666  const half_t x = c * d;
667 
668  ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
669  }
670 
671  template <>
672  __host__ __device__ constexpr void
673  operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
674  {
675  const float x0_f = c * d;
676 
677  float x1_f = 0;
678 
679  ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
680  x0_f);
681 
682  e = type_convert<half_t>(x1_f);
683  }
684 
685  template <>
686  __host__ __device__ constexpr void
687  operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& e, const bhalf_t& c, const bhalf_t& d) const
688  {
689  const float x0_f = type_convert<float>(c) * type_convert<float>(d);
690 
691  float x1_f = 0;
692 
693  FastGelu{}.template operator()<float, float>(x1_f, x0_f);
694 
695  e = type_convert<bhalf_t>(x1_f);
696  }
697 
698  template <>
699  __host__ __device__ constexpr void
700  operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
701  {
702  const float x0_f = c * type_convert<float>(d);
703 
704  float x1_f = 0;
705 
706  FastGelu{}.template operator()<float, float>(x1_f, x0_f);
707 
708  e = type_convert<bhalf_t>(x1_f);
709  }
710 };
711 
712 // E = Silu(C + D)
713 struct AddSilu
714 {
715  static constexpr const char* name = "AddSilu";
716 
717  template <typename E, typename C, typename D>
718  __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
719 
720  template <>
721  __host__ __device__ constexpr void
722  operator()<float, float, float>(float& e, const float& c, const float& d) const
723  {
724  const float x = c + d;
725 
726  Silu{}.template operator()<float>(e, x);
727  }
728 
729  template <>
730  __host__ __device__ constexpr void
731  operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
732  {
733  const half_t x = c + d;
734 
735  Silu{}.template operator()<half_t>(e, x);
736  }
737 
738  template <>
739  __host__ __device__ constexpr void
740  operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
741  {
742  const float x0_f = c + d;
743 
744  float x1_f = 0;
745 
746  Silu{}.template operator()<float>(x1_f, x0_f);
747 
748  e = type_convert<half_t>(x1_f);
749  }
750 
751  template <>
752  __host__ __device__ constexpr void
753  operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
754  {
755  const float x0_f = c + type_convert<float>(d);
756 
757  float x1_f = 0;
758 
759  Silu{}.template operator()<float>(x1_f, x0_f);
760 
761  e = type_convert<bhalf_t>(x1_f);
762  }
763 };
764 
766 {
767  static constexpr const char* name = "ConvScaleAdd";
768 
769  __host__ __device__ ConvScaleAdd(float scale_in = 1.f,
770  float scale_wei = 1.f,
771  float scale_out = 1.f)
772  : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
773  {
774  }
775 
776  template <typename E, typename C, typename D>
777  __host__ __device__ void operator()(E& e, const C& c, const D& d) const;
778 
779  template <>
780  __host__ __device__ void
781  operator()<f8_t, float, float>(f8_t& e, const float& c, const float& d) const
782  {
783  float x;
784  Add{}.template operator()<float>(x, c * scale_in_ * scale_wei_, d);
785  e = type_convert<f8_t>(x * scale_out_);
786  };
787 
788  float scale_in_;
789  float scale_wei_;
790  float scale_out_;
791 };
792 
793 } // namespace element_wise
794 } // namespace tensor_operation
795 } // namespace ck
__host__ T ceil(T x)
Definition: math_v2.hpp:331
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ T floor(T x)
Definition: math_v2.hpp:367
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: ck.hpp:270
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1762
_Float16 half_t
Definition: data_type.hpp:31
ushort bhalf_t
Definition: data_type.hpp:30
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1517
signed int int32_t
Definition: stdint.h:123
signed char int8_t
Definition: stdint.h:121
Definition: numeric_limits.hpp:310
Definition: amd_ck_fp8.hpp:36
Definition: binary_element_wise_operation.hpp:367
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:368
AddClamp(float floor=0.f, float ceil=NumericLimits< float >::Max())
Definition: binary_element_wise_operation.hpp:370
const float ceil_
Definition: binary_element_wise_operation.hpp:454
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
const float floor_
Definition: binary_element_wise_operation.hpp:451
Definition: binary_element_wise_operation.hpp:580
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:581
Definition: binary_element_wise_operation.hpp:541
__host__ constexpr __device__ void operator()(T &y, const T &x0, const T &x1) const
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:542
Definition: binary_element_wise_operation.hpp:14
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:15
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:458
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:459
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:714
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:715
Definition: binary_element_wise_operation.hpp:293
Bilinear(float alpha=1.f, float beta=1.f)
Definition: binary_element_wise_operation.hpp:296
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:294
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &) const
float beta_
Definition: binary_element_wise_operation.hpp:363
float alpha_
Definition: binary_element_wise_operation.hpp:360
Definition: binary_element_wise_operation.hpp:766
float scale_in_
Definition: binary_element_wise_operation.hpp:786
float scale_wei_
Definition: binary_element_wise_operation.hpp:789
__host__ __device__ ConvScaleAdd(float scale_in=1.f, float scale_wei=1.f, float scale_out=1.f)
Definition: binary_element_wise_operation.hpp:769
float scale_out_
Definition: binary_element_wise_operation.hpp:790
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:767
__host__ __device__ void operator()(E &e, const C &c, const D &d) const
Definition: unary_element_wise_operation.hpp:924
Definition: binary_element_wise_operation.hpp:98
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:99
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:102
Definition: binary_element_wise_operation.hpp:111
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:112
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:115
Definition: binary_element_wise_operation.hpp:647
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:648
Definition: binary_element_wise_operation.hpp:124
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:125
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:218
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:224
float scale_
Definition: binary_element_wise_operation.hpp:241
__host__ __device__ ScaleAdd(float scale=1.f)
Definition: binary_element_wise_operation.hpp:221
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:219
Definition: unary_element_wise_operation.hpp:1087
Definition: binary_element_wise_operation.hpp:247
static constexpr const char * name
Definition: binary_element_wise_operation.hpp:248
__host__ constexpr __device__ void operator()(T &y, const T &x0, const T &x1) const