API reference guide#

This document contains details of the APIs for the Composable Kernel (CK) library and introduces some of the key design principles that are used to write new classes that extend CK functionality.

Using CK API#

This section describes how to use the CK library API.

CK Datatypes#

DeviceMem#

struct DeviceMem#

Container for storing data in GPU device memory.

Kernels For Flashattention#

The Flashattention algorithm is defined in Dao et al. [DFE+22]. This section lists the classes that are used in the CK GPU implementation of Flashattention.

Gridwise classes

template<typename FloatAB, typename FloatGemmAcc, typename FloatCShuffle, typename FloatC, typename AElementwiseOperation, typename BElementwiseOperation, typename AccElementwiseOperation, typename B1ElementwiseOperation, typename CElementwiseOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation, typename AGridDesc_AK0_M_AK1, typename BGridDesc_BK0_N_BK1, typename B1GridDesc_BK0_N_BK1, typename CGridDesc_M_N, index_t NumGemmKPrefetchStage, index_t BlockSize, index_t MPerBlock, index_t NPerBlock, index_t KPerBlock, index_t Gemm1NPerBlock, index_t Gemm1KPerBlock, index_t AK1Value, index_t BK1Value, index_t B1K1Value, index_t MPerXdl, index_t NPerXdl, index_t MXdlPerWave, index_t NXdlPerWave, index_t Gemm1NXdlPerWave, typename ABlockTransferThreadClusterLengths_AK0_M_AK1, typename ABlockTransferThreadClusterArrangeOrder, typename ABlockTransferSrcAccessOrder, index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferDstScalarPerVector_AK1, bool AThreadTransferSrcResetCoordinateAfterRun, index_t ABlockLdsExtraM, typename BBlockTransferThreadClusterLengths_BK0_N_BK1, typename BBlockTransferThreadClusterArrangeOrder, typename BBlockTransferSrcAccessOrder, index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferDstScalarPerVector_BK1, bool BThreadTransferSrcResetCoordinateAfterRun, index_t BBlockLdsExtraN, typename B1BlockTransferThreadClusterLengths_BK0_N_BK1, typename B1BlockTransferThreadClusterArrangeOrder, typename B1BlockTransferSrcAccessOrder, index_t B1BlockTransferSrcVectorDim, index_t B1BlockTransferSrcScalarPerVector, index_t B1BlockTransferDstScalarPerVector_BK1, bool B1ThreadTransferSrcResetCoordinateAfterRun, index_t B1BlockLdsExtraN, index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, index_t CShuffleBlockTransferScalarPerVector_NPerBlock, LoopScheduler LoopSched, bool PadN, bool MaskOutUpperTriangle, PipelineVersion PipelineVer = PipelineVersion::v1>
struct GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle#

Gridwise gemm + softmax + gemm fusion.

Blockwise classes

template<typename ThreadGroup, typename SrcElementwiseOperation, typename DstElementwiseOperation, InMemoryDataOperationEnum DstInMemOp, typename BlockSliceLengths, typename ThreadClusterLengths, typename ThreadClusterArrangeOrder, typename SrcData, typename DstData, typename SrcDesc, typename DstDesc, typename SrcDimAccessOrder, typename DstDimAccessOrder, index_t SrcVectorDim, index_t DstVectorDim, index_t SrcScalarPerVector, index_t DstScalarPerVector, index_t SrcScalarStrideInVector, index_t DstScalarStrideInVector, bool ThreadTransferSrcResetCoordinateAfterRun, bool ThreadTransferDstResetCoordinateAfterRun, index_t NumThreadScratch = 1>
struct ThreadGroupTensorSliceTransfer_v4r1#

Blockwise data transfer.

This version does following things to avoid scratch memory issue

  1. Use StaticallyIndexedArray instead of C array for thread buffer

  2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor

  3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate

template<index_t BlockSize, typename FloatAB, typename FloatAcc, typename ATileDesc, typename BTileDesc, typename AMmaTileDesc, typename BMmaTileDesc, index_t MPerBlock, index_t NPerBlock, index_t KPerBlock, index_t MPerXDL, index_t NPerXDL, index_t MRepeat, index_t NRepeat, index_t KPack, bool TransposeC = false, index_t AMmaKStride = KPack * XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops, index_t BMmaKStride = KPack * XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops>
struct BlockwiseGemmXdlops_v2#

Blockwise gemm.

Supports

  1. regular XDL output M2_M3_M4_M2 and transposed XDL output M2_N2_N3_N4

  2. decoupled input tile descriptor and mma tile descriptor in order to support both vgpr and LDS source buffer

  3. configurable k index starting position and step size after each FMA/XDL instruction

template<index_t BlockSize, typename AccDataType, typename ThreadMap_M_K, typename ThreadClusterDesc_M_K, typename ThreadSliceDesc_M_K, bool IgnoreNaN = false>
struct BlockwiseSoftmax#

Blockwise softmax.

Template Parameters
  • BlockSize – Block size

  • AccDataType – Accumulator data type

  • ThreadMap_M_K – Thread id to m_k

  • ThreadClusterDesc_M_K – Threadwise cluster descriptor

  • ThreadSliceDesc_M_K – Threadwise slices descriptor

  • IgnoreNaN – Flag to ignore NaN, false by default

Threadwise classes

template<typename SrcData, typename DstData, typename SrcDesc, typename DstDesc, typename ElementwiseOperation, typename SliceLengths, typename DimAccessOrder, index_t DstVectorDim, index_t DstScalarPerVector, typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_StaticToStatic#

Threadwise data transfer.

Do NOT involve any tensor coordinates with StaticBuffer

DFE+22

Tri Dao, Daniel Y Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. Flashattention: fast and memory-efficient exact attention with io-awareness. arXiv preprint arXiv:2205.14135, 2022.