/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_contraction_multiple_abd.hpp Source File
device_contraction_multiple_abd.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <array>
7 
9 
10 namespace ck {
11 namespace tensor_operation {
12 namespace device {
13 
14 // GEMM:
15 // input : A0[M0, M1, ... K0, K1, ...], ...
16 // input : B0[N0, N1, ... K0, K1, ...], ...
17 // input : D0[M0, M1, ... N0, N1, ...], D1[M0, M1, ... N0, N1, ...], ...
18 // output : E[M0, M1, ... N0, N1, ...]
19 // C = a_op(A) * b_op(B)
20 // E = cde_op(C, D0, D1, ...)
21 // Assume:
22 // D0, D1, ... and E have the same layout
23 template <index_t NumDimM,
24  index_t NumDimN,
25  index_t NumDimK,
26  typename AsDataType,
27  typename BsDataType,
28  typename DsDataType,
29  typename EDataType,
30  typename AElementwiseOperation,
31  typename BElementwiseOperation,
32  typename CDEElementwiseOperation>
34 {
35  static constexpr index_t NumATensor = AsDataType::Size();
36  static constexpr index_t NumBTensor = BsDataType::Size();
37  static constexpr index_t NumDTensor = DsDataType::Size();
38 
39  virtual std::unique_ptr<BaseArgument>
40  MakeArgumentPointer(std::array<const void*, NumATensor> p_as,
41  std::array<const void*, NumBTensor> p_bs,
42  std::array<const void*, NumDTensor> p_ds,
43  void* p_e,
44  const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_lengths,
45  const std::array<std::vector<index_t>, NumATensor>& a_ms_ks_strides,
46  const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_lengths,
47  const std::array<std::vector<index_t>, NumBTensor>& b_ns_ks_strides,
48  const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_lengths,
49  const std::array<std::vector<index_t>, NumDTensor>& d_ms_ns_strides,
50  const std::vector<index_t>& e_ms_ns_length,
51  const std::vector<index_t>& e_ms_ns_stride,
52  AElementwiseOperation a_element_op,
53  BElementwiseOperation b_element_op,
54  CDEElementwiseOperation cde_element_op) = 0;
55 
56  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
57 };
58 
59 } // namespace device
60 } // namespace tensor_operation
61 } // namespace ck
Definition: ck.hpp:264
int32_t index_t
Definition: ck.hpp:289
Definition: device_base.hpp:76
Definition: device_contraction_multiple_abd.hpp:34
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
static constexpr index_t NumATensor
Definition: device_contraction_multiple_abd.hpp:35
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(std::array< const void *, NumATensor > p_as, std::array< const void *, NumBTensor > p_bs, std::array< const void *, NumDTensor > p_ds, void *p_e, const std::array< std::vector< index_t >, NumATensor > &a_ms_ks_lengths, const std::array< std::vector< index_t >, NumATensor > &a_ms_ks_strides, const std::array< std::vector< index_t >, NumBTensor > &b_ns_ks_lengths, const std::array< std::vector< index_t >, NumBTensor > &b_ns_ks_strides, const std::array< std::vector< index_t >, NumDTensor > &d_ms_ns_lengths, const std::array< std::vector< index_t >, NumDTensor > &d_ms_ns_strides, const std::vector< index_t > &e_ms_ns_length, const std::vector< index_t > &e_ms_ns_stride, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
static constexpr index_t NumBTensor
Definition: device_contraction_multiple_abd.hpp:36
static constexpr index_t NumDTensor
Definition: device_contraction_multiple_abd.hpp:37