WarpGemmImpl< WarpGemmAttribute_ > Struct Template Reference

WarpGemmImpl&lt; WarpGemmAttribute_ &gt; Struct Template Reference#

Composable Kernel: ck_tile::WarpGemmImpl< WarpGemmAttribute_ > Struct Template Reference
ck_tile::WarpGemmImpl< WarpGemmAttribute_ > Struct Template Reference

#include <warp_gemm_impl.hpp>

Public Types

using WarpGemmAttribute = remove_cvref_t< WarpGemmAttribute_ >
 
using ADataType = typename WarpGemmAttribute::ADataType
 
using BDataType = typename WarpGemmAttribute::BDataType
 
using CDataType = typename WarpGemmAttribute::CDataType
 
using AWarpDstrEncoding = typename WarpGemmAttribute::AWarpDstrEncoding
 
using BWarpDstrEncoding = typename WarpGemmAttribute::BWarpDstrEncoding
 
using CWarpDstrEncoding = typename WarpGemmAttribute::CWarpDstrEncoding
 
using AWarpDstr = remove_cvref_t< decltype(make_static_tile_distribution(AWarpDstrEncoding{}))>
 
using BWarpDstr = remove_cvref_t< decltype(make_static_tile_distribution(BWarpDstrEncoding{}))>
 
using CWarpDstr = remove_cvref_t< decltype(make_static_tile_distribution(CWarpDstrEncoding{}))>
 
using AWarpTensor = static_distributed_tensor< ADataType, AWarpDstr >
 
using BWarpTensor = static_distributed_tensor< BDataType, BWarpDstr >
 
using CWarpTensor = static_distributed_tensor< CDataType, CWarpDstr >
 

Public Member Functions

template<typename CTensor , typename ATensor , typename BTensor , bool post_nop_ = false>
CK_TILE_DEVICE void operator() (CTensor &c, const ATensor &a, const BTensor &b, bool_constant< post_nop_ >={}) const
 
template<typename CTensor , typename ATensor , typename BTensor , index_t i_subk, bool post_nop_ = false>
CK_TILE_DEVICE void operator() (CTensor &c, const ATensor &a, const BTensor &b, number< i_subk >, bool_constant< post_nop_ >={}) const
 
template<typename ATensor , typename BTensor >
CK_TILE_DEVICE auto operator() (const ATensor &a, const BTensor &b) const
 

Static Public Member Functions

static constexpr CK_TILE_HOST_DEVICE auto get_num_of_access ()
 

Static Public Attributes

static constexpr index_t kM = WarpGemmAttribute::kM
 
static constexpr index_t kN = WarpGemmAttribute::kN
 
static constexpr index_t kK = WarpGemmAttribute::kK
 
static constexpr index_t kKPerThread = WarpGemmAttribute::kKPerThread
 The number of elements in K dimension processed by single thread in wavefront. More...
 

Member Typedef Documentation

◆ ADataType

template<typename WarpGemmAttribute_ >
using ck_tile::WarpGemmImpl< WarpGemmAttribute_ >::ADataType = typename WarpGemmAttribute::ADataType

◆ AWarpDstr

template<typename WarpGemmAttribute_ >
using ck_tile::WarpGemmImpl< WarpGemmAttribute_ >::AWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(AWarpDstrEncoding{}))>

◆ AWarpDstrEncoding

template<typename WarpGemmAttribute_ >
using ck_tile::WarpGemmImpl< WarpGemmAttribute_ >::AWarpDstrEncoding = typename WarpGemmAttribute::AWarpDstrEncoding

◆ AWarpTensor

template<typename WarpGemmAttribute_ >
using ck_tile::WarpGemmImpl< WarpGemmAttribute_ >::AWarpTensor = static_distributed_tensor<ADataType, AWarpDstr>

◆ BDataType

template<typename WarpGemmAttribute_ >
using ck_tile::WarpGemmImpl< WarpGemmAttribute_ >::BDataType = typename WarpGemmAttribute::BDataType

◆ BWarpDstr

template<typename WarpGemmAttribute_ >
using ck_tile::WarpGemmImpl< WarpGemmAttribute_ >::BWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(BWarpDstrEncoding{}))>

◆ BWarpDstrEncoding

