GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > Struct Template Reference

GroupedGemmKernel&lt; TilePartitioner_, GemmPipeline_, EpiloguePipeline_ &gt; Struct Template Reference#

Composable Kernel: ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > Struct Template Reference
ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > Struct Template Reference

#include <grouped_gemm_kernel.hpp>

Inheritance diagram for ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >:
ck_tile::GemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >

Public Types

using TilePartitioner = remove_cvref_t< TilePartitioner_ >
 
using GemmPipeline = remove_cvref_t< GemmPipeline_ >
 
using EpiloguePipeline = remove_cvref_t< EpiloguePipeline_ >
 
using ALayout = remove_cvref_t< typename GemmPipeline::ALayout >
 
using BLayout = remove_cvref_t< typename GemmPipeline::BLayout >
 
using ELayout = remove_cvref_t< typename GemmPipeline::CLayout >
 
using ADataType = remove_cvref_t< typename GemmPipeline::ADataType >
 
using BDataType = remove_cvref_t< typename GemmPipeline::BDataType >
 
using CDataType = remove_cvref_t< typename EpiloguePipeline::ODataType >
 
using OffsetTile1DPartitioner = OffsettedTile1DPartitioner< TilePartitioner >
 
using Base = GemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >
 
using Kernel = GroupedGemmKernel< TilePartitioner, GemmPipeline, EpiloguePipeline >
 
- Public Types inherited from ck_tile::GemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >
using TilePartitioner = remove_cvref_t< TilePartitioner_ >
 
using GemmPipeline = remove_cvref_t< GemmPipeline_ >
 
using EpiloguePipeline = remove_cvref_t< EpiloguePipeline_ >
 
using ALayout = remove_cvref_t< typename GemmPipeline::ALayout >
 
using BLayout = remove_cvref_t< typename GemmPipeline::BLayout >
 
using ELayout = remove_cvref_t< typename GemmPipeline::CLayout >
 
using DsLayout = remove_cvref_t< typename EpiloguePipeline::DsLayout >
 
using DsDataType = remove_cvref_t< typename EpiloguePipeline::DsDataType >
 
using ADataType = remove_cvref_t< typename GemmPipeline::ADataType >
 
using BDataType = remove_cvref_t< typename GemmPipeline::BDataType >
 
using EDataType = remove_cvref_t< typename EpiloguePipeline::ODataType >
 
using KernelArgs = GemmKernelArgs< DsLayout::size()>
 

Public Member Functions

CK_TILE_DEVICE void Run (const GemmTransKernelArg &kargs, const tuple< index_t, index_t > &block_idx_2d, const index_t block_idx_z) const
 
CK_TILE_DEVICE void Run (const GemmKernelArgs<> &kargs, const tuple< index_t, index_t > &block_idx_2d, const index_t block_idx_z) const
 
CK_TILE_DEVICE index_t FindGroupId (const GemmTransKernelArg *gemm_desc_ptr, index_t block_id, index_t group_count) const
 
template<bool U = UsePersistentKernel, typename = std::enable_if_t<!U>>
CK_TILE_DEVICE void operator() (const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, index_t group_count) const
 
template<bool U = UsePersistentKernel, typename = std::enable_if_t<U>, typename = void>
CK_TILE_DEVICE void operator() (const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count) const
 
- Public Member Functions inherited from ck_tile::GemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >
template<bool U = !PersistentKernel, typename = std::enable_if_t<U>>
CK_TILE_DEVICE void operator() (KernelArgs kargs) const
 
template<bool U = PersistentKernel, typename = std::enable_if_t<U>, typename = void>
CK_TILE_DEVICE void operator() (KernelArgs kargs) const
 

Static Public Member Functions

static CK_TILE_HOST const std::string GetName ()
 
static CK_TILE_HOST auto GetWorkSpaceSize (const std::vector< GemmHostArgs<>> &gemm_descs) -> std::size_t
 
static CK_TILE_HOST auto GetWorkSpaceSize (index_t group_count) -> std::size_t
 
static constexpr CK_TILE_HOST auto BlockSize () -> dim3
 
static CK_TILE_HOST auto MaxOccupancyGridSize (const stream_config &s) -> dim3
 Get the maximum occupancy grid size for the persistent kernel on the current device. More...
 
static constexpr CK_TILE_HOST auto GridSize (const std::vector< GemmHostArgs<>> &gemm_descs)
 
