22     template <
typename Problem>
 
   27         constexpr 
index_t data_bytes = 
sizeof(
typename Problem::ADataType);
 
   28         static_assert(copy_bytes % data_bytes == 0);
 
   29         return copy_bytes / data_bytes;
 
   32     template <
typename Problem>
 
   35         constexpr 
index_t copy_bytes = [&]() { 
return 16; }();
 
   36         constexpr 
index_t data_bytes = 
sizeof(
typename Problem::GDataType);
 
   37         static_assert(copy_bytes % data_bytes == 0);
 
   38         return copy_bytes / data_bytes;
 
   41     template <
typename Problem>
 
   44         constexpr 
index_t copy_bytes = [&]() { 
return 16; }();
 
   45         constexpr 
index_t data_bytes = 
sizeof(
typename Problem::DDataType);
 
   46         static_assert(copy_bytes % data_bytes == 0);
 
   47         return copy_bytes / data_bytes;
 
   50     template <
typename Problem>
 
   53         if constexpr(Problem::Traits::OAtomic == 1)
 
   56             static_assert(
sizeof(
typename Problem::ODataType) == 2);
 
   59         else if constexpr(Problem::Traits::OAtomic == 2)
 
   66             return 16 / 
sizeof(
typename Problem::ODataType);
 
   70     template <
typename DataType_>
 
   77     template <
typename Problem>
 
   80         return GetSmemKPack<typename Problem::ADataType>();
 
   84     template <
typename Problem>
 
   88         return 16 / 
sizeof(
typename Problem::YDataType);
 
   91     template <
typename Problem>
 
   94         constexpr 
auto a_sld_desc = MakeLdsLoadDesc_A<Problem>();
 
   95         constexpr 
auto a_sst_desc = MakeLdsStoreDesc_A<Problem>();
 
   96         static_assert(a_sld_desc.get_element_space_size() == a_sst_desc.get_element_space_size());
 
   97         return a_sld_desc.get_element_space_size();
 
  100     template <
typename Problem>
 
  103         constexpr 
auto bridge_sld_desc = MakeBridgeLdsLoadDesc<Problem>();
 
  104         constexpr 
auto bridge_sst_desc = MakeBridgeLdsStoreDesc<Problem>();
 
  105         static_assert(bridge_sld_desc.get_element_space_size() ==
 
  106                       bridge_sst_desc.get_element_space_size());
 
  107         return bridge_sld_desc.get_element_space_size();
 
  110     template <
typename Problem>
 
  113         constexpr 
index_t a_lds      = GetSmemSize_A<Problem>();
 
  114         constexpr 
index_t bridge_lds = GetSmemSize_Bridge<Problem>();
 
  115         return max(a_lds, bridge_lds);
 
  118     template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment>
 
  121         constexpr 
index_t K_vec = Alignment;
 
  122         constexpr 
index_t K_rem = KPerBlock / K_vec;
 
  129             static_assert(K_wav <= NumWarps, 
"not not support thread has repeat along K yet");
 
  130             constexpr 
index_t M_wav = NumWarps / K_wav;
 
  131             static_assert(MPerBlock % M_wav == 0, 
"this tile size is too small please check");
 
  132             constexpr 
index_t M_rep = MPerBlock / M_wav;
 
  145             constexpr 
index_t K_lan = K_rem;
 
  147             constexpr 
index_t M_wav = NumWarps;
 
  148             static_assert(MPerBlock % (M_lan * M_wav) == 0,
 
  149                           "this tile size is too small please check");
 
  150             constexpr 
index_t M_rep = MPerBlock / (M_lan * M_wav);
 
  163     template <index_t MPerBlock, index_t KPerBlock, index_t NumWarps, index_t Alignment>
 
  166         constexpr 
index_t K_vec = Alignment;
 
  167         constexpr 
index_t K_rem = KPerBlock / K_vec;
 
  174             static_assert(K_wav <= NumWarps, 
"do not support thread has repeat along K yet");
 
  175             constexpr 
index_t M_wav = NumWarps / K_wav;
 
  176             static_assert(MPerBlock % M_wav == 0, 
"this tile size is too small please check");
 
  177             constexpr 
