/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3r1.hpp Source File#
device_gemm_xdl_cshuffle_v3r1.hpp
  
Go to the documentation of this file.
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
std::string getGemmSpecializationString(const GemmSpecialization &s)
Definition: gemm_specialization.hpp:33
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:264
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
BlockGemmPipelineVersion
Definition: blockwise_gemm_pipeline_xdlops_ab_scale_selector.hpp:13
@ One
@ Seven
@ Even
@ Odd
@ Four
@ Two
@ Full
@ Three
@ Five
@ Six
__global__ void kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:58
__host__ constexpr __device__ auto generate_sequence_v2(F &&f, Number< N >)
Definition: sequence_helper.hpp:25
@ Intrawave
@ Interwave
__global__ void kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg)
Definition: gridwise_gemm_xdl_cshuffle_streamk_v3.hpp:37
Definition: stream_config.hpp:10
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:241
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:596
CDataType * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:261
index_t M
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:222
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:203
index_t N
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:223
index_t KBatch
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:228
index_t K
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:224
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:66
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:610
static constexpr __host__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdl_cshuffle_v3.hpp:1004
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:603
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:88
Definition: integral_constant.hpp:10
Definition: reduction_operator.hpp:37
Definition: device_base.hpp:50
void * p_workspace_
Definition: device_base.hpp:57
Definition: device_base.hpp:61
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:142
const std::array< const void *, NumDTensor > p_ds
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:171
std::array< ck::index_t, NumDTensor > StrideDs
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:172
Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, const std::array< const void *, NumDTensor > p_ds_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< ck::index_t, NumDTensor > StrideDs_, index_t StrideC_, index_t k_batch_)
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:143
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:208
float RunReduce(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:209
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:546
float Run(const Argument &arg_, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:281
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:87
DeviceReduceThreadWiseMultiD< ReduceDataType, DsDataType, GemmAccDataType, CDataType, 3, 1, ReduceAdd, PassThrough, OutElementwiseOperation, 256, CShuffleBlockTransferScalarPerVector_NPerBlock, 1, 0, CShuffleBlockTransferScalarPerVector_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, decltype(DsVectorLengthSequence)> DeviceReduceInstance
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:204
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, std::array< const void *, NumDTensor > p_ds, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:605
static constexpr index_t NumDTensor
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:88
size_t GetWorkSpaceSize(const BaseArgument *p_arg) const override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:686
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:553
static constexpr auto DsVectorLengthSequence
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:178
ck::reduce::Add ReduceAdd
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:175
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:559
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:636
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:578
GridwiseGemm_xdl_cshuffle_v3< ALayout, BLayout, CLayout, ADataType, BDataType, GemmAccDataType, CShuffleDataType, ReduceDataType, AElementwiseOperation, BElementwiseOperation, PassThrough, GemmSpec, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, ComputeTypeB > GridwiseGemm
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:139
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, const std::array< const void *, NumDTensor > p_ds, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, std::array< ck::index_t, NumDTensor > StrideDs, index_t StrideC, index_t KBatch, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:583
static auto MakeInvoker()
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:602
CElementwiseOperation OutElementwiseOperation
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:176
ck::tensor_operation::element_wise::PassThrough PassThrough
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:90
std::string GetTypeString() const override
Definition: device_gemm_xdl_cshuffle_v3r1.hpp:642
Definition: device_gemm_v2.hpp:57
Definition: device_reduce_threadwise_multi_d.hpp:47
Definition: unary_element_wise_operation.hpp:241
Definition: flush_cache.hpp:137