/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp Source File
add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck_tile {
9 
10 // X = A + B, Y = Rmsnorm2d(X), QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale)
11 template <typename ADataType_,
12  typename BDataType_,
13  typename GammaDataType_,
14  typename ComputeDataType_,
15  typename XDataType_,
16  typename YScaleDataType_,
17  typename QYDataType_,
18  typename BlockShape_,
19  bool kPadN_,
20  bool kSaveX_,
21  bool kThreePass_>
23 {
32 
33  static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
34  static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
35 
36  static constexpr bool kPadN = kPadN_;
37  static constexpr bool kSaveX = kSaveX_;
38  static constexpr bool kThreePass = kThreePass_;
39 };
40 
41 } // namespace ck_tile
Definition: cluster_descriptor.hpp:13
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
Definition: add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:23
static constexpr bool kPadN
Definition: add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:36
remove_cvref_t< ADataType_ > ADataType
Definition: add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:24
static constexpr bool kNeedCrossLaneSync
Definition: add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:33
static constexpr bool kNeedCrossWarpSync
Definition: add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:34
static constexpr bool kThreePass
Definition: add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:38
remove_cvref_t< QYDataType_ > QYDataType
Definition: add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:30
remove_cvref_t< ComputeDataType_ > ComputeDataType
Definition: add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:27
static constexpr bool kSaveX
Definition: add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:37
remove_cvref_t< GammaDataType_ > GammaDataType
Definition: add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:26
remove_cvref_t< BlockShape_ > BlockShape
Definition: add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:31
remove_cvref_t< BDataType_ > BDataType
Definition: add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:25
remove_cvref_t< XDataType_ > XDataType
Definition: add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:28
remove_cvref_t< YScaleDataType_ > YScaleDataType
Definition: add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp:29