/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_e_permute_xdl.hpp Source File#
device_batched_gemm_e_permute_xdl.hpp
Go to the documentation of this file.
26 * strided batched, but we can easily extend to other layouts. The returned offset can be either \p
34 * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
37 * impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
39 * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
42 * \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
43 * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__global__ void kernel_batched_gemm_e_permute_xdl(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, EDataType *__restrict__ p_e_grid, const index_t batch_count, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CDEElementwiseOperation cde_element_op, const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch, const Block2ETileMap block_2_etile_map)
Definition: device_batched_gemm_e_permute_xdl.hpp:63
GemmSpecialization
Definition: gemm_specialization.hpp:11
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables... callables)
Definition: kernel_launch.hpp:72
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename remove_reference< T >::type remove_reference_t
Definition: type.hpp:292
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:78
__host__ static constexpr __device__ auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:221
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:396
__host__ static constexpr __device__ auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:204
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const DsGridDesc_M_N &ds_grid_desc_m_n, const EGridDesc_M_N &e_grid_desc_m_n, [[maybe_unused]] const Block2ETileMap &)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:329
__host__ static constexpr __device__ auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k)
Definition: gridwise_gemm_multiple_d_xdl_cshuffle.hpp:187
Definition: sequence.hpp:43
Definition: tuple.hpp:186
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_batched_gemm_e_permute.hpp:12
Definition: device_batched_gemm_e_permute.hpp:27
Definition: device_batched_gemm_e_permute_xdl.hpp:399
void Print() const
Definition: device_batched_gemm_e_permute_xdl.hpp:458
EDataType * p_e_grid_
Definition: device_batched_gemm_e_permute_xdl.hpp:469
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition: device_batched_gemm_e_permute_xdl.hpp:480
BGridDesc_N_K b_grid_desc_n_k_
Definition: device_batched_gemm_e_permute_xdl.hpp:476
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition: device_batched_gemm_e_permute_xdl.hpp:481
CDEElementwiseOperation cde_element_op_
Definition: device_batched_gemm_e_permute_xdl.hpp:494
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_
Definition: device_batched_gemm_e_permute_xdl.hpp:486
EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n_
Definition: device_batched_gemm_e_permute_xdl.hpp:483
EGridDesc_M_N e_grid_desc_m_n_
Definition: device_batched_gemm_e_permute_xdl.hpp:477
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, EDataType *p_e_grid, index_t M, index_t N, index_t K, index_t stride_A, index_t stride_B, index_t batch_stride_A, index_t batch_stride_B, BatchedGemmEPermuteDesc batched_gemm_e_permute_desc, index_t BatchCount, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition: device_batched_gemm_e_permute_xdl.hpp:400
const ADataType * p_a_grid_
Definition: device_batched_gemm_e_permute_xdl.hpp:467
index_t BatchCount_
Definition: device_batched_gemm_e_permute_xdl.hpp:472
const BDataType * p_b_grid_
Definition: device_batched_gemm_e_permute_xdl.hpp:468
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock e_grid_desc_mblock_mperblock_nblock_nperblock
Definition: device_batched_gemm_e_permute_xdl.hpp:482
AGridDesc_M_K a_grid_desc_m_k_
Definition: device_batched_gemm_e_permute_xdl.hpp:475
Block2ETileMap block_2_etile_map_
Definition: device_batched_gemm_e_permute_xdl.hpp:489
BElementwiseOperation b_element_op_
Definition: device_batched_gemm_e_permute_xdl.hpp:493
AElementwiseOperation a_element_op_
Definition: device_batched_gemm_e_permute_xdl.hpp:492
Definition: device_batched_gemm_e_permute_xdl.hpp:299
ComputePtrOffsetOfStridedBatch(index_t Batchstride_A, index_t Batchstride_B, EGridDesc_G0_G1_M_N e_grid_desc_g0_g1_m_n)
Definition: device_batched_gemm_e_permute_xdl.hpp:300
__host__ constexpr __device__ long_index_t GetCPtrOffset(index_t g_idx) const
Definition: device_batched_gemm_e_permute_xdl.hpp:319
__host__ constexpr __device__ long_index_t GetAPtrOffset(index_t g_idx) const
Definition: device_batched_gemm_e_permute_xdl.hpp:309
__host__ constexpr __device__ long_index_t GetBPtrOffset(index_t g_idx) const
Definition: device_batched_gemm_e_permute_xdl.hpp:314
Definition: device_batched_gemm_e_permute_xdl.hpp:499
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_batched_gemm_e_permute_xdl.hpp:502
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_batched_gemm_e_permute_xdl.hpp:566
Definition: device_batched_gemm_e_permute_xdl.hpp:171
decltype(MakeEGridDescriptor_G0_G1_M_N(1, 1, 1, 1, 1, 1, 1, 1)) EGridDesc_G0_G1_M_N
Definition: device_batched_gemm_e_permute_xdl.hpp:296
static auto MakeInvoker()
Definition: device_batched_gemm_e_permute_xdl.hpp:632
static auto MakeBGridDescriptor_N_K(index_t KRaw, index_t NRaw, index_t StrideB)
Definition: device_batched_gemm_e_permute_xdl.hpp:199
std::string GetTypeString() const override
Definition: device_batched_gemm_e_permute_xdl.hpp:676
static auto MakeEGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t stride_M, index_t stride_N)
Definition: device_batched_gemm_e_permute_xdl.hpp:218
decltype(MakeAGridDescriptor_M_K(1, 1, 1)) AGridDesc_M_K
Definition: device_batched_gemm_e_permute_xdl.hpp:293
static auto MakeAGridDescriptor_M_K(index_t MRaw, index_t KRaw, index_t StrideA)
Definition: device_batched_gemm_e_permute_xdl.hpp:181
static constexpr auto I1
Definition: device_batched_gemm_e_permute_xdl.hpp:175
static constexpr auto matrix_padder
Definition: device_batched_gemm_e_permute_xdl.hpp:178
decltype(MakeBGridDescriptor_N_K(1, 1, 1)) BGridDesc_N_K
Definition: device_batched_gemm_e_permute_xdl.hpp:294
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, EDataType *p_e, index_t M, index_t N, index_t K, index_t stride_A, index_t stride_B, index_t batch_stride_A, index_t batch_stride_B, BatchedGemmEPermuteDesc batched_gemm_e_permute_desc, index_t BatchCount, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition: device_batched_gemm_e_permute_xdl.hpp:599
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_e, index_t M, index_t N, index_t K, index_t stride_A, index_t stride_B, index_t batch_stride_A, index_t batch_stride_B, BatchedGemmEPermuteDesc batched_gemm_e_permute_desc, index_t BatchCount, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op) override
Definition: device_batched_gemm_e_permute_xdl.hpp:636
static constexpr auto I0
Definition: device_batched_gemm_e_permute_xdl.hpp:174
static constexpr bool IsValidCompilationParameter()
Definition: device_batched_gemm_e_permute_xdl.hpp:573
GridwiseGemmMultipleD_xdl_cshuffle< ADataType, BDataType, ComputeDataType, AccDataType, CShuffleDataType, ck::Tuple<>, EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, InMemoryDataOperationEnum::Set, AGridDesc_M_K, BGridDesc_N_K, Tuple<>, EGridDesc_M_N, NumPrefetch, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CDEBlockTransferScalarPerVector_NPerBlock, LoopSched > GridwiseGemm
Definition: device_batched_gemm_e_permute_xdl.hpp:383
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_batched_gemm_e_permute_xdl.hpp:594
decltype(MakeEGridDescriptor_M_N(1, 1, 1, 1)) EGridDesc_M_N
Definition: device_batched_gemm_e_permute_xdl.hpp:295
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(AGridDesc_M_K{}))> AGridDesc_AK0_M_AK1
Definition: device_batched_gemm_e_permute_xdl.hpp:387
static constexpr auto I2
Definition: device_batched_gemm_e_permute_xdl.hpp:176
typename GridwiseGemm::DefaultBlock2ETileMap Block2ETileMap
Definition: device_batched_gemm_e_permute_xdl.hpp:395
static bool IsSupportedArgument(const Argument &arg)
Definition: device_batched_gemm_e_permute_xdl.hpp:579
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_batched_gemm_e_permute_xdl.hpp:670
remove_cvref_t< decltype(GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(BGridDesc_N_K{}))> BGridDesc_BK0_N_BK1
Definition: device_batched_gemm_e_permute_xdl.hpp:390
static auto MakeEGridDescriptor_G0_G1_M_N(index_t G0, index_t G1, index_t MRaw, index_t NRaw, index_t stride_G0, index_t stride_G1, index_t stride_M, index_t stride_N)
Definition: device_batched_gemm_e_permute_xdl.hpp:226
ADataType ComputeDataType
Definition: device_batched_gemm_e_permute_xdl.hpp:333
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{})) EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
Definition: device_batched_gemm_e_permute_xdl.hpp:394
Definition: matrix_padder.hpp:180