/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_gemm_reduce_xdl_cshuffle.hpp Source File#
device_gemm_reduce_xdl_cshuffle.hpp
Go to the documentation of this file.
23 // Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__global__ void kernel_gemm_reduce_xdl_cshuffle_v1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, ReducePtrsGlobal p_reduces_grid, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const ReduceInElementwiseOperations reduce_in_element_ops, const ReduceAccElementwiseOperations reduce_out_element_ops, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock, const ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock, const Block2CTileMap block_2_ctile_map)
Definition: gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:40
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:149
remove_cvref_t< decltype(MakeReduceGridDescriptor_MBlock_MPerBlock(ReduceGridDesc_M{}))> ReduceGridDescriptor_MBlock_MPerBlock
Definition: gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:326
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))> DefaultBlock2CTileMap
Definition: gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:329
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:232
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:280
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:272
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition: gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:323
__host__ static constexpr __device__ auto MakeReduceGridDescriptor_MBlock_MPerBlock(const ReduceGridDesc_M &d_grid_desc_m)
Definition: gridwise_gemm_reduce_xdl_cshuffle_v1.hpp:299
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_gemm_reduce_xdl_cshuffle.hpp:438
CGridDesc_M_N c_grid_desc_m_n_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:492
const BDataType * p_b_grid_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:487
ReducePtrsGlobal p_reduces_grid_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:489
BElementwiseOperation b_element_op_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:500
CDataType * p_c_grid_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:488
GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:495
CElementwiseOperation c_element_op_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:501
Argument(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, ReducePtrsGlobal p_reduces_grid, index_t MRaw, index_t NRaw, index_t KRaw, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CElementwiseOperation c_element_op, ReduceInElementwiseOperations reduce_in_element_ops, ReduceAccElementwiseOperations reduce_out_element_ops)
Definition: device_gemm_reduce_xdl_cshuffle.hpp:439
GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:498
ReduceInElementwiseOperations reduce_in_element_ops_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:502
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:490
AElementwiseOperation a_element_op_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:499
ReduceAccElementwiseOperations reduce_out_element_ops_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:503
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:491
GridwiseGemm::ReduceGridDescriptor_MBlock_MPerBlock reduce_grid_desc_mblock_mperblock_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:497
const ADataType * p_a_grid_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:486
ReduceGridDesc_M reduce_grid_desc_m_
Definition: device_gemm_reduce_xdl_cshuffle.hpp:493
Definition: device_gemm_reduce_xdl_cshuffle.hpp:508
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_reduce_xdl_cshuffle.hpp:632
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_reduce_xdl_cshuffle.hpp:511
Definition: device_gemm_reduce_xdl_cshuffle.hpp:78
static constexpr auto I0
Definition: device_gemm_reduce_xdl_cshuffle.hpp:81
static constexpr auto I2
Definition: device_gemm_reduce_xdl_cshuffle.hpp:83
static constexpr auto I1
Definition: device_gemm_reduce_xdl_cshuffle.hpp:82
static auto MakeCGridDescriptor_M_N(index_t MRaw, index_t NRaw, index_t StrideC)
Definition: device_gemm_reduce_xdl_cshuffle.hpp:291
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_reduce_xdl_cshuffle.hpp:639
static auto MakeBGridDescriptor_BK0_N_BK1(index_t KRaw, index_t NRaw, index_t StrideB)
Definition: device_gemm_reduce_xdl_cshuffle.hpp:188
decltype(MakeCGridDescriptor_M_N(1, 1, 1)) CGridDesc_M_N
Definition: device_gemm_reduce_xdl_cshuffle.hpp:377
GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1< ADataType, GemmAccDataType, CShuffleDataType, CDataType, ReduceAccDataType, ReducePtrsGlobal, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, ReduceOperations, ReduceInElementwiseOperations, ReduceAccElementwiseOperations, InMemoryDataOperationEnum::Set, ReduceGlobalMemoryDataOperation, AGridDesc_AK0_M_AK1, BGridDesc_BK0_N_BK1, CGridDesc_M_N, ReduceGridDesc_M, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferScalarPerVector_NPerBlock, CReduceThreadClusterLengths_MPerBlock_NPerBlock, CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock, CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock, LoopSched > GridwiseGemm
Definition: device_gemm_reduce_xdl_cshuffle.hpp:434
decltype(MakeReduceGridDescriptor_M(1)) ReduceGridDesc_M
Definition: device_gemm_reduce_xdl_cshuffle.hpp:378
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_reduce_xdl_cshuffle.hpp:811
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_reduce_xdl_cshuffle.hpp:659
static bool IsSupportedArgument(const Argument &arg)
Definition: device_gemm_reduce_xdl_cshuffle.hpp:645
static auto MakeReduceGridDescriptor_M(index_t MRaw)
Definition: device_gemm_reduce_xdl_cshuffle.hpp:350
static auto MakeInvoker()
Definition: device_gemm_reduce_xdl_cshuffle.hpp:735
static auto MakeAGridDescriptor_AK0_M_AK1(index_t MRaw, index_t KRaw, index_t StrideA)
Definition: device_gemm_reduce_xdl_cshuffle.hpp:85
decltype(MakeAGridDescriptor_AK0_M_AK1(1, 1, 1)) AGridDesc_AK0_M_AK1
Definition: device_gemm_reduce_xdl_cshuffle.hpp:375
decltype(MakeBGridDescriptor_BK0_N_BK1(1, 1, 1)) BGridDesc_BK0_N_BK1
Definition: device_gemm_reduce_xdl_cshuffle.hpp:376
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, const void *p_bias, std::array< const void *, 0 > p_ds, void *p_c, std::array< void *, NumReduce > p_reduces, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, std::array< ck::index_t, 0 > StrideDs, std::array< void *, 3 > gemm_element_ops, std::array< void *, 0 > d_element_ops, std::array< void *, NumReduce > reduce_in_element_op, std::array< void *, NumReduce > reduce_out_element_op, ck::index_t=1) override
Definition: device_gemm_reduce_xdl_cshuffle.hpp:739
static auto MakeArgument(const void *p_a, const void *p_b, const void *p_bias, std::array< const void *, 0 > p_ds, void *p_c, std::array< void *, NumReduce > p_reduces, ck::index_t M, ck::index_t N, ck::index_t K, ck::index_t StrideA, ck::index_t StrideB, ck::index_t StrideC, std::array< ck::index_t, 0 > StrideDs, std::array< void *, 3 > gemm_element_ops, std::array< void *, 0 > d_element_ops, std::array< void *, NumReduce > reduce_in_element_op, std::array< void *, NumReduce > reduce_out_element_op)
Definition: device_gemm_reduce_xdl_cshuffle.hpp:665
static constexpr int NumReduce
Definition: device_gemm_reduce_xdl_cshuffle.hpp:664
std::string GetTypeString() const override
Definition: device_gemm_reduce_xdl_cshuffle.hpp:817
Definition: device_gemm_reduce.hpp:17