/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/element/binary_element_wise_operation.hpp Source File
binary_element_wise_operation.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
8 
9 namespace ck {
10 namespace tensor_operation {
11 namespace element_wise {
12 
13 struct Add
14 {
15  template <typename Y, typename X0, typename X1>
16  __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
17 
18  template <>
19  __host__ __device__ constexpr void
20  operator()<float>(float& y, const float& x0, const float& x1) const
21  {
22  y = x0 + x1;
23  };
24 
25  template <>
26  __host__ __device__ constexpr void
27  operator()<double>(double& y, const double& x0, const double& x1) const
28  {
29  y = x0 + x1;
30  };
31 
32  template <>
33  __host__ __device__ constexpr void
34  operator()<float>(float& y, const float& x0, const half_t& x1) const
35  {
36  y = x0 + type_convert<half_t>(x1);
37  };
38 
39  template <>
40  __host__ __device__ constexpr void
41  operator()<half_t>(half_t& y, const float& x0, const float& x1) const
42  {
43  y = type_convert<half_t>(x0 + x1);
44  };
45 
46  template <>
47  __host__ __device__ constexpr void
48  operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
49  {
50  y = type_convert<half_t>(x0) + x1;
51  };
52 
53  template <>
54  __host__ __device__ constexpr void
55  operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
56  {
57  y = x0 + x1;
58  };
59 
60  template <>
61  __host__ __device__ constexpr void
62  operator()<float>(float& y, const float& x0, const bhalf_t& x1) const
63  {
64  const float x1_tmp = ck::type_convert<float>(x1);
65  y = x0 + x1_tmp;
66  }
67 
68  template <>
69  __host__ __device__ constexpr void
70  operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
71  {
72  const float x1_tmp = ck::type_convert<float>(x0);
73  const float x2_tmp = ck::type_convert<float>(x1);
74  const float y_tmp = x1_tmp + x2_tmp;
75  y = ck::type_convert<bhalf_t>(y_tmp);
76  }
77 
78  template <>
79  __host__ __device__ constexpr void
80  operator()<bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
81  {
82  const float x2_tmp = ck::type_convert<float>(x1);
83  const float y_tmp = x0 + x2_tmp;
84  y = ck::type_convert<bhalf_t>(y_tmp);
85  }
86 
87  template <>
88  __host__ __device__ constexpr void
89  operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
90  {
91  y = x0 + x1;
92  };
93 };
94 
95 struct Max
96 {
97  template <typename Y, typename X0, typename X1>
98  __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
99  {
100  const Y x0_converted = type_convert<Y>(x0);
101  const Y x1_converted = type_convert<Y>(x1);
102  y = ck::math::max(x0_converted, x1_converted);
103  }
104 };
105 
106 struct Min
107 {
108  template <typename Y, typename X0, typename X1>
109  __host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
110  {
111  const Y x0_converted = type_convert<Y>(x0);
112  const Y x1_converted = type_convert<Y>(x1);
113  y = ck::math::min(x0_converted, x1_converted);
114  }
115 };
116 
117 struct Multiply
118 {
119  template <typename Y, typename X0, typename X1>
120  __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
121 
122  template <>
123  __host__ __device__ constexpr void
124  operator()<float>(float& y, const float& x0, const float& x1) const
125  {
126  y = x0 * x1;
127  };
128 
129  template <>
130  __host__ __device__ constexpr void
131  operator()<double>(double& y, const double& x0, const double& x1) const
132  {
133  y = x0 * x1;
134  };
135 
136  template <>
137  __host__ __device__ constexpr void
138  operator()<float>(float& y, const float& x0, const half_t& x1) const
139  {
140  y = x0 * type_convert<half_t>(x1);
141  };
142 
143  template <>
144  __host__ __device__ constexpr void
145  operator()<half_t>(half_t& y, const float& x0, const float& x1) const
146  {
147  y = type_convert<half_t>(x0 * x1);
148  };
149 
150  template <>
151  __host__ __device__ constexpr void
152  operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
153  {
154  y = type_convert<half_t>(x0) * x1;
155  };
156 
157  template <>
158  __host__ __device__ constexpr void
159  operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
160  {
161  y = x0 * x1;
162  };
163 
164  template <>
165  __host__ __device__ constexpr void
166  operator()<float>(float& y, const float& x0, const bhalf_t& x1) const
167  {
168  const float x1_tmp = ck::type_convert<float>(x1);
169  y = x0 * x1_tmp;
170  }
171 
172  template <>
173  __host__ __device__ constexpr void
174  operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
175  {
176  const float x1_tmp = ck::type_convert<float>(x0);
177  const float x2_tmp = ck::type_convert<float>(x1);
178  const float y_tmp = x1_tmp * x2_tmp;
179  y = ck::type_convert<bhalf_t>(y_tmp);
180  }
181 
182  template <>
183  __host__ __device__ constexpr void
184  operator()<bhalf_t>(bhalf_t& y, const int8_t& x0, const bhalf_t& x1) const
185  {
186  const float x1_tmp = ck::type_convert<float>(x0);
187  const float x2_tmp = ck::type_convert<float>(x1);
188  const float y_tmp = x1_tmp * x2_tmp;
189  y = ck::type_convert<bhalf_t>(y_tmp);
190  }
191 
192  template <>
193  __host__ __device__ constexpr void
194  operator()<bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
195  {
196  const float x2_tmp = ck::type_convert<float>(x1);
197  const float y_tmp = x0 * x2_tmp;
198  y = ck::type_convert<bhalf_t>(y_tmp);
199  }
200 
201  template <>
202  __host__ __device__ constexpr void
203  operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
204  {
205  y = x0 * x1;
206  };
207 };
208 
209 struct ScaleAdd
210 {
211  __host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {}
212 
213  template <typename Y, typename X0, typename X1>
214  __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const
215  {
216  y = ck::type_convert<Y>(scale_ * ck::type_convert<float>(x0) + ck::type_convert<float>(x1));
217  }
218 
219  template <>
220  __host__ __device__ void
221  operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
222  {
223  y = scale_ * x0 + ck::type_convert<float>(x1);
224  };
225 
226  template <>
227  __host__ __device__ void
228  operator()<float, float, bhalf_t>(float& y, const float& x0, const bhalf_t& x1) const
229  {
230  y = scale_ * x0 + ck::type_convert<float>(x1);
231  };
232 
233  float scale_;
234 };
235 
236 struct Subtract
237 {
238  template <typename T>
239  __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
240 
241  template <>
242  __host__ __device__ constexpr void
243  operator()<float>(float& y, const float& x0, const float& x1) const
244  {
245  y = x0 - x1;
246  };
247 
248  template <>
249  __host__ __device__ constexpr void
250  operator()<double>(double& y, const double& x0, const double& x1) const
251  {
252  y = x0 - x1;
253  };
254 
255  template <>
256  __host__ __device__ constexpr void
257  operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
258  {
259  y = x0 - x1;
260  };
261 
262  template <>
263  __host__ __device__ constexpr void
264  operator()<bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
265  {
266  const float x1_tmp = ck::type_convert<float>(x0);
267  const float x2_tmp = ck::type_convert<float>(x1);
268  const float y_tmp = x1_tmp - x2_tmp;
269  y = ck::type_convert<bhalf_t>(y_tmp);
270  }
271 
272  template <>
273  __host__ __device__ constexpr void
274  operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
275  {
276  y = x0 - x1;
277  };
278 };
279 
280 struct Bilinear
281 {
282  Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
283 
284  template <typename Y, typename X0, typename X1>
285  __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&) const;
286 
287  template <>
288  __host__ __device__ constexpr void
289  operator()<double, double, double>(double& y, const double& x0, const double& x1) const
290  {
291  y = alpha_ * x0 + beta_ * x1;
292  };
293 
294  template <>
295  __host__ __device__ constexpr void
296  operator()<float, float, float>(float& y, const float& x0, const float& x1) const
297  {
298  y = alpha_ * x0 + beta_ * x1;
299  };
300 
301  template <>
302  __host__ __device__ constexpr void
303  operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
304  {
305  y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) +
306  beta_ * type_convert<float>(x1));
307  };
308 
309  template <>
310  __host__ __device__ constexpr void
311  operator()<half_t, half_t, half_t>(half_t& y, const half_t& x0, const half_t& x1) const
312  {
313  y = type_convert<half_t>(alpha_) * x0 + type_convert<half_t>(beta_) * x1;
314  };
315 
316  template <>
317  __host__ __device__ constexpr void
318  operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
319  {
320  y = type_convert<half_t>(alpha_ * x0 + beta_ * ck::type_convert<float>(x1));
321  };
322 
323  template <>
324  __host__ __device__ constexpr void
325  operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& y, const bhalf_t& x0, const bhalf_t& x1) const
326  {
327  const float x0_tmp = type_convert<float>(x0);
328  const float x1_tmp = type_convert<float>(x1);
329  const float y_tmp = alpha_ * x0_tmp + beta_ * x1_tmp;
330  y = type_convert<bhalf_t>(y_tmp);
331  };
332 
333  template <>
334  __host__ __device__ constexpr void
335  operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
336  {
337  const float x1_tmp = ck::type_convert<float>(x1);
338  const float y_tmp = alpha_ * x0 + beta_ * x1_tmp;
339  y = y_tmp;
340  };
341 
342  template <>
343  __host__ __device__ constexpr void
344  operator()<int8_t, int32_t, int8_t>(int8_t& y, const int32_t& x0, const int8_t& x1) const
345  {
346  y = type_convert<int8_t>(alpha_ * type_convert<float>(x0) +
347  beta_ * type_convert<float>(x1));
348  };
349 
350  float alpha_;
351  float beta_;
352 };
353 
354 struct AddRelu
355 {
356  template <typename Y, typename X0, typename X1>
357  __host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
358 
359  template <>
360  __host__ __device__ constexpr void
361  operator()<float, float, float>(float& y, const float& x0, const float& x1) const
362  {
363  const float a = x0 + x1;
364  y = a > 0.0f ? a : 0.0f;
365  };
366 
367  template <>
368  __host__ __device__ constexpr void
369  operator()<double, double, double>(double& y, const double& x0, const double& x1) const
370  {
371  const double a = x0 + x1;
372  y = a > 0.0 ? a : 0.0;
373  };
374 
375  template <>
376  __host__ __device__ constexpr void
377  operator()<half_t, half_t, half_t>(half_t& y, const half_t& x0, const half_t& x1) const
378  {
379  const half_t a = x0 + x1;
380  y = a > type_convert<half_t>(0.0f) ? a : type_convert<half_t>(0.0f);
381  };
382 
383  template <>
384  __host__ __device__ constexpr void
385  operator()<half_t, float, half_t>(half_t& y, const float& x0, const half_t& x1) const
386  {
387  const float a = x0 + x1;
388  y = a > type_convert<half_t>(0.0f) ? a : type_convert<half_t>(0.0f);
389  };
390 
391  template <>
392  __host__ __device__ constexpr void
393  operator()<float, float, half_t>(float& y, const float& x0, const half_t& x1) const
394  {
395  const float a = x0 + type_convert<float>(x1);
396  y = a > 0.0f ? a : 0.0f;
397  };
398 
399  template <>
400  __host__ __device__ constexpr void
401  operator()<bhalf_t, float, bhalf_t>(bhalf_t& y, const float& x0, const bhalf_t& x1) const
402  {
403  const float a = x0 + type_convert<float>(x1);
404  y = a > type_convert<bhalf_t>(0.0f) ? a : type_convert<bhalf_t>(0.0f);
405  };
406 
407  template <>
408  __host__ __device__ constexpr void
409  operator()<int, int, int8_t>(int& y, const int& x0, const int8_t& x1) const
410  {
411  const int8_t a = x0 + x1;
412  y = a > 0 ? a : 0;
413  };
414 
415  template <>
416  __host__ __device__ constexpr void
417  operator()<int8_t, int8_t, int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
418  {
419  const int8_t a = x0 + x1;
420  y = a > 0 ? a : 0;
421  };
422 };
423 
425 {
426  template <typename T>
427  __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
428 
429  template <>
430  __host__ __device__ constexpr void
431  operator()<float>(float& y, const float& x0, const float& x1) const
432  {
433  float a = x0 + x1;
434  float b = a + float{3};
435  float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f;
436  y = c;
437  };
438 
439  template <>
440  __host__ __device__ constexpr void
441  operator()<double>(double& y, const double& x0, const double& x1) const
442  {
443  double a = x0 + x1;
444  double b = a + 3.0;
445  double c = (b > 0) * (b > 6.0 ? 6.0 : b) * a * 0.166667;
446  y = c;
447  };
448 
449  template <>
450  __host__ __device__ constexpr void
451  operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
452  {
453  float a = x0 + x1;
454  float b = a + 3.0f;
455  float c = (b > 0) * (b > 6.0f ? 6.0f : b) * a * 0.166667f;
456  y = c;
457  };
458 };
459 
460 // E = FastGelu(C + D)
462 {
463  template <typename E, typename C, typename D>
464  __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
465 
466  template <>
467  __host__ __device__ constexpr void
468  operator()<float, float, float>(float& e, const float& c, const float& d) const
469  {
470  const float x = c + d;
471 
472  FastGelu{}.template operator()<float, float>(e, x);
473  }
474 
475  template <>
476  __host__ __device__ constexpr void
477  operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
478  {
479  const half_t x = c + d;
480 
481  ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
482  }
483 
484  template <>
485  __host__ __device__ constexpr void
486  operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
487  {
488  const float x0_f = c + d;
489 
490  float x1_f = 0;
491 
492  ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
493  x0_f);
494 
495  e = type_convert<half_t>(x1_f);
496  }
497 
498  template <>
499  __host__ __device__ constexpr void
500  operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& e, const bhalf_t& c, const bhalf_t& d) const
501  {
502  const float x0_f = type_convert<float>(c) + type_convert<float>(d);
503 
504  float x1_f = 0;
505 
506  FastGelu{}.template operator()<float, float>(x1_f, x0_f);
507 
508  e = type_convert<bhalf_t>(x1_f);
509  }
510 
511  template <>
512  __host__ __device__ constexpr void
513  operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
514  {
515  const float x0_f = c + type_convert<float>(d);
516 
517  float x1_f = 0;
518 
519  FastGelu{}.template operator()<float, float>(x1_f, x0_f);
520 
521  e = type_convert<bhalf_t>(x1_f);
522  }
523 };
524 
525 // E = MultiplyFastGelu(C + D)
527 {
528  template <typename E, typename C, typename D>
529  __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
530 
531  template <>
532  __host__ __device__ constexpr void
533  operator()<float, float, float>(float& e, const float& c, const float& d) const
534  {
535  const float x = c * d;
536 
537  FastGelu{}.template operator()<float, float>(e, x);
538  }
539 
540  template <>
541  __host__ __device__ constexpr void
542  operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
543  {
544  const half_t x = c * d;
545 
546  ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
547  }
548 
549  template <>
550  __host__ __device__ constexpr void
551  operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
552  {
553  const float x0_f = c * d;
554 
555  float x1_f = 0;
556 
557  ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
558  x0_f);
559 
560  e = type_convert<half_t>(x1_f);
561  }
562 
563  template <>
564  __host__ __device__ constexpr void
565  operator()<bhalf_t, bhalf_t, bhalf_t>(bhalf_t& e, const bhalf_t& c, const bhalf_t& d) const
566  {
567  const float x0_f = type_convert<float>(c) * type_convert<float>(d);
568 
569  float x1_f = 0;
570 
571  FastGelu{}.template operator()<float, float>(x1_f, x0_f);
572 
573  e = type_convert<bhalf_t>(x1_f);
574  }
575 
576  template <>
577  __host__ __device__ constexpr void
578  operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
579  {
580  const float x0_f = c * type_convert<float>(d);
581 
582  float x1_f = 0;
583 
584  FastGelu{}.template operator()<float, float>(x1_f, x0_f);
585 
586  e = type_convert<bhalf_t>(x1_f);
587  }
588 };
589 
590 // E = Silu(C + D)
591 struct AddSilu
592 {
593  template <typename E, typename C, typename D>
594  __host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
595 
596  template <>
597  __host__ __device__ constexpr void
598  operator()<float, float, float>(float& e, const float& c, const float& d) const
599  {
600  const float x = c + d;
601 
602  Silu{}.template operator()<float>(e, x);
603  }
604 
605  template <>
606  __host__ __device__ constexpr void
607  operator()<half_t, half_t, half_t>(half_t& e, const half_t& c, const half_t& d) const
608  {
609  const half_t x = c + d;
610 
611  Silu{}.template operator()<half_t>(e, x);
612  }
613 
614  template <>
615  __host__ __device__ constexpr void
616  operator()<half_t, float, half_t>(half_t& e, const float& c, const half_t& d) const
617  {
618  const float x0_f = c + d;
619 
620  float x1_f = 0;
621 
622  Silu{}.template operator()<float>(x1_f, x0_f);
623 
624  e = type_convert<half_t>(x1_f);
625  }
626 
627  template <>
628  __host__ __device__ constexpr void
629  operator()<bhalf_t, float, bhalf_t>(bhalf_t& e, const float& c, const bhalf_t& d) const
630  {
631  const float x0_f = c + type_convert<float>(d);
632 
633  float x1_f = 0;
634 
635  Silu{}.template operator()<float>(x1_f, x0_f);
636 
637  e = type_convert<bhalf_t>(x1_f);
638  }
639 };
640 
642 {
643  __host__ __device__ ConvScaleAdd(float scale_in = 1.f,
644  float scale_wei = 1.f,
645  float scale_out = 1.f)
646  : scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
647  {
648  }
649 
650  template <typename E, typename C, typename D>
651  __host__ __device__ void operator()(E& e, const C& c, const D& d) const;
652 
653  template <>
654  __host__ __device__ void
655  operator()<f8_t, float, float>(f8_t& e, const float& c, const float& d) const
656  {
657  float x;
658  Add{}.template operator()<float>(x, c * scale_in_ * scale_wei_, d);
659  e = type_convert<f8_t>(x * scale_out_);
660  };
661 
662  float scale_in_;
663  float scale_wei_;
664  float scale_out_;
665 };
666 
667 } // namespace element_wise
668 } // namespace tensor_operation
669 } // namespace ck
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
int8_t int8_t
Definition: int8.hpp:20
Definition: ck.hpp:264
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:990
_Float16 half_t
Definition: data_type.hpp:25
ushort bhalf_t
Definition: data_type.hpp:24
Definition: binary_element_wise_operation.hpp:462
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
Definition: binary_element_wise_operation.hpp:425
__host__ constexpr __device__ void operator()(T &y, const T &x0, const T &x1) const
Definition: binary_element_wise_operation.hpp:14
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:355
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:592
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
Definition: binary_element_wise_operation.hpp:281
Bilinear(float alpha=1.f, float beta=1.f)
Definition: binary_element_wise_operation.hpp:282
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &) const
float beta_
Definition: binary_element_wise_operation.hpp:351
float alpha_
Definition: binary_element_wise_operation.hpp:348
Definition: binary_element_wise_operation.hpp:642
float scale_in_
Definition: binary_element_wise_operation.hpp:660
float scale_wei_
Definition: binary_element_wise_operation.hpp:663
__host__ __device__ ConvScaleAdd(float scale_in=1.f, float scale_wei=1.f, float scale_out=1.f)
Definition: binary_element_wise_operation.hpp:643
float scale_out_
Definition: binary_element_wise_operation.hpp:664
__host__ __device__ void operator()(E &e, const C &c, const D &d) const
Definition: unary_element_wise_operation.hpp:688
Definition: binary_element_wise_operation.hpp:96
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:98
Definition: binary_element_wise_operation.hpp:107
__host__ __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:109
Definition: binary_element_wise_operation.hpp:527
__host__ constexpr __device__ void operator()(E &e, const C &c, const D &d) const
Definition: binary_element_wise_operation.hpp:118
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:210
__host__ constexpr __device__ void operator()(Y &y, const X0 &x0, const X1 &x1) const
Definition: binary_element_wise_operation.hpp:214
float scale_
Definition: binary_element_wise_operation.hpp:231
__host__ __device__ ScaleAdd(float scale=1.f)
Definition: binary_element_wise_operation.hpp:211
Definition: unary_element_wise_operation.hpp:836
Definition: binary_element_wise_operation.hpp:237
__host__ constexpr __device__ void operator()(T &y, const T &x0, const T &x1) const