/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/core/utility/functional.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/core/utility/functional.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/core/utility/functional.hpp Source File
functional.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
10 #include <stdint.h>
11 #include <utility>
12 
13 namespace ck_tile {
14 
15 namespace detail {
16 
17 struct swallow
18 {
19  template <typename... Ts>
20  CK_TILE_HOST_DEVICE constexpr swallow(Ts&&...)
21  {
22  }
23 };
24 
25 template <class>
27 
28 template <index_t... Is>
29 struct static_for_impl<sequence<Is...>>
30 {
31  template <class F>
32  CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
33  {
34  swallow{(f(number<Is>{}), 0)...};
35  }
36 };
37 
38 } // namespace detail
39 
40 // F signature: F(number<Iter>)
41 template <index_t NBegin, index_t NEnd, index_t Increment>
42 struct static_for
43 {
45  {
46  static_assert(Increment != 0 && (NEnd - NBegin) % Increment == 0,
47  "Wrong! should satisfy (NEnd - NBegin) % Increment == 0");
48  static_assert((Increment > 0 && NBegin <= NEnd) || (Increment < 0 && NBegin >= NEnd),
49  "wrongs! should (Increment > 0 && NBegin <= NEnd) || (Increment < 0 && "
50  "NBegin >= NEnd)");
51  }
52 
53  template <class F>
54  CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
55  {
57  f);
58  }
59 };
60 
61 struct identity
62 {
63  template <typename T>
64  CK_TILE_HOST_DEVICE constexpr T&& operator()(T&& arg) const noexcept
65  {
66  return std::forward<T>(arg);
67  }
68 };
69 
70 namespace detail {
71 
72 // RemainLengths: sequence<...>
73 // Orders: sequence<...>
74 template <class RemainLengths, class Orders>
76 {
78  {
79  static_assert(RemainLengths::size() > 0, "wrong! should not get here");
80  }
81 
82  // F signature: F(sequence<...>)
83  // CurrentOrderedId: sequence<...>
84  template <class F, class CurrentOrderedId>
85  CK_TILE_HOST_DEVICE constexpr void operator()(F f, CurrentOrderedId) const
86  {
87  static_for<0, RemainLengths::front(), 1>{}([=](auto I) {
88  static_ford_impl<decltype(RemainLengths::pop_front()), Orders>{}(
89  f, CurrentOrderedId::push_back(I));
90  });
91  }
92 };
93 
94 template <class Orders>
95 struct static_ford_impl<sequence<>, Orders>
96 {
97  // F signature: F(sequence<...>)
98  // OrderedId: sequence<...>
99  template <class F, class OrderedId>
100  CK_TILE_HOST_DEVICE constexpr void operator()(F f, OrderedId) const
101  {
102  // retrive unordered Id
103  f(OrderedId::reorder_old_to_new(Orders{}));
104  }
105 };
106 
107 } // namespace detail
108 
109 // Lengths is sequence<...>, it is the length of each dimension for
110 // N-dimensional loop
111 // Orders is sequence<...>, it is the order of dimension in which static_ford
112 // will loop over each
113 // dimension
114 template <class Lengths,
115  class Orders = typename arithmetic_sequence_gen<0, Lengths::size(), 1>::type>
117 {
119  {
120  static_assert(Lengths::size() > 0, "wrong! Lengths is empty");
121  static_assert(Lengths::size() == Orders::size(), "wrong! inconsistent size");
122  }
123 
124  // F signature: F(sequence<...> multi_id)
125  // multi_id is the unordered multi-index
126  template <class F>
127  CK_TILE_HOST_DEVICE constexpr void operator()(F f) const
128  {
129  constexpr auto ordered_lengths = Lengths::reorder_new_to_old(Orders{});
130  detail::static_ford_impl<decltype(ordered_lengths), Orders>{}(f, sequence<>{});
131  }
132 };
133 
134 namespace detail {
135 
136 template <typename Indices>
137 struct unpack_impl;
138 
139 template <index_t... Is>
140 struct unpack_impl<sequence<Is...>>
141 {
142  template <typename F, typename X>
143  CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x) const
144  {
145 #if 0
146  return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...);
147 #else
148  return std::forward<F>(f)(std::forward<X>(x).template at<Is>()...);
149 #endif
150  }
151 };
152 
153 template <typename Seq0, typename Seq1>
155 
156 // TODO: remove this, after properly implementing unpack that takes any number of containers
157 template <index_t... Is, index_t... Js>
158 struct unpack2_impl<sequence<Is...>, sequence<Js...>>
159 {
160  template <typename F, typename X, typename Y>
161  CK_TILE_HOST_DEVICE constexpr auto operator()(F&& f, X&& x, Y&& y) const
162  {
163 #if 0
164  return std::forward<F>(f)(std::forward<X>(x).at(number<Is>{})...,
165  std::forward<Y>(y).at(number<Js>{})...);
166 #else
167  return std::forward<F>(f)(std::forward<X>(x).template at<Is>()...,
168  std::forward<Y>(y).template at<Js>()...);
169 #endif
170  }
171 };
172 
173 } // namespace detail
174 
175 template <typename F, typename X>
176 CK_TILE_HOST_DEVICE constexpr auto unpack(F&& f, X&& x)
177 {
178  using X_ = remove_reference_t<X>;
179  return detail::unpack_impl<typename arithmetic_sequence_gen<0, X_::size(), 1>::type>{}(
180  std::forward<F>(f), std::forward<X>(x));
181 }
182 
183 // TODO: properly implement unpack that takes any number of containers
184 template <typename F, typename X, typename Y>
185 CK_TILE_HOST_DEVICE constexpr auto unpack2(F&& f, X&& x, Y&& y)
186 {
187  using X_ = remove_reference_t<X>;
188  using Y_ = remove_reference_t<Y>;
189  return detail::unpack2_impl<typename arithmetic_sequence_gen<0, X_::size(), 1>::type,
190  typename arithmetic_sequence_gen<0, Y_::size(), 1>::type>{}(
191  std::forward<F>(f), std::forward<X>(x), std::forward<Y>(y));
192 }
193 
194 // z = predicate ? x : y
195 template <bool predicate, typename X, typename Y>
196 constexpr auto conditional_expr(X&& x, Y&& y)
197 {
198  if constexpr(predicate)
199  {
200  return std::forward<X>(x);
201  }
202  else
203  {
204  return std::forward<Y>(y);
205  }
206 }
207 
208 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
constexpr auto conditional_expr(X &&x, Y &&y)
Definition: functional.hpp:196
constexpr CK_TILE_HOST_DEVICE auto unpack2(F &&f, X &&x, Y &&y)
Definition: functional.hpp:185
int32_t index_t
Definition: integer.hpp:9
typename std::remove_reference< T >::type remove_reference_t
Definition: type_traits.hpp:14
constexpr CK_TILE_HOST_DEVICE auto unpack(F &&f, X &&x)
Definition: functional.hpp:176
Definition: sequence.hpp:278
Definition: integral_constant.hpp:13
constexpr CK_TILE_HOST_DEVICE void operator()(F f) const
Definition: functional.hpp:32
Definition: functional.hpp:26
constexpr CK_TILE_HOST_DEVICE void operator()(F f, OrderedId) const
Definition: functional.hpp:100
Definition: functional.hpp:76
constexpr CK_TILE_HOST_DEVICE static_ford_impl()
Definition: functional.hpp:77
constexpr CK_TILE_HOST_DEVICE void operator()(F f, CurrentOrderedId) const
Definition: functional.hpp:85
Definition: functional.hpp:18
constexpr CK_TILE_HOST_DEVICE swallow(Ts &&...)
Definition: functional.hpp:20
constexpr CK_TILE_HOST_DEVICE auto operator()(F &&f, X &&x, Y &&y) const
Definition: functional.hpp:161
Definition: functional.hpp:154
constexpr CK_TILE_HOST_DEVICE auto operator()(F &&f, X &&x) const
Definition: functional.hpp:143
Definition: functional.hpp:137
Definition: functional.hpp:62
constexpr CK_TILE_HOST_DEVICE T && operator()(T &&arg) const noexcept
Definition: functional.hpp:64
Definition: sequence.hpp:52
Definition: functional.hpp:43
constexpr CK_TILE_HOST_DEVICE void operator()(F f) const
Definition: functional.hpp:54
constexpr CK_TILE_HOST_DEVICE static_for()
Definition: functional.hpp:44
Definition: functional.hpp:117
constexpr CK_TILE_HOST_DEVICE void operator()(F f) const
Definition: functional.hpp:127
constexpr CK_TILE_HOST_DEVICE static_ford()
Definition: functional.hpp:118