TileFmhaShape< BlockTile_, Gemm0BlockWarps_, Gemm0WarpTile_, Gemm1BlockWarps_, Gemm1WarpTile_, IsVLayoutRowMajor_ > Struct Template Reference

TileFmhaShape&lt; BlockTile_, Gemm0BlockWarps_, Gemm0WarpTile_, Gemm1BlockWarps_, Gemm1WarpTile_, IsVLayoutRowMajor_ &gt; Struct Template Reference#

Composable Kernel: ck_tile::TileFmhaShape< BlockTile_, Gemm0BlockWarps_, Gemm0WarpTile_, Gemm1BlockWarps_, Gemm1WarpTile_, IsVLayoutRowMajor_ > Struct Template Reference
ck_tile::TileFmhaShape< BlockTile_, Gemm0BlockWarps_, Gemm0WarpTile_, Gemm1BlockWarps_, Gemm1WarpTile_, IsVLayoutRowMajor_ > Struct Template Reference

#include <tile_fmha_shape.hpp>

Public Types

using BlockTile = remove_cvref_t< BlockTile_ >
 
using Gemm0BlockWarps = remove_cvref_t< Gemm0BlockWarps_ >
 
using Gemm0WarpTile = remove_cvref_t< Gemm0WarpTile_ >
 
using Gemm1BlockWarps = remove_cvref_t< Gemm1BlockWarps_ >
 
using Gemm1WarpTile = remove_cvref_t< Gemm1WarpTile_ >
 
using VLayout = std::conditional_t< IsVLayoutRowMajor, ck_tile::tensor_layout::gemm::RowMajor, ck_tile::tensor_layout::gemm::ColumnMajor >
 

Static Public Attributes

static constexpr index_t NumGemm0Warps
 
static constexpr index_t NumGemm1Warps
 
static constexpr index_t NumWarps = max(NumGemm0Warps, NumGemm1Warps)
 
static constexpr index_t kM0 = BlockTile::at(number<0>{})
 
static constexpr index_t kN0 = BlockTile::at(number<1>{})
 
static constexpr index_t kK0 = BlockTile::at(number<2>{})
 
static constexpr index_t kN1 = BlockTile::at(number<3>{})
 
static constexpr index_t kK1 = BlockTile::at(number<4>{})
 
static constexpr index_t kQKHeaddim
 
static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(kQKHeaddim)
 
static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_
 

Member Typedef Documentation

◆ BlockTile

template<typename BlockTile_ , typename Gemm0BlockWarps_ , typename Gemm0WarpTile_ , typename Gemm1BlockWarps_ , typename Gemm1WarpTile_ , bool IsVLayoutRowMajor_>
using ck_tile::TileFmhaShape< BlockTile_, Gemm0BlockWarps_, Gemm0WarpTile_, Gemm1BlockWarps_, Gemm1WarpTile_, IsVLayoutRowMajor_ >::BlockTile = remove_cvref_t<BlockTile_>

◆ Gemm0BlockWarps

template<typename BlockTile_ , typename Gemm0BlockWarps_ , typename Gemm0WarpTile_ , typename Gemm1BlockWarps_ , typename Gemm1WarpTile_ , bool IsVLayoutRowMajor_>
using ck_tile::TileFmhaShape< BlockTile_, Gemm0BlockWarps_, Gemm0WarpTile_, Gemm1BlockWarps_, Gemm1WarpTile_, IsVLayoutRowMajor_ >::Gemm0BlockWarps = remove_cvref_t<Gemm0BlockWarps_>

◆ Gemm0WarpTile

template<typename BlockTile_ , typename Gemm0BlockWarps_ , typename Gemm0WarpTile_ , typename Gemm1BlockWarps_ , typename Gemm1WarpTile_ , bool IsVLayoutRowMajor_>
using ck_tile::TileFmhaShape< BlockTile_, Gemm0BlockWarps_, Gemm0WarpTile_, Gemm1BlockWarps_, Gemm1WarpTile_, IsVLayoutRowMajor_ >::Gemm0WarpTile = remove_cvref_t<Gemm0WarpTile_>

◆ Gemm1BlockWarps

