/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp Source File#
device_grouped_gemm.hpp
Go to the documentation of this file.
147 //----------------------------------------------------------------------------------------------
165 //----------------------------------------------------------------------------------------------
Definition: ck.hpp:264
Definition: device_base.hpp:50
Definition: device_base.hpp:76
virtual std::string GetTypeString() const
Definition: device_base.hpp:82
Definition: device_grouped_gemm.hpp:105
virtual void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args) const
Sets the device kernel arguments pointer and may copy data to device.
Definition: device_grouped_gemm.hpp:154
virtual std::unique_ptr< BaseInvoker > MakeInvokerPointer()=0
virtual std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > &p_a, std::vector< const void * > &p_b, std::vector< std::array< const void *, NumDTensor >> &p_ds, std::vector< void * > &p_e, std::vector< GemmDesc > &gemm_desc, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)=0
virtual void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args, const void *p_host_kernel_args) const
Sets the device kernel arguments pointer and may copy data to device.
Definition: device_grouped_gemm.hpp:133
static constexpr index_t NumDTensor
Definition: device_grouped_gemm.hpp:106
virtual size_t GetDeviceKernelArgSize(const BaseArgument *p_arg) const
Gets the device kernel argument size.
Definition: device_grouped_gemm.hpp:172
Definition: device_grouped_gemm.hpp:86
ck::index_t stride_C_
Definition: device_grouped_gemm.hpp:88
std::vector< ck::index_t > stride_Ds_
Definition: device_grouped_gemm.hpp:90
ck::index_t stride_A_
Definition: device_grouped_gemm.hpp:88
ck::index_t stride_B_
Definition: device_grouped_gemm.hpp:88
Structure representing single GEMM problem arguments.
Definition: device_grouped_gemm.hpp:29
void Print() const
Definition: device_grouped_gemm.hpp:67
index_t StrideB
Definition: device_grouped_gemm.hpp:63
void * p_e_grid
Definition: device_grouped_gemm.hpp:58
index_t StrideE
Definition: device_grouped_gemm.hpp:65
index_t N
Definition: device_grouped_gemm.hpp:60
const void * p_a_grid
Definition: device_grouped_gemm.hpp:55
index_t K
Definition: device_grouped_gemm.hpp:61
std::array< index_t, NumDTensor > StrideDs
Definition: device_grouped_gemm.hpp:64
index_t StrideA
Definition: device_grouped_gemm.hpp:62
index_t M
Definition: device_grouped_gemm.hpp:59
__host__ __device__ GroupedGemmKernelArgument(const void *p_a_grid_, const void *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, void *p_e_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_)
Definition: device_grouped_gemm.hpp:30
const void * p_b_grid
Definition: device_grouped_gemm.hpp:56
std::array< const void *, NumDTensor > p_ds_grid
Definition: device_grouped_gemm.hpp:57