/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_multiple_reduce_threadwise.hpp Source File
device_multiple_reduce_threadwise.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
11 
16 
18 
19 namespace ck {
20 namespace tensor_operation {
21 namespace device {
22 
23 template <index_t NumReduction,
24  typename InDataType,
25  typename AccDataType,
26  typename OutDataTypeTuple,
27  index_t Rank,
28  index_t NumReduceDim,
29  typename ReduceOperation,
30  typename InElementwiseOperationTuple,
31  typename AccElementwiseOperationTuple,
32  bool PropagateNan,
33  index_t BlockSize,
34  index_t MThreadSliceSize,
35  index_t KThreadSliceSize,
36  index_t InSrcVectorDim,
37  index_t InSrcVectorSize,
38  typename OutDstVectorSizeSeq>
40  NumReduceDim,
41  NumReduction,
42  InElementwiseOperationTuple,
43  AccElementwiseOperationTuple>
44 {
45  static_assert(Rank <= 6, "Bigger Rank size is not supported!");
46 
47  static_assert((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
48  (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0),
49  "Invalid thread slice sizes and/or vector sizes configuration, please check!");
50 
51  static_assert(NumReduction == OutDataTypeTuple::Size() &&
52  NumReduction == InElementwiseOperationTuple::Size() &&
53  NumReduction == AccElementwiseOperationTuple::Size() &&
54  NumReduction == OutDstVectorSizeSeq::Size(),
55  "All tuple should have the same size as the number of Reductions!");
56 
57  static_assert(sequence_all_of(OutDstVectorSizeSeq{},
58  [](auto vectorSize) {
59  return (MThreadSliceSize % vectorSize == 0);
60  }),
61  "The OutDstVectorSize should completely divide the MThreadSliceSize!");
62 
63  static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
64 
65  static constexpr index_t NumInputDim = Rank;
66  static constexpr index_t NumOutputDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
67  static constexpr bool reduceAllDim = (NumInvariantDim == 0);
68 
69  static constexpr index_t M_BlockTileSize = BlockSize * MThreadSliceSize;
70  static constexpr index_t K_BlockTileSize = 1 * KThreadSliceSize;
71 
73  {
74  return generate_tuple(
75  [&](auto I) {
76  using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
77 
78  return static_cast<DataType*>(nullptr);
79  },
81  };
82 
84 
85  static auto MakeSrc2dDescriptor(const std::array<index_t, NumInputDim>& inLengths,
86  const std::array<index_t, NumInputDim>& inStrides)
87  {
88  const auto tupleSrcLengths =
89  generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInputDim>{});
90  const auto tupleSrcStrides =
91  generate_tuple([&](auto I) { return inStrides[I]; }, Number<NumInputDim>{});
92 
93  const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
94 
95  const auto in_grid_desc_m_k = [&]() {
96  if constexpr(reduceAllDim)
97  {
98  const auto one_dim_inDesc = transform_tensor_descriptor(
99  inDesc,
100  make_tuple(make_merge_transform(tupleSrcLengths)),
103 
104  return transform_tensor_descriptor(one_dim_inDesc,
106  1, one_dim_inDesc.GetLength(Number<0>{})))),
109  }
110  else
111  {
112  using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
114 
115  const auto reduceDimLengths = generate_tuple(
116  [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
117  const auto invariantDimLengths =
118  generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
119 
121  inDesc,
122  make_tuple(make_merge_transform(invariantDimLengths),
123  make_merge_transform(reduceDimLengths)),
124  make_tuple(InvariantDims{}, ReduceDims{}),
126  }
127  }();
128 
129  const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
130  const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
131 
132  const auto inPad_M =
133  math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
134  const auto inPad_K =
135  math::integer_least_multiple(reduceLength, K_BlockTileSize) - reduceLength;
136 
137  auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
138  in_grid_desc_m_k,
139  make_tuple(make_right_pad_transform(invariantLength, inPad_M),
140  make_right_pad_transform(reduceLength, inPad_K)),
143 
144  return (in_grid_desc_m_k_padded);
145  };
146 
147  static auto MakeDst1dDescriptor(const std::array<index_t, NumOutputDim>& outLengths,
148  const std::array<index_t, NumOutputDim>& outStrides)
149  {
150  const auto tupleDstLengths =
151  generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumOutputDim>{});
152  const auto tupleDstStrides =
153  generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumOutputDim>{});
154 
155  auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
156 
157  auto out_grid_desc_m = transform_tensor_descriptor(
158  outDesc,
159  make_tuple(make_merge_transform(tupleDstLengths)),
162 
163  const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
164 
165  const auto outPad =
166  math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
167 
168  auto out_grid_desc_m_padded = transform_tensor_descriptor(
169  out_grid_desc_m,
170  make_tuple(make_right_pad_transform(invariantLength, outPad)),
173  return (out_grid_desc_m_padded);
174  };
175 
177  {
178  return generate_tuple(
179  [&](auto I) {
180  (void)I;
181  return MakeDst1dDescriptor(std::array<index_t, NumOutputDim>{},
182  std::array<index_t, NumOutputDim>{});
183  },
185  };
186 
187  using InGridDesc_M_K = decltype(MakeSrc2dDescriptor(std::array<index_t, NumInputDim>{},
188  std::array<index_t, NumInputDim>{}));
190 
191  struct Argument : public BaseArgument
192  {
193  Argument(const std::array<index_t, NumInputDim>& inLengths,
194  const std::array<index_t, NumInputDim>& inStrides,
195  const std::array<index_t, NumOutputDim>& outLengths,
196  const std::array<std::array<index_t, NumOutputDim>, NumReduction>& outStridesArray,
197  const std::array<int, NumReduceDim>& reduceDims,
198  const std::array<double, NumReduction>& alphas,
199  const std::array<double, NumReduction>& betas,
200  const void* in_dev,
201  const std::array<void*, NumReduction>& out_dev_buffers,
202  const InElementwiseOperationTuple in_elementwise_op_tuple,
203  const AccElementwiseOperationTuple acc_elementwise_op_tuple)
204  : outLengths_{outLengths},
205  outStridesArray_{outStridesArray},
206  in_elementwise_op_tuple_{in_elementwise_op_tuple},
207  acc_elementwise_op_tuple_{acc_elementwise_op_tuple}
208  {
209  inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
210  inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
211 
212  for(size_t i = 0; i < NumReduction; i++)
213  {
214  alpha_values_(i) = static_cast<AccDataType>(alphas[i]);
215  beta_values_(i) = static_cast<AccDataType>(betas[i]);
216  };
217 
218  in_dev_ = static_cast<const InDataType*>(in_dev);
219 
221  [&](auto iR) {
222  using OutDataTypePointer =
223  remove_cvref_t<decltype(OutDataTypePointerTuple{}[iR])>;
225  return static_cast<OutDataType*>(out_dev_buffers[iR]);
226  },
228 
230  get_2d_lengths<Rank, NumReduceDim>(inLengths_);
231 
233 
235  [&](auto I) { return MakeDst1dDescriptor(outLengths, outStridesArray[I]); },
237 
240  }
241 
242  std::array<index_t, NumInputDim> inLengths_;
243  std::array<index_t, NumInputDim> inStrides_;
244 
245  std::array<index_t, NumOutputDim> outLengths_;
246  std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray_;
247 
250 
251  const InDataType* in_dev_;
253 
256 
257  InElementwiseOperationTuple in_elementwise_op_tuple_;
258  AccElementwiseOperationTuple acc_elementwise_op_tuple_;
259 
262 
263  size_t gridSize;
264  };
265 
266  struct Invoker : public BaseInvoker
267  {
268  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
269  {
270  using GridwiseMultipleReduce =
272  InDataType,
274  AccDataType,
277  ReduceOperation,
278  InElementwiseOperationTuple,
279  AccElementwiseOperationTuple,
281  PropagateNan,
282  BlockSize,
283  MThreadSliceSize,
284  KThreadSliceSize,
285  InSrcVectorDim,
286  InSrcVectorSize,
287  OutDstVectorSizeSeq>;
288 
289  const auto kernel_main =
290  kernel_multiple_reduce_threadwise<GridwiseMultipleReduce,
291  NumReduction,
292  InDataType,
294  AccDataType,
297  InElementwiseOperationTuple,
298  AccElementwiseOperationTuple>;
299 
300  float avg_time = 0;
301 
302  avg_time += launch_and_time_kernel(stream_config,
303  kernel_main,
304  dim3(arg.gridSize),
305  dim3(BlockSize),
306  0,
307  arg.in_grid_desc_m_k,
311  arg.alpha_values_,
312  arg.in_dev_,
313  arg.beta_values_,
314  arg.out_dev_buffers_);
315 
316  return (avg_time);
317  };
318 
319  float Run(const BaseArgument* p_arg,
320  const StreamConfig& stream_config = StreamConfig{}) override
321  {
322  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
323  };
324  };
325 
326  bool IsSupportedArgument(const BaseArgument* p_arg) override
327  {
328  const Argument* pArg = dynamic_cast<const Argument*>(p_arg);
329 
330  if constexpr(InSrcVectorDim == 0)
331  {
332  if constexpr(NumInvariantDim == 0)
333  {
334  return (false);
335  }
336  else
337  {
338  if(pArg->inStrides_[NumInvariantDim - 1] != 1 && InSrcVectorSize != 1)
339  return (false);
340 
341  if(pArg->inLengths_[NumInvariantDim - 1] % InSrcVectorSize != 0)
342  return (false);
343  };
344  }
345  else
346  {
347  if(pArg->inStrides_[Rank - 1] != 1 && InSrcVectorSize != 1)
348  return (false);
349 
350  if(pArg->inLengths_[Rank - 1] % InSrcVectorSize != 0)
351  return (false);
352  };
353 
354  // To improve
355  bool valid = true;
356  static_for<0, NumReduction, 1>{}([&](auto I) {
357  if(pArg->outStridesArray_[I.value][NumOutputDim - 1] != 1 &&
358  OutDstVectorSizeSeq::At(I) != 1)
359  valid = false;
360 
361  if(pArg->outLengths_[NumOutputDim - 1] % OutDstVectorSizeSeq::At(I) != 0)
362  valid = false;
363  });
364 
365  if(!valid)
366  return (false);
367 
368  return (true);
369  };
370 
371  std::unique_ptr<BaseArgument> MakeArgumentPointer(
372  const std::array<index_t, NumInputDim> inLengths,
373  const std::array<index_t, NumInputDim> inStrides,
374  const std::array<index_t, NumOutputDim> outLengths,
375  const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStridesArray,
376  const std::array<int, NumReduceDim> reduceDims,
377  const std::array<double, NumReduction> alphas,
378  const std::array<double, NumReduction> betas,
379  const void* in_dev,
380  const std::array<void*, NumReduction> out_dev_buffers,
381  const InElementwiseOperationTuple in_elementwise_op_tuple,
382  const AccElementwiseOperationTuple acc_elementwise_op_tuple) override
383  {
384  return std::make_unique<Argument>(inLengths,
385  inStrides,
386  outLengths,
387  outStridesArray,
388  reduceDims,
389  alphas,
390  betas,
391  in_dev,
392  out_dev_buffers,
393  in_elementwise_op_tuple,
394  acc_elementwise_op_tuple);
395  };
396 
397  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
398  {
399  return std::make_unique<Invoker>();
400  };
401 
402  std::string GetTypeString() const override
403  {
404  auto str = std::stringstream();
405 
406  // clang-format off
407  str << "DeviceMultipleReduceThreadwise<" << BlockSize << ",";
408  str << "M_C" << BlockSize << "_S" << MThreadSliceSize << ",";
409  str << "K_C" << 1 << "_S" << KThreadSliceSize << ",";
410  str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << ",";
411  str << "OutDstVectorSize";
412  static_for<0, OutDstVectorSizeSeq::Size(), 1>{}([&](auto I) {str << "_" << OutDstVectorSizeSeq::At(I); });
413  str << ">";
414  // clang-format on
415 
416  return str.str();
417  }
418 };
419 
420 } // namespace device
421 } // namespace tensor_operation
422 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
__global__ void kernel_multiple_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M_Tuple out_grid_desc_m_tuple, const InElementwiseOperationTuple in_elementwise_op_tuple, const AccElementwiseOperationTuple acc_elementwise_op_tuple, Array< AccDataType, NumReduction > alpha_values, const InDataType *const __restrict__ p_in_value_global, Array< AccDataType, NumReduction > beta_values, OutDataTypePointerTuple p_out_value_global_tuple)
Definition: gridwise_2d_multiple_reduction_threadwise.hpp:26
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
int64_t long_index_t
Definition: ck.hpp:290
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ bool sequence_all_of(Seq, F f)
Definition: sequence.hpp:885
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
Definition: stream_config.hpp:10
Definition: gridwise_2d_multiple_reduction_threadwise.hpp:63
Definition: sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:256
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_multiple_reduce.hpp:25
Definition: device_multiple_reduce_threadwise.hpp:192
InGridDesc_M_K in_grid_desc_m_k
Definition: device_multiple_reduce_threadwise.hpp:254
long_index_t invariant_total_length
Definition: device_multiple_reduce_threadwise.hpp:260
long_index_t reduce_total_length
Definition: device_multiple_reduce_threadwise.hpp:261
Array< AccDataType, NumReduction > beta_values_
Definition: device_multiple_reduce_threadwise.hpp:249
Array< AccDataType, NumReduction > alpha_values_
Definition: device_multiple_reduce_threadwise.hpp:248
const InDataType * in_dev_
Definition: device_multiple_reduce_threadwise.hpp:251
std::array< index_t, NumInputDim > inLengths_
Definition: device_multiple_reduce_threadwise.hpp:242
std::array< index_t, NumInputDim > inStrides_
Definition: device_multiple_reduce_threadwise.hpp:243
size_t gridSize
Definition: device_multiple_reduce_threadwise.hpp:263
InElementwiseOperationTuple in_elementwise_op_tuple_
Definition: device_multiple_reduce_threadwise.hpp:257
std::array< index_t, NumOutputDim > outLengths_
Definition: device_multiple_reduce_threadwise.hpp:245
Argument(const std::array< index_t, NumInputDim > &inLengths, const std::array< index_t, NumInputDim > &inStrides, const std::array< index_t, NumOutputDim > &outLengths, const std::array< std::array< index_t, NumOutputDim >, NumReduction > &outStridesArray, const std::array< int, NumReduceDim > &reduceDims, const std::array< double, NumReduction > &alphas, const std::array< double, NumReduction > &betas, const void *in_dev, const std::array< void *, NumReduction > &out_dev_buffers, const InElementwiseOperationTuple in_elementwise_op_tuple, const AccElementwiseOperationTuple acc_elementwise_op_tuple)
Definition: device_multiple_reduce_threadwise.hpp:193
AccElementwiseOperationTuple acc_elementwise_op_tuple_
Definition: device_multiple_reduce_threadwise.hpp:258
OutGridDesc_M_Tuple out_grid_desc_m_tuple
Definition: device_multiple_reduce_threadwise.hpp:255
std::array< std::array< index_t, NumOutputDim >, NumReduction > outStridesArray_
Definition: device_multiple_reduce_threadwise.hpp:246
OutDataTypePointerTuple out_dev_buffers_
Definition: device_multiple_reduce_threadwise.hpp:252
Definition: device_multiple_reduce_threadwise.hpp:267
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_multiple_reduce_threadwise.hpp:268
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_multiple_reduce_threadwise.hpp:319
Definition: device_multiple_reduce_threadwise.hpp:44
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_multiple_reduce_threadwise.hpp:326
static constexpr index_t NumInputDim
Definition: device_multiple_reduce_threadwise.hpp:65
static constexpr index_t NumInvariantDim
Definition: device_multiple_reduce_threadwise.hpp:63
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, NumInputDim > inLengths, const std::array< index_t, NumInputDim > inStrides, const std::array< index_t, NumOutputDim > outLengths, const std::array< std::array< index_t, NumOutputDim >, NumReduction > outStridesArray, const std::array< int, NumReduceDim > reduceDims, const std::array< double, NumReduction > alphas, const std::array< double, NumReduction > betas, const void *in_dev, const std::array< void *, NumReduction > out_dev_buffers, const InElementwiseOperationTuple in_elementwise_op_tuple, const AccElementwiseOperationTuple acc_elementwise_op_tuple) override
Definition: device_multiple_reduce_threadwise.hpp:371
static auto MakeSrc2dDescriptor(const std::array< index_t, NumInputDim > &inLengths, const std::array< index_t, NumInputDim > &inStrides)
Definition: device_multiple_reduce_threadwise.hpp:85
static auto GenerateOutGrid1dDescTuple()
Definition: device_multiple_reduce_threadwise.hpp:176
static auto MakeDst1dDescriptor(const std::array< index_t, NumOutputDim > &outLengths, const std::array< index_t, NumOutputDim > &outStrides)
Definition: device_multiple_reduce_threadwise.hpp:147
std::string GetTypeString() const override
Definition: device_multiple_reduce_threadwise.hpp:402
decltype(GenerateOutDataTypePointerTuple()) OutDataTypePointerTuple
Definition: device_multiple_reduce_threadwise.hpp:83
static constexpr index_t K_BlockTileSize
Definition: device_multiple_reduce_threadwise.hpp:70
static auto GenerateOutDataTypePointerTuple()
Definition: device_multiple_reduce_threadwise.hpp:72
decltype(GenerateOutGrid1dDescTuple()) OutGridDesc_M_Tuple
Definition: device_multiple_reduce_threadwise.hpp:189
static constexpr bool reduceAllDim
Definition: device_multiple_reduce_threadwise.hpp:67
decltype(MakeSrc2dDescriptor(std::array< index_t, NumInputDim >{}, std::array< index_t, NumInputDim >{})) InGridDesc_M_K
Definition: device_multiple_reduce_threadwise.hpp:188
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_multiple_reduce_threadwise.hpp:397
static constexpr index_t NumOutputDim
Definition: device_multiple_reduce_threadwise.hpp:66
static constexpr index_t M_BlockTileSize
Definition: device_multiple_reduce_threadwise.hpp:69