template<typename BlockTile_ , typename Gemm0BlockWarps_ , typename Gemm0WarpTile_ , typename Gemm1BlockWarps_ , typename Gemm1WarpTile_ , bool IsVLayoutRowMajor_>
using ck_tile::TileFmhaShape< BlockTile_, Gemm0BlockWarps_, Gemm0WarpTile_, Gemm1BlockWarps_, Gemm1WarpTile_, IsVLayoutRowMajor_ >::Gemm1BlockWarps = remove_cvref_t<Gemm1BlockWarps_>

◆ Gemm1WarpTile

template<typename BlockTile_ , typename Gemm0BlockWarps_ , typename Gemm0WarpTile_ , typename Gemm1BlockWarps_ , typename Gemm1WarpTile_ , bool IsVLayoutRowMajor_>
using ck_tile::TileFmhaShape< BlockTile_, Gemm0BlockWarps_, Gemm0WarpTile_, Gemm1BlockWarps_, Gemm1WarpTile_, IsVLayoutRowMajor_ >::Gemm1WarpTile = remove_cvref_t<Gemm1WarpTile_>

◆ VLayout

template<typename BlockTile_ , typename Gemm0BlockWarps_ , typename Gemm0WarpTile_ , typename Gemm1BlockWarps_ , typename Gemm1WarpTile_ , bool IsVLayoutRowMajor_>
using ck_tile::TileFmhaShape< BlockTile_, Gemm0BlockWarps_, Gemm0WarpTile_, Gemm1BlockWarps_, Gemm1WarpTile_, IsVLayoutRowMajor_ >::VLayout = std::conditional_t<IsVLayoutRowMajor, ck_tile::tensor_layout::gemm::RowMajor, ck_tile::tensor_layout::gemm::ColumnMajor>

Member Data Documentation

◆ IsVLayoutRowMajor

template<typename BlockTile_ , typename Gemm0BlockWarps_ , typename Gemm0WarpTile_ , typename Gemm1BlockWarps_ , typename Gemm1WarpTile_ , bool IsVLayoutRowMajor_>
constexpr bool ck_tile::TileFmhaShape< BlockTile_, Gemm0BlockWarps_, Gemm0WarpTile_, Gemm1BlockWarps_, Gemm1WarpTile_, IsVLayoutRowMajor_ >::IsVLayoutRowMajor = IsVLayoutRowMajor_
staticconstexpr

◆ kK0

template<typename BlockTile_ , typename Gemm0BlockWarps_ , typename Gemm0WarpTile_ , typename Gemm1BlockWarps_ , typename Gemm1WarpTile_ , bool IsVLayoutRowMajor_>
constexpr index_t ck_tile::TileFmhaShape< BlockTile_, Gemm0BlockWarps_, Gemm0WarpTile_, Gemm1BlockWarps_, Gemm1WarpTile_, IsVLayoutRowMajor_ >::kK0 = BlockTile::at(number<2>{})
staticconstexpr

◆ kK1

template<typename BlockTile_ , typename Gemm0BlockWarps_ , typename Gemm0WarpTile_ , typename Gemm1BlockWarps_ , typename Gemm1WarpTile_ , bool IsVLayoutRowMajor_>
constexpr index_t ck_tile::TileFmhaShape< BlockTile_, Gemm0BlockWarps_, Gemm0WarpTile_, Gemm1BlockWarps_, Gemm1WarpTile_, IsVLayoutRowMajor_ >::kK1 = BlockTile::at(number<4>{})
staticconstexpr

◆ kM0

template<typename BlockTile_ , typename Gemm0BlockWarps_ , typename Gemm0WarpTile_ , typename Gemm1BlockWarps_ , typename Gemm1WarpTile_ , bool IsVLayoutRowMajor_>
constexpr index_t ck_tile::TileFmhaShape< BlockTile_, Gemm0BlockWarps_, Gemm0WarpTile_, Gemm1BlockWarps_, Gemm1WarpTile_, IsVLayoutRowMajor_ >::kM0 = BlockTile::at(number<0>{})
staticconstexpr

◆ kN0

template<typename BlockTile_ , typename Gemm0BlockWarps_ , typename Gemm0WarpTile_ , typename Gemm1BlockWarps_ , typename Gemm1WarpTile_ , bool IsVLayoutRowMajor_>
constexpr index_t ck_tile::TileFmhaShape< BlockTile_, Gemm0BlockWarps_, Gemm0WarpTile_, Gemm1BlockWarps_, Gemm1WarpTile_, IsVLayoutRowMajor_ >::kN0 = BlockTile::at(number<1>{})
staticconstexpr