template<typename WarpGemmAttribute_ >
using ck_tile::WarpGemmImpl< WarpGemmAttribute_ >::BWarpDstrEncoding = typename WarpGemmAttribute::BWarpDstrEncoding

◆ BWarpTensor

template<typename WarpGemmAttribute_ >
using ck_tile::WarpGemmImpl< WarpGemmAttribute_ >::BWarpTensor = static_distributed_tensor<BDataType, BWarpDstr>

◆ CDataType

template<typename WarpGemmAttribute_ >
using ck_tile::WarpGemmImpl< WarpGemmAttribute_ >::CDataType = typename WarpGemmAttribute::CDataType

◆ CWarpDstr

template<typename WarpGemmAttribute_ >
using ck_tile::WarpGemmImpl< WarpGemmAttribute_ >::CWarpDstr = remove_cvref_t<decltype(make_static_tile_distribution(CWarpDstrEncoding{}))>

◆ CWarpDstrEncoding

template<typename WarpGemmAttribute_ >
using ck_tile::WarpGemmImpl< WarpGemmAttribute_ >::CWarpDstrEncoding = typename WarpGemmAttribute::CWarpDstrEncoding

◆ CWarpTensor

template<typename WarpGemmAttribute_ >
using ck_tile::WarpGemmImpl< WarpGemmAttribute_ >::CWarpTensor = static_distributed_tensor<CDataType, CWarpDstr>

◆ WarpGemmAttribute

template<typename WarpGemmAttribute_ >
using ck_tile::WarpGemmImpl< WarpGemmAttribute_ >::WarpGemmAttribute = remove_cvref_t<WarpGemmAttribute_>

Member Function Documentation

◆ get_num_of_access()

template<typename WarpGemmAttribute_ >
static constexpr CK_TILE_HOST_DEVICE auto ck_tile::WarpGemmImpl< WarpGemmAttribute_ >::get_num_of_access ( )
inlinestaticconstexpr

◆ operator()() [1/3]

template<typename WarpGemmAttribute_ >
template<typename ATensor , typename BTensor >
CK_TILE_DEVICE auto ck_tile::WarpGemmImpl< WarpGemmAttribute_ >::operator() ( const ATensor &  a,
const BTensor &  b 
) const
inline

◆ operator()() [2/3]

template<typename WarpGemmAttribute_ >
template<typename CTensor , typename ATensor , typename BTensor , bool post_nop_ = false>
CK_TILE_DEVICE void ck_tile::WarpGemmImpl< WarpGemmAttribute_ >::operator() ( CTensor &  c,
const ATensor &  a,
const BTensor &  b,
bool_constant< post_nop_ >  = {} 
) const
inline

◆ operator()() [3/3]

template<typename WarpGemmAttribute_ >
template<typename CTensor , typename ATensor , typename BTensor , index_t i_subk, bool post_nop_ = false>
CK_TILE_DEVICE void ck_tile::WarpGemmImpl< WarpGemmAttribute_ >::operator() ( CTensor &  c,
const ATensor &  a,
const BTensor &  b,
number< i_subk >  ,
bool_constant< post_nop_ >  = {} 
) const
inline

Member Data Documentation

◆ kK

template<typename WarpGemmAttribute_ >
constexpr index_t ck_tile::WarpGemmImpl< WarpGemmAttribute_ >::kK = WarpGemmAttribute::kK
staticconstexpr

◆ kKPerThread

template<typename WarpGemmAttribute_ >
constexpr index_t ck_tile::WarpGemmImpl< WarpGemmAttribute_ >::kKPerThread = WarpGemmAttribute::kKPerThread
staticconstexpr

The number of elements in K dimension processed by single thread in wavefront.

Note
Note that WarpGemm may run MFMA instruction multiple times (on different K). In such situation this value reflects this fact.

◆ kM

template<typename WarpGemmAttribute_ >
constexpr index_t ck_tile::WarpGemmImpl< WarpGemmAttribute_ >::kM = WarpGemmAttribute::kM
staticconstexpr

◆ kN

template<typename WarpGemmAttribute_ >
constexpr index_t ck_tile::WarpGemmImpl< WarpGemmAttribute_ >::kN = WarpGemmAttribute::kN
staticconstexpr

The documentation for this struct was generated from the following file:
  • /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/gemm/warp/warp_gemm_impl.hpp