/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_2d_reduction_threadwise.hpp Source File
gridwise_2d_reduction_threadwise.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
13 
14 namespace ck {
15 
16 template <typename GridwiseReduction,
17  bool OutputIndex,
18  bool TransformIndexKtoGlobal,
19  bool HaveIndexInput,
20  typename InDataType,
21  typename OutDataType,
22  typename AccDataType,
23  typename IndexDataType,
24  typename InGridDesc_M_K,
25  typename OutGridDesc_M,
26  typename InElementwiseOperation,
27  typename AccElementwiseOperation>
28 __global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k,
29  const OutGridDesc_M out_grid_desc_m,
30  const InElementwiseOperation in_elementwise_op,
31  const AccElementwiseOperation acc_elementwise_op,
32  AccDataType alpha,
33  const InDataType* const __restrict__ p_in_value_global,
34  const IndexDataType* const __restrict__ p_in_index_global,
35  AccDataType beta,
36  OutDataType* const __restrict__ p_out_value_global,
37  IndexDataType* const __restrict__ p_out_index_global)
38 {
39  if constexpr(!OutputIndex)
40  {
41  GridwiseReduction::Run(in_grid_desc_m_k,
42  out_grid_desc_m,
43  in_elementwise_op,
44  acc_elementwise_op,
45  alpha,
46  p_in_value_global,
47  beta,
48  p_out_value_global);
49  }
50  else
51  {
52  GridwiseReduction::template RunWithIndex<TransformIndexKtoGlobal, HaveIndexInput>(
53  in_grid_desc_m_k,
54  out_grid_desc_m,
55  in_elementwise_op,
56  acc_elementwise_op,
57  alpha,
58  p_in_value_global,
59  p_in_index_global,
60  beta,
61  p_out_value_global,
62  p_out_index_global);
63  };
64 };
65 
66 template <typename InDataType,
67  typename OutDataType,
68  typename AccDataType,
69  typename IndexDataType,
70  typename InGridDesc_M_K,
71  typename OutGridDesc_M,
72  typename ReduceOperation,
73  typename InElementwiseOperation,
74  typename AccElementwiseOperation,
75  InMemoryDataOperationEnum OutMemoryDataOperation,
76  bool PropagateNan,
77  index_t BlockSize,
78  index_t MThreadSliceSize,
79  index_t KThreadSliceSize,
80  index_t InSrcVectorDim,
81  index_t InSrcVectorSize,
82  index_t OutDstVectorSize>
84 {
85  static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
86  (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
87  (MThreadSliceSize % OutDstVectorSize == 0),
88  "Invalid thread slice sizes and/or vector sizes configuration, please check!");
89 
92 
97 
99 
100  static constexpr auto I0 = Number<0>{};
101 
102  __device__ static void Run(const InGridDesc_M_K& in_grid_desc_m_k,
103  const OutGridDesc_M& out_grid_desc_m,
104  const InElementwiseOperation& in_elementwise_op,
105  const AccElementwiseOperation& acc_elementwise_op,
106  AccDataType alpha,
107  const InDataType* const __restrict__ p_in_value_global,
108  AccDataType beta,
109  OutDataType* const __restrict__ p_out_value_global)
110  {
111  using ThreadwiseReduce = ThreadwiseReduction<AccDataType,
114  ReduceOperation,
115  PropagateNan>;
116 
117  const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
118 
119  const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
120  p_in_value_global,
121  in_grid_desc_m_k.GetElementSpaceSize(),
122  ReduceOperation::template GetIdentityValue<InDataType>());
123  auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
124  p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
125 
127  in_thread_buf;
128 
130 
131  static_for<0, MThreadSliceSize, 1>{}([&](auto I) { accu_value_buf(I) = identityVal; });
132 
133  const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
134 
135  using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
136  constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
138 
139  index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
140 
141  auto threadwise_src_val_load =
143  AccDataType,
144  InGridDesc_M_K,
145  decltype(thread_buffer_desc),
146  ThreadBufferLengths,
148  InSrcVectorDim,
149  InSrcVectorSize,
150  1,
151  false>(
152  in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
153 
154  constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
155 
156  index_t reducedLength = 0;
157  do
158  {
159  threadwise_src_val_load.Run(in_grid_desc_m_k,
160  in_global_val_buf,
161  thread_buffer_desc,
162  make_tuple(I0, I0),
163  in_thread_buf);
164 
165  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
166  // do element-wise pre-reduction operation
167  static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
168  constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
169  in_elementwise_op(in_thread_buf(Number<offset>{}),
170  in_thread_buf(Number<offset>{}));
171  });
172  });
173 
174  ThreadwiseReduce::Reduce(in_thread_buf, accu_value_buf);
175 
176  threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
177 
178  reducedLength += KThreadSliceSize;
179  } while(reducedLength < toReduceLength);
180 
181  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
182  acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
183 
184  accu_value_buf(I) *= alpha;
185  });
186 
187  constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
188 
189  if(!float_equal_zero{}(beta))
190  {
191  auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<OutDataType,
192  OutDataType,
193  OutGridDesc_M,
194  decltype(reduced_data_desc),
196  Sequence<0>,
197  0,
198  1,
199  1,
200  true>(
201  out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
202 
204  priorDstValue_buf;
205 
206  threadwise_dst_load.Run(out_grid_desc_m,
207  dst_global_buf,
208  reduced_data_desc,
209  make_tuple(I0),
210  priorDstValue_buf);
211 
212  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
213  accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
214  });
215  };
216 
217  auto threadwise_dst_store = ThreadwiseTensorSliceTransfer_v1r3<AccDataType,
218  OutDataType,
219  decltype(reduced_data_desc),
220  OutGridDesc_M,
223  Sequence<0>,
224  0,
225  OutDstVectorSize,
226  OutMemoryDataOperation,
227  1,
228  false>(
229  out_grid_desc_m,
230  make_multi_index(thread_global_1d_id * MThreadSliceSize),
231  PassThroughOp{});
232 
233  threadwise_dst_store.Run(
234  reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, dst_global_buf);
235  };
236 
237  template <bool TransformIndexKtoGlobal, bool HaveIndexInput>
238  __device__ static void RunWithIndex(const InGridDesc_M_K& in_grid_desc_m_k,
239  const OutGridDesc_M& out_grid_desc_m,
240  const InElementwiseOperation& in_elementwise_op,
241  const AccElementwiseOperation& acc_elementwise_op,
242  AccDataType alpha,
243  const InDataType* const __restrict__ p_in_value_global,
244  const IndexDataType* const __restrict__ p_in_index_global,
245  AccDataType beta,
246  OutDataType* const __restrict__ p_out_value_global,
247  IndexDataType* const __restrict__ p_out_index_global)
248  {
249  using ThreadwiseReduceWithIndex = ThreadwiseReductionWithIndex<AccDataType,
250  IndexDataType,
253  ReduceOperation,
254  PropagateNan>;
255 
256  (void)acc_elementwise_op;
257 
258  const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
259 
260  const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
261  p_in_value_global,
262  in_grid_desc_m_k.GetElementSpaceSize(),
263  ReduceOperation::template GetIdentityValue<InDataType>());
264  const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
265  p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
266 
267  auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
268  p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
269  auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
270  p_out_index_global, out_grid_desc_m.GetElementSpaceSize());
271 
273  in_thread_val_buf;
274 
276  IndexDataType,
277  MThreadSliceSize * KThreadSliceSize,
278  true>
279  in_thread_idx_buf;
280 
283 
284  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
285  accu_value_buf(I) = identityVal;
286  accu_index_buf(I) = 0;
287  });
288 
289  const auto toReduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
290 
291  using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
292  constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
294 
295  index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id();
296 
297  auto threadwise_src_val_load =
299  AccDataType,
300  InGridDesc_M_K,
301  decltype(thread_buffer_desc),
302  ThreadBufferLengths,
304  InSrcVectorDim,
305  InSrcVectorSize,
306  1,
307  false>(
308  in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
309 
310  constexpr auto in_thread_copy_step = make_multi_index(0, KThreadSliceSize);
311 
312  index_t indexStart = 0;
313  index_t reducedLength = 0;
314  if constexpr(HaveIndexInput)
315  {
316  auto threadwise_src_idx_load =
317  ThreadwiseTensorSliceTransfer_v2<IndexDataType,
318  IndexDataType,
319  InGridDesc_M_K,
320  decltype(thread_buffer_desc),
321  ThreadBufferLengths,
323  InSrcVectorDim,
324  InSrcVectorSize,
325  1,
326  false>(
327  in_grid_desc_m_k, make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
328 
329  do
330  {
331  threadwise_src_val_load.Run(in_grid_desc_m_k,
332  in_global_val_buf,
333  thread_buffer_desc,
334  make_tuple(I0, I0),
335  in_thread_val_buf);
336 
337  threadwise_src_idx_load.Run(in_grid_desc_m_k,
338  in_global_idx_buf,
339  thread_buffer_desc,
340  make_tuple(I0, I0),
341  in_thread_idx_buf);
342 
343  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
344  // do element-wise pre-reduction operation
345  static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
346  constexpr auto offset =
347  thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
348 
349  in_elementwise_op(in_thread_val_buf(Number<offset>{}),
350  in_thread_val_buf(Number<offset>{}));
351  });
352  });
353 
354  ThreadwiseReduceWithIndex::Reduce(
355  in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
356 
357  threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
358  threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
359 
360  indexStart += KThreadSliceSize;
361  reducedLength += KThreadSliceSize;
362  } while(reducedLength < toReduceLength);
363  }
364  else
365  {
366  do
367  {
368  threadwise_src_val_load.Run(in_grid_desc_m_k,
369  in_global_val_buf,
370  thread_buffer_desc,
371  make_tuple(I0, I0),
372  in_thread_val_buf);
373 
374  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
375  // do element-wise pre-reduction operation
376  static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
377  constexpr auto offset =
378  thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
379 
380  in_thread_idx_buf(Number<offset>{}) = indexStart + iK();
381 
382  in_elementwise_op(in_thread_val_buf(Number<offset>{}),
383  in_thread_val_buf(Number<offset>{}));
384  });
385  });
386 
387  ThreadwiseReduceWithIndex::Reduce(
388  in_thread_val_buf, in_thread_idx_buf, accu_value_buf, accu_index_buf);
389 
390  threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
391 
392  indexStart += KThreadSliceSize;
393  reducedLength += KThreadSliceSize;
394  } while(reducedLength < toReduceLength);
395 
396  if constexpr(TransformIndexKtoGlobal)
397  {
398  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
399  const auto coord = make_tensor_coordinate(
400  in_grid_desc_m_k,
401  make_multi_index(thread_global_1d_id * MThreadSliceSize + I,
402  accu_index_buf(I)));
403 
404  accu_index_buf(I) = coord.GetOffset();
405  });
406  }
407  };
408 
409  // for indiced operation, acc_elementwise_op shoud do nothing
410  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
411  acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
412 
413  accu_value_buf(I) *= alpha;
414  });
415 
416  constexpr auto reduced_data_desc = ThreadReduceDstDesc_M{};
417 
418  if(!float_equal_zero{}(beta))
419  {
420  auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<OutDataType,
421  OutDataType,
422  OutGridDesc_M,
423  decltype(reduced_data_desc),
425  Sequence<0>,
426  0,
427  1,
428  1,
429  false>(
430  out_grid_desc_m, make_multi_index(thread_global_1d_id * MThreadSliceSize));
431 
433  priorDstValue_buf;
434 
435  threadwise_dst_load.Run(out_grid_desc_m,
436  out_global_val_buf,
437  reduced_data_desc,
438  make_tuple(I0),
439  priorDstValue_buf);
440 
441  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
442  accu_value_buf(I) += type_convert<AccDataType>(priorDstValue_buf[I]) * beta;
443  });
444  };
445 
446  auto threadwise_dst_val_store =
448  OutDataType,
449  decltype(reduced_data_desc),
450  OutGridDesc_M,
453  Sequence<0>,
454  0,
455  OutDstVectorSize,
456  OutMemoryDataOperation,
457  1,
458  false>(
459  out_grid_desc_m,
460  make_multi_index(thread_global_1d_id * MThreadSliceSize),
461  PassThroughOp{});
462 
463  auto threadwise_dst_idx_store =
465  IndexDataType,
466  decltype(reduced_data_desc),
467  OutGridDesc_M,
470  Sequence<0>,
471  0,
472  OutDstVectorSize,
473  OutMemoryDataOperation,
474  1,
475  false>(
476  out_grid_desc_m,
477  make_multi_index(thread_global_1d_id * MThreadSliceSize),
478  PassThroughOp{});
479 
480  threadwise_dst_val_store.Run(
481  reduced_data_desc, make_tuple(I0), accu_value_buf, out_grid_desc_m, out_global_val_buf);
482 
483  threadwise_dst_idx_store.Run(
484  reduced_data_desc, make_tuple(I0), accu_index_buf, out_grid_desc_m, out_global_idx_buf);
485  };
486 };
487 
488 } // namespace ck
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
InMemoryDataOperationEnum
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition: gridwise_2d_reduction_threadwise.hpp:28
__host__ constexpr __device__ auto make_tensor_coordinate(const TensorDesc &tensor_desc, const VisibleIndex &idx_visible)
Definition: tensor_descriptor.hpp:407
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:16
Definition: gridwise_2d_reduction_threadwise.hpp:84
static constexpr auto I0
Definition: gridwise_2d_reduction_threadwise.hpp:100
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_2d_reduction_threadwise.hpp:98
static __device__ void RunWithIndex(const InGridDesc_M_K &in_grid_desc_m_k, const OutGridDesc_M &out_grid_desc_m, const InElementwiseOperation &in_elementwise_op, const AccElementwiseOperation &acc_elementwise_op, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition: gridwise_2d_reduction_threadwise.hpp:238
typename conditional< InSrcVectorDim==0, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition: gridwise_2d_reduction_threadwise.hpp:91
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition: gridwise_2d_reduction_threadwise.hpp:96
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition: gridwise_2d_reduction_threadwise.hpp:94
static __device__ void Run(const InGridDesc_M_K &in_grid_desc_m_k, const OutGridDesc_M &out_grid_desc_m, const InElementwiseOperation &in_elementwise_op, const AccElementwiseOperation &acc_elementwise_op, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global)
Definition: gridwise_2d_reduction_threadwise.hpp:102
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: reduction_functions_threadwise.hpp:23
Definition: reduction_functions_threadwise.hpp:65
Definition: threadwise_tensor_slice_transfer.hpp:39
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:66
Definition: threadwise_tensor_slice_transfer.hpp:214
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:243
Definition: functional.hpp:100
Definition: reduction_common.hpp:20
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
Definition: unary_element_wise_operation.hpp:241