/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp Source File
device_gemm_multiple_d.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <array>
7 
9 
10 namespace ck {
11 namespace tensor_operation {
12 namespace device {
13 
14 // GEMM:
15 // input : A[M, K], B[K, N],
16 // input : D0[M, N], D1[M, N], ...
17 // output : E[M, N]
18 // C = a_op(A) * b_op(B)
19 // E = cde_op(C, D0, D1, ...)
20 // Assume:
21 // D0, D1, ... and E have the same layout
22 template <typename ALayout,
23  typename BLayout,
24  typename DsLayout,
25  typename ELayout,
26  typename ADataType,
27  typename BDataType,
28  typename DsDataType,
29  typename EDataType,
30  typename AElementwiseOperation,
31  typename BElementwiseOperation,
32  typename CDEElementwiseOperation>
34 {
35  static constexpr index_t NumDTensor = DsDataType::Size();
36 
37  virtual std::unique_ptr<BaseArgument>
38  MakeArgumentPointer(const void* p_a,
39  const void* p_b,
40  std::array<const void*, NumDTensor> p_ds,
41  void* p_e,
42  ck::index_t M,
43  ck::index_t N,
44  ck::index_t K,
45  ck::index_t StrideA,
46  ck::index_t StrideB,
47  std::array<ck::index_t, NumDTensor> StrideDs,
48  ck::index_t StrideE,
49  AElementwiseOperation a_element_op,
50  BElementwiseOperation b_element_op,
51  CDEElementwiseOperation cde_element_op) = 0;
52 
53  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
54 };
55 
56 // GEMM:
57 // input : A[M, K], B[K, N],
58 // input : D0[M, N], D1[M, N], ...
59 // output : E[M, N]
60 // C = a_op(A) * b_op(B)
61 // E = cde_op(C, D0, D1, ...)
62 // Assume:
63 // D0, D1, ... and E have the same layout
64 template <typename ALayout,
65  typename BLayout,
66  typename DsLayout,
67  typename ELayout,
68  typename ADataType,
69  typename BDataType,
70  typename DsDataType,
71  typename EDataType,
72  typename AElementwiseOperation,
73  typename BElementwiseOperation,
74  typename CDEElementwiseOperation>
76 {
77  static constexpr index_t NumDTensor = DsDataType::Size();
78 
79  virtual std::unique_ptr<BaseArgument>
80  MakeArgumentPointer(const void* p_a,
81  const void* p_b,
82  std::array<const void*, NumDTensor> p_ds,
83  void* p_e,
84  ck::index_t M,
85  ck::index_t N,
86  ck::index_t K,
87  ck::index_t StrideA,
88  ck::index_t StrideB,
89  std::array<ck::index_t, NumDTensor> StrideDs,
90  ck::index_t StrideE,
91  ck::index_t KBatch,
92  AElementwiseOperation a_element_op,
93  BElementwiseOperation b_element_op,
94  CDEElementwiseOperation cde_element_op) = 0;
95 
96  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
97 };
98 
99 } // namespace device
100 } // namespace tensor_operation
101 } // namespace ck
Definition: ck.hpp:264
int32_t index_t
Definition: ck.hpp:289
Definition: device_base.hpp:76
Definition: device_gemm_multiple_d.hpp:34
static constexpr index_t NumDTensor
Definition: device_gemm_multiple_d.hpp:35
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
Definition: device_gemm_multiple_d.hpp:76
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, ck::index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
static constexpr index_t NumDTensor
Definition: device_gemm_multiple_d.hpp:77