/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp Source File#
grouped_gemm_kernel.hpp
Go to the documentation of this file.
Definition: cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition: arch.hpp:136
Definition: gemm_kernel.hpp:32
Definition: gemm_kernel.hpp:86
Definition: gemm_kernel.hpp:119
Definition: gemm_kernel.hpp:60
static CK_TILE_DEVICE void RunGemm(const ADataType *a_ptr, const BDataType *b_ptr, CDataType *c_ptr, void *smem_ptr, const GemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: gemm_kernel.hpp:465
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Definition: gemm_kernel.hpp:69
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition: gemm_kernel.hpp:64
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: gemm_kernel.hpp:72
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: gemm_kernel.hpp:70
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: gemm_kernel.hpp:62
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: gemm_kernel.hpp:66
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: gemm_kernel.hpp:65
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: gemm_kernel.hpp:61
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: gemm_kernel.hpp:63
Definition: grouped_gemm_kernel.hpp:15
CK_TILE_HOST GroupedGemmHostArgs() noexcept=default
Definition: grouped_gemm_kernel.hpp:55
GemmTransKernelArg(GemmKernelArgs &&karg, index_t bl_start, index_t bl_end)
Definition: grouped_gemm_kernel.hpp:61
GemmKernelArgs group_karg
Definition: grouped_gemm_kernel.hpp:56
ck_tile::index_t block_end
Definition: grouped_gemm_kernel.hpp:58
ck_tile::index_t block_start
Definition: grouped_gemm_kernel.hpp:57
GemmTransKernelArg()=default
Definition: grouped_gemm_kernel.hpp:36
static constexpr index_t KernelBlockSize
Definition: grouped_gemm_kernel.hpp:52
static __host__ auto GetWorkSpaceSize(const std::vector< GroupedGemmHostArgs > &gemm_descs) -> std::size_t
Definition: grouped_gemm_kernel.hpp:67
static CK_TILE_HOST auto MakeKargs(const std::vector< GroupedGemmHostArgs > &gemm_descs) -> std::vector< GemmTransKernelArg >
Definition: grouped_gemm_kernel.hpp:86
CK_TILE_DEVICE void Run(const GemmTransKernelArg &kargs) const
Definition: grouped_gemm_kernel.hpp:138
static constexpr CK_TILE_HOST_DEVICE auto GetSmemSize() -> index_t
Definition: grouped_gemm_kernel.hpp:133
typename Base::GemmKernelArgs GemmKernelArgs
Definition: grouped_gemm_kernel.hpp:50
static constexpr __host__ auto GridSize(const std::vector< GroupedGemmHostArgs > &gemm_descs)
Definition: grouped_gemm_kernel.hpp:75
static constexpr __host__ auto BlockSize() -> dim3
Definition: grouped_gemm_kernel.hpp:73
CK_TILE_DEVICE void operator()(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, index_t group_count) const
Definition: grouped_gemm_kernel.hpp:159
Struct used to calculate offseted tile indexes.
Definition: gemm_tile_partitioner.hpp:183
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from 1D raw-indexes.
Definition: gemm_tile_partitioner.hpp:192