index_t M_rep = MPerBlock / M_wav;
 
  190             constexpr 
index_t K_lan = K_rem;
 
  192             constexpr 
index_t M_wav = NumWarps;
 
  193             static_assert(MPerBlock % (M_lan * M_wav) == 0,
 
  194                           "this tile size is too small please check");
 
  195             constexpr 
index_t M_rep = MPerBlock / (M_lan * M_wav);
 
  210     template <
index_t WarpPerBlock_N_,
 
  229     template <
typename Problem>
 
  232         constexpr 
index_t Block_M_   = Problem::BlockShape::Block_M0;
 
  233         constexpr 
index_t Block_K_   = Problem::BlockShape::Block_K0;
 
  234         constexpr 
index_t NumWarps_  = Problem::BlockShape::NumWarps;
 
  235         constexpr 
index_t Alignment_ = GetAlignment_A<Problem>();
 
  242     template <
typename Problem>
 
  245         constexpr 
auto PermuteEnum = Problem::Traits::PermuteEnum;
 
  247         using S_ = 
typename Problem::BlockShape;
 
  257                                                       GetAlignment_G<Problem>()>();
 
  261     template <
typename Problem>
 
  264         constexpr 
auto PermuteEnum = Problem::Traits::PermuteEnum;
 
  265         using S_                   = 
typename Problem::BlockShape;
 
  273                                                       GetAlignment_D<Problem>()>();
 
  277     template <
typename Problem>
 
  284         constexpr 
auto c_block_outer_dstr_encoding =
 
  294             c_block_outer_dstr_encoding, 
typename WarpGemm::CWarpDstrEncoding{});
 
  299     template <
typename Problem>
 
  303         constexpr 
index_t Block_M = Problem::BlockShape::Block_M0;
 
  304         constexpr 
index_t Block_K = Problem::BlockShape::Block_K0;
 
  307         constexpr 
index_t NumWarps = Problem::BlockShape::NumWarps;
 
  309         constexpr 
index_t KPack   = GetSmemKPack_A<Problem>(); 
 
  310         constexpr 
index_t KVector = GetAlignment_A<Problem>(); 
 
  311         constexpr 
index_t KPad    = KPack;                     
 
  313         static_assert(Block_K % KVector == 0);
 
  314         constexpr 
index_t LanesPerK = Block_K / KVector; 
 
  315         if constexpr(LanesPerK >= WarpSize)
 
  318             static_assert(LanesPerK % WarpSize == 0);
 
  319             constexpr 
index_t wavesPerK = LanesPerK / WarpSize;
 
  320             if constexpr(wavesPerK > NumWarps)
 
  326                 constexpr 
index_t wavesPerM     = NumWarps / wavesPerK;
 
  327                 constexpr 
index_t NumIssues     = Block_M / wavesPerM;
 
  335                                number<wavesPerK*(WarpSize * KVector + KPad)>{}, 
 
  351                 return lds_block_desc_issues_warps_lanes;
 
  357             static_assert(WarpSize % LanesPerK == 0);
 
  358             constexpr 
index_t LaneGroups = WarpSize / LanesPerK; 
 
  359             constexpr 
index_t NumIssues  = Block_M / (LaneGroups * NumWarps);
 
  384             return lds_block_desc_issues_warps_lanes;
 
  388     template <
typename Problem>
 
  398         constexpr 
index_t Block_M = Problem::BlockShape::Block_M0;
 
  399         constexpr 
index_t Block_K = Problem::BlockShape::Block_K0;
 
  402         constexpr 
index_t NumWarps = Problem::BlockShape::NumWarps;
 
  404         constexpr 
index_t KPack   = GetSmemKPack_A<Problem>(); 
 
  405         constexpr 
index_t KVector = GetAlignment_A<Problem>(); 
 
  406         constexpr 
index_t KPad    = KPack;                     
 
  408         static_assert(Block_K % KVector == 0);
 
  409         constexpr 
