/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_multiple_reduce.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_multiple_reduce.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_multiple_reduce.hpp Source File
device_multiple_reduce.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <vector>
7 #include <memory>
8 #include <array>
9 #include <iostream>
10 
11 #include "ck/ck.hpp"
14 
15 namespace ck {
16 namespace tensor_operation {
17 namespace device {
18 
19 template <index_t Rank,
20  index_t NumReduceDim,
21  index_t NumReduction,
22  typename InElementwiseOperationTuple,
23  typename AccElementwiseOperationTuple>
25 {
26  static constexpr index_t NumInputDim = Rank;
27  static constexpr index_t NumOutputDim = (Rank - NumReduceDim > 1) ? Rank - NumReduceDim : 1;
28 
29  virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
30  const std::array<index_t, NumInputDim> inLengths,
31  const std::array<index_t, NumInputDim> inStrides,
32  const std::array<index_t, NumOutputDim> outLengths,
33  const std::array<std::array<index_t, NumOutputDim>, NumReduction> outStrides,
34  const std::array<int, NumReduceDim> reduceDims,
35  const std::array<double, NumReduction> alphas,
36  const std::array<double, NumReduction> betas,
37  const void* in_dev,
38  const std::array<void*, NumReduction> out_dev_buffers,
39  const InElementwiseOperationTuple in_elementwise_op_tuple,
40  const AccElementwiseOperationTuple acc_elementwise_op_tuple) = 0;
41 
42  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
43 };
44 
45 template <index_t Rank,
46  index_t NumReduceDim,
47  index_t NumReduction,
48  typename InElementwiseOperationTuple,
49  typename AccElementwiseOperationTuple>
50 using DeviceMultipleReducePtr = std::unique_ptr<DeviceMultipleReduce<Rank,
51  NumReduceDim,
52  NumReduction,
53  InElementwiseOperationTuple,
54  AccElementwiseOperationTuple>>;
55 
56 } // namespace device
57 } // namespace tensor_operation
58 } // namespace ck
std::unique_ptr< DeviceMultipleReduce< Rank, NumReduceDim, NumReduction, InElementwiseOperationTuple, AccElementwiseOperationTuple > > DeviceMultipleReducePtr
Definition: device_multiple_reduce.hpp:54
Definition: ck.hpp:264
int32_t index_t
Definition: ck.hpp:289
Definition: device_base.hpp:76
Definition: device_multiple_reduce.hpp:25
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, NumInputDim > inLengths, const std::array< index_t, NumInputDim > inStrides, const std::array< index_t, NumOutputDim > outLengths, const std::array< std::array< index_t, NumOutputDim >, NumReduction > outStrides, const std::array< int, NumReduceDim > reduceDims, const std::array< double, NumReduction > alphas, const std::array< double, NumReduction > betas, const void *in_dev, const std::array< void *, NumReduction > out_dev_buffers, const InElementwiseOperationTuple in_elementwise_op_tuple, const AccElementwiseOperationTuple acc_elementwise_op_tuple)=0
static constexpr index_t NumInputDim
Definition: device_multiple_reduce.hpp:26
static constexpr index_t NumOutputDim
Definition: device_multiple_reduce.hpp:27
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0