15 template <index_t NumDTensor = 0>
 
   21                                 const std::array<const void*, NumDTensor>& ds_ptr_,
 
   29                                 const std::array<index_t, NumDTensor>& stride_Ds_,
 
   48     const std::array<const void*, NumDTensor> 
ds_ptr;
 
   69 template <index_t NumDTensor = 0>
 
   75     const std::array<const void*, NumDTensor> 
ds_ptr;
 
   87 template <
typename TilePartitioner_, 
typename FlatmmPipeline_, 
typename EpiloguePipeline_>
 
  114     static_assert(DsLayout::size() == DsDataType::size(),
 
  115                   "The size of DsLayout and DsDataType should be the same");
 
  121         return concat(
'_', 
"gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
 
  127         return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
 
  151         return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
 
  158             constexpr 
auto K1   = TilePartitioner::BlockGemmShape::WarpTile::at(
number<2>{});
 
  160             const index_t KRead = (kargs.
K + K_t - 1) / K_t * K1;
 
  162             if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
 
  166             else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
 
  171             if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
 
  175             else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
 
  180             if(k_id < 
static_cast<uint32_t
>(kargs.
k_batch - 1))
 
  197         if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
 
  202                 std::cerr << 
"Conditions not met for Kbatch >1 !" << std::endl;
 
  207         if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
 
  209             if(kargs.
K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == 
false)
 
  211                 std::cerr << 
"Can't support K that is not a multiple of KPerBlock" 
  216             if(kargs.
K % FlatmmPipeline::GetVectorSizeA() != 0)
 
  218                 std::cerr << 
"K is not a multiple of vector load size for A tensor!" << std::endl;
 
  224             if(kargs.
M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == 
false)
 
  226                 std::cerr << 
"Can't support M that is not a multiple of MPerBlock" 
  231             if(kargs.
M % FlatmmPipeline::GetVectorSizeA() != 0)
 
  233                 std::cerr << 
"M is not a multiple of vector load size for A tensor!" << std::endl;
 
  238         if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
 
  240             if(kargs.
N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == 
false)
 
  242                 std::cerr << 
"Can't support N that is not a multiple of NPerBlock" 
  247             if(kargs.
N % FlatmmPipeline::GetVectorSizeB() != 0)
 
  249                 std::cerr << 
"N is not a multiple of vector load size for B tensor!" << std::endl;
 
  255             if(kargs.
K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK == 
false)
 
  257                 std::cerr << 
"Can't support K that is not a multiple of KPerBlock" 
  262             if(kargs.
K % FlatmmPipeline::GetVectorSizeB() != 0)
 
  264                 std::cerr << 
"K is not a multiple of vector load size for B tensor!" << std::endl;
 
  269         bool DTesnorIsValid = {
true};
 
  272             if(std::is_same_v<DiLayout, ELayout> == 
false)
 
  274                 DTesnorIsValid = 
false;
 
  276             if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
 
  278                 if(kargs.
N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == 
false)
 
  280                     CK_TILE_ERROR(
"Can't support N for tensor D that is not a multiple of " 
  281                                   "NPerBlock without padding!");
 
  282                     DTesnorIsValid = 
false;
 
  284                 if(kargs.
N % EpiloguePipeline::GetVectorSizeD(index) != 0)
 
  286                     CK_TILE_ERROR(
"N is not a multiple of vector load size for D tensor!");
 
  287                     DTesnorIsValid = 
false;
 
  292                 if(kargs.
M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == 
false)
 
  294                     CK_TILE_ERROR(
"Can't support M for tensor D that is not a multiple of " 
  295                                   "MPerBlock without padding!");
 
  297                     DTesnorIsValid = 
false;
 
  299                 if(kargs.
M % EpiloguePipeline::GetVectorSizeD(index) != 0)
 
  301                     CK_TILE_ERROR(
"M is not a multiple of vector load size for D tensor!");
 
  302                     DTesnorIsValid = 
false;
 
  307         if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
 
  309             if(kargs.
N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN == 
false)
 
  311                 std::cerr << 
"Can't support N that is not a multiple of NPerBlock" 
  316             if(kargs.
N % EpiloguePipeline::GetVectorSizeC() != 0)
 
  318                 std::cerr << 
"N is not a multiple of vector load size for C tensor!" << std::endl;
 
  324             if(kargs.
M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM == 
false)
 
  326                 std::cerr << 
"Can't support M that is not a multiple of MPerBlock" 
  331             if(kargs.
M % EpiloguePipeline::GetVectorSizeC() != 0)
 
  333                 std::cerr << 
"M is not a multiple of vector load size for C tensor!" << std::endl;
 
  337         return DTesnorIsValid;
 
  340     template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
 
  344                         const std::array<const void*, NumDTensor>& ds_ptr,
 
  349         const auto& a_tensor_view = [&]() {
 
  350             if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
 
  352                 return make_naive_tensor_view<address_space_enum::global>(
 
  356                     number<FlatmmPipeline::GetVectorSizeA()>{},
 
  361                 return make_naive_tensor_view<address_space_enum::global>(
 
  365                     number<FlatmmPipeline::GetVectorSizeA()>{},
 
  370         index_t kFlatK = FlatmmPipeline::flatKPerWarp * (splitk_batch_offset.
splitted_k /
 
  371                                                          BlockGemmShape::WarpTile::at(
number<2>{}));
 
  372         index_t kFlatN = kargs.
N * kargs.
K / kFlatK;
 
  373         const auto& b_flat_tensor_view = [&]() {
 
  374             return make_naive_tensor_view<address_space_enum::global>(
 
  378                 number<FlatmmPipeline::GetVectorSizeB()>{},
 
  386                 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
 
  388                     return make_naive_tensor_view<address_space_enum::global>(
 
  389                         static_cast<const DDataType_*
>(ds_ptr[i]),
 
  392                         number<EpiloguePipeline::GetVectorSizeD(i)>{},
 
  397                     return make_naive_tensor_view<address_space_enum::global>(
 
  398                         static_cast<const DDataType_*
>(ds_ptr[i]),
 
  401                         number<EpiloguePipeline::GetVectorSizeD(i)>{},
 
  408         const auto& e_tensor_view = [&]() {
 
  409             if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
 
  411                 return make_naive_tensor_view<address_space_enum::global>(
 
  415                     number<EpiloguePipeline::GetVectorSizeC()>{},
 
  420                 return make_naive_tensor_view<address_space_enum::global>(
 
  429         return make_tuple(a_tensor_view, b_flat_tensor_view, ds_tensor_view, e_tensor_view);
 
  432     template <
typename TensorView>
 
  435         const auto& a_pad_view = [&]() {
 
  436             const auto& a_tensor_view = views.at(
I0);
 
  437             if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
 
  453         const auto& b_flat_tensor_view = views.at(
I1);
 
  457                 const auto& d_tensor_view = views.at(
I2);
 
  459                 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
 
  477         const auto& e_pad_view = [&]() {
 
  478             const auto& e_tensor_view = views.at(
I3);
 
  479             if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
 
  495         return make_tuple(a_pad_view, b_flat_tensor_view, ds_pad_view, e_pad_view);
 
  498     template <
typename PadView>
 
  502         const auto& a_pad_view      = views.at(
I0);
 
  503         const auto& b_flat_pad_view = views.at(
I1);
 
  504         const auto& ds_pad_view     = views.at(
I2);
 
  505         const auto& e_pad_view      = views.at(
I3);
 
  507         const auto& a_block_window = [&]() {
 
  508             if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
 
  524         const auto& b_flat_block_window =
 
  528                              {
static_cast<int>(i_n / BlockGemmShape::WarpTile::at(
I1)), 0});
 
  533                 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
 
  555         return make_tuple(a_block_window, b_flat_block_window, ds_block_window, e_block_window);
 
  558     template <
bool UseDefaultScheduler = true>
 
  561                                          const std::array<const void*, NumDTensor>& ds_ptr,
 
  570         const auto& gemm_tensor_views_tuple =
 
  571             MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
 
  572                 a_ptr, b_flat_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
 
  576         const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.
splitted_k);
 
  579         const auto& a_block_window      = gemm_tile_windows.at(
I0);
 
  580         const auto& b_flat_block_window = gemm_tile_windows.at(
I1);
 
  581         const auto& d_block_window      = gemm_tile_windows.at(
I2);
 
  583             a_block_window, b_flat_block_window, num_loop, smem_ptr);
 
  584         if(UseDefaultScheduler || (get_warp_id() == 0))
 
  587             auto& c_block_window = gemm_tile_windows.at(
I3);
 
  590             operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
 
  591                 c_block_window, c_block_tile, d_block_window, smem_ptr);
 
  597         const auto [iM, iN] = 
TilePartitioner{kargs.
M, kargs.
N}.GetOutputTileIndex(blockIdx.x);
 
  598         const index_t i_m   = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
 
  599         const index_t i_n   = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
 
  613                        EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
 
  616             constexpr 
auto scheduler_type = (FlatmmPipeline::NumWaveGroups == 1);
 
  617             RunFlatmm<scheduler_type>(a_ptr,
 
#define CK_TILE_DEVICE
Definition: config.hpp:40
 
#define CK_TILE_HOST
Definition: config.hpp:39
 
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
 
Definition: cluster_descriptor.hpp:13
 
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
 
int32_t index_t
Definition: integer.hpp:9
 
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:529
 
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
 
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
 
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
 
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:412
 
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:343
 
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
 
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
 
__device__ X atomic_add(X *p_dst, const X &x)
 
Definition: flatmm_kernel.hpp:17
 
index_t stride_C
Definition: flatmm_kernel.hpp:63
 
index_t stride_A
Definition: flatmm_kernel.hpp:57
 
CK_TILE_HOST FlatmmHostArgs(const void *a_ptr_, const void *b_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition: flatmm_kernel.hpp:19
 
index_t K
Definition: flatmm_kernel.hpp:56
 
index_t stride_E
Definition: flatmm_kernel.hpp:62
 
const void * b_ptr
Definition: flatmm_kernel.hpp:47
 
void * c_ptr
Definition: flatmm_kernel.hpp:52
 
CK_TILE_HOST FlatmmHostArgs()=default
 
void * e_ptr
Definition: flatmm_kernel.hpp:51
 
const std::array< index_t, NumDTensor > stride_Ds
Definition: flatmm_kernel.hpp:59
 
const void * a_ptr
Definition: flatmm_kernel.hpp:46
 
index_t N
Definition: flatmm_kernel.hpp:55
 
index_t stride_B
Definition: flatmm_kernel.hpp:58
 
index_t k_batch
Definition: flatmm_kernel.hpp:66
 
index_t M
Definition: flatmm_kernel.hpp:54
 
const std::array< const void *, NumDTensor > ds_ptr
Definition: flatmm_kernel.hpp:48
 
Definition: flatmm_kernel.hpp:155
 
index_t b_k_split_offset
Definition: flatmm_kernel.hpp:191
 
index_t a_k_split_offset
Definition: flatmm_kernel.hpp:190
 
index_t splitted_k
Definition: flatmm_kernel.hpp:192
 
__device__ SplitKBatchOffset(const KernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: flatmm_kernel.hpp:156
 
Definition: flatmm_kernel.hpp:71
 
index_t N
Definition: flatmm_kernel.hpp:78
 
index_t K
Definition: flatmm_kernel.hpp:79
 
void * e_ptr
Definition: flatmm_kernel.hpp:76
 
index_t k_batch
Definition: flatmm_kernel.hpp:84
 
const std::array< const void *, NumDTensor > ds_ptr
Definition: flatmm_kernel.hpp:75
 
index_t M
Definition: flatmm_kernel.hpp:77
 
const void * a_ptr
Definition: flatmm_kernel.hpp:72
 
index_t stride_A
Definition: flatmm_kernel.hpp:80
 
index_t stride_E
Definition: flatmm_kernel.hpp:83
 
index_t stride_B
Definition: flatmm_kernel.hpp:81
 
const void * b_ptr
Definition: flatmm_kernel.hpp:74
 
std::array< index_t, NumDTensor > stride_Ds
Definition: flatmm_kernel.hpp:82
 
Definition: flatmm_kernel.hpp:89
 
FlatmmKernelArgs< DsLayout::size()> KernelArgs
Definition: flatmm_kernel.hpp:116
 
static constexpr CK_TILE_HOST auto BlockSize()
Definition: flatmm_kernel.hpp:130
 
remove_cvref_t< typename FlatmmPipeline::BlockGemmShape > BlockGemmShape
Definition: flatmm_kernel.hpp:93
 
static constexpr auto I0
Definition: flatmm_kernel.hpp:109
 
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: flatmm_kernel.hpp:90
 
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: flatmm_kernel.hpp:98
 
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: flatmm_kernel.hpp:99
 
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: flatmm_kernel.hpp:105
 
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: flatmm_kernel.hpp:342
 
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: flatmm_kernel.hpp:94
 
static constexpr auto I2
Definition: flatmm_kernel.hpp:111
 
static CK_TILE_DEVICE void RunFlatmm(const ADataType *a_ptr, const BDataType *b_flat_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition: flatmm_kernel.hpp:559
 
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: flatmm_kernel.hpp:433
 
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: flatmm_kernel.hpp:195
 
remove_cvref_t< FlatmmPipeline_ > FlatmmPipeline
Definition: flatmm_kernel.hpp:91
 
remove_cvref_t< typename FlatmmPipeline::CLayout > ELayout
Definition: flatmm_kernel.hpp:97
 
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition: flatmm_kernel.hpp:102
 
remove_cvref_t< typename FlatmmPipeline::BLayout > BLayout
Definition: flatmm_kernel.hpp:96
 
static constexpr index_t NumDTensor
Definition: flatmm_kernel.hpp:107
 
static CK_TILE_HOST const std::string GetName()
Definition: flatmm_kernel.hpp:118
 
static constexpr CK_TILE_HOST KernelArgs MakeKernelArgs(const FlatmmHostArgs< NumDTensor > &hostArgs)
Definition: flatmm_kernel.hpp:133
 
remove_cvref_t< typename FlatmmPipeline::ALayout > ALayout
Definition: flatmm_kernel.hpp:95
 
static constexpr auto I3
Definition: flatmm_kernel.hpp:112
 
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: flatmm_kernel.hpp:149
 
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: flatmm_kernel.hpp:125
 
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: flatmm_kernel.hpp:500
 
static constexpr auto I1
Definition: flatmm_kernel.hpp:110
 
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition: flatmm_kernel.hpp:595
 
static constexpr index_t KernelBlockSize
Definition: flatmm_kernel.hpp:100
 
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition: flatmm_kernel.hpp:103
 
Definition: integral_constant.hpp:13
 
Definition: type_traits.hpp:115
 
Definition: sequence.hpp:52
 
Definition: functional.hpp:43