index_t LanesPerK = Block_K / KVector; 
 
  410         if constexpr(LanesPerK >= WarpSize)
 
  413             static_assert(LanesPerK % WarpSize == 0);
 
  414             constexpr 
index_t wavesPerK = LanesPerK / WarpSize;
 
  415             if constexpr(wavesPerK >= NumWarps)
 
  421                 constexpr 
index_t wavesPerM     = NumWarps / wavesPerK;
 
  422                 constexpr 
index_t NumIssues     = Block_M / wavesPerM;
 
  430                                number<wavesPerK*(WarpSize * KVector + KPad)>{}, 
 
  452             static_assert(WarpSize % LanesPerK == 0);
 
  453             constexpr 
index_t LaneGroups = WarpSize / LanesPerK; 
 
  454             constexpr 
index_t NumIssues  = Block_M / (LaneGroups * NumWarps);
 
  483     template <
typename Problem>
 
  486         constexpr 
index_t Block_M = Problem::BlockShape::Block_M0;
 
  487         constexpr 
index_t Block_N = Problem::BlockShape::Block_N0;
 
  489         constexpr 
index_t KVector = GetSmemKPack_Y<Problem>(); 
 
  492         constexpr 
auto desc =
 
  500     template <
typename Problem>
 
  503         constexpr 
index_t Block_M = Problem::BlockShape::Block_M0;
 
  504         constexpr 
index_t Block_N = Problem::BlockShape::Block_N0;
 
  506         constexpr 
index_t KVector = GetSmemKPack_Y<Problem>(); 
 
  509         constexpr 
auto desc =
 
  517     template <
typename Problem>
 
  520         constexpr 
index_t WarpPerBlock_N = Problem::BlockShape::WarpPerBlock_N0;
 
  521         constexpr 
index_t Repeat_N       = Problem::BlockShape::Repeat_N0;
 
  522         constexpr 
index_t Repeat_M       = Problem::BlockShape::Repeat_M0;
 
  524         constexpr 
index_t kAMLane     = 16;
 
  525         constexpr 
index_t kABKLane    = 4;
 
  526         constexpr 
index_t kABKPerLane = 4;
 
  528         constexpr 
index_t KPack = kABKPerLane;
 
  559     template <
typename Problem>
 
  562         using S_ = 
typename Problem::BlockShape;
 
  567         if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
 
  568                      std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
 
  569                      S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
 
  575         else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::int8_t> &&
 
  576                           std::is_same_v<typename Problem::GDataType, ck_tile::int8_t> &&
 
  577                           S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
 
  585     template <
typename Problem>
 
  591         using S_                = 
typename Problem::BlockShape;
 
  595         if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
 
  596                      std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
 
  597                      S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16 &&
 
  598                      S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
 
  604             constexpr 
auto seq_all =
 
  617         else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
 
  618                           std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
 
  619                           S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16 &&
 
  620                           S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
 
  626             constexpr 
auto seq_all =
 
  637     template <
typename Problem>
 
  643         using S_                = 
typename Problem::BlockShape;
 
  646         if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
 
  647                      std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
 
  648                      S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 &&
 
  649                      S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
 
  655             constexpr 
auto seq_all =
 
  668         else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
 
  669                           std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
 
  670                           S_::Warp_M1 == 32 && S_::Warp_N1 == 32 && S_::Warp_K1 == 16 &&
 
  671                           S_::Block_M0 == 32 && S_::Block_N0 == 256 && S_::Block_K0 == 128 &&
 
  677             constexpr 
auto seq_all =
 
  688     template <
typename Problem>
 
  691         using S_               = 
typename Problem::BlockShape;
 
  694         if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
 
  695                      std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
 
  696                      S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 16)
 
  702         else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::int8_t> &&
 
  703                           std::is_same_v<typename Problem::DDataType, ck_tile::int8_t> &&
 
  704                           S_::Warp_M0 == 32 && S_::Warp_N0 == 32 && S_::Warp_K0 == 32)
 
  712     template <
typename Problem>
 
  717         using CDataType = 
typename WarpGemm::CDataType;
 
  719         constexpr 
auto c_block_outer_dstr_encoding =
 
  729             c_block_outer_dstr_encoding, 
