/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_grouped_gemm_multi_abd.hpp Source File
device_grouped_gemm_multi_abd.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <vector>
8 
9 #include "device_base.hpp"
10 
11 namespace ck {
12 namespace tensor_operation {
13 namespace device {
14 
16 {
18 
19  std::vector<ck::index_t> stride_As_;
20  std::vector<ck::index_t> stride_Bs_;
21  std::vector<ck::index_t> stride_Ds_;
22 
24 };
25 
26 /*
27  * \brief Grouped Gemm Multi ABD
28  *
29  * C = a_op(A, A1...) * b_op(B, B1...)
30  * E = cde_op(C, D0, D1, ...)
31  *
32  * \tparam AsLayout A layouts (tuple).
33  * \tparam BsLayout B layouts (tuple).
34  * \tparam DsLayout Ds layouts (tuple).
35  * \tparam ELayout Output layout.
36  * \tparam AsDataType A data types (tuple).
37  * \tparam BsDataType B data types (tuple).
38  * \tparam DsDataType D data types (tuple).
39  * \tparam EDataType Output data type.
40  * \tparam AElementwiseOperation A elementwise operation.
41  * \tparam BElementwiseOperation B elementwise operation.
42  * \tparam CDEElementwiseOperation C elementwise operation.
43  */
44 template <typename AsLayout,
45  typename BsLayout,
46  typename DsLayout,
47  typename ELayout,
48  typename AsDataType,
49  typename BsDataType,
50  typename DsDataType,
51  typename EDataType,
52  typename AElementwiseOperation,
53  typename BElementwiseOperation,
54  typename CDEElementwiseOperation>
56 {
57  static constexpr index_t NumATensor = AsDataType::Size();
58  static constexpr index_t NumBTensor = BsDataType::Size();
59  static constexpr index_t NumDTensor = DsDataType::Size();
60 
61  static_assert(AsLayout::Size() == AsDataType::Size(), "wrong! inconsistent NumATensor");
62  static_assert(BsLayout::Size() == BsDataType::Size(), "wrong! inconsistent NumBTensor");
63  static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsistent NumDTensor");
64 
65  /*
66  * \brief Make argument pointer for grouped gemm multi abd.
67  *
68  * \param p_as A pointers to the A.
69  * \param p_bs A pointers to the B.
70  * \param p_ds A pointers to the Ds.
71  * \param p_e A pointers to the E.
72  * \param gemm_desc Gemm descriptors for each group.
73  * \param a_element_op A elementwise operation object.
74  * \param b_element_op B elementwise operation object.
75  * \param cde_element_op CDE elementwise operation object.
76  * \return Pointer to the argument.
77  */
78  virtual std::unique_ptr<BaseArgument>
79  MakeArgumentPointer(std::vector<std::array<const void*, NumATensor>>& p_as,
80  std::vector<std::array<const void*, NumBTensor>>& p_bs,
81  std::vector<std::array<const void*, NumDTensor>>& p_ds,
82  std::vector<void*>& p_e,
83  std::vector<GemmMultiABDDesc>& gemm_desc,
84  AElementwiseOperation a_element_op = AElementwiseOperation{},
85  BElementwiseOperation b_element_op = BElementwiseOperation{},
86  CDEElementwiseOperation c_element_op = CDEElementwiseOperation{}) = 0;
87 
88  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
89 
90  virtual void SetElementwiseOps(BaseArgument* p_arg,
91  AElementwiseOperation a_element_op,
92  BElementwiseOperation b_element_op,
93  CDEElementwiseOperation cde_element_op) const = 0;
94 };
95 
96 } // namespace device
97 } // namespace tensor_operation
98 } // namespace ck
Definition: ck.hpp:264
int32_t index_t
Definition: ck.hpp:289
Definition: device_base.hpp:50
Definition: device_base.hpp:76
Definition: device_grouped_gemm_multi_abd.hpp:56
virtual void SetElementwiseOps(BaseArgument *p_arg, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) const =0
static constexpr index_t NumDTensor
Definition: device_grouped_gemm_multi_abd.hpp:59
static constexpr index_t NumATensor
Definition: device_grouped_gemm_multi_abd.hpp:57
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< std::array< const void *, NumATensor >> &p_as, std::vector< std::array< const void *, NumBTensor >> &p_bs, std::vector< std::array< const void *, NumDTensor >> &p_ds, std::vector< void * > &p_e, std::vector< GemmMultiABDDesc > &gemm_desc, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CDEElementwiseOperation c_element_op=CDEElementwiseOperation{})=0
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
static constexpr index_t NumBTensor
Definition: device_grouped_gemm_multi_abd.hpp:58
Definition: device_grouped_gemm_multi_abd.hpp:16
ck::index_t stride_C_
Definition: device_grouped_gemm_multi_abd.hpp:23
std::vector< ck::index_t > stride_Ds_
Definition: device_grouped_gemm_multi_abd.hpp:21
std::vector< ck::index_t > stride_As_
Definition: device_grouped_gemm_multi_abd.hpp:19
ck::index_t M_
Definition: device_grouped_gemm_multi_abd.hpp:17
std::vector< ck::index_t > stride_Bs_
Definition: device_grouped_gemm_multi_abd.hpp:20
ck::index_t N_
Definition: device_grouped_gemm_multi_abd.hpp:17
ck::index_t K_
Definition: device_grouped_gemm_multi_abd.hpp:17