device Namespace Reference#
ck::tensor_operation::device Namespace Reference
Typedefs | |
template<typename ALayout , typename BLayout , typename CLayout , typename ADataType , typename BDataType , typename CDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation > | |
using | DeviceBatchedGemmPtr = std::unique_ptr< DeviceBatchedGemm< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation > > |
template<typename XDataType , typename DxDataType , typename DyDataType , typename AccDataType , typename ScaleDataType , typename DscaleDbiasDataType , typename MeanVarDataType , typename DyElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim> | |
using | DeviceBatchNormBwdPtr = std::unique_ptr< DeviceBatchNormBwd< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim > > |
template<typename XDataType , typename YDataType , typename AccDataType , typename ScaleDataType , typename BiasDataType , typename MeanVarDataType , typename YElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim> | |
using | DeviceBatchNormFwdPtr = std::unique_ptr< DeviceBatchNormFwd< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumBatchNormReduceDim > > |
template<typename XDataType , typename YDataType , typename AccDataType , typename ScaleDataType , typename BiasDataType , typename MeanVarDataType , typename YElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim> | |
using | DeviceBatchNormInferPtr = std::unique_ptr< DeviceBatchNormInfer< XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumBatchNormReduceDim > > |
template<typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation > | |
using | DeviceCGemmPtr = std::unique_ptr< DeviceCGemm< AElementwiseOperation, BElementwiseOperation, CElementwiseOperation > > |
template<typename InElementwiseOperation , typename WeiElementwiseOperation , typename OutElementwiseOperation > | |
using | DeviceConvFwdBiasActivationPtr = std::unique_ptr< DeviceConvFwdBiasActivation< InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation > > |
template<typename InElementwiseOperation , typename WeiElementwiseOperation , typename OutElementwiseOperation > | |
using | DeviceConvFwdBiasActivationAddPtr = std::unique_ptr< DeviceConvFwdBiasActivationAdd< InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation > > |
template<typename InDataTypeTuple , typename OutDataTypeTuple , typename ElementwiseOperation , index_t NumDim> | |
using | DeviceElementwisePtr = std::unique_ptr< DeviceElementwise< InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, NumDim > > |
template<typename InDataTypeTuple , typename GammaDataType , typename BetaDataType , typename AccDataType , typename YDataType , typename XElementwiseOperation , typename YElementwiseOperation , index_t Rank, index_t NumReduceDim> | |
using | DeviceElementwiseNormalizationPtr = std::unique_ptr< DeviceElementwiseNormalization< InDataTypeTuple, GammaDataType, BetaDataType, AccDataType, YDataType, XElementwiseOperation, YElementwiseOperation, Rank, NumReduceDim > > |
template<typename ALayout , typename BLayout , typename DELayout , typename ADataType , typename BDataType , typename DsDataType , typename EDataType , typename RsDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , typename QsElementwiseOperation , typename RsElementwiseOperation > | |
using | DeviceGemmMultipleDMultipleRPtr = std::unique_ptr< DeviceGemmMultipleDMultipleR< ALayout, BLayout, DELayout, ADataType, BDataType, DsDataType, EDataType, RsDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, QsElementwiseOperation, RsElementwiseOperation > > |
template<ck::index_t NumDTensor, ck::index_t NumReduce> | |
using | DeviceGemmReducePtr = std::unique_ptr< DeviceGemmReduce< NumDTensor, NumReduce > > |
template<typename ALayout , typename BLayout , typename CLayout , typename ADataType , typename BDataType , typename CDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , typename ComputeType = CDataType> | |
using | DeviceGemmSplitKPtr = std::unique_ptr< DeviceGemmSplitK< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, ComputeType > > |
template<typename ALayout , typename BLayout , typename CLayout , typename ADataType , typename BDataType , typename CDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation > | |
using | DeviceGemmStreamKPtr = std::unique_ptr< DeviceGemmStreamK< ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation > > |
template<typename T > | |
using | is_tuple = decltype(std::declval< T & >().IsTuple()) |
template<index_t NDimSpatial, typename ALayout , typename BLayout , typename DsLayout , typename ELayout , typename ADataType , typename BDataType , typename DsDataType , typename EDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , typename ComputeType = decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value, Number<0>, ADataType>())> | |
using | DeviceGroupedConvFwdMultipleD = DeviceGroupedConvFwdMultipleABD< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ComputeType > |
Grouped Convolution Forward. More... | |
template<index_t Rank, index_t NumReduceDim, index_t NumReduction, typename InElementwiseOperationTuple , typename AccElementwiseOperationTuple > | |
using | DeviceMultipleReducePtr = std::unique_ptr< DeviceMultipleReduce< Rank, NumReduceDim, NumReduction, InElementwiseOperationTuple, AccElementwiseOperationTuple > > |
template<typename DYDataType , typename XDataType , typename GammaDataType , typename MeanInvStdDataType , typename DXDataType , index_t Rank, index_t NumReduceDim> | |
using | DeviceNormalizationBwdDataPtr = std::unique_ptr< DeviceNormalizationBwdData< DYDataType, XDataType, GammaDataType, MeanInvStdDataType, DXDataType, Rank, NumReduceDim > > |
template<typename DYDataType , typename XDataType , typename MeanInvStdDataType , typename DGammaDataType , typename DBetaDataType , index_t Rank, index_t NumReduceDim> | |
using | DeviceNormalizationBwdGammaBetaPtr = std::unique_ptr< DeviceNormalizationBwdGammaBeta< DYDataType, XDataType, MeanInvStdDataType, DGammaDataType, DBetaDataType, Rank, NumReduceDim > > |
template<typename XDataType , typename GammaDataType , typename BetaDataType , typename YDataType , typename SaveMeanInvStdDataType , typename YElementwiseOperation , index_t Rank, index_t NumReduceDim> | |
using | DeviceNormalizationFwdPtr = std::unique_ptr< DeviceNormalizationFwd< XDataType, GammaDataType, BetaDataType, YDataType, SaveMeanInvStdDataType, YElementwiseOperation, Rank, NumReduceDim > > |
template<typename InDataType , typename AccDataType , typename OutDataType , index_t Rank, index_t NumReduceDim, typename ReduceOperation , typename InElementwiseOperation , typename AccElementwiseOperation , bool PropagateNan, bool OutputIndex> | |
using | DeviceReducePtr = std::unique_ptr< DeviceReduce< InDataType, AccDataType, OutDataType, Rank, NumReduceDim, ReduceOperation, InElementwiseOperation, AccElementwiseOperation, PropagateNan, OutputIndex > > |
template<typename InDataType , typename DsDataType , typename AccDataType , typename OutDataType , index_t Rank, index_t NumReduceDim, typename ReduceOperation , typename InElementwiseOperation , typename OutElementwiseOperation > | |
using | DeviceReduceMultiDPtr = std::unique_ptr< DeviceReduceMultiD< InDataType, DsDataType, AccDataType, OutDataType, Rank, NumReduceDim, ReduceOperation, InElementwiseOperation, OutElementwiseOperation > > |
template<typename InDataType , typename AccDataType , typename OutDataType , typename InElementwiseOp , typename AccElementwiseOp , index_t Rank, index_t NumReduceDim> | |
using | DeviceSoftmaxPtr = std::unique_ptr< DeviceSoftmax< InDataType, AccDataType, OutDataType, InElementwiseOp, AccElementwiseOp, Rank, NumReduceDim > > |
template<index_t NDimSpatial, typename ALayout , typename BLayout , typename DsLayout , typename ELayout , typename ADataType , typename BDataType , typename AccDataType , typename CShuffleDataType , typename DsDataType , typename EDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , ConvolutionForwardSpecialization ConvForwardSpecialization, GemmSpecialization GemmSpec, index_t NumGemmKPrefetchStage, index_t BlockSize, index_t MPerBlock, index_t NPerBlock, index_t KPerBlock, index_t AK1, index_t BK1, index_t MPerXDL, index_t NPerXDL, index_t MXdlPerWave, index_t NXdlPerWave, typename ABlockTransferThreadClusterLengths_AK0_M_AK1 , typename ABlockTransferThreadClusterArrangeOrder , typename ABlockTransferSrcAccessOrder , index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferDstScalarPerVector_AK1, index_t ABlockLdsExtraM, typename BBlockTransferThreadClusterLengths_BK0_N_BK1 , typename BBlockTransferThreadClusterArrangeOrder , typename BBlockTransferSrcAccessOrder , index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferDstScalarPerVector_BK1, index_t BBlockLdsExtraN, index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock , index_t CDEBlockTransferScalarPerVector_NPerBlock, typename AComputeDataType = decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value, Number<0>, ADataType>()), typename BComputeDataType = AComputeDataType, LoopScheduler LoopSched = make_default_loop_scheduler()> | |
using | DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched > |
Enumerations | |
enum class | ConvolutionBackwardDataSpecialization { Default , Filter1x1Stride1Pad0 } |
enum class | ConvolutionBackwardWeightSpecialization { Default , Filter1x1Stride1Pad0 , Filter1x1Pad0 , OddC } |
enum class | ConvolutionForwardSpecialization { Default , Filter1x1Pad0 , Filter1x1Stride1Pad0 , OddC , Filter3x3 } |
enum class | GemmSpecialization { Default , MPadding , NPadding , KPadding , MNPadding , MKPadding , NKPadding , MNKPadding , OPadding , MOPadding , NOPadding , KOPadding , MNOPadding , MKOPadding , NKOPadding , MNKOPadding } |
enum class | MaskingSpecialization { MaskDisabled , MaskOutUpperTriangle } |
enum class | TensorSpecialization { Default , Packed } |
Functions | |
std::string | getConvBackwardDataSpecializationString (const ConvolutionBackwardDataSpecialization &s) |
std::string | getConvBackwardWeightSpecializationString (const ConvolutionBackwardWeightSpecialization &s) |
std::string | getConvForwardSpecializationString (const ConvolutionForwardSpecialization &s) |
std::string | getGemmSpecializationString (const GemmSpecialization &s) |
template<typename GridwiseGemm , typename ABDataType , typename EDataType , typename AGridDesc_AK0_M_AK1 , typename BGridDesc_BK0_N_BK1 , typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , typename ComputePtrOffsetOfBatch , typename Block2ETileMap , bool HasMainKBlockLoop> | |
__global__ void | kernel_batched_gemm_e_permute_xdl (const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, EDataType *__restrict__ p_e_grid, const index_t batch_count, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map) |
template<typename GridwiseGemm , typename FloatAB , typename FloatC , typename AElementwiseOperation , typename BElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , typename AGridDesc_AK0_M_AK1 , typename BGridDesc_BK0_N_BK1 , typename B1GridDesc_BK0_N_BK1 , typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename Block2CTileMap , typename ComputeBasePtrOfStridedBatch , bool HasMainKBlockLoop> | |
__global__ void | kernel_gemm_gemm_xdl_cshuffle_v1 (const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, const FloatAB *__restrict__ p_b1_grid, FloatC *__restrict__ p_c_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const AccElementwiseOperation acc_element_op, const B1ElementwiseOperation b1_element_op, const CElementwiseOperation c_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const Block2CTileMap block_2_ctile_map, const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) |
template<typename GridwiseGemm , typename ABDataType , typename DsPointer , typename EDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , typename AGridDesc_AK0_M_AK1 , typename BGridDesc_BK0_N_BK1 , typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock , typename ComputePtrOffsetOfBatch , typename Block2ETileMap , bool HasMainKBlockLoop> | |
__global__ void | kernel_batched_gemm_xdl (const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const index_t batch_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_k0_m_k1, const BGridDesc_BK0_N_BK1 b_grid_desc_k0_n_k1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map) |
template<typename GridwiseGemm , typename ABDataType , typename DsPointer , typename EDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , typename AGridDesc_K0_M0_M1_K1 , typename BGridDesc_K0_N0_N1_K1 , typename DsGridDesc_M0_M10_M11_N0_N10_N11 , typename CGridDesc_M0_M10_M11_N0_N10_N11 , typename ComputePtrOffsetOfBatch , typename Block2CTileMap , bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> | |
__global__ void | kernel_gemm_dl_multiple_d (const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const index_t batch_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const DsGridDesc_M0_M10_M11_N0_N10_N11 ds_grid_desc_m0_m10_m11_n0_n10_n11, const CGridDesc_M0_M10_M11_N0_N10_N11 e_grid_desc_m0_m10_m11_n0_n10_n11, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2CTileMap block_2_ctile_map) |
template<typename GridwiseGemm , typename A0B0B1DataType , typename D0sPointer , typename D1sPointer , typename E1DataType , typename A0ElementwiseOperation , typename B0ElementwiseOperation , typename CDE0ElementwiseOperation , typename B1ElementwiseOperation , typename CDE1ElementwiseOperation , typename A0GridDesc_AK0_M_AK1 , typename B0GridDesc_BK0_N_BK1 , typename D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 , typename B1GridDesc_BK0_N_BK1 , typename D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename Block2E1TileMap , typename ComputeBasePtrOfStridedBatch , bool HasMainKBlockLoop> | |
__global__ void | kernel_batched_gemm_gemm_xdl_cshuffle_v1 (const A0B0B1DataType *__restrict__ p_a0_grid, const A0B0B1DataType *__restrict__ p_b0_grid, D0sPointer p_d0s_grid, const A0B0B1DataType *__restrict__ p_b1_grid, D1sPointer p_d1s_grid, E1DataType *__restrict__ p_e1_grid, const A0ElementwiseOperation a0_element_op, const B0ElementwiseOperation b0_element_op, const CDE0ElementwiseOperation cde0_element_op, const B1ElementwiseOperation b1_element_op, const CDE1ElementwiseOperation cde1_element_op, const A0GridDesc_AK0_M_AK1 a0_grid_desc_ak0_m_ak1, const B0GridDesc_BK0_N_BK1 b0_grid_desc_bk0_n_bk1, const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock d1s_grid_desc_mblock_mperblock_nblock_nperblock, const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e1_grid_desc_mblock_mperblock_nblock_nperblock, const Block2E1TileMap block_2_e1tile_map, const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch) |
template<typename GridwiseGemm , typename FloatAB , typename FloatC , typename ReducePtrsGlobal , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , typename ReduceInElementwiseOperations , typename ReduceAccElementwiseOperations , typename AGridDesc_AK0_M_AK1 , typename BGridDesc_BK0_N_BK1 , typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename ReduceGridDescriptor_MBlock_MPerBlock , typename ComputeBasePrtOfBatch , typename Block2CTileMap , bool HasMainK0BlockLoop> | |
__global__ void | kernel_batched_gemm_reduce_xdl_cshuffle_v1 (const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, ReducePtrsGlobal p_reduces_grid, const index_t batch_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const ReduceInElementwiseOperations reduce_in_element_ops, const ReduceAccElementwiseOperations reduce_out_element_ops, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, const ComputeBasePrtOfBatch compute_base_ptr_of_batch_, const Block2CTileMap block_2_ctile_map) |
template<typename DeviceOp , typename GridwiseOp , typename ADataType , typename B0DataType , typename B1DataType , typename CDataType , typename AElementwiseOperation , typename B0ElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , bool HasMainKBlockLoop> | |
__global__ void | kernel_batched_gemm_softmax_gemm_wmma_cshuffle (const ADataType *__restrict__ p_a_grid, const B0DataType *__restrict__ p_b0_grid, const B1DataType *__restrict__ p_b1_grid, CDataType *__restrict__ p_c_grid, index_t M, index_t N, index_t K, index_t O, index_t G0, index_t G1, float alpha, bool input_permute, bool output_permute) |
template<typename DeviceOp , typename GridwiseOp , typename QKVDataType , typename ODataType , typename AElementwiseOperation , typename B0ElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , bool HasMainKBlockLoop> | |
__global__ void | kernel_wmma_self_attention_forward (const QKVDataType *__restrict__ p_qkv_grid, ODataType *__restrict__ p_out_grid, index_t batch_size, index_t sequence_length, index_t head_count, index_t head_size, float alpha) |
template<typename DeviceOp , typename GridwiseOp , typename QDataType , typename KVDataType , typename ODataType , typename AElementwiseOperation , typename B0ElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , bool HasMainKBlockLoop> | |
__global__ void | kernel_wmma_cross_attention_forward (const QDataType *__restrict__ p_q_grid, const KVDataType *__restrict__ p_kv_grid, ODataType *__restrict__ p_out_grid, index_t batch_size, index_t q_sequence_length, index_t kv_sequence_length, index_t head_count, index_t head_size, float alpha) |
template<typename GridwiseGemm , typename FloatAB , typename FloatC , typename D0sPointer , typename AElementwiseOperation , typename BElementwiseOperation , typename C0DEElementwiseOperation , typename B1ElementwiseOperation , typename C1DEElementwiseOperation , typename AGridDesc_AK0_M_AK1 , typename BGridDesc_BK0_N_BK1 , typename B1GridDesc_BK0_N_BK1 , typename C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 , typename Block2CTileMap , typename ComputeBasePtrOfStridedBatch , typename C0MatrixMask , bool HasMainKBlockLoop> | |
__global__ void | kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1 (const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, const FloatAB *__restrict__ p_b1_grid, FloatC *__restrict__ p_c_grid, D0sPointer p_d0s_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const C0DEElementwiseOperation c0de_element_op, const B1ElementwiseOperation b1_element_op, const C1DEElementwiseOperation c1de_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c1_grid_desc_mblock_mperblock_nblock_nperblock, const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, const Block2CTileMap block_2_ctile_map, const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const C0MatrixMask c0_matrix_mask) |
template<typename GridwiseGemm , typename FloatAB , typename FloatC , typename AElementwiseOperation , typename BElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , typename AGridDesc_AK0_M_AK1 , typename BGridDesc_BK0_N_BK1 , typename B1GridDesc_BK0_N_BK1 , typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename Block2CTileMap , typename ComputeBasePtrOfStridedBatch , typename C0MatrixMask , bool HasMainKBlockLoop> | |
__global__ void | kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1 (const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, const FloatAB *__restrict__ p_b1_grid, FloatC *__restrict__ p_c_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const AccElementwiseOperation acc_element_op, const B1ElementwiseOperation b1_element_op, const CElementwiseOperation c_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const B1GridDesc_BK0_N_BK1 b1_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const Block2CTileMap block_2_ctile_map, const index_t batch_count, const ComputeBasePtrOfStridedBatch compute_base_ptr_of_batch, const C0MatrixMask c0_matrix_mask) |
template<typename DeviceOp , typename GridwiseGemm , bool HasMainKBlockLoop> | |
__global__ void | kernel_batched_gemm_xdlops_v2r3 (const typename DeviceOp::Argument karg) |
template<index_t NumDim1, index_t NumDim2> | |
auto | CalculateMaxRead (const std::vector< index_t > &lengths, const std::vector< index_t > &strides) |
template<typename GridwiseGemm , typename FloatAB , typename FloatC , typename AGridDesc_K0_M_K1 , typename BGridDesc_K0_N_K1 , typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , typename Block2CTileMap , bool HasMainKBlockLoop> | |
__global__ void | kernel_gemm_xdlops_v2r3_for_conv3d (const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const index_t num_batches, const index_t a_batch_stride, const index_t b_batch_stride, const index_t c_batch_stride, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map) |
template<typename GridwiseGemm , typename FloatAB , typename FloatC , typename AGridDesc_B_K0_M0_M1_K1 , typename BGridDesc_B_K0_N0_N1_K1 , typename CGridDesc_M0_M10_M11_N0_N10_N11 , typename Block2CTileMap , typename ComputePtrOffsetOfBatch , bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> | |
__global__ void | kernel_batched_gemm_dlops_bwd_weight (const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const index_t batch_count, const AGridDesc_B_K0_M0_M1_K1 a_grid_desc_kbatch_k0_m0_m1_k1, const BGridDesc_B_K0_N0_N1_K1 b_grid_desc_kbatch_k0_n0_n1_k1, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) |
template<typename GridwiseGemm , typename FloatA , typename FloatB , typename FloatC , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , typename AGridDesc_B_K0_M_K1 , typename BGridDesc_B_K0_N_K1 , typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock , typename Block2CTileMap , typename ComputePtrOffsetOfBatch , bool HasMainKBlockLoop> | |
__global__ void | kernel_batched_gemm_xdlops_bwd_weight (const FloatA *__restrict__ p_a_grid, const FloatB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const index_t batch_count, const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc, const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const Block2CTileMap block_2_ctile_map, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch) |
template<typename GridwiseGemm , typename AGridDesc_AK0_M_K1 , typename BGridDesc_BK0_N_K1 , typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock , typename ComputePtrOffsetOfBatch , index_t NumGroupsToMerge, bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation, index_t MinimumOccupancy = 1, TailNumber TailNum = TailNumber::Full> | |
__global__ void | kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3 (typename GridwiseGemm::Argument karg, [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, [[maybe_unused]] const index_t num_k_per_block) |
template<typename GridwiseGemm , typename AGridDesc_AK0_M_K1 , typename BGridDesc_BK0_N_K1 , typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock , typename ComputePtrOffsetOfBatch , index_t NumGroupsToMerge, bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation, index_t MinimumOccupancy = 1, TailNumber TailNum = TailNumber::Full> | |
__global__ void | kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds (typename GridwiseGemm::Argument karg, [[maybe_unused]] const AGridDesc_AK0_M_K1 a_grid_desc_ak0_m_ak1, [[maybe_unused]] const BGridDesc_BK0_N_K1 b_grid_desc_bk0_n_bk1, [[maybe_unused]] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, [[maybe_unused]] const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, [[maybe_unused]] const index_t num_k_per_block) |
template<typename InLayout , typename WeiLayout , typename OutLayout > | |
constexpr bool | is_NWGC_GKXC_NWGK () |
template<typename InLayout , typename WeiLayout , typename OutLayout > | |
constexpr bool | is_GNWC_GKXC_GNWK () |
template<typename InLayout , typename WeiLayout , typename OutLayout > | |
constexpr bool | is_NGCW_GKXC_NGKW () |
template<typename InLayout , typename WeiLayout , typename OutLayout > | |
constexpr bool | is_NHWGC_GKYXC_NHWGK () |
template<typename InLayout , typename WeiLayout , typename OutLayout > | |
constexpr bool | is_GNHWC_GKYXC_GNHWK () |
template<typename InLayout , typename WeiLayout , typename OutLayout > | |
constexpr bool | is_NGCHW_GKYXC_NGKHW () |
template<typename InLayout , typename WeiLayout , typename OutLayout > | |
constexpr bool | is_NDHWGC_GKZYXC_NDHWGK () |
template<typename InLayout , typename WeiLayout , typename OutLayout > | |
constexpr bool | is_GNDHWC_GKZYXC_GNDHWK () |
template<typename InLayout , typename WeiLayout , typename OutLayout > | |
constexpr bool | is_NGCDHW_GKZYXC_NGKDHW () |
template<typename InLayout , typename WeiLayout , typename OutLayout > | |
constexpr bool | is_NSpatialGC_GKSpatial_NSpatialGK () |
template<typename InLayout , typename WeiLayout , typename OutLayout > | |
constexpr bool | is_GNSpatialC_GKSpatial_GNSpatialK () |
template<typename InLayout , typename WeiLayout , typename OutLayout > | |
constexpr bool | is_NGCSpatial_GKSpatial_NGKSpatial () |
template<typename GridwiseGemm , typename GemmDesc , GemmSpecialization GemmSpec, typename AsLayout , typename BsLayout , typename DsLayout , typename ELayout , typename Block2ETileMap , typename GroupedGemmBlock2ETileMap , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , InMemoryDataOperationEnum EGlobalMemoryDataOperation, bool HasMainKBlockLoop> | |
__global__ void | kernel_grouped_gemm_xdl_fixed_nk (const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const index_t grid_size_grp, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) |
template<typename GridwiseGemm , typename GemmDesc , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop> | |
__global__ void | kernel_grouped_gemm_multiple_d_dl (const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) |
template<typename GridwiseGemm , typename GemmDesc , GemmSpecialization GemmSpec, typename ADataType , typename BDataType , typename DsDataType , typename EDataType , typename ALayout , typename BLayout , typename DsLayout , typename ELayout , index_t KPerBlock, typename OffsettedBlockToCTileMap , typename LocalBlock2ETileMap , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , BlockGemmPipelineScheduler BlkGemmPipeSched, BlockGemmPipelineVersion BlkGemmPipelineVer> | |
__global__ void | kernel_grouped_gemm_multiple_d_xdl (const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op) |
Entry point kernel for device-wide Grouped GEMM operation. More... | |
template<typename GridwiseGemm , typename GroupKernelArg , typename AElementwiseOperation , typename BElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , bool HasMainKBlockLoop> | |
__global__ void | kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1 (const void CK_CONSTANT_ADDRESS_SPACE *group_kernel_args, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const AccElementwiseOperation acc_element_op, const B1ElementwiseOperation b1_element_op, const CElementwiseOperation c_element_op) |
template<typename GridwiseGemm , typename GemmDesc , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , bool HasMainKBlockLoop> | |
__global__ void | kernel_grouped_gemm_xdl (const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation c_element_op) |
template<typename GridwiseGemm , typename GemmDesc , GemmSpecialization GemmSpec, bool Zeroing, typename ALayout , typename BLayout , typename DsLayout , typename ELayout , typename DsDataType , typename Block2ETileMap , typename GroupedGemmBlock2ETileMap , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , InMemoryDataOperationEnum EGlobalMemoryDataOperation, bool HasMainKBlockLoop> | |
__global__ void | kernel_grouped_gemm_xdl_fixed_nk (const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, uint32_t *barrier_count, const index_t barrier_size_grp, const index_t group_count, const index_t grid_size_grp, const index_t KBatch, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation c_element_op) |
template<typename GridwiseGemm , typename GemmDesc , bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation, typename AElementwiseOperation = ck::tensor_operation::element_wise::PassThrough, typename BElementwiseOperation = ck::tensor_operation::element_wise::PassThrough, typename CElementwiseOperation = ck::tensor_operation::element_wise::PassThrough> | |
__global__ void | kernel_grouped_gemm_xdl_splitk (const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op) |
template<typename DeviceOp , typename GridwiseOp , typename ADataType , typename B0DataType , typename B1DataType , typename CDataType , typename AElementwiseOperation , typename B0ElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , ck::index_t QueryGroupNumber, bool HasMainKBlockLoop> | |
__global__ void | kernel_grouped_query_attention_wmma (const ADataType *__restrict__ p_a_grid, const B0DataType *__restrict__ p_b0_grid, const B1DataType *__restrict__ p_b1_grid, CDataType *__restrict__ p_c_grid, index_t M, index_t N, index_t K, index_t O, index_t G0, index_t G1, float alpha, bool input_permute, bool output_permute) |
template<typename DeviceOp , typename GridwiseOp , typename ADataType , typename B0DataType , typename B1DataType , typename CDataType , typename AElementwiseOperation , typename B0ElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , bool HasMainKBlockLoop> | |
__global__ void | kernel_multi_query_attention_wmma (const ADataType *__restrict__ p_a_grid, const B0DataType *__restrict__ p_b0_grid, const B1DataType *__restrict__ p_b1_grid, CDataType *__restrict__ p_c_grid, index_t M, index_t N, index_t K, index_t O, index_t G0, index_t G1, float alpha, bool input_permute, bool output_permute) |
template<typename GridwiseNormalizationBwd , typename DYDataType , typename XDataType , typename GammaDataType , typename MeanInvStdDataType , typename DXDataType , typename GridDesc_M_K > | |
__global__ void | kernel_normalization_bwd_data (const GridDesc_M_K dy_grid_desc_m_k, const GridDesc_M_K x_grid_desc_m_k, const GridDesc_M_K gamma_grid_desc_m_k, const GridDesc_M_K mean_grid_desc_m_k, const GridDesc_M_K inv_std_grid_desc_m_k, const GridDesc_M_K dx_grid_desc_m_k, index_t num_k_block_tile_iteration, const DYDataType *const __restrict__ p_dy_global, const XDataType *const __restrict__ p_x_global, const GammaDataType *const __restrict__ p_gamma_global, const MeanInvStdDataType *const __restrict__ p_mean_global, const MeanInvStdDataType *const __restrict__ p_inv_std_global, DXDataType *const __restrict__ p_dx_global) |
template<typename GridwiseReduction , typename DYDataType , typename XDataType , typename MeanInvStdDataType , typename DGammaDataType , typename DBetaDataType , typename GridDesc_M_K , typename GridDesc_M > | |
__global__ void | kernel_normalization_bwd_gamma_beta (const GridDesc_M_K dy_grid_desc_m_k, const GridDesc_M_K x_grid_desc_m_k, const GridDesc_M_K mean_grid_desc_m_k, const GridDesc_M_K inv_std_grid_desc_m_k, const GridDesc_M dgamma_grid_desc_m, const GridDesc_M dbeta_grid_desc_m, index_t num_k_block_tile_iteration, const DYDataType *const __restrict__ p_dy_global, const XDataType *const __restrict__ p_x_global, const MeanInvStdDataType *const __restrict__ p_mean_global, const MeanInvStdDataType *const __restrict__ p_inv_std_global, DGammaDataType *const __restrict__ p_dgamma_global, DBetaDataType *const __restrict__ p_dbeta_global) |
template<index_t Rank, int NumReduceDim> | |
std::pair< long_index_t, long_index_t > | get_2d_lengths (const std::vector< index_t > &inLengths) |
template<index_t Rank, int NumReduceDim> | |
std::pair< long_index_t, long_index_t > | get_2d_lengths (const std::array< index_t, Rank > &inLengths) |
template<index_t... Ns> | |
auto | make_tuple_from_array_and_index_seq (const std::vector< index_t > &lengths, Sequence< Ns... >) |
template<index_t arraySize> | |
auto | make_tuple_from_array (const std::vector< index_t > &lengths, Number< arraySize >) |
template<index_t Rank, index_t NumReduceDim> | |
std::vector< index_t > | shuffle_tensor_dimensions (const std::vector< index_t > &origLengthsStrides, const std::vector< int > &reduceDims) |
template<index_t Rank, index_t NumReduceDim> | |
std::array< index_t, Rank > | shuffle_tensor_dimensions (const std::array< index_t, Rank > &origLengthsStrides, const std::array< int, NumReduceDim > &reduceDims) |
std::string | getMaskingSpecializationString (const MaskingSpecialization &s) |
template<typename TensorDesc , typename TileLengths , typename DoPads > | |
__host__ constexpr __device__ auto | PadTensorDescriptor (const TensorDesc &desc, const TileLengths &tile_lengths, DoPads) |
template<GemmSpecialization GemmSpec, typename MPerTileType , typename NPerTileType , typename KPerTileType , typename CDesc_MRaw_NRaw > | |
auto | grid_desc (MatrixPadder< GemmSpec, MPerTileType, NPerTileType, KPerTileType > matrix_padder, CDesc_MRaw_NRaw conv_desc) |
std::string | getTensorSpecializationString (const TensorSpecialization &s) |
Typedef Documentation
◆ DeviceBatchedGemmPtr
template<typename ALayout , typename BLayout , typename CLayout , typename ADataType , typename BDataType , typename CDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation >
using ck::tensor_operation::device::DeviceBatchedGemmPtr = typedef std::unique_ptr<DeviceBatchedGemm<ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> > |
◆ DeviceBatchNormBwdPtr
template<typename XDataType , typename DxDataType , typename DyDataType , typename AccDataType , typename ScaleDataType , typename DscaleDbiasDataType , typename MeanVarDataType , typename DyElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim>
using ck::tensor_operation::device::DeviceBatchNormBwdPtr = typedef std::unique_ptr<DeviceBatchNormBwd<XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim> > |
◆ DeviceBatchNormFwdPtr
template<typename XDataType , typename YDataType , typename AccDataType , typename ScaleDataType , typename BiasDataType , typename MeanVarDataType , typename YElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim>
using ck::tensor_operation::device::DeviceBatchNormFwdPtr = typedef std::unique_ptr<DeviceBatchNormFwd<XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumBatchNormReduceDim> > |
◆ DeviceBatchNormInferPtr
template<typename XDataType , typename YDataType , typename AccDataType , typename ScaleDataType , typename BiasDataType , typename MeanVarDataType , typename YElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim>
using ck::tensor_operation::device::DeviceBatchNormInferPtr = typedef std::unique_ptr<DeviceBatchNormInfer<XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumBatchNormReduceDim> > |
◆ DeviceCGemmPtr
template<typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation >
using ck::tensor_operation::device::DeviceCGemmPtr = typedef std::unique_ptr< DeviceCGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> > |
◆ DeviceConvFwdBiasActivationAddPtr
template<typename InElementwiseOperation , typename WeiElementwiseOperation , typename OutElementwiseOperation >
using ck::tensor_operation::device::DeviceConvFwdBiasActivationAddPtr = typedef std::unique_ptr<DeviceConvFwdBiasActivationAdd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation> > |
◆ DeviceConvFwdBiasActivationPtr
template<typename InElementwiseOperation , typename WeiElementwiseOperation , typename OutElementwiseOperation >
using ck::tensor_operation::device::DeviceConvFwdBiasActivationPtr = typedef std::unique_ptr<DeviceConvFwdBiasActivation<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation> > |
◆ DeviceElementwiseNormalizationPtr
template<typename InDataTypeTuple , typename GammaDataType , typename BetaDataType , typename AccDataType , typename YDataType , typename XElementwiseOperation , typename YElementwiseOperation , index_t Rank, index_t NumReduceDim>
using ck::tensor_operation::device::DeviceElementwiseNormalizationPtr = typedef std::unique_ptr<DeviceElementwiseNormalization<InDataTypeTuple, GammaDataType, BetaDataType, AccDataType, YDataType, XElementwiseOperation, YElementwiseOperation, Rank, NumReduceDim> > |
◆ DeviceElementwisePtr
template<typename InDataTypeTuple , typename OutDataTypeTuple , typename ElementwiseOperation , index_t NumDim>
using ck::tensor_operation::device::DeviceElementwisePtr = typedef std::unique_ptr<DeviceElementwise<InDataTypeTuple, OutDataTypeTuple, ElementwiseOperation, UnaryOperation, Scale, NumDim> > |
◆ DeviceGemmMultipleDMultipleRPtr
template<typename ALayout , typename BLayout , typename DELayout , typename ADataType , typename BDataType , typename DsDataType , typename EDataType , typename RsDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , typename QsElementwiseOperation , typename RsElementwiseOperation >
using ck::tensor_operation::device::DeviceGemmMultipleDMultipleRPtr = typedef std::unique_ptr<DeviceGemmMultipleDMultipleR<ALayout, BLayout, DELayout, ADataType, BDataType, DsDataType, EDataType, RsDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, QsElementwiseOperation, RsElementwiseOperation> > |
◆ DeviceGemmReducePtr
template<ck::index_t NumDTensor, ck::index_t NumReduce>
using ck::tensor_operation::device::DeviceGemmReducePtr = typedef std::unique_ptr<DeviceGemmReduce<NumDTensor, NumReduce> > |
◆ DeviceGemmSplitKPtr
template<typename ALayout , typename BLayout , typename CLayout , typename ADataType , typename BDataType , typename CDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , typename ComputeType = CDataType>
using ck::tensor_operation::device::DeviceGemmSplitKPtr = typedef std::unique_ptr<DeviceGemmSplitK<ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, ComputeType> > |
◆ DeviceGemmStreamKPtr
template<typename ALayout , typename BLayout , typename CLayout , typename ADataType , typename BDataType , typename CDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation >
using ck::tensor_operation::device::DeviceGemmStreamKPtr = typedef std::unique_ptr<DeviceGemmStreamK<ALayout, BLayout, CLayout, ADataType, BDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> > |
◆ DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
template<index_t NDimSpatial, typename ALayout , typename BLayout , typename DsLayout , typename ELayout , typename ADataType , typename BDataType , typename AccDataType , typename CShuffleDataType , typename DsDataType , typename EDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , ConvolutionForwardSpecialization ConvForwardSpecialization, GemmSpecialization GemmSpec, index_t NumGemmKPrefetchStage, index_t BlockSize, index_t MPerBlock, index_t NPerBlock, index_t KPerBlock, index_t AK1, index_t BK1, index_t MPerXDL, index_t NPerXDL, index_t MXdlPerWave, index_t NXdlPerWave, typename ABlockTransferThreadClusterLengths_AK0_M_AK1 , typename ABlockTransferThreadClusterArrangeOrder , typename ABlockTransferSrcAccessOrder , index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferDstScalarPerVector_AK1, index_t ABlockLdsExtraM, typename BBlockTransferThreadClusterLengths_BK0_N_BK1 , typename BBlockTransferThreadClusterArrangeOrder , typename BBlockTransferSrcAccessOrder , index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferDstScalarPerVector_BK1, index_t BBlockLdsExtraN, index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock , index_t CDEBlockTransferScalarPerVector_NPerBlock, typename AComputeDataType = decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value, Number<0>, ADataType>()), typename BComputeDataType = AComputeDataType, LoopScheduler LoopSched = make_default_loop_scheduler()>
using ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle = typedef DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ConvForwardSpecialization, GemmSpec, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, AComputeDataType, BComputeDataType, LoopSched> |
◆ DeviceGroupedConvFwdMultipleD
template<index_t NDimSpatial, typename ALayout , typename BLayout , typename DsLayout , typename ELayout , typename ADataType , typename BDataType , typename DsDataType , typename EDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , typename ComputeType = decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value, Number<0>, ADataType>())>
using ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD = typedef DeviceGroupedConvFwdMultipleABD<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, ADataType, BDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, ComputeType> |
Grouped Convolution Forward.
- Note
- This structure is deprecated (left for backwards compatibility). Please use DeviceGroupedConvFwdMultipleABD.
- Template Parameters
-
NDimSpatial Number of spatial dimensions. ALayout Input layout (also for a1, a2...). BLayout Weight layout (also for b1, b2...). DsLayout Ds layouts. ELayout Output layout. ADataType Input data type. Pass tuple if there is multiple A. BDataType Weight data type. Pass tuple if there is multiple B. DsDataType D data types. EDataType Output data type. AElementwiseOperation A elementwise operation. BElementwiseOperation B elementwise operation. CDEElementwiseOperation CDE elementwise operation. ComputeType Compute data type (default: ADataType, first if tuple passed).
◆ DeviceMultipleReducePtr
template<index_t Rank, index_t NumReduceDim, index_t NumReduction, typename InElementwiseOperationTuple , typename AccElementwiseOperationTuple >
using ck::tensor_operation::device::DeviceMultipleReducePtr = typedef std::unique_ptr<DeviceMultipleReduce<Rank, NumReduceDim, NumReduction, InElementwiseOperationTuple, AccElementwiseOperationTuple> > |
◆ DeviceNormalizationBwdDataPtr
template<typename DYDataType , typename XDataType , typename GammaDataType , typename MeanInvStdDataType , typename DXDataType , index_t Rank, index_t NumReduceDim>
using ck::tensor_operation::device::DeviceNormalizationBwdDataPtr = typedef std::unique_ptr<DeviceNormalizationBwdData<DYDataType, XDataType, GammaDataType, MeanInvStdDataType, DXDataType, Rank, NumReduceDim> > |
◆ DeviceNormalizationBwdGammaBetaPtr
template<typename DYDataType , typename XDataType , typename MeanInvStdDataType , typename DGammaDataType , typename DBetaDataType , index_t Rank, index_t NumReduceDim>
using ck::tensor_operation::device::DeviceNormalizationBwdGammaBetaPtr = typedef std::unique_ptr<DeviceNormalizationBwdGammaBeta<DYDataType, XDataType, MeanInvStdDataType, DGammaDataType, DBetaDataType, Rank, NumReduceDim> > |
◆ DeviceNormalizationFwdPtr
template<typename XDataType , typename GammaDataType , typename BetaDataType , typename YDataType , typename SaveMeanInvStdDataType , typename YElementwiseOperation , index_t Rank, index_t NumReduceDim>
using ck::tensor_operation::device::DeviceNormalizationFwdPtr = typedef std::unique_ptr<DeviceNormalizationFwd<XDataType, GammaDataType, BetaDataType, YDataType, SaveMeanInvStdDataType, YElementwiseOperation, Rank, NumReduceDim> > |
◆ DeviceReduceMultiDPtr
template<typename InDataType , typename DsDataType , typename AccDataType , typename OutDataType , index_t Rank, index_t NumReduceDim, typename ReduceOperation , typename InElementwiseOperation , typename OutElementwiseOperation >
using ck::tensor_operation::device::DeviceReduceMultiDPtr = typedef std::unique_ptr<DeviceReduceMultiD<InDataType, DsDataType, AccDataType, OutDataType, Rank, NumReduceDim, ReduceOperation, InElementwiseOperation, OutElementwiseOperation> > |
◆ DeviceReducePtr
template<typename InDataType , typename AccDataType , typename OutDataType , index_t Rank, index_t NumReduceDim, typename ReduceOperation , typename InElementwiseOperation , typename AccElementwiseOperation , bool PropagateNan, bool OutputIndex>
using ck::tensor_operation::device::DeviceReducePtr = typedef std::unique_ptr<DeviceReduce<InDataType, AccDataType, OutDataType, Rank, NumReduceDim, ReduceOperation, InElementwiseOperation, AccElementwiseOperation, PropagateNan, OutputIndex> > |
◆ DeviceSoftmaxPtr
template<typename InDataType , typename AccDataType , typename OutDataType , typename InElementwiseOp , typename AccElementwiseOp , index_t Rank, index_t NumReduceDim>
using ck::tensor_operation::device::DeviceSoftmaxPtr = typedef std::unique_ptr<DeviceSoftmax<InDataType, AccDataType, OutDataType, InElementwiseOp, AccElementwiseOp, Rank, NumReduceDim> > |
◆ is_tuple
template<typename T >
using ck::tensor_operation::device::is_tuple = typedef decltype(std::declval<T&>().IsTuple()) |
Enumeration Type Documentation
◆ ConvolutionBackwardDataSpecialization
◆ ConvolutionBackwardWeightSpecialization
◆ ConvolutionForwardSpecialization
◆ GemmSpecialization
|
strong |
◆ MaskingSpecialization
◆ TensorSpecialization
Function Documentation
◆ CalculateMaxRead()
template<index_t NumDim1, index_t NumDim2>
auto ck::tensor_operation::device::CalculateMaxRead | ( | const std::vector< index_t > & | lengths, |
const std::vector< index_t > & | strides | ||
) |
Calculates the maximum number of subsequent elements of the fast changing dimension that are consecutive in memory.
Example: NumDimM = 2, NumDimK = 3 A shape = [ 2, 3, 4, 5, 6] A strides = [360, 120, 30, 6, 1] | M | | K | It follows from strides that K is FCD and all the subsequent elements of K are consecutive in memory. But if strides were [360, 120, 6, 24, 1], then only 6 subsequent elements of K would be consecutive in memory.
Assumes that the dimensions are split into two groups of NumDim1
and NumDim2
dimensions.
◆ get_2d_lengths() [1/2]
template<index_t Rank, int NumReduceDim>
std::pair<long_index_t, long_index_t> ck::tensor_operation::device::get_2d_lengths | ( | const std::array< index_t, Rank > & | inLengths | ) |
◆ get_2d_lengths() [2/2]
template<index_t Rank, int NumReduceDim>
std::pair<long_index_t, long_index_t> ck::tensor_operation::device::get_2d_lengths | ( | const std::vector< index_t > & | inLengths | ) |
◆ getConvBackwardDataSpecializationString()
|
inline |
◆ getConvBackwardWeightSpecializationString()
|
inline |
◆ getConvForwardSpecializationString()
|
inline |
◆ getGemmSpecializationString()
|
inline |
◆ getMaskingSpecializationString()
|
inline |
◆ getTensorSpecializationString()
|
inline |
◆ grid_desc()
template<GemmSpecialization GemmSpec, typename MPerTileType , typename NPerTileType , typename KPerTileType , typename CDesc_MRaw_NRaw >
auto ck::tensor_operation::device::grid_desc | ( | MatrixPadder< GemmSpec, MPerTileType, NPerTileType, KPerTileType > | matrix_padder, |
CDesc_MRaw_NRaw | conv_desc | ||
) |
◆ is_GNDHWC_GKZYXC_GNDHWK()
template<typename InLayout , typename WeiLayout , typename OutLayout >
|
constexpr |
◆ is_GNHWC_GKYXC_GNHWK()
template<typename InLayout , typename WeiLayout , typename OutLayout >
|
constexpr |
◆ is_GNSpatialC_GKSpatial_GNSpatialK()
template<typename InLayout , typename WeiLayout , typename OutLayout >
|
constexpr |
◆ is_GNWC_GKXC_GNWK()
template<typename InLayout , typename WeiLayout , typename OutLayout >
|
constexpr |
◆ is_NDHWGC_GKZYXC_NDHWGK()
template<typename InLayout , typename WeiLayout , typename OutLayout >
|
constexpr |
◆ is_NGCDHW_GKZYXC_NGKDHW()
template<typename InLayout , typename WeiLayout , typename OutLayout >
|
constexpr |
◆ is_NGCHW_GKYXC_NGKHW()
template<typename InLayout , typename WeiLayout , typename OutLayout >
|
constexpr |
◆ is_NGCSpatial_GKSpatial_NGKSpatial()
template<typename InLayout , typename WeiLayout , typename OutLayout >
|
constexpr |
◆ is_NGCW_GKXC_NGKW()
template<typename InLayout , typename WeiLayout , typename OutLayout >
|
constexpr |
◆ is_NHWGC_GKYXC_NHWGK()
template<typename InLayout , typename WeiLayout , typename OutLayout >
|
constexpr |
◆ is_NSpatialGC_GKSpatial_NSpatialGK()
template<typename InLayout , typename WeiLayout , typename OutLayout >
|
constexpr |
◆ is_NWGC_GKXC_NWGK()
template<typename InLayout , typename WeiLayout , typename OutLayout >
|
constexpr |
◆ kernel_batched_gemm_dlops_bwd_weight()
template<typename GridwiseGemm , typename FloatAB , typename FloatC , typename AGridDesc_B_K0_M0_M1_K1 , typename BGridDesc_B_K0_N0_N1_K1 , typename CGridDesc_M0_M10_M11_N0_N10_N11 , typename Block2CTileMap , typename ComputePtrOffsetOfBatch , bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_batched_gemm_dlops_bwd_weight | ( | const FloatAB *__restrict__ | p_a_grid, |
const FloatAB *__restrict__ | p_b_grid, | ||
FloatC *__restrict__ | p_c_grid, | ||
const index_t | batch_count, | ||
const AGridDesc_B_K0_M0_M1_K1 | a_grid_desc_kbatch_k0_m0_m1_k1, | ||
const BGridDesc_B_K0_N0_N1_K1 | b_grid_desc_kbatch_k0_n0_n1_k1, | ||
const CGridDesc_M0_M10_M11_N0_N10_N11 | c_grid_desc_m0_m10_m11_n0_n10_n11, | ||
const Block2CTileMap | block_2_ctile_map, | ||
const ComputePtrOffsetOfBatch | compute_ptr_offset_of_batch | ||
) |
◆ kernel_batched_gemm_e_permute_xdl()
template<typename GridwiseGemm , typename ABDataType , typename EDataType , typename AGridDesc_AK0_M_AK1 , typename BGridDesc_BK0_N_BK1 , typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , typename ComputePtrOffsetOfBatch , typename Block2ETileMap , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_batched_gemm_e_permute_xdl | ( | const ABDataType *__restrict__ | p_a_grid, |
const ABDataType *__restrict__ | p_b_grid, | ||
EDataType *__restrict__ | p_e_grid, | ||
const index_t | batch_count, | ||
const AGridDesc_AK0_M_AK1 | a_grid_desc_ak0_m_ak1, | ||
const BGridDesc_BK0_N_BK1 | b_grid_desc_bk0_n_bk1, | ||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock | e_grid_desc_mblock_mperblock_nblock_nperblock, | ||
const AElementwiseOperation | a_element_op, | ||
const BElementwiseOperation | b_element_op, | ||
const CDEElementwiseOperation | cde_element_op, | ||
const ComputePtrOffsetOfBatch | compute_ptr_offset_of_batch, | ||
const Block2ETileMap | block_2_etile_map | ||
) |
◆ kernel_batched_gemm_gemm_xdl_cshuffle_v1()
template<typename GridwiseGemm , typename A0B0B1DataType , typename D0sPointer , typename D1sPointer , typename E1DataType , typename A0ElementwiseOperation , typename B0ElementwiseOperation , typename CDE0ElementwiseOperation , typename B1ElementwiseOperation , typename CDE1ElementwiseOperation , typename A0GridDesc_AK0_M_AK1 , typename B0GridDesc_BK0_N_BK1 , typename D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 , typename B1GridDesc_BK0_N_BK1 , typename D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename Block2E1TileMap , typename ComputeBasePtrOfStridedBatch , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_batched_gemm_gemm_xdl_cshuffle_v1 | ( | const A0B0B1DataType *__restrict__ | p_a0_grid, |
const A0B0B1DataType *__restrict__ | p_b0_grid, | ||
D0sPointer | p_d0s_grid, | ||
const A0B0B1DataType *__restrict__ | p_b1_grid, | ||
D1sPointer | p_d1s_grid, | ||
E1DataType *__restrict__ | p_e1_grid, | ||
const A0ElementwiseOperation | a0_element_op, | ||
const B0ElementwiseOperation | b0_element_op, | ||
const CDE0ElementwiseOperation | cde0_element_op, | ||
const B1ElementwiseOperation | b1_element_op, | ||
const CDE1ElementwiseOperation | cde1_element_op, | ||
const A0GridDesc_AK0_M_AK1 | a0_grid_desc_ak0_m_ak1, | ||
const B0GridDesc_BK0_N_BK1 | b0_grid_desc_bk0_n_bk1, | ||
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 | d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, | ||
const B1GridDesc_BK0_N_BK1 | b1_grid_desc_bk0_n_bk1, | ||
const D1sGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock | d1s_grid_desc_mblock_mperblock_nblock_nperblock, | ||
const E1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock | e1_grid_desc_mblock_mperblock_nblock_nperblock, | ||
const Block2E1TileMap | block_2_e1tile_map, | ||
const index_t | batch_count, | ||
const ComputeBasePtrOfStridedBatch | compute_base_ptr_of_batch | ||
) |
◆ kernel_batched_gemm_reduce_xdl_cshuffle_v1()
template<typename GridwiseGemm , typename FloatAB , typename FloatC , typename ReducePtrsGlobal , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , typename ReduceInElementwiseOperations , typename ReduceAccElementwiseOperations , typename AGridDesc_AK0_M_AK1 , typename BGridDesc_BK0_N_BK1 , typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename ReduceGridDescriptor_MBlock_MPerBlock , typename ComputeBasePrtOfBatch , typename Block2CTileMap , bool HasMainK0BlockLoop>
__global__ void ck::tensor_operation::device::kernel_batched_gemm_reduce_xdl_cshuffle_v1 | ( | const FloatAB *__restrict__ | p_a_grid, |
const FloatAB *__restrict__ | p_b_grid, | ||
FloatC *__restrict__ | p_c_grid, | ||
ReducePtrsGlobal | p_reduces_grid, | ||
const index_t | batch_count, | ||
const AElementwiseOperation | a_element_op, | ||
const BElementwiseOperation | b_element_op, | ||
const CElementwiseOperation | c_element_op, | ||
const ReduceInElementwiseOperations | reduce_in_element_ops, | ||
const ReduceAccElementwiseOperations | reduce_out_element_ops, | ||
const AGridDesc_AK0_M_AK1 | a_grid_desc_ak0_m_ak1, | ||
const BGridDesc_BK0_N_BK1 | b_grid_desc_bk0_n_bk1, | ||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock | c_grid_desc_mblock_mperblock_nblock_nperblock, | ||
const ReduceGridDescriptor_MBlock_MPerBlock | reduce_grid_desc_mblock_mperblock, | ||
const ComputeBasePrtOfBatch | compute_base_ptr_of_batch_, | ||
const Block2CTileMap | block_2_ctile_map | ||
) |
◆ kernel_batched_gemm_softmax_gemm_wmma_cshuffle()
template<typename DeviceOp , typename GridwiseOp , typename ADataType , typename B0DataType , typename B1DataType , typename CDataType , typename AElementwiseOperation , typename B0ElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_batched_gemm_softmax_gemm_wmma_cshuffle | ( | const ADataType *__restrict__ | p_a_grid, |
const B0DataType *__restrict__ | p_b0_grid, | ||
const B1DataType *__restrict__ | p_b1_grid, | ||
CDataType *__restrict__ | p_c_grid, | ||
index_t | M, | ||
index_t | N, | ||
index_t | K, | ||
index_t | O, | ||
index_t | G0, | ||
index_t | G1, | ||
float | alpha, | ||
bool | input_permute, | ||
bool | output_permute | ||
) |
◆ kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1() [1/2]
template<typename GridwiseGemm , typename FloatAB , typename FloatC , typename AElementwiseOperation , typename BElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , typename AGridDesc_AK0_M_AK1 , typename BGridDesc_BK0_N_BK1 , typename B1GridDesc_BK0_N_BK1 , typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename Block2CTileMap , typename ComputeBasePtrOfStridedBatch , typename C0MatrixMask , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1 | ( | const FloatAB *__restrict__ | p_a_grid, |
const FloatAB *__restrict__ | p_b_grid, | ||
const FloatAB *__restrict__ | p_b1_grid, | ||
FloatC *__restrict__ | p_c_grid, | ||
const AElementwiseOperation | a_element_op, | ||
const BElementwiseOperation | b_element_op, | ||
const AccElementwiseOperation | acc_element_op, | ||
const B1ElementwiseOperation | b1_element_op, | ||
const CElementwiseOperation | c_element_op, | ||
const AGridDesc_AK0_M_AK1 | a_grid_desc_ak0_m_ak1, | ||
const BGridDesc_BK0_N_BK1 | b_grid_desc_bk0_n_bk1, | ||
const B1GridDesc_BK0_N_BK1 | b1_grid_desc_bk0_n_bk1, | ||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock | c_grid_desc_mblock_mperblock_nblock_nperblock, | ||
const Block2CTileMap | block_2_ctile_map, | ||
const index_t | batch_count, | ||
const ComputeBasePtrOfStridedBatch | compute_base_ptr_of_batch, | ||
const C0MatrixMask | c0_matrix_mask | ||
) |
◆ kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1() [2/2]
template<typename GridwiseGemm , typename FloatAB , typename FloatC , typename D0sPointer , typename AElementwiseOperation , typename BElementwiseOperation , typename C0DEElementwiseOperation , typename B1ElementwiseOperation , typename C1DEElementwiseOperation , typename AGridDesc_AK0_M_AK1 , typename BGridDesc_BK0_N_BK1 , typename B1GridDesc_BK0_N_BK1 , typename C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 , typename Block2CTileMap , typename ComputeBasePtrOfStridedBatch , typename C0MatrixMask , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_batched_gemm_softmax_gemm_xdl_cshuffle_v1 | ( | const FloatAB *__restrict__ | p_a_grid, |
const FloatAB *__restrict__ | p_b_grid, | ||
const FloatAB *__restrict__ | p_b1_grid, | ||
FloatC *__restrict__ | p_c_grid, | ||
D0sPointer | p_d0s_grid, | ||
const AElementwiseOperation | a_element_op, | ||
const BElementwiseOperation | b_element_op, | ||
const C0DEElementwiseOperation | c0de_element_op, | ||
const B1ElementwiseOperation | b1_element_op, | ||
const C1DEElementwiseOperation | c1de_element_op, | ||
const AGridDesc_AK0_M_AK1 | a_grid_desc_ak0_m_ak1, | ||
const BGridDesc_BK0_N_BK1 | b_grid_desc_bk0_n_bk1, | ||
const B1GridDesc_BK0_N_BK1 | b1_grid_desc_bk0_n_bk1, | ||
const C1GridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock | c1_grid_desc_mblock_mperblock_nblock_nperblock, | ||
const D0sGridDescriptor_M0_N0_M1_N1_M2_N2_M3_N3_N4_N5 | d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5, | ||
const Block2CTileMap | block_2_ctile_map, | ||
const index_t | batch_count, | ||
const ComputeBasePtrOfStridedBatch | compute_base_ptr_of_batch, | ||
const C0MatrixMask | c0_matrix_mask | ||
) |
◆ kernel_batched_gemm_xdl()
template<typename GridwiseGemm , typename ABDataType , typename DsPointer , typename EDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , typename AGridDesc_AK0_M_AK1 , typename BGridDesc_BK0_N_BK1 , typename DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock , typename ComputePtrOffsetOfBatch , typename Block2ETileMap , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_batched_gemm_xdl | ( | const ABDataType *__restrict__ | p_a_grid, |
const ABDataType *__restrict__ | p_b_grid, | ||
DsPointer | p_ds_grid, | ||
EDataType *__restrict__ | p_e_grid, | ||
const index_t | batch_count, | ||
const AElementwiseOperation | a_element_op, | ||
const BElementwiseOperation | b_element_op, | ||
const CDEElementwiseOperation | cde_element_op, | ||
const AGridDesc_AK0_M_AK1 | a_grid_desc_k0_m_k1, | ||
const BGridDesc_BK0_N_BK1 | b_grid_desc_k0_n_k1, | ||
const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock | ds_grid_desc_mblock_mperblock_nblock_nperblock, | ||
const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock | e_grid_desc_mblock_mperblock_nblock_nperblock_, | ||
const ComputePtrOffsetOfBatch | compute_ptr_offset_of_batch, | ||
const Block2ETileMap | block_2_etile_map | ||
) |
◆ kernel_batched_gemm_xdlops_bwd_weight()
template<typename GridwiseGemm , typename FloatA , typename FloatB , typename FloatC , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , typename AGridDesc_B_K0_M_K1 , typename BGridDesc_B_K0_N_K1 , typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock , typename Block2CTileMap , typename ComputePtrOffsetOfBatch , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_batched_gemm_xdlops_bwd_weight | ( | const FloatA *__restrict__ | p_a_grid, |
const FloatB *__restrict__ | p_b_grid, | ||
FloatC *__restrict__ | p_c_grid, | ||
const AElementwiseOperation | a_element_op, | ||
const BElementwiseOperation | b_element_op, | ||
const CElementwiseOperation | c_element_op, | ||
const index_t | batch_count, | ||
const AGridDesc_B_K0_M_K1 | a_b_k0_m_k1_grid_desc, | ||
const BGridDesc_B_K0_N_K1 | b_b_k0_n_k1_grid_desc, | ||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock | c_grid_desc_mblock_mperblock_nblock_nperblock, | ||
const Block2CTileMap | block_2_ctile_map, | ||
const ComputePtrOffsetOfBatch | compute_ptr_offset_of_batch | ||
) |
◆ kernel_batched_gemm_xdlops_v2r3()
template<typename DeviceOp , typename GridwiseGemm , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_batched_gemm_xdlops_v2r3 | ( | const typename DeviceOp::Argument | karg | ) |
◆ kernel_gemm_dl_multiple_d()
template<typename GridwiseGemm , typename ABDataType , typename DsPointer , typename EDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , typename AGridDesc_K0_M0_M1_K1 , typename BGridDesc_K0_N0_N1_K1 , typename DsGridDesc_M0_M10_M11_N0_N10_N11 , typename CGridDesc_M0_M10_M11_N0_N10_N11 , typename ComputePtrOffsetOfBatch , typename Block2CTileMap , bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_gemm_dl_multiple_d | ( | const ABDataType *__restrict__ | p_a_grid, |
const ABDataType *__restrict__ | p_b_grid, | ||
DsPointer | p_ds_grid, | ||
EDataType *__restrict__ | p_e_grid, | ||
const index_t | batch_count, | ||
const AElementwiseOperation | a_element_op, | ||
const BElementwiseOperation | b_element_op, | ||
const CDEElementwiseOperation | cde_element_op, | ||
const AGridDesc_K0_M0_M1_K1 | a_grid_desc_k0_m0_m1_k1, | ||
const BGridDesc_K0_N0_N1_K1 | b_grid_desc_k0_n0_n1_k1, | ||
const DsGridDesc_M0_M10_M11_N0_N10_N11 | ds_grid_desc_m0_m10_m11_n0_n10_n11, | ||
const CGridDesc_M0_M10_M11_N0_N10_N11 | e_grid_desc_m0_m10_m11_n0_n10_n11, | ||
const ComputePtrOffsetOfBatch | compute_ptr_offset_of_batch, | ||
const Block2CTileMap | block_2_ctile_map | ||
) |
◆ kernel_gemm_gemm_xdl_cshuffle_v1()
template<typename GridwiseGemm , typename FloatAB , typename FloatC , typename AElementwiseOperation , typename BElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , typename AGridDesc_AK0_M_AK1 , typename BGridDesc_BK0_N_BK1 , typename B1GridDesc_BK0_N_BK1 , typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock , typename Block2CTileMap , typename ComputeBasePtrOfStridedBatch , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_gemm_gemm_xdl_cshuffle_v1 | ( | const FloatAB *__restrict__ | p_a_grid, |
const FloatAB *__restrict__ | p_b_grid, | ||
const FloatAB *__restrict__ | p_b1_grid, | ||
FloatC *__restrict__ | p_c_grid, | ||
const AElementwiseOperation | a_element_op, | ||
const BElementwiseOperation | b_element_op, | ||
const AccElementwiseOperation | acc_element_op, | ||
const B1ElementwiseOperation | b1_element_op, | ||
const CElementwiseOperation | c_element_op, | ||
const AGridDesc_AK0_M_AK1 | a_grid_desc_ak0_m_ak1, | ||
const BGridDesc_BK0_N_BK1 | b_grid_desc_bk0_n_bk1, | ||
const B1GridDesc_BK0_N_BK1 | b1_grid_desc_bk0_n_bk1, | ||
const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock | c_grid_desc_mblock_mperblock_nblock_nperblock, | ||
const Block2CTileMap | block_2_ctile_map, | ||
const index_t | batch_count, | ||
const ComputeBasePtrOfStridedBatch | compute_base_ptr_of_batch | ||
) |
◆ kernel_gemm_xdlops_v2r3_for_conv3d()
template<typename GridwiseGemm , typename FloatAB , typename FloatC , typename AGridDesc_K0_M_K1 , typename BGridDesc_K0_N_K1 , typename CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , typename Block2CTileMap , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_gemm_xdlops_v2r3_for_conv3d | ( | const FloatAB *__restrict__ | p_a_grid, |
const FloatAB *__restrict__ | p_b_grid, | ||
FloatC *__restrict__ | p_c_grid, | ||
const index_t | num_batches, | ||
const index_t | a_batch_stride, | ||
const index_t | b_batch_stride, | ||
const index_t | c_batch_stride, | ||
const AGridDesc_K0_M_K1 | a_grid_desc_k0_m_k1, | ||
const BGridDesc_K0_N_K1 | b_grid_desc_k0_n_k1, | ||
const CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 | c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2, | ||
const AElementwiseOperation | a_element_op, | ||
const BElementwiseOperation | b_element_op, | ||
const CElementwiseOperation | c_element_op, | ||
const Block2CTileMap | block_2_ctile_map | ||
) |
◆ kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3()
template<typename GridwiseGemm , typename AGridDesc_AK0_M_K1 , typename BGridDesc_BK0_N_K1 , typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock , typename ComputePtrOffsetOfBatch , index_t NumGroupsToMerge, bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation, index_t MinimumOccupancy = 1, TailNumber TailNum = TailNumber::Full>
__global__ void ck::tensor_operation::device::kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3 | ( | typename GridwiseGemm::Argument | karg, |
[[maybe_unused] ] const AGridDesc_AK0_M_K1 | a_grid_desc_ak0_m_ak1, | ||
[[maybe_unused] ] const BGridDesc_BK0_N_K1 | b_grid_desc_bk0_n_bk1, | ||
[[maybe_unused] ] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock | c_grid_desc_mblock_mperblock_nblock_nperblock, | ||
[[maybe_unused] ] const ComputePtrOffsetOfBatch | compute_ptr_offset_of_batch, | ||
[[maybe_unused] ] const index_t | num_k_per_block | ||
) |
◆ kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds()
template<typename GridwiseGemm , typename AGridDesc_AK0_M_K1 , typename BGridDesc_BK0_N_K1 , typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock , typename ComputePtrOffsetOfBatch , index_t NumGroupsToMerge, bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation, index_t MinimumOccupancy = 1, TailNumber TailNum = TailNumber::Full>
__global__ void ck::tensor_operation::device::kernel_grouped_conv_bwd_weight_xdl_cshuffle_v3_2lds | ( | typename GridwiseGemm::Argument | karg, |
[[maybe_unused] ] const AGridDesc_AK0_M_K1 | a_grid_desc_ak0_m_ak1, | ||
[[maybe_unused] ] const BGridDesc_BK0_N_K1 | b_grid_desc_bk0_n_bk1, | ||
[[maybe_unused] ] const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock | c_grid_desc_mblock_mperblock_nblock_nperblock, | ||
[[maybe_unused] ] const ComputePtrOffsetOfBatch | compute_ptr_offset_of_batch, | ||
[[maybe_unused] ] const index_t | num_k_per_block | ||
) |
◆ kernel_grouped_gemm_multiple_d_dl()
template<typename GridwiseGemm , typename GemmDesc , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_grouped_gemm_multiple_d_dl | ( | const void CK_CONSTANT_ADDRESS_SPACE * | gemm_descs_const, |
const index_t | group_count, | ||
const AElementwiseOperation | a_element_op, | ||
const BElementwiseOperation | b_element_op, | ||
const CDEElementwiseOperation | cde_element_op | ||
) |
◆ kernel_grouped_gemm_multiple_d_xdl()
template<typename GridwiseGemm , typename GemmDesc , GemmSpecialization GemmSpec, typename ADataType , typename BDataType , typename DsDataType , typename EDataType , typename ALayout , typename BLayout , typename DsLayout , typename ELayout , index_t KPerBlock, typename OffsettedBlockToCTileMap , typename LocalBlock2ETileMap , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , BlockGemmPipelineScheduler BlkGemmPipeSched, BlockGemmPipelineVersion BlkGemmPipelineVer>
__global__ void ck::tensor_operation::device::kernel_grouped_gemm_multiple_d_xdl | ( | const void CK_CONSTANT_ADDRESS_SPACE * | gemm_descs_const, |
const index_t | group_count, | ||
const AElementwiseOperation | a_element_op, | ||
const BElementwiseOperation | b_element_op, | ||
const CDEElementwiseOperation | cde_element_op | ||
) |
Entry point kernel for device-wide Grouped GEMM operation.
- Parameters
-
[in] gemm_descs_const The pointer to the array of GEMM descriptor structures. [in] group_count The number of together processed GEMMs.
- Template Parameters
-
GridwiseGemm The specific GridwiseGEMM algorithm implementation. GemmDesc The structure holding all necessary descriptors and other data needed for grouped gemm calculation and work distribution. LocalBlock2ETileMap The structure providing mapping between workgroup ids, the data tiles to process and the output tiles.
◆ kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1()
template<typename GridwiseGemm , typename GroupKernelArg , typename AElementwiseOperation , typename BElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_grouped_gemm_softmax_gemm_xdl_cshuffle_v1 | ( | const void CK_CONSTANT_ADDRESS_SPACE * | group_kernel_args, |
const index_t | group_count, | ||
const AElementwiseOperation | a_element_op, | ||
const BElementwiseOperation | b_element_op, | ||
const AccElementwiseOperation | acc_element_op, | ||
const B1ElementwiseOperation | b1_element_op, | ||
const CElementwiseOperation | c_element_op | ||
) |
◆ kernel_grouped_gemm_xdl()
template<typename GridwiseGemm , typename GemmDesc , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_grouped_gemm_xdl | ( | const void CK_CONSTANT_ADDRESS_SPACE * | gemm_descs_const, |
const index_t | group_count, | ||
const AElementwiseOperation | a_element_op, | ||
const BElementwiseOperation | b_element_op, | ||
const CDEElementwiseOperation | c_element_op | ||
) |
◆ kernel_grouped_gemm_xdl_fixed_nk() [1/2]
template<typename GridwiseGemm , typename GemmDesc , GemmSpecialization GemmSpec, typename AsLayout , typename BsLayout , typename DsLayout , typename ELayout , typename Block2ETileMap , typename GroupedGemmBlock2ETileMap , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , InMemoryDataOperationEnum EGlobalMemoryDataOperation, bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_grouped_gemm_xdl_fixed_nk | ( | const void CK_CONSTANT_ADDRESS_SPACE * | gemm_descs_const, |
const index_t | group_count, | ||
const index_t | grid_size_grp, | ||
const AElementwiseOperation | a_element_op, | ||
const BElementwiseOperation | b_element_op, | ||
const CDEElementwiseOperation | cde_element_op | ||
) |
◆ kernel_grouped_gemm_xdl_fixed_nk() [2/2]
template<typename GridwiseGemm , typename GemmDesc , GemmSpecialization GemmSpec, bool Zeroing, typename ALayout , typename BLayout , typename DsLayout , typename ELayout , typename DsDataType , typename Block2ETileMap , typename GroupedGemmBlock2ETileMap , typename AElementwiseOperation , typename BElementwiseOperation , typename CDEElementwiseOperation , InMemoryDataOperationEnum EGlobalMemoryDataOperation, bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_grouped_gemm_xdl_fixed_nk | ( | const void CK_CONSTANT_ADDRESS_SPACE * | gemm_descs_const, |
uint32_t * | barrier_count, | ||
const index_t | barrier_size_grp, | ||
const index_t | group_count, | ||
const index_t | grid_size_grp, | ||
const index_t | KBatch, | ||
const AElementwiseOperation | a_element_op, | ||
const BElementwiseOperation | b_element_op, | ||
const CDEElementwiseOperation | c_element_op | ||
) |
◆ kernel_grouped_gemm_xdl_splitk()
template<typename GridwiseGemm , typename GemmDesc , bool HasMainKBlockLoop, InMemoryDataOperationEnum CGlobalMemoryDataOperation, typename AElementwiseOperation = ck::tensor_operation::element_wise::PassThrough, typename BElementwiseOperation = ck::tensor_operation::element_wise::PassThrough, typename CElementwiseOperation = ck::tensor_operation::element_wise::PassThrough>
__global__ void ck::tensor_operation::device::kernel_grouped_gemm_xdl_splitk | ( | const void CK_CONSTANT_ADDRESS_SPACE * | gemm_descs_const, |
const index_t | group_count, | ||
const AElementwiseOperation | a_element_op, | ||
const BElementwiseOperation | b_element_op, | ||
const CElementwiseOperation | c_element_op | ||
) |
◆ kernel_grouped_query_attention_wmma()
template<typename DeviceOp , typename GridwiseOp , typename ADataType , typename B0DataType , typename B1DataType , typename CDataType , typename AElementwiseOperation , typename B0ElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , ck::index_t QueryGroupNumber, bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_grouped_query_attention_wmma | ( | const ADataType *__restrict__ | p_a_grid, |
const B0DataType *__restrict__ | p_b0_grid, | ||
const B1DataType *__restrict__ | p_b1_grid, | ||
CDataType *__restrict__ | p_c_grid, | ||
index_t | M, | ||
index_t | N, | ||
index_t | K, | ||
index_t | O, | ||
index_t | G0, | ||
index_t | G1, | ||
float | alpha, | ||
bool | input_permute, | ||
bool | output_permute | ||
) |
◆ kernel_multi_query_attention_wmma()
template<typename DeviceOp , typename GridwiseOp , typename ADataType , typename B0DataType , typename B1DataType , typename CDataType , typename AElementwiseOperation , typename B0ElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_multi_query_attention_wmma | ( | const ADataType *__restrict__ | p_a_grid, |
const B0DataType *__restrict__ | p_b0_grid, | ||
const B1DataType *__restrict__ | p_b1_grid, | ||
CDataType *__restrict__ | p_c_grid, | ||
index_t | M, | ||
index_t | N, | ||
index_t | K, | ||
index_t | O, | ||
index_t | G0, | ||
index_t | G1, | ||
float | alpha, | ||
bool | input_permute, | ||
bool | output_permute | ||
) |
◆ kernel_normalization_bwd_data()
template<typename GridwiseNormalizationBwd , typename DYDataType , typename XDataType , typename GammaDataType , typename MeanInvStdDataType , typename DXDataType , typename GridDesc_M_K >
__global__ void ck::tensor_operation::device::kernel_normalization_bwd_data | ( | const GridDesc_M_K | dy_grid_desc_m_k, |
const GridDesc_M_K | x_grid_desc_m_k, | ||
const GridDesc_M_K | gamma_grid_desc_m_k, | ||
const GridDesc_M_K | mean_grid_desc_m_k, | ||
const GridDesc_M_K | inv_std_grid_desc_m_k, | ||
const GridDesc_M_K | dx_grid_desc_m_k, | ||
index_t | num_k_block_tile_iteration, | ||
const DYDataType *const __restrict__ | p_dy_global, | ||
const XDataType *const __restrict__ | p_x_global, | ||
const GammaDataType *const __restrict__ | p_gamma_global, | ||
const MeanInvStdDataType *const __restrict__ | p_mean_global, | ||
const MeanInvStdDataType *const __restrict__ | p_inv_std_global, | ||
DXDataType *const __restrict__ | p_dx_global | ||
) |
◆ kernel_normalization_bwd_gamma_beta()
template<typename GridwiseReduction , typename DYDataType , typename XDataType , typename MeanInvStdDataType , typename DGammaDataType , typename DBetaDataType , typename GridDesc_M_K , typename GridDesc_M >
__global__ void ck::tensor_operation::device::kernel_normalization_bwd_gamma_beta | ( | const GridDesc_M_K | dy_grid_desc_m_k, |
const GridDesc_M_K | x_grid_desc_m_k, | ||
const GridDesc_M_K | mean_grid_desc_m_k, | ||
const GridDesc_M_K | inv_std_grid_desc_m_k, | ||
const GridDesc_M | dgamma_grid_desc_m, | ||
const GridDesc_M | dbeta_grid_desc_m, | ||
index_t | num_k_block_tile_iteration, | ||
const DYDataType *const __restrict__ | p_dy_global, | ||
const XDataType *const __restrict__ | p_x_global, | ||
const MeanInvStdDataType *const __restrict__ | p_mean_global, | ||
const MeanInvStdDataType *const __restrict__ | p_inv_std_global, | ||
DGammaDataType *const __restrict__ | p_dgamma_global, | ||
DBetaDataType *const __restrict__ | p_dbeta_global | ||
) |
◆ kernel_wmma_cross_attention_forward()
template<typename DeviceOp , typename GridwiseOp , typename QDataType , typename KVDataType , typename ODataType , typename AElementwiseOperation , typename B0ElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_wmma_cross_attention_forward | ( | const QDataType *__restrict__ | p_q_grid, |
const KVDataType *__restrict__ | p_kv_grid, | ||
ODataType *__restrict__ | p_out_grid, | ||
index_t | batch_size, | ||
index_t | q_sequence_length, | ||
index_t | kv_sequence_length, | ||
index_t | head_count, | ||
index_t | head_size, | ||
float | alpha | ||
) |
◆ kernel_wmma_self_attention_forward()
template<typename DeviceOp , typename GridwiseOp , typename QKVDataType , typename ODataType , typename AElementwiseOperation , typename B0ElementwiseOperation , typename AccElementwiseOperation , typename B1ElementwiseOperation , typename CElementwiseOperation , bool HasMainKBlockLoop>
__global__ void ck::tensor_operation::device::kernel_wmma_self_attention_forward | ( | const QKVDataType *__restrict__ | p_qkv_grid, |
ODataType *__restrict__ | p_out_grid, | ||
index_t | batch_size, | ||
index_t | sequence_length, | ||
index_t | head_count, | ||
index_t | head_size, | ||
float | alpha | ||
) |
◆ make_tuple_from_array()
template<index_t arraySize>
auto ck::tensor_operation::device::make_tuple_from_array | ( | const std::vector< index_t > & | lengths, |
Number< arraySize > | |||
) |
◆ make_tuple_from_array_and_index_seq()
template<index_t... Ns>
auto ck::tensor_operation::device::make_tuple_from_array_and_index_seq | ( | const std::vector< index_t > & | lengths, |
Sequence< Ns... > | |||
) |
◆ PadTensorDescriptor()
template<typename TensorDesc , typename TileLengths , typename DoPads >
|
constexpr |
◆ shuffle_tensor_dimensions() [1/2]
template<index_t Rank, index_t NumReduceDim>
std::array<index_t, Rank> ck::tensor_operation::device::shuffle_tensor_dimensions | ( | const std::array< index_t, Rank > & | origLengthsStrides, |
const std::array< int, NumReduceDim > & | reduceDims | ||
) |