/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_elementwise_normalization.hpp Source File
device_elementwise_normalization.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <vector>
8 
10 
11 namespace ck {
12 namespace tensor_operation {
13 namespace device {
14 
15 template <typename InDataTypeTuple,
16  typename GammaDataType,
17  typename BetaDataType,
18  typename AccDataType,
19  typename YDataType,
20  typename XElementwiseOperation,
21  typename YElementwiseOperation,
22  index_t Rank,
23  index_t NumReduceDim>
25 {
26  static constexpr int NumInput = InDataTypeTuple::Size();
27 
28  virtual std::unique_ptr<BaseArgument>
29  MakeArgumentPointer(const std::vector<index_t> lengths,
30  const std::array<std::vector<index_t>, NumInput> inStridesArray,
31  const std::vector<index_t> gammaStrides,
32  const std::vector<index_t> betaStrides,
33  const std::vector<index_t> yStrides,
34  const std::vector<index_t> reduceDims,
35  double epsilon,
36  const std::array<const void*, NumInput> in_dev_buffers,
37  const void* p_gamma,
38  const void* p_beta,
39  void* p_y,
40  XElementwiseOperation x_elementwise_op,
41  YElementwiseOperation y_elementwise_op) = 0;
42 
43  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
44 };
45 
46 template <typename InDataTypeTuple,
47  typename GammaDataType,
48  typename BetaDataType,
49  typename AccDataType,
50  typename YDataType,
51  typename XElementwiseOperation,
52  typename YElementwiseOperation,
53  index_t Rank,
54  index_t NumReduceDim>
56  std::unique_ptr<DeviceElementwiseNormalization<InDataTypeTuple,
57  GammaDataType,
58  BetaDataType,
59  AccDataType,
60  YDataType,
61  XElementwiseOperation,
62  YElementwiseOperation,
63  Rank,
64  NumReduceDim>>;
65 
66 } // namespace device
67 } // namespace tensor_operation
68 } // namespace ck
std::unique_ptr< DeviceElementwiseNormalization< InDataTypeTuple, GammaDataType, BetaDataType, AccDataType, YDataType, XElementwiseOperation, YElementwiseOperation, Rank, NumReduceDim > > DeviceElementwiseNormalizationPtr
Definition: device_elementwise_normalization.hpp:64
Definition: ck.hpp:264
int32_t index_t
Definition: ck.hpp:289
Definition: device_base.hpp:76
Definition: device_elementwise_normalization.hpp:25
static constexpr int NumInput
Definition: device_elementwise_normalization.hpp:26
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::vector< index_t > lengths, const std::array< std::vector< index_t >, NumInput > inStridesArray, const std::vector< index_t > gammaStrides, const std::vector< index_t > betaStrides, const std::vector< index_t > yStrides, const std::vector< index_t > reduceDims, double epsilon, const std::array< const void *, NumInput > in_dev_buffers, const void *p_gamma, const void *p_beta, void *p_y, XElementwiseOperation x_elementwise_op, YElementwiseOperation y_elementwise_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0