/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_reduce_threadwise.hpp Source File
device_reduce_threadwise.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 #include <array>
9 
15 
16 namespace ck {
17 namespace tensor_operation {
18 namespace device {
19 
20 template <typename InDataType,
21  typename AccDataType,
22  typename OutDataType,
23  index_t Rank,
24  index_t NumReduceDim,
25  typename ReduceOperation,
26  typename InElementwiseOperation,
27  typename AccElementwiseOperation,
28  bool PropagateNan,
29  bool OutputIndex,
30  bool TransformIndexKtoGlobal,
31  bool HaveIndexInputIfOutputIndex,
32  index_t BlockSize,
33  index_t MThreadSliceSize,
34  index_t KThreadSliceSize,
35  index_t InSrcVectorDim,
36  index_t InSrcVectorSize,
37  index_t OutDstVectorSize>
38 struct DeviceReduceThreadWise : public DeviceReduce<InDataType,
39  AccDataType,
40  OutDataType,
41  Rank,
42  NumReduceDim,
43  ReduceOperation,
44  InElementwiseOperation,
45  AccElementwiseOperation,
46  PropagateNan,
47  OutputIndex>
48 
49 {
50  static_assert(Rank <= 12, "Bigger Rank size is not supported!");
51 
52  static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
53  (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
54  (MThreadSliceSize % OutDstVectorSize == 0),
55  "Invalid thread slice sizes and/or vector sizes configuration, please check!");
56 
57  using IndexDataType = int32_t;
58 
59  static constexpr bool HaveIndexInput = OutputIndex && HaveIndexInputIfOutputIndex;
60 
61  static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
62 
63  static constexpr index_t NumSrcDim = Rank;
64  static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
65  static constexpr bool reduceAllDim = (NumInvariantDim == 0);
66 
67  static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize;
68  static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize;
69 
70  static auto MakeSrc2dDescriptor(const std::array<index_t, Rank>& inLengths,
71  const std::array<index_t, Rank>& inStrides)
72  {
73  const auto tupleSrcLengths =
74  generate_tuple([&](auto I) { return inLengths[I]; }, Number<Rank>{});
75  const auto tupleSrcStrides =
76  generate_tuple([&](auto I) { return inStrides[I]; }, Number<Rank>{});
77 
78  const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
79 
80  const auto in_grid_desc_m_k = [&]() {
81  if constexpr(reduceAllDim)
82  {
83  const auto one_dim_inDesc = transform_tensor_descriptor(
84  inDesc,
85  make_tuple(make_merge_transform(tupleSrcLengths)),
88 
89  return transform_tensor_descriptor(one_dim_inDesc,
91  1, one_dim_inDesc.GetLength(Number<0>{})))),
94  }
95  else
96  {
97  using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
99 
100  const auto reduceDimLengths = generate_tuple(
101  [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
102  const auto invariantDimLengths =
103  generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
104 
106  inDesc,
107  make_tuple(make_merge_transform(invariantDimLengths),
108  make_merge_transform(reduceDimLengths)),
109  make_tuple(InvariantDims{}, ReduceDims{}),
111  }
112  }();
113 
114  const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
115  const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
116 
117  const auto inPad_M =
118  math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
119  const auto inPad_K =
120  math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
121 
122  auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
123  in_grid_desc_m_k,
124  make_tuple(make_right_pad_transform(invariantLength, inPad_M),
125  make_right_pad_transform(reduceLength, inPad_K)),
128 
129  return (in_grid_desc_m_k_padded);
130  };
131 
132  static auto MakeDst1dDescriptor(const std::array<index_t, NumDstDim>& outLengths,
133  const std::array<index_t, NumDstDim>& outStrides)
134  {
135  const auto tupleDstLengths =
136  generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumDstDim>{});
137  const auto tupleDstStrides =
138  generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumDstDim>{});
139 
140  auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
141 
142  auto out_grid_desc_m = transform_tensor_descriptor(
143  outDesc,
144  make_tuple(make_merge_transform(tupleDstLengths)),
147 
148  const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
149 
150  const auto outPad =
151  math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
152 
153  auto out_grid_desc_m_padded = transform_tensor_descriptor(
154  out_grid_desc_m,
155  make_tuple(make_right_pad_transform(invariantLength, outPad)),
158  return (out_grid_desc_m_padded);
159  };
160 
161  struct Argument : public BaseArgument
162  {
163  Argument(const std::array<index_t, Rank> inLengths,
164  const std::array<index_t, Rank> inStrides,
165  const std::array<index_t, NumDstDim> outLengths,
166  const std::array<index_t, NumDstDim> outStrides,
167  const std::array<int, NumReduceDim> reduceDims,
168  double alpha,
169  double beta,
170  const InDataType* in_dev,
171  OutDataType* out_dev,
172  IndexDataType* out_index_dev,
173  const InElementwiseOperation in_elementwise_op,
174  const AccElementwiseOperation acc_elementwise_op)
175  : outLengths_{outLengths},
176  outStrides_{outStrides},
177  in_dev_{in_dev},
178  out_dev_{out_dev},
179  out_index_dev_{out_index_dev},
180  in_elementwise_op_{in_elementwise_op},
181  acc_elementwise_op_{acc_elementwise_op}
182  {
183  inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
184  inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
185 
186  alpha_ = type_convert<AccDataType>(alpha);
187  beta_ = type_convert<AccDataType>(beta);
188 
190  get_2d_lengths<Rank, NumReduceDim>(inLengths_);
191 
192  if constexpr(NumInvariantDim == 0)
194  else
196 
197  reduce_lowest_length = inLengths_[Rank - 1];
198 
200 
203  }
204 
205  std::array<index_t, Rank> inLengths_;
206  std::array<index_t, Rank> inStrides_;
207  std::array<index_t, NumDstDim> outLengths_;
208  std::array<index_t, NumDstDim> outStrides_;
209 
210  AccDataType alpha_;
211  AccDataType beta_;
212 
213  const InDataType* in_dev_;
214  OutDataType* out_dev_;
216 
217  InElementwiseOperation in_elementwise_op_;
218  AccElementwiseOperation acc_elementwise_op_;
219 
224 
226  size_t gridSize;
227  };
228 
229  struct Invoker : public BaseInvoker
230  {
231  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
232  {
233  const auto in_grid_desc_m_k =
235  const auto out_grid_desc_m =
237  using InGridDesc_M_K = decltype(in_grid_desc_m_k);
238  using OutGridDesc_M = decltype(out_grid_desc_m);
239 
240  float avg_time = 0;
241 
242  using GridwiseReduce =
244  OutDataType,
245  AccDataType,
247  InGridDesc_M_K,
248  OutGridDesc_M,
249  ReduceOperation,
250  InElementwiseOperation,
251  AccElementwiseOperation,
253  PropagateNan,
254  BlockSize,
255  MThreadSliceSize,
256  KThreadSliceSize,
257  InSrcVectorDim,
258  InSrcVectorSize,
259  OutDstVectorSize>;
260 
261  const auto kernel = kernel_reduce_threadwise<GridwiseReduce,
262  OutputIndex,
263  TransformIndexKtoGlobal,
265  InDataType,
266  OutDataType,
267  AccDataType,
269  InGridDesc_M_K,
270  OutGridDesc_M,
271  InElementwiseOperation,
272  AccElementwiseOperation>;
273 
274  avg_time = launch_and_time_kernel(stream_config,
275  kernel,
276  dim3(arg.gridSize),
277  dim3(BlockSize),
278  0,
279  in_grid_desc_m_k,
280  out_grid_desc_m,
281  arg.in_elementwise_op_,
283  arg.alpha_,
284  arg.in_dev_,
285  nullptr,
286  arg.beta_,
287  arg.out_dev_,
288  arg.out_index_dev_);
289 
290  return (avg_time);
291  };
292 
293  float Run(const BaseArgument* p_arg,
294  const StreamConfig& stream_config = StreamConfig{}) override
295  {
296  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
297  };
298  };
299 
300  bool IsSupportedArgument(const BaseArgument* p_arg) override
301  {
302  const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
303 
304  if constexpr(InSrcVectorDim == 0)
305  {
306  if constexpr(NumInvariantDim == 0)
307  {
308  return (false);
309  }
310  else
311  {
312  if(pArg->inStrides_[NumInvariantDim - 1] != 1)
313  return (false);
314 
315  if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
316  return (false);
317  };
318  }
319  else
320  {
321  if(pArg->inStrides_[Rank - 1] != 1)
322  return (false);
323 
324  if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
325  return (false);
326  };
327 
328  // To improve
329  if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
330  return (false);
331 
332  // cases with big reduce_total_length should be handled by Blockwise kernel
333  if(pArg->reduce_total_length / KThreadSliceSize >= 32)
334  return (false);
335 
336  return (true);
337  };
338 
339  std::unique_ptr<BaseArgument>
340  MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
341  const std::array<index_t, Rank> inStrides,
342  const std::array<index_t, NumDstDim> outLengths,
343  const std::array<index_t, NumDstDim> outStrides,
344  const std::array<int, NumReduceDim> reduceDims,
345  double alpha,
346  double beta,
347  const void* in_dev,
348  const void* in_index_dev,
349  void* out_dev,
350  void* out_index_dev,
351  const InElementwiseOperation in_elementwise_op,
352  const AccElementwiseOperation acc_elementwise_op) override
353  {
354  (void)in_index_dev;
355 
356  return std::make_unique<Argument>(inLengths,
357  inStrides,
358  outLengths,
359  outStrides,
360  reduceDims,
361  alpha,
362  beta,
363  static_cast<const InDataType*>(in_dev),
364  static_cast<OutDataType*>(out_dev),
365  static_cast<IndexDataType*>(out_index_dev),
366  in_elementwise_op,
367  acc_elementwise_op);
368  };
369 
370  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
371  {
372  return std::make_unique<Invoker>();
373  };
374 
375  std::string GetTypeString() const override
376  {
377  auto str = std::stringstream();
378 
379  // clang-format off
380  str << "DeviceReduceThreadWise<" << BlockSize << ",";
381  str << "M_C" << BlockSize << "_S" << MThreadSliceSize << ",";
382  str << "K_C" << 1 << "_S" << KThreadSliceSize << ",";
383  str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
384  // clang-format on
385 
386  return str.str();
387  }
388 };
389 
390 } // namespace device
391 } // namespace tensor_operation
392 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
int64_t long_index_t
Definition: ck.hpp:290
__global__ void kernel_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition: gridwise_2d_reduction_threadwise.hpp:28
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
Definition: stream_config.hpp:10
Definition: gridwise_2d_reduction_threadwise.hpp:84
Definition: sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:256
Definition: integral_constant.hpp:10
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_reduce.hpp:27
Definition: device_reduce_threadwise.hpp:162
std::array< index_t, Rank > inLengths_
Definition: device_reduce_threadwise.hpp:205
std::array< index_t, NumDstDim > outStrides_
Definition: device_reduce_threadwise.hpp:208
InElementwiseOperation in_elementwise_op_
Definition: device_reduce_threadwise.hpp:217
int numBlockTileIteration
Definition: device_reduce_threadwise.hpp:225
IndexDataType * out_index_dev_
Definition: device_reduce_threadwise.hpp:215
index_t invariant_lowest_length
Definition: device_reduce_threadwise.hpp:220
OutDataType * out_dev_
Definition: device_reduce_threadwise.hpp:214
AccDataType alpha_
Definition: device_reduce_threadwise.hpp:210
size_t gridSize
Definition: device_reduce_threadwise.hpp:226
AccDataType beta_
Definition: device_reduce_threadwise.hpp:211
AccElementwiseOperation acc_elementwise_op_
Definition: device_reduce_threadwise.hpp:218
std::array< index_t, Rank > inStrides_
Definition: device_reduce_threadwise.hpp:206
std::array< index_t, NumDstDim > outLengths_
Definition: device_reduce_threadwise.hpp:207
index_t reduce_lowest_length
Definition: device_reduce_threadwise.hpp:221
long_index_t invariant_total_length
Definition: device_reduce_threadwise.hpp:222
long_index_t reduce_total_length
Definition: device_reduce_threadwise.hpp:223
Argument(const std::array< index_t, Rank > inLengths, const std::array< index_t, Rank > inStrides, const std::array< index_t, NumDstDim > outLengths, const std::array< index_t, NumDstDim > outStrides, const std::array< int, NumReduceDim > reduceDims, double alpha, double beta, const InDataType *in_dev, OutDataType *out_dev, IndexDataType *out_index_dev, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op)
Definition: device_reduce_threadwise.hpp:163
const InDataType * in_dev_
Definition: device_reduce_threadwise.hpp:213
Definition: device_reduce_threadwise.hpp:230
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_reduce_threadwise.hpp:293
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_reduce_threadwise.hpp:231
Definition: device_reduce_threadwise.hpp:49
static constexpr index_t NumSrcDim
Definition: device_reduce_threadwise.hpp:63
int32_t IndexDataType
Definition: device_reduce_threadwise.hpp:57
static constexpr index_t NumDstDim
Definition: device_reduce_threadwise.hpp:64
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_reduce_threadwise.hpp:300
std::string GetTypeString() const override
Definition: device_reduce_threadwise.hpp:375
static constexpr index_t NumInvariantDim
Definition: device_reduce_threadwise.hpp:61
static constexpr index_t M_BlockTileSize
Definition: device_reduce_threadwise.hpp:67
static constexpr bool reduceAllDim
Definition: device_reduce_threadwise.hpp:65
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_reduce_threadwise.hpp:370
static constexpr bool HaveIndexInput
Definition: device_reduce_threadwise.hpp:59
static constexpr index_t K_BlockTileSize
Definition: device_reduce_threadwise.hpp:68
static auto MakeSrc2dDescriptor(const std::array< index_t, Rank > &inLengths, const std::array< index_t, Rank > &inStrides)
Definition: device_reduce_threadwise.hpp:70
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, Rank > inLengths, const std::array< index_t, Rank > inStrides, const std::array< index_t, NumDstDim > outLengths, const std::array< index_t, NumDstDim > outStrides, const std::array< int, NumReduceDim > reduceDims, double alpha, double beta, const void *in_dev, const void *in_index_dev, void *out_dev, void *out_index_dev, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op) override
Definition: device_reduce_threadwise.hpp:340
static auto MakeDst1dDescriptor(const std::array< index_t, NumDstDim > &outLengths, const std::array< index_t, NumDstDim > &outStrides)
Definition: device_reduce_threadwise.hpp:132