typename WarpGemm::CWarpDstrEncoding{});
 
  731         auto c_block_tensor         = make_static_distributed_tensor<CDataType>(c_block_dstr);
 
  732         return c_block_tensor;
 
  735     template <
typename Problem>
 
  740         using CDataType = 
typename WarpGemm::CDataType;
 
  742         constexpr 
auto c_block_outer_dstr_encoding =
 
  752             c_block_outer_dstr_encoding, 
typename WarpGemm::CWarpDstrEncoding{});
 
  754         auto c_block_tensor         = make_static_distributed_tensor<CDataType>(c_block_dstr);
 
  755         return c_block_tensor;
 
  759     template <
typename Problem>
 
  766         constexpr 
auto y_outer_dstr_enc =
 
  775             y_outer_dstr_enc, 
typename WarpGemm::AWarpDstrEncoding{});
 
  780     template <
typename Problem>
 
  783         constexpr 
auto y_block_dstr = MakeYTileDistribution<Problem>();
 
  784         auto y_block_tensor =
 
  785             make_static_distributed_tensor<typename Problem::YDataType>(y_block_dstr);
 
  786         return y_block_tensor;
 
  789     template <
typename Problem>
 
  792         using S_ = 
typename Problem::BlockShape;
 
  793         if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::bf16_t> &&
 
  794                      std::is_same_v<typename Problem::GDataType, ck_tile::bf16_t> &&
 
  795                      S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
 
  796                      S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
 
  800         else if constexpr(std::is_same_v<typename Problem::ADataType, ck_tile::fp16_t> &&
 
  801                           std::is_same_v<typename Problem::GDataType, ck_tile::fp16_t> &&
 
  802                           S_::Block_M0 == 32 && S_::Block_N0 == 512 && S_::Block_K0 == 128 &&
 
  803                           S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32)
 
  809     template <
typename Problem>
 
  812         using S_ = 
typename Problem::BlockShape;
 
  813         using T_ = 
typename Problem::Traits;
 
  814         if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
 
  815                      std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
 
  816                      std::is_same_v<typename Problem::TopkWeightDataType, float> &&
 
  817                      S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
 
  818                      S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
 
  819                      T_::PipeInterleave == 
false)
 
  824         else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
 
  825                           std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
 
  826                           std::is_same_v<typename Problem::TopkWeightDataType, float> &&
 
  827                           S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
 
  828                           S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
 
  829                           T_::PipeInterleave == 
false)
 
  834         else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::bf16_t> &&
 
  835                           std::is_same_v<typename Problem::DDataType, ck_tile::bf16_t> &&
 
  836                           std::is_same_v<typename Problem::TopkWeightDataType, float> &&
 
  837                           S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
 
  838                           S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
 
  839                           T_::PipeInterleave == 
true)
 
  844         else if constexpr(std::is_same_v<typename Problem::YDataType, ck_tile::fp16_t> &&
 
  845                           std::is_same_v<typename Problem::DDataType, ck_tile::fp16_t> &&
 
  846                           std::is_same_v<typename Problem::TopkWeightDataType, float> &&
 
  847                           S_::Block_M1 == 32 && S_::Block_N1 == 128 && S_::Block_K1 == 512 &&
 
  848                           S_::Warp_M0 == 16 && S_::Warp_N0 == 16 && S_::Warp_K0 == 32 &&
 
  849                           T_::PipeInterleave == 
true)
 
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
 
constexpr CK_TILE_HOST_DEVICE auto make_embed_tile_distribution_encoding(OuterDstr, InnerDstr)
Definition: tile_distribution_encoding.hpp:539
 
Definition: cluster_descriptor.hpp:13
 
constexpr CK_TILE_HOST_DEVICE index_t get_warp_size()
Definition: arch.hpp:51
 
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:255
 
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1672
 
int32_t index_t
Definition: integer.hpp:9
 
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1615
 
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
 
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:184
 
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
 
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
 
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:498
 
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:385
 
Definition: flatmm_32x512x128_1x4x1_16x16x32.hpp:524
 
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp:18
 
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:74
 
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32_itl.hpp:265
 
