/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp Source File
device_grouped_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <array>
7 #include <iostream>
8 #include <sstream>
9 #include <stdexcept>
10 #include <vector>
11 
12 #include "device_base.hpp"
13 #include "ck/utility/ignore.hpp"
14 
15 namespace ck {
16 namespace tensor_operation {
17 namespace device {
18 
27 template <index_t NumDTensor = 0>
29 {
30  __host__ __device__ GroupedGemmKernelArgument(const void* p_a_grid_,
31  const void* p_b_grid_,
32  std::array<const void*, NumDTensor> p_ds_grid_,
33  void* p_e_grid_,
34  index_t M_,
35  index_t N_,
36  index_t K_,
37  index_t StrideA_,
38  index_t StrideB_,
39  std::array<index_t, NumDTensor> StrideDs_,
40  index_t StrideE_)
41  : p_a_grid{p_a_grid_},
42  p_b_grid{p_b_grid_},
43  p_ds_grid{p_ds_grid_},
44  p_e_grid{p_e_grid_},
45  M{M_},
46  N{N_},
47  K{K_},
48  StrideA{StrideA_},
49  StrideB{StrideB_},
50  StrideDs{StrideDs_},
51  StrideE{StrideE_}
52  {
53  }
54 
55  const void* p_a_grid;
56  const void* p_b_grid;
57  std::array<const void*, NumDTensor> p_ds_grid;
58  void* p_e_grid;
64  std::array<index_t, NumDTensor> StrideDs;
66 
67  void Print() const
68  {
69  std::stringstream str;
70  for(auto sd : StrideDs)
71  str << sd << ",";
72 
73  std::cout << "arg {"
74  << "M:" << M << ", "
75  << "N:" << N << ", "
76  << "K:" << K << ", "
77  << "SA:" << StrideA << ", "
78  << "SB:" << StrideB << ", "
79  << "SE:" << StrideE << ", "
80  << "SDs: {" << str.str() << "}"
81  << "}" << std::endl;
82  }
83 };
84 
85 struct GemmDesc
86 {
89 
90  std::vector<ck::index_t> stride_Ds_;
91 };
92 
93 template <typename ALayout,
94  typename BLayout,
95  typename DsLayout,
96  typename ELayout,
97  typename ADataType,
98  typename BDataType,
99  typename DsDataType,
100  typename EDataType,
101  typename AElementwiseOperation,
102  typename BElementwiseOperation,
103  typename CElementwiseOperation>
105 {
106  static constexpr index_t NumDTensor = DsDataType::Size();
107 
108  static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsistent NumDTensor");
109 
110  virtual std::unique_ptr<BaseArgument>
111  MakeArgumentPointer(std::vector<const void*>& p_a,
112  std::vector<const void*>& p_b,
113  std::vector<std::array<const void*, NumDTensor>>& p_ds,
114  std::vector<void*>& p_e,
115  std::vector<GemmDesc>& gemm_desc,
116  AElementwiseOperation a_element_op,
117  BElementwiseOperation b_element_op,
118  CElementwiseOperation c_element_op) = 0;
119 
120  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
121 
122  //---------------------------------------------------------------------------------------------
133  virtual void SetDeviceKernelArgs(BaseArgument* p_arg,
134  void* p_dev_kernel_args,
135  const void* p_host_kernel_args) const
136  {
137  ignore = p_arg;
138  ignore = p_dev_kernel_args;
139  ignore = p_host_kernel_args;
140 
141  std::ostringstream err;
142  err << "This function is not implemented by the kernel: " << this->GetTypeString()
143  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
144  throw std::runtime_error(err.str());
145  }
146 
147  //----------------------------------------------------------------------------------------------
154  virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const
155  {
156  ignore = p_arg;
157  ignore = p_dev_kernel_args;
158 
159  std::ostringstream err;
160  err << "This function is not implemented by the kernel: " << this->GetTypeString()
161  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
162  throw std::runtime_error(err.str());
163  }
164 
165  //----------------------------------------------------------------------------------------------
172  virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const
173  {
174  ignore = p_arg;
175 
176  std::ostringstream err;
177  err << "This function is not implemented by the kernel: " << this->GetTypeString()
178  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
179  throw std::runtime_error(err.str());
180  }
181 };
182 
183 } // namespace device
184 } // namespace tensor_operation
185 } // namespace ck
Definition: ck.hpp:264
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
int32_t index_t
Definition: ck.hpp:289
Definition: device_base.hpp:50
Definition: device_base.hpp:76
virtual std::string GetTypeString() const
Definition: device_base.hpp:82
Definition: device_grouped_gemm.hpp:105
virtual void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args) const
Sets the device kernel arguments pointer and may copy data to device.
Definition: device_grouped_gemm.hpp:154
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > &p_a, std::vector< const void * > &p_b, std::vector< std::array< const void *, NumDTensor >> &p_ds, std::vector< void * > &p_e, std::vector< GemmDesc > &gemm_desc, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)=0
virtual void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args, const void *p_host_kernel_args) const
Sets the device kernel arguments pointer and may copy data to device.
Definition: device_grouped_gemm.hpp:133
static constexpr index_t NumDTensor
Definition: device_grouped_gemm.hpp:106
virtual size_t GetDeviceKernelArgSize(const BaseArgument *p_arg) const
Gets the device kernel argument size.
Definition: device_grouped_gemm.hpp:172
Definition: device_grouped_gemm.hpp:86
ck::index_t stride_C_
Definition: device_grouped_gemm.hpp:88
std::vector< ck::index_t > stride_Ds_
Definition: device_grouped_gemm.hpp:90
ck::index_t K_
Definition: device_grouped_gemm.hpp:87
ck::index_t stride_A_
Definition: device_grouped_gemm.hpp:88
ck::index_t N_
Definition: device_grouped_gemm.hpp:87
ck::index_t stride_B_
Definition: device_grouped_gemm.hpp:88
ck::index_t M_
Definition: device_grouped_gemm.hpp:87
Structure representing single GEMM problem arguments.
Definition: device_grouped_gemm.hpp:29
void Print() const
Definition: device_grouped_gemm.hpp:67
index_t StrideB
Definition: device_grouped_gemm.hpp:63
void * p_e_grid
Definition: device_grouped_gemm.hpp:58
index_t StrideE
Definition: device_grouped_gemm.hpp:65
index_t N
Definition: device_grouped_gemm.hpp:60
const void * p_a_grid
Definition: device_grouped_gemm.hpp:55
index_t K
Definition: device_grouped_gemm.hpp:61
std::array< index_t, NumDTensor > StrideDs
Definition: device_grouped_gemm.hpp:64
index_t StrideA
Definition: device_grouped_gemm.hpp:62
index_t M
Definition: device_grouped_gemm.hpp:59
__host__ __device__ GroupedGemmKernelArgument(const void *p_a_grid_, const void *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, void *p_e_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_)
Definition: device_grouped_gemm.hpp:30
const void * p_b_grid
Definition: device_grouped_gemm.hpp:56
std::array< const void *, NumDTensor > p_ds_grid
Definition: device_grouped_gemm.hpp:57