30 template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
34 const std::array<const void*, NumBTensor>& bs_ptr_,
35 const std::array<const void*, NumDTensor>& ds_ptr_,
41 const std::array<index_t, NumATensor>& stride_As_,
42 const std::array<index_t, NumBTensor>& stride_Bs_,
43 const std::array<index_t, NumDTensor>& stride_Ds_,
60 const std::array<const void*, NumATensor>
as_ptr;
61 const std::array<const void*, NumBTensor>
bs_ptr;
62 const std::array<const void*, NumDTensor>
ds_ptr;
84 template <index_t NumATensor = 1, index_t NumBTensor = 1, index_t NumDTensor = 0>
88 const std::array<const void*, NumATensor>
as_ptr;
90 const std::array<const void*, NumBTensor>
bs_ptr;
92 const std::array<const void*, NumDTensor>
ds_ptr;
152 template <
typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_>
204 template <
typename T>
207 static constexpr
bool value = []() {
209 return GemmPipeline::UsePersistentKernel;
228 static_assert(AsLayout::size() == AsDataType::size(),
229 "The size of AsLayout and AsDataType should be the same");
231 static_assert(BsLayout::size() == BsDataType::size(),
232 "The size of BsLayout and BsDataType should be the same");
234 static_assert(DsLayout::size() == DsDataType::size(),
235 "The size of DsLayout and DsDataType should be the same");
243 return concat(
'_',
"gemm", gemm_prec_str<ADataType, BDataType>(), GemmPipeline::GetName());
249 return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
261 const auto kernel = kentry<KernelBlockSize, 1, Kernel, KernelArgs>;
264 hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel,
KernelBlockSize, 0));
265 const int grid_size = get_available_compute_units(s) * occupancy;
266 return dim3(grid_size, 1, 1);
290 return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
297 constexpr
auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(
number<2>{});
298 const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.
k_batch * K1);
299 const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.
K + K_t - 1) / K_t * K1);
303 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, AiLayout>)
307 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, AiLayout>)
310 __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.
stride_As[index]);
316 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BiLayout>)
319 __builtin_amdgcn_readfirstlane(k_id * KRead * kargs.
stride_Bs[index]);
321 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BiLayout>)
327 if(k_id <
static_cast<uint32_t
>(kargs.
k_batch - 1))
329 splitted_k = __builtin_amdgcn_readfirstlane(KRead);
344 if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
357 bool AsTesnorIsValid = {
true};
360 if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
362 if(kargs.
K % (TilePartitioner::KPerBlock * kargs.
k_batch) != 0 &&
363 GemmPipeline::kPadK ==
false)
368 "Can't support K that is not a multiple of k_batch * KPerBlock "
371 AsTesnorIsValid =
false;
373 if(kargs.
K % GemmPipeline::GetVectorSizeA() != 0)
377 CK_TILE_ERROR(
"K is not a multiple of vector load size for A tensor!");
379 AsTesnorIsValid =
false;
384 if(kargs.
M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM ==
false)
389 "Can't support M that is not a multiple of MPerBlock without padding!");
391 AsTesnorIsValid =
false;
393 if(kargs.
M % GemmPipeline::GetVectorSizeA() != 0)
397 CK_TILE_ERROR(
"M is not a multiple of vector load size for A tensor!");
399 AsTesnorIsValid =
false;
404 bool BsTesnorIsValid = {
true};
407 if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
409 if(kargs.
N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN ==
false)
414 "Can't support N that is not a multiple of NPerBlock without padding!");
416 BsTesnorIsValid =
false;
418 if(kargs.
N % GemmPipeline::GetVectorSizeB() != 0)
422 CK_TILE_ERROR(
"N is not a multiple of vector load size for B tensor!");
424 BsTesnorIsValid =
false;
429 if(kargs.
K % (TilePartitioner::KPerBlock * kargs.
k_batch) != 0 &&
430 GemmPipeline::kPadK ==
false)
435 "Can't support K that is not a multiple of k_batch * KPerBlock "
438 BsTesnorIsValid =
false;
440 if(kargs.
K % GemmPipeline::GetVectorSizeB() != 0)
444 CK_TILE_ERROR(
"K is not a multiple of vector load size for B tensor!");
446 BsTesnorIsValid =
false;
451 bool DTesnorIsValid = {
true};
454 if(std::is_same_v<DiLayout, ELayout> ==
false)
456 DTesnorIsValid =
false;
458 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
460 if(kargs.
N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN ==
false)
464 CK_TILE_ERROR(
"Can't support N for tensor D that is not a multiple of "
465 "NPerBlock without padding!");
467 DTesnorIsValid =
false;
469 if(kargs.
N % EpiloguePipeline::GetVectorSizeD(index) != 0)
473 CK_TILE_ERROR(
"N is not a multiple of vector load size for D tensor!");
475 DTesnorIsValid =
false;
480 if(kargs.
M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM ==
false)
484 CK_TILE_ERROR(
"Can't support M for tensor D that is not a multiple of "
485 "MPerBlock without padding!");
487 DTesnorIsValid =
false;
489 if(kargs.
M % EpiloguePipeline::GetVectorSizeD(index) != 0)
493 CK_TILE_ERROR(
"M is not a multiple of vector load size for D tensor!");
495 DTesnorIsValid =
false;
500 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
502 if(kargs.
N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN ==
false)
507 "Can't support N that is not a multiple of NPerBlock without padding!");
511 if(kargs.
N % EpiloguePipeline::GetVectorSizeC() != 0)
515 CK_TILE_ERROR(
"N is not a multiple of vector load size for C tensor!");
522 if(kargs.
M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM ==
false)
527 "Can't support M that is not a multiple of MPerBlock without padding!");
531 if(kargs.
M % EpiloguePipeline::GetVectorSizeC() != 0)
535 CK_TILE_ERROR(
"M is not a multiple of vector load size for C tensor!");
540 return AsTesnorIsValid && BsTesnorIsValid && DTesnorIsValid;
543 template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
546 const std::array<const BDataType*, NumBTensor>& bs_ptr,
547 const std::array<const void*, NumDTensor>& ds_ptr,
552 static_assert(!TilePartitioner::BlockGemmShape::PermuteA,
"Not implemented!");
558 if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
560 return make_naive_tensor_view<address_space_enum::global>(
561 static_cast<const AiDataType*
>(as_ptr[i]),
564 number<GemmPipeline::GetVectorSizeA()>{},
569 return make_naive_tensor_view<address_space_enum::global>(
570 static_cast<const AiDataType*
>(as_ptr[i]),
573 number<GemmPipeline::GetVectorSizeA()>{},
583 if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::RowMajor>)
585 if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
587 constexpr
index_t K1 = GemmPipeline::GetSmemPackB();
589 constexpr
index_t VectorSizeB =
590 std::min(K1, GemmPipeline::GetVectorSizeB());
591 const auto b_k0_n_k1_desc =
602 return make_tensor_view<address_space_enum::global>(
603 static_cast<const BiDataType*
>(bs_ptr[i]), b_n_k_desc);
607 return make_naive_tensor_view<address_space_enum::global>(
611 number<GemmPipeline::GetVectorSizeB()>{},
617 if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
619 constexpr
index_t K1 = GemmPipeline::GetSmemPackB();
621 constexpr
index_t VectorSizeB =
622 std::min(K1, GemmPipeline::GetVectorSizeB());
623 const auto b_k0_n_k1_desc =
634 return make_tensor_view<address_space_enum::global>(
635 static_cast<const BiDataType*
>(bs_ptr[i]), b_n_k_desc);
639 if constexpr(GemmPipeline::Preshuffle)
642 GemmPipeline::BlockGemmShape::flatKPerWarp *
644 TilePartitioner::BlockGemmShape::WarpTile::at(
number<2>{}));
645 index_t kFlatN = kargs.
N * kargs.
K / kFlatK;
647 return make_naive_tensor_view<address_space_enum::global>(
651 number<GemmPipeline::GetVectorSizeB()>{},
656 return make_naive_tensor_view<address_space_enum::global>(
660 number<GemmPipeline::GetVectorSizeB()>{},
672 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
674 return make_naive_tensor_view<address_space_enum::global>(
675 static_cast<const DDataType_*
>(ds_ptr[i]),
678 number<EpiloguePipeline::GetVectorSizeD(i)>{},
683 return make_naive_tensor_view<address_space_enum::global>(
684 static_cast<const DDataType_*
>(ds_ptr[i]),
687 number<EpiloguePipeline::GetVectorSizeD(i)>{},
694 const auto& e_tensor_view = [&]() {
695 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
697 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
701 number<EpiloguePipeline::GetVectorSizeC()>{},
706 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
715 return make_tuple(as_tensor_view, bs_tensor_view, ds_tensor_view, e_tensor_view);
718 template <
typename TensorView>
723 const auto& a_tensor_view = views.at(
I0);
725 if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
742 const auto& b_flat_pad_view = views.at(
I1);
746 const auto& b_tensor_view = views.at(
I1);
748 if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
767 const auto& d_tensor_view = views.at(
I2);
769 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
787 const auto& e_pad_view = [&]() {
788 const auto& e_tensor_view = views.at(
I3);
789 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
805 if constexpr(GemmPipeline::Preshuffle)
808 return make_tuple(as_pad_view, b_flat_pad_view, ds_pad_view, e_pad_view);
812 return make_tuple(as_pad_view, bs_pad_view, ds_pad_view, e_pad_view);
816 template <
typename PadView>
820 const auto& as_pad_view = views.at(
I0);
821 const auto& bs_pad_view = views.at(
I1);
822 const auto& ds_pad_view = views.at(
I2);
823 const auto& e_pad_view = views.at(
I3);
828 if constexpr(std::is_same_v<AiLayout, tensor_layout::gemm::RowMajor>)
848 if constexpr(GemmPipeline::Preshuffle)
854 {
static_cast<int>(i_n / GemmPipeline::BlockGemmShape::WarpTile::at(
I1)),
859 if constexpr(std::is_same_v<BiLayout, tensor_layout::gemm::ColumnMajor>)
880 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
902 return make_tuple(as_block_window, bs_block_window, ds_block_window, e_block_window);
919 template <
bool UseDefaultScheduler = true>
921 const std::array<const BDataType*, NumBTensor>& bs_ptr,
922 const std::array<const void*, NumDTensor>& ds_ptr,
931 const auto& gemm_tensor_views_tuple =
932 MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
933 as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
938 const index_t num_loop = __builtin_amdgcn_readfirstlane(
939 TilePartitioner::GetLoopNum(splitk_batch_offset.
splitted_k));
942 const auto& as_block_window = gemm_tile_windows.at(
I0);
943 const auto& bs_block_window = gemm_tile_windows.at(
I1);
944 const auto& ds_block_window = gemm_tile_windows.at(
I2);
947 as_block_window[
I0], bs_block_window[
I0], num_loop, smem_ptr_0);
949 if(UseDefaultScheduler || (get_warp_id() == 0))
952 auto& c_block_window = gemm_tile_windows.at(
I3);
955 operator()<decltype(c_block_window), decltype(c_block_tile), decltype(ds_block_window)>(
956 c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
978 const std::array<const BDataType*, NumBTensor>& bs_ptr,
979 const std::array<const void*, NumDTensor>& ds_ptr,
981 void* __restrict__ smem_ptr_0,
982 void* __restrict__ smem_ptr_1,
989 const auto& gemm_tensor_views_tuple =
990 MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
991 as_ptr, bs_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
996 const index_t num_loop = __builtin_amdgcn_readfirstlane(
997 TilePartitioner::GetLoopNum(splitk_batch_offset.
splitted_k));
1000 const auto& as_block_window = gemm_tile_windows.at(
I0);
1001 const auto& bs_block_window = gemm_tile_windows.at(
I1);
1002 const auto& ds_block_window = gemm_tile_windows.at(
I2);
1005 as_block_window[
I0], bs_block_window[
I0], num_loop, smem_ptr_0, smem_ptr_1);
1008 auto& c_block_window = gemm_tile_windows.at(
I3);
1011 operator()<decltype(c_block_window), decltype(c_block_tile), decltype(ds_block_window)>(
1012 c_block_window, c_block_tile, ds_block_window, smem_ptr_0);
1016 template <
bool U = !PersistentKernel,
typename = std::enable_if_t<U>>
1019 const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
1020 const auto [iM, iN] =
TilePartitioner{kargs.
M, kargs.
N}.GetOutputTileIndex(blockId);
1021 const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
1022 const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
1027 std::array<const ADataType*, NumATensor> as_ptr;
1033 std::array<const BDataType*, NumBTensor> bs_ptr;
1044 if constexpr(GemmPipeline::DoubleSmemBuffer ==
true)
1048 EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1058 splitk_batch_offset,
1066 EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1069 constexpr
auto scheduler_type = (GemmPipeline::NumWaveGroups == 1);
1070 RunGemm<scheduler_type>(as_ptr,
1076 splitk_batch_offset,
1084 template <
bool U = PersistentKernel,
typename = std::enable_if_t<U>,
typename =
void>
1087 const auto grid_size = __builtin_amdgcn_readfirstlane(
get_grid_size());
1088 const auto num_tiles =
1089 __builtin_amdgcn_readfirstlane(TilePartitioner::GridSize(kargs.
M, kargs.
N));
1090 const auto num_work = __builtin_amdgcn_readfirstlane(num_tiles * kargs.
k_batch);
1091 auto block_id = __builtin_amdgcn_readfirstlane(get_block_id());
1093 while(block_id < num_work)
1096 const auto tile_idx = __builtin_amdgcn_readfirstlane(block_id % num_tiles);
1097 const auto [iM, iN] =
TilePartitioner{kargs.
M, kargs.
N}.GetOutputTileIndex(tile_idx);
1098 const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
1099 const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
1102 const auto k_batch = __builtin_amdgcn_readfirstlane(block_id / num_tiles);
1105 std::array<const ADataType*, NumATensor> as_ptr;
1111 std::array<const BDataType*, NumBTensor> bs_ptr;
1122 if constexpr(GemmPipeline::DoubleSmemBuffer ==
true)
1125 if constexpr(!(EpiloguePipeline::MemoryOperation ==
1127 EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1137 splitk_batch_offset,
1144 if constexpr(!(EpiloguePipeline::MemoryOperation ==
1146 EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1155 splitk_batch_offset,
1161 block_id += grid_size;
1162 if(block_id >= num_work)
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:255
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1672
CK_TILE_HOST void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:529
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1615
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:184
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:412
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:343
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
__device__ index_t get_grid_size()
Definition: get_id.hpp:27
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
__device__ X atomic_add(X *p_dst, const X &x)
The Universal GEMM kernel host arguments.
Definition: universal_gemm_kernel.hpp:32
void * c_ptr
Definition: universal_gemm_kernel.hpp:66
const std::array< index_t, NumDTensor > stride_Ds
Definition: universal_gemm_kernel.hpp:73
const std::array< index_t, NumBTensor > stride_Bs
Definition: universal_gemm_kernel.hpp:72
CK_TILE_HOST UniversalGemmHostArgs(const std::array< const void *, NumATensor > &as_ptr_, const std::array< const void *, NumBTensor > &bs_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, const std::array< index_t, NumATensor > &stride_As_, const std::array< index_t, NumBTensor > &stride_Bs_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition: universal_gemm_kernel.hpp:33
index_t K
Definition: universal_gemm_kernel.hpp:70
void * e_ptr
Definition: universal_gemm_kernel.hpp:65
index_t M
Definition: universal_gemm_kernel.hpp:68
const std::array< const void *, NumDTensor > ds_ptr
Definition: universal_gemm_kernel.hpp:62
const std::array< const void *, NumATensor > as_ptr
Definition: universal_gemm_kernel.hpp:60
const std::array< index_t, NumATensor > stride_As
Definition: universal_gemm_kernel.hpp:71
index_t N
Definition: universal_gemm_kernel.hpp:69
index_t stride_E
Definition: universal_gemm_kernel.hpp:76
const std::array< const void *, NumBTensor > bs_ptr
Definition: universal_gemm_kernel.hpp:61
index_t stride_C
Definition: universal_gemm_kernel.hpp:77
index_t k_batch
Definition: universal_gemm_kernel.hpp:80
Definition: universal_gemm_kernel.hpp:294
std::array< index_t, NumATensor > as_k_split_offset
Definition: universal_gemm_kernel.hpp:337
index_t splitted_k
Definition: universal_gemm_kernel.hpp:339
__device__ SplitKBatchOffset(const KernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: universal_gemm_kernel.hpp:295
std::array< index_t, NumBTensor > bs_k_split_offset
Definition: universal_gemm_kernel.hpp:338
Definition: universal_gemm_kernel.hpp:203
static constexpr bool value
Definition: universal_gemm_kernel.hpp:207
decltype(T::UsePersistentKernel) has_persistent_type
Definition: universal_gemm_kernel.hpp:205
The GEMM kernel device arguments.
Definition: universal_gemm_kernel.hpp:86
void * e_ptr
The E output tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:94
std::array< index_t, NumBTensor > stride_Bs
The distance between consecutive elements of non-contiguous dimension (in memory) of Bs tensor.
Definition: universal_gemm_kernel.hpp:106
const std::array< const void *, NumDTensor > ds_ptr
The Ds input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:92
std::array< index_t, NumATensor > stride_As
The distance between consecutive elements of non-contiguous dimension (in memory) of As tensor.
Definition: universal_gemm_kernel.hpp:103
const std::array< const void *, NumATensor > as_ptr
The As input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:88
index_t k_batch
Definition: universal_gemm_kernel.hpp:113
index_t N
GEMM's N dimension size.
Definition: universal_gemm_kernel.hpp:98
index_t stride_E
The distance between consecutive elements of non-contiguous dimension (in memory) of E tensor.
Definition: universal_gemm_kernel.hpp:112
index_t K
GEMM's K dimension size.
Definition: universal_gemm_kernel.hpp:100
const std::array< const void *, NumBTensor > bs_ptr
The Bs input tensor's pointer to device memory.
Definition: universal_gemm_kernel.hpp:90
std::array< index_t, NumDTensor > stride_Ds
The distance between consecutive elements of non-contiguous dimension (in memory) of Ds tensor.
Definition: universal_gemm_kernel.hpp:109
index_t M
GEMM's M dimension size.
Definition: universal_gemm_kernel.hpp:96
The Universal GEMM kernel template.
Definition: universal_gemm_kernel.hpp:154
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition: universal_gemm_kernel.hpp:1017
std::conditional_t< BDataTypeIsTuple, remove_cvref_t< typename GemmPipeline::BDataType >, remove_cvref_t< tuple< typename GemmPipeline::BDataType > >> BsDataType
Definition: universal_gemm_kernel.hpp:189
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: universal_gemm_kernel.hpp:156
static CK_TILE_HOST const std::string GetName()
Definition: universal_gemm_kernel.hpp:240
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: universal_gemm_kernel.hpp:155
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition: universal_gemm_kernel.hpp:1085
static CK_TILE_DEVICE void RunGemm(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: universal_gemm_kernel.hpp:920
static constexpr bool BDataTypeIsTuple
Definition: universal_gemm_kernel.hpp:161
static constexpr auto I2
Definition: universal_gemm_kernel.hpp:218
static constexpr bool BLayoutIsTuple
Definition: universal_gemm_kernel.hpp:167
static CK_TILE_DEVICE auto MakeGemmTensorViews(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: universal_gemm_kernel.hpp:545
std::conditional_t< BLayoutIsTuple, remove_cvref_t< typename GemmPipeline::BLayout >, remove_cvref_t< tuple< typename GemmPipeline::BLayout > >> BsLayout
Definition: universal_gemm_kernel.hpp:177
static constexpr index_t NumATensor
Definition: universal_gemm_kernel.hpp:221
static constexpr bool ALayoutIsTuple
Definition: universal_gemm_kernel.hpp:165
remove_cvref_t< std::tuple_element_t< I0, AsDataType > > ADataType
Definition: universal_gemm_kernel.hpp:225
static CK_TILE_DEVICE void RunGemm2LDS(const std::array< const ADataType *, NumATensor > &as_ptr, const std::array< const BDataType *, NumBTensor > &bs_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: universal_gemm_kernel.hpp:977
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: universal_gemm_kernel.hpp:818
static constexpr auto I3
Definition: universal_gemm_kernel.hpp:219
std::conditional_t< DDataTypeIsTuple, remove_cvref_t< typename EpiloguePipeline::DsDataType >, remove_cvref_t< tuple< typename EpiloguePipeline::DsDataType > >> DsDataType
Definition: universal_gemm_kernel.hpp:194
static constexpr bool ADataTypeIsTuple
Definition: universal_gemm_kernel.hpp:159
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: universal_gemm_kernel.hpp:719
remove_cvref_t< typename GemmPipeline::CLayout > ELayout
Definition: universal_gemm_kernel.hpp:196
static constexpr index_t NumDTensor
Definition: universal_gemm_kernel.hpp:223
UniversalGemmKernelArgs< AsLayout::size(), BsLayout::size(), DsLayout::size()> KernelArgs
Definition: universal_gemm_kernel.hpp:238
static constexpr bool DDataTypeIsTuple
Definition: universal_gemm_kernel.hpp:163
static constexpr bool PersistentKernel
Definition: universal_gemm_kernel.hpp:214
static constexpr auto I1
Definition: universal_gemm_kernel.hpp:217
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: universal_gemm_kernel.hpp:247
std::conditional_t< ADataTypeIsTuple, remove_cvref_t< typename GemmPipeline::ADataType >, remove_cvref_t< tuple< typename GemmPipeline::ADataType > >> AsDataType
Definition: universal_gemm_kernel.hpp:185
remove_cvref_t< std::tuple_element_t< I0, BsDataType > > BDataType
Definition: universal_gemm_kernel.hpp:226
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: universal_gemm_kernel.hpp:258
static constexpr index_t KernelBlockSize
Definition: universal_gemm_kernel.hpp:199
static constexpr index_t NumBTensor
Definition: universal_gemm_kernel.hpp:222
static constexpr auto I0
Definition: universal_gemm_kernel.hpp:216
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: universal_gemm_kernel.hpp:342
std::conditional_t< ALayoutIsTuple, remove_cvref_t< typename GemmPipeline::ALayout >, remove_cvref_t< tuple< typename GemmPipeline::ALayout > >> AsLayout
Definition: universal_gemm_kernel.hpp:174
std::conditional_t< DLayoutIsTuple, remove_cvref_t< typename EpiloguePipeline::DsLayout >, remove_cvref_t< tuple< typename EpiloguePipeline::DsLayout > >> DsLayout
Definition: universal_gemm_kernel.hpp:181
static constexpr bool DLayoutIsTuple
Definition: universal_gemm_kernel.hpp:169
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: universal_gemm_kernel.hpp:157
static constexpr CK_TILE_HOST auto BlockSize()
Definition: universal_gemm_kernel.hpp:269
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: universal_gemm_kernel.hpp:288
static constexpr CK_TILE_HOST KernelArgs MakeKernelArgs(const UniversalGemmHostArgs< NumATensor, NumBTensor, NumDTensor > &hostArgs)
Definition: universal_gemm_kernel.hpp:272
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: universal_gemm_kernel.hpp:197
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:115
Definition: sequence.hpp:52
Definition: functional.hpp:43
Definition: stream_config.hpp:30
#define CK_TILE_ENV(name)
Definition: env.hpp:145