44 :
GemmProblem(M_, N_, K_, stride_A_, stride_B_, stride_C_),
58 template <
typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_>
80 return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
115 return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
121 const std::size_t k_id = blockIdx.z)
123 constexpr
auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(
number<2>{});
125 const index_t KRead = (kargs.
K + K_t - 1) / K_t * K1;
127 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
131 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
136 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
140 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
145 if(k_id <
static_cast<uint32_t
>(kargs.
k_batch - 1))
162 if constexpr(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
167 std::cerr <<
"Conditions not met for Kbatch >1 !" << std::endl;
172 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
174 if(kargs.
K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK ==
false)
176 std::cerr <<
"Can't support K that is not a multiple of KPerBlock"
181 if(kargs.
K % GemmPipeline::GetVectorSizeA() != 0)
183 std::cerr <<
"K is not a multiple of vector load size for A tensor!" << std::endl;
189 if(kargs.
M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM ==
false)
191 std::cerr <<
"Can't support M that is not a multiple of MPerBlock"
196 if(kargs.
M % GemmPipeline::GetVectorSizeA() != 0)
198 std::cerr <<
"M is not a multiple of vector load size for A tensor!" << std::endl;
203 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
205 if(kargs.
N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN ==
false)
207 std::cerr <<
"Can't support N that is not a multiple of NPerBlock"
212 if(kargs.
N % GemmPipeline::GetVectorSizeB() != 0)
214 std::cerr <<
"N is not a multiple of vector load size for B tensor!" << std::endl;
220 if(kargs.
K % TilePartitioner::KPerBlock != 0 && GemmPipeline::kPadK ==
false)
222 std::cerr <<
"Can't support K that is not a multiple of KPerBlock"
227 if(kargs.
K % GemmPipeline::GetVectorSizeB() != 0)
229 std::cerr <<
"K is not a multiple of vector load size for B tensor!" << std::endl;
234 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
236 if(kargs.
N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN ==
false)
238 std::cerr <<
"Can't support N that is not a multiple of NPerBlock"
243 if(kargs.
N % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
245 std::cerr <<
"N is not a multiple of vector load size for C tensor!" << std::endl;
251 if(kargs.
M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM ==
false)
253 std::cerr <<
"Can't support M that is not a multiple of MPerBlock"
258 if(kargs.
M % EpiloguePipeline::template GetVectorSizeC<CDataType>() != 0)
260 std::cerr <<
"M is not a multiple of vector load size for C tensor!" << std::endl;
267 template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
274 const auto& a_tensor_view = [&]() {
275 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
277 return make_naive_tensor_view<address_space_enum::global>(
281 number<GemmPipeline::GetVectorSizeA()>{},
286 return make_naive_tensor_view<address_space_enum::global>(
290 number<GemmPipeline::GetVectorSizeA()>{},
295 const auto& b_tensor_view = [&]() {
296 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
298 return make_naive_tensor_view<address_space_enum::global>(
302 number<GemmPipeline::GetVectorSizeB()>{},
307 return make_naive_tensor_view<address_space_enum::global>(
311 number<GemmPipeline::GetVectorSizeB()>{},
317 const auto& c_tensor_view = [&]() {
318 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
320 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
324 number<EpiloguePipeline::template GetVectorSizeC<CDataType>()>{},
329 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
338 return make_tuple(a_tensor_view, b_tensor_view, c_tensor_view);
341 template <
typename TensorView>
344 const auto& a_pad_view = [&]() {
345 const auto& a_tensor_view = views.at(
I0);
346 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
362 const auto& b_pad_view = [&]() {
363 const auto& b_tensor_view = views.at(
I1);
364 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
381 const auto& c_pad_view = [&]() {
382 const auto& c_tensor_view = views.at(
I2);
383 if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
399 return make_tuple(a_pad_view, b_pad_view, c_pad_view);
402 template <
typename PadView>
406 const auto& a_pad_view = views.at(
I0);
407 const auto& b_pad_view = views.at(
I1);
408 const auto& c_pad_view = views.at(
I2);
410 const auto& a_block_window = [&]() {
411 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
427 const auto& b_block_window = [&]() {
428 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
449 return make_tuple(a_block_window, b_block_window, c_block_window);
464 template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
475 const auto& gemm_tensor_views_tuple =
476 MakeGemmTensorViews<DstInMemOp>(a_ptr, b_ptr, c_ptr, kargs, splitk_batch_offset);
481 const index_t num_loop = TilePartitioner::GetLoopNum(splitk_batch_offset.
splitted_k);
484 const auto& a_block_window = gemm_tile_windows.at(
I0);
485 const auto& b_block_window = gemm_tile_windows.at(
I1);
486 const auto& c_block_tile =
490 auto& c_block_window = gemm_tile_windows.at(
I2);
493 .template operator()<decltype(c_block_window), decltype(c_block_tile), DstInMemOp>(
494 c_block_window, c_block_tile, smem_ptr);
499 const auto [iM, iN] =
TilePartitioner{kargs.
M, kargs.
N}.GetOutputTileIndex(blockIdx.x);
500 const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
501 const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
516 RunGemm(a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
522 if constexpr(!(EpiloguePipeline::template GetVectorSizeC<CDataType>() % 2 != 0 &&
525 RunGemm<memory_operation_enum::atomic_add>(
526 a_ptr, b_ptr, c_ptr, smem_ptr, kargs, splitk_batch_offset, i_m, i_n);
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:480
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
Definition: gemm_kernel.hpp:32
CK_TILE_HOST GemmHostArgs()=default
void * c_ptr
Definition: gemm_kernel.hpp:54
CK_TILE_HOST GemmHostArgs(const void *a_ptr_, const void *b_ptr_, void *c_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
Definition: gemm_kernel.hpp:34
const void * a_ptr
Definition: gemm_kernel.hpp:52
const void * b_ptr
Definition: gemm_kernel.hpp:53
index_t k_batch
Definition: gemm_kernel.hpp:55
Definition: gemm_kernel.hpp:86
index_t M
Definition: gemm_kernel.hpp:90
index_t N
Definition: gemm_kernel.hpp:91
const void * b_ptr
Definition: gemm_kernel.hpp:88
const void * a_ptr
Definition: gemm_kernel.hpp:87
index_t k_batch
Definition: gemm_kernel.hpp:96
index_t stride_A
Definition: gemm_kernel.hpp:93
void * c_ptr
Definition: gemm_kernel.hpp:89
index_t stride_B
Definition: gemm_kernel.hpp:94
index_t K
Definition: gemm_kernel.hpp:92
index_t stride_C
Definition: gemm_kernel.hpp:95
Definition: gemm_kernel.hpp:119
index_t b_k_split_offset
Definition: gemm_kernel.hpp:156
__device__ SplitKBatchOffset(const GemmKernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: gemm_kernel.hpp:120
index_t a_k_split_offset
Definition: gemm_kernel.hpp:155
index_t splitted_k
Definition: gemm_kernel.hpp:157
Definition: gemm_kernel.hpp:60
static CK_TILE_DEVICE void RunGemm(const ADataType *a_ptr, const BDataType *b_ptr, CDataType *c_ptr, void *smem_ptr, const GemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: gemm_kernel.hpp:465
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Definition: gemm_kernel.hpp:69
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: gemm_kernel.hpp:78
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition: gemm_kernel.hpp:64
remove_cvref_t< typename EpiloguePipeline::ODataType > CDataType
Definition: gemm_kernel.hpp:72
static constexpr index_t KernelBlockSize
Definition: gemm_kernel.hpp:67
static CK_TILE_HOST bool IsSupportedArgument(const GemmKernelArgs &kargs)
Definition: gemm_kernel.hpp:160
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: gemm_kernel.hpp:70
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: gemm_kernel.hpp:404
static constexpr auto I0
Definition: gemm_kernel.hpp:74
CK_TILE_DEVICE void operator()(GemmKernelArgs kargs) const
Definition: gemm_kernel.hpp:497
static constexpr auto I1
Definition: gemm_kernel.hpp:75
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: gemm_kernel.hpp:113
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: gemm_kernel.hpp:62
remove_cvref_t< typename GemmPipeline::CLayout > CLayout
Definition: gemm_kernel.hpp:66
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: gemm_kernel.hpp:342
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: gemm_kernel.hpp:65
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_ptr, CDataType *c_ptr, const GemmKernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: gemm_kernel.hpp:268
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: gemm_kernel.hpp:61
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: gemm_kernel.hpp:63
static constexpr auto I2
Definition: gemm_kernel.hpp:76
static constexpr CK_TILE_HOST GemmKernelArgs MakeKernelArgs(const GemmHostArgs &hostArgs)
Definition: gemm_kernel.hpp:99
static constexpr CK_TILE_HOST auto BlockSize()
Definition: gemm_kernel.hpp:83
Definition: gemm_kernel.hpp:15
index_t stride_C
Definition: gemm_kernel.hpp:28
index_t stride_B
Definition: gemm_kernel.hpp:27
CK_TILE_HOST GemmProblem(index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, index_t stride_C_)
Definition: gemm_kernel.hpp:17
CK_TILE_HOST GemmProblem()=default
index_t K
Definition: gemm_kernel.hpp:25
index_t stride_A
Definition: gemm_kernel.hpp:26
index_t N
Definition: gemm_kernel.hpp:24
index_t M
Definition: gemm_kernel.hpp:23
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:114
Definition: sequence.hpp:52