/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_multiple_r.hpp Source File
device_gemm_multiple_d_multiple_r.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <array>
7 
9 
10 namespace ck {
11 namespace tensor_operation {
12 namespace device {
13 
14 // FIXME: DeviceGemmReduce type need to well define the problem
15 // GEMM:
16 // input : A[AK0, M, AK1]
17 // input : B[AK0, N, AK1]
18 // input : D0[M, N], D1[M, N], ...
19 // output : E[M, N]
20 // output : R0[M], R1[M], ...
21 // C = a_op(A) * b_op(B)
22 // E = cde_op(C, D0, D1, ...)
23 // Q0 = reduce0(q_op0(E)), Q1 = reduce1(q_op0(E)), ...
24 // R0 = r_op0(Q0), R1 = r_op1(Q1), ...
25 // Assume:
26 // D0, D1, ... and E have the same layout
27 template <typename ALayout,
28  typename BLayout,
29  typename DELayout,
30  typename ADataType,
31  typename BDataType,
32  typename DsDataType,
33  typename EDataType,
34  typename RsDataType,
35  typename AElementwiseOperation,
36  typename BElementwiseOperation,
37  typename CDEElementwiseOperation,
38  typename QsElementwiseOperation,
39  typename RsElementwiseOperation>
41 {
42  static constexpr index_t NumDTensor = DsDataType::Size();
43  static constexpr index_t NumRTensor = RsDataType::Size();
44 
45  virtual std::unique_ptr<BaseArgument>
46  MakeArgumentPointer(const void* p_a,
47  const void* p_b,
48  std::array<const void*, NumDTensor> p_ds,
49  void* p_e,
50  std::array<void*, NumRTensor> p_rs,
51  ck::index_t M,
52  ck::index_t N,
53  ck::index_t K,
54  ck::index_t StrideA,
55  ck::index_t StrideB,
56  std::array<ck::index_t, NumDTensor> StrideDs,
57  ck::index_t StrideE,
58  AElementwiseOperation a_element_op,
59  BElementwiseOperation b_element_op,
60  CDEElementwiseOperation cde_element_op,
61  QsElementwiseOperation qs_element_op,
62  RsElementwiseOperation rs_element_op) = 0;
63 
64  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
65 };
66 
67 template <typename ALayout,
68  typename BLayout,
69  typename DELayout,
70  typename ADataType,
71  typename BDataType,
72  typename DsDataType,
73  typename EDataType,
74  typename RsDataType,
75  typename AElementwiseOperation,
76  typename BElementwiseOperation,
77  typename CDEElementwiseOperation,
78  typename QsElementwiseOperation,
79  typename RsElementwiseOperation>
81  std::unique_ptr<DeviceGemmMultipleDMultipleR<ALayout,
82  BLayout,
83  DELayout,
84  ADataType,
85  BDataType,
86  DsDataType,
87  EDataType,
88  RsDataType,
89  AElementwiseOperation,
90  BElementwiseOperation,
91  CDEElementwiseOperation,
92  QsElementwiseOperation,
93  RsElementwiseOperation>>;
94 
95 } // namespace device
96 } // namespace tensor_operation
97 } // namespace ck
std::unique_ptr< DeviceGemmMultipleDMultipleR< ALayout, BLayout, DELayout, ADataType, BDataType, DsDataType, EDataType, RsDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, QsElementwiseOperation, RsElementwiseOperation > > DeviceGemmMultipleDMultipleRPtr
Definition: device_gemm_multiple_d_multiple_r.hpp:93
Definition: ck.hpp:264
int32_t index_t
Definition: ck.hpp:289
Definition: device_base.hpp:76
Definition: device_gemm_multiple_d_multiple_r.hpp:41
static constexpr index_t NumDTensor
Definition: device_gemm_multiple_d_multiple_r.hpp:42
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
static constexpr index_t NumRTensor
Definition: device_gemm_multiple_d_multiple_r.hpp:43
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, std::array< void *, NumRTensor > p_rs, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op, QsElementwiseOperation qs_element_op, RsElementwiseOperation rs_element_op)=0