/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_gemm_v2.hpp Source File
device_gemm_v2.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 namespace tensor_operation {
10 namespace device {
11 
12 template <typename ALayout,
13  typename BLayout,
14  typename CLayout,
15  typename ADataType,
16  typename BDataType,
17  typename CDataType,
18  typename AElementwiseOperation,
19  typename BElementwiseOperation,
20  typename CElementwiseOperation>
21 struct DeviceGemmV2 : public BaseOperator
22 {
23  virtual std::unique_ptr<BaseArgument>
24  MakeArgumentPointer(const void* p_a,
25  const void* p_b,
26  void* p_c,
27  ck::index_t M,
28  ck::index_t N,
29  ck::index_t K,
30  ck::index_t StrideA,
31  ck::index_t StrideB,
32  ck::index_t StrideC,
33  ck::index_t KSplit,
34  AElementwiseOperation a_element_op,
35  BElementwiseOperation b_element_op,
36  CElementwiseOperation c_element_op) = 0;
37 
38  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
39 
40  virtual bool GetPermuteA() = 0;
41  virtual bool GetPermuteB() = 0;
42  virtual ck::index_t GetKPerBlock() = 0;
43 };
44 
45 template <typename ALayout,
46  typename BLayout,
47  typename DsLayout,
48  typename CLayout,
49  typename ADataType,
50  typename BDataType,
51  typename DsDataType,
52  typename CDataType,
53  typename AElementwiseOperation,
54  typename BElementwiseOperation,
55  typename CElementwiseOperation>
57 {
58  static constexpr index_t NumDTensor = DsDataType::Size();
59 
60  virtual std::unique_ptr<BaseArgument>
61  MakeArgumentPointer(const void* p_a,
62  const void* p_b,
63  std::array<const void*, NumDTensor> p_ds,
64  void* p_c,
65  ck::index_t M,
66  ck::index_t N,
67  ck::index_t K,
68  ck::index_t StrideA,
69  ck::index_t StrideB,
70  std::array<ck::index_t, NumDTensor> DsStrides,
71  ck::index_t StrideC,
72  ck::index_t KSplit,
73  AElementwiseOperation a_element_op,
74  BElementwiseOperation b_element_op,
75  CElementwiseOperation c_element_op) = 0;
76 
77  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
78 };
79 
80 template <typename ALayout,
81  typename BLayout,
82  typename CLayout,
83  typename ADataType,
84  typename BDataType,
85  typename BScaleType,
86  typename CDataType,
87  index_t ScaleBlockN,
88  index_t ScaleBlockK,
89  typename AElementwiseOperation,
90  typename BElementwiseOperation,
91  typename CElementwiseOperation>
93 {
94  virtual std::unique_ptr<BaseArgument>
95  MakeArgumentPointer(const void* p_a,
96  const void* p_b,
97  void* p_c,
98  ck::index_t M,
99  ck::index_t N,
100  ck::index_t K,
101  ck::index_t StrideA,
102  ck::index_t StrideB,
103  ck::index_t StrideC,
104  ck::index_t StrideScaleB,
105  const void* p_b_scale,
106  ck::index_t KSplit,
107  AElementwiseOperation a_element_op,
108  BElementwiseOperation b_element_op,
109  CElementwiseOperation c_element_op) = 0;
110 
111  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
112 
113  virtual bool GetPermuteB() = 0;
114  virtual ck::index_t GetKPerBlock() = 0;
115 };
116 
117 } // namespace device
118 } // namespace tensor_operation
119 } // namespace ck
Definition: ck.hpp:264
int32_t index_t
Definition: ck.hpp:289
Definition: device_base.hpp:76
Definition: device_gemm_v2.hpp:93
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, ck::index_t StrideScaleB, const void *p_b_scale, ck::index_t KSplit, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
Definition: device_gemm_v2.hpp:22
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, ck::index_t KSplit, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)=0
Definition: device_gemm_v2.hpp:57
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
static constexpr index_t NumDTensor
Definition: device_gemm_v2.hpp:58
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, std::array< ck::index_t, NumDTensor > DsStrides, ck::index_t StrideC, ck::index_t KSplit, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)=0