/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_avgpool_bwd.hpp Source File
device_avgpool_bwd.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <vector>
7 
9 
10 namespace ck {
11 namespace tensor_operation {
12 namespace device {
13 
14 template <index_t NDimSpatial,
15  typename DOutDataType,
16  typename DInDataType,
17  typename DOutLayout,
18  typename DInLayout>
20 {
21  virtual std::unique_ptr<BaseArgument>
22  MakeArgumentPointer(const void* p_dout,
23  void* p_din,
24  std::vector<ck::index_t> dout_n_k_wos_lengths,
25  std::vector<ck::index_t> dout_n_k_wos_strides,
26  std::vector<ck::index_t> din_n_k_wos_length,
27  std::vector<ck::index_t> din_n_k_wos_strides,
28  std::vector<ck::index_t> window_k_c_xs_lengths,
29  std::vector<ck::index_t> window_strides,
30  std::vector<ck::index_t> window_dilations,
31  std::vector<ck::index_t> input_left_pads,
32  std::vector<ck::index_t> input_right_pads) = 0;
33 
34  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
35 };
36 
37 } // namespace device
38 } // namespace tensor_operation
39 } // namespace ck
Definition: ck.hpp:264
int32_t index_t
Definition: ck.hpp:289
Definition: device_base.hpp:76
Definition: device_avgpool_bwd.hpp:20
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_dout, void *p_din, std::vector< ck::index_t > dout_n_k_wos_lengths, std::vector< ck::index_t > dout_n_k_wos_strides, std::vector< ck::index_t > din_n_k_wos_length, std::vector< ck::index_t > din_n_k_wos_strides, std::vector< ck::index_t > window_k_c_xs_lengths, std::vector< ck::index_t > window_strides, std::vector< ck::index_t > window_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads)=0