/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_elementwise.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_elementwise.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/core/tensor/tile_elementwise.hpp Source File
tile_elementwise.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
15 
16 namespace ck_tile {
17 
18 // TODO: support tensors with different distribution
19 template <typename InOutElementFunc,
20  typename... InOutDstrTensors,
21  typename = std::enable_if_t<std::conjunction_v<
22  std::negation<std::is_same<std::remove_const_t<InOutDstrTensors>, null_tensor>>...>>>
23 CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc& inout_element_func,
24  InOutDstrTensors&... inout_dstr_tensors)
25 {
26  // TODO: make sure all distributed tensors have same lengths and distribution
27  // static_assert(xxx);
28 
29  constexpr index_t thread_buffer_size =
30  __type_pack_element<0, InOutDstrTensors...>::get_thread_buffer_size();
31 
33  [&](auto i) { inout_element_func(inout_dstr_tensors.get_thread_buffer().at(i)...); });
34 }
35 
36 template <typename InElementFunc,
37  typename... InTensor,
38  typename = std::enable_if_t<
39  std::conjunction_v<std::negation<std::is_same<InTensor, null_tensor>>...>>>
40 CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
41  const InTensor&... in_dstr_tensors)
42 {
43  using OutDataType = decltype(in_element_func(typename InTensor::DataType{}...));
44 
45  // TODO: make sure all distributed tensors have same lengths and distribution
46  // static_assert(xxx);
47  constexpr auto in_tile_dstr = __type_pack_element<0, InTensor...>::get_tile_distribution();
48 
49  constexpr index_t thread_buffer_size =
50  __type_pack_element<0, InTensor...>::get_thread_buffer_size();
51 
52  auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
53 
55  out_dstr_tensor.get_thread_buffer()(i) =
56  in_element_func(in_dstr_tensors.get_thread_buffer()[i]...);
57  });
58 
59  return out_dstr_tensor;
60 }
61 
70 template <typename InElementFunc, typename Tuple, size_t... I>
71 CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc& in_element_func,
72  const Tuple& t,
73  std::index_sequence<I...>)
74 {
75  return tile_elementwise_inout(in_element_func, t[number<I>{}]...);
76 }
77 
86 template <typename InElementFunc, typename Tuple>
87 CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc& in_element_func,
88  const Tuple& t)
89 {
90  static constexpr auto size = Tuple::size();
91  return tile_elementwise_inout_unpack(in_element_func, t, std::make_index_sequence<size>{});
92 }
93 
94 template <typename DstrTensors, typename T>
95 CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, const T& value)
96 {
98  [&value](auto& x) {
99  x = type_convert<typename DstrTensors::DataType, remove_cvref_t<T>>(value);
100  },
101  dstr_tensor);
102 }
103 
104 template <typename T>
106 {
107 }
108 
109 // TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with
110 // sub-dword tensor...
111 template <typename DstrTensors, index_t v, bool skip_subdword_opt = false>
112 CK_TILE_DEVICE void
113 set_tile(DstrTensors& dstr_tensor, number<v>, bool_constant<skip_subdword_opt> = {})
114 {
115  using elem_type = typename DstrTensors::DataType;
116  constexpr index_t elem_size = sizeof(elem_type);
117 
118  constexpr index_t tensor_bytes = DstrTensors::get_thread_buffer_size() * elem_size;
119 
120  // # bytes per write = 4
121  if constexpr(v == 0 && tensor_bytes % 4 == 0 && !skip_subdword_opt)
122  {
123 #if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
124  auto& buffer = dstr_tensor.get_thread_buffer();
125 
126  static_for<0, tensor_bytes / 4, 1>{}([&](auto i_write) {
127  if constexpr(elem_size == 1)
128  {
129  // # elements per write = 4
130  constexpr auto values = ext_vector_t<elem_type, 4>{0, 0, 0, 0};
131 
132  buffer[i_write * 4 + 0] = values.x;
133  buffer[i_write * 4 + 1] = values.y;
134  buffer[i_write * 4 + 2] = values.z;
135  buffer[i_write * 4 + 3] = values.w;
136  }
137  else if constexpr(elem_size == 2)
138  {
139  // # elements per write = 2
140  constexpr auto values = ext_vector_t<elem_type, 2>{0, 0};
141 
142  buffer[i_write * 2 + 0] = values.x;
143  buffer[i_write * 2 + 1] = values.y;
144  }
145  else if constexpr(elem_size == 4)
146  {
147  // # elements per write = 1
148  constexpr elem_type value = 0;
149 
150  buffer[i_write] = value;
151  }
152  else
153  {
154  static_assert(false, "type not supported");
155  }
156  });
157 #else
158  using dvec_t = array<index_t, tensor_bytes / 4>;
159  auto& tensor = reinterpret_cast<dvec_t&>(dstr_tensor.get_thread_buffer());
160  for(auto i = 0; i < tensor.size(); i++)
161  tensor.get(i) = v;
162 #endif
163  }
164  else
165  {
166  tile_elementwise_inout([](auto& x) { x = type_convert<elem_type, index_t>(v); },
167  dstr_tensor);
168  }
169 }
170 
171 template <index_t v>
173 {
174 }
175 
176 template <typename DstrTensors>
177 CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor)
178 {
179  set_tile(dstr_tensor, 0);
180 }
181 
182 namespace impl {
183 // TODO: this is ugly
184 template <typename OutDataType, typename InTensor>
185 CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor& in_dstr_tensors)
186 {
187 #if defined(__gfx94__) || defined(__gfx12__)
188  // This API is designed to use the _pk_ serious of function
189  constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
190 
191  constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
192  static_assert(thread_buffer_size % 4 == 0);
193  constexpr index_t thread_buffer_size_pk = thread_buffer_size / 4;
194 
195  auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
196 #pragma clang diagnostic push
197 #pragma clang diagnostic ignored "-Wuninitialized"
198  // __builtin_amdgcn_cvt_pk_fp8_f32() this builtin requires the old value, and
199  // will generate a v_mov_b32 vxxx [old] before cvt, which result in unwanted ISA
200  // so we prepare an uninitialized variable purposely, and turn off the warning
201  int dummy_old;
203  uint32_t x = __builtin_amdgcn_cvt_pk_fp8_f32(
204  in_dstr_tensors.get_thread_buffer()[number<4 * i + 0>{}],
205  in_dstr_tensors.get_thread_buffer()[number<4 * i + 1>{}],
206  dummy_old,
207  false); // false -> WORD0
208 
209  uint32_t y = __builtin_amdgcn_cvt_pk_fp8_f32(
210  in_dstr_tensors.get_thread_buffer()[number<4 * i + 2>{}],
211  in_dstr_tensors.get_thread_buffer()[number<4 * i + 3>{}],
212  x,
213  true); // true -> WORD1
214 
215  using vec_t = array<OutDataType, 4>;
216 
217  vec_t d = bit_cast<vec_t>(y);
218  out_dstr_tensor.get_thread_buffer().template set_as<vec_t>(number<i>{}, d);
219  });
220 #pragma clang diagnostic pop
221 
222  return out_dstr_tensor;
223 #else
224  // fallback
225  return tile_elementwise_in(type_convert<OutDataType, typename InTensor::DataType>,
226  in_dstr_tensors);
227 #endif
228 }
229 
230 template <typename OutDataType, typename InTensor>
231 CK_TILE_DEVICE auto cast_tile_pkrtz_fp16_fp32(const InTensor& in_dstr_tensors)
232 {
233 #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
234  // This API is designed to use the _pk_ serious of function
235  constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
236 
237  constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
238  static_assert(thread_buffer_size % 2 == 0);
239  constexpr index_t thread_buffer_size_pk = thread_buffer_size / 2;
240 
241  auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
242 
243  // TODO: this is rtz cvt, need be very careful
244  for(index_t i = 0; i < thread_buffer_size_pk; i++)
245  {
246  auto o = __builtin_amdgcn_cvt_pkrtz(in_dstr_tensors.get_thread_buffer()[2 * i + 0],
247  in_dstr_tensors.get_thread_buffer()[2 * i + 1]);
248 
249  out_dstr_tensor.get_thread_buffer().at(2 * i + 0) = o.x;
250  out_dstr_tensor.get_thread_buffer().at(2 * i + 1) = o.y;
251  }
252 
253  return out_dstr_tensor;
254 #else
255  // fallback
256  return tile_elementwise_in(type_convert<OutDataType, typename InTensor::DataType>,
257  in_dstr_tensors);
258 #endif
259 }
260 
261 template <typename OutDataType, typename InTensor>
262 CK_TILE_DEVICE auto cast_tile_pk_fp16bf16_fp32(const InTensor& in_dstr_tensors)
263 {
264  // This API is designed to help compiler to identify pairs of f32 -> fp16/bf16 cast and use
265  // cvt_pk instruction when possible
266  constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
267 
268  constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
269  static_assert(thread_buffer_size % 2 == 0);
270  auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
271  using f16x2_t = std::conditional_t<std::is_same_v<OutDataType, fp16_t>, fp16x2_t, bf16x2_t>;
272  for(index_t i = 0; i < thread_buffer_size / 2; i++)
273  {
274  auto o = type_convert<f16x2_t>(fp32x2_t{
275  in_dstr_tensors.get_thread_buffer()[2 * i + 0],
276  in_dstr_tensors.get_thread_buffer()[2 * i + 1],
277  });
278 
279  out_dstr_tensor.get_thread_buffer().at(2 * i + 0) = o.x;
280  out_dstr_tensor.get_thread_buffer().at(2 * i + 1) = o.y;
281  }
282  return out_dstr_tensor;
283 }
284 
285 #if CK_TILE_USE_SUBDWORD_TILE_CAST
286 // this function assume either src or dst (or both) date type is under 1 dword
287 // we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
288 template <typename OutDataType, typename InTensor>
289 CK_TILE_DEVICE auto cast_tile_opt_subdword(const InTensor& in_dstr_tensors)
290 {
291  constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
292 
293  auto out_dstr_tensor = make_static_distributed_tensor<OutDataType>(in_tile_dstr);
294 
296  using o_type = remove_cvref_t<OutDataType>;
297  constexpr index_t i_elem_bytes = sizeof(i_type);
298  constexpr index_t o_elem_bytes = sizeof(o_type);
299  static_assert(i_elem_bytes < 4 || o_elem_bytes < 4);
300 
301  constexpr index_t bulk_size =
302  (i_elem_bytes >= o_elem_bytes) ? (4 / o_elem_bytes) : (4 / i_elem_bytes);
303  static_assert(bulk_size != 0);
304 
305  using o_bulk_type =
306  std::conditional_t<i_elem_bytes >= o_elem_bytes, float, array<o_type, bulk_size>>;
307 
308  constexpr index_t thread_buffer_size = InTensor::get_thread_buffer_size();
309 
310  constexpr index_t iters = thread_buffer_size / bulk_size;
311  constexpr index_t rems = thread_buffer_size % bulk_size;
312 
313  // cast the sequence per-bulk
314  static_for<0, iters, 1>{}([&](auto i) {
315  union bulk_wrapper
316  {
317  o_bulk_type bulk{};
318  o_type data[bulk_size];
319  } o_bulk;
320 
321  // TODO: should use below function, but somehow will result in spill (same as c-forloop)
322  static_for<0, bulk_size, 1>{}([&o_bulk, &in_dstr_tensors, &i](auto ib) {
323  o_bulk.data[ib.value] = static_cast<o_type>(
324  in_dstr_tensors.get_thread_buffer()
325  .template get_as<i_type>()[number<bulk_size * i.value + ib.value>{}]);
326  });
327 
328  // TODO: fixme, should use above!
329  // static_assert(sizeof(i_type) / sizeof(o_type) == 2);
330  // o_bulk.data[0] = static_cast<o_type>(
331  // in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 0>{}]);
332  // o_bulk.data[1] = static_cast<o_type>(
333  // in_dstr_tensors.get_thread_buffer().template get_as<i_type>()[number<2 * i + 1>{}]);
334 
335  out_dstr_tensor.get_thread_buffer().template set_as<o_bulk_type>(i, o_bulk.bulk);
336  });
337 
338  static_for<0, rems, 1>{}([&](auto r) {
339  // TODO: introducing local scratch pad?
340  auto idx = number<iters * bulk_size + r>{};
341  out_dstr_tensor.get_thread_buffer().at(idx) =
342  static_cast<o_type>(in_dstr_tensors.get_thread_buffer().at(idx));
343  });
344 
345  return out_dstr_tensor;
346 }
347 #endif
348 } // namespace impl
349 
350 template <typename DstType, typename SrcTensor>
351 CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
352 {
353  if constexpr((std::is_same_v<DstType, fp8_t> || std::is_same_v<DstType, bf8_t>) &&
354  std::is_same_v<typename SrcTensor::DataType, float> &&
355  (SrcTensor::get_thread_buffer_size() % 4 == 0))
356  return impl::cast_tile_pk_fp8_fp32<DstType, SrcTensor>(src_tensor);
357 #if CK_TILE_USE_PK_FP16_TILE_CAST
358  else if constexpr(std::is_same_v<DstType, fp16_t> &&
359  std::is_same_v<typename SrcTensor::DataType, float> &&
360  (SrcTensor::get_thread_buffer_size() % 2 == 0))
361  return impl::cast_tile_pkrtz_fp16_fp32<DstType, SrcTensor>(src_tensor);
362 #endif
363 #if 0 // currently it causes extra spills in qr_async_vr pipeline of fmha_fwd
364  else if constexpr((std::is_same_v<DstType, fp16_t> || std::is_same_v<DstType, bf16_t>) &&
365  std::is_same_v<typename SrcTensor::DataType, float> &&
366  (SrcTensor::get_thread_buffer_size() % 2 == 0))
367  return impl::cast_tile_pk_fp16bf16_fp32<DstType, SrcTensor>(src_tensor);
368 #endif
369 #if CK_TILE_USE_SUBDWORD_TILE_CAST
370  else if constexpr(sizeof(DstType) < 4 || sizeof(typename SrcTensor::DataType) < 4)
371  return impl::cast_tile_opt_subdword<DstType, SrcTensor>(src_tensor);
372 #endif
373  else
374  return tile_elementwise_in(type_convert<DstType, typename SrcTensor::DataType>, src_tensor);
375 }
376 
377 // no-op function for null_tensor arguments
378 template <typename InOutElementFunc,
379  typename... MaybeNullTensor,
380  typename = std::enable_if_t<
381  std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, null_tensor>...>>>
382 CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc&, MaybeNullTensor&&...)
383 {
384 }
385 
386 // no-op function for null_tensor arguments
387 template <typename InElementFunc,
388  typename... MaybeNullTensor,
389  typename = std::enable_if_t<
390  std::disjunction_v<std::is_same<remove_cvref_t<MaybeNullTensor>, null_tensor>...>>>
391 CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc&, MaybeNullTensor&&...)
392 {
393  return null_tensor{};
394 }
395 
396 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
CK_TILE_DEVICE auto cast_tile_pkrtz_fp16_fp32(const InTensor &in_dstr_tensors)
Definition: tile_elementwise.hpp:231
CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor &in_dstr_tensors)
Definition: tile_elementwise.hpp:185
CK_TILE_DEVICE auto cast_tile_pk_fp16bf16_fp32(const InTensor &in_dstr_tensors)
Definition: tile_elementwise.hpp:262
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc &in_element_func, const InTensor &... in_dstr_tensors)
Definition: tile_elementwise.hpp:40
CK_TILE_DEVICE void set_tile(DstrTensors &dstr_tensor, const T &value)
Definition: tile_elementwise.hpp:95
bfloat16_t bf16x2_t
Definition: bfloat16.hpp:433
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition: tile_elementwise.hpp:23
float fp32x2_t
Definition: bfloat16.hpp:434
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constant< v > number
Definition: integral_constant.hpp:37
_Float16 fp16x2_t
Definition: half.hpp:385
CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc &in_element_func, const Tuple &t, std::index_sequence< I... >)
Template function that "unpacks" a tuple and applies an element-wise operation.
Definition: tile_elementwise.hpp:71
CK_TILE_DEVICE auto cast_tile(const SrcTensor &src_tensor)
Definition: tile_elementwise.hpp:351
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:177
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:27
const GenericPointer< typename T::ValueType > T2 value
Definition: pointer.h:1697
unsigned int uint32_t
Definition: stdint.h:126
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
Definition: integral_constant.hpp:13
Definition: null_tensor.hpp:9
Definition: functional.hpp:43