/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_grouped_conv_fwd.hpp Source File
device_grouped_conv_fwd.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <array>
7 
9 
10 namespace ck {
11 namespace tensor_operation {
12 namespace device {
13 
14 // Convolution Forward:
15 // input : input image A[G, N, C, Hi, Wi],
16 // input : weight B[G, K, C, Y, X],
17 // output : output image E[G, N, K, Ho, Wo]
18 // C = a_op(A) * b_op(B)
19 // E = cde_op(C, D0, D1, ...)
20 template <index_t NDimSpatial,
21  typename InLayout,
22  typename WeiLayout,
23  typename OutLayout,
24  typename InDataType,
25  typename WeiDataType,
26  typename OutDataType,
27  typename InElementwiseOperation,
28  typename WeiElementwiseOperation,
29  typename OutElementwiseOperation>
31 {
32  virtual std::unique_ptr<BaseArgument>
33  MakeArgumentPointer(const void* p_in, // input image
34  const void* p_wei, // weight
35  void* p_out, // output image
36  const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_lengths,
37  const std::array<index_t, NDimSpatial + 3>& in_g_n_c_wis_strides,
38  const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_lengths,
39  const std::array<index_t, NDimSpatial + 3>& wei_g_k_c_xs_strides,
40  const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_lengths,
41  const std::array<index_t, NDimSpatial + 3>& out_g_n_k_wos_strides,
42  const std::array<index_t, NDimSpatial>& conv_filter_strides,
43  const std::array<index_t, NDimSpatial>& conv_filter_dilations,
44  const std::array<index_t, NDimSpatial>& input_left_pads,
45  const std::array<index_t, NDimSpatial>& input_right_pads,
46  const InElementwiseOperation& in_element_op,
47  const WeiElementwiseOperation& wei_element_op,
48  const OutElementwiseOperation& out_element_op) = 0;
49 
50  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
51 };
52 
53 } // namespace device
54 } // namespace tensor_operation
55 } // namespace ck
Definition: ck.hpp:264
int32_t index_t
Definition: ck.hpp:289
Definition: device_base.hpp:76
Definition: device_grouped_conv_fwd.hpp:31
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in, const void *p_wei, void *p_out, const std::array< index_t, NDimSpatial+3 > &in_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &in_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &wei_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &wei_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &out_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &out_g_n_k_wos_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const InElementwiseOperation &in_element_op, const WeiElementwiseOperation &wei_element_op, const OutElementwiseOperation &out_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0