/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp Source File
device_gemm_multiple_d.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 #ifndef __HIPCC_RTC__
6 #include <array>
7 #endif
8 
9 #include "ck/utility/array.hpp"
11 
12 namespace ck {
13 namespace tensor_operation {
14 namespace device {
15 
16 // GEMM:
17 // input : A[M, K], B[K, N],
18 // input : D0[M, N], D1[M, N], ...
19 // output : E[M, N]
20 // C = a_op(A) * b_op(B)
21 // E = cde_op(C, D0, D1, ...)
22 // Assume:
23 // D0, D1, ... and E have the same layout
24 template <typename ALayout,
25  typename BLayout,
26  typename DsLayout,
27  typename ELayout,
28  typename ADataType,
29  typename BDataType,
30  typename DsDataType,
31  typename EDataType,
32  typename AElementwiseOperation,
33  typename BElementwiseOperation,
34  typename CDEElementwiseOperation>
36 {
37  static constexpr index_t NumDTensor = DsDataType::Size();
38 
39 #ifndef __HIPCC_RTC__
40  virtual std::unique_ptr<BaseArgument>
41  MakeArgumentPointer(const void* p_a,
42  const void* p_b,
43  std::array<const void*, NumDTensor> p_ds,
44  void* p_e,
45  ck::index_t M,
46  ck::index_t N,
47  ck::index_t K,
48  ck::index_t StrideA,
49  ck::index_t StrideB,
50  std::array<ck::index_t, NumDTensor> StrideDs,
51  ck::index_t StrideE,
52  AElementwiseOperation a_element_op,
53  BElementwiseOperation b_element_op,
54  CDEElementwiseOperation cde_element_op) = 0;
55 
56  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
57 #endif
58 };
59 
60 // GEMM:
61 // input : A[M, K], B[K, N],
62 // input : D0[M, N], D1[M, N], ...
63 // output : E[M, N]
64 // C = a_op(A) * b_op(B)
65 // E = cde_op(C, D0, D1, ...)
66 // Assume:
67 // D0, D1, ... and E have the same layout
68 template <typename ALayout,
69  typename BLayout,
70  typename DsLayout,
71  typename ELayout,
72  typename ADataType,
73  typename BDataType,
74  typename DsDataType,
75  typename EDataType,
76  typename AElementwiseOperation,
77  typename BElementwiseOperation,
78  typename CDEElementwiseOperation>
80 {
81  static constexpr index_t NumDTensor = DsDataType::Size();
82 
83 #ifndef __HIPCC_RTC__
84  virtual std::unique_ptr<BaseArgument>
85  MakeArgumentPointer(const void* p_a,
86  const void* p_b,
87  std::array<const void*, NumDTensor> p_ds,
88  void* p_e,
89  ck::index_t M,
90  ck::index_t N,
91  ck::index_t K,
92  ck::index_t StrideA,
93  ck::index_t StrideB,
94  std::array<ck::index_t, NumDTensor> StrideDs,
95  ck::index_t StrideE,
96  ck::index_t KBatch,
97  AElementwiseOperation a_element_op,
98  BElementwiseOperation b_element_op,
99  CDEElementwiseOperation cde_element_op) = 0;
100 
101  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
102 #endif
103 };
104 
105 // GEMM:
106 // input : A[M, K], B[K, N],
107 // input : D0[M, N], D1[M, N], ...
108 // output : E[M, N]
109 // C = a_op(A) * b_op(B)
110 // E = cde_op(C, D0, D1, ...)
111 // Assume:
112 // D0, D1, ... and E have the same layout
113 template <typename ALayout,
114  typename BLayout,
115  typename DsLayout,
116  typename ELayout,
117  typename ADataType,
118  typename BDataType,
119  typename DsDataType,
120  typename EDataType,
121  typename AElementwiseOperation,
122  typename BElementwiseOperation,
123  typename CDEElementwiseOperation>
125 {
126  static constexpr index_t NumDTensor = DsDataType::Size();
127 
128 #ifndef CK_CODE_GEN_RTC
129  virtual std::unique_ptr<BaseArgument>
130  MakeArgumentPointer(const void* p_a,
131  const void* p_b,
132  std::array<const void*, NumDTensor> p_ds,
133  void* p_e,
134  ck::index_t M,
135  ck::index_t N,
136  ck::index_t K,
137  ck::index_t StrideA,
138  ck::index_t StrideB,
139  std::array<ck::index_t, NumDTensor> StrideDs,
140  ck::index_t StrideE,
141  ck::index_t KBatch,
142  AElementwiseOperation a_element_op,
143  BElementwiseOperation b_element_op,
144  CDEElementwiseOperation cde_element_op) = 0;
145 
146  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
147 
148  virtual int GetPreShuffleParameters() = 0;
149 #endif
150 };
151 
159 template <typename ALayout,
160  typename BLayout,
161  typename DsLayout,
162  typename ELayout,
163  typename ADataType,
164  typename BDataType,
165  typename DsDataType,
166  typename EDataType,
167  typename AElementwiseOperation,
168  typename BElementwiseOperation,
169  typename CDEElementwiseOperation>
171  BLayout,
172  DsLayout,
173  ELayout,
174  ADataType,
175  BDataType,
176  DsDataType,
177  EDataType,
178  AElementwiseOperation,
179  BElementwiseOperation,
180  CDEElementwiseOperation>
181 {
183  BLayout,
184  DsLayout,
185  ELayout,
186  ADataType,
187  BDataType,
188  DsDataType,
189  EDataType,
190  AElementwiseOperation,
191  BElementwiseOperation,
192  CDEElementwiseOperation>;
193 
194  static constexpr index_t NumDTensor = DsDataType::Size();
195 
196 #ifndef __HIPCC_RTC__
197 
198  explicit DeviceGemmMultipleDSplitKWrapper(std::unique_ptr<DeviceOp> p_op)
199  : p_op_(std::move(p_op))
200  {
201  }
202 
203  bool IsSupportedArgument(const BaseArgument* p_arg) override
204  {
205  return p_op_->IsSupportedArgument(p_arg);
206  }
207  std::unique_ptr<BaseArgument>
208  MakeArgumentPointer(const void* p_a,
209  const void* p_b,
210  std::array<const void*, NumDTensor> p_ds,
211  void* p_e,
212  ck::index_t M,
213  ck::index_t N,
214  ck::index_t K,
215  ck::index_t StrideA,
216  ck::index_t StrideB,
217  std::array<ck::index_t, NumDTensor> StrideDs,
218  ck::index_t StrideE,
219  AElementwiseOperation a_element_op,
220  BElementwiseOperation b_element_op,
221  CDEElementwiseOperation cde_element_op) override
222  {
223  return p_op_->MakeArgumentPointer(p_a,
224  p_b,
225  p_ds,
226  p_e,
227  M,
228  N,
229  K,
230  StrideA,
231  StrideB,
232  StrideDs,
233  StrideE,
234  1, // KBatch
235  a_element_op,
236  b_element_op,
237  cde_element_op);
238  }
239 
240  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
241  {
242  return p_op_->MakeInvokerPointer();
243  }
244 
245  std::string GetTypeString() const override { return p_op_->GetTypeString(); }
246 
247  private:
248  std::unique_ptr<DeviceOp> p_op_;
249 
250 #endif // __HIPCC_RTC__
251 };
252 
253 } // namespace device
254 } // namespace tensor_operation
255 } // namespace ck
Definition: ck.hpp:267
int32_t index_t
Definition: ck.hpp:298
Definition: device_base.hpp:51
Definition: device_base.hpp:77
Definition: device_gemm_multiple_d.hpp:36
static constexpr index_t NumDTensor
Definition: device_gemm_multiple_d.hpp:37
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
static constexpr index_t NumDTensor
Definition: device_gemm_multiple_d.hpp:126
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, ck::index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
Definition: device_gemm_multiple_d.hpp:80
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, ck::index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)=0
static constexpr index_t NumDTensor
Definition: device_gemm_multiple_d.hpp:81
Wrapper for backward compatibility that allows to use instances of DeviceGemmMultipleDSplitK in conte...
Definition: device_gemm_multiple_d.hpp:181
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_multiple_d.hpp:240
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_multiple_d.hpp:203
DeviceGemmMultipleDSplitKWrapper(std::unique_ptr< DeviceOp > p_op)
Definition: device_gemm_multiple_d.hpp:198
std::string GetTypeString() const override
Definition: device_gemm_multiple_d.hpp:245
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, ck::index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition: device_gemm_multiple_d.hpp:208
static constexpr index_t NumDTensor
Definition: device_gemm_multiple_d.hpp:194