/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_elementwise_scale.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_elementwise_scale.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_elementwise_scale.hpp Source File
device_elementwise_scale.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <memory>
7 #include <array>
8 
9 #include "ck/ck.hpp"
11 
12 namespace ck {
13 namespace tensor_operation {
14 namespace device {
15 
20 template <typename InDataTypeTuple,
21  typename OutDataTypeTuple,
22  typename ElementwiseOperation,
23  typename UnaryOperation,
24  typename Scale,
25  index_t NumDim>
26 struct DeviceElementwise : public BaseOperator
27 {
28  static constexpr int NumInput = InDataTypeTuple::Size();
29  static constexpr int NumOutput = OutDataTypeTuple::Size();
30 
31  virtual std::unique_ptr<BaseArgument>
32  MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
33  const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
34  const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
35  const std::array<const void*, NumInput> in_dev_buffers,
36  const std::array<void*, NumOutput> out_dev_buffers,
37  ElementwiseOperation elementwise_op,
38  UnaryOperation unary_op,
39  Scale scale_op) = 0;
40 
41  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
42 }; // namespace device
43 
44 template <typename InDataTypeTuple,
45  typename OutDataTypeTuple,
46  typename ElementwiseOperation,
47  typename UnaryOperation,
48  typename Scale,
49  index_t NumDim>
50 using DeviceElementwisePtr = std::unique_ptr<DeviceElementwise<InDataTypeTuple,
51  OutDataTypeTuple,
52  ElementwiseOperation,
53  UnaryOperation,
54  Scale,
55  NumDim>>;
56 
57 } // namespace device
58 } // namespace tensor_operation
59 } // namespace ck
std::unique_ptr< DeviceElementwise< InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim > > DeviceElementwisePtr
Definition: device_elementwise.hpp:41
Definition: ck.hpp:264
int32_t index_t
Definition: ck.hpp:289
Definition: device_elementwise.hpp:21
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
static constexpr int NumInput
Definition: device_elementwise.hpp:22
static constexpr int NumOutput
Definition: device_elementwise.hpp:23
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, NumDim > lengths, const std::array< std::array< index_t, NumDim >, NumInput > inStridesArray, const std::array< std::array< index_t, NumDim >, NumOutput > outStridesArray, const std::array< const void *, NumInput > in_dev_buffers, const std::array< void *, NumOutput > out_dev_buffers, ElementwiseOperation elementwise_op, UnaryOperation unary_op, Scale scale_op)=0