Definition: flatmm_sn_32x128x512_1x4x1_16x16x32.hpp:318
 
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:15
 
static constexpr CK_TILE_HOST_DEVICE auto MakeGlobalTileDistribution_D()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:262
 
static constexpr CK_TILE_HOST_DEVICE auto MakeCBlockTile_Gemm1()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:736
 
static constexpr CK_TILE_HOST_DEVICE auto MakeGlobalTileDistribution_O()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:278
 
static constexpr CK_TILE_HOST_DEVICE auto GetSmemKPack_Y()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:85
 
static constexpr CK_TILE_HOST_DEVICE auto GetSmemKPack()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:71
 
static constexpr CK_TILE_HOST_DEVICE auto GetAlignment_G()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:33
 
static constexpr CK_TILE_HOST_DEVICE auto MakeBridgeLdsStoreDesc()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:501
 
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize_A()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:92
 
static constexpr CK_TILE_HOST_DEVICE auto MakeYTileDistribution()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:760
 
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize_Bridge()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:101
 
static constexpr CK_TILE_HOST_DEVICE auto MakeGlobalTileDistribution_SimpleMxK_Async()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:164
 
static constexpr CK_TILE_HOST_DEVICE auto MakeGlobalTileDistribution_SimpleMxK()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:119
 
static constexpr CK_TILE_HOST_DEVICE auto GetWarpGemm0()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:560
 
static constexpr CK_TILE_HOST_DEVICE index_t GetAsyncCopyDwords()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:16
 
static constexpr CK_TILE_HOST_DEVICE auto GetAlignment_O()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:51
 
static constexpr CK_TILE_HOST_DEVICE auto GetAlignment_D()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:42
 
static constexpr CK_TILE_HOST_DEVICE auto MakeBridgeLdsStoreForUKDesc()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:518
 
static constexpr CK_TILE_HOST_DEVICE auto GetAlignment_A()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:23
 
static constexpr CK_TILE_HOST_DEVICE auto GetSequencer_0()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:586
 
static constexpr CK_TILE_HOST_DEVICE auto MakeYBlockTile()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:781
 
static constexpr CK_TILE_HOST_DEVICE auto MakeGlobalTileDistribution_A()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:230
 
static constexpr CK_TILE_HOST_DEVICE auto MakeLdsStoreDesc_A()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:300
 
static constexpr CK_TILE_HOST_DEVICE auto GetSequencer_1()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:638
 
static constexpr CK_TILE_HOST_DEVICE auto MakeBridgeLdsLoadDesc()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:484
 
static constexpr CK_TILE_HOST_DEVICE ck_tile::index_t GetSmemSize()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:111
 
static constexpr CK_TILE_HOST_DEVICE auto MakeLdsLoadDesc_A()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:389
 
static constexpr CK_TILE_HOST_DEVICE auto MakeCBlockTile_Gemm0()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:713
 
static constexpr CK_TILE_HOST_DEVICE auto GetUK_1()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:810
 
static constexpr CK_TILE_HOST_DEVICE auto GetSmemKPack_A()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:78
 
static constexpr CK_TILE_HOST_DEVICE auto MakeGlobalTileDistribution_Nr_Kr_W()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:216
 
static constexpr CK_TILE_HOST_DEVICE auto GetUK_0()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:790
 
static constexpr CK_TILE_HOST_DEVICE auto MakeGlobalTileDistribution_G()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:243
 
static constexpr CK_TILE_HOST_DEVICE auto GetWarpGemm1()
Definition: fused_moegemm_pipeline_flatmm_policy.hpp:689
 
Definition: warp_gemm_attribute_mfma.hpp:670
 
Definition: warp_gemm_attribute_mfma_impl.hpp:1544
 
Definition: warp_gemm_attribute_mfma_impl.hpp:448
 
Definition: warp_gemm_impl.hpp:11
 
Definition: integral_constant.hpp:13
 
Definition: sequence.hpp:52
 
Definition: tile_distribution_encoding.hpp:26
 
Definition: tuple.hpp:192