/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_normalization_fwd.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_normalization_fwd.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_normalization_fwd.hpp Source File
device_normalization_fwd.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <vector>
8 
10 
11 namespace ck {
12 namespace tensor_operation {
13 namespace device {
14 template <typename XDataType,
15  typename GammaDataType,
16  typename BetaDataType,
17  typename YDataType,
18  typename SaveMeanInvStdDataType,
19  typename YElementwiseOperation,
20  index_t Rank,
21  index_t NumReduceDim>
23 {
24  virtual std::unique_ptr<BaseArgument>
25  MakeArgumentPointer(const std::vector<index_t> lengths,
26  const std::vector<index_t> xStrides,
27  const std::vector<index_t> gammaStrides,
28  const std::vector<index_t> betaStrides,
29  const std::vector<index_t> yStrides,
30  const std::vector<index_t> saveMeanStrides,
31  const std::vector<index_t> saveInvStdStrides,
32  const std::vector<index_t> reduceDims,
33  double epsilon,
34  const void* p_x,
35  const void* p_gamma,
36  const void* p_beta,
37  void* p_y,
38  void* p_savedMean,
39  void* p_savedInvVar,
40  YElementwiseOperation y_elementwise_op) = 0;
41 
42  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
43 };
44 
45 template <typename XDataType,
46  typename GammaDataType,
47  typename BetaDataType,
48  typename YDataType,
49  typename SaveMeanInvStdDataType,
50  typename YElementwiseOperation,
51  index_t Rank,
52  index_t NumReduceDim>
53 using DeviceNormalizationFwdPtr = std::unique_ptr<DeviceNormalizationFwd<XDataType,
54  GammaDataType,
55  BetaDataType,
56  YDataType,
57  SaveMeanInvStdDataType,
58  YElementwiseOperation,
59  Rank,
60  NumReduceDim>>;
61 
62 } // namespace device
63 } // namespace tensor_operation
64 } // namespace ck
std::unique_ptr< DeviceNormalizationFwd< XDataType, GammaDataType, BetaDataType, YDataType, SaveMeanInvStdDataType, YElementwiseOperation, Rank, NumReduceDim > > DeviceNormalizationFwdPtr
Definition: device_normalization_fwd.hpp:60
Definition: ck.hpp:264
int32_t index_t
Definition: ck.hpp:289
Definition: device_base.hpp:76
Definition: device_normalization_fwd.hpp:23
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::vector< index_t > lengths, const std::vector< index_t > xStrides, const std::vector< index_t > gammaStrides, const std::vector< index_t > betaStrides, const std::vector< index_t > yStrides, const std::vector< index_t > saveMeanStrides, const std::vector< index_t > saveInvStdStrides, const std::vector< index_t > reduceDims, double epsilon, const void *p_x, const void *p_gamma, const void *p_beta, void *p_y, void *p_savedMean, void *p_savedInvVar, YElementwiseOperation y_elementwise_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0