/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_grouped_contraction_multiple_d.hpp Source File
device_grouped_contraction_multiple_d.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <vector>
8 
10 
11 namespace ck {
12 namespace tensor_operation {
13 namespace device {
14 
15 template <index_t NumDTensor>
17 {
18  std::vector<index_t> a_ms_ks_lengths;
19  std::vector<index_t> a_ms_ks_strides;
20 
21  std::vector<index_t> b_ns_ks_lengths;
22  std::vector<index_t> b_ns_ks_strides;
23 
24  std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_lengths;
25  std::array<std::vector<index_t>, NumDTensor> ds_ms_ns_strides;
26 
27  std::vector<index_t> e_ms_ns_lengths;
28  std::vector<index_t> e_ms_ns_strides;
29 };
30 
31 // Tensor Contraction:
32 // input : A
33 // input : B
34 // input : D0, D1, ...
35 // output : E
36 // C = a_op(A) * b_op(B)
37 // E = cde_op(C, D0, D1, ...)
38 // Assume:
39 // A[M0, M1, M2, ..., K0, K1, K2, ...]
40 // B[N0, N1, N2, ..., K0, K1, K2, ...]
41 // D[M0, M1, M2, ..., N0, N1, N2, ...]
42 // E[M0, M1, M2, ..., N0, N1, N2, ...]
43 template <index_t NumDimM,
44  index_t NumDimN,
45  index_t NumDimK,
46  typename ADataType,
47  typename BDataType,
48  typename DsDataType,
49  typename EDataType,
50  typename AElementwiseOperation,
51  typename BElementwiseOperation,
52  typename CDEElementwiseOperation>
54 {
55  static constexpr index_t NumDTensor = DsDataType::Size();
56 
57  virtual std::unique_ptr<BaseArgument>
58  MakeArgumentPointer(std::vector<const void*> p_a_vec,
59  std::vector<const void*> p_b_vec,
60  std::vector<std::array<const void*, NumDTensor>> p_ds_vec,
61  std::vector<void*> p_e_vec,
62  std::vector<ContractionDesc<NumDTensor>> contraction_descs,
63  AElementwiseOperation a_element_op,
64  BElementwiseOperation b_element_op,
65  CDEElementwiseOperation cde_element_op) = 0;
66 
67  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
68 };
69 
70 } // namespace device
71 } // namespace tensor_operation
72 } // namespace ck
Definition: ck.hpp:264
int32_t index_t
Definition: ck.hpp:289
Definition: device_base.hpp:76
Definition: device_grouped_contraction_multiple_d.hpp:17
std::array< std::vector< index_t >, NumDTensor > ds_ms_ns_lengths
Definition: device_grouped_contraction_multiple_d.hpp:24
std::vector< index_t > b_ns_ks_lengths
Definition: device_grouped_contraction_multiple_d.hpp:21
std::vector< index_t > a_ms_ks_strides
Definition: device_grouped_contraction_multiple_d.hpp:19
std::vector< index_t > a_ms_ks_lengths
Definition: device_grouped_contraction_multiple_d.hpp:18
std::array< std::vector< index_t >, NumDTensor > ds_ms_ns_strides
Definition: device_grouped_contraction_multiple_d.hpp:25
std::vector< index_t > e_ms_ns_strides
Definition: device_grouped_contraction_multiple_d.hpp:28
std::vector< index_t > e_ms_ns_lengths
Definition: device_grouped_contraction_multiple_d.hpp:27
std::vector< index_t > b_ns_ks_strides
Definition: device_grouped_contraction_multiple_d.hpp:22
Definition: device_grouped_contraction_multiple_d.hpp:54
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > p_a_vec, std::vector< const void * > p_b_vec, std::vector< std::array< const void *, NumDTensor >> p_ds_vec, std::vector< void * > p_e_vec, std::vector< ContractionDesc< NumDTensor >> contraction_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
static constexpr index_t NumDTensor
Definition: device_grouped_contraction_multiple_d.hpp:55