21 template <
typename Problem_, 
typename Policy_ = FusedMoeGemmPipelineFlatmmPolicy>
 
   42     using Traits = 
typename Problem::Traits;
 
   60         if constexpr(Problem::kBlockPerCu != -1)
 
   61             return Problem::kBlockPerCu;
 
   69     static constexpr 
const char* 
name = 
"fused_moe_flatmm";
 
   74         return Policy::template GetSmemSize_A<Problem>();
 
   79         return Policy::template GetSmemSize<Problem>();
 
   85         constexpr 
auto a_dist = Policy::template MakeGlobalTileDistribution_A<Problem>();
 
   86         const auto a_coord    = a_dist.calculate_index();
 
   93         constexpr 
auto o_dist = Policy::template MakeOGlobalTileDistribution<Problem>();
 
   94         const auto o_coord    = o_dist.calculate_index();
 
   98     template <
typename AWindow, 
typename GWindow, 
typename DWindow, 
typename OWindow>
 
  100                                    const GWindow& g_window_,
 
  101                                    const DWindow& d_window_,
 
  108         _Pragma(
"clang diagnostic push") _Pragma(
"clang diagnostic ignored \"-Wc++20-extensions\"");
 
  109         constexpr 
auto NEG1  = 
number<-1>{};
 
  118             Policy::template GetSmemSize_A<Problem>());
 
  120         auto g_view = g_window_.get_bottom_tensor_view();
 
  122         auto u_view = [&]() {
 
  129                 index_t nr_0 = intermediate_size / BlockShape::Block_Nr0;
 
  130                 index_t kr_0 = hidden_size / BlockShape::Block_Kr0;
 
  133                     g_window_.get_bottom_tensor_view().get_buffer_view().p_data_;
 
  136                 const auto u_view_ = make_naive_tensor_view<address_space_enum::global>(
 
  142                 const auto u_view_1_ =
 
  153             a_window_, Policy::template MakeGlobalTileDistribution_A<Problem>());
 
  156                                     Policy::template MakeGlobalTileDistribution_G<Problem>(),
 
  160                                     Policy::template MakeGlobalTileDistribution_D<Problem>(),
 
  163             o_window_, Policy::template MakeGlobalTileDistribution_O<Problem>());
 
  165         using g_thread_type = decltype(
load_tile(g_win));
 
  166         using d_thread_type = decltype(
load_tile(d_win));
 
  168         using WarpGemm0  = decltype(Policy::template GetWarpGemm0<Problem>());
 
  169         using WarpGemm1  = decltype(Policy::template GetWarpGemm1<Problem>());
 
  170         auto warp_gemm_0 = WarpGemm0{};
 
  171         auto warp_gemm_1 = WarpGemm1{};
 
  176                                  smem_0, Policy::template MakeLdsStoreDesc_A<Problem>()),
 
  177                              Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
 
  182                                  smem_1, Policy::template MakeLdsStoreDesc_A<Problem>()),
 
  183                              Policy::template MakeLdsStoreDesc_A<Problem>().get_lengths(),
 
  186         auto a_sld_win0 = [&]() {
 
  187             using WG                        = WarpGemm0;
 
  197                 a_outer_dstr_enc, 
typename WG::AWarpDstrEncoding{});
 
  199                 make_tensor_view<address_space_enum::lds>(
 
  200                     smem_0, Policy::template MakeLdsLoadDesc_A<Problem>()),
 
  201                 Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
 
  207         auto a_sld_win1 = [&]() {
 
  208             using WG                        = WarpGemm0;
 
  218                 a_outer_dstr_enc, 
typename WG::AWarpDstrEncoding{});
 
  220                 make_tensor_view<address_space_enum::lds>(
 
  221                     smem_1, Policy::template MakeLdsLoadDesc_A<Problem>()),
 
  222                 Policy::template MakeLdsLoadDesc_A<Problem>().get_lengths(),
 
  227         auto bridge_sst_win = [&]() {
 
  229                 make_tensor_view<address_space_enum::lds>(
 
  231                     Policy::template MakeBridgeLdsStoreDesc<Problem>()),
 
  232                 Policy::template MakeBridgeLdsStoreDesc<Problem>().get_lengths(),
 
  236         auto bridge_sld_win = [&]() {
 
  238                 make_tensor_view<address_space_enum::lds>(
 
  240                     Policy::template MakeBridgeLdsLoadDesc<Problem>()),
 
  241                 Policy::template MakeBridgeLdsLoadDesc<Problem>().get_lengths(),
 
  243                 Policy::template MakeYTileDistribution<Problem>());
 
  249         constexpr 
auto issues_a = 
number<a_win.get_num_of_access()>{};
 
  250         constexpr 
auto issues_g = 
number<g_win.get_num_of_access()>{};
 
  253         constexpr 
auto issues_gemm0 =
 
  254             number<BlockShape::Repeat_M0 * BlockShape::Repeat_N0 * BlockShape::Repeat_K0 *
 
  255                    warp_gemm_0.get_num_of_access()>{};
 
  256         constexpr 
auto issues_gemm1 =
 
  257             number<BlockShape::Repeat_M1 * BlockShape::Repeat_N1 * BlockShape::Repeat_K1 *
 
  258                    warp_gemm_1.get_num_of_access()>{};
 
  262             (hidden_size + BlockShape::Block_K0 - 1) / BlockShape::Block_K0;
 
  264             (hidden_size + BlockShape::Block_N1 - 1) / BlockShape::Block_N1;
 
  266         using a_thread_type = decltype(
load_tile(a_sld_win0));
 
  270             auto& a_store_, 
auto i_access, PreNop = {})
 
  274         auto move_a = [&]() {
 
  277         auto sld_a = [&](
auto& a_, 
auto& win_, 
auto i_access) {
 
  282             auto& g_, 
auto i_access, PreNop = {})
 
  287                 if constexpr(i_access.
value == 0)
 
  289                     g_win.bottom_tensor_view_ = g_view;
 
  291                 else if constexpr(i_access.
value == issues_g / 2)
 
  293                     g_win.bottom_tensor_view_ = u_view;
 
  298         auto move_g = [&]() {
 
  304             auto& d_, 
auto i_access, PreNop = {})
 
  308         auto move_d = [&]() {
 
  314             auto& o_, 
auto i_access, PreNop = {})
 
  319         auto acc_0  = Policy::template MakeCBlockTile_Gemm0<Problem>();
 
  321             [&](
auto) { 
return Policy::template MakeCBlockTile_Gemm1<Problem>(); }, 
number<2>{});
 
  325         (
auto& t_c, 
auto& t_a, 
auto& t_b, 
auto i_access, PostNop = {}) {
 
  328             constexpr 
auto repeat_sub = WarpGemm::get_num_of_access();
 
  329             constexpr 
auto repeat_m = BlockShape::Repeat_M0;
 
  331             constexpr 
auto repeat_k = BlockShape::Repeat_K0;
 
  333             constexpr 
auto i_sub = i_access % repeat_sub;
 
  334             constexpr 
auto i_k = (i_access / repeat_sub) % repeat_k;
 
  335             constexpr 
auto i_m = (i_access / (repeat_sub * repeat_k )) % repeat_m;
 
  336             constexpr 
auto i_n = (i_access / (repeat_sub * repeat_k )) / repeat_m;
 
  338             using AWarpTensor = 
typename WarpGemm::AWarpTensor;
 
  339             using BWarpTensor = 
typename WarpGemm::BWarpTensor;
 
  340             using CWarpTensor = 
typename WarpGemm::CWarpTensor;
 
  341             using AWarpDstr = 
typename WarpGemm::AWarpDstr;
 
  342             using BWarpDstr = 
typename WarpGemm::BWarpDstr;
 
  343             using CWarpDstr = 
typename WarpGemm::CWarpDstr;
 
  349             constexpr 
auto a_warp_y_lengths = 
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
 
  350             constexpr 
auto b_warp_y_lengths = 
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
 
  351             constexpr 
auto c_warp_y_lengths = 
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
 
  354             w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data(
 
  359             w_b.get_thread_buffer() = t_b.get_y_sliced_thread_data(
 
  364             w_c.get_thread_buffer() = t_c.get_y_sliced_thread_data(
 
  370             t_c.set_y_sliced_thread_data(
 
  373                         w_c.get_thread_buffer());
 
  379         (
auto& t_c, 
auto& t_a, 
auto& t_b, 
auto i_access, PostNop = {}) {
 
  382             constexpr 
auto repeat_sub = WarpGemm::get_num_of_access();
 
  383             constexpr 
auto repeat_m = BlockShape::Repeat_M0;
 
  385             constexpr 
auto repeat_k = BlockShape::Repeat_K0;
 
  387             constexpr 
auto i_sub = i_access % repeat_sub;
 
  388             constexpr 
auto i_k = (i_access / repeat_sub) % repeat_k;
 
  389             constexpr 
auto i_m = (i_access / (repeat_sub * repeat_k )) % repeat_m;
 
  390             constexpr 
auto i_n = (i_access / (repeat_sub * repeat_k )) / repeat_m;
 
  392             using AWarpTensor = 
typename WarpGemm::AWarpTensor;
 
  393             using BWarpTensor = 
typename WarpGemm::BWarpTensor;
 
  394             using CWarpTensor = 
typename WarpGemm::CWarpTensor;
 
  395             using AWarpDstr = 
typename WarpGemm::AWarpDstr;
 
  396             using BWarpDstr = 
typename WarpGemm::BWarpDstr;
 
  397             using CWarpDstr = 
typename WarpGemm::CWarpDstr;
 
  403             constexpr 
auto a_warp_y_lengths = 
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
 
  404             constexpr 
auto b_warp_y_lengths = 
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
 
  405             constexpr 
auto c_warp_y_lengths = 
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
 
  408             w_a.get_thread_buffer() = t_a.get_y_sliced_thread_data(
 
  413             w_b.get_thread_buffer() = t_b.get_y_sliced_thread_data(
 
  418             w_c.get_thread_buffer() = t_c.get_y_sliced_thread_data(
 
  424             t_c.set_y_sliced_thread_data(
 
  427                         w_c.get_thread_buffer());
 
  430         _Pragma(
"clang diagnostic pop");
 
  439         auto pipeline_gemm0 = [&]() {
 
  440             constexpr 
index_t total_loops = issues_gemm0;
 
  441             constexpr 
auto sr             = Policy::template GetSequencer_0<Problem>();
 
  442             static_assert(sr.size() == total_loops);
 
  444             constexpr 
auto c_sld_a_0 = 
MAKE_SC();
 
  445             constexpr 
auto c_gld_a_0 = 
MAKE_SC();
 
  446             constexpr 
auto c_gld_b_0 = 
MAKE_SC();
 
  449                 gemm_0(acc_0, as[I0], gs[I0], i_issue);
 
  450                 constexpr 
index_t slot = sr.at(i_issue);
 
  452                 if constexpr(slot & 
SLD_A)
 
  453                     sld_a(as[I1], a_sld_win1, 
number<
NEXT_SCI(c_sld_a_0, i_issue)>{});
 
  454                 if constexpr(slot & 
GLD_A)
 
  456                 if constexpr(slot & 
GLD_B)
 
  464             constexpr 
auto c_sld_a_1 = 
MAKE_SC();
 
  465             constexpr 
auto c_gld_a_1 = 
MAKE_SC();
 
  466             constexpr 
auto c_gld_b_1 = 
MAKE_SC();
 
  470                 gemm_0(acc_0, as[I1], gs[I1], i_issue);
 
  471                 constexpr 
index_t slot = sr.at(i_issue);
 
  473                 if constexpr(slot & 
SLD_A)
 
  474                     sld_a(as[I0], a_sld_win0, 
number<
NEXT_SCI(c_sld_a_1, i_issue)>{});
 
  475                 if constexpr(slot & 
GLD_A)
 
  477                 if constexpr(slot & 
GLD_B)
 
  486         auto pipeline_gemm0_tail = [&]() {
 
  487             constexpr 
index_t total_loops = issues_gemm0;
 
  488             constexpr 
auto sr             = Policy::template GetSequencer_0<Problem>();
 
  489             static_assert(sr.size() == total_loops);
 
  491             constexpr 
auto c_gld_b_0 = 
MAKE_SC();
 
  495                 gemm_0(acc_0, as[I0], gs[I0], i_issue);
 
  496                 constexpr 
index_t slot = sr.at(i_issue);
 
  498                 if constexpr(slot & 
GLD_B)
 
  503             sld_a(as[I1], a_sld_win1, NEG1);
 
  507                 constexpr 
auto last_nop = [&]() {
 
  508                     if constexpr(i_issue == (total_loops - 1))
 
  513                 gemm_0(acc_0, as[I1], gs[I1], i_issue, last_nop); 
 
  517         auto y = Policy::template MakeYBlockTile<Problem>();
 
  519         auto pipeline_bridge = [&]() {
 
  521             auto y_pre = cast_tile<YDataType>(acc_0);
 
  530         auto pipeline_gemm1 = [&]() {
 
  531             constexpr 
index_t total_loops = issues_gemm1;
 
  532             constexpr 
auto sr             = Policy::template GetSequencer_1<Problem>();
 
  533             static_assert(sr.size() == total_loops);
 
  535             constexpr 
auto c_gld_b_0 = 
MAKE_SC();
 
  536             constexpr 
auto c_gst_o_0 = 
MAKE_SC();
 
  537             constexpr 
auto c_gld_b_1 = 
MAKE_SC();
 
  538             constexpr 
auto c_gst_o_1 = 
MAKE_SC();
 
  542                 gemm_1(acc_1s[I1], y, ds[I1], i_issue);
 
  543                 constexpr 
index_t slot = sr.at(i_issue);
 
  544                 if constexpr(slot & 
GLD_B)
 
  547                 if constexpr(slot & 
GST_O)
 
  549                     auto out = cast_tile<ODataType>(acc_1s[I0]);
 
  558                 gemm_1(acc_1s[I0], y, ds[I0], i_issue);
 
  559                 constexpr 
index_t slot = sr.at(i_issue);
 
  560                 if constexpr(slot & 
GLD_B)
 
  563                 if constexpr(slot & 
GST_O)
 
  565                     auto out = cast_tile<ODataType>(acc_1s[I1]);
 
  572         auto pipeline_gemm1_head = [&]() {
 
  573             constexpr 
index_t total_loops = issues_gemm1;
 
  574             constexpr 
auto sr             = Policy::template GetSequencer_1<Problem>();
 
  575             static_assert(sr.size() == total_loops);
 
  577             constexpr 
auto c_gld_b_0 = 
MAKE_SC();
 
  581                 gemm_1(acc_1s[I0], y, ds[I0], i_issue);
 
  582                 constexpr 
index_t slot = sr.at(i_issue);
 
  583                 if constexpr(slot & 
GLD_B)
 
  588         auto pipeline_gemm1_tail = [&]() {
 
  589             constexpr 
index_t total_loops = issues_gemm1;
 
  590             constexpr 
auto sr             = Policy::template GetSequencer_1<Problem>();
 
  591             static_assert(sr.size() == total_loops);
 
  593             constexpr 
auto c_gst_o_0 = 
MAKE_SC();
 
  597                 gemm_1(acc_1s[I1], y, ds[I1], i_issue);
 
  599                 constexpr 
index_t slot = sr.at(i_issue);
 
  600                 if constexpr(slot & 
GST_O)
 
  602                     auto out = cast_tile<ODataType>(acc_1s[I0]);
 
  607                 auto out = cast_tile<ODataType>(acc_1s[I1]);
 
  608                 atomic_add_o(out, NEG1);
 
  614         gld_a(a_sst_win0, NEG1, TRUE);
 
  615         gld_g(gs[I0], NEG1, TRUE);
 
  621         gld_a(a_sst_win1, NEG1); 
 
  629         const index_t iters_0 = (num_blocks_k0 - 2) / 2;
 
  631         while(i_0++ < iters_0)
 
  635         pipeline_gemm0_tail();
 
  639         const index_t iters_1 = (num_blocks_n1 - 2) / 2;
 
  641         pipeline_gemm1_head();
 
  642         while(i_1++ < iters_1)
 
  646         pipeline_gemm1_tail();
 
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_LDS_ADDR
Definition: config.hpp:56
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:420
Definition: cluster_descriptor.hpp:13
CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_ &&lds_tile, const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: load_tile.hpp:149
tuple_array< T, N > statically_indexed_array
Definition: statically_indexed_array.hpp:16
int32_t index_t
Definition: integer.hpp:9
CK_TILE_DEVICE void lds_load_fence(index_t cnt=0)
Definition: amd_buffer_addressing.hpp:624
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:480
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:27
constexpr CK_TILE_HOST_DEVICE auto to_sequence(tuple< number< Is >... >)
Definition: sequence.hpp:1046
constexpr CK_TILE_HOST_DEVICE auto merge_sequences(Seqs...)
Definition: sequence.hpp:817
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
CK_TILE_DEVICE auto load_tile_raw(T &tile, const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Loads a tile of data using inline assembly.
Definition: load_tile.hpp:106
CK_TILE_DEVICE void move_tile_window(null_tile_window< WindowLengths > &, const typename null_tile_window< WindowLengths >::BottomTensorIndex &)
Definition: null_tile_window.hpp:92
constexpr CK_TILE_DEVICE auto make_tile_window_linear(const TensorView_ &tensor_view, const WindowLengths_ &window_lengths, const multi_index< TensorView_::get_num_of_dimension()> &origin, const StaticTileDistribution_ &tile_distribution, LinearBottomDims_={})
Definition: tile_window_linear.hpp:1124
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:400
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:145
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt=0)
Definition: arch.hpp:95
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:498
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1017
CK_TILE_DEVICE void update_tile_raw(tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor, number< i_access >={}, bool_constant< oob_conditional_check >={}, bool_constant< pre_nop >={})
Definition: update_tile.hpp:68
#define NEXT_SCI(c_, static_i_)
Definition: static_counter.hpp:109
#define MAKE_SC()
Definition: static_counter.hpp:104
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:23
static constexpr index_t SLD_A
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:54
static constexpr index_t kAlignmentA
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:49
static constexpr bool PadHiddenSize
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:46
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize_A()
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:72
typename Problem::DScaleDataType DScaleDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:36
typename Problem::BlockShape BlockShape
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:27
static constexpr index_t kAlignmentO
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:52
typename Problem::IndexDataType IndexDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:39
static constexpr index_t kBlockPerCu
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:59
remove_cvref_t< Policy_ > Policy
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:25
static constexpr const char * name
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:69
static constexpr index_t GLD_B
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:56
static constexpr index_t GLD_A
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:55
remove_cvref_t< Problem_ > Problem
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:24
typename Problem::ADataType ADataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:29
static CK_TILE_HOST_DEVICE auto GetOCoord()
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:91
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:77
typename Problem::GDataType GDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:30
static constexpr index_t kAlignmentG
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:50
typename Problem::DDataType DDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:31
static CK_TILE_HOST_DEVICE auto GetACoord()
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:83
typename Problem::ODataType ODataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:33
typename Problem::GScaleDataType GScaleDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:35
typename Problem::AccDataType AccDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:32
static constexpr index_t GST_O
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:57
typename Problem::TopkWeightDataType TopkWeightDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:38
static constexpr index_t kAlignmentD
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:51
typename Problem::YDataType YDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:40
static constexpr bool IsGateOnly
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:44
typename Problem::Traits Traits
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:42
typename Problem::YSmoothScaleDataType YSmoothScaleDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:37
static constexpr bool PadIntermediateSize
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:47
typename Problem::AScaleDataType AScaleDataType
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:34
static constexpr bool UseSmoothQuant
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:45
CK_TILE_DEVICE auto operator()(const AWindow &a_window_, const GWindow &g_window_, const DWindow &d_window_, OWindow &o_window_, TopkWeightDataType, CK_TILE_LDS_ADDR void *smem, index_t hidden_size, index_t intermediate_size)
Definition: fused_moegemm_pipeline_flatmm_ex.hpp:99
Definition: integral_constant.hpp:13
static constexpr value_type value
Definition: integral_constant.hpp:16
Definition: sequence.hpp:52
Definition: functional.hpp:43
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192