/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/flatmm/kernel/mixed_prec_flatmm_kernel.hpp Source File#
mixed_prec_flatmm_kernel.hpp
Go to the documentation of this file.
57 return concat('_', "mixed_prec_gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
__device__ X atomic_add(X *p_dst, const X &x)
Definition: mixed_prec_flatmm_kernel.hpp:18
static constexpr int N_Pack
Definition: mixed_prec_flatmm_kernel.hpp:40
static constexpr auto I4
Definition: mixed_prec_flatmm_kernel.hpp:48
static constexpr index_t KernelBlockSize
Definition: mixed_prec_flatmm_kernel.hpp:31
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: mixed_prec_flatmm_kernel.hpp:277
static CK_TILE_HOST const std::string GetName()
Definition: mixed_prec_flatmm_kernel.hpp:54
static constexpr auto I0
Definition: mixed_prec_flatmm_kernel.hpp:44
static constexpr auto I1
Definition: mixed_prec_flatmm_kernel.hpp:45
static constexpr auto I2
Definition: mixed_prec_flatmm_kernel.hpp:46
CK_TILE_DEVICE void operator()(FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> kargs, int partition_idx=blockIdx.x) const
Definition: mixed_prec_flatmm_kernel.hpp:413
static constexpr int QuantPackedSize
Definition: mixed_prec_flatmm_kernel.hpp:39
static constexpr bool UsePersistentKernel
Definition: mixed_prec_flatmm_kernel.hpp:32
static constexpr CK_TILE_HOST auto GridSize(const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs)
Definition: mixed_prec_flatmm_kernel.hpp:63
typename Underlying::SplitKBatchOffset SplitKBatchOffset
Definition: mixed_prec_flatmm_kernel.hpp:101
static constexpr auto I3
Definition: mixed_prec_flatmm_kernel.hpp:47
static CK_TILE_DEVICE void RunFlatmm(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_ping, void *smem_ptr_pong, const FlatmmKernelArgs< ScaleM, ScaleN, DsDataType::size()> &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition: mixed_prec_flatmm_kernel.hpp:347
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: mixed_prec_flatmm_kernel.hpp:105
static constexpr index_t NumDTensor
Definition: mixed_prec_flatmm_kernel.hpp:42
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: mixed_prec_flatmm_kernel.hpp:210
Definition: flatmm_kernel.hpp:362
Definition: flatmm_kernel.hpp:229
Definition: flatmm_kernel.hpp:249
static constexpr CK_TILE_HOST auto BlockSize()
Definition: flatmm_kernel.hpp:330
remove_cvref_t< typename FlatmmPipeline::BlockGemmShape > BlockGemmShape
Definition: flatmm_kernel.hpp:253
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: flatmm_kernel.hpp:250
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: flatmm_kernel.hpp:258
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: flatmm_kernel.hpp:259
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: flatmm_kernel.hpp:266
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: flatmm_kernel.hpp:254
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPingSize()
Definition: flatmm_kernel.hpp:352
remove_cvref_t< FlatmmPipeline_ > FlatmmPipeline
Definition: flatmm_kernel.hpp:251
remove_cvref_t< typename FlatmmPipeline::CLayout > ELayout
Definition: flatmm_kernel.hpp:257
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition: flatmm_kernel.hpp:263
remove_cvref_t< typename FlatmmPipeline::BLayout > BLayout
Definition: flatmm_kernel.hpp:256
remove_cvref_t< typename FlatmmPipeline::ALayout > ALayout
Definition: flatmm_kernel.hpp:255
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition: flatmm_kernel.hpp:264
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemPongSize()
Definition: flatmm_kernel.hpp:356
Definition: integral_constant.hpp:13
Unsigned representation of a conventional biased Float32 exponent.
Definition: e8m0.hpp:27
Definition: type_traits.hpp:115
Definition: numeric.hpp:81
Definition: sequence.hpp:49