static CK_TILE_HOST auto MakeKargs (const std::vector< GemmHostArgs<>> &gemm_descs) -> std::vector< GemmTransKernelArg >
 
static CK_TILE_HOST bool IsSupportedArgument (const std::vector< GemmTransKernelArg > &kargs)
 
static constexpr CK_TILE_HOST_DEVICE auto GetSmemSize () -> index_t
 
static CK_TILE_DEVICE void RunGemmWithPipelineSelection (const ADataType *a_ptr, const BDataType *b_ptr, CDataType *c_ptr, void *smem_ptr_0, const GemmKernelArgs<> &kargs, const typename Base::SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
 Runs single GEMM problem cooperatively by whole workgroup. More...
 
- Static Public Member Functions inherited from ck_tile::GemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >
static CK_TILE_HOST const std::string GetName ()
 
static constexpr CK_TILE_HOST auto GridSize (index_t M, index_t N, index_t KBatch)
 
static CK_TILE_HOST auto MaxOccupancyGridSize (const stream_config &s) -> dim3
 Get the maximum occupancy grid size for the persistent kernel on the current device. More...
 
static constexpr CK_TILE_HOST auto BlockSize ()
 
static constexpr CK_TILE_HOST KernelArgs MakeKernelArgs (const GemmHostArgs< NumDTensor > &hostArgs)
 
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize ()
 
static CK_TILE_HOST bool IsSupportedArgument (const KernelArgs &kargs)
 
template<memory_operation_enum DstInMemOp = memory_operation_enum::set>
static CK_TILE_DEVICE auto MakeGemmTensorViews (const ADataType *a_ptr, const BDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
 
template<typename TensorView >
static CK_TILE_DEVICE auto MakeGemmPadViews (const TensorView &views)
 
template<typename PadView >
static CK_TILE_DEVICE auto MakeGemmTileWindows (const PadView &views, const index_t i_m, const index_t i_n)
 
template<bool UseDefaultScheduler = true>
static CK_TILE_DEVICE void RunGemm (const ADataType *a_ptr, const BDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
 Runs single GEMM problem cooperatively by whole workgroup. More...
 
static CK_TILE_DEVICE void RunGemm2LDS (const ADataType *a_ptr, const BDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
 Runs single GEMM problem cooperatively by whole workgroup. More...
 

Static Public Attributes

static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize
 
static constexpr bool UsePersistentKernel = GemmPipeline::UsePersistentKernel
 
- Static Public Attributes inherited from ck_tile::GemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize
 
static constexpr bool PersistentKernel = has_persistent_kernel::value
 
static constexpr index_t NumDTensor = DsDataType::size()
 
static constexpr auto I0 = number<0>()
 
static constexpr auto I1 = number<1>()
 
static constexpr auto I2 = number<2>()
 
static constexpr auto I3 = number<3>{}
 

Member Typedef Documentation

◆ ADataType

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::ADataType = remove_cvref_t<typename GemmPipeline::ADataType>

◆ ALayout

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::ALayout = remove_cvref_t<typename GemmPipeline::ALayout>

◆ Base

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>

◆ BDataType

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::BDataType = remove_cvref_t<typename GemmPipeline::BDataType>

◆ BLayout

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::BLayout = remove_cvref_t<typename GemmPipeline::BLayout>

◆ CDataType

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>

◆ ELayout

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::ELayout = remove_cvref_t<typename GemmPipeline::CLayout>

◆ EpiloguePipeline

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::EpiloguePipeline = remove_cvref_t<EpiloguePipeline_>

◆ GemmPipeline

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::GemmPipeline = remove_cvref_t<GemmPipeline_>

◆ Kernel

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::Kernel = GroupedGemmKernel<TilePartitioner, GemmPipeline, EpiloguePipeline>

◆ OffsetTile1DPartitioner

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::OffsetTile1DPartitioner = OffsettedTile1DPartitioner<TilePartitioner>

◆ TilePartitioner

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::TilePartitioner = remove_cvref_t<TilePartitioner_>

Member Function Documentation

◆ BlockSize()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static constexpr CK_TILE_HOST auto ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::BlockSize ( ) -> dim3
inlinestaticconstexpr

◆ FindGroupId()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
CK_TILE_DEVICE index_t ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::FindGroupId ( const GemmTransKernelArg gemm_desc_ptr,
index_t  block_id,
index_t  group_count 
) const
inline

◆ GetName()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static CK_TILE_HOST const std::string ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::GetName ( )
inlinestatic

◆ GetSmemSize()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static constexpr CK_TILE_HOST_DEVICE auto ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::GetSmemSize ( ) -> index_t
inlinestaticconstexpr

◆ GetWorkSpaceSize() [1/2]

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static CK_TILE_HOST auto ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::GetWorkSpaceSize ( const std::vector< GemmHostArgs<>> &  gemm_descs) -> std::size_t
inlinestatic

◆ GetWorkSpaceSize() [2/2]

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static CK_TILE_HOST auto ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::GetWorkSpaceSize ( index_t  group_count) -> std::size_t
inlinestatic

◆ GridSize()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static constexpr CK_TILE_HOST auto ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::GridSize ( const std::vector< GemmHostArgs<>> &  gemm_descs)
inlinestaticconstexpr

◆ IsSupportedArgument()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static CK_TILE_HOST bool ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::IsSupportedArgument ( const std::vector< GemmTransKernelArg > &  kargs)
inlinestatic

◆ MakeKargs()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static CK_TILE_HOST auto ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::MakeKargs ( const std::vector< GemmHostArgs<>> &  gemm_descs) -> std::vector<GemmTransKernelArg>
inlinestatic

◆ MaxOccupancyGridSize()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static CK_TILE_HOST auto ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::MaxOccupancyGridSize ( const stream_config s) -> dim3
inlinestatic

Get the maximum occupancy grid size for the persistent kernel on the current device.

Returns
The maximum occupancy grid size.
Note
This function queries the maximum occupancy of the kernel using hipOccupancyMaxActiveBlocksPerMultiprocessor.

◆ operator()() [1/2]

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
template<bool U = UsePersistentKernel, typename = std::enable_if_t<U>, typename = void>
CK_TILE_DEVICE void ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::operator() ( const void CK_CONSTANT_ADDRESS_SPACE gemm_descs_const,
const index_t  group_count 
) const
inline

◆ operator()() [2/2]

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
template<bool U = UsePersistentKernel, typename = std::enable_if_t<!U>>
CK_TILE_DEVICE void ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::operator() ( const void CK_CONSTANT_ADDRESS_SPACE gemm_descs_const,
index_t  group_count 
) const
inline

◆ Run() [1/2]

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
CK_TILE_DEVICE void ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::Run ( const GemmKernelArgs<> &  kargs,
const tuple< index_t, index_t > &  block_idx_2d,
const index_t  block_idx_z 
) const
inline

◆ Run() [2/2]

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
CK_TILE_DEVICE void ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::Run ( const GemmTransKernelArg kargs,
const tuple< index_t, index_t > &  block_idx_2d,
const index_t  block_idx_z 
) const
inline

◆ RunGemmWithPipelineSelection()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static CK_TILE_DEVICE void ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::RunGemmWithPipelineSelection ( const ADataType a_ptr,
const BDataType b_ptr,
CDataType c_ptr,
void *  smem_ptr_0,
const GemmKernelArgs<> &  kargs,
const typename Base::SplitKBatchOffset splitk_batch_offset,
const index_t  block_idx_m,
const index_t  block_idx_n 
)
inlinestatic

Runs single GEMM problem cooperatively by whole workgroup.

Note
The GEMM pipeline is selected in-kernel based on the number of K-loops and the tail-number. This is needed for the persistent tile-loop when we didn't have access to the K dimension on the host.
Parameters
a_ptrinput A pointer
b_ptrinput B pointer
c_ptroutput C pointer
smem_ptr_0The start memory pointer of the shared memory block.
kargsGEMM kernel arguments
splitk_batch_offsetsplitk_batch_offset Utility structure used to calculate k batch.
block_idx_mThe GEMM's output M dimension tile index processed by this workgroup.
block_idx_nThe GEMM's output N dimension tile index processed by this workgroup.

Member Data Documentation

◆ KernelBlockSize

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
constexpr index_t ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::KernelBlockSize = GemmPipeline::BlockSize
staticconstexpr

◆ UsePersistentKernel

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
constexpr bool ck_tile::GroupedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::UsePersistentKernel = GemmPipeline::UsePersistentKernel
staticconstexpr

The documentation for this struct was generated from the following file:
  • /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/gemm/kernel/grouped_gemm_kernel.hpp