/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_batched_contraction_multiple_d.hpp Source File
device_batched_contraction_multiple_d.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <vector>
8 
10 
11 namespace ck {
12 namespace tensor_operation {
13 namespace device {
14 
15 // Tensor Contraction:
16 // input : A
17 // input : B
18 // input : D0, D1, ...
19 // output : E
20 // C = a_op(A) * b_op(B)
21 // E = cde_op(C, D0, D1, ...)
22 // Assume:
23 // A[G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
24 // B[G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
25 // D[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
26 // E[G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
27 template <index_t NumDimG,
28  index_t NumDimM,
29  index_t NumDimN,
30  index_t NumDimK,
31  typename ADataType,
32  typename BDataType,
33  typename DsDataType,
34  typename EDataType,
35  typename AElementwiseOperation,
36  typename BElementwiseOperation,
37  typename CDEElementwiseOperation>
39 {
40  static constexpr index_t NumDTensor = DsDataType::Size();
41 
42  virtual std::unique_ptr<BaseArgument>
43  MakeArgumentPointer(const void* p_a,
44  const void* p_b,
45  std::array<const void*, NumDTensor> p_ds,
46  void* p_e,
47  const std::vector<index_t>& a_gs_ms_ns_lengths,
48  const std::vector<index_t>& a_gs_ms_ks_strides,
49  const std::vector<index_t>& b_gs_ns_ks_lengths,
50  const std::vector<index_t>& b_gs_ns_ks_strides,
51  const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_lengths,
52  const std::array<std::vector<index_t>, NumDTensor>& ds_gs_ms_ns_strides,
53  const std::vector<index_t>& e_gs_ms_ns_lengths,
54  const std::vector<index_t>& e_gs_ms_ns_strides,
55  AElementwiseOperation a_element_op,
56  BElementwiseOperation b_element_op,
57  CDEElementwiseOperation cde_element_op) = 0;
58 
59  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
60 };
61 
62 } // namespace device
63 } // namespace tensor_operation
64 } // namespace ck
Definition: ck.hpp:264
int32_t index_t
Definition: ck.hpp:289
Definition: device_base.hpp:76
Definition: device_batched_contraction_multiple_d.hpp:39
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, const std::vector< index_t > &a_gs_ms_ns_lengths, const std::vector< index_t > &a_gs_ms_ks_strides, const std::vector< index_t > &b_gs_ns_ks_lengths, const std::vector< index_t > &b_gs_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &ds_gs_ms_ns_strides, const std::vector< index_t > &e_gs_ms_ns_lengths, const std::vector< index_t > &e_gs_ms_ns_strides, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
static constexpr index_t NumDTensor
Definition: device_batched_contraction_multiple_d.hpp:40
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0