/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_waveletmodel_cshuffle.hpp Source File#
device_gemm_xdl_waveletmodel_cshuffle.hpp
Go to the documentation of this file.
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
GemmSpecialization
Definition: gemm_specialization.hpp:11
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables... callables)
Definition: kernel_launch.hpp:72
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_gemm_xdl_waveletmodel_cshuffle(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, EDataType *__restrict__ p_e_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const EElementwiseOperation e_element_op, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap block_2_etile_map)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:37
Definition: stream_config.hpp:10
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:64
__host__ static constexpr __device__ auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDescriptor_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:314
__host__ static constexpr __device__ auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:298
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const EGridDesc_M_N &e_grid_desc_m_n, const Block2ETileMap &)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:177
remove_cvref_t< decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> DefaultBlock2ETileMap
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:338
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:226
__host__ static constexpr __device__ index_t CalculateGridSize(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:270
remove_cvref_t< decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))> EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:335
__host__ static constexpr __device__ auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:282
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:261
AElementwiseOperation a_element_op_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:327
BGridDesc_N_K b_grid_desc_n_k_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:314
const BDataType * p_b_grid_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:309
Block2ETileMap block_2_etile_map_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:324
CDEElementwiseOperation cde_element_op_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:329
const ADataType * p_a_grid_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:308
ck::tensor_operation::device::DeviceGemm_Xdl_WaveletModel_CShuffle::Argument::b_grid_desc_bk0_n_bk1_
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:319
BElementwiseOperation b_element_op_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:328
EDataType * p_e_grid_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:310
void Print() const
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:299
AGridDesc_M_K a_grid_desc_m_k_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:313
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, EDataType *p_e_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:262
GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:321
ck::tensor_operation::device::DeviceGemm_Xdl_WaveletModel_CShuffle::Argument::a_grid_desc_ak0_m_ak1_
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:318
EGridDesc_M_N e_grid_desc_m_n_
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:315
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:334
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:337
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:412
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:135
static constexpr auto matrix_padder
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:142
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:252
static constexpr auto I1
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:139
static auto MakeInvoker()
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:465
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:163
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_e, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:469
static constexpr auto I2
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:140
static constexpr auto I0
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:138
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:497
typename GridwiseGemm::DefaultBlock2ETileMap Block2ETileMap
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:257
decltype(MakeEGridDescriptor_M_N< ELayout >(1, 1, 1)) EGridDesc_M_N
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:202
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:255
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:433
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:145
GridwiseGemm_k0mk1_k0nk1_mn_xdl_waveletmodel_cshuffle< ADataType, GemmAcEDataType, CShuffleDataType, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_M_K, BGridDesc_N_K, EGridDesc_M_N, NumGemmKPrefetchStage, TileLoadThreadGroupSize, TileMathThreadGroupSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock > GridwiseGemm
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:248
decltype(MakeAGridDescriptor_M_K(1, 1, 1)) AGridDesc_M_K
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:200
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:419
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideE)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:182
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, EDataType *p_e, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideE, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:438
decltype(MakeBGridDescriptor_N_K(1, 1, 1)) BGridDesc_N_K
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:201
std::string GetTypeString() const override
Definition: device_gemm_xdl_waveletmodel_cshuffle.hpp:503
Definition: device_gemm.hpp:22
Definition: matrix_padder.hpp:180