◆ kN1

template<typename BlockTile_ , typename Gemm0BlockWarps_ , typename Gemm0WarpTile_ , typename Gemm1BlockWarps_ , typename Gemm1WarpTile_ , bool IsVLayoutRowMajor_>
constexpr index_t ck_tile::TileFmhaShape< BlockTile_, Gemm0BlockWarps_, Gemm0WarpTile_, Gemm1BlockWarps_, Gemm1WarpTile_, IsVLayoutRowMajor_ >::kN1 = BlockTile::at(number<3>{})
staticconstexpr

◆ kQKHeaddim

template<typename BlockTile_ , typename Gemm0BlockWarps_ , typename Gemm0WarpTile_ , typename Gemm1BlockWarps_ , typename Gemm1WarpTile_ , bool IsVLayoutRowMajor_>
constexpr index_t ck_tile::TileFmhaShape< BlockTile_, Gemm0BlockWarps_, Gemm0WarpTile_, Gemm1BlockWarps_, Gemm1WarpTile_, IsVLayoutRowMajor_ >::kQKHeaddim
staticconstexpr
Initial value:
=
BlockTile::at(number<5>{})

◆ kSubQKHeaddim

template<typename BlockTile_ , typename Gemm0BlockWarps_ , typename Gemm0WarpTile_ , typename Gemm1BlockWarps_ , typename Gemm1WarpTile_ , bool IsVLayoutRowMajor_>
constexpr index_t ck_tile::TileFmhaShape< BlockTile_, Gemm0BlockWarps_, Gemm0WarpTile_, Gemm1BlockWarps_, Gemm1WarpTile_, IsVLayoutRowMajor_ >::kSubQKHeaddim = ceil_to_qualified_tile_length(kQKHeaddim)
staticconstexpr

◆ NumGemm0Warps

template<typename BlockTile_ , typename Gemm0BlockWarps_ , typename Gemm0WarpTile_ , typename Gemm1BlockWarps_ , typename Gemm1WarpTile_ , bool IsVLayoutRowMajor_>
constexpr index_t ck_tile::TileFmhaShape< BlockTile_, Gemm0BlockWarps_, Gemm0WarpTile_, Gemm1BlockWarps_, Gemm1WarpTile_, IsVLayoutRowMajor_ >::NumGemm0Warps
staticconstexpr
Initial value:
=
__host__ __device__ multiplies() -> multiplies< void, void >
FIXME: create macro to replace 'host device' and nothing more.
constexpr CK_TILE_HOST_DEVICE index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition: sequence.hpp:973
remove_cvref_t< Gemm0BlockWarps_ > Gemm0BlockWarps
Definition: tile_fmha_shape.hpp:33

◆ NumGemm1Warps

template<typename BlockTile_ , typename Gemm0BlockWarps_ , typename Gemm0WarpTile_ , typename Gemm1BlockWarps_ , typename Gemm1WarpTile_ , bool IsVLayoutRowMajor_>
constexpr index_t ck_tile::TileFmhaShape< BlockTile_, Gemm0BlockWarps_, Gemm0WarpTile_, Gemm1BlockWarps_, Gemm1WarpTile_, IsVLayoutRowMajor_ >::NumGemm1Warps
staticconstexpr
Initial value:
=
remove_cvref_t< Gemm1BlockWarps_ > Gemm1BlockWarps
Definition: tile_fmha_shape.hpp:35

◆ NumWarps

template<typename BlockTile_ , typename Gemm0BlockWarps_ , typename Gemm0WarpTile_ , typename Gemm1BlockWarps_ , typename Gemm1WarpTile_ , bool IsVLayoutRowMajor_>
constexpr index_t ck_tile::TileFmhaShape< BlockTile_, Gemm0BlockWarps_, Gemm0WarpTile_, Gemm1BlockWarps_, Gemm1WarpTile_, IsVLayoutRowMajor_ >::NumWarps = max(NumGemm0Warps, NumGemm1Warps)
staticconstexpr

The documentation for this struct was generated from the following file:
  • /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp