/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/utility/reduce_operator.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/utility/reduce_operator.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/utility/reduce_operator.hpp Source File
reduce_operator.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
8 
9 namespace ck_tile {
10 
11 namespace ReduceOp {
12 // y = ReduceOp(y, x);
13 struct Add
14 {
15  template <typename T>
17  {
18  return type_convert<T>(0.0f);
19  };
20 
21  template <typename T,
23  CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
24  {
25  return y + x;
26  }
27 
28  template <typename T,
30  CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
31  {
32  float y_ = type_convert<float>(y);
33  float x_ = type_convert<float>(x);
34 
35  return type_convert<T>(y_ + x_);
36  }
37 };
38 
39 struct SquareAdd
40 {
41  template <typename T>
43  {
44  return type_convert<T>(0.0f);
45  };
46 
47  template <typename T,
49  CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
50  {
51  return y + (x * x);
52  }
53 
54  template <typename T,
56  CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
57  {
58  float y_ = type_convert<float>(y);
59  float x_ = type_convert<float>(x);
60  return type_convert<T>(y_ + (x_ * x_));
61  }
62 };
63 
64 struct Max
65 {
66  template <
67  typename T,
68  typename = std::enable_if_t<
71  {
72  return numeric<T>::lowest();
73  };
74 
75  template <
76  typename T,
77  typename = std::enable_if_t<
79  CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
80  {
81  return max(y, x);
82  }
83 
84  // Overload with changed flag for index tracking
85  template <
86  typename T,
87  typename = std::enable_if_t<
89  CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const
90  {
91  T new_max = max(y, x);
92  if(x > y)
93  {
94  changed = true;
95  }
96  return new_max;
97  }
98 };
99 
100 struct AbsMax
101 {
102  template <
103  typename T,
104  typename = std::enable_if_t<
107  {
108  return numeric<T>::zero();
109  };
110 
111  template <
112  typename T,
113  typename = std::enable_if_t<
115  CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
116  {
117  return max(y, abs(x));
118  }
119 
120  // Overload with changed flag for index tracking
121  template <
122  typename T,
123  typename = std::enable_if_t<
125  CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const
126  {
127  T new_max = max(y, abs(x));
128  if(abs(x) > y)
129  {
130  changed = true;
131  }
132  return new_max;
133  }
134 };
135 
136 } // namespace ReduceOp
137 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
CK_TILE_HOST_DEVICE bfloat16_t abs(const bfloat16_t &x)
Definition: bfloat16.hpp:400
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1697
Definition: reduce_operator.hpp:101
constexpr CK_TILE_HOST_DEVICE T operator()(const T &y, const T x, bool &changed) const
Definition: reduce_operator.hpp:125
constexpr CK_TILE_HOST_DEVICE T operator()(const T &y, const T x) const
Definition: reduce_operator.hpp:115
static constexpr CK_TILE_HOST_DEVICE T GetIdentityValue()
Definition: reduce_operator.hpp:106
Definition: reduce_operator.hpp:14
constexpr CK_TILE_HOST_DEVICE T operator()(T &y, T x) const
Definition: reduce_operator.hpp:30
static constexpr CK_TILE_HOST_DEVICE T GetIdentityValue()
Definition: reduce_operator.hpp:16
constexpr CK_TILE_HOST_DEVICE T operator()(const T &y, const T x) const
Definition: reduce_operator.hpp:23
Definition: reduce_operator.hpp:65
constexpr CK_TILE_HOST_DEVICE T operator()(const T &y, const T x, bool &changed) const
Definition: reduce_operator.hpp:89
static constexpr CK_TILE_HOST_DEVICE T GetIdentityValue()
Definition: reduce_operator.hpp:70
constexpr CK_TILE_HOST_DEVICE T operator()(const T &y, const T x) const
Definition: reduce_operator.hpp:79
Definition: reduce_operator.hpp:40
constexpr CK_TILE_HOST_DEVICE T operator()(const T &y, const T x) const
Definition: reduce_operator.hpp:49
constexpr CK_TILE_HOST_DEVICE T operator()(T &y, T x) const
Definition: reduce_operator.hpp:56
static constexpr CK_TILE_HOST_DEVICE T GetIdentityValue()
Definition: reduce_operator.hpp:42
Definition: type_traits.hpp:115
static constexpr CK_TILE_HOST_DEVICE T lowest()
Definition: numeric.hpp:23
static constexpr CK_TILE_HOST_DEVICE T zero()
Definition: numeric.hpp:58