/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/element/element_wise_operation.hpp Source File
element_wise_operation.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/math_v2.hpp"
11 
12 namespace ck {
13 namespace tensor_operation {
14 namespace element_wise {
15 
16 // Need to ensure compiler will fail if there is no matching candidate, instead of compiler
17 // siliently do implicit type conversion
18 //
19 // Example:
20 //
21 // struct ExampleElementwiseOp
22 // {
23 // template<typename Y, typename X>
24 // __host__ __device__ constexpr void
25 // operator()(Y&, const X) const;
26 //
27 // template<>
28 // __host__ __device__ constexpr void
29 // operator()<half_t, half_t>(half_t& y, const half_t& x) const
30 // {
31 // }
32 // };
33 
34 struct AddReluAdd
35 {
36  template <typename Y, typename X0, typename X1, typename X2>
37  __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const;
38 
39  template <>
40  __host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
41  half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const
42  {
43  half_t a = x0 + x1;
44  half_t b = a > 0 ? a : 0;
45  y = b + x2;
46  }
47 
48  template <>
49  __host__ __device__ constexpr void operator()<float, float, float, float>(float& y,
50  const float& x0,
51  const float& x1,
52  const float& x2) const
53  {
54  float a = x0 + x1;
55  float b = a > 0 ? a : 0;
56  float c = b + x2;
57  y = c;
58  }
59 
60  template <>
61  __host__ __device__ constexpr void operator()<half_t, float, half_t, half_t>(
62  half_t& y, const float& x0, const half_t& x1, const half_t& x2) const
63  {
64  float a = x0 + x1;
65  float b = a > 0 ? a : 0;
66  float c = b + x2;
67  y = c;
68  }
69 
70  template <>
71  __host__ __device__ constexpr void operator()<bhalf_t, float, bhalf_t, bhalf_t>(
72  bhalf_t& y, const float& x0, const bhalf_t& x1, const bhalf_t& x2) const
73  {
74  float a = x0 + x1;
75  float b = a > 0 ? a : 0;
76  float c = b + x2;
77  y = c;
78  }
79 
80  template <>
81  __host__ __device__ constexpr void operator()<int8_t, int8_t, int8_t, int8_t>(
82  int8_t& y, const int8_t& x0, const int8_t& x1, const int8_t& x2) const
83  {
84  int32_t a = x0 + x1;
85  int32_t b = a > 0 ? a : 0;
86  int32_t c = b + x2;
87  y = c;
88  }
89 
90 #ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
91  template <>
92  __host__ __device__ constexpr void operator()<int4_t, int8_t, int4_t, int4_t>(
93  int4_t& y, const int8_t& x0, const int4_t& x1, const int4_t& x2) const
94  {
95  int32_t a = x0 + x1;
96  int32_t b = a > 0 ? a : 0;
97  int32_t c = b + x2;
98  y = c;
99  }
100 #endif // CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
101 };
102 
104 {
105  template <typename Y, typename X0, typename X1, typename X2>
106  __host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const;
107 
108  template <>
109  __host__ __device__ constexpr void operator()<float, float, float, float>(float& y,
110  const float& x0,
111  const float& x1,
112  const float& x2) const
113  {
114  float a = x0 + x1;
115  float b = a + float{3};
116  float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
117  float d = c + x2;
118  y = d;
119  }
120 
121  template <>
122  __host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
123  half_t& y, const half_t& x0, const half_t& x1, const half_t& x2) const
124  {
125  float a = x0 + x1;
126  float b = a + float{3};
127  float c = (b > 0) * (b > float{6} ? float{6} : b) * a * float{0.166667};
128  float d = c + x2;
129  y = d;
130  }
131 };
132 
133 // C = A * B
134 // E = C + D0 + D1
135 struct AddAdd
136 {
137  template <typename E, typename C, typename D0, typename D1>
138  __host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const
139  {
140  // Only support floating so far
143  "Data type is not supported by this operation!");
144 
147  "Data type is not supported by this operation!");
148 
151  "Data type is not supported by this operation!");
152 
155  "Data type is not supported by this operation!");
156 
157  const C y = c + type_convert<C>(d0) + type_convert<C>(d1);
158  e = type_convert<E>(y);
159  }
160 };
161 
162 // C = A * B
163 // E = (C + D0) x D1
165 {
166  template <typename E, typename C, typename D0, typename D1>
167  __host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
168 
169  template <>
170  __host__ __device__ void operator()<half_t, half_t, half_t, half_t>(half_t& e,
171  const half_t& c,
172  const half_t& d0,
173  const half_t& d1) const
174  {
175  const half_t y = (c + d0) * d1;
176  e = y;
177  }
178  template <>
179  __host__ __device__ void operator()<half_t, float, half_t, half_t>(half_t& e,
180  const float& c,
181  const half_t& d0,
182  const half_t& d1) const
183  {
184  const half_t y = (type_convert<half_t>(c) + d0) * d1;
185  e = y;
186  }
187  template <>
188  __host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
189  const float& c,
190  const half_t& d0,
191  const half_t& d1) const
192  {
193  const float y = (c + d0) * d1;
194  e = y;
195  }
196 };
197 
198 // C = A * B
199 // E = C x D0 + D1
201 {
202  template <typename E, typename C, typename D0, typename D1>
203  __host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
204 
205  template <>
206  __host__ __device__ void operator()<half_t, half_t, half_t, half_t>(half_t& e,
207  const half_t& c,
208  const half_t& d0,
209  const half_t& d1) const
210  {
211  const half_t y = (c * d0) + d1;
212  e = y;
213  }
214  template <>
215  __host__ __device__ void operator()<half_t, float, half_t, half_t>(half_t& e,
216  const float& c,
217  const half_t& d0,
218  const half_t& d1) const
219  {
220  const half_t y = type_convert<half_t>(c) * d0 + d1;
221  e = y;
222  }
223  template <>
224  __host__ __device__ void operator()<bhalf_t, float, bhalf_t, bhalf_t>(bhalf_t& e,
225  const float& c,
226  const bhalf_t& d0,
227  const bhalf_t& d1) const
228  {
229  const bhalf_t y = type_convert<bhalf_t>(c) * d0 + d1;
230  e = y;
231  }
232  template <>
233  __host__ __device__ void operator()<float, float, half_t, half_t>(float& e,
234  const float& c,
235  const half_t& d0,
236  const half_t& d1) const
237  {
238  const float y = c * d0 + d1;
239  e = y;
240  }
241  template <>
242  __host__ __device__ void operator()<half_t, float, float, float>(half_t& e,
243  const float& c,
244  const float& d0,
245  const float& d1) const
246  {
247  const float y = c * d0 + d1;
248  e = y;
249  }
250 };
251 
253 {
254  template <typename E, typename C, typename D0, typename D1>
255  __host__ __device__ constexpr void
256  operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
257 
258  template <>
259  __host__ __device__ constexpr void operator()<ck::half_t, float, float, float>(
260  ck::half_t& e, const float& c, const float& d0, const float& d1) const
261  {
262  const float x0_f = c * d0 * d1;
263 
264  e = ck::type_convert<ck::half_t>(x0_f);
265  }
266 
267  template <>
268  __host__ __device__ constexpr void operator()<ck::bhalf_t, float, float, float>(
269  ck::bhalf_t& e, const float& c, const float& d0, const float& d1) const
270  {
271  const float x0_f = c * d0 * d1;
272 
273  e = ck::type_convert<ck::bhalf_t>(x0_f);
274  }
275 
276  template <>
277  __host__ __device__ constexpr void operator()<ck::half_t, int, ck::half_t, ck::half_t>(
278  ck::half_t& e, const int& c, const ck::half_t& d0, const ck::half_t& d1) const
279  {
280  const float x0_f =
281  ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
282 
283  e = ck::type_convert<ck::half_t>(x0_f);
284  }
285 
286  template <>
287  __host__ __device__ constexpr void operator()<ck::bhalf_t, int, float, float>(
288  ck::bhalf_t& e, const int& c, const float& d0, const float& d1) const
289  {
290  const float x0_f =
291  ck::type_convert<float>(c) * ck::type_convert<float>(d0) * ck::type_convert<float>(d1);
292 
293  e = ck::type_convert<ck::bhalf_t>(x0_f);
294  }
295 };
296 
298 {
299  template <typename E, typename C, typename D0, typename D1>
300  __host__ __device__ constexpr void
301  operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
302 
303  template <>
304  __host__ __device__ constexpr void operator()<ck::bhalf_t, float, ck::bhalf_t, ck::bhalf_t>(
305  ck::bhalf_t& e, const float& c, const ck::bhalf_t& d0, const ck::bhalf_t& d1) const
306  {
307  const float x0_f = c * ck::type_convert<float>(d0) + ck::type_convert<float>(d1);
308 
309  float x1_f = 0;
310 
311  FastGelu{}.template operator()<float, float>(x1_f, x0_f);
312 
313  e = ck::type_convert<ck::bhalf_t>(x1_f);
314  }
315 };
316 
317 // E = FastGelu(C + D0 + D1)
319 {
320  template <typename E, typename C, typename D0, typename D1>
321  __host__ __device__ constexpr void
322  operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
323 
324  template <>
325  __host__ __device__ constexpr void operator()<float, float, float, float>(float& e,
326  const float& c,
327  const float& d0,
328  const float& d1) const
329  {
330  const float x = c + d0 + d1;
331 
332  FastGelu{}.template operator()<float, float>(e, x);
333  }
334 
335  template <>
336  __host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
337  half_t& e, const half_t& c, const half_t& d0, const half_t& d1) const
338  {
339  const half_t x = c + d0 + d1;
340 
341  ck::tensor_operation::element_wise::FastGelu{}.template operator()<half_t, half_t>(e, x);
342  }
343 
344  template <>
345  __host__ __device__ constexpr void operator()<half_t, float, half_t, half_t>(
346  half_t& e, const float& c, const half_t& d0, const half_t& d1) const
347  {
348  const float x0_f = c + d0 + d1;
349 
350  float x1_f = 0;
351 
352  ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
353  x0_f);
354 
355  e = type_convert<half_t>(x1_f);
356  }
357 
358  template <>
359  __host__ __device__ constexpr void operator()<bhalf_t, float, bhalf_t, bhalf_t>(
360  bhalf_t& e, const float& c, const bhalf_t& d0, const bhalf_t& d1) const
361  {
362  const float x0_f = c + type_convert<float>(d0) + type_convert<float>(d1);
363 
364  float x1_f = 0;
365 
366  ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
367  x0_f);
368 
369  e = type_convert<bhalf_t>(x1_f);
370  }
371 
372  template <>
373  __host__ __device__ constexpr void operator()<int8_t, int32_t, int8_t, int8_t>(
374  int8_t& e, const int32_t& c, const int8_t& d0, const int8_t& d1) const
375  {
376  const float x0_f =
377  type_convert<float>(c) + type_convert<float>(d0) + type_convert<float>(d1);
378 
379  float x1_f = 0;
380 
381  ck::tensor_operation::element_wise::FastGelu{}.template operator()<float, float>(x1_f,
382  x0_f);
383 
384  e = type_convert<int8_t>(x1_f);
385  }
386 };
387 
388 // E = Relu(alpha1 * C + alpha2 * D0 + D1)
390 {
391 
392  ScaleAddScaleAddRelu(const float alpha1 = 1.f, const float alpha2 = 1.f)
393  : alpha1_(alpha1), alpha2_(alpha2)
394  {
395  }
396 
397  template <typename E, typename C, typename D0, typename D1>
398  __host__ __device__ constexpr void
399  operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
400 
401  template <>
402  __host__ __device__ constexpr void operator()<float, float, float, float>(float& e,
403  const float& c,
404  const float& d0,
405  const float& d1) const
406  {
407  const float x = c * alpha1_ + alpha2_ * d0 + d1;
408  e = x > 0 ? x : 0;
409  }
410 
411  template <>
412  __host__ __device__ constexpr void operator()<half_t, half_t, half_t, half_t>(
413  half_t& e, const half_t& c, const half_t& d0, const half_t& d1) const
414  {
415  const float x = type_convert<float>(c) * alpha1_ + alpha2_ * type_convert<float>(d0) +
416  type_convert<float>(d1);
417 
418  float result = 0;
419  result = x > 0 ? x : 0;
420 
421  e = type_convert<half_t>(result);
422  }
423 
424  template <>
425  __host__ __device__ constexpr void operator()<bhalf_t, bhalf_t, bhalf_t, bhalf_t>(
426  bhalf_t& e, const bhalf_t& c, const bhalf_t& d0, const bhalf_t& d1) const
427  {
428  const float x = type_convert<float>(c) * alpha1_ + alpha2_ * type_convert<float>(d0) +
429  type_convert<float>(d1);
430 
431  float result = 0;
432  result = x > 0 ? x : 0;
433 
434  e = type_convert<bhalf_t>(result);
435  }
436 
437  template <>
438  __host__ __device__ constexpr void operator()<int8_t, int8_t, float, float>(
439  int8_t& e, const int8_t& c, const float& d0, const float& d1) const
440  {
441  const float x = type_convert<float>(c) * alpha1_ + alpha2_ * d0 + d1;
442 
443  float result = 0;
444  result = x > 0 ? x : 0;
445 
446  e = type_convert<int8_t>(result);
447  }
448 
449  const float alpha1_;
450  const float alpha2_;
451 };
452 
453 struct Normalize
454 {
455  // FIXME: is double absolutely necessary?
456  Normalize(double epsilon = 1e-4) : epsilon_(epsilon) {}
457 
458  template <typename T1, typename T2, typename T3>
459  __host__ __device__ constexpr void operator()(T1& y,
460  const T1& x,
461  const T2& mean,
462  const T2& mean_square,
463  const T3& gamma,
464  const T3& beta) const;
465 
466  template <>
467  __host__ __device__ constexpr void operator()<half_t, float, half_t>(half_t& y,
468  const half_t& x,
469  const float& mean,
470  const float& mean_square,
471  const half_t& gamma,
472  const half_t& beta) const
473  {
474  using ck::math::sqrt;
475 
476  float variance = mean_square - (mean * mean);
477 
478  float tmp_x = type_convert<float>(x);
479  float tmp_gamma = type_convert<float>(gamma);
480  float tmp_beta = type_convert<float>(beta);
481 
482  float tmp_y =
483  ((tmp_x - mean) / sqrt(variance + type_convert<float>(epsilon_))) * tmp_gamma +
484  tmp_beta;
485 
486  y = type_convert<half_t>(tmp_y);
487  };
488 
489  template <>
490  __host__ __device__ constexpr void operator()<float, float, float>(float& y,
491  const float& x,
492  const float& mean,
493  const float& mean_square,
494  const float& gamma,
495  const float& beta) const
496  {
497  using ck::math::sqrt;
498 
499  float variance = mean_square - (mean * mean);
500  y = ((x - mean) / sqrt(variance + type_convert<float>(epsilon_))) * gamma + beta;
501  };
502 
503  template <>
504  __host__ __device__ constexpr void operator()<double, double, double>(double& y,
505  const double& x,
506  const double& mean,
507  const double& mean_square,
508  const double& gamma,
509  const double& beta) const
510  {
511  using ck::math::sqrt;
512 
513  double variance = mean_square - (mean * mean);
514  y = ((x - mean) / sqrt(variance + epsilon_)) * gamma + beta;
515  };
516 
517  // FIXME: is double absolutely necessary?
518  double epsilon_;
519 };
520 
521 // used by BatchNorm inference
522 // y = gamma * (x-mean) / sqrt(epsilon+variance) + beta
523 // The data type of mean and variance is used as AccDataType
525 {
526  NormalizeInInfer(double epsilon = 1e-4) : epsilon_(epsilon) {}
527 
528  template <typename T1, typename T2, typename T3, typename T4>
529  __host__ __device__ constexpr void operator()(T1& y,
530  const T1& x,
531  const T2& mean,
532  const T2& variance,
533  const T3& gamma,
534  const T4& beta) const
535  {
537  "Data type is not supported by this operation!");
538 
539  using ck::type_convert;
540  using ck::math::sqrt;
541 
542  T2 tmp_x, tmp_y;
543 
544  tmp_x = type_convert<T2>(x);
545 
546  tmp_y = ((tmp_x - mean) / sqrt(variance + type_convert<T2>(epsilon_))) *
547  type_convert<T2>(gamma) +
548  type_convert<T2>(beta);
549  y = type_convert<T1>(tmp_y);
550  };
551 
552  double epsilon_;
553 };
554 
555 template <typename Y, typename X>
557 
558 template <>
559 struct UnaryTypeConvert<float, ck::bhalf_t>
560 {
561  __host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
562  {
563  y = ck::type_convert<float, ck::bhalf_t>(x);
564  }
565 };
566 
567 template <>
568 struct UnaryTypeConvert<ck::bhalf_t, float>
569 {
570  __host__ __device__ void operator()(ck::bhalf_t& y, float& x) const
571  {
572  y = ck::type_convert<ck::bhalf_t, float>(x);
573  }
574 };
575 
576 } // namespace element_wise
577 } // namespace tensor_operation
578 } // namespace ck
int8_t int8_t
Definition: int8.hpp:20
Definition: ck.hpp:264
_Float16 half_t
Definition: data_type.hpp:25
ushort bhalf_t
Definition: data_type.hpp:24
__host__ constexpr __device__ Y type_convert(X x)
Definition: type_convert.hpp:80
_BitInt(4) int4_t
Definition: data_type.hpp:26
Definition: type.hpp:177
Definition: element_wise_operation.hpp:319
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:136
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:138
Definition: element_wise_operation.hpp:104
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &, const X2 &) const
Definition: element_wise_operation.hpp:165
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:35
__host__ constexpr __device__ void operator()(Y &, const X0 &, const X1 &, const X2 &) const
Definition: unary_element_wise_operation.hpp:688
Definition: element_wise_operation.hpp:298
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:201
__host__ __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:253
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
Definition: element_wise_operation.hpp:454
Normalize(double epsilon=1e-4)
Definition: element_wise_operation.hpp:456
double epsilon_
Definition: element_wise_operation.hpp:515
__host__ constexpr __device__ void operator()(T1 &y, const T1 &x, const T2 &mean, const T2 &mean_square, const T3 &gamma, const T3 &beta) const
Definition: element_wise_operation.hpp:525
double epsilon_
Definition: element_wise_operation.hpp:550
__host__ constexpr __device__ void operator()(T1 &y, const T1 &x, const T2 &mean, const T2 &variance, const T3 &gamma, const T4 &beta) const
Definition: element_wise_operation.hpp:529
NormalizeInInfer(double epsilon=1e-4)
Definition: element_wise_operation.hpp:526
Definition: element_wise_operation.hpp:390
ScaleAddScaleAddRelu(const float alpha1=1.f, const float alpha2=1.f)
Definition: element_wise_operation.hpp:392
const float alpha2_
Definition: element_wise_operation.hpp:450
const float alpha1_
Definition: element_wise_operation.hpp:449
__host__ constexpr __device__ void operator()(E &e, const C &c, const D0 &d0, const D1 &d1) const
__host__ __device__ void operator()(ck::bhalf_t &y, float &x) const
Definition: element_wise_operation.hpp:570
__host__ __device__ void operator()(float &y, ck::bhalf_t &x) const
Definition: element_wise_operation.hpp:561
Definition: element_wise_operation.hpp:556