/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp Source File#
gridwise_gemm_wmma_cshuffle_v3.hpp
Go to the documentation of this file.
GemmSpecialization
Definition: gemm_specialization.hpp:11
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
Definition: block_to_ctile_map.hpp:270
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:393
const AElementwiseOperation a_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:442
__host__ __device__ bool IsAtomicAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:432
const CDEElementwiseOperation cde_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:444
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, std::array< const void *, NumDTensor > p_ds_grid_, EDataType *p_e_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CDEElementwiseOperation cde_element_op_, bool is_reduce_=false)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:394
EDataType * p_e_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:440
bool is_reduce
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:447
const BDataType * p_b_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:438
const BElementwiseOperation b_element_op
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:443
const ADataType * p_a_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:437
__host__ __device__ bool IsReduceAdd() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:427
DsGridPointer p_ds_grid
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:439
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:327
index_t KPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:384
std::array< index_t, NumDTensor > StrideDs
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:378
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, std::array< index_t, NumDTensor > StrideDs_, index_t StrideE_, index_t KBatch_)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:328
index_t BK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:386
index_t KRead
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:383
index_t K
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:375
index_t M
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:373
index_t N
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:374
index_t StrideB
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:377
index_t StrideE
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:379
index_t MBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:387
index_t AK0
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:385
index_t MPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:381
index_t NPadded
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:382
index_t KBatch
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:380
index_t StrideA
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:376
__host__ void Print() const
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:355
index_t NBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:388
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:451
index_t c_reduce_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:502
index_t b_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:501
__device__ SplitKBatchOffset(Argument &karg, index_t k_id)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:453
index_t a_k_split_offset
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:500
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:106
static constexpr auto I0
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:108
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:133
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:678
static constexpr __device__ auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc &de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:853
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:172
static constexpr auto AK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:121
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:178
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:228
static constexpr auto I4
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:112
__host__ static __device__ auto MakeDEGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideDE)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:440
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:202
__host__ static constexpr __device__ auto GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat()
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:814
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:162
static constexpr auto I1
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:109
static constexpr auto BK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:124
static constexpr auto I7
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:115
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:540
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:157
static constexpr auto AK1Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:123
decltype(MakeDsGridPointer()) DsGridPointer
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:515
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, decltype(MakeAWmmaTileDescriptor(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBWmmaTileDescriptor(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >())> BlockwiseGemmPipe
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:850
static constexpr auto I3
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:111
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:197
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:140
static constexpr auto BK0Number
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:122
static constexpr auto I5
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:113
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:312
static constexpr auto I6
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:114
static constexpr auto I2
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:110
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:529
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:517
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:190
static constexpr index_t NumDTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:502
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:167
"Universal" GEMM kernel with SplitK support.
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:230
static constexpr auto I0
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:108
static constexpr index_t APackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:133
static __device__ void Run(void *p_shared, const SplitKBatchOffset &splitk_batch_offset, Argument &karg)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:596
static constexpr __device__ auto MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DEGridDesc &de_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:853
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:172
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:178
__host__ static __device__ auto MakeAGridDescriptor_AK0_M_AK1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA, index_t AK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:228
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:202
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:162
static constexpr auto I1
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:109
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:157
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:197
typename Base::BlockwiseGemmPipe BlockwiseGemmPipe
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:505
static constexpr index_t BPackedSize
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:140
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, DsGridPointer &p_ds_grid, EDataType *p_e_grid, void *p_shared, const Problem &problem, AElementwiseOperation a_element_op, BElementwiseOperation b_element_op, CDEElementwiseOperation cde_element_op)
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:517
__host__ static __device__ auto MakeBGridDescriptor_BK0_N_BK1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB, index_t BK0)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:312
static constexpr auto I2
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:110
static constexpr __device__ auto MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const DsGridDesc &ds_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:529
__host__ static __device__ auto MakeDsGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, std::array< index_t, NumDTensor > StrideDs)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:517
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:318
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:190
static __device__ index_t GetKBlockPerScale()
Definition: gridwise_gemm_wmma_cshuffle_v3.hpp:512
static constexpr index_t NumDTensor
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:502
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_wmma_cshuffle_v3_common.hpp:167
Definition: functional2.hpp:33
Definition: device_base.hpp:51