BatchedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > Struct Template Reference

BatchedGemmKernel&lt; TilePartitioner_, GemmPipeline_, EpiloguePipeline_ &gt; Struct Template Reference#

Composable Kernel: ck_tile::BatchedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > Struct Template Reference
ck_tile::BatchedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ > Struct Template Reference

#include <batched_gemm_kernel.hpp>

Inheritance diagram for ck_tile::BatchedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >:
ck_tile::GemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >

Classes

struct  BatchedGemmKernelArgs
 

Public Types

using Base = GemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >
 
using GemmKernelArgs = typename ck_tile::GemmKernelArgs<>
 
using ADataType = typename Base::ADataType
 
using BDataType = typename Base::BDataType
 
using CDataType = typename Base::EDataType
 
using TilePartitioner = typename Base::TilePartitioner
 
using GemmPipeline = typename Base::GemmPipeline
 
using EpiloguePipeline = typename Base::EpiloguePipeline
 
using ALayout = typename Base::ALayout
 
using BLayout = typename Base::BLayout
 
using CLayout = typename Base::ELayout
 
using KernelArgs = BatchedGemmKernelArgs
 
- Public Types inherited from ck_tile::GemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >
using TilePartitioner = remove_cvref_t< TilePartitioner_ >
 
using GemmPipeline = remove_cvref_t< GemmPipeline_ >
 
using EpiloguePipeline = remove_cvref_t< EpiloguePipeline_ >
 
using ALayout = remove_cvref_t< typename GemmPipeline::ALayout >
 
using BLayout = remove_cvref_t< typename GemmPipeline::BLayout >
 
using ELayout = remove_cvref_t< typename GemmPipeline::CLayout >
 
using DsLayout = remove_cvref_t< typename EpiloguePipeline::DsLayout >
 
using DsDataType = remove_cvref_t< typename EpiloguePipeline::DsDataType >
 
using ADataType = remove_cvref_t< typename GemmPipeline::ADataType >
 
using BDataType = remove_cvref_t< typename GemmPipeline::BDataType >
 
using EDataType = remove_cvref_t< typename EpiloguePipeline::ODataType >
 
using KernelArgs = GemmKernelArgs< DsLayout::size()>
 

Public Member Functions

CK_TILE_DEVICE void operator() (BatchedGemmKernelArgs kargs) const
 
- Public Member Functions inherited from ck_tile::GemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >
template<bool U = !PersistentKernel, typename = std::enable_if_t<U>>
CK_TILE_DEVICE void operator() (KernelArgs kargs) const
 
template<bool U = PersistentKernel, typename = std::enable_if_t<U>, typename = void>
CK_TILE_DEVICE void operator() (KernelArgs kargs) const
 

Static Public Member Functions

static CK_TILE_HOST const std::string GetName ()
 
static constexpr __host__ auto GridSize (index_t M, index_t N, index_t KBatch, index_t batch_count)
 
static constexpr __host__ auto BlockSize ()
 
static constexpr CK_TILE_HOST BatchedGemmKernelArgs MakeKernelArgs (const BatchedGemmHostArgs &hostArgs)
 
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize ()
 
- Static Public Member Functions inherited from ck_tile::GemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >
static CK_TILE_HOST const std::string GetName ()
 
static constexpr CK_TILE_HOST auto GridSize (index_t M, index_t N, index_t KBatch)
 
static CK_TILE_HOST auto MaxOccupancyGridSize (const stream_config &s) -> dim3
 Get the maximum occupancy grid size for the persistent kernel on the current device. More...
 
static constexpr CK_TILE_HOST auto BlockSize ()
 
static constexpr CK_TILE_HOST KernelArgs MakeKernelArgs (const GemmHostArgs< NumDTensor > &hostArgs)
 
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize ()
 
static CK_TILE_HOST bool IsSupportedArgument (const KernelArgs &kargs)
 
template<memory_operation_enum DstInMemOp = memory_operation_enum::set>
static CK_TILE_DEVICE auto MakeGemmTensorViews (const ADataType *a_ptr, const BDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
 
template<typename TensorView >
static CK_TILE_DEVICE auto MakeGemmPadViews (const TensorView &views)
 
template<typename PadView >
static CK_TILE_DEVICE auto MakeGemmTileWindows (const PadView &views, const index_t i_m, const index_t i_n)
 
template<bool UseDefaultScheduler = true>
static CK_TILE_DEVICE void RunGemm (const ADataType *a_ptr, const BDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
 Runs single GEMM problem cooperatively by whole workgroup. More...
 
static CK_TILE_DEVICE void RunGemm2LDS (const ADataType *a_ptr, const BDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
 Runs single GEMM problem cooperatively by whole workgroup. More...
 

Additional Inherited Members

- Static Public Attributes inherited from ck_tile::GemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >
static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize
 
static constexpr bool PersistentKernel = has_persistent_kernel::value
 
static constexpr index_t NumDTensor = DsDataType::size()
 
static constexpr auto I0 = number<0>()
 
static constexpr auto I1 = number<1>()
 
static constexpr auto I2 = number<2>()
 
static constexpr auto I3 = number<3>{}
 

Member Typedef Documentation

◆ ADataType

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::BatchedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::ADataType = typename Base::ADataType

◆ ALayout

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::BatchedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::ALayout = typename Base::ALayout

◆ Base

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::BatchedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::Base = GemmKernel<TilePartitioner_, GemmPipeline_, EpiloguePipeline_>

◆ BDataType

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::BatchedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::BDataType = typename Base::BDataType

◆ BLayout

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::BatchedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::BLayout = typename Base::BLayout

◆ CDataType

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::BatchedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::CDataType = typename Base::EDataType

◆ CLayout

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::BatchedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::CLayout = typename Base::ELayout

◆ EpiloguePipeline

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::BatchedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::EpiloguePipeline = typename Base::EpiloguePipeline

◆ GemmKernelArgs

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::BatchedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::GemmKernelArgs = typename ck_tile::GemmKernelArgs<>

◆ GemmPipeline

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::BatchedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::GemmPipeline = typename Base::GemmPipeline

◆ KernelArgs

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::BatchedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::KernelArgs = BatchedGemmKernelArgs

◆ TilePartitioner

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
using ck_tile::BatchedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::TilePartitioner = typename Base::TilePartitioner

Member Function Documentation

◆ BlockSize()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static constexpr __host__ auto ck_tile::BatchedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::BlockSize ( )
inlinestaticconstexpr

◆ GetName()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static CK_TILE_HOST const std::string ck_tile::BatchedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::GetName ( )
inlinestatic

◆ GetSmemSize()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static constexpr CK_TILE_HOST_DEVICE index_t ck_tile::BatchedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::GetSmemSize ( )
inlinestaticconstexpr

◆ GridSize()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static constexpr __host__ auto ck_tile::BatchedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::GridSize ( index_t  M,
index_t  N,
index_t  KBatch,
index_t  batch_count 
)
inlinestaticconstexpr

◆ MakeKernelArgs()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
static constexpr CK_TILE_HOST BatchedGemmKernelArgs ck_tile::BatchedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::MakeKernelArgs ( const BatchedGemmHostArgs hostArgs)
inlinestaticconstexpr

◆ operator()()

template<typename TilePartitioner_ , typename GemmPipeline_ , typename EpiloguePipeline_ >
CK_TILE_DEVICE void ck_tile::BatchedGemmKernel< TilePartitioner_, GemmPipeline_, EpiloguePipeline_ >::operator() ( BatchedGemmKernelArgs  kargs) const
inline

The documentation for this struct was generated from the following file:
  • /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp