/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp Source File#
device_gemm_multiple_d_xdl_cshuffle.hpp
Go to the documentation of this file.
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:33
GemmSpecialization
Definition: gemm_specialization.hpp:11
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables... callables)
Definition: kernel_launch.hpp:72
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
__global__ void kernel_gemm_multiple_d_xdl_cshuffle(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map)
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:40
__host__ constexpr __device__ auto transform_tuples(F f, const X &x)
Definition: tuple_helper.hpp:86
@ Default
@ Interwave
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
__host__ static constexpr __device__ auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:221
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:396
__host__ static constexpr __device__ auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:204
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:403
__host__ static constexpr __device__ auto MakeDefaultBlock2ETileMap(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:254
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const DsGridDesc_M_N &ds_grid_desc_m_n, const EGridDesc_M_N &e_grid_desc_m_n, [[maybe_unused]] const Block2ETileMap &)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:329
__host__ static constexpr __device__ auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:187
__host__ static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc_M_N &ds_grid_desc_m_n)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:242
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: functional2.hpp:31
Definition: tensor_layout.hpp:21
Definition: tensor_layout.hpp:16
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:313
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:404
index_t MRaw_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:418
const BDataType * p_b_grid_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:392
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:403
EGridDesc_M_N e_grid_desc_m_n_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:400
BElementwiseOperation b_element_op_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:414
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:406
index_t KRaw_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:420
index_t NRaw_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:419
EDataType * p_e_grid_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:394
AGridDesc_M_K a_grid_desc_m_k_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:397
Block2ETileMap block_2_etile_map_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:410
AElementwiseOperation a_element_op_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:413
void Print() const
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:380
DsGridDesc_M_N ds_grid_desc_m_n_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:399
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:407
GridwiseGemm::DsGridPointer p_ds_grid_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:393
const ADataType * p_a_grid_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:391
CDEElementwiseOperation cde_element_op_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:415
Argument(const void *p_a_grid, const void *p_b_grid, std::array< const void *, NumDTensor > p_ds_grid, void *p_e_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:314
BGridDesc_N_K b_grid_desc_n_k_
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:398
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:714
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:751
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{})))> AGridDesc_AK0_M_AK1
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:730
index_t NRaw
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:765
index_t MRaw
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:764
AElementwiseOperation a_element_op
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:759
constexpr index_t GetGridSize() const
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:821
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{})))> BGridDesc_BK0_N_BK1
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:733
DsGridDesc_M_N ds_grid_desc_m_n
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:746
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:750
constexpr bool IsValid() const
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:809
EGridDesc_M_N e_grid_desc_m_n
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:747
Block2ETileMap block_2_etile_map
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:756
remove_cvref_t< decltype(DeviceOp::matrix_padder.PadADescriptor_M_K(ADesc{}))> AGridDesc_M_K
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:722
constexpr Descriptor(ADesc a, BDesc b, DsDesc ds, EDesc e, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CDEElementwiseOperation cde_element_op_)
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:770
AGridDesc_M_K a_grid_desc_m_k
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:744
remove_cvref_t< decltype(ds_tuple())> DsGridDesc_M_N
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:725
remove_cvref_t< decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))> EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:739
constexpr index_t GetBlockSize() const
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:819
index_t KRaw
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:766
bool has_main_k_block_loop
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:768
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{})))> Block2ETileMap
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:741
remove_cvref_t< decltype(DeviceOp::matrix_padder.PadCDescriptor_M_N(EDesc{}))> EGridDesc_M_N
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:727
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock ds_grid_desc_mblock_mperblock_nblock_nperblock
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:752
BElementwiseOperation b_element_op
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:760
remove_cvref_t< decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_tuple()))> DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:736
CDEElementwiseOperation cde_element_op
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:761
static constexpr auto ds_tuple()
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:715
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:753
BGridDesc_N_K b_grid_desc_n_k
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:745
remove_cvref_t< decltype(DeviceOp::matrix_padder.PadBDescriptor_N_K(BDesc{}))> BGridDesc_N_K
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:724
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:425
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:493
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:428
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:159
static auto MakeDsGridDescriptor_M_N(const std::array< index_t, NumDTensor > &MRaws, const std::array< index_t, NumDTensor > &NRaws, const std::array< index_t, NumDTensor > &DsStride)
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:227
static constexpr auto I1
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:165
GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched, PipelineVer > GridwiseGemm
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:291
std::string GetTypeString() const override
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:674
remove_cvref_t< decltype(MakeDsGridDescriptor_M_N({}, {}, {}))> DsGridDesc_M_N
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:243
static constexpr bool IsSupported(index_t MRaw_, index_t NRaw_, index_t KRaw_)
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:500
decltype(MakeBGridDescriptor_N_K(1, 1, 1)) BGridDesc_N_K
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:242
static auto MakeArgument(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:601
remove_cvref_t< decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))> EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:305
static constexpr auto make_descriptor(ADesc a, BDesc b, DsDesc ds, EDesc e, AElementwiseOperation a_element_op=AElementwiseOperation{}, BElementwiseOperation b_element_op=BElementwiseOperation{}, CDEElementwiseOperation cde_element_op=CDEElementwiseOperation{})
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:829
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> Block2ETileMap
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:309
static constexpr auto I0
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:164
decltype(MakeAGridDescriptor_M_K(1, 1, 1)) AGridDesc_M_K
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:241
static constexpr auto matrix_padder
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:169
decltype(MakeEGridDescriptor_M_N< ELayout >(1, 1, 1)) EGridDesc_M_N
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:244
static __device__ void Run(const Desc &desc, const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, DsPointer p_ds_grid, EDataType *__restrict__ p_e_grid)
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:842
static constexpr index_t NumDTensor
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:162
static auto MakeInvoker()
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:632
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_e, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:636
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:596
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:668
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:580
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:190
static constexpr auto I3
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:167
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:209
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:172
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:296
static constexpr auto I2
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:166
remove_cvref_t< decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(DsGridDesc_M_N{}))> DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:302
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition: device_gemm_multiple_d_xdl_cshuffle.hpp:299
Definition: device_gemm_multiple_d.hpp:34
Definition: matrix_padder.hpp:180