device Namespace Reference

device Namespace Reference#

Composable Kernel: ck::tensor_operation::device Namespace Reference
ck::tensor_operation::device Namespace Reference

Classes

struct  DeviceAvgPoolBwd
 
struct  BaseArgument
 
struct  BaseInvoker
 
struct  BaseOperator
 
struct  DeviceBatchedContractionMultipleD
 
struct  DeviceBatchedGemm
 
struct  BatchedGemmEPermuteDesc
 
struct  DeviceBatchedGemmEPermute
 
struct  DeviceBatchedGemmGemm
 
struct  DeviceBatchedGemmMultiD
 
struct  DeviceBatchedGemmV2MultiD
 
struct  DeviceBatchedGemmMultipleDGemmMultipleD
 
struct  DeviceBatchedGemmSoftmaxGemm
 
struct  DeviceBatchedGemmSoftmaxGemmPermute
 
struct  DeviceBatchNormBwd
 
struct  DeviceBatchNormFwd
 
struct  DeviceBatchNormInfer
 
struct  DeviceCGemm
 
struct  DeviceContractionMultipleABD
 
struct  DeviceContractionMultipleD
 
struct  DeviceConvBwdData
 
struct  DeviceConvFwd
 
struct  DeviceConvFwdBiasActivation
 
struct  DeviceConvFwdBiasActivationAdd
 
struct  DeviceConvTensorRearrange
 Convolution Tensor Rearrange. More...
 
struct  DeviceElementwise
 
struct  DeviceElementwiseNormalization
 
struct  DeviceGemm
 
struct  DEGridDesc_M0_M1_M2_N0_N1
 
struct  DeviceGemmBiasCPermute
 
struct  DeviceGemm_dequantB
 
struct  DeviceGemmMultipleABD
 
struct  DeviceGemmMultipleD
 
struct  DeviceGemmMultipleDSplitK
 
struct  DeviceGemmMultipleD_ABScale
 
struct  DeviceGemmMultipleDLayernorm
 
struct  DeviceGemmMultipleDMultipleR
 
struct  DeviceGemmReduce
 
struct  DeviceGemmSplitK
 
struct  DeviceGemmStreamK
 
struct  DeviceGemm_Streamk_V2
 
struct  DeviceGemmV2
 
struct  DeviceGemmV2R1
 
struct  DeviceGemmV2BScale
 
struct  ContractionDesc
 
struct  DeviceGroupedContractionMultipleD
 
struct  DeviceGroupedConvBwdDataMultipleD
 
struct  DeviceGroupedConvBwdWeight
 
struct  DeviceGroupedConvBwdWeightMultipleD
 
struct  DeviceGroupedConvFwd
 
struct  DeviceGroupedConvFwdMultipleABD
 Grouped Convolution Forward. More...
 
struct  GroupedGemmKernelArgument
 Structure representing single GEMM problem arguments. More...
 
struct  GemmDesc
 
struct  DeviceGroupedGemm
 
struct  DeviceGroupedGemmFixedNK
 
struct  GemmMultiABDDesc
 
struct  DeviceGroupedGemmMultiABD
 
struct  GroupedGemmMultiABDKernelArgument
 
struct  DeviceGroupedGemmMultiABDFixedNK
 
struct  DeviceGroupedGemmSoftmaxGemmPermute
 
struct  DeviceGroupedGemmSplitK
 
struct  DeviceGroupedGemmTileLoop
 Grouped GEMM kernel using output Tile Looping algorithm. More...
 
struct  DeviceMaxPoolBwd
 
struct  DeviceMultipleReduce
 
struct  DeviceNormalizationBwdData
 
struct  DeviceNormalizationBwdGammaBeta
 
struct  DeviceNormalizationFwd
 
struct  DevicePermute
 
struct  DevicePoolFwd
 
struct  DevicePutElement
 
struct  DeviceReduce
 
struct  DeviceReduceMultiD
 
struct  DeviceSoftmax
 
struct  DeviceSplitKContractionMultipleD
 
struct  CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
 
struct  DeviceAvgPool2dBwd_NHWC_NHWC
 
struct  DeviceAvgPool3dBwd_NDHWC_NDHWC
 
struct  DeviceBatchedContractionMultipleD_Wmma_CShuffle
 
struct  DeviceBatchedContractionMultipleD_Xdl_CShuffle
 
struct  DeviceBatchedGemmEPermuteXdl
 
struct  DeviceBatchedGemmGemm_Xdl_CShuffle
 
struct  DeviceBatchedGemmMultiD_Xdl
 
struct  DeviceBatchedGemmMultipleD_Dl
 
struct  DeviceBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
 
struct  DeviceBatchedGemmMultiD_Xdl_CShuffle_V3
 
struct  DeviceBatchedGemmReduce_Xdl_CShuffle
 
struct  DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
 
struct  DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
 
struct  DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
 
struct  DeviceBatchedGemmXdl
 
struct  DeviceBatchNormBwdImpl
 
struct  DeviceBatchNormFwdImpl
 
struct  DeviceCGemm_4Gemm_Xdl_CShuffle
 
struct  DeviceColumnToImageImpl
 
struct  DeviceContractionMultipleABD_Xdl_CShuffle
 
struct  DeviceContractionMultipleD_Xdl_CShuffle
 
struct  DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
 
struct  DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
 
struct  DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
 
struct  DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
 
struct  DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
 
struct  DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
 
struct  DeviceConv3dFwdNaive_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
 
struct  DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_K
 
struct  DeviceConvNdBwdDataNwcKxcNwk_Dl
 
struct  DeviceConvNdBwdDataNwcKxcNwk_Xdl
 
struct  DeviceElementwiseImpl
 
struct  DeviceElementwiseNormalizationImpl
 
struct  DeviceFpAintBGemm_Wmma_CShuffle
 
struct  DeviceGemmBiasAddReduce_Xdl_CShuffle
 
struct  DeviceGemmDl
 
struct  DeviceGemmDpp
 
struct  DeviceGemmMultipleABD_Xdl_CShuffle
 
struct  DeviceGemmMultipleD_Dl
 
struct  DeviceGemmMultipleDLayernorm_Xdl_CShuffle
 
struct  DeviceGemmMultipleDMultipleR_Xdl_CShuffle
 
struct  DeviceGemmMultipleD_Wmma_CShuffle
 
struct  DeviceGemmMultipleD_Xdl_CShuffle
 
struct  DeviceGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
 
struct  DeviceGemmMultiD_Xdl_CShuffle_V3
 
struct  DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
 
struct  DeviceGemmReduce_Xdl_CShuffle
 
struct  DeviceGemmWmma_CShuffle
 
struct  DeviceGemmXdl
 
struct  DeviceGemm_Xdl_CShuffle
 
struct  DeviceGemm_Xdl_CShuffle_LdsDirectLoad
 
struct  DeviceGemm_Xdl_CShuffle_Streamk_V3
 
struct  DeviceGemm_Xdl_CShuffleV2
 
struct  DeviceGemm_Xdl_CShuffleV3
 
struct  DeviceGemm_Xdl_CShuffleV3R1
 
struct  DeviceGemmLayerNorm_Xdl_CShuffle
 
struct  DeviceGemmXdlSkipBLds
 
struct  DeviceGemmXdlSplitKCShuffle
 
struct  DeviceGemmXdlSplitKCShuffle_LdsDirectLoad
 
struct  DeviceGemmXdlStreamK
 
struct  DeviceGemm_Xdl_WaveletModel_CShuffle
 
struct  DeviceGroupedContractionMultipleD_Xdl_CShuffle
 
struct  DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
 
struct  DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
 
struct  DeviceGroupedConvBwdWeight_Dl
 
struct  DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle
 
struct  DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle
 
struct  DeviceGroupedConvBwdWeight_Wmma_CShuffle
 
struct  DeviceGroupedConvBwdWeight_Xdl_CShuffle
 
struct  DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
 
struct  DeviceGroupedConvFwdDl_NHWC_KYXC_NHWK
 
struct  DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
 
struct  DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
 
struct  DeviceGroupedConvFwdMultipleDMultipleR
 
struct  DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
 
struct  DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
 
struct  DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
 
struct  ComputePtrOffsetOfStridedBatch
 
struct  ComputePtrOffsetOfStridedBatch< NumATensor, NumBTensor, NumDTensor, enable_if_t<(NumATensor > 1||NumBTensor > 1)> >
 
struct  ComputePtrOffsetOfStridedBatch< NumATensor, NumBTensor, NumDTensor, enable_if_t<(NumATensor==1 &&NumBTensor==1)> >
 
struct  DeviceGroupedGemm_Xdl_Multi_ABD_Fixed_NK
 
struct  DeviceGroupedGemmMultipleD_Dl
 
struct  DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
 
struct  DeviceGroupedGemmMultipleDXdlCShuffleTileLoop
 
struct  DeviceGroupedGemmSoftmaxGemmPermute_Xdl_CShuffle
 
struct  DeviceGroupedGemm_Xdl
 
struct  DeviceGroupedGemm_Xdl_Fixed_NK
 
struct  DeviceGroupedGemmXdlSplitKCShuffle
 
struct  DeviceGroupedQueryAttentionForward_Wmma
 
struct  DeviceImageToColumnImpl
 
struct  DeviceMaxPoolBwdImpl
 
struct  DeviceMultiQueryAttentionForward_Wmma
 
struct  DeviceMultipleReduceMultiBlock
 
struct  DeviceMultipleReduceThreadWise
 
struct  DeviceNormalizationBwdDataImpl
 
struct  DeviceNormalizationBwdGammaBetaImpl
 
struct  DeviceNormalizationFwdImpl
 
struct  DeviceNormalizationFwdSplitKImpl
 
struct  DevicePermuteImpl
 
struct  DevicePool2dFwd_NHWC_NHWC
 
struct  DevicePool3dFwd_NDHWC_NDHWC
 
struct  DevicePutElementImpl
 
struct  DeviceReduceMultiBlock
 
struct  DeviceReduceThreadWise
 
struct  DeviceReduceThreadWiseMultiD
 
struct  DeviceSoftmaxImpl
 
struct  DeviceSparseEmbeddingsForwardLayernorm
 
struct  DeviceSplitKContractionMultipleD_Xdl_CShuffle
 
struct  MaskDisabledPredicate
 
struct  MaskOutUpperTrianglePredicate
 
struct  C0MatrixMask_impl
 
struct  GemmGemmPadder
 
struct  GemmPadder
 
struct  MatrixPadder
 
struct  GemmPadder_v2
 
struct  MatrixPadder_v2
 
struct  GetReduceCountPerThreadForBlockwiseWelford
 
struct  GetReduceCountPerThreadForMultiblockWelford
 

Typedefs

