/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp Source File#
device_grouped_gemm_xdl.hpp
Go to the documentation of this file.
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:33
GemmSpecialization
Definition: gemm_specialization.hpp:11
__global__ void kernel_grouped_gemm_xdl(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation c_element_op)
Definition: device_grouped_gemm_xdl.hpp:35
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables... callables)
Definition: kernel_launch.hpp:72
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition: amd_address_space.hpp:35
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition: amd_address_space.hpp:24
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
__host__ static constexpr __device__ auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:221
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:396
__host__ static constexpr __device__ auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:204
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:403
__host__ static constexpr __device__ auto MakeDefaultBlock2ETileMap(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:254
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const DsGridDesc_M_N &ds_grid_desc_m_n, const EGridDesc_M_N &e_grid_desc_m_n, [[maybe_unused]] const Block2ETileMap &)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:329
__host__ static constexpr __device__ auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:187
__host__ static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N &ds_grid_desc_m_n)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:242
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: functional2.hpp:31
Definition: device_base.hpp:50
void * p_workspace_
Definition: device_base.hpp:57
Definition: device_base.hpp:61
virtual void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const
Definition: device_base.hpp:101
Definition: device_grouped_gemm_xdl.hpp:360
std::vector< GemmBiasTransKernelArg > gemm_desc_kernel_arg_
Definition: device_grouped_gemm_xdl.hpp:498
std::vector< Tuple< index_t, index_t > > a_mtx_mraw_kraw_
Definition: device_grouped_gemm_xdl.hpp:499
Argument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor >> &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation c_element_op)
Definition: device_grouped_gemm_xdl.hpp:361
AElementwiseOperation a_element_op_
Definition: device_grouped_gemm_xdl.hpp:494
CDEElementwiseOperation c_element_op_
Definition: device_grouped_gemm_xdl.hpp:496
std::vector< Tuple< index_t, index_t > > b_mtx_nraw_kraw_
Definition: device_grouped_gemm_xdl.hpp:500
index_t skipped_group_count_
Definition: device_grouped_gemm_xdl.hpp:492
index_t grid_size_
Definition: device_grouped_gemm_xdl.hpp:502
index_t group_count_
Definition: device_grouped_gemm_xdl.hpp:491
BElementwiseOperation b_element_op_
Definition: device_grouped_gemm_xdl.hpp:495
Definition: device_grouped_gemm_xdl.hpp:333
EGridDesc_M_N e_grid_desc_m_n_
Definition: device_grouped_gemm_xdl.hpp:344
DsGridDesc_M_N ds_grid_desc_m_n_
Definition: device_grouped_gemm_xdl.hpp:343
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition: device_grouped_gemm_xdl.hpp:348
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition: device_grouped_gemm_xdl.hpp:351
const ADataType * a_ptr_
Definition: device_grouped_gemm_xdl.hpp:335
BGridDesc_N_K b_grid_desc_n_k_
Definition: device_grouped_gemm_xdl.hpp:342
AGridDesc_M_K a_grid_desc_m_k_
Definition: device_grouped_gemm_xdl.hpp:341
GroupedGemmBlock2ETileMap block_2_etile_map_
Definition: device_grouped_gemm_xdl.hpp:354
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition: device_grouped_gemm_xdl.hpp:347
EDataType * e_ptr_
Definition: device_grouped_gemm_xdl.hpp:338
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_
Definition: device_grouped_gemm_xdl.hpp:350
ck::index_t BlockStart_
Definition: device_grouped_gemm_xdl.hpp:355
GridwiseGemm::DsGridPointer ds_ptr_
Definition: device_grouped_gemm_xdl.hpp:337
ck::index_t BlockEnd_
Definition: device_grouped_gemm_xdl.hpp:355
const BDataType * b_ptr_
Definition: device_grouped_gemm_xdl.hpp:336
Definition: device_grouped_gemm_xdl.hpp:292
__host__ __device__ bool ValidCTileIndex(const CTileIdx &c_tile_idx, const CTileDim &c_tile_dim) const
Definition: device_grouped_gemm_xdl.hpp:317
GroupedGemmBlock2ETileMap()
Definition: device_grouped_gemm_xdl.hpp:296
Block2ETileMap block_2_etile_map_
Definition: device_grouped_gemm_xdl.hpp:328
GroupedGemmBlock2ETileMap(const EGridDesc_M_N &e_grid_desc_m_n, ck::index_t BlockStart)
Definition: device_grouped_gemm_xdl.hpp:302
ck::index_t BlockStart_
Definition: device_grouped_gemm_xdl.hpp:329
ck::tensor_operation::device::DeviceGroupedGemm_Xdl::GroupedGemmBlock2ETileMap::CalculateBottomIndex
__host__ constexpr __device__ auto CalculateBottomIndex(const TopIdx &idx_top) const
Definition: device_grouped_gemm_xdl.hpp:309
__host__ bool CheckValidity(const EGridDesc_M_N &e_grid_desc_m_n) const
Definition: device_grouped_gemm_xdl.hpp:323
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> Block2ETileMap
Definition: device_grouped_gemm_xdl.hpp:294
Definition: device_grouped_gemm_xdl.hpp:507
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_grouped_gemm_xdl.hpp:510
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_grouped_gemm_xdl.hpp:602
Definition: device_grouped_gemm_xdl.hpp:145
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition: device_grouped_gemm_xdl.hpp:283
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
Definition: device_grouped_gemm_xdl.hpp:194
decltype(MakeBGridDescriptor_N_K(1, 1, 1)) BGridDesc_N_K
Definition: device_grouped_gemm_xdl.hpp:226
static auto MakeArgument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor >> &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation c_element_op)
Definition: device_grouped_gemm_xdl.hpp:652
std::string GetTypeString() const override
Definition: device_grouped_gemm_xdl.hpp:689
static auto MakeDsGridDescriptor_M_N(const std::array< index_t, NumDTensor > &MRaws, const std::array< index_t, NumDTensor > &NRaws, const std::array< index_t, NumDTensor > &DsStride)
Definition: device_grouped_gemm_xdl.hpp:212
void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args) const override
Sets the device kernel arguments pointer and may copy data to device.
Definition: device_grouped_gemm_xdl.hpp:734
decltype(MakeAGridDescriptor_M_K(1, 1, 1)) AGridDesc_M_K
Definition: device_grouped_gemm_xdl.hpp:225
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition: device_grouped_gemm_xdl.hpp:175
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor >> &p_Ds, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation c_element_op) override
Definition: device_grouped_gemm_xdl.hpp:669
ADataType ComputeDataType
Definition: device_grouped_gemm_xdl.hpp:230
remove_cvref_t< decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))> DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: device_grouped_gemm_xdl.hpp:286
static constexpr auto I1
Definition: device_grouped_gemm_xdl.hpp:151
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition: device_grouped_gemm_xdl.hpp:280
remove_cvref_t< decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))> EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: device_grouped_gemm_xdl.hpp:289
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_grouped_gemm_xdl.hpp:647
size_t GetDeviceKernelArgSize(const BaseArgument *p_arg) const override
Gets the device kernel argument size.
Definition: device_grouped_gemm_xdl.hpp:729
static constexpr auto I2
Definition: device_grouped_gemm_xdl.hpp:152
static constexpr auto matrix_padder
Definition: device_grouped_gemm_xdl.hpp:154
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N({}, {}, {}))> DsGridDesc_M_N
Definition: device_grouped_gemm_xdl.hpp:227
static bool IsSupportedArgument(const Argument &arg)
Definition: device_grouped_gemm_xdl.hpp:609
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_grouped_gemm_xdl.hpp:683
static auto MakeInvoker()
Definition: device_grouped_gemm_xdl.hpp:665
static constexpr index_t NumDTensor
Definition: device_grouped_gemm_xdl.hpp:148
decltype(MakeEGridDescriptor_M_N< ELayout >(1, 1, 1)) EGridDesc_M_N
Definition: device_grouped_gemm_xdl.hpp:228
GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, NumPrefetch, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched > GridwiseGemm
Definition: device_grouped_gemm_xdl.hpp:276
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
Definition: device_grouped_gemm_xdl.hpp:157
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition: device_grouped_gemm_xdl.hpp:717
static constexpr auto I0
Definition: device_grouped_gemm_xdl.hpp:150
Definition: device_grouped_gemm.hpp:105
Definition: device_grouped_gemm.hpp:86
Definition: matrix_padder.hpp:180