36 const void* b_shuffle_ptr_,
59 template <
typename TilePartitioner_,
typename FlatmmPipeline_,
typename EpiloguePipeline_>
87 return concat(
'_',
"gemm", gemm_prec_str<ADataType, BDataType>, FlatmmPipeline::GetName());
93 return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
128 return max(FlatmmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
134 const std::size_t k_id = blockIdx.z)
136 constexpr
auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(
number<2>{});
138 const index_t KRead = (kargs.
K + K_t - 1) / K_t * K1;
140 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
144 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
149 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
153 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
158 if(k_id <
static_cast<uint32_t
>(kargs.
k_batch - 1))
175 if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
180 std::cerr <<
"Conditions not met for Kbatch >1 !" << std::endl;
185 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
187 if(kargs.
K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK ==
false)
189 std::cerr <<
"Can't support K that is not a multiple of KPerBlock"
194 if(kargs.
K % FlatmmPipeline::GetVectorSizeA() != 0)
196 std::cerr <<
"K is not a multiple of vector load size for A tensor!" << std::endl;
202 if(kargs.
M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM ==
false)
204 std::cerr <<
"Can't support M that is not a multiple of MPerBlock"
209 if(kargs.
M % FlatmmPipeline::GetVectorSizeA() != 0)
211 std::cerr <<
"M is not a multiple of vector load size for A tensor!" << std::endl;
216 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
218 if(kargs.
N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN ==
false)
220 std::cerr <<
"Can't support N that is not a multiple of NPerBlock"
225 if(kargs.
N % FlatmmPipeline::GetVectorSizeB() != 0)
227 std::cerr <<
"N is not a multiple of vector load size for B tensor!" << std::endl;
233 if(kargs.
K % TilePartitioner::KPerBlock != 0 && FlatmmPipeline::kPadK ==
false)
235 std::cerr <<
"Can't support K that is not a multiple of KPerBlock"
240 if(kargs.
K % FlatmmPipeline::GetVectorSizeB() != 0)
242 std::cerr <<
"K is not a multiple of vector load size for B tensor!" << std::endl;
247 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
249 if(kargs.
N % TilePartitioner::NPerBlock != 0 && FlatmmPipeline::kPadN ==
false)
251 std::cerr <<
"Can't support N that is not a multiple of NPerBlock"
256 if(kargs.
N % EpiloguePipeline::GetVectorSizeC() != 0)
258 std::cerr <<
"N is not a multiple of vector load size for C tensor!" << std::endl;
264 if(kargs.
M % TilePartitioner::MPerBlock != 0 && FlatmmPipeline::kPadM ==
false)
266 std::cerr <<
"Can't support M that is not a multiple of MPerBlock"
271 if(kargs.
M % EpiloguePipeline::GetVectorSizeC() != 0)
273 std::cerr <<
"M is not a multiple of vector load size for C tensor!" << std::endl;
280 template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
287 const auto& a_tensor_view = [&]() {
288 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
290 return make_naive_tensor_view<address_space_enum::global>(
294 number<FlatmmPipeline::GetVectorSizeA()>{},
299 return make_naive_tensor_view<address_space_enum::global>(
303 number<FlatmmPipeline::GetVectorSizeA()>{},
308 index_t kFlatK = FlatmmPipeline::flatKPerWarp * (splitk_batch_offset.
splitted_k /
309 BlockGemmShape::WarpTile::at(
number<2>{}));
310 index_t kFlatN = kargs.
N * kargs.
K / kFlatK;
311 const auto& b_flat_tensor_view = [&]() {
312 return make_naive_tensor_view<address_space_enum::global>(
316 number<FlatmmPipeline::GetVectorSizeB()>{},
321 const auto& c_tensor_view = [&]() {
322 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
324 return make_naive_tensor_view<address_space_enum::global>(
328 number<EpiloguePipeline::GetVectorSizeC()>{},
333 return make_naive_tensor_view<address_space_enum::global>(
342 return make_tuple(a_tensor_view, b_flat_tensor_view, c_tensor_view);
345 template <
typename TensorView>
348 const auto& a_pad_view = [&]() {
349 const auto& a_tensor_view = views.at(
I0);
350 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
366 const auto& b_flat_tensor_view = views.at(
I1);
369 const auto& c_pad_view = [&]() {
370 const auto& c_tensor_view = views.at(
I2);
371 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
387 return make_tuple(a_pad_view, b_flat_tensor_view, c_pad_view);
390 template <
typename PadView>
394 const auto& a_pad_view = views.at(
I0);
395 const auto& b_flat_pad_view = views.at(
I1);
396 const auto& c_pad_view = views.at(
I2);
398 const auto& a_block_window = [&]() {
399 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
415 const auto& b_flat_block_window =
419 {
static_cast<int>(i_n / BlockGemmShape::WarpTile::at(
idxN)), 0});
426 return make_tuple(a_block_window, b_flat_block_window, c_block_window);
439 const auto& gemm_tensor_views_tuple =
440 MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
441 a_ptr, b_flat_ptr, c_ptr, kargs, splitk_batch_offset);
445 const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.
splitted_k);
448 const auto& a_block_window = gemm_tile_windows.at(
I0);
449 const auto& b_flat_block_window = gemm_tile_windows.at(
I1);
450 const auto& d_block_window = gemm_tile_windows.at(
I2);
452 a_block_window, b_flat_block_window, num_loop, smem_ptr);
455 auto& c_block_window = gemm_tile_windows.at(
I2);
457 EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
458 c_block_window, c_block_tile, d_block_window, smem_ptr);
463 const auto [iM, iN] =
TilePartitioner{kargs.
M, kargs.
N}.GetOutputTileIndex(blockIdx.x);
464 const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
465 const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
479 EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
482 RunFlatmm(a_ptr, b_flat_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:529
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:41
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
Definition: flatmm_kernel.hpp:33
const void * a_ptr
Definition: flatmm_kernel.hpp:53
void * c_ptr
Definition: flatmm_kernel.hpp:55
index_t k_batch
Definition: flatmm_kernel.hpp:56
CK_TILE_HOST FlatmmHostArgs(const void *a_ptr_, const void *b_shuffle_ptr_, void *c_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
Definition: flatmm_kernel.hpp:35
CK_TILE_HOST FlatmmHostArgs()=default
const void * b_shuffle_ptr
Definition: flatmm_kernel.hpp:54
Definition: flatmm_kernel.hpp:99
index_t K
Definition: flatmm_kernel.hpp:105
index_t stride_C
Definition: flatmm_kernel.hpp:108
index_t N
Definition: flatmm_kernel.hpp:104
index_t stride_A
Definition: flatmm_kernel.hpp:106
index_t k_batch
Definition: flatmm_kernel.hpp:109
index_t stride_B
Definition: flatmm_kernel.hpp:107
const void * a_ptr
Definition: flatmm_kernel.hpp:100
index_t M
Definition: flatmm_kernel.hpp:103
const void * b_shuffle_ptr
Definition: flatmm_kernel.hpp:101
void * c_ptr
Definition: flatmm_kernel.hpp:102
Definition: flatmm_kernel.hpp:132
index_t b_k_split_offset
Definition: flatmm_kernel.hpp:169
index_t a_k_split_offset
Definition: flatmm_kernel.hpp:168
index_t splitted_k
Definition: flatmm_kernel.hpp:170
__device__ SplitKBatchOffset(const FlatmmKernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: flatmm_kernel.hpp:133
Definition: flatmm_kernel.hpp:61
static constexpr auto idxK
Definition: flatmm_kernel.hpp:82
static constexpr CK_TILE_HOST auto BlockSize()
Definition: flatmm_kernel.hpp:96
CK_TILE_DEVICE void operator()(FlatmmKernelArgs kargs) const
Definition: flatmm_kernel.hpp:461
remove_cvref_t< typename FlatmmPipeline::BlockGemmShape > BlockGemmShape
Definition: flatmm_kernel.hpp:65
static constexpr auto I0
Definition: flatmm_kernel.hpp:77
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: flatmm_kernel.hpp:62
static CK_TILE_DEVICE void RunFlatmm(const ADataType *a_ptr, const BDataType *b_flat_ptr, CDataType *c_ptr, void *smem_ptr, const FlatmmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Definition: flatmm_kernel.hpp:429
static constexpr auto idxN
Definition: flatmm_kernel.hpp:81
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_flat_ptr, CDataType *c_ptr, const FlatmmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: flatmm_kernel.hpp:281
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: flatmm_kernel.hpp:66
static constexpr auto I2
Definition: flatmm_kernel.hpp:79
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: flatmm_kernel.hpp:346
remove_cvref_t< FlatmmPipeline_ > FlatmmPipeline
Definition: flatmm_kernel.hpp:63
remove_cvref_t< typename FlatmmPipeline::ADataType > ADataType
Definition: flatmm_kernel.hpp:72
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: flatmm_kernel.hpp:75
remove_cvref_t< typename FlatmmPipeline::BLayout > BLayout
Definition: flatmm_kernel.hpp:68
static CK_TILE_HOST const std::string GetName()
Definition: flatmm_kernel.hpp:84
remove_cvref_t< typename FlatmmPipeline::ALayout > ALayout
Definition: flatmm_kernel.hpp:67
static constexpr auto idxM
Definition: flatmm_kernel.hpp:80
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: flatmm_kernel.hpp:126
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: flatmm_kernel.hpp:91
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: flatmm_kernel.hpp:392
static constexpr auto I1
Definition: flatmm_kernel.hpp:78
static constexpr index_t KernelBlockSize
Definition: flatmm_kernel.hpp:70
remove_cvref_t< typename FlatmmPipeline::BDataType > BDataType
Definition: flatmm_kernel.hpp:73
static CK_TILE_HOST bool IsSupportedArgument(const FlatmmKernelArgs &kargs)
Definition: flatmm_kernel.hpp:173
remove_cvref_t< typename FlatmmPipeline::CLayout > CLayout
Definition: flatmm_kernel.hpp:69
static constexpr CK_TILE_HOST FlatmmKernelArgs MakeKernelArgs(const FlatmmHostArgs &hostArgs)
Definition: flatmm_kernel.hpp:112
Definition: flatmm_kernel.hpp:16
index_t stride_C
Definition: flatmm_kernel.hpp:29
CK_TILE_HOST FlatmmProblem()=default
index_t M
Definition: flatmm_kernel.hpp:24
index_t stride_B
Definition: flatmm_kernel.hpp:28
CK_TILE_HOST FlatmmProblem(index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
Definition: flatmm_kernel.hpp:18
index_t stride_A
Definition: flatmm_kernel.hpp:27
index_t N
Definition: flatmm_kernel.hpp:25
index_t K
Definition: flatmm_kernel.hpp:26
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:115
Definition: sequence.hpp:52