/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_xdl.hpp Source File#
device_batched_gemm_xdl.hpp
Go to the documentation of this file.
28 * strided batched, but we can easily extend to other layouts. The returned offset can be either \p
36 * \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
39 * device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
40 * DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
43 * \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes.
44 * Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__global__ void kernel_batched_gemm_xdlops_v2r3(const typename DeviceOp::Argument karg)
Definition: device_batched_gemm_xdl.hpp:53
Definition: ck.hpp:264
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition: amd_wave_read_first_lane.hpp:100
@ Default
@ Interwave
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_gemm_xdlops_v2r3.hpp:781
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdlops_v2r3.hpp:968
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdlops_v2r3.hpp:382
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_batched_gemm.hpp:25
Definition: device_batched_gemm_xdl.hpp:223
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch
Definition: device_batched_gemm_xdl.hpp:250
const BDataType * p_b_grid
Definition: device_batched_gemm_xdl.hpp:247
const ADataType * p_a_grid
Definition: device_batched_gemm_xdl.hpp:246
Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t Batch_)
Definition: device_batched_gemm_xdl.hpp:224
index_t Batch
Definition: device_batched_gemm_xdl.hpp:249
CDataType * p_c_grid
Definition: device_batched_gemm_xdl.hpp:248
Definition: device_batched_gemm_xdl.hpp:145
__host__ constexpr __device__ long_index_t GetBPtrOffset(index_t g_idx) const
Definition: device_batched_gemm_xdl.hpp:158
__host__ constexpr __device__ long_index_t GetCPtrOffset(index_t g_idx) const
Definition: device_batched_gemm_xdl.hpp:163
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC)
Definition: device_batched_gemm_xdl.hpp:146
__host__ constexpr __device__ long_index_t GetAPtrOffset(index_t g_idx) const
Definition: device_batched_gemm_xdl.hpp:153
Definition: device_batched_gemm_xdl.hpp:255
float Run(const Argument &karg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_batched_gemm_xdl.hpp:258
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_batched_gemm_xdl.hpp:297
Definition: device_batched_gemm_xdl.hpp:137
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_batched_gemm_xdl.hpp:321
static constexpr auto I0
Definition: device_batched_gemm_xdl.hpp:138
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t Batch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_batched_gemm_xdl.hpp:358
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_batched_gemm_xdl.hpp:391
static auto MakeInvoker()
Definition: device_batched_gemm_xdl.hpp:355
std::string GetTypeString() const override
Definition: device_batched_gemm_xdl.hpp:397
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t BatchStrideA, index_t BatchStrideB, index_t BatchStrideC, index_t Batch)
Definition: device_batched_gemm_xdl.hpp:326
static constexpr auto I1
Definition: device_batched_gemm_xdl.hpp:139
static bool IsSupportedArgument(const Problem &problem)
Definition: device_batched_gemm_xdl.hpp:310
static constexpr auto K1Number
Definition: device_batched_gemm_xdl.hpp:142
static constexpr auto I2
Definition: device_batched_gemm_xdl.hpp:140
static constexpr bool IsValidCompilationParameter()
Definition: device_batched_gemm_xdl.hpp:304
typename GridwiseGemm::Problem Problem
Definition: device_batched_gemm_xdl.hpp:219