include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp Source File#
grouped_convolution_forward_kernel.hpp
Go to the documentation of this file.
391 return concat('_', "grouped_convolution_forward", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName());
Definition: cluster_descriptor.hpp:13
ConvolutionSpecialization
Definition: convolution_specialization.hpp:11
@ Filter1x1Stride1Pad0
@ Filter1x1Pad0
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:529
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:412
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:343
__device__ X atomic_add(X *p_dst, const X &x)
The Grouped Convolution kernel device arguments.
Definition: grouped_convolution_forward_kernel.hpp:22
index_t k_batch
Definition: grouped_convolution_forward_kernel.hpp:282
array< index_t, NonSpatialDims+GroupedConvTraitsType::NDimSpatial > out_g_n_k_wos_lengths
Definition: grouped_convolution_forward_kernel.hpp:275
array< index_t, GroupedConvTraitsType::NDimSpatial > input_left_pads
Definition: grouped_convolution_forward_kernel.hpp:279
array< index_t, NonSpatialDims+GroupedConvTraitsType::NDimSpatial > in_g_n_c_wis_lengths
Definition: grouped_convolution_forward_kernel.hpp:273
remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeCDescriptor_M_N< typename GroupedConvTraitsType::OutLayout >())> CGridDescMN
Definition: grouped_convolution_forward_kernel.hpp:270
const void * wei_ptr
Definition: grouped_convolution_forward_kernel.hpp:289
long_index_t group_stride_c
Definition: grouped_convolution_forward_kernel.hpp:299
index_t GemmBatch
Definition: grouped_convolution_forward_kernel.hpp:286
index_t GemmN
Definition: grouped_convolution_forward_kernel.hpp:284
index_t GemmM
Definition: grouped_convolution_forward_kernel.hpp:283
long_index_t group_stride_b
Definition: grouped_convolution_forward_kernel.hpp:298
CGridDescMN c_grid_desc_m_n
Definition: grouped_convolution_forward_kernel.hpp:295
index_t GemmK
Definition: grouped_convolution_forward_kernel.hpp:285
array< index_t, GroupedConvTraitsType::NDimSpatial > input_right_pads
Definition: grouped_convolution_forward_kernel.hpp:280
AGridDescMK a_grid_desc_m_k
Definition: grouped_convolution_forward_kernel.hpp:293
remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeADescriptor_M_K< typename GroupedConvTraitsType::InLayout >())> AGridDescMK
Definition: grouped_convolution_forward_kernel.hpp:264
array< index_t, NonSpatialDims+GroupedConvTraitsType::NDimSpatial > wei_g_k_c_xs_lengths
Definition: grouped_convolution_forward_kernel.hpp:274
static constexpr index_t NumDTensor
Definition: grouped_convolution_forward_kernel.hpp:27
BGridDescNK b_grid_desc_n_k
Definition: grouped_convolution_forward_kernel.hpp:294
array< index_t, GroupedConvTraitsType::NDimSpatial > conv_filter_dilations
Definition: grouped_convolution_forward_kernel.hpp:278
array< index_t, GroupedConvTraitsType::NDimSpatial > conv_filter_strides
Definition: grouped_convolution_forward_kernel.hpp:277
remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeBDescriptor_N_K< typename GroupedConvTraitsType::WeiLayout >())> BGridDescNK
Definition: grouped_convolution_forward_kernel.hpp:267
void * out_ptr
Definition: grouped_convolution_forward_kernel.hpp:291
static constexpr index_t NonSpatialDims
Definition: grouped_convolution_forward_kernel.hpp:272
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvFwdHostArgs &args)
Definition: grouped_convolution_forward_kernel.hpp:37
long_index_t group_stride_a
Definition: grouped_convolution_forward_kernel.hpp:297
std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_convolution_forward_kernel.hpp:290
const void * in_ptr
Definition: grouped_convolution_forward_kernel.hpp:288
The Grouped Conv kernel host arguments.
Definition: grouped_convolution_utils.hpp:19
const std::vector< const void * > ds_ptr
Definition: grouped_convolution_utils.hpp:38
The Grouped Convolution Forward kernel template.
Definition: grouped_convolution_forward_kernel.hpp:345
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: grouped_convolution_forward_kernel.hpp:579
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_convolution_forward_kernel.hpp:350
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_convolution_forward_kernel.hpp:351
static constexpr auto I2
Definition: grouped_convolution_forward_kernel.hpp:379
static constexpr ConvolutionSpecialization ConvSpecialization
Definition: grouped_convolution_forward_kernel.hpp:347
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: grouped_convolution_forward_kernel.hpp:368
remove_cvref_t< typename GemmPipeline::BDataType > WeiDataType
Definition: grouped_convolution_forward_kernel.hpp:367
remove_cvref_t< typename GemmPipeline::ADataType > InDataType
Definition: grouped_convolution_forward_kernel.hpp:366
static CK_TILE_DEVICE void RunGemm2LDS(const InDataType *a_ptr, const WeiDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, OutDataType *c_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const GroupedConvFwdKernelArgsSpecialized &kargs, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_forward_kernel.hpp:720
remove_cvref_t< typename GroupedConvTraitsType::DsLayout > DsLayout
Definition: grouped_convolution_forward_kernel.hpp:359
static CK_TILE_HOST bool IsSupportedArgument(const GroupedConvFwdKernelArgsSpecialized &kargs)
Definition: grouped_convolution_forward_kernel.hpp:414
static constexpr auto I1
Definition: grouped_convolution_forward_kernel.hpp:378
static CK_TILE_DEVICE void RunGemm(const InDataType *a_ptr, const WeiDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, OutDataType *c_ptr, void *smem_ptr_0, const GroupedConvFwdKernelArgsSpecialized &kargs, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_forward_kernel.hpp:670
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const
Definition: grouped_convolution_forward_kernel.hpp:755
static constexpr bool IsSplitKSupported
Definition: grouped_convolution_forward_kernel.hpp:375
static constexpr index_t NDimSpatial
Definition: grouped_convolution_forward_kernel.hpp:346
static constexpr index_t NumDTensor
Definition: grouped_convolution_forward_kernel.hpp:362
remove_cvref_t< typename GemmPipeline::CLayout > GemmCLayout
Definition: grouped_convolution_forward_kernel.hpp:354
static constexpr CK_TILE_HOST GroupedConvFwdKernelArgsSpecialized MakeKernelArgs(const GroupedConvFwdHostArgs &hostArgs)
Definition: grouped_convolution_forward_kernel.hpp:404
remove_cvref_t< typename GroupedConvTraitsType::OutLayout > OutLayout
Definition: grouped_convolution_forward_kernel.hpp:358
static CK_TILE_HOST const std::string GetName()
Definition: grouped_convolution_forward_kernel.hpp:388
remove_cvref_t< typename EpiloguePipeline::DsLayout > GemmDsLayout
Definition: grouped_convolution_forward_kernel.hpp:361
static constexpr index_t KernelBlockSize
Definition: grouped_convolution_forward_kernel.hpp:364
static constexpr auto I0
Definition: grouped_convolution_forward_kernel.hpp:377
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_forward_kernel.hpp:349
remove_cvref_t< typename GemmPipeline::ALayout > GemmALayout
Definition: grouped_convolution_forward_kernel.hpp:352
remove_cvref_t< typename GroupedConvTraitsType::WeiLayout > WeiLayout
Definition: grouped_convolution_forward_kernel.hpp:357
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: grouped_convolution_forward_kernel.hpp:620
static constexpr CK_TILE_HOST auto GridSize(const GroupedConvFwdKernelArgsSpecialized &kargs)
Definition: grouped_convolution_forward_kernel.hpp:395
static constexpr CK_TILE_HOST auto BlockSize()
Definition: grouped_convolution_forward_kernel.hpp:401
static constexpr auto I3
Definition: grouped_convolution_forward_kernel.hpp:380
GroupedConvFwdKernelArgs< GroupedConvTraitsType > GroupedConvFwdKernelArgsSpecialized
Definition: grouped_convolution_forward_kernel.hpp:372
remove_cvref_t< typename GemmPipeline::BLayout > GemmBLayout
Definition: grouped_convolution_forward_kernel.hpp:353
remove_cvref_t< typename EpiloguePipeline::ODataType > OutDataType
Definition: grouped_convolution_forward_kernel.hpp:370
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: grouped_convolution_forward_kernel.hpp:409
remove_cvref_t< typename GroupedConvTraitsType::InLayout > InLayout
Definition: grouped_convolution_forward_kernel.hpp:356
static CK_TILE_DEVICE auto MakeGemmTensorViews(const InDataType *a_ptr, const WeiDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, OutDataType *c_ptr, const GroupedConvFwdKernelArgsSpecialized &kargs)
Definition: grouped_convolution_forward_kernel.hpp:540
Definition: transform_conv_fwd_to_gemm.hpp:19
Definition: integral_constant.hpp:13
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition: convolution_parameter.hpp:130
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition: convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > input_right_pads_
Definition: convolution_parameter.hpp:137
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition: convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition: convolution_parameter.hpp:129
std::vector< ck_tile::long_index_t > input_left_pads_
Definition: convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition: convolution_parameter.hpp:134
Definition: type_traits.hpp:115
Definition: sequence.hpp:52