/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp Source File#
rmsnorm2d_fwd_kernel.hpp
Go to the documentation of this file.
150 if (kFusedAdd != Rmsnorm2dFusedAddEnum::NO_ADD) n += _SS_("_") + Rmsnorm2dFusedAddEnumName<kFusedAdd>::name;
151 if (kFusedQuant != Rmsnorm2dFusedQuantEnum::NO_SWEEP) n += _SS_("_") + Rmsnorm2dFusedQuantEnumName<kFusedQuant>::name;
175 _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" +
176 _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" +
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_DEVICE auto make_null_tile_window(const WindowLengths &window_lengths)
Definition: null_tile_window.hpp:66
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
@ NO_SPECIFIC_MODEL
@ T5_MODEL_LIKE
@ SMOOTH_DYNAMIC_QUANT
@ DYNAMIC_QUANT
@ PRE_ADD_STORE
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
#define _TS_
#define _SS_
Definition: rmsnorm2d_fwd_traits.hpp:20
Definition: rmsnorm2d_fwd_traits.hpp:34
Definition: rmsnorm2d_fwd_kernel.hpp:79
const void * p_sm_scale
Definition: rmsnorm2d_fwd_kernel.hpp:82
void * p_y_residual
Definition: rmsnorm2d_fwd_kernel.hpp:86
const void * p_x_residual
Definition: rmsnorm2d_fwd_kernel.hpp:81
Definition: rmsnorm2d_fwd_kernel.hpp:130
Definition: rmsnorm2d_fwd_kernel.hpp:14
void * p_y_residual
Definition: rmsnorm2d_fwd_kernel.hpp:21
const void * p_x_residual
Definition: rmsnorm2d_fwd_kernel.hpp:16
const void * p_sm_scale
Definition: rmsnorm2d_fwd_kernel.hpp:17
Definition: rmsnorm2d_fwd_kernel.hpp:39
CK_TILE_DEVICE void operator()(Kargs kargs) const
Definition: rmsnorm2d_fwd_kernel.hpp:183
XDataType XResidualDataType
Definition: rmsnorm2d_fwd_kernel.hpp:54
remove_cvref_t< typename Problem::UnquantYDataType > UnquantYDataType
Definition: rmsnorm2d_fwd_kernel.hpp:51
remove_cvref_t< Epilogue_ > Epilogue
Definition: rmsnorm2d_fwd_kernel.hpp:41
remove_cvref_t< typename Problem::YScaleDataType > YScaleDataType
Definition: rmsnorm2d_fwd_kernel.hpp:50
static constexpr bool kTwoPass
Definition: rmsnorm2d_fwd_kernel.hpp:65
static constexpr bool kSaveInvRms
Definition: rmsnorm2d_fwd_kernel.hpp:58
static constexpr CK_TILE_HOST auto GridSize(const Hargs &hargs)
Definition: rmsnorm2d_fwd_kernel.hpp:122
typename Pipeline::Problem Problem
Definition: rmsnorm2d_fwd_kernel.hpp:42
remove_cvref_t< typename Problem::InvRmsDataType > InvRmsDataType
Definition: rmsnorm2d_fwd_kernel.hpp:48
static CK_TILE_HOST std::string GetName()
Definition: rmsnorm2d_fwd_kernel.hpp:142
remove_cvref_t< typename Problem::YDataType > YDataType
Definition: rmsnorm2d_fwd_kernel.hpp:47
static constexpr auto kFusedQuant
Definition: rmsnorm2d_fwd_kernel.hpp:67
remove_cvref_t< typename Problem::ComputeDataType > ComputeDataType
Definition: rmsnorm2d_fwd_kernel.hpp:46
remove_cvref_t< Pipeline_ > Pipeline
Definition: rmsnorm2d_fwd_kernel.hpp:40
remove_cvref_t< typename Problem::SmoothScaleDataType > SmoothScaleDataType
Definition: rmsnorm2d_fwd_kernel.hpp:49
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: rmsnorm2d_fwd_kernel.hpp:140
static constexpr index_t Block_M
Definition: rmsnorm2d_fwd_kernel.hpp:61
static constexpr auto kFusedAdd
Definition: rmsnorm2d_fwd_kernel.hpp:66
XDataType YResidualDataType
Definition: rmsnorm2d_fwd_kernel.hpp:55
static constexpr bool kHasGamma
Definition: rmsnorm2d_fwd_kernel.hpp:57
remove_cvref_t< typename Problem::XDataType > XDataType
Definition: rmsnorm2d_fwd_kernel.hpp:44
static constexpr index_t kBlockSize
Definition: rmsnorm2d_fwd_kernel.hpp:73
static constexpr index_t ThreadPerWarp_N
Definition: rmsnorm2d_fwd_kernel.hpp:70
static constexpr CK_TILE_HOST auto BlockSize()
Definition: rmsnorm2d_fwd_kernel.hpp:127
static constexpr CK_TILE_HOST Kargs MakeKargs(const Hargs &hargs)
Definition: rmsnorm2d_fwd_kernel.hpp:102
remove_cvref_t< typename Problem::GammaDataType > GammaDataType
Definition: rmsnorm2d_fwd_kernel.hpp:45
static constexpr index_t Block_N
Definition: rmsnorm2d_fwd_kernel.hpp:62
static constexpr bool kSaveUnquant
Definition: rmsnorm2d_fwd_kernel.hpp:59
static constexpr index_t Repeat_N
Definition: rmsnorm2d_fwd_kernel.hpp:72
static constexpr auto kUseModelSensitiveRMSNorm
Definition: rmsnorm2d_fwd_kernel.hpp:68
static constexpr index_t Vector_N
Definition: rmsnorm2d_fwd_kernel.hpp:71
Definition: integral_constant.hpp:13
Definition: sequence.hpp:49