/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_dl.hpp Source File#
device_gemm_dl.hpp
Go to the documentation of this file.
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:264
typename remove_reference< T >::type remove_reference_t
Definition: type.hpp:292
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_gemm_dl_v1r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map)
Definition: gridwise_gemm_dl_v1r3.hpp:33
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
Definition: stream_config.hpp:10
Definition: gridwise_gemm_dl_v1r3.hpp:93
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:129
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:208
__host__ static constexpr __device__ auto MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1)
Definition: gridwise_gemm_dl_v1r3.hpp:188
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_dl_v1r3.hpp:146
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_v1r3.hpp:153
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:241
__host__ static constexpr __device__ bool CalculateHasDoubleTailKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_v1r3.hpp:160
__host__ static constexpr __device__ auto MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1)
Definition: gridwise_gemm_dl_v1r3.hpp:168
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_gemm_dl.hpp:252
index_t M_raw_
Definition: device_gemm_dl.hpp:320
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_
Definition: device_gemm_dl.hpp:306
CGridDesc_M_N c_grid_desc_m_n_
Definition: device_gemm_dl.hpp:308
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_
Definition: device_gemm_dl.hpp:311
index_t M01_
Definition: device_gemm_dl.hpp:317
index_t N01_
Definition: device_gemm_dl.hpp:318
CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_
Definition: device_gemm_dl.hpp:312
index_t K_raw_
Definition: device_gemm_dl.hpp:322
CDataType * p_c_grid_
Definition: device_gemm_dl.hpp:304
index_t N_raw_
Definition: device_gemm_dl.hpp:321
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_
Definition: device_gemm_dl.hpp:307
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t M01, index_t N01, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_gemm_dl.hpp:253
const BDataType * p_b_grid_
Definition: device_gemm_dl.hpp:303
AElementwiseOperation a_element_op_
Definition: device_gemm_dl.hpp:325
BElementwiseOperation b_element_op_
Definition: device_gemm_dl.hpp:326
DefaultBlock2CTileMap block_2_ctile_map_
Definition: device_gemm_dl.hpp:314
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_
Definition: device_gemm_dl.hpp:310
CElementwiseOperation c_element_op_
Definition: device_gemm_dl.hpp:327
const ADataType * p_a_grid_
Definition: device_gemm_dl.hpp:302
Definition: device_gemm_dl.hpp:332
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_dl.hpp:335
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_dl.hpp:479
Definition: device_gemm_dl.hpp:77
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{})) DefaultBlock2CTileMap
Definition: device_gemm_dl.hpp:248
static constexpr auto I0
Definition: device_gemm_dl.hpp:78
static constexpr auto I2
Definition: device_gemm_dl.hpp:80
GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector > GridwiseGemm
Definition: device_gemm_dl.hpp:239
virtual std::string GetTypeString() const override
Definition: device_gemm_dl.hpp:621
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_dl.hpp:548
decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1)) BGridDesc_K0_N_K1
Definition: device_gemm_dl.hpp:201
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_dl.hpp:492
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})) AGridDesc_K0_M0_M1_K1
Definition: device_gemm_dl.hpp:242
static constexpr auto I3
Definition: device_gemm_dl.hpp:81
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_dl.hpp:615
static auto MakeInvoker()
Definition: device_gemm_dl.hpp:582
decltype(MakeAGridDescriptor_K0_M_K1(1, 1, 1)) AGridDesc_K0_M_K1
Definition: device_gemm_dl.hpp:200
static auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t StrideB)
Definition: device_gemm_dl.hpp:126
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})) BGridDesc_K0_N0_N1_K1
Definition: device_gemm_dl.hpp:244
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override
Definition: device_gemm_dl.hpp:585
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})) CGridDesc_M0_M10_M11_N0_N10_N11
Definition: device_gemm_dl.hpp:246
static constexpr auto I5
Definition: device_gemm_dl.hpp:83
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) CGridDesc_M_N
Definition: device_gemm_dl.hpp:202
static constexpr auto I1
Definition: device_gemm_dl.hpp:79
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_dl.hpp:486
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op)
Definition: device_gemm_dl.hpp:553
static constexpr auto I4
Definition: device_gemm_dl.hpp:82
static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC)
Definition: device_gemm_dl.hpp:165
static auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t K, index_t StrideA)
Definition: device_gemm_dl.hpp:87
static constexpr auto K1Number
Definition: device_gemm_dl.hpp:85
Definition: device_gemm.hpp:22