/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_pool_fwd.hpp Source File
device_pool_fwd.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <vector>
7 
10 
11 namespace ck {
12 namespace tensor_operation {
13 namespace device {
14 
15 template <index_t InOutRank,
16  index_t WindowRank,
17  typename InDataType,
18  typename OutDataType,
19  typename IndexDataType,
20  typename InLayout,
21  typename OutLayout,
22  ReduceTensorOp ReduceOpId,
23  bool OutputIndex>
24 struct DevicePoolFwd : public BaseOperator
25 {
26  virtual std::unique_ptr<BaseArgument>
27  MakeArgumentPointer(const void* p_in_dev,
28  void* p_out_dev,
29  void* p_out_indices_dev,
30  std::vector<ck::index_t> input_n_c_wis_lengths,
31  std::vector<ck::index_t> window_xs_lengths,
32  std::vector<ck::index_t> output_n_c_wos_lengths,
33  std::vector<ck::index_t> input_n_c_wis_stride,
34  std::vector<ck::index_t> output_n_c_wis_stride,
35  std::vector<ck::index_t> indices_n_c_wis_stride,
36  std::vector<ck::index_t> window_xs_strides,
37  std::vector<ck::index_t> window_xs_dilations,
38  std::vector<ck::index_t> input_left_pads,
39  std::vector<ck::index_t> input_right_pads,
40  std::vector<ck::index_t> pooling_dims) = 0;
41 
42  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
43 };
44 
45 } // namespace device
46 } // namespace tensor_operation
47 } // namespace ck
Definition: ck.hpp:264
ReduceTensorOp
Definition: reduction_enums.hpp:9
int32_t index_t
Definition: ck.hpp:289
Definition: device_base.hpp:76
Definition: device_pool_fwd.hpp:25
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in_dev, void *p_out_dev, void *p_out_indices_dev, std::vector< ck::index_t > input_n_c_wis_lengths, std::vector< ck::index_t > window_xs_lengths, std::vector< ck::index_t > output_n_c_wos_lengths, std::vector< ck::index_t > input_n_c_wis_stride, std::vector< ck::index_t > output_n_c_wis_stride, std::vector< ck::index_t > indices_n_c_wis_stride, std::vector< ck::index_t > window_xs_strides, std::vector< ck::index_t > window_xs_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, std::vector< ck::index_t > pooling_dims)=0