/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp Source File
device_grouped_conv_bwd_weight.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <array>
7 
9 
10 namespace ck {
11 namespace tensor_operation {
12 namespace device {
13 
14 template <ck::index_t NDimSpatial,
15  typename InLayout,
16  typename WeiLayout,
17  typename OutLayout,
18  typename InDataType,
19  typename WeiDataType,
20  typename OutDataType,
21  typename InElementwiseOperation,
22  typename WeiElementwiseOperation,
23  typename OutElementwiseOperation,
24  typename ComputeTypeA = InDataType,
25  typename ComputeTypeB = ComputeTypeA>
27 {
28  virtual std::unique_ptr<BaseArgument>
29  MakeArgumentPointer(const void* p_in,
30  void* p_wei,
31  const void* p_out,
32  const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
33  const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
34  const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
35  const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
36  const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
37  const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
38  const std::array<ck::index_t, NDimSpatial>& conv_filter_strides,
39  const std::array<ck::index_t, NDimSpatial>& conv_filter_dilations,
40  const std::array<ck::index_t, NDimSpatial>& input_left_pads,
41  const std::array<ck::index_t, NDimSpatial>& input_right_pads,
42  InElementwiseOperation in_element_op,
43  WeiElementwiseOperation wei_element_op,
44  OutElementwiseOperation out_element_op,
45  ck::index_t split_k) = 0;
46 
47  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
48 };
49 
50 } // namespace device
51 } // namespace tensor_operation
52 } // namespace ck
Definition: ck.hpp:264
int32_t index_t
Definition: ck.hpp:289
Definition: device_base.hpp:76
Definition: device_grouped_conv_bwd_weight.hpp:27
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in, void *p_wei, const void *p_out, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_lengths, const std::array< index_t, NDimSpatial+3 > &a_g_n_c_wis_strides, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_lengths, const std::array< index_t, NDimSpatial+3 > &b_g_k_c_xs_strides, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_lengths, const std::array< index_t, NDimSpatial+3 > &e_g_n_k_wos_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_strides, const std::array< ck::index_t, NDimSpatial > &conv_filter_dilations, const std::array< ck::index_t, NDimSpatial > &input_left_pads, const std::array< ck::index_t, NDimSpatial > &input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op, ck::index_t split_k)=0