DeviceMoeGemm< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, DsDataType, CDataType, GemmAccDataType, CShuffleDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ActivationOP, NSwizzle, IsInputGemm, MulRoutedWeight, PerTokenQuant, IndexType, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB > Struct Template Reference#
Classes |
Public Types |
Public Member Functions |
Static Public Member Functions |
Static Public Attributes |
List of all members
ck::tensor_operation::device::DeviceMoeGemm< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, DsDataType, CDataType, GemmAccDataType, CShuffleDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ActivationOP, NSwizzle, IsInputGemm, MulRoutedWeight, PerTokenQuant, IndexType, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB > Struct Template Reference
#include <device_moe_gemm.hpp>
Inheritance diagram for ck::tensor_operation::device::DeviceMoeGemm< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, DsDataType, CDataType, GemmAccDataType, CShuffleDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ActivationOP, NSwizzle, IsInputGemm, MulRoutedWeight, PerTokenQuant, IndexType, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >:
Classes | |
| struct | Invoker |
Public Types | |
| using | GridwiseGemm = GridwiseMoeGemm< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ActivationOP, NSwizzle, IsInputGemm, MulRoutedWeight, PerTokenQuant, IndexType, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB > |
| using | Argument = typename GridwiseGemm::Argument |
Public Member Functions | |
| int | GetPreShuffleParameters () override |
| bool | IsSupportedArgument (const BaseArgument *p_arg) override |
| std::unique_ptr< BaseArgument > | MakeArgumentPointer (const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override |
| std::unique_ptr< BaseInvoker > | MakeInvokerPointer () override |
| std::string | GetTypeString () const override |
Public Member Functions inherited from ck::tensor_operation::device::BaseOperator | |
| BaseOperator ()=default | |
| BaseOperator (const BaseOperator &)=default | |
| BaseOperator & | operator= (const BaseOperator &)=default |
| virtual std::string | GetTypeIdName () const |
| virtual std::optional< std::string > | GetObjectName () const |
| virtual std::optional< std::string > | GetTemplateInfo () const |
| virtual std::string | GetTypeIdHashCode () const |
| virtual size_t | GetWorkSpaceSize (const BaseArgument *) const |
| virtual void | SetWorkSpacePointer (BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const |
| virtual | ~BaseOperator () |
Static Public Member Functions | |
| static constexpr bool | IsValidCompilationParameter () |
| static bool | IsSupportedArgument (const Argument &arg) |
| static auto | MakeArgument (const void *p_sorted_token_ids, const void *p_sorted_expert_ids, const void *p_max_token_id, const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t NumTokens, index_t TopK, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) |
| static auto | MakeInvoker () |
Static Public Attributes | |
| static constexpr index_t | NumDTensor = DsDataType::Size() |
| static constexpr index_t | APackedSize |
| static constexpr index_t | BPackedSize |
Static Public Attributes inherited from ck::tensor_operation::device::DeviceGemmMultipleDSplitKBPreShuffle< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation > | |
| static constexpr index_t | NumDTensor |
Member Typedef Documentation
◆ Argument
template<typename ALayout , typename BLayout , typename DsLayout , typename CLayout , typename ADataType , typename BDataType , typename DsDataType , typename CDataType , typename GemmAccDataType , typename CShuffleDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , GemmSpecialization GemmSpec, index_t BlockSize, index_t MPerBlock, index_t NPerBlock, index_t KPerBlock, index_t AK1, index_t BK1, index_t MPerXDL, index_t NPerXDL, index_t MXdlPerWave, index_t NXdlPerWave, typename ABlockTransferThreadClusterLengths_AK0_M_AK1 , typename ABlockTransferThreadClusterArrangeOrder , typename ABlockTransferSrcAccessOrder , index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferDstScalarPerVector_AK1, bool ABlockLdsExtraM, typename BBlockTransferThreadClusterLengths_BK0_N_BK1 , typename BBlockTransferThreadClusterArrangeOrder , typename BBlockTransferSrcAccessOrder , index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferDstScalarPerVector_BK1, bool BBlockLdsExtraN, index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock , typename CDEShuffleBlockTransferScalarPerVectors , BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, index_t ActivationOP = 0, bool NSwizzle = false, bool IsInputGemm = true, bool MulRoutedWeight = true, bool PerTokenQuant = true, typename IndexType = index_t, typename ComputeTypeA = CDataType, typename ComputeTypeB = ComputeTypeA, typename LDSTypeA = ComputeTypeA, typename LDSTypeB = ComputeTypeB>
| using ck::tensor_operation::device::DeviceMoeGemm< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, DsDataType, CDataType, GemmAccDataType, CShuffleDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ActivationOP, NSwizzle, IsInputGemm, MulRoutedWeight, PerTokenQuant, IndexType, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::Argument = typename GridwiseGemm::Argument |
◆ GridwiseGemm
template<typename ALayout , typename BLayout , typename DsLayout , typename CLayout , typename ADataType , typename BDataType , typename DsDataType , typename CDataType , typename GemmAccDataType , typename CShuffleDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , GemmSpecialization GemmSpec, index_t BlockSize, index_t MPerBlock, index_t NPerBlock, index_t KPerBlock, index_t AK1, index_t BK1, index_t MPerXDL, index_t NPerXDL, index_t MXdlPerWave, index_t NXdlPerWave, typename ABlockTransferThreadClusterLengths_AK0_M_AK1 , typename ABlockTransferThreadClusterArrangeOrder , typename ABlockTransferSrcAccessOrder , index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferDstScalarPerVector_AK1, bool ABlockLdsExtraM, typename BBlockTransferThreadClusterLengths_BK0_N_BK1 , typename BBlockTransferThreadClusterArrangeOrder , typename BBlockTransferSrcAccessOrder , index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferDstScalarPerVector_BK1, bool BBlockLdsExtraN, index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock , typename CDEShuffleBlockTransferScalarPerVectors , BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, index_t ActivationOP = 0, bool NSwizzle = false, bool IsInputGemm = true, bool MulRoutedWeight = true, bool PerTokenQuant = true, typename IndexType = index_t, typename ComputeTypeA = CDataType, typename ComputeTypeB = ComputeTypeA, typename LDSTypeA = ComputeTypeA, typename LDSTypeB = ComputeTypeB>
| using ck::tensor_operation::device::DeviceMoeGemm< ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, DsDataType, CDataType, GemmAccDataType, CShuffleDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ActivationOP, NSwizzle, IsInputGemm, MulRoutedWeight, PerTokenQuant, IndexType, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB >::GridwiseGemm = GridwiseMoeGemm<ALayout, BLayout, DsLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, DsDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEShuffleBlockTransferScalarPerVectors, BlkGemmPipeSched, BlkGemmPipelineVer, ActivationOP, NSwizzle, IsInputGemm, MulRoutedWeight, PerTokenQuant, IndexType, ComputeTypeA, ComputeTypeB, LDSTypeA, LDSTypeB> |
Member Function Documentation
◆ GetPreShuffleParameters()
template<typename ALayout , typename BLayout , typename DsLayout , typename CLayout , typename ADataType , typename BDataType , typename DsDataType , typename CDataType , typename GemmAccDataType , typename CShuffleDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , GemmSpecialization GemmSpec, index_t BlockSize, index_t MPerBlock, index_t NPerBlock, index_t KPerBlock, index_t AK1, index_t BK1, index_t MPerXDL, index_t NPerXDL, index_t MXdlPerWave, index_t NXdlPerWave, typename ABlockTransferThreadClusterLengths_AK0_M_AK1 , typename ABlockTransferThreadClusterArrangeOrder , typename ABlockTransferSrcAccessOrder , index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferDstScalarPerVector_AK1, bool ABlockLdsExtraM, typename BBlockTransferThreadClusterLengths_BK0_N_BK1 , typename BBlockTransferThreadClusterArrangeOrder , typename BBlockTransferSrcAccessOrder , index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferDstScalarPerVector_BK1, bool BBlockLdsExtraN, index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock , typename CDEShuffleBlockTransferScalarPerVectors , BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, index_t ActivationOP = 0, bool NSwizzle = false, bool IsInputGemm = true, bool MulRoutedWeight = true, bool PerTokenQuant = true, typename IndexType = index_t, typename ComputeTypeA = CDataType, typename ComputeTypeB = ComputeTypeA, typename LDSTypeA = ComputeTypeA, typename LDSTypeB = ComputeTypeB>
|
inlineoverridevirtual |
◆ GetTypeString()
template<typename ALayout , typename BLayout , typename DsLayout , typename CLayout , typename ADataType , typename BDataType , typename DsDataType , typename CDataType , typename GemmAccDataType , typename CShuffleDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , GemmSpecialization GemmSpec, index_t BlockSize, index_t MPerBlock, index_t NPerBlock, index_t KPerBlock, index_t AK1, index_t BK1, index_t MPerXDL, index_t NPerXDL, index_t MXdlPerWave, index_t NXdlPerWave, typename ABlockTransferThreadClusterLengths_AK0_M_AK1 , typename ABlockTransferThreadClusterArrangeOrder , typename ABlockTransferSrcAccessOrder , index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferDstScalarPerVector_AK1, bool ABlockLdsExtraM, typename BBlockTransferThreadClusterLengths_BK0_N_BK1 , typename BBlockTransferThreadClusterArrangeOrder , typename BBlockTransferSrcAccessOrder , index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferDstScalarPerVector_BK1, bool BBlockLdsExtraN, index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock , typename CDEShuffleBlockTransferScalarPerVectors , BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, index_t ActivationOP = 0, bool NSwizzle = false, bool IsInputGemm = true, bool MulRoutedWeight = true, bool PerTokenQuant = true, typename IndexType = index_t, typename ComputeTypeA = CDataType, typename ComputeTypeB = ComputeTypeA, typename LDSTypeA = ComputeTypeA, typename LDSTypeB = ComputeTypeB>
|
inlineoverridevirtual |
Reimplemented from ck::tensor_operation::device::BaseOperator.
◆ IsSupportedArgument() [1/2]
template<typename ALayout , typename BLayout , typename DsLayout , typename CLayout , typename ADataType , typename BDataType , typename DsDataType , typename CDataType , typename GemmAccDataType , typename CShuffleDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , GemmSpecialization GemmSpec, index_t BlockSize, index_t MPerBlock, index_t NPerBlock, index_t KPerBlock, index_t AK1, index_t BK1, index_t MPerXDL, index_t NPerXDL, index_t MXdlPerWave, index_t NXdlPerWave, typename ABlockTransferThreadClusterLengths_AK0_M_AK1 , typename ABlockTransferThreadClusterArrangeOrder , typename ABlockTransferSrcAccessOrder , index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferDstScalarPerVector_AK1, bool ABlockLdsExtraM, typename BBlockTransferThreadClusterLengths_BK0_N_BK1 , typename BBlockTransferThreadClusterArrangeOrder , typename BBlockTransferSrcAccessOrder , index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferDstScalarPerVector_BK1, bool BBlockLdsExtraN, index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock , typename CDEShuffleBlockTransferScalarPerVectors , BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, index_t ActivationOP = 0, bool NSwizzle = false, bool IsInputGemm = true, bool MulRoutedWeight = true, bool PerTokenQuant = true, typename IndexType = index_t, typename ComputeTypeA = CDataType, typename ComputeTypeB = ComputeTypeA, typename LDSTypeA = ComputeTypeA, typename LDSTypeB = ComputeTypeB>
|
inlinestatic |
◆ IsSupportedArgument() [2/2]
template<typename ALayout , typename BLayout , typename DsLayout , typename CLayout , typename ADataType , typename BDataType , typename DsDataType , typename CDataType , typename GemmAccDataType , typename CShuffleDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , GemmSpecialization GemmSpec, index_t BlockSize, index_t MPerBlock, index_t NPerBlock, index_t KPerBlock, index_t AK1, index_t BK1, index_t MPerXDL, index_t NPerXDL, index_t MXdlPerWave, index_t NXdlPerWave, typename ABlockTransferThreadClusterLengths_AK0_M_AK1 , typename ABlockTransferThreadClusterArrangeOrder , typename ABlockTransferSrcAccessOrder , index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferDstScalarPerVector_AK1, bool ABlockLdsExtraM, typename BBlockTransferThreadClusterLengths_BK0_N_BK1 , typename BBlockTransferThreadClusterArrangeOrder , typename BBlockTransferSrcAccessOrder , index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferDstScalarPerVector_BK1, bool BBlockLdsExtraN, index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock , typename CDEShuffleBlockTransferScalarPerVectors , BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, index_t ActivationOP = 0, bool NSwizzle = false, bool IsInputGemm = true, bool MulRoutedWeight = true, bool PerTokenQuant = true, typename IndexType = index_t, typename ComputeTypeA = CDataType, typename ComputeTypeB = ComputeTypeA, typename LDSTypeA = ComputeTypeA, typename LDSTypeB = ComputeTypeB>
|
inlineoverridevirtual |
Reimplemented from ck::tensor_operation::device::BaseOperator.
◆ IsValidCompilationParameter()
template<typename ALayout , typename BLayout , typename DsLayout , typename CLayout , typename ADataType , typename BDataType , typename DsDataType , typename CDataType , typename GemmAccDataType , typename CShuffleDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , GemmSpecialization GemmSpec, index_t BlockSize, index_t MPerBlock, index_t NPerBlock, index_t KPerBlock, index_t AK1, index_t BK1, index_t MPerXDL, index_t NPerXDL, index_t MXdlPerWave, index_t NXdlPerWave, typename ABlockTransferThreadClusterLengths_AK0_M_AK1 , typename ABlockTransferThreadClusterArrangeOrder , typename ABlockTransferSrcAccessOrder , index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferDstScalarPerVector_AK1, bool ABlockLdsExtraM, typename BBlockTransferThreadClusterLengths_BK0_N_BK1 , typename BBlockTransferThreadClusterArrangeOrder , typename BBlockTransferSrcAccessOrder , index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferDstScalarPerVector_BK1, bool BBlockLdsExtraN, index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock , typename CDEShuffleBlockTransferScalarPerVectors , BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, index_t ActivationOP = 0, bool NSwizzle = false, bool IsInputGemm = true, bool MulRoutedWeight = true, bool PerTokenQuant = true, typename IndexType = index_t, typename ComputeTypeA = CDataType, typename ComputeTypeB = ComputeTypeA, typename LDSTypeA = ComputeTypeA, typename LDSTypeB = ComputeTypeB>
|
inlinestaticconstexpr |
◆ MakeArgument()
template<typename ALayout , typename BLayout , typename DsLayout , typename CLayout , typename ADataType , typename BDataType , typename DsDataType , typename CDataType , typename GemmAccDataType , typename CShuffleDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , GemmSpecialization GemmSpec, index_t BlockSize, index_t MPerBlock, index_t NPerBlock, index_t KPerBlock, index_t AK1, index_t BK1, index_t MPerXDL, index_t NPerXDL, index_t MXdlPerWave, index_t NXdlPerWave, typename ABlockTransferThreadClusterLengths_AK0_M_AK1 , typename ABlockTransferThreadClusterArrangeOrder , typename ABlockTransferSrcAccessOrder , index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferDstScalarPerVector_AK1, bool ABlockLdsExtraM, typename BBlockTransferThreadClusterLengths_BK0_N_BK1 , typename BBlockTransferThreadClusterArrangeOrder , typename BBlockTransferSrcAccessOrder , index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferDstScalarPerVector_BK1, bool BBlockLdsExtraN, index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock , typename CDEShuffleBlockTransferScalarPerVectors , BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, index_t ActivationOP = 0, bool NSwizzle = false, bool IsInputGemm = true, bool MulRoutedWeight = true, bool PerTokenQuant = true, typename IndexType = index_t, typename ComputeTypeA = CDataType, typename ComputeTypeB = ComputeTypeA, typename LDSTypeA = ComputeTypeA, typename LDSTypeB = ComputeTypeB>
|
inlinestatic |
◆ MakeArgumentPointer()
template<typename ALayout , typename BLayout , typename DsLayout , typename CLayout , typename ADataType , typename BDataType , typename DsDataType , typename CDataType , typename GemmAccDataType , typename CShuffleDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , GemmSpecialization GemmSpec, index_t BlockSize, index_t MPerBlock, index_t NPerBlock, index_t KPerBlock, index_t AK1, index_t BK1, index_t MPerXDL, index_t NPerXDL, index_t MXdlPerWave, index_t NXdlPerWave, typename ABlockTransferThreadClusterLengths_AK0_M_AK1 , typename ABlockTransferThreadClusterArrangeOrder , typename ABlockTransferSrcAccessOrder , index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferDstScalarPerVector_AK1, bool ABlockLdsExtraM, typename BBlockTransferThreadClusterLengths_BK0_N_BK1 , typename BBlockTransferThreadClusterArrangeOrder , typename BBlockTransferSrcAccessOrder , index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferDstScalarPerVector_BK1, bool BBlockLdsExtraN, index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock , typename CDEShuffleBlockTransferScalarPerVectors , BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, index_t ActivationOP = 0, bool NSwizzle = false, bool IsInputGemm = true, bool MulRoutedWeight = true, bool PerTokenQuant = true, typename IndexType = index_t, typename ComputeTypeA = CDataType, typename ComputeTypeB = ComputeTypeA, typename LDSTypeA = ComputeTypeA, typename LDSTypeB = ComputeTypeB>
|
inlineoverridevirtual |
◆ MakeInvoker()
template<typename ALayout , typename BLayout , typename DsLayout , typename CLayout , typename ADataType , typename BDataType , typename DsDataType , typename CDataType , typename GemmAccDataType , typename CShuffleDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , GemmSpecialization GemmSpec, index_t BlockSize, index_t MPerBlock, index_t NPerBlock, index_t KPerBlock, index_t AK1, index_t BK1, index_t MPerXDL, index_t NPerXDL, index_t MXdlPerWave, index_t NXdlPerWave, typename ABlockTransferThreadClusterLengths_AK0_M_AK1 , typename ABlockTransferThreadClusterArrangeOrder , typename ABlockTransferSrcAccessOrder , index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferDstScalarPerVector_AK1, bool ABlockLdsExtraM, typename BBlockTransferThreadClusterLengths_BK0_N_BK1 , typename BBlockTransferThreadClusterArrangeOrder , typename BBlockTransferSrcAccessOrder , index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferDstScalarPerVector_BK1, bool BBlockLdsExtraN, index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock , typename CDEShuffleBlockTransferScalarPerVectors , BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, index_t ActivationOP = 0, bool NSwizzle = false, bool IsInputGemm = true, bool MulRoutedWeight = true, bool PerTokenQuant = true, typename IndexType = index_t, typename ComputeTypeA = CDataType, typename ComputeTypeB = ComputeTypeA, typename LDSTypeA = ComputeTypeA, typename LDSTypeB = ComputeTypeB>
|
inlinestatic |
◆ MakeInvokerPointer()
template<typename ALayout , typename BLayout , typename DsLayout , typename CLayout , typename ADataType , typename BDataType , typename DsDataType , typename CDataType , typename GemmAccDataType , typename CShuffleDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , GemmSpecialization GemmSpec, index_t BlockSize, index_t MPerBlock, index_t NPerBlock, index_t KPerBlock, index_t AK1, index_t BK1, index_t MPerXDL, index_t NPerXDL, index_t MXdlPerWave, index_t NXdlPerWave, typename ABlockTransferThreadClusterLengths_AK0_M_AK1 , typename ABlockTransferThreadClusterArrangeOrder , typename ABlockTransferSrcAccessOrder , index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferDstScalarPerVector_AK1, bool ABlockLdsExtraM, typename BBlockTransferThreadClusterLengths_BK0_N_BK1 , typename BBlockTransferThreadClusterArrangeOrder , typename BBlockTransferSrcAccessOrder , index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferDstScalarPerVector_BK1, bool BBlockLdsExtraN, index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock , typename CDEShuffleBlockTransferScalarPerVectors , BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, index_t ActivationOP = 0, bool NSwizzle = false, bool IsInputGemm = true, bool MulRoutedWeight = true, bool PerTokenQuant = true, typename IndexType = index_t, typename ComputeTypeA = CDataType, typename ComputeTypeB = ComputeTypeA, typename LDSTypeA = ComputeTypeA, typename LDSTypeB = ComputeTypeB>
|
inlineoverridevirtual |
Member Data Documentation
◆ APackedSize
template<typename ALayout , typename BLayout , typename DsLayout , typename CLayout , typename ADataType , typename BDataType , typename DsDataType , typename CDataType , typename GemmAccDataType , typename CShuffleDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , GemmSpecialization GemmSpec, index_t BlockSize, index_t MPerBlock, index_t NPerBlock, index_t KPerBlock, index_t AK1, index_t BK1, index_t MPerXDL, index_t NPerXDL, index_t MXdlPerWave, index_t NXdlPerWave, typename ABlockTransferThreadClusterLengths_AK0_M_AK1 , typename ABlockTransferThreadClusterArrangeOrder , typename ABlockTransferSrcAccessOrder , index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferDstScalarPerVector_AK1, bool ABlockLdsExtraM, typename BBlockTransferThreadClusterLengths_BK0_N_BK1 , typename BBlockTransferThreadClusterArrangeOrder , typename BBlockTransferSrcAccessOrder , index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferDstScalarPerVector_BK1, bool BBlockLdsExtraN, index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock , typename CDEShuffleBlockTransferScalarPerVectors , BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, index_t ActivationOP = 0, bool NSwizzle = false, bool IsInputGemm = true, bool MulRoutedWeight = true, bool PerTokenQuant = true, typename IndexType = index_t, typename ComputeTypeA = CDataType, typename ComputeTypeB = ComputeTypeA, typename LDSTypeA = ComputeTypeA, typename LDSTypeB = ComputeTypeB>
|
staticconstexpr |
Initial value:
◆ BPackedSize
template<typename ALayout , typename BLayout , typename DsLayout , typename CLayout , typename ADataType , typename BDataType , typename DsDataType , typename CDataType , typename GemmAccDataType , typename CShuffleDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , GemmSpecialization GemmSpec, index_t BlockSize, index_t MPerBlock, index_t NPerBlock, index_t KPerBlock, index_t AK1, index_t BK1, index_t MPerXDL, index_t NPerXDL, index_t MXdlPerWave, index_t NXdlPerWave, typename ABlockTransferThreadClusterLengths_AK0_M_AK1 , typename ABlockTransferThreadClusterArrangeOrder , typename ABlockTransferSrcAccessOrder , index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferDstScalarPerVector_AK1, bool ABlockLdsExtraM, typename BBlockTransferThreadClusterLengths_BK0_N_BK1 , typename BBlockTransferThreadClusterArrangeOrder , typename BBlockTransferSrcAccessOrder , index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferDstScalarPerVector_BK1, bool BBlockLdsExtraN, index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock , typename CDEShuffleBlockTransferScalarPerVectors , BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, index_t ActivationOP = 0, bool NSwizzle = false, bool IsInputGemm = true, bool MulRoutedWeight = true, bool PerTokenQuant = true, typename IndexType = index_t, typename ComputeTypeA = CDataType, typename ComputeTypeB = ComputeTypeA, typename LDSTypeA = ComputeTypeA, typename LDSTypeB = ComputeTypeB>
|
staticconstexpr |
Initial value:
◆ NumDTensor
template<typename ALayout , typename BLayout , typename DsLayout , typename CLayout , typename ADataType , typename BDataType , typename DsDataType , typename CDataType , typename GemmAccDataType , typename CShuffleDataType , typename AElementwiseOperation , typename BElementwiseOperation , typename CElementwiseOperation , GemmSpecialization GemmSpec, index_t BlockSize, index_t MPerBlock, index_t NPerBlock, index_t KPerBlock, index_t AK1, index_t BK1, index_t MPerXDL, index_t NPerXDL, index_t MXdlPerWave, index_t NXdlPerWave, typename ABlockTransferThreadClusterLengths_AK0_M_AK1 , typename ABlockTransferThreadClusterArrangeOrder , typename ABlockTransferSrcAccessOrder , index_t ABlockTransferSrcVectorDim, index_t ABlockTransferSrcScalarPerVector, index_t ABlockTransferDstScalarPerVector_AK1, bool ABlockLdsExtraM, typename BBlockTransferThreadClusterLengths_BK0_N_BK1 , typename BBlockTransferThreadClusterArrangeOrder , typename BBlockTransferSrcAccessOrder , index_t BBlockTransferSrcVectorDim, index_t BBlockTransferSrcScalarPerVector, index_t BBlockTransferDstScalarPerVector_BK1, bool BBlockLdsExtraN, index_t CShuffleMXdlPerWavePerShuffle, index_t CShuffleNXdlPerWavePerShuffle, typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock , typename CDEShuffleBlockTransferScalarPerVectors , BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1, index_t ActivationOP = 0, bool NSwizzle = false, bool IsInputGemm = true, bool MulRoutedWeight = true, bool PerTokenQuant = true, typename IndexType = index_t, typename ComputeTypeA = CDataType, typename ComputeTypeB = ComputeTypeA, typename LDSTypeA = ComputeTypeA, typename LDSTypeB = ComputeTypeB>
|
staticconstexpr |
The documentation for this struct was generated from the following file:
- /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.1/include/ck/tensor_operation/gpu/device/impl/device_moe_gemm.hpp
Public Member Functions inherited from