/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_reduce_multiblock.hpp Source File
device_reduce_multiblock.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 #include <array>
9 
18 
19 namespace ck {
20 namespace tensor_operation {
21 namespace device {
22 
23 template <typename InDataType,
24  typename AccDataType,
25  typename OutDataType,
26  index_t Rank,
27  index_t NumReduceDim,
28  typename ReduceOperation,
29  typename InElementwiseOperation,
30  typename AccElementwiseOperation,
31  InMemoryDataOperationEnum OutMemoryDataOperation,
32  bool PropagateNan,
33  bool OutputIndex,
34  bool HaveIndexInputIfOutputIndex,
35  index_t BlockSize,
36  index_t MThreadClusterSize,
37  index_t KThreadClusterSize,
38  index_t MThreadSliceSize,
39  index_t KThreadSliceSize,
40  index_t InSrcVectorDim,
41  index_t InSrcVectorSize,
42  index_t OutDstVectorSize>
43 struct DeviceReduceMultiBlock : public DeviceReduce<InDataType,
44  AccDataType,
45  OutDataType,
46  Rank,
47  NumReduceDim,
48  ReduceOperation,
49  InElementwiseOperation,
50  AccElementwiseOperation,
51  PropagateNan,
52  OutputIndex>
53 {
54  static_assert(Rank <= 12, "Bigger Rank size is not supported!");
55  static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
56  "Invalid thread cluster size assignments!");
57 
58  static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
59  (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
60  (MThreadSliceSize % OutDstVectorSize == 0),
61  "Invalid thread slice sizes and/or vector sizes configuration, please check!");
62 
63  using IndexDataType = int32_t;
64 
65  static constexpr bool HaveIndexInput = OutputIndex && HaveIndexInputIfOutputIndex;
66 
67  static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
68 
69  static constexpr index_t NumSrcDim = Rank;
70  static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
71  static constexpr bool reduceAllDim = (NumInvariantDim == 0);
72 
73  // So far, only AtomicAdd is considered, other Atomic Operation like AtomicMax can be added
74  // later
75  static constexpr bool use_multiblock =
76  (OutMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd);
77 
78  static_assert(ck::reduce::InMemoryDataOperationSupportedOnDataType<OutMemoryDataOperation,
79  OutDataType>::value,
80  "The OutDataType must support the specified OutMemoryDataOperation!");
81 
82  static_assert(!use_multiblock || (use_multiblock && !OutputIndex),
83  "MultiBlock reduction can only be used when outputing index is not required");
84 
85  static_assert(
86  ReduceOperation::IsCompatibleInMemoryDataOperation(OutMemoryDataOperation),
87  "The reduction accumulation operation must be compatible with the OutMemoryDataOperation!");
88 
89  static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
90  static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
91 
92  static auto MakeSrc2dDescriptor(const std::array<index_t, Rank>& inLengths,
93  const std::array<index_t, Rank>& inStrides,
94  int blkGroupSize,
95  int numBlockTileIteration)
96  {
97  const auto tupleSrcLengths =
98  generate_tuple([&](auto I) { return inLengths[I]; }, Number<Rank>{});
99  const auto tupleSrcStrides =
100  generate_tuple([&](auto I) { return inStrides[I]; }, Number<Rank>{});
101 
102  const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
103 
104  const auto in_grid_desc_m_k = [&]() {
105  if constexpr(reduceAllDim)
106  {
107  const auto one_dim_inDesc = transform_tensor_descriptor(
108  inDesc,
109  make_tuple(make_merge_transform(tupleSrcLengths)),
112 
113  return transform_tensor_descriptor(one_dim_inDesc,
115  1, one_dim_inDesc.GetLength(Number<0>{})))),
118  }
119  else
120  {
121  using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
123 
124  const auto reduceDimLengths = generate_tuple(
125  [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
126  const auto invariantDimLengths =
127  generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
128 
130  inDesc,
131  make_tuple(make_merge_transform(invariantDimLengths),
132  make_merge_transform(reduceDimLengths)),
133  make_tuple(InvariantDims{}, ReduceDims{}),
135  }
136  }();
137 
138  const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
139  const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
140 
141  const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
142  const auto inPad_M =
143  math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
144  const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
145 
146  auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
147  in_grid_desc_m_k,
148  make_tuple(make_right_pad_transform(invariantLength, inPad_M),
149  make_right_pad_transform(reduceLength, inPad_K)),
152 
153  return (in_grid_desc_m_k_padded);
154  };
155 
156  static auto MakeDst1dDescriptor(const std::array<index_t, NumDstDim>& outLengths,
157  const std::array<index_t, NumDstDim>& outStrides)
158  {
159  const auto tupleDstLengths =
160  generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumDstDim>{});
161  const auto tupleDstStrides =
162  generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumDstDim>{});
163 
164  auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
165 
166  auto out_grid_desc_m = transform_tensor_descriptor(
167  outDesc,
168  make_tuple(make_merge_transform(tupleDstLengths)),
171 
172  const auto invariantLength = out_grid_desc_m.GetLength(Number<0>{});
173 
174  const auto outPad =
175  math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
176 
177  auto out_grid_desc_m_padded = transform_tensor_descriptor(
178  out_grid_desc_m,
179  make_tuple(make_right_pad_transform(invariantLength, outPad)),
182  return (out_grid_desc_m_padded);
183  };
184 
185  static auto MakeDst1dDescriptorForBufferSet(const std::array<index_t, NumDstDim>& outLengths,
186  const std::array<index_t, NumDstDim>& outStrides)
187  {
188  const auto tupleDstLengths =
189  generate_tuple([&](auto I) { return outLengths[I]; }, Number<NumDstDim>{});
190  const auto tupleDstStrides =
191  generate_tuple([&](auto I) { return outStrides[I]; }, Number<NumDstDim>{});
192 
193  auto outDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides);
194 
195  auto out_grid_desc_m = transform_tensor_descriptor(
196  outDesc,
197  make_tuple(make_merge_transform(tupleDstLengths)),
200 
201  const auto length = out_grid_desc_m.GetLength(Number<0>{});
202 
203  const auto pad = math::integer_least_multiple(length, BlockSize) - length;
204 
205  auto out_grid_desc_m_padded =
206  transform_tensor_descriptor(out_grid_desc_m,
210  return (out_grid_desc_m_padded);
211  };
212 
213  struct Argument : public BaseArgument
214  {
215  Argument(const std::array<index_t, Rank> inLengths,
216  const std::array<index_t, Rank> inStrides,
217  const std::array<index_t, NumDstDim> outLengths,
218  const std::array<index_t, NumDstDim> outStrides,
219  const std::array<int, NumReduceDim> reduceDims,
220  double alpha,
221  double beta,
222  const InDataType* in_dev,
223  const IndexDataType* in_index_dev,
224  OutDataType* out_dev,
225  IndexDataType* out_index_dev,
226  const InElementwiseOperation in_elementwise_op,
227  const AccElementwiseOperation acc_elementwise_op)
228  : outLengths_{outLengths},
229  outStrides_{outStrides},
230  in_dev_{in_dev},
231  in_index_dev_{in_index_dev},
232  out_dev_{out_dev},
233  out_index_dev_{out_index_dev},
234  in_elementwise_op_{in_elementwise_op},
235  acc_elementwise_op_{acc_elementwise_op}
236  {
237  if(Rank != inLengths.size() || Rank != inStrides.size() ||
238  NumReduceDim != reduceDims.size())
239  {
240  throw std::runtime_error(
241  "One of inLengths/inStrides/reduceDims has invalid size!"
242  "\nExpected size inLengths: " +
243  std::to_string(Rank) + ", inStrides: " + std::to_string(Rank) +
244  ", reduceDims: " + std::to_string(NumReduceDim) +
245  "\nBut have inLengths: " + std::to_string(inLengths.size()) +
246  ", inStrides: " + std::to_string(inStrides.size()) +
247  ", reduceDims: " + std::to_string(reduceDims.size()));
248  }
249 
250  for(std::size_t i = 0; i < reduceDims.size(); ++i)
251  {
252  if(reduceDims[i] < 0 || reduceDims[i] >= Rank)
253  {
254  throw std::runtime_error("Provided reduce dimension exceed input tensor Rank!"
255  "\nHave reduceDims[" +
256  std::to_string(i) +
257  "]: " + std::to_string(reduceDims[i]));
258  }
259  }
260 
261  inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
262  inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
263 
264  alpha_ = type_convert<AccDataType>(alpha);
265  beta_ = type_convert<AccDataType>(beta);
266 
268  get_2d_lengths<Rank, NumReduceDim>(inLengths_);
269 
270  if constexpr(NumInvariantDim == 0)
272  else
274 
275  reduce_lowest_length = inLengths_[Rank - 1];
276 
277  if constexpr(use_multiblock)
278  {
279 
280  int iterations = 1;
281  while(true)
282  {
283  int testBlkGroupSize =
284  (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
285  (K_BlockTileSize * iterations);
286 
287  // we want the blkGroupSize be not more than 128
288  if(testBlkGroupSize <= 128)
289  break;
290 
291  iterations++;
292  };
293 
294  blkGroupSize = (reduce_total_length + (K_BlockTileSize * iterations) - 1) /
295  (K_BlockTileSize * iterations);
296 
297  numBlockTileIteration = iterations;
298  }
299  else
300  {
301  blkGroupSize = 1;
304  };
305 
308 
309  gridSize_pre =
311  }
312 
313  std::array<index_t, Rank> inLengths_;
314  std::array<index_t, Rank> inStrides_;
315  std::array<index_t, NumDstDim> outLengths_;
316  std::array<index_t, NumDstDim> outStrides_;
317 
318  AccDataType alpha_;
319  AccDataType beta_;
320 
321  const InDataType* in_dev_;
323  OutDataType* out_dev_;
325 
326  InElementwiseOperation in_elementwise_op_;
327  AccElementwiseOperation acc_elementwise_op_;
328 
333 
336  size_t gridSize;
337 
338  size_t gridSize_pre;
339  };
340 
341  struct Invoker : public BaseInvoker
342  {
343  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
344  {
345  const auto in_grid_desc_m_k = DeviceReduceMultiBlock::MakeSrc2dDescriptor(
347  const auto out_grid_desc_m =
349  const auto out_grid_desc_m_2 = DeviceReduceMultiBlock::MakeDst1dDescriptorForBufferSet(
350  arg.outLengths_, arg.outStrides_);
351 
352  using InGridDesc_M_K = decltype(in_grid_desc_m_k);
353  using OutGridDesc_M = decltype(out_grid_desc_m);
354  using OutGridDesc_M_2 = decltype(out_grid_desc_m_2);
355 
356  using GridwiseReduce = GridwiseReduction_mk_to_m_multiblock<InDataType,
357  OutDataType,
358  AccDataType,
360  InGridDesc_M_K,
361  OutGridDesc_M,
362  ReduceOperation,
363  InElementwiseOperation,
364  AccElementwiseOperation,
365  OutMemoryDataOperation,
366  PropagateNan,
367  BlockSize,
368  MThreadClusterSize,
369  KThreadClusterSize,
370  MThreadSliceSize,
371  KThreadSliceSize,
372  InSrcVectorDim,
373  InSrcVectorSize,
374  OutDstVectorSize>;
375 
376  const auto kernel_main = kernel_reduce_multiblock<GridwiseReduce,
377  OutputIndex,
379  InDataType,
380  OutDataType,
381  AccDataType,
382  int32_t,
383  InGridDesc_M_K,
384  OutGridDesc_M,
385  InElementwiseOperation,
386  AccElementwiseOperation>;
387 
388  float avg_time = 0;
389 
390  if constexpr(use_multiblock)
391  {
392  const auto identityVal =
393  ck::reduce::GetIdentityValueForInMemoryDataOperation<OutDataType>(
394  OutMemoryDataOperation);
395 
396  const auto kernel_pre =
397  kernel_buffer_set_value<BlockSize, OutDataType, OutGridDesc_M_2>;
398 
399  avg_time += launch_and_time_kernel(stream_config,
400  kernel_pre,
401  dim3(arg.gridSize_pre),
402  dim3(BlockSize),
403  0,
404  out_grid_desc_m_2,
405  arg.out_dev_,
406  identityVal);
407  };
408 
409  avg_time += launch_and_time_kernel(stream_config,
410  kernel_main,
411  dim3(arg.gridSize),
412  dim3(BlockSize),
413  0,
414  in_grid_desc_m_k,
415  out_grid_desc_m,
416  arg.in_elementwise_op_,
418  arg.blkGroupSize,
420  arg.alpha_,
421  arg.in_dev_,
422  arg.in_index_dev_,
423  arg.beta_,
424  arg.out_dev_,
425  arg.out_index_dev_);
426 
427  return (avg_time);
428  };
429 
430  float Run(const BaseArgument* p_arg,
431  const StreamConfig& stream_config = StreamConfig{}) override
432  {
433  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
434  };
435  };
436 
437  static bool IsSupportedArgument(const Argument* pArg)
438  {
439  if constexpr(use_multiblock)
440  {
441  if(static_cast<float>(pArg->beta_) != 0.0f)
442  return (false);
443  };
444 
445  if constexpr(InSrcVectorDim == 0)
446  {
447  if constexpr(NumInvariantDim == 0)
448  {
449  return (false);
450  }
451  else
452  {
453  if(pArg->inStrides_[NumInvariantDim - 1] != 1)
454  return (false);
455 
456  if(pArg->invariant_lowest_length % InSrcVectorSize != 0)
457  return (false);
458  };
459  }
460  else
461  {
462  if(pArg->inStrides_[Rank - 1] != 1)
463  return (false);
464 
465  if(pArg->reduce_lowest_length % InSrcVectorSize != 0)
466  return (false);
467  };
468 
469  // To improve
470  if(pArg->invariant_lowest_length % OutDstVectorSize != 0)
471  return (false);
472 
473  if constexpr(use_multiblock)
474  {
475  // blkGroupSize of 1 should be handled by Blockwise path using
476  // InMemoryDataOperationEnum::Set
477  if(pArg->blkGroupSize == 1)
478  return (false);
479 
480  // This is very strong restriction, but needed to avoid some failure
482  return (false);
483  }
484  else
485  {
486  // cases with very small reduce_total_length should be handled by ThreadWise kernel
487  // if(pArg->reduce_total_length / KThreadSliceSize < 2)
488  // return (false);
489  };
490 
491  return (true);
492  }
493 
494  bool IsSupportedArgument(const BaseArgument* p_arg) override
495  {
496  return IsSupportedArgument(dynamic_cast<const Argument*>(p_arg));
497  };
498 
499  std::unique_ptr<BaseArgument>
500  MakeArgumentPointer(const std::array<index_t, Rank> inLengths,
501  const std::array<index_t, Rank> inStrides,
502  const std::array<index_t, NumDstDim> outLengths,
503  const std::array<index_t, NumDstDim> outStrides,
504  const std::array<int, NumReduceDim> reduceDims,
505  double alpha,
506  double beta,
507  const void* in_dev,
508  const void* in_index_dev,
509  void* out_dev,
510  void* out_index_dev,
511  const InElementwiseOperation in_elementwise_op,
512  const AccElementwiseOperation acc_elementwise_op) override
513  {
514  return std::make_unique<Argument>(inLengths,
515  inStrides,
516  outLengths,
517  outStrides,
518  reduceDims,
519  alpha,
520  beta,
521  static_cast<const InDataType*>(in_dev),
522  static_cast<const IndexDataType*>(in_index_dev),
523  static_cast<OutDataType*>(out_dev),
524  static_cast<IndexDataType*>(out_index_dev),
525  in_elementwise_op,
526  acc_elementwise_op);
527  };
528 
529  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
530  {
531  return std::make_unique<Invoker>();
532  };
533 
534  std::string GetTypeString() const override
535  {
536  auto str = std::stringstream();
537 
538  // clang-format off
539  str << (OutMemoryDataOperation == InMemoryDataOperationEnum::Set? "DeviceReduceBlockWise<" : "DeviceReduceMultiBlock<") << BlockSize << ",";
540  str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
541  str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
542  str << "InSrcVectorDim_" << InSrcVectorDim << "_InSrcVectorSize_" << InSrcVectorSize << "_OutDstVectorSize_" << OutDstVectorSize << ">";
543  // clang-format on
544 
545  return str.str();
546  }
547 };
548 
549 } // namespace device
550 } // namespace tensor_operation
551 } // namespace ck
auto pad(ck::index_t mpb, ck::index_t npb, ck::index_t kpb, ck::tensor_operation::device::GemmSpecialization gemm, CDesc_MRaw_NRaw conv)
Definition: helper.hpp:70
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
InMemoryDataOperationEnum
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
int64_t long_index_t
Definition: ck.hpp:290
__global__ void kernel_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op, index_t block_group_size, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition: gridwise_2d_reduction_multiblock.hpp:27
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
Definition: stream_config.hpp:10
Definition: gridwise_2d_reduction_multiblock.hpp:92
Definition: sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:256
Definition: integral_constant.hpp:10
Definition: reduction_operator.hpp:485
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_reduce.hpp:27
Definition: device_reduce_multiblock.hpp:214
OutDataType * out_dev_
Definition: device_reduce_multiblock.hpp:323
int numBlockTileIteration
Definition: device_reduce_multiblock.hpp:335
size_t gridSize_pre
Definition: device_reduce_multiblock.hpp:338
Argument(const std::array< index_t, Rank > inLengths, const std::array< index_t, Rank > inStrides, const std::array< index_t, NumDstDim > outLengths, const std::array< index_t, NumDstDim > outStrides, const std::array< int, NumReduceDim > reduceDims, double alpha, double beta, const InDataType *in_dev, const IndexDataType *in_index_dev, OutDataType *out_dev, IndexDataType *out_index_dev, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op)
Definition: device_reduce_multiblock.hpp:215
index_t reduce_lowest_length
Definition: device_reduce_multiblock.hpp:330
int blkGroupSize
Definition: device_reduce_multiblock.hpp:334
long_index_t reduce_total_length
Definition: device_reduce_multiblock.hpp:332
std::array< index_t, Rank > inStrides_
Definition: device_reduce_multiblock.hpp:314
std::array< index_t, Rank > inLengths_
Definition: device_reduce_multiblock.hpp:313
InElementwiseOperation in_elementwise_op_
Definition: device_reduce_multiblock.hpp:326
IndexDataType * out_index_dev_
Definition: device_reduce_multiblock.hpp:324
index_t invariant_lowest_length
Definition: device_reduce_multiblock.hpp:329
long_index_t invariant_total_length
Definition: device_reduce_multiblock.hpp:331
std::array< index_t, NumDstDim > outLengths_
Definition: device_reduce_multiblock.hpp:315
AccDataType beta_
Definition: device_reduce_multiblock.hpp:319
size_t gridSize
Definition: device_reduce_multiblock.hpp:336
const InDataType * in_dev_
Definition: device_reduce_multiblock.hpp:321
AccDataType alpha_
Definition: device_reduce_multiblock.hpp:318
std::array< index_t, NumDstDim > outStrides_
Definition: device_reduce_multiblock.hpp:316
const IndexDataType * in_index_dev_
Definition: device_reduce_multiblock.hpp:322
AccElementwiseOperation acc_elementwise_op_
Definition: device_reduce_multiblock.hpp:327
Definition: device_reduce_multiblock.hpp:342
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_reduce_multiblock.hpp:430
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_reduce_multiblock.hpp:343
Definition: device_reduce_multiblock.hpp:53
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, Rank > inLengths, const std::array< index_t, Rank > inStrides, const std::array< index_t, NumDstDim > outLengths, const std::array< index_t, NumDstDim > outStrides, const std::array< int, NumReduceDim > reduceDims, double alpha, double beta, const void *in_dev, const void *in_index_dev, void *out_dev, void *out_index_dev, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op) override
Definition: device_reduce_multiblock.hpp:500
static constexpr bool use_multiblock
Definition: device_reduce_multiblock.hpp:75
int32_t IndexDataType
Definition: device_reduce_multiblock.hpp:63
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_reduce_multiblock.hpp:529
static constexpr index_t NumInvariantDim
Definition: device_reduce_multiblock.hpp:67
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_reduce_multiblock.hpp:494
static bool IsSupportedArgument(const Argument *pArg)
Definition: device_reduce_multiblock.hpp:437
static constexpr index_t M_BlockTileSize
Definition: device_reduce_multiblock.hpp:89
static auto MakeDst1dDescriptor(const std::array< index_t, NumDstDim > &outLengths, const std::array< index_t, NumDstDim > &outStrides)
Definition: device_reduce_multiblock.hpp:156
static constexpr bool reduceAllDim
Definition: device_reduce_multiblock.hpp:71
std::string GetTypeString() const override
Definition: device_reduce_multiblock.hpp:534
static constexpr index_t NumDstDim
Definition: device_reduce_multiblock.hpp:70
static constexpr bool HaveIndexInput
Definition: device_reduce_multiblock.hpp:65
static constexpr index_t NumSrcDim
Definition: device_reduce_multiblock.hpp:69
static auto MakeSrc2dDescriptor(const std::array< index_t, Rank > &inLengths, const std::array< index_t, Rank > &inStrides, int blkGroupSize, int numBlockTileIteration)
Definition: device_reduce_multiblock.hpp:92
static constexpr index_t K_BlockTileSize
Definition: device_reduce_multiblock.hpp:90
static auto MakeDst1dDescriptorForBufferSet(const std::array< index_t, NumDstDim > &outLengths, const std::array< index_t, NumDstDim > &outStrides)
Definition: device_reduce_multiblock.hpp:185