/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_cgemm_4gemm_xdl_cshuffle.hpp Source File#
device_cgemm_4gemm_xdl_cshuffle.hpp
Go to the documentation of this file.
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
__global__ void kernel_gemm_xdl_cshuffle_v1(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:25
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__global__ void kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op)
Definition: gridwise_elementwise_2d.hpp:29
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_elementwise_2d.hpp:162
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:414
index_t N
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:456
index_t K
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:457
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:437
index_t M
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:455
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:114
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:557
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:142
static __host__ auto CalculateAK0(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:152
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:137
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:661
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:361
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_xdl_cshuffle_v1.hpp:132
Definition: sequence.hpp:43
Definition: tuple.hpp:117
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:172
typename GridwiseGemm::Problem Problem
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:173
CGridDesc_M_N c_grid_desc_m_n
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:218
CDataType * p_c_grid_real
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:214
const BDataType * p_b_grid_imag
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:213
const ADataType * p_a_grid_imag
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:211
const ADataType * p_a_grid_real
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:210
CDataType * p_aux_grid
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:216
CDataType * p_aux_2_grid
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:217
CDataType * p_c_grid_imag
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:215
Argument(const ADataType *p_a_grid_real_, const ADataType *p_a_grid_imag_, const BDataType *p_b_grid_real_, const BDataType *p_b_grid_imag_, CDataType *p_c_grid_real_, CDataType *p_c_grid_imag_, CDataType *p_workspace, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:175
const BDataType * p_b_grid_real
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:212
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:223
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:224
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:465
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:76
static constexpr auto I2
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:81
static std::size_t GetCElementSpaceSize(index_t M, index_t N, index_t StrideC)
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:588
static auto MakeDescriptor_M_N(const std::vector< index_t > &lengths, const std::vector< index_t > &strides)
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:109
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:489
static constexpr bool IsValidCompilationParameter()
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:472
static auto MakeInvoker()
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:526
std::size_t GetWorkspaceSize(index_t M, index_t N, [[maybe_unused]] index_t K, [[maybe_unused]] index_t StrideA, [[maybe_unused]] index_t StrideB, index_t StrideC) const override
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:596
static constexpr index_t MPerThread
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:83
decltype(MakeDescriptor_M_N({1, 1}, {1, 1})) CGridDesc_M_N
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:168
static constexpr auto I1
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:80
static constexpr auto I0
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:79
static constexpr auto CScalarPerVector
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:90
std::string GetTypeString() const override
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:569
static auto PadDescriptor_M_N(Desc_M_N desc)
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:93
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a_real, const void *p_a_imag, const void *p_b_real, const void *p_b_imag, void *p_c_real, void *p_c_imag, void *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, index_t=1) override
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:529
static constexpr auto BScalarPerVector
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:89
static auto MakeArgument(const ADataType *p_a_real, const ADataType *p_a_imag, const BDataType *p_b_real, const BDataType *p_b_imag, CDataType *p_c_real, CDataType *p_c_imag, CDataType *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:494
static bool IsSupportedArgument(const Argument &arg)
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:478
std::size_t GetWorkSpaceSize(const BaseArgument *base_arg) const override
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:606
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, CDataType, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, GemmSpec, InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, LoopSched > GridwiseGemm
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:166
static constexpr index_t NPerThread
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:85
static constexpr auto AScalarPerVector
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:88
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_cgemm_4gemm_xdl_cshuffle.hpp:563
Definition: device_cgemm.hpp:15
Definition: binary_element_wise_operation.hpp:14
Definition: binary_element_wise_operation.hpp:237