template<typename ALayout , typename BLayout , typename CLayout , typename ADataType , typename BDataType , typename CDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation >
using DeviceBatchedGemmPtr = std::unique_ptr< DeviceBatchedGemm< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation > >
 
template<typename XDataType , typename DxDataType , typename DyDataType , typename AccDataType , typename ScaleDataType , typename DscaleDbiasDataType , typename MeanVarDataType , typename DyElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim>
using DeviceBatchNormBwdPtr = std::unique_ptr< DeviceBatchNormBwd< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim > >
 
template<typename XDataType , typename YDataType , typename AccDataType , typename ScaleDataType , typename BiasDataType , typename MeanVarDataType , typename YElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim>
using DeviceBatchNormFwdPtr = std::unique_ptr< DeviceBatchNormFwd< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumBatchNormReduceDim > >
 
template<typename XDataType , typename YDataType , typename AccDataType , typename ScaleDataType , typename BiasDataType , typename MeanVarDataType , typename YElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim>
using DeviceBatchNormInferPtr = std::unique_ptr< DeviceBatchNormInfer< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumBatchNormReduceDim > >
 
template<typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation >
using DeviceCGemmPtr = std::unique_ptr< DeviceCGemm< AElementwiseOperation, BElementwiseOperation, CElementwiseOperation > >
 
template<typename InElementwiseOperation , typename WeiElementwiseOperation , typename OutElementwiseOperation >
using DeviceConvFwdBiasActivationPtr = std::unique_ptr< DeviceConvFwdBiasActivation< InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation > >
 
template<typename InElementwiseOperation , typename WeiElementwiseOperation , typename OutElementwiseOperation >
using DeviceConvFwdBiasActivationAddPtr = std::unique_ptr< DeviceConvFwdBiasActivationAdd< InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation > >
 
template<typename InDataTypeTuple , typename OutDataTypeTuple , typename ElementwiseOperation , index_t NumDim>
using DeviceElementwisePtr = std::unique_ptr< DeviceElementwise< InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim > >
 
template<typename InDataTypeTuple , typename GammaDataType , typename BetaDataType , typename AccDataType , typename YDataType , typename XElementwiseOperation , typename YElementwiseOperation , index_t Rank, index_t NumReduceDim>
using DeviceElementwiseNormalizationPtr = std::unique_ptr< DeviceElementwiseNormalization< InDataTypeTuple, GammaDataType, BetaDataType, AccDataType, YDataType, XElementwiseOperation, YElementwiseOperation, Rank, NumReduceDim > >
 
template<typename ALayout , typename BLayout , typename DELayout , typename ADataType , typename BDataType , typename DsDataType , typename EDataType , typename RsDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , typename QsElementwiseOperation , typename RsElementwiseOperation >
using DeviceGemmMultipleDMultipleRPtr = std::unique_ptr< DeviceGemmMultipleDMultipleR< ALayout, BLayout, DELayout, ADataType, BDataType, DsDataType, EDataType, RsDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, QsElementwiseOperation, RsElementwiseOperation > >
 
template<ck::index_t NumDTensor, ck::index_t NumReduce>
using DeviceGemmReducePtr = std::unique_ptr< DeviceGemmReduce< NumDTensor, NumReduce > >
 
template<typename ALayout , typename BLayout , typename CLayout , typename ADataType , typename BDataType , typename CDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , typename ComputeType = CDataType>
using DeviceGemmSplitKPtr = std::unique_ptr< DeviceGemmSplitK< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, ComputeType > >
 
template<typename ALayout , typename BLayout , typename CLayout , typename ADataType , typename BDataType , typename CDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation >
using DeviceGemmStreamKPtr = std::unique_ptr< DeviceGemmStreamK< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation > >
 
template<typename T >
using is_tuple = decltype(std::declval< T & >().IsTuple())
 
template<index_t NDimSpatial, typename ALayout , typename BLayout , typename DsLayout , typename ELayout , typename ADataType , typename BDataType , typename DsDataType , typename EDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , typename ComputeType = decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value, Number<0>, ADataType>())>
using DeviceGroupedConvFwdMultipleD = DeviceGroupedConvFwdMultipleABD< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ComputeType >
 Grouped Convolution Forward. More...
 
template<index_t Rank, index_t NumReduceDim, index_t NumReduction, typename InElementwiseOperationTuple , typename AccElementwiseOperationTuple >
using DeviceMultipleReducePtr = std::unique_ptr< DeviceMultipleReduce< Rank, NumReduceDim, NumReduction, InElementwiseOperationTuple, AccElementwiseOperationTuple > >
 
template<typename DYDataType , typename XDataType , typename GammaDataType , typename MeanInvStdDataType , typename DXDataType , index_t Rank, index_t NumReduceDim>
using DeviceNormalizationBwdDataPtr = std::unique_ptr< DeviceNormalizationBwdData< DYDataType, XDataType, GammaDataType, MeanInvStdDataType, DXDataType, Rank, NumReduceDim > >
 
template<typename DYDataType , typename XDataType , typename MeanInvStdDataType , typename DGammaDataType , typename DBetaDataType , index_t Rank, index_t NumReduceDim>
using DeviceNormalizationBwdGammaBetaPtr = std::unique_ptr< DeviceNormalizationBwdGammaBeta< DYDataType, XDataType, MeanInvStdDataType, DGammaDataType, DBetaDataType, Rank, NumReduceDim > >
 
template<typename XDataType , typename GammaDataType , typename BetaDataType , typename YDataType , typename SaveMeanInvStdDataType , typename YElementwiseOperation , index_t Rank, index_t NumReduceDim>
using DeviceNormalizationFwdPtr = std::unique_ptr< DeviceNormalizationFwd< XDataType, GammaDataType, BetaDataType, YDataType, SaveMeanInvStdDataType, YElementwiseOperation, Rank, NumReduceDim > >
 
template<typename InDataType , typename AccDataType , typename OutDataType , index_t Rank, index_t NumReduceDim, typename ReduceOperation , typename InElementwiseOperation , typename AccElementwiseOperation , bool PropagateNan, bool OutputIndex>
using DeviceReducePtr = std::unique_ptr< DeviceReduce< InDataType, AccDataType, OutDataType, Rank, NumReduceDim, ReduceOperation, InElementwiseOperation, AccElementwiseOperation, PropagateNan, OutputIndex > >
 
template<typename InDataType , typename DsDataType , typename AccDataType , typename OutDataType , index_t Rank, index_t NumReduceDim, typename ReduceOperation , typename InElementwiseOperation , typename OutElementwiseOperation >
using DeviceReduceMultiDPtr = std::unique_ptr< DeviceReduceMultiD< InDataType, DsDataType, AccDataType, OutDataType, Rank, NumReduceDim, ReduceOperation, InElementwiseOperation, OutElementwiseOperation > >
 
template<typename InDataType , typename AccDataType , typename OutDataType , typename InElementwiseOp , typename AccElementwiseOp , index_t Rank, index_t NumReduceDim>
using DeviceSoftmaxPtr = std::unique_ptr< DeviceSoftmax< InDataType, AccDataType, OutDataType, InElementwiseOp, AccElementwiseOp, Rank, NumReduceDim > >
 
template<index_t NDimSpatial, typename ALayout , typename BLayout , typename DsLayout , typename ELayout , typename ADataType , typename BDataType , typename AccDataType , typename CShuffleDataType , typename DsDataType , typename EDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , ConvolutionForwardSpecialization ConvForwardSpecialization, GemmSpecialization GemmSpec, index_t NumGemmKPrefetchStage, index_t BlockSize, index_t MPerBlock, index_t NPerBlock, index_t KPerBlock, index_t AK1, index_t BK1, index_t MPerXDL, index_t NPerXDL, index_t MXdlPerWave, index_t NXdlPerWave, typename ABlockTransferThreadClusterLengths_AK0_M_AK1 , typename ABlockTransferThreadClusterArrangeOrder , typename ABlockTransferSrcAccessOrder , index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferDstScalarPerVector_AK1, index_t ABlockLdsExtraM, typename BBlockTransferThreadClusterLengths_BK0_N_BK1 , typename BBlockTransferThreadClusterArrangeOrder , typename BBlockTransferSrcAccessOrder , index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferDstScalarPerVector_BK1, index_t BBlockLdsExtraN, index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock , index_t CDEBlockTransferScalarPerVector_NPerBlock, typename AComputeDataType = decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value, Number<0>, ADataType>()), typename BComputeDataType = AComputeDataType, LoopScheduler LoopSched = make_default_loop_scheduler()>
using DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched >
 

Enumerations

enum class  ConvolutionBackwardDataSpecialization {
  Default ,
  Filter1x1Stride1Pad0
}
 
enum class  ConvolutionBackwardWeightSpecialization {
  Default ,
  Filter1x1Stride1Pad0 ,
  Filter1x1Pad0 ,
  OddC
}
 
enum class  ConvolutionForwardSpecialization {
  Default ,
  Filter1x1Pad0 ,
  Filter1x1Stride1Pad0 ,
  OddC ,
  Filter3x3
}
 
enum class  GemmSpecialization {
  Default ,
  MPadding ,
  NPadding ,
  KPadding ,
  MNPadding ,
  MKPadding ,
  NKPadding ,
  MNKPadding ,
  OPadding ,
  MOPadding ,
  NOPadding ,
  KOPadding ,
  MNOPadding ,
  MKOPadding ,
  NKOPadding ,
  MNKOPadding
}
 
enum class  MaskingSpecialization {
  MaskDisabled ,
  MaskOutUpperTriangle
}
 
enum class  TensorSpecialization {
  Default ,
  Packed
}
 

Functions

std::string getConvBackwardDataSpecializationString (const ConvolutionBackwardDataSpecialization &s)
 
std::string getConvBackwardWeightSpecializationString (const ConvolutionBackwardWeightSpecialization &s)
 
std::string getConvForwardSpecializationString (const ConvolutionForwardSpecialization &s)
 
std::string getGemmSpecializationString (const GemmSpecialization &s)
 
template<typename GridwiseGemm , typename ABDataType , typename EDataType , typename AGridDesc_AK0_M_AK1 , typename BGridDesc_BK0_N_BK1 , typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , typename ComputePtrOffsetOfBatch , typename Block2ETileMap , bool HasMainKBlockLoop>
__global__ void kernel_batched_gemm_e_permute_xdl (const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, EDataType *__restrict__ p_e_grid, const index_t batch_count, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map)
 
template<typename GridwiseGemm , typename FloatAB , typename FloatC , typename AElementwiseOperation , typename BElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , typename AGridDesc_AK0_M_AK1 , typename BGridDesc_BK0_N_BK1 , typename B1GridDesc_BK0_N_BK1 , typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename Block2CTileMap , typename ComputeBasePtrOfStridedBatch , bool HasMainKBlockLoop>
__global__ void kernel_gemm_gemm_xdl_cshuffle_v1 (const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, const FloatAB *__restrict__ p_b1_grid, FloatC *__restrict__ p_c_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const AccElementwiseOperation acc_element_op, const B1ElementwiseOperation b1_element_op, const CElementwiseOperation c_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const Block2CTileMap block_2_ctile_map, const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
 
template<typename GridwiseGemm , typename ABDataType , typename DsPointer , typename EDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , typename AGridDesc_AK0_M_AK1 , typename BGridDesc_BK0_N_BK1 , typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock , typename ComputePtrOffsetOfBatch , typename Block2ETileMap , bool HasMainKBlockLoop>
__global__ void kernel_batched_gemm_xdl (const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const index_t batch_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map)
 
template<typename GridwiseGemm , typename ABDataType , typename DsPointer , typename EDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , typename AGridDesc_K0_M0_M1_K1 , typename BGridDesc_K0_N0_N1_K1 , typename DsGridDesc_M0_M10_M11_N0_N10_N11 , typename CGridDesc_M0_M10_M11_N0_N10_N11 , typename ComputePtrOffsetOfBatch , typename Block2CTileMap , bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__global__ void kernel_gemm_dl_multiple_d (const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const index_t batch_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2CTileMap block_2_ctile_map)
 
template<typename GridwiseGemm , typename A0B0B1DataType , typename D0sPointer , typename D1sPointer , typename E1DataType , typename A0ElementwiseOperation , typename B0ElementwiseOperation , typename CDE0ElementwiseOperation , typename B1ElementwiseOperation , typename CDE1ElementwiseOperation , typename A0GridDesc_AK0_M_AK1 , typename B0GridDesc_BK0_N_BK1 , typename D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 , typename B1GridDesc_BK0_N_BK1 , typename D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename Block2E1TileMap , typename ComputeBasePtrOfStridedBatch , bool HasMainKBlockLoop>
__global__ void kernel_batched_gemm_gemm_xdl_cshuffle_v1 (const A0B0B1DataType *__restrict__ p_a0_grid, const A0B0B1DataType *__restrict__ p_b0_grid, D0sPointer p_d0s_grid, const A0B0B1DataType *__restrict__ p_b1_grid, D1sPointer p_d1s_grid, E1DataType *__restrict__ p_e1_grid, const A0ElementwiseOperation a0_element_op, const B0ElementwiseOperation b0_element_op, const CDE0ElementwiseOperation cde0_element_op, const B1ElementwiseOperation b1_element_op, const CDE1ElementwiseOperation cde1_element_op, const A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1, const B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1, const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock d1s_grid_desc_mblock_mperblock_nblock_nperblock, const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e1_grid_desc_mblock_mperblock_nblock_nperblock, const Block2E1TileMap block_2_e1tile_map, const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch)
 
template<typename GridwiseGemm , typename FloatAB , typename FloatC , typename ReducePtrsGlobal , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , typename ReduceInElementwiseOperations , typename ReduceAccElementwiseOperations , typename AGridDesc_AK0_M_AK1 , typename BGridDesc_BK0_N_BK1 , typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename ReduceGridDescriptor_MBlock_MPerBlock , typename ComputeBasePrtOfBatch , typename Block2CTileMap , bool HasMainK0BlockLoop>
__global__ void kernel_batched_gemm_reduce_xdl_cshuffle_v1 (const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, ReducePtrsGlobal p_reduces_grid, const index_t batch_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const ReduceInElementwiseOperations reduce_in_element_ops, const ReduceAccElementwiseOperations reduce_out_element_ops, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, const Block2CTileMap block_2_ctile_map)
 
template<typename DeviceOp , typename GridwiseOp , typename ADataType , typename B0DataType , typename B1DataType , typename CDataType , typename AElementwiseOperation , typename B0ElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , bool HasMainKBlockLoop>
__global__ void kernel_batched_gemm_softmax_gemm_wmma_cshuffle (const ADataType *__restrict__ p_a_grid, const B0DataType *__restrict__ p_b0_grid, const B1DataType *__restrict__ p_b1_grid, CDataType *__restrict__ p_c_grid, index_t M, index_t N, index_t K, index_t O, index_t G0, index_t G1, float alpha, bool input_permute, bool output_permute)
 
template<typename DeviceOp , typename GridwiseOp , typename QKVDataType , typename ODataType , typename AElementwiseOperation , typename B0ElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , bool HasMainKBlockLoop>
__global__ void kernel_wmma_self_attention_forward (const QKVDataType *__restrict__ p_qkv_grid, ODataType *__restrict__ p_out_grid, index_t batch_size, index_t sequence_length, index_t head_count, index_t head_size, float alpha)
 
template<typename DeviceOp , typename GridwiseOp , typename QDataType , typename KVDataType , typename ODataType , typename AElementwiseOperation , typename B0ElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , bool HasMainKBlockLoop>
__global__ void kernel_wmma_cross_attention_forward (const QDataType *__restrict__ p_q_grid, const KVDataType *__restrict__ p_kv_grid, ODataType *__restrict__ p_out_grid, index_t batch_size, index_t q_sequence_length, index_t kv_sequence_length, index_t head_count, index_t head_size, float alpha)
 
template<typename GridwiseGemm , typename FloatAB , typename FloatC , typename D0sPointer , typename AElementwiseOperation , typename BElementwiseOperation , typename C0DEElementwiseOperation , typename B1ElementwiseOperation , typename C1DEElementwiseOperation , typename AGridDesc_AK0_M_AK1 , typename BGridDesc_BK0_N_BK1 , typename B1GridDesc_BK0_N_BK1 , typename C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 , typename Block2CTileMap , typename ComputeBasePtrOfStridedBatch , typename C0MatrixMask , bool HasMainKBlockLoop>
__global__ void kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1 (const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, const FloatAB *__restrict__ p_b1_grid, FloatC *__restrict__ p_c_grid, D0sPointer p_d0s_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const C0DEElementwiseOperation c0de_element_op, const B1ElementwiseOperation b1_element_op, const C1DEElementwiseOperation c1de_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c1_grid_desc_mblock_mperblock_nblock_nperblock, const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, const Block2CTileMap block_2_ctile_map, const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const C0MatrixMask c0_matrix_mask)
 
template<typename GridwiseGemm , typename FloatAB , typename FloatC , typename AElementwiseOperation , typename BElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , typename AGridDesc_AK0_M_AK1 , typename BGridDesc_BK0_N_BK1 , typename B1GridDesc_BK0_N_BK1 , typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename Block2CTileMap , typename ComputeBasePtrOfStridedBatch , typename C0MatrixMask , bool HasMainKBlockLoop>
__global__ void kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1 (const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, const FloatAB *__restrict__ p_b1_grid, FloatC *__restrict__ p_c_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const AccElementwiseOperation acc_element_op, const B1ElementwiseOperation b1_element_op, const CElementwiseOperation c_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const Block2CTileMap block_2_ctile_map, const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const C0MatrixMask c0_matrix_mask)
 
template<typename DeviceOp , typename GridwiseGemm , bool HasMainKBlockLoop>
__global__ void kernel_batched_gemm_xdlops_v2r3 (const typename DeviceOp::Argument karg)
 
template<index_t NumDim1, index_t NumDim2>
auto CalculateMaxRead (const std::vector< index_t > &lengths, const std::vector< index_t > &strides)
 
template<typename GridwiseGemm , typename FloatAB , typename FloatC , typename AGridDesc_K0_M_K1 , typename BGridDesc_K0_N_K1 , typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , typename Block2CTileMap , bool HasMainKBlockLoop>
__global__ void kernel_gemm_xdlops_v2r3_for_conv3d (const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const index_t num_batches, const index_t a_batch_stride, const index_t b_batch_stride, const index_t c_batch_stride, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
 
template<typename GridwiseGemm , typename FloatAB , typename FloatC , typename AGridDesc_B_K0_M0_M1_K1 , typename BGridDesc_B_K0_N0_N1_K1 , typename CGridDesc_M0_M10_M11_N0_N10_N11 , typename Block2CTileMap , typename ComputePtrOffsetOfBatch , bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__global__ void kernel_batched_gemm_dlops_bwd_weight (const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const index_t batch_count, const AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1, const BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
 
template<typename GridwiseGemm , typename FloatA , typename FloatB , typename FloatC , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , typename AGridDesc_B_K0_M_K1 , typename BGridDesc_B_K0_N_K1 , typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock , typename Block2CTileMap , typename ComputePtrOffsetOfBatch , bool HasMainKBlockLoop>
__global__ void kernel_batched_gemm_xdlops_bwd_weight (const FloatA *__restrict__ p_a_grid, const FloatB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const index_t batch_count, const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
 
template<typename GridwiseGemm , typename AGridDesc_AK0_M_K1 , typename BGridDesc_BK0_N_K1 , typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock , typename ComputePtrOffsetOfBatch , index_t NumGroupsToMerge, bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation, index_t MinimumOccupancy = 1, TailNumber TailNum = TailNumber::Full>
__global__ void kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3 (typename GridwiseGemm::Argument karg, [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, [[maybe_unused]] const index_t num_k_per_block)
 
template<typename GridwiseGemm , typename AGridDesc_AK0_M_K1 , typename BGridDesc_BK0_N_K1 , typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock , typename ComputePtrOffsetOfBatch , index_t NumGroupsToMerge, bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation, index_t MinimumOccupancy = 1, TailNumber TailNum = TailNumber::Full>
__global__ void kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds (typename GridwiseGemm::Argument karg, [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, [[maybe_unused]] const index_t num_k_per_block)
 
template<typename InLayout , typename WeiLayout , typename OutLayout >
constexpr bool is_NWGC_GKXC_NWGK ()
 
template<typename InLayout , typename WeiLayout , typename OutLayout >
constexpr bool is_GNWC_GKXC_GNWK ()
 
template<typename InLayout , typename WeiLayout , typename OutLayout >
constexpr bool is_NGCW_GKXC_NGKW ()
 
template<typename InLayout , typename WeiLayout , typename OutLayout >
constexpr bool is_NHWGC_GKYXC_NHWGK ()
 
template<typename InLayout , typename WeiLayout , typename OutLayout >
constexpr bool is_GNHWC_GKYXC_GNHWK ()
 
template<typename InLayout , typename WeiLayout , typename OutLayout >
constexpr bool is_NGCHW_GKYXC_NGKHW ()
 
template<typename InLayout , typename WeiLayout , typename OutLayout >
constexpr bool is_NDHWGC_GKZYXC_NDHWGK ()
 
template<typename InLayout , typename WeiLayout , typename OutLayout >
constexpr bool is_GNDHWC_GKZYXC_GNDHWK ()
 
template<typename InLayout , typename WeiLayout , typename OutLayout >
constexpr bool is_NGCDHW_GKZYXC_NGKDHW ()
 
template<typename InLayout , typename WeiLayout , typename OutLayout >
constexpr bool is_NSpatialGC_GKSpatial_NSpatialGK ()
 
template<typename InLayout , typename WeiLayout , typename OutLayout >
constexpr bool is_GNSpatialC_GKSpatial_GNSpatialK ()
 
template<typename InLayout , typename WeiLayout , typename OutLayout >
constexpr bool is_NGCSpatial_GKSpatial_NGKSpatial ()
 
template<typename GridwiseGemm , typename GemmDesc , GemmSpecialization GemmSpec, typename AsLayout , typename BsLayout , typename DsLayout , typename ELayout , typename Block2ETileMap , typename GroupedGemmBlock2ETileMap , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , InMemoryDataOperationEnum EGlobalMemoryDataOperation, bool HasMainKBlockLoop>
__global__ void kernel_grouped_gemm_xdl_fixed_nk (const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const index_t grid_size_grp, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op)
 
template<typename GridwiseGemm , typename GemmDesc , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__global__ void kernel_grouped_gemm_multiple_d_dl (const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op)
 
template<typename GridwiseGemm , typename GemmDesc , GemmSpecialization GemmSpec, typename ADataType , typename BDataType , typename DsDataType , typename EDataType , typename ALayout , typename BLayout , typename DsLayout , typename ELayout , index_t KPerBlock, typename OffsettedBlockToCTileMap , typename LocalBlock2ETileMap , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , BlockGemmPipelineScheduler BlkGemmPipeSched, BlockGemmPipelineVersion BlkGemmPipelineVer>
__global__ void kernel_grouped_gemm_multiple_d_xdl (const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op)
 Entry point kernel for device-wide Grouped GEMM operation. More...
 
template<typename GridwiseGemm , typename GroupKernelArg , typename AElementwiseOperation , typename BElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , bool HasMainKBlockLoop>
__global__ void kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1 (const void CK_CONSTANT_ADDRESS_SPACE *group_kernel_args, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const AccElementwiseOperation acc_element_op, const B1ElementwiseOperation b1_element_op, const CElementwiseOperation c_element_op)
 
template<typename GridwiseGemm , typename GemmDesc , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , bool HasMainKBlockLoop>
__global__ void kernel_grouped_gemm_xdl (const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation c_element_op)
 
template<typename GridwiseGemm , typename GemmDesc , GemmSpecialization GemmSpec, bool Zeroing, typename ALayout , typename BLayout , typename DsLayout , typename ELayout , typename DsDataType , typename Block2ETileMap , typename GroupedGemmBlock2ETileMap , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , InMemoryDataOperationEnum EGlobalMemoryDataOperation, bool HasMainKBlockLoop>
__global__ void kernel_grouped_gemm_xdl_fixed_nk (const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, uint32_t *barrier_count, const index_t barrier_size_grp, const index_t group_count, const index_t grid_size_grp, const index_t KBatch, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation c_element_op)
 
template<typename GridwiseGemm , typename GemmDesc , bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation, typename AElementwiseOperation = ck::tensor_operation::element_wise::PassThrough, typename BElementwiseOperation = ck::tensor_operation::element_wise::PassThrough, typename CElementwiseOperation = ck::tensor_operation::element_wise::PassThrough>
__global__ void kernel_grouped_gemm_xdl_splitk (const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op)
 
template<typename DeviceOp , typename GridwiseOp , typename ADataType , typename B0DataType , typename B1DataType , typename CDataType , typename AElementwiseOperation , typename B0ElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , ck::index_t QueryGroupNumber, bool HasMainKBlockLoop>
__global__ void kernel_grouped_query_attention_wmma (const ADataType *__restrict__ p_a_grid, const B0DataType *__restrict__ p_b0_grid, const B1DataType *__restrict__ p_b1_grid, CDataType *__restrict__ p_c_grid, index_t M, index_t N, index_t K, index_t O, index_t G0, index_t G1, float alpha, bool input_permute, bool output_permute)
 
template<typename DeviceOp , typename GridwiseOp , typename ADataType , typename B0DataType , typename B1DataType , typename CDataType , typename AElementwiseOperation , typename B0ElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , bool HasMainKBlockLoop>
__global__ void kernel_multi_query_attention_wmma (const ADataType *__restrict__ p_a_grid, const B0DataType *__restrict__ p_b0_grid, const B1DataType *__restrict__ p_b1_grid, CDataType *__restrict__ p_c_grid, index_t M, index_t N, index_t K, index_t O, index_t G0, index_t G1, float alpha, bool input_permute, bool output_permute)
 
template<typename GridwiseNormalizationBwd , typename DYDataType , typename XDataType , typename GammaDataType , typename MeanInvStdDataType , typename DXDataType , typename GridDesc_M_K >
__global__ void kernel_normalization_bwd_data (const GridDesc_M_K dy_grid_desc_m_k, const GridDesc_M_K x_grid_desc_m_k, const GridDesc_M_K gamma_grid_desc_m_k, const GridDesc_M_K mean_grid_desc_m_k, const GridDesc_M_K inv_std_grid_desc_m_k, const GridDesc_M_K dx_grid_desc_m_k, index_t num_k_block_tile_iteration, const DYDataType *const __restrict__ p_dy_global, const XDataType *const __restrict__ p_x_global, const GammaDataType *const __restrict__ p_gamma_global, const MeanInvStdDataType *const __restrict__ p_mean_global, const MeanInvStdDataType *const __restrict__ p_inv_std_global, DXDataType *const __restrict__ p_dx_global)
 
template<typename GridwiseReduction , typename DYDataType , typename XDataType , typename MeanInvStdDataType , typename DGammaDataType , typename DBetaDataType , typename GridDesc_M_K , typename GridDesc_M >
__global__ void kernel_normalization_bwd_gamma_beta (const GridDesc_M_K dy_grid_desc_m_k, const GridDesc_M_K x_grid_desc_m_k, const GridDesc_M_K mean_grid_desc_m_k, const GridDesc_M_K inv_std_grid_desc_m_k, const GridDesc_M dgamma_grid_desc_m, const GridDesc_M dbeta_grid_desc_m, index_t num_k_block_tile_iteration, const DYDataType *const __restrict__ p_dy_global, const XDataType *const __restrict__ p_x_global, const MeanInvStdDataType *const __restrict__ p_mean_global, const MeanInvStdDataType *const __restrict__ p_inv_std_global, DGammaDataType *const __restrict__ p_dgamma_global, DBetaDataType *const __restrict__ p_dbeta_global)
 
template<index_t Rank, int NumReduceDim>
std::pair< long_index_t, long_index_tget_2d_lengths (const std::vector< index_t > &inLengths)
 
template<index_t Rank, int NumReduceDim>
std::pair< long_index_t, long_index_tget_2d_lengths (const std::array< index_t, Rank > &inLengths)
 
template<index_t... Ns>
auto make_tuple_from_array_and_index_seq (const std::vector< index_t > &lengths, Sequence< Ns... >)
 
template<index_t arraySize>
auto make_tuple_from_array (const std::vector< index_t > &lengths, Number< arraySize >)
 
template<index_t Rank, index_t NumReduceDim>
std::vector< index_tshuffle_tensor_dimensions (const std::vector< index_t > &origLengthsStrides, const std::vector< int > &reduceDims)
 
template<index_t Rank, index_t NumReduceDim>
std::array< index_t, Rank > shuffle_tensor_dimensions (const std::array< index_t, Rank > &origLengthsStrides, const std::array< int, NumReduceDim > &reduceDims)
 
std::string getMaskingSpecializationString (const MaskingSpecialization &s)
 
template<typename TensorDesc , typename TileLengths , typename DoPads >
__host__ constexpr __device__ auto PadTensorDescriptor (const TensorDesc &desc, const TileLengths &tile_lengths, DoPads)
 
template<GemmSpecialization GemmSpec, typename MPerTileType , typename NPerTileType , typename KPerTileType , typename CDesc_MRaw_NRaw >
auto grid_desc (MatrixPadder< GemmSpec, MPerTileType, NPerTileType, KPerTileType > matrix_padder, CDesc_MRaw_NRaw conv_desc)
 
std::string getTensorSpecializationString (const TensorSpecialization &s)
 

Typedef Documentation

◆ DeviceBatchedGemmPtr

template<typename ALayout , typename BLayout , typename CLayout , typename ADataType , typename BDataType , typename CDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation >
using ck::tensor_operation::device::DeviceBatchedGemmPtr = typedef std::unique_ptr<DeviceBatchedGemm<ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> >

◆ DeviceBatchNormBwdPtr

template<typename XDataType , typename DxDataType , typename DyDataType , typename AccDataType , typename ScaleDataType , typename DscaleDbiasDataType , typename MeanVarDataType , typename DyElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim>
using ck::tensor_operation::device::DeviceBatchNormBwdPtr = typedef std::unique_ptr<DeviceBatchNormBwd<XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim> >

◆ DeviceBatchNormFwdPtr

template<typename XDataType , typename YDataType , typename AccDataType , typename ScaleDataType , typename BiasDataType , typename MeanVarDataType , typename YElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim>
using ck::tensor_operation::device::DeviceBatchNormFwdPtr = typedef std::unique_ptr<DeviceBatchNormFwd<XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumBatchNormReduceDim> >

◆ DeviceBatchNormInferPtr

template<typename XDataType , typename YDataType , typename AccDataType , typename ScaleDataType , typename BiasDataType , typename MeanVarDataType , typename YElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim>
using ck::tensor_operation::device::DeviceBatchNormInferPtr = typedef std::unique_ptr<DeviceBatchNormInfer<XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumBatchNormReduceDim> >

◆ DeviceCGemmPtr

template<typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation >
using ck::tensor_operation::device::DeviceCGemmPtr = typedef std::unique_ptr< DeviceCGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> >

◆ DeviceConvFwdBiasActivationAddPtr

template<typename InElementwiseOperation , typename WeiElementwiseOperation , typename OutElementwiseOperation >
using ck::tensor_operation::device::DeviceConvFwdBiasActivationAddPtr = typedef std::unique_ptr<DeviceConvFwdBiasActivationAdd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation> >

◆ DeviceConvFwdBiasActivationPtr

template<typename InElementwiseOperation , typename WeiElementwiseOperation , typename OutElementwiseOperation >
using ck::tensor_operation::device::DeviceConvFwdBiasActivationPtr = typedef std::unique_ptr<DeviceConvFwdBiasActivation<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation> >

◆ DeviceElementwiseNormalizationPtr

template<typename InDataTypeTuple , typename GammaDataType , typename BetaDataType , typename AccDataType , typename YDataType , typename XElementwiseOperation , typename YElementwiseOperation , index_t Rank, index_t NumReduceDim>
using ck::tensor_operation::device::DeviceElementwiseNormalizationPtr = typedef std::unique_ptr<DeviceElementwiseNormalization<InDataTypeTuple, GammaDataType, BetaDataType, AccDataType, YDataType, XElementwiseOperation, YElementwiseOperation, Rank, NumReduceDim> >

◆ DeviceElementwisePtr

template<typename InDataTypeTuple , typename OutDataTypeTuple , typename ElementwiseOperation , index_t NumDim>
using ck::tensor_operation::device::DeviceElementwisePtr = typedef std::unique_ptr<DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, UnaryOperation, Scale, NumDim> >

◆ DeviceGemmMultipleDMultipleRPtr

template<typename ALayout , typename BLayout , typename DELayout , typename ADataType , typename BDataType , typename DsDataType , typename EDataType , typename RsDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , typename QsElementwiseOperation , typename RsElementwiseOperation >
using ck::tensor_operation::device::DeviceGemmMultipleDMultipleRPtr = typedef std::unique_ptr<DeviceGemmMultipleDMultipleR<ALayout, BLayout, DELayout, ADataType, BDataType, DsDataType, EDataType, RsDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, QsElementwiseOperation, RsElementwiseOperation> >

◆ DeviceGemmReducePtr

template<ck::index_t NumDTensor, ck::index_t NumReduce>
using ck::tensor_operation::device::DeviceGemmReducePtr = typedef std::unique_ptr<DeviceGemmReduce<NumDTensor, NumReduce> >

◆ DeviceGemmSplitKPtr

template<typename ALayout , typename BLayout , typename CLayout , typename ADataType , typename BDataType , typename CDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , typename ComputeType = CDataType>
using ck::tensor_operation::device::DeviceGemmSplitKPtr = typedef std::unique_ptr<DeviceGemmSplitK<ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, ComputeType> >

◆ DeviceGemmStreamKPtr

template<typename ALayout , typename BLayout , typename CLayout , typename ADataType , typename BDataType , typename CDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation >
using ck::tensor_operation::device::DeviceGemmStreamKPtr = typedef std::unique_ptr<DeviceGemmStreamK<ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> >

◆ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle

template<index_t NDimSpatial, typename ALayout , typename BLayout , typename DsLayout , typename ELayout , typename ADataType , typename BDataType , typename AccDataType , typename CShuffleDataType , typename DsDataType , typename EDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , ConvolutionForwardSpecialization ConvForwardSpecialization, GemmSpecialization GemmSpec, index_t NumGemmKPrefetchStage, index_t BlockSize, index_t MPerBlock, index_t NPerBlock, index_t KPerBlock, index_t AK1, index_t BK1, index_t MPerXDL, index_t NPerXDL, index_t MXdlPerWave, index_t NXdlPerWave, typename ABlockTransferThreadClusterLengths_AK0_M_AK1 , typename ABlockTransferThreadClusterArrangeOrder , typename ABlockTransferSrcAccessOrder , index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferDstScalarPerVector_AK1, index_t ABlockLdsExtraM, typename BBlockTransferThreadClusterLengths_BK0_N_BK1 , typename BBlockTransferThreadClusterArrangeOrder , typename BBlockTransferSrcAccessOrder , index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferDstScalarPerVector_BK1, index_t BBlockLdsExtraN, index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock , index_t CDEBlockTransferScalarPerVector_NPerBlock, typename AComputeDataType = decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value, Number<0>, ADataType>()), typename BComputeDataType = AComputeDataType, LoopScheduler LoopSched = make_default_loop_scheduler()>
using ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = typedef DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched>

◆ DeviceGroupedConvFwdMultipleD

template<index_t NDimSpatial, typename ALayout , typename BLayout , typename DsLayout , typename ELayout , typename ADataType , typename BDataType , typename DsDataType , typename EDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , typename ComputeType = decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value, Number<0>, ADataType>())>
using ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD = typedef DeviceGroupedConvFwdMultipleABD<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ComputeType>

Grouped Convolution Forward.

Note
This structure is deprecated (left for backwards compatibility). Please use DeviceGroupedConvFwdMultipleABD.
Template Parameters
NDimSpatialNumber of spatial dimensions.
ALayoutInput layout (also for a1, a2...).
BLayoutWeight layout (also for b1, b2...).
DsLayoutDs layouts.
ELayoutOutput layout.
ADataTypeInput data type. Pass tuple if there is multiple A.
BDataTypeWeight data type. Pass tuple if there is multiple B.
DsDataTypeD data types.
EDataTypeOutput data type.
AElementwiseOperationA elementwise operation.
BElementwiseOperationB elementwise operation.
CDEElementwiseOperationCDE elementwise operation.
ComputeTypeCompute data type (default: ADataType, first if tuple passed).

◆ DeviceMultipleReducePtr

template<index_t Rank, index_t NumReduceDim, index_t NumReduction, typename InElementwiseOperationTuple , typename AccElementwiseOperationTuple >
using ck::tensor_operation::device::DeviceMultipleReducePtr = typedef std::unique_ptr<DeviceMultipleReduce<Rank, NumReduceDim, NumReduction, InElementwiseOperationTuple, AccElementwiseOperationTuple> >

◆ DeviceNormalizationBwdDataPtr

template<typename DYDataType , typename XDataType , typename GammaDataType , typename MeanInvStdDataType , typename DXDataType , index_t Rank, index_t NumReduceDim>
using ck::tensor_operation::device::DeviceNormalizationBwdDataPtr = typedef std::unique_ptr<DeviceNormalizationBwdData<DYDataType, XDataType, GammaDataType, MeanInvStdDataType, DXDataType, Rank, NumReduceDim> >

◆ DeviceNormalizationBwdGammaBetaPtr

template<typename DYDataType , typename XDataType , typename MeanInvStdDataType , typename DGammaDataType , typename DBetaDataType , index_t Rank, index_t NumReduceDim>
using ck::tensor_operation::device::DeviceNormalizationBwdGammaBetaPtr = typedef std::unique_ptr<DeviceNormalizationBwdGammaBeta<DYDataType, XDataType, MeanInvStdDataType, DGammaDataType, DBetaDataType, Rank, NumReduceDim> >

◆ DeviceNormalizationFwdPtr

template<typename XDataType , typename GammaDataType , typename BetaDataType , typename YDataType , typename SaveMeanInvStdDataType , typename YElementwiseOperation , index_t Rank, index_t NumReduceDim>
using ck::tensor_operation::device::DeviceNormalizationFwdPtr = typedef std::unique_ptr<DeviceNormalizationFwd<XDataType, GammaDataType, BetaDataType, YDataType, SaveMeanInvStdDataType, YElementwiseOperation, Rank, NumReduceDim> >

◆ DeviceReduceMultiDPtr

template<typename InDataType , typename DsDataType , typename AccDataType , typename OutDataType , index_t Rank, index_t NumReduceDim, typename ReduceOperation , typename InElementwiseOperation , typename OutElementwiseOperation >
using ck::tensor_operation::device::DeviceReduceMultiDPtr = typedef std::unique_ptr<DeviceReduceMultiD<InDataType, DsDataType, AccDataType, OutDataType, Rank, NumReduceDim, ReduceOperation, InElementwiseOperation, OutElementwiseOperation> >

◆ DeviceReducePtr

template<typename InDataType , typename AccDataType , typename OutDataType , index_t Rank, index_t NumReduceDim, typename ReduceOperation , typename InElementwiseOperation , typename AccElementwiseOperation , bool PropagateNan, bool OutputIndex>
using ck::tensor_operation::device::DeviceReducePtr = typedef std::unique_ptr<DeviceReduce<InDataType, AccDataType, OutDataType, Rank, NumReduceDim, ReduceOperation, InElementwiseOperation, AccElementwiseOperation, PropagateNan, OutputIndex> >

◆ DeviceSoftmaxPtr

template<typename InDataType , typename AccDataType , typename OutDataType , typename InElementwiseOp , typename AccElementwiseOp , index_t Rank, index_t NumReduceDim>
using ck::tensor_operation::device::DeviceSoftmaxPtr = typedef std::unique_ptr<DeviceSoftmax<InDataType, AccDataType, OutDataType, InElementwiseOp, AccElementwiseOp, Rank, NumReduceDim> >

◆ is_tuple

template<typename T >
using ck::tensor_operation::device::is_tuple = typedef decltype(std::declval<T&>().IsTuple())

Enumeration Type Documentation

◆ ConvolutionBackwardDataSpecialization

Enumerator
Default 
Filter1x1Stride1Pad0 

◆ ConvolutionBackwardWeightSpecialization

Enumerator
Default 
Filter1x1Stride1Pad0 
Filter1x1Pad0 
OddC 

◆ ConvolutionForwardSpecialization

Enumerator
Default 
Filter1x1Pad0 
Filter1x1Stride1Pad0 
OddC 
Filter3x3 

◆ GemmSpecialization

Enumerator
Default 
MPadding 
NPadding 
KPadding 
MNPadding 
MKPadding 
NKPadding 
MNKPadding 
OPadding 
MOPadding 
NOPadding 
KOPadding 
MNOPadding 
MKOPadding 
NKOPadding 
MNKOPadding 

◆ MaskingSpecialization

Enumerator
MaskDisabled 
MaskOutUpperTriangle 

◆ TensorSpecialization

Enumerator
Default 
Packed 

Function Documentation

◆ CalculateMaxRead()

template<index_t NumDim1, index_t NumDim2>
auto ck::tensor_operation::device::CalculateMaxRead ( const std::vector< index_t > &  lengths,
const std::vector< index_t > &  strides 
)

Calculates the maximum number of subsequent elements of the fast changing dimension that are consecutive in memory.

Example: NumDimM = 2, NumDimK = 3 A shape = [ 2, 3, 4, 5, 6] A strides = [360, 120, 30, 6, 1] | M | | K | It follows from strides that K is FCD and all the subsequent elements of K are consecutive in memory. But if strides were [360, 120, 6, 24, 1], then only 6 subsequent elements of K would be consecutive in memory.

Assumes that the dimensions are split into two groups of NumDim1 and NumDim2 dimensions.

◆ get_2d_lengths() [1/2]

template<index_t Rank, int NumReduceDim>
std::pair<long_index_t, long_index_t> ck::tensor_operation::device::get_2d_lengths ( const std::array< index_t, Rank > &  inLengths)

◆ get_2d_lengths() [2/2]

template<index_t Rank, int NumReduceDim>
std::pair<long_index_t, long_index_t> ck::tensor_operation::device::get_2d_lengths ( const std::vector< index_t > &  inLengths)

◆ getConvBackwardDataSpecializationString()

std::string ck::tensor_operation::device::getConvBackwardDataSpecializationString ( const ConvolutionBackwardDataSpecialization s)
inline

◆ getConvBackwardWeightSpecializationString()

std::string ck::tensor_operation::device::getConvBackwardWeightSpecializationString ( const ConvolutionBackwardWeightSpecialization s)
inline

◆ getConvForwardSpecializationString()

std::string ck::tensor_operation::device::getConvForwardSpecializationString ( const ConvolutionForwardSpecialization s)
inline

◆ getGemmSpecializationString()

std::string ck::tensor_operation::device::getGemmSpecializationString ( const GemmSpecialization s)
inline

◆ getMaskingSpecializationString()

std::string ck::tensor_operation::device::getMaskingSpecializationString ( const MaskingSpecialization s)
inline

◆ getTensorSpecializationString()

std::string ck::tensor_operation::device::getTensorSpecializationString ( const TensorSpecialization s)
inline

◆ grid_desc()

template<GemmSpecialization GemmSpec, typename MPerTileType , typename NPerTileType , typename KPerTileType , typename CDesc_MRaw_NRaw >
auto ck::tensor_operation::device::grid_desc ( MatrixPadder< GemmSpec, MPerTileType, NPerTileType, KPerTileType >  matrix_padder,
CDesc_MRaw_NRaw  conv_desc 
)

◆ is_GNDHWC_GKZYXC_GNDHWK()

template<typename InLayout , typename WeiLayout , typename OutLayout >
constexpr bool ck::tensor_operation::device::is_GNDHWC_GKZYXC_GNDHWK ( )
constexpr

◆ is_GNHWC_GKYXC_GNHWK()

template<typename InLayout , typename WeiLayout , typename OutLayout >
constexpr bool ck::tensor_operation::device::is_GNHWC_GKYXC_GNHWK ( )
constexpr

◆ is_GNSpatialC_GKSpatial_GNSpatialK()

template<typename InLayout , typename WeiLayout , typename OutLayout >
constexpr bool ck::tensor_operation::device::is_GNSpatialC_GKSpatial_GNSpatialK ( )
constexpr

◆ is_GNWC_GKXC_GNWK()

template<typename InLayout , typename WeiLayout , typename OutLayout >
constexpr bool ck::tensor_operation::device::is_GNWC_GKXC_GNWK ( )
constexpr

◆ is_NDHWGC_GKZYXC_NDHWGK()

template<typename InLayout , typename WeiLayout , typename OutLayout >
constexpr bool ck::tensor_operation::device::is_NDHWGC_GKZYXC_NDHWGK ( )
constexpr

◆ is_NGCDHW_GKZYXC_NGKDHW()

template<typename InLayout , typename WeiLayout , typename OutLayout >
constexpr bool ck::tensor_operation::device::is_NGCDHW_GKZYXC_NGKDHW ( )
constexpr

◆ is_NGCHW_GKYXC_NGKHW()

template<typename InLayout , typename WeiLayout , typename OutLayout >
constexpr bool ck::tensor_operation::device::is_NGCHW_GKYXC_NGKHW ( )
constexpr

◆ is_NGCSpatial_GKSpatial_NGKSpatial()

template<typename InLayout , typename WeiLayout , typename OutLayout >
constexpr bool ck::tensor_operation::device::is_NGCSpatial_GKSpatial_NGKSpatial ( )
constexpr

◆ is_NGCW_GKXC_NGKW()

template<typename InLayout , typename WeiLayout , typename OutLayout >
constexpr bool ck::tensor_operation::device::is_NGCW_GKXC_NGKW ( )
constexpr

◆ is_NHWGC_GKYXC_NHWGK()

template<typename InLayout , typename WeiLayout , typename OutLayout >
constexpr bool ck::tensor_operation::device::is_NHWGC_GKYXC_NHWGK ( )
constexpr

◆ is_NSpatialGC_GKSpatial_NSpatialGK()

template<typename InLayout , typename WeiLayout , typename OutLayout >
constexpr bool ck::tensor_operation::device::is_NSpatialGC_GKSpatial_NSpatialGK ( )
constexpr

◆ is_NWGC_GKXC_NWGK()

template<typename InLayout , typename WeiLayout , typename OutLayout >
constexpr bool ck::tensor_operation::device::is_NWGC_GKXC_NWGK ( )
constexpr

◆ kernel_batched_gemm_dlops_bwd_weight()

template<typename GridwiseGemm , typename FloatAB , typename FloatC , typename AGridDesc_B_K0_M0_M1_K1 , typename BGridDesc_B_K0_N0_N1_K1 , typename CGridDesc_M0_M10_M11_N0_N10_N11 , typename Block2CTileMap , typename ComputePtrOffsetOfBatch , bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_batched_gemm_dlops_bwd_weight ( const FloatAB *__restrict__  p_a_grid,
const FloatAB *__restrict__  p_b_grid,
FloatC *__restrict__  p_c_grid,
const index_t  batch_count,
const AGridDesc_B_K0_M0_M1_K1  a_grid_desc_kbatch_k0_m0_m1_k1,
const BGridDesc_B_K0_N0_N1_K1  b_grid_desc_kbatch_k0_n0_n1_k1,
const CGridDesc_M0_M10_M11_N0_N10_N11  c_grid_desc_m0_m10_m11_n0_n10_n11,
const Block2CTileMap  block_2_ctile_map,
const ComputePtrOffsetOfBatch  compute_ptr_offset_of_batch 
)

◆ kernel_batched_gemm_e_permute_xdl()

template<typename GridwiseGemm , typename ABDataType , typename EDataType , typename AGridDesc_AK0_M_AK1 , typename BGridDesc_BK0_N_BK1 , typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , typename ComputePtrOffsetOfBatch , typename Block2ETileMap , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_batched_gemm_e_permute_xdl ( const ABDataType *__restrict__  p_a_grid,
const ABDataType *__restrict__  p_b_grid,
EDataType *__restrict__  p_e_grid,
const index_t  batch_count,
const AGridDesc_AK0_M_AK1  a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1  b_grid_desc_bk0_n_bk1,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock  e_grid_desc_mblock_mperblock_nblock_nperblock,
const AElementwiseOperation  a_element_op,
const BElementwiseOperation  b_element_op,
const CDEElementwiseOperation  cde_element_op,
const ComputePtrOffsetOfBatch  compute_ptr_offset_of_batch,
const Block2ETileMap  block_2_etile_map 
)

◆ kernel_batched_gemm_gemm_xdl_cshuffle_v1()

template<typename GridwiseGemm , typename A0B0B1DataType , typename D0sPointer , typename D1sPointer , typename E1DataType , typename A0ElementwiseOperation , typename B0ElementwiseOperation , typename CDE0ElementwiseOperation , typename B1ElementwiseOperation , typename CDE1ElementwiseOperation , typename A0GridDesc_AK0_M_AK1 , typename B0GridDesc_BK0_N_BK1 , typename D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 , typename B1GridDesc_BK0_N_BK1 , typename D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename Block2E1TileMap , typename ComputeBasePtrOfStridedBatch , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_batched_gemm_gemm_xdl_cshuffle_v1 ( const A0B0B1DataType *__restrict__  p_a0_grid,
const A0B0B1DataType *__restrict__  p_b0_grid,
D0sPointer  p_d0s_grid,
const A0B0B1DataType *__restrict__  p_b1_grid,
D1sPointer  p_d1s_grid,
E1DataType *__restrict__  p_e1_grid,
const A0ElementwiseOperation  a0_element_op,
const B0ElementwiseOperation  b0_element_op,
const CDE0ElementwiseOperation  cde0_element_op,
const B1ElementwiseOperation  b1_element_op,
const CDE1ElementwiseOperation  cde1_element_op,
const A0GridDesc_AK0_M_AK1  a0_grid_desc_ak0_m_ak1,
const B0GridDesc_BK0_N_BK1  b0_grid_desc_bk0_n_bk1,
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5  d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const B1GridDesc_BK0_N_BK1  b1_grid_desc_bk0_n_bk1,
const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock  d1s_grid_desc_mblock_mperblock_nblock_nperblock,
const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock  e1_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2E1TileMap  block_2_e1tile_map,
const index_t  batch_count,
const ComputeBasePtrOfStridedBatch  compute_base_ptr_of_batch 
)

◆ kernel_batched_gemm_reduce_xdl_cshuffle_v1()

template<typename GridwiseGemm , typename FloatAB , typename FloatC , typename ReducePtrsGlobal , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , typename ReduceInElementwiseOperations , typename ReduceAccElementwiseOperations , typename AGridDesc_AK0_M_AK1 , typename BGridDesc_BK0_N_BK1 , typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename ReduceGridDescriptor_MBlock_MPerBlock , typename ComputeBasePrtOfBatch , typename Block2CTileMap , bool HasMainK0BlockLoop>
__global__ void ck::tensor_operation::device::kernel_batched_gemm_reduce_xdl_cshuffle_v1 ( const FloatAB *__restrict__  p_a_grid,
const FloatAB *__restrict__  p_b_grid,
FloatC *__restrict__  p_c_grid,
ReducePtrsGlobal  p_reduces_grid,
const index_t  batch_count,
const AElementwiseOperation  a_element_op,
const BElementwiseOperation  b_element_op,
const CElementwiseOperation  c_element_op,
const ReduceInElementwiseOperations  reduce_in_element_ops,
const ReduceAccElementwiseOperations  reduce_out_element_ops,
const AGridDesc_AK0_M_AK1  a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1  b_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock  c_grid_desc_mblock_mperblock_nblock_nperblock,
const ReduceGridDescriptor_MBlock_MPerBlock  reduce_grid_desc_mblock_mperblock,
const ComputeBasePrtOfBatch  compute_base_ptr_of_batch_,
const Block2CTileMap  block_2_ctile_map 
)

◆ kernel_batched_gemm_softmax_gemm_wmma_cshuffle()

template<typename DeviceOp , typename GridwiseOp , typename ADataType , typename B0DataType , typename B1DataType , typename CDataType , typename AElementwiseOperation , typename B0ElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_batched_gemm_softmax_gemm_wmma_cshuffle ( const ADataType *__restrict__  p_a_grid,
const B0DataType *__restrict__  p_b0_grid,
const B1DataType *__restrict__  p_b1_grid,
CDataType *__restrict__  p_c_grid,
index_t  M,
index_t  N,
index_t  K,
index_t  O,
index_t  G0,
index_t  G1,
float  alpha,
bool  input_permute,
bool  output_permute 
)

◆ kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1() [1/2]

template<typename GridwiseGemm , typename FloatAB , typename FloatC , typename AElementwiseOperation , typename BElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , typename AGridDesc_AK0_M_AK1 , typename BGridDesc_BK0_N_BK1 , typename B1GridDesc_BK0_N_BK1 , typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename Block2CTileMap , typename ComputeBasePtrOfStridedBatch , typename C0MatrixMask , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1 ( const FloatAB *__restrict__  p_a_grid,
const FloatAB *__restrict__  p_b_grid,
const FloatAB *__restrict__  p_b1_grid,
FloatC *__restrict__  p_c_grid,
const AElementwiseOperation  a_element_op,
const BElementwiseOperation  b_element_op,
const AccElementwiseOperation  acc_element_op,
const B1ElementwiseOperation  b1_element_op,
const CElementwiseOperation  c_element_op,
const AGridDesc_AK0_M_AK1  a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1  b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1  b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock  c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap  block_2_ctile_map,
const index_t  batch_count,
const ComputeBasePtrOfStridedBatch  compute_base_ptr_of_batch,
const C0MatrixMask  c0_matrix_mask 
)

◆ kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1() [2/2]

template<typename GridwiseGemm , typename FloatAB , typename FloatC , typename D0sPointer , typename AElementwiseOperation , typename BElementwiseOperation , typename C0DEElementwiseOperation , typename B1ElementwiseOperation , typename C1DEElementwiseOperation , typename AGridDesc_AK0_M_AK1 , typename BGridDesc_BK0_N_BK1 , typename B1GridDesc_BK0_N_BK1 , typename C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 , typename Block2CTileMap , typename ComputeBasePtrOfStridedBatch , typename C0MatrixMask , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1 ( const FloatAB *__restrict__  p_a_grid,
const FloatAB *__restrict__  p_b_grid,
const FloatAB *__restrict__  p_b1_grid,
FloatC *__restrict__  p_c_grid,
D0sPointer  p_d0s_grid,
const AElementwiseOperation  a_element_op,
const BElementwiseOperation  b_element_op,
const C0DEElementwiseOperation  c0de_element_op,
const B1ElementwiseOperation  b1_element_op,
const C1DEElementwiseOperation  c1de_element_op,
const AGridDesc_AK0_M_AK1  a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1  b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1  b1_grid_desc_bk0_n_bk1,
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock  c1_grid_desc_mblock_mperblock_nblock_nperblock,
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5  d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
const Block2CTileMap  block_2_ctile_map,
const index_t  batch_count,
const ComputeBasePtrOfStridedBatch  compute_base_ptr_of_batch,
const C0MatrixMask  c0_matrix_mask 
)

◆ kernel_batched_gemm_xdl()

template<typename GridwiseGemm , typename ABDataType , typename DsPointer , typename EDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , typename AGridDesc_AK0_M_AK1 , typename BGridDesc_BK0_N_BK1 , typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock , typename ComputePtrOffsetOfBatch , typename Block2ETileMap , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_batched_gemm_xdl ( const ABDataType *__restrict__  p_a_grid,
const ABDataType *__restrict__  p_b_grid,
DsPointer  p_ds_grid,
EDataType *__restrict__  p_e_grid,
const index_t  batch_count,
const AElementwiseOperation  a_element_op,
const BElementwiseOperation  b_element_op,
const CDEElementwiseOperation  cde_element_op,
const AGridDesc_AK0_M_AK1  a_grid_desc_k0_m_k1,
const BGridDesc_BK0_N_BK1  b_grid_desc_k0_n_k1,
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock  ds_grid_desc_mblock_mperblock_nblock_nperblock,
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock  e_grid_desc_mblock_mperblock_nblock_nperblock_,
const ComputePtrOffsetOfBatch  compute_ptr_offset_of_batch,
const Block2ETileMap  block_2_etile_map 
)

◆ kernel_batched_gemm_xdlops_bwd_weight()

template<typename GridwiseGemm , typename FloatA , typename FloatB , typename FloatC , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , typename AGridDesc_B_K0_M_K1 , typename BGridDesc_B_K0_N_K1 , typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock , typename Block2CTileMap , typename ComputePtrOffsetOfBatch , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_batched_gemm_xdlops_bwd_weight ( const FloatA *__restrict__  p_a_grid,
const FloatB *__restrict__  p_b_grid,
FloatC *__restrict__  p_c_grid,
const AElementwiseOperation  a_element_op,
const BElementwiseOperation  b_element_op,
const CElementwiseOperation  c_element_op,
const index_t  batch_count,
const AGridDesc_B_K0_M_K1  a_b_k0_m_k1_grid_desc,
const BGridDesc_B_K0_N_K1  b_b_k0_n_k1_grid_desc,
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock  c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap  block_2_ctile_map,
const ComputePtrOffsetOfBatch  compute_ptr_offset_of_batch 
)

◆ kernel_batched_gemm_xdlops_v2r3()

template<typename DeviceOp , typename GridwiseGemm , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_batched_gemm_xdlops_v2r3 ( const typename DeviceOp::Argument  karg)

◆ kernel_gemm_dl_multiple_d()

template<typename GridwiseGemm , typename ABDataType , typename DsPointer , typename EDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , typename AGridDesc_K0_M0_M1_K1 , typename BGridDesc_K0_N0_N1_K1 , typename DsGridDesc_M0_M10_M11_N0_N10_N11 , typename CGridDesc_M0_M10_M11_N0_N10_N11 , typename ComputePtrOffsetOfBatch , typename Block2CTileMap , bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_gemm_dl_multiple_d ( const ABDataType *__restrict__  p_a_grid,
const ABDataType *__restrict__  p_b_grid,
DsPointer  p_ds_grid,
EDataType *__restrict__  p_e_grid,
const index_t  batch_count,
const AElementwiseOperation  a_element_op,
const BElementwiseOperation  b_element_op,
const CDEElementwiseOperation  cde_element_op,
const AGridDesc_K0_M0_M1_K1  a_grid_desc_k0_m0_m1_k1,
const BGridDesc_K0_N0_N1_K1  b_grid_desc_k0_n0_n1_k1,
const DsGridDesc_M0_M10_M11_N0_N10_N11  ds_grid_desc_m0_m10_m11_n0_n10_n11,
const CGridDesc_M0_M10_M11_N0_N10_N11  e_grid_desc_m0_m10_m11_n0_n10_n11,
const ComputePtrOffsetOfBatch  compute_ptr_offset_of_batch,
const Block2CTileMap  block_2_ctile_map 
)

◆ kernel_gemm_gemm_xdl_cshuffle_v1()

template<typename GridwiseGemm , typename FloatAB , typename FloatC , typename AElementwiseOperation , typename BElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , typename AGridDesc_AK0_M_AK1 , typename BGridDesc_BK0_N_BK1 , typename B1GridDesc_BK0_N_BK1 , typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename Block2CTileMap , typename ComputeBasePtrOfStridedBatch , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_gemm_gemm_xdl_cshuffle_v1 ( const FloatAB *__restrict__  p_a_grid,
const FloatAB *__restrict__  p_b_grid,
const FloatAB *__restrict__  p_b1_grid,
FloatC *__restrict__  p_c_grid,
const AElementwiseOperation  a_element_op,
const BElementwiseOperation  b_element_op,
const AccElementwiseOperation  acc_element_op,
const B1ElementwiseOperation  b1_element_op,
const CElementwiseOperation  c_element_op,
const AGridDesc_AK0_M_AK1  a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1  b_grid_desc_bk0_n_bk1,
const B1GridDesc_BK0_N_BK1  b1_grid_desc_bk0_n_bk1,
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock  c_grid_desc_mblock_mperblock_nblock_nperblock,
const Block2CTileMap  block_2_ctile_map,
const index_t  batch_count,
const ComputeBasePtrOfStridedBatch  compute_base_ptr_of_batch 
)

◆ kernel_gemm_xdlops_v2r3_for_conv3d()

template<typename GridwiseGemm , typename FloatAB , typename FloatC , typename AGridDesc_K0_M_K1 , typename BGridDesc_K0_N_K1 , typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , typename Block2CTileMap , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_gemm_xdlops_v2r3_for_conv3d ( const FloatAB *__restrict__  p_a_grid,
const FloatAB *__restrict__  p_b_grid,
FloatC *__restrict__  p_c_grid,
const index_t  num_batches,
const index_t  a_batch_stride,
const index_t  b_batch_stride,
const index_t  c_batch_stride,
const AGridDesc_K0_M_K1  a_grid_desc_k0_m_k1,
const BGridDesc_K0_N_K1  b_grid_desc_k0_n_k1,
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2  c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
const AElementwiseOperation  a_element_op,
const BElementwiseOperation  b_element_op,
const CElementwiseOperation  c_element_op,
const Block2CTileMap  block_2_ctile_map 
)

◆ kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3()

template<typename GridwiseGemm , typename AGridDesc_AK0_M_K1 , typename BGridDesc_BK0_N_K1 , typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock , typename ComputePtrOffsetOfBatch , index_t NumGroupsToMerge, bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation, index_t MinimumOccupancy = 1, TailNumber TailNum = TailNumber::Full>
__global__ void ck::tensor_operation::device::kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3 ( typename GridwiseGemm::Argument  karg,
[[maybe_unused] ] const AGridDesc_AK0_M_K1  a_grid_desc_ak0_m_ak1,
[[maybe_unused] ] const BGridDesc_BK0_N_K1  b_grid_desc_bk0_n_bk1,
[[maybe_unused] ] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock  c_grid_desc_mblock_mperblock_nblock_nperblock,
[[maybe_unused] ] const ComputePtrOffsetOfBatch  compute_ptr_offset_of_batch,
[[maybe_unused] ] const index_t  num_k_per_block 
)

◆ kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds()

template<typename GridwiseGemm , typename AGridDesc_AK0_M_K1 , typename BGridDesc_BK0_N_K1 , typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock , typename ComputePtrOffsetOfBatch , index_t NumGroupsToMerge, bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation, index_t MinimumOccupancy = 1, TailNumber TailNum = TailNumber::Full>
__global__ void ck::tensor_operation::device::kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds ( typename GridwiseGemm::Argument  karg,
[[maybe_unused] ] const AGridDesc_AK0_M_K1  a_grid_desc_ak0_m_ak1,
[[maybe_unused] ] const BGridDesc_BK0_N_K1  b_grid_desc_bk0_n_bk1,
[[maybe_unused] ] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock  c_grid_desc_mblock_mperblock_nblock_nperblock,
[[maybe_unused] ] const ComputePtrOffsetOfBatch  compute_ptr_offset_of_batch,
[[maybe_unused] ] const index_t  num_k_per_block 
)

◆ kernel_grouped_gemm_multiple_d_dl()

template<typename GridwiseGemm , typename GemmDesc , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_grouped_gemm_multiple_d_dl ( const void CK_CONSTANT_ADDRESS_SPACE gemm_descs_const,
const index_t  group_count,
const AElementwiseOperation  a_element_op,
const BElementwiseOperation  b_element_op,
const CDEElementwiseOperation  cde_element_op 
)

◆ kernel_grouped_gemm_multiple_d_xdl()

template<typename GridwiseGemm , typename GemmDesc , GemmSpecialization GemmSpec, typename ADataType , typename BDataType , typename DsDataType , typename EDataType , typename ALayout , typename BLayout , typename DsLayout , typename ELayout , index_t KPerBlock, typename OffsettedBlockToCTileMap , typename LocalBlock2ETileMap , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , BlockGemmPipelineScheduler BlkGemmPipeSched, BlockGemmPipelineVersion BlkGemmPipelineVer>
__global__ void ck::tensor_operation::device::kernel_grouped_gemm_multiple_d_xdl ( const void CK_CONSTANT_ADDRESS_SPACE gemm_descs_const,
const index_t  group_count,
const AElementwiseOperation  a_element_op,
const BElementwiseOperation  b_element_op,
const CDEElementwiseOperation  cde_element_op 
)

Entry point kernel for device-wide Grouped GEMM operation.

Parameters
[in]gemm_descs_constThe pointer to the array of GEMM descriptor structures.
[in]group_countThe number of together processed GEMMs.
Template Parameters
GridwiseGemmThe specific GridwiseGEMM algorithm implementation.
GemmDescThe structure holding all necessary descriptors and other data needed for grouped gemm calculation and work distribution.
LocalBlock2ETileMapThe structure providing mapping between workgroup ids, the data tiles to process and the output tiles.

◆ kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1()

template<typename GridwiseGemm , typename GroupKernelArg , typename AElementwiseOperation , typename BElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1 ( const void CK_CONSTANT_ADDRESS_SPACE group_kernel_args,
const index_t  group_count,
const AElementwiseOperation  a_element_op,
const BElementwiseOperation  b_element_op,
const AccElementwiseOperation  acc_element_op,
const B1ElementwiseOperation  b1_element_op,
const CElementwiseOperation  c_element_op 
)

◆ kernel_grouped_gemm_xdl()

template<typename GridwiseGemm , typename GemmDesc , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_grouped_gemm_xdl ( const void CK_CONSTANT_ADDRESS_SPACE gemm_descs_const,
const index_t  group_count,
const AElementwiseOperation  a_element_op,
const BElementwiseOperation  b_element_op,
const CDEElementwiseOperation  c_element_op 
)

◆ kernel_grouped_gemm_xdl_fixed_nk() [1/2]

template<typename GridwiseGemm , typename GemmDesc , GemmSpecialization GemmSpec, typename AsLayout , typename BsLayout , typename DsLayout , typename ELayout , typename Block2ETileMap , typename GroupedGemmBlock2ETileMap , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , InMemoryDataOperationEnum EGlobalMemoryDataOperation, bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_grouped_gemm_xdl_fixed_nk ( const void CK_CONSTANT_ADDRESS_SPACE gemm_descs_const,
const index_t  group_count,
const index_t  grid_size_grp,
const AElementwiseOperation  a_element_op,
const BElementwiseOperation  b_element_op,
const CDEElementwiseOperation  cde_element_op 
)

◆ kernel_grouped_gemm_xdl_fixed_nk() [2/2]

template<typename GridwiseGemm , typename GemmDesc , GemmSpecialization GemmSpec, bool Zeroing, typename ALayout , typename BLayout , typename DsLayout , typename ELayout , typename DsDataType , typename Block2ETileMap , typename GroupedGemmBlock2ETileMap , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , InMemoryDataOperationEnum EGlobalMemoryDataOperation, bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_grouped_gemm_xdl_fixed_nk ( const void CK_CONSTANT_ADDRESS_SPACE gemm_descs_const,
uint32_t *  barrier_count,
const index_t  barrier_size_grp,
const index_t  group_count,
const index_t  grid_size_grp,
const index_t  KBatch,
const AElementwiseOperation  a_element_op,
const BElementwiseOperation  b_element_op,
const CDEElementwiseOperation  c_element_op 
)

◆ kernel_grouped_gemm_xdl_splitk()

template<typename GridwiseGemm , typename GemmDesc , bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation, typename AElementwiseOperation = ck::tensor_operation::element_wise::PassThrough, typename BElementwiseOperation = ck::tensor_operation::element_wise::PassThrough, typename CElementwiseOperation = ck::tensor_operation::element_wise::PassThrough>
__global__ void ck::tensor_operation::device::kernel_grouped_gemm_xdl_splitk ( const void CK_CONSTANT_ADDRESS_SPACE gemm_descs_const,
const index_t  group_count,
const AElementwiseOperation  a_element_op,
const BElementwiseOperation  b_element_op,
const CElementwiseOperation  c_element_op 
)

◆ kernel_grouped_query_attention_wmma()

template<typename DeviceOp , typename GridwiseOp , typename ADataType , typename B0DataType , typename B1DataType , typename CDataType , typename AElementwiseOperation , typename B0ElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , ck::index_t QueryGroupNumber, bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_grouped_query_attention_wmma ( const ADataType *__restrict__  p_a_grid,
const B0DataType *__restrict__  p_b0_grid,
const B1DataType *__restrict__  p_b1_grid,
CDataType *__restrict__  p_c_grid,
index_t  M,
index_t  N,
index_t  K,
index_t  O,
index_t  G0,
index_t  G1,
float  alpha,
bool  input_permute,
bool  output_permute 
)

◆ kernel_multi_query_attention_wmma()

template<typename DeviceOp , typename GridwiseOp , typename ADataType , typename B0DataType , typename B1DataType , typename CDataType , typename AElementwiseOperation , typename B0ElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_multi_query_attention_wmma ( const ADataType *__restrict__  p_a_grid,
const B0DataType *__restrict__  p_b0_grid,
const B1DataType *__restrict__  p_b1_grid,
CDataType *__restrict__  p_c_grid,
index_t  M,
index_t  N,
index_t  K,
index_t  O,
index_t  G0,
index_t  G1,
float  alpha,
bool  input_permute,
bool  output_permute 
)

◆ kernel_normalization_bwd_data()

template<typename GridwiseNormalizationBwd , typename DYDataType , typename XDataType , typename GammaDataType , typename MeanInvStdDataType , typename DXDataType , typename GridDesc_M_K >
__global__ void ck::tensor_operation::device::kernel_normalization_bwd_data ( const GridDesc_M_K  dy_grid_desc_m_k,
const GridDesc_M_K  x_grid_desc_m_k,
const GridDesc_M_K  gamma_grid_desc_m_k,
const GridDesc_M_K  mean_grid_desc_m_k,
const GridDesc_M_K  inv_std_grid_desc_m_k,
const GridDesc_M_K  dx_grid_desc_m_k,
index_t  num_k_block_tile_iteration,
const DYDataType *const __restrict__  p_dy_global,
const XDataType *const __restrict__  p_x_global,
const GammaDataType *const __restrict__  p_gamma_global,
const MeanInvStdDataType *const __restrict__  p_mean_global,
const MeanInvStdDataType *const __restrict__  p_inv_std_global,
DXDataType *const __restrict__  p_dx_global 
)

◆ kernel_normalization_bwd_gamma_beta()

template<typename GridwiseReduction , typename DYDataType , typename XDataType , typename MeanInvStdDataType , typename DGammaDataType , typename DBetaDataType , typename GridDesc_M_K , typename GridDesc_M >
__global__ void ck::tensor_operation::device::kernel_normalization_bwd_gamma_beta ( const GridDesc_M_K  dy_grid_desc_m_k,
const GridDesc_M_K  x_grid_desc_m_k,
const GridDesc_M_K  mean_grid_desc_m_k,
const GridDesc_M_K  inv_std_grid_desc_m_k,
const GridDesc_M  dgamma_grid_desc_m,
const GridDesc_M  dbeta_grid_desc_m,
index_t  num_k_block_tile_iteration,
const DYDataType *const __restrict__  p_dy_global,
const XDataType *const __restrict__  p_x_global,
const MeanInvStdDataType *const __restrict__  p_mean_global,
const MeanInvStdDataType *const __restrict__  p_inv_std_global,
DGammaDataType *const __restrict__  p_dgamma_global,
DBetaDataType *const __restrict__  p_dbeta_global 
)

◆ kernel_wmma_cross_attention_forward()

template<typename DeviceOp , typename GridwiseOp , typename QDataType , typename KVDataType , typename ODataType , typename AElementwiseOperation , typename B0ElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_wmma_cross_attention_forward ( const QDataType *__restrict__  p_q_grid,
const KVDataType *__restrict__  p_kv_grid,
ODataType *__restrict__  p_out_grid,
index_t  batch_size,
index_t  q_sequence_length,
index_t  kv_sequence_length,
index_t  head_count,
index_t  head_size,
float  alpha 
)

◆ kernel_wmma_self_attention_forward()

template<typename DeviceOp , typename GridwiseOp , typename QKVDataType , typename ODataType , typename AElementwiseOperation , typename B0ElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_wmma_self_attention_forward ( const QKVDataType *__restrict__  p_qkv_grid,
ODataType *__restrict__  p_out_grid,
index_t  batch_size,
index_t  sequence_length,
index_t  head_count,
index_t  head_size,
float  alpha 
)

◆ make_tuple_from_array()

template<index_t arraySize>
auto ck::tensor_operation::device::make_tuple_from_array ( const std::vector< index_t > &  lengths,
Number< arraySize >   
)

◆ make_tuple_from_array_and_index_seq()

template<index_t... Ns>
auto ck::tensor_operation::device::make_tuple_from_array_and_index_seq ( const std::vector< index_t > &  lengths,
Sequence< Ns... >   
)

◆ PadTensorDescriptor()

template<typename TensorDesc , typename TileLengths , typename DoPads >
__host__ constexpr __device__ auto ck::tensor_operation::device::PadTensorDescriptor ( const TensorDesc &  desc,
const TileLengths &  tile_lengths,
DoPads   
)
constexpr

◆ shuffle_tensor_dimensions() [1/2]

template<index_t Rank, index_t NumReduceDim>
std::array<index_t, Rank> ck::tensor_operation::device::shuffle_tensor_dimensions ( const std::array< index_t, Rank > &  origLengthsStrides,
const std::array< int, NumReduceDim > &  reduceDims 
)

◆ shuffle_tensor_dimensions() [2/2]

template<index_t Rank, index_t NumReduceDim>
std::vector<index_t> ck::tensor_operation::device::shuffle_tensor_dimensions ( const std::vector< index_t > &  origLengthsStrides,
const std::vector< int > &  reduceDims 
)