/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp Source File#
device_gemm_wmma.hpp
Go to the documentation of this file.
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
GemmSpecialization
Definition: gemm_specialization.hpp:11
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables... callables)
Definition: kernel_launch.hpp:72
Definition: ck.hpp:264
typename remove_reference< T >::type remove_reference_t
Definition: type.hpp:292
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_gemm_wmma(const ADataType *__restrict__ p_a_grid, const BDataType *__restrict__ p_b_grid, CDataType *__restrict__ p_c_grid, const AGridDesc a_grid_desc, const BGridDesc b_grid_desc, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition: gridwise_gemm_wmma.hpp:36
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
@ Default
@ Interwave
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_gemm_wmma.hpp:123
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition: gridwise_gemm_wmma.hpp:530
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition: gridwise_gemm_wmma.hpp:570
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc &a_grid_desc, const BGridDesc &b_grid_desc, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_wmma.hpp:412
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_wmma.hpp:511
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_wmma.hpp:503
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))> DefaultBlock2CTileMap
Definition: gridwise_gemm_wmma.hpp:572
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: tensor_layout.hpp:21
Definition: tensor_layout.hpp:16
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_gemm.hpp:22
Definition: device_gemm_wmma.hpp:295
CElementwiseOperation c_element_op_
Definition: device_gemm_wmma.hpp:357
AElementwiseOperation a_element_op_
Definition: device_gemm_wmma.hpp:355
AGridDesc a_grid_desc_
Definition: device_gemm_wmma.hpp:347
index_t M01_
Definition: device_gemm_wmma.hpp:353
index_t NRaw_
Definition: device_gemm_wmma.hpp:360
index_t N01_
Definition: device_gemm_wmma.hpp:354
const ADataType * p_a_grid_
Definition: device_gemm_wmma.hpp:344
BGridDesc b_grid_desc_k0_n_k1_
Definition: device_gemm_wmma.hpp:348
GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_
Definition: device_gemm_wmma.hpp:352
index_t MRaw_
Definition: device_gemm_wmma.hpp:359
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t M01, index_t N01, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_gemm_wmma.hpp:296
GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock
Definition: device_gemm_wmma.hpp:351
index_t KRaw_
Definition: device_gemm_wmma.hpp:361
BElementwiseOperation b_element_op_
Definition: device_gemm_wmma.hpp:356
CGridDesc_M_N c_grid_desc_m_n_
Definition: device_gemm_wmma.hpp:349
CDataType * p_c_grid_
Definition: device_gemm_wmma.hpp:346
const BDataType * p_b_grid_
Definition: device_gemm_wmma.hpp:345
Definition: device_gemm_wmma.hpp:366
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_wmma.hpp:438
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_wmma.hpp:369
Definition: device_gemm_wmma.hpp:76
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
Definition: device_gemm_wmma.hpp:221
static constexpr auto K1Number
Definition: device_gemm_wmma.hpp:85
static auto MakeAGridDescriptor(index_t MRaw, index_t KRaw, index_t StrideA)
Definition: device_gemm_wmma.hpp:113
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_wmma.hpp:604
static auto MakeInvoker()
Definition: device_gemm_wmma.hpp:571
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_wmma.hpp:537
static constexpr auto AEnableLds_manu
Definition: device_gemm_wmma.hpp:104
static constexpr auto AEnableLds
Definition: device_gemm_wmma.hpp:107
static auto MakeBGridDescriptor(index_t KRaw, index_t NRaw, index_t StrideB)
Definition: device_gemm_wmma.hpp:167
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_wmma.hpp:451
static constexpr auto BEnableLds
Definition: device_gemm_wmma.hpp:108
static constexpr auto I3
Definition: device_gemm_wmma.hpp:80
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) CGridDesc_M_N
Definition: device_gemm_wmma.hpp:242
static constexpr auto MaxVectorLoadA
Definition: device_gemm_wmma.hpp:90
static constexpr auto I1
Definition: device_gemm_wmma.hpp:78
static constexpr auto AEnableLds_auto
Definition: device_gemm_wmma.hpp:93
std::string GetTypeString() const override
Definition: device_gemm_wmma.hpp:610
static constexpr auto I6
Definition: device_gemm_wmma.hpp:83
static constexpr auto I0
Definition: device_gemm_wmma.hpp:77
static constexpr auto I2
Definition: device_gemm_wmma.hpp:79
static constexpr auto I5
Definition: device_gemm_wmma.hpp:82
decltype(MakeBGridDescriptor(1, 1, 1)) BGridDesc
Definition: device_gemm_wmma.hpp:241
static constexpr auto MWaves
Definition: device_gemm_wmma.hpp:87
GridwiseGemm_Wmma< BlockSize, ADataType, BDataType, AccDataType, CShuffleDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc, BGridDesc, CGridDesc_M_N, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, K1, MRepeat, NRepeat, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, AEnableLds, ABlockLdsAddExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BEnableLds, BBlockLdsAddExtraN, CShuffleMRepeatPerShuffle, CShuffleNRepeatPerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, NumPrefetch, LoopSched, PipelineVer > GridwiseGemm
Definition: device_gemm_wmma.hpp:291
static constexpr auto NWaves
Definition: device_gemm_wmma.hpp:88
static constexpr auto matrix_padder
Definition: device_gemm_wmma.hpp:110
static constexpr auto WmmaK
Definition: device_gemm_wmma.hpp:89
decltype(MakeAGridDescriptor(1, 1, 1)) AGridDesc
Definition: device_gemm_wmma.hpp:240
static constexpr auto BEnableLds_auto
Definition: device_gemm_wmma.hpp:97
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_wmma.hpp:445
static constexpr auto MaxVectorLoadB
Definition: device_gemm_wmma.hpp:91
static constexpr auto BEnableLds_manu
Definition: device_gemm_wmma.hpp:105
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_gemm_wmma.hpp:542
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_gemm_wmma.hpp:574
static constexpr auto I4
Definition: device_gemm_wmma.hpp:81
Definition: matrix_padder.hpp:180