/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp Source File
device_grouped_gemm.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <array>
7 #include <iostream>
8 #include <sstream>
9 #include <stdexcept>
10 #include <vector>
11 
12 #include "device_base.hpp"
13 #include "ck/utility/ignore.hpp"
14 
15 namespace ck {
16 namespace tensor_operation {
17 namespace device {
18 
27 template <index_t NumDTensor = 0>
29 {
30  __host__ __device__ GroupedGemmKernelArgument(const void* p_a_grid_,
31  const void* p_b_grid_,
32  std::array<const void*, NumDTensor> p_ds_grid_,
33  void* p_e_grid_,
34  index_t M_,
35  index_t N_,
36  index_t K_,
37  index_t StrideA_,
38  index_t StrideB_,
39  std::array<index_t, NumDTensor> StrideDs_,
40  index_t StrideE_)
41  : p_a_grid{p_a_grid_},
42  p_b_grid{p_b_grid_},
43  p_ds_grid{p_ds_grid_},
44  p_e_grid{p_e_grid_},
45  M{M_},
46  N{N_},
47  K{K_},
48  StrideA{StrideA_},
49  StrideB{StrideB_},
50  StrideDs{StrideDs_},
51  StrideE{StrideE_}
52  {
53  }
54 
55  const void* p_a_grid;
56  const void* p_b_grid;
57  std::array<const void*, NumDTensor> p_ds_grid;
58  void* p_e_grid;
64  std::array<index_t, NumDTensor> StrideDs;
66 
67  void Print() const
68  {
69  std::stringstream str;
70  for(auto sd : StrideDs)
71  str << sd << ",";
72 
73  std::cout << "arg {" << "M:" << M << ", " << "N:" << N << ", " << "K:" << K << ", "
74  << "SA:" << StrideA << ", " << "SB:" << StrideB << ", " << "SE:" << StrideE
75  << ", " << "SDs: {" << str.str() << "}" << "}" << std::endl;
76  }
77 };
78 
79 struct GemmDesc
80 {
83 
84  std::vector<ck::index_t> stride_Ds_;
85 };
86 
87 template <typename ALayout,
88  typename BLayout,
89  typename DsLayout,
90  typename ELayout,
91  typename ADataType,
92  typename BDataType,
93  typename DsDataType,
94  typename EDataType,
95  typename AElementwiseOperation,
96  typename BElementwiseOperation,
97  typename CElementwiseOperation,
98  typename ComputeDataType = ADataType>
100 {
101  static constexpr index_t NumDTensor = DsDataType::Size();
102 
103  static_assert(DsLayout::Size() == DsDataType::Size(), "wrong! inconsistent NumDTensor");
104 
105  virtual std::unique_ptr<BaseArgument>
106  MakeArgumentPointer(std::vector<const void*>& p_a,
107  std::vector<const void*>& p_b,
108  std::vector<std::array<const void*, NumDTensor>>& p_ds,
109  std::vector<void*>& p_e,
110  std::vector<GemmDesc>& gemm_desc,
111  AElementwiseOperation a_element_op,
112  BElementwiseOperation b_element_op,
113  CElementwiseOperation c_element_op) = 0;
114 
115  virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
116 
117  //---------------------------------------------------------------------------------------------
128  virtual void SetDeviceKernelArgs(BaseArgument* p_arg,
129  void* p_dev_kernel_args,
130  const void* p_host_kernel_args) const
131  {
132  ignore = p_arg;
133  ignore = p_dev_kernel_args;
134  ignore = p_host_kernel_args;
135 
136  std::ostringstream err;
137  err << "This function is not implemented by the kernel: " << this->GetTypeString()
138  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
139  throw std::runtime_error(err.str());
140  }
141 
142  //----------------------------------------------------------------------------------------------
149  virtual void SetDeviceKernelArgs(BaseArgument* p_arg, void* p_dev_kernel_args) const
150  {
151  ignore = p_arg;
152  ignore = p_dev_kernel_args;
153 
154  std::ostringstream err;
155  err << "This function is not implemented by the kernel: " << this->GetTypeString()
156  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
157  throw std::runtime_error(err.str());
158  }
159 
160  //----------------------------------------------------------------------------------------------
167  virtual size_t GetDeviceKernelArgSize(const BaseArgument* p_arg) const
168  {
169  ignore = p_arg;
170 
171  std::ostringstream err;
172  err << "This function is not implemented by the kernel: " << this->GetTypeString()
173  << __FILE__ << ":" << __LINE__ << ", in function: " << __func__;
174  throw std::runtime_error(err.str());
175  }
176 };
177 
178 } // namespace device
179 } // namespace tensor_operation
180 } // namespace ck
Definition: ck.hpp:270
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
int32_t index_t
Definition: ck.hpp:301
Definition: device_base.hpp:197
Definition: device_base.hpp:223
virtual std::string GetTypeString() const
Definition: device_base.hpp:229
Definition: device_grouped_gemm.hpp:100
static constexpr index_t NumDTensor
Definition: device_grouped_gemm.hpp:101
virtual size_t GetDeviceKernelArgSize(const BaseArgument *p_arg) const
Gets the device kernel argument size.
Definition: device_grouped_gemm.hpp:167
virtual void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args) const
Sets the device kernel arguments pointer and may copy data to device.
Definition: device_grouped_gemm.hpp:149
virtual void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args, const void *p_host_kernel_args) const
Sets the device kernel arguments pointer and may copy data to device.
Definition: device_grouped_gemm.hpp:128
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > &p_a, std::vector< const void * > &p_b, std::vector< std::array< const void *, NumDTensor >> &p_ds, std::vector< void * > &p_e, std::vector< GemmDesc > &gemm_desc, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)=0
Definition: device_grouped_gemm.hpp:80
ck::index_t stride_C_
Definition: device_grouped_gemm.hpp:82
std::vector< ck::index_t > stride_Ds_
Definition: device_grouped_gemm.hpp:84
ck::index_t K_
Definition: device_grouped_gemm.hpp:81
ck::index_t stride_A_
Definition: device_grouped_gemm.hpp:82
ck::index_t N_
Definition: device_grouped_gemm.hpp:81
ck::index_t stride_B_
Definition: device_grouped_gemm.hpp:82
ck::index_t M_
Definition: device_grouped_gemm.hpp:81
Structure representing single GEMM problem arguments.
Definition: device_grouped_gemm.hpp:29
void Print() const
Definition: device_grouped_gemm.hpp:67
index_t StrideB
Definition: device_grouped_gemm.hpp:63
void * p_e_grid
Definition: device_grouped_gemm.hpp:58
index_t StrideE
Definition: device_grouped_gemm.hpp:65
index_t N
Definition: device_grouped_gemm.hpp:60
const void * p_a_grid
Definition: device_grouped_gemm.hpp:55
index_t K
Definition: device_grouped_gemm.hpp:61
std::array< index_t, NumDTensor > StrideDs
Definition: device_grouped_gemm.hpp:64
index_t StrideA
Definition: device_grouped_gemm.hpp:62
index_t M
Definition: device_grouped_gemm.hpp:59
__host__ __device__ GroupedGemmKernelArgument(const void *p_a_grid_, const void *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, void *p_e_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_)
Definition: device_grouped_gemm.hpp:30
const void * p_b_grid
Definition: device_grouped_gemm.hpp:56
std::array< const void *, NumDTensor > p_ds_grid
Definition: device_grouped_gemm.hpp:57