/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_splitk_cshuffle.hpp Source File#
device_grouped_gemm_xdl_splitk_cshuffle.hpp
Go to the documentation of this file.
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:33
__global__ void kernel_grouped_gemm_xdl_splitk(const void CK_CONSTANT_ADDRESS_SPACE *gemm_descs_const, const index_t group_count, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op)
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:37
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:264
@ AtomicAdd
__host__ __device__ T CK_CONSTANT_ADDRESS_SPACE * cast_pointer_to_constant_address_space(T *p)
Definition: amd_address_space.hpp:35
__device__ T * cast_pointer_to_generic_address_space(T CK_CONSTANT_ADDRESS_SPACE *p)
Definition: amd_address_space.hpp:24
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: block_to_ctile_map.hpp:539
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:128
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:103
__host__ static __device__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:200
__host__ static constexpr __device__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:440
__host__ static __device__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:205
__host__ static __device__ auto CalculateK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:210
__host__ static constexpr __device__ bool CalculateHasMainK0BlockLoop(index_t K0Padded)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:609
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:375
remove_cvref_t< decltype(MakeCGridDescriptor_M_N(1, 1, 1))> CGridDesc_M_N
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:661
__host__ static __device__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdlops_v2r4r2.hpp:217
Definition: integral_constant.hpp:10
Definition: device_base.hpp:50
void * p_workspace_
Definition: device_base.hpp:57
Definition: device_base.hpp:61
virtual void SetWorkSpacePointer(BaseArgument *p_arg, void *p_workspace, const StreamConfig &=StreamConfig{}) const
Definition: device_base.hpp:101
Definition: device_grouped_gemm_splitk.hpp:33
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:231
index_t skipped_group_count_
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:365
index_t K_BATCH
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:363
void UpdateKBatch(index_t kbatch)
Recalculate group grid size for all gemms and update B2C maps.
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:324
std::vector< GemmTransKernelArg > gemm_kernel_args_
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:367
Argument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, index_t kbatch)
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:242
index_t grid_size_
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:368
Argument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs)
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:233
index_t group_count_
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:364
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:209
GemmTransKernelArg(KernelArgument &&karg, GroupedGemmBlock2ETileMap &&b2c_map, index_t block_start, index_t block_end)
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:215
KernelArgument karg_
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:210
index_t block_start_
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:212
GemmTransKernelArg()=default
GroupedGemmBlock2ETileMap block_2_ctile_map_
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:211
index_t block_end_
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:212
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:373
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:508
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:374
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:145
static bool IsSupportedArgument(const Argument &arg)
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:521
std::string GetTypeString() const override
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:606
void SetKBatchSize(BaseArgument *p_arg, index_t kbatch) const override
Sets the k batch size.
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:659
static constexpr auto I1
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:149
static auto MakeInvoker()
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:583
static auto MakeArgument(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor >> &, std::vector< void * > &p_Es, std::vector< GemmDesc > gemm_descs, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation)
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:571
static constexpr auto I2
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:150
static constexpr index_t DefaultKBatch
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:227
static constexpr index_t B2E_M01
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:204
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:637
void SetDeviceKernelArgs(BaseArgument *p_arg, void *p_dev_kernel_args) const override
Sets the device kernel arguments pointer and may copy data to device.
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:672
size_t GetDeviceKernelArgSize(const BaseArgument *p_arg) const override
Gets the device kernel argument size.
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:650
static constexpr auto I3
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:151
OffsettedBlockToCTileMap< Block2ETileMapKSplit > GroupedGemmBlock2ETileMap
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:205
static constexpr bool IsValidCompilationParameter()
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:515
static constexpr index_t NumDTensor
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:146
GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4r2< BlockSize, ADataType, BDataType, AccDataType, EDataType, ALayout, BLayout, ELayout, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, GemmSpec, NumGemmKPrefetchStage, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, AK1, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferScalarPerVector_NPerBlock, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, LoopSched, PipelineVer > GridwiseGemm
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:198
std::unique_ptr< BaseArgument > MakeArgumentPointer(std::vector< const void * > &p_As, std::vector< const void * > &p_Bs, std::vector< std::array< const void *, NumDTensor >> &, std::vector< void * > &p_Es, std::vector< GemmDesc > &gemm_descs, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation) override
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:587
typename GridwiseGemm::CGridDesc_M_N CGridDesc_M_N
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:200
ck::tensor_operation::element_wise::PassThrough PassThrough
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:207
typename GridwiseGemm::Argument KernelArgument
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:206
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:566
static constexpr auto I0
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:148
static void SetKBatchSize(Argument &arg, index_t kbatch)
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:656
static constexpr index_t K0PerBlock
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:153
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_grouped_gemm_xdl_splitk_cshuffle.hpp:600
Definition: device_grouped_gemm.hpp:86
Definition: unary_element_wise_operation.hpp:241