28 template <index_t NumDTensor = 0>
34 const std::array<const void*, NumDTensor>& ds_ptr_,
42 const std::array<index_t, NumDTensor>& stride_Ds_,
61 const std::array<const void*, NumDTensor>
ds_ptr;
83 template <index_t NumDTensor = 0>
91 const std::array<const void*, NumDTensor>
ds_ptr;
151 template <
typename TilePartitioner_,
typename GemmPipeline_,
typename EpiloguePipeline_>
168 template <
typename T>
171 static constexpr
bool value = []() {
173 return GemmPipeline::UsePersistentKernel;
192 static_assert(DsLayout::size() == DsDataType::size(),
193 "The size of DsLayout and DsDataType should be the same");
199 return concat(
'_',
"gemm", gemm_prec_str<ADataType, BDataType>, GemmPipeline::GetName());
205 return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
217 const auto kernel = kentry<KernelBlockSize, 1, Kernel, KernelArgs>;
220 hipOccupancyMaxActiveBlocksPerMultiprocessor(&occupancy, kernel,
KernelBlockSize, 0));
221 const int grid_size = get_available_compute_units(s) * occupancy;
222 return dim3(grid_size, 1, 1);
247 return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
254 constexpr
auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(
number<2>{});
255 const index_t K_t = __builtin_amdgcn_readfirstlane(kargs.
k_batch * K1);
256 const index_t KRead = __builtin_amdgcn_readfirstlane((kargs.
K + K_t - 1) / K_t * K1);
258 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
262 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
267 if constexpr(std::is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
271 else if constexpr(std::is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
276 if(k_id <
static_cast<uint32_t
>(kargs.
k_batch - 1))
278 splitted_k = __builtin_amdgcn_readfirstlane(KRead);
293 if constexpr(EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
306 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
308 if(kargs.
K % (TilePartitioner::KPerBlock * kargs.
k_batch) != 0 &&
309 GemmPipeline::kPadK ==
false)
313 CK_TILE_ERROR(
"Can't support K that is not a multiple of k_batch * KPerBlock "
318 if(kargs.
K % GemmPipeline::GetVectorSizeA() != 0)
322 CK_TILE_ERROR(
"K is not a multiple of vector load size for A tensor!");
329 if(kargs.
M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM ==
false)
334 "Can't support M that is not a multiple of MPerBlock without padding!");
338 if(kargs.
M % GemmPipeline::GetVectorSizeA() != 0)
342 CK_TILE_ERROR(
"M is not a multiple of vector load size for A tensor!");
348 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
350 if(kargs.
N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN ==
false)
355 "Can't support N that is not a multiple of NPerBlock without padding!");
359 if(kargs.
N % GemmPipeline::GetVectorSizeB() != 0)
363 CK_TILE_ERROR(
"N is not a multiple of vector load size for B tensor!");
370 if(kargs.
K % (TilePartitioner::KPerBlock * kargs.
k_batch) != 0 &&
371 GemmPipeline::kPadK ==
false)
375 CK_TILE_ERROR(
"Can't support K that is not a multiple of k_batch * KPerBlock "
380 if(kargs.
K % GemmPipeline::GetVectorSizeB() != 0)
384 CK_TILE_ERROR(
"K is not a multiple of vector load size for B tensor!");
390 bool DTesnorIsValid = {
true};
393 if(std::is_same_v<DiLayout, ELayout> ==
false)
395 DTesnorIsValid =
false;
397 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
399 if(kargs.
N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN ==
false)
403 CK_TILE_ERROR(
"Can't support N for tensor D that is not a multiple of "
404 "NPerBlock without padding!");
406 DTesnorIsValid =
false;
408 if(kargs.
N % EpiloguePipeline::GetVectorSizeD(index) != 0)
412 CK_TILE_ERROR(
"N is not a multiple of vector load size for D tensor!");
414 DTesnorIsValid =
false;
419 if(kargs.
M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM ==
false)
423 CK_TILE_ERROR(
"Can't support M for tensor D that is not a multiple of "
424 "MPerBlock without padding!");
426 DTesnorIsValid =
false;
428 if(kargs.
M % EpiloguePipeline::GetVectorSizeD(index) != 0)
432 CK_TILE_ERROR(
"M is not a multiple of vector load size for D tensor!");
434 DTesnorIsValid =
false;
439 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
441 if(kargs.
N % TilePartitioner::NPerBlock != 0 && GemmPipeline::kPadN ==
false)
446 "Can't support N that is not a multiple of NPerBlock without padding!");
450 if(kargs.
N % EpiloguePipeline::GetVectorSizeC() != 0)
454 CK_TILE_ERROR(
"N is not a multiple of vector load size for C tensor!");
461 if(kargs.
M % TilePartitioner::MPerBlock != 0 && GemmPipeline::kPadM ==
false)
466 "Can't support M that is not a multiple of MPerBlock without padding!");
470 if(kargs.
M % EpiloguePipeline::GetVectorSizeC() != 0)
474 CK_TILE_ERROR(
"M is not a multiple of vector load size for C tensor!");
479 return DTesnorIsValid;
482 template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
486 const std::array<const void*, NumDTensor>& ds_ptr,
491 static_assert(!TilePartitioner::BlockGemmShape::PermuteA,
"Not implemented!");
492 const auto& a_tensor_view = [&]() {
493 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
495 return make_naive_tensor_view<address_space_enum::global>(
499 number<GemmPipeline::GetVectorSizeA()>{},
504 return make_naive_tensor_view<address_space_enum::global>(
508 number<GemmPipeline::GetVectorSizeA()>{},
513 const auto& b_tensor_view = [&]() {
514 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
516 if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
518 constexpr
index_t K1 = GemmPipeline::GetSmemPackB();
520 constexpr
index_t VectorSizeB =
std::min(K1, GemmPipeline::GetVectorSizeB());
521 const auto b_k0_n_k1_desc =
532 return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
536 return make_naive_tensor_view<address_space_enum::global>(
540 number<GemmPipeline::GetVectorSizeB()>{},
546 if constexpr(TilePartitioner::BlockGemmShape::PermuteB)
548 constexpr
index_t K1 = GemmPipeline::GetSmemPackB();
550 constexpr
index_t VectorSizeB =
std::min(K1, GemmPipeline::GetVectorSizeB());
551 const auto b_k0_n_k1_desc =
562 return make_tensor_view<address_space_enum::global>(b_ptr, b_n_k_desc);
566 return make_naive_tensor_view<address_space_enum::global>(
570 number<GemmPipeline::GetVectorSizeB()>{},
580 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
582 return make_naive_tensor_view<address_space_enum::global>(
583 static_cast<const DDataType_*
>(ds_ptr[i]),
586 number<EpiloguePipeline::GetVectorSizeD(i)>{},
591 return make_naive_tensor_view<address_space_enum::global>(
592 static_cast<const DDataType_*
>(ds_ptr[i]),
595 number<EpiloguePipeline::GetVectorSizeD(i)>{},
602 const auto& e_tensor_view = [&]() {
603 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
605 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
609 number<EpiloguePipeline::GetVectorSizeC()>{},
614 return make_naive_tensor_view<address_space_enum::global, DstInMemOp>(
623 return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, e_tensor_view);
626 template <
typename TensorView>
629 const auto& a_pad_view = [&]() {
630 const auto& a_tensor_view = views.at(
I0);
631 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
647 const auto& b_pad_view = [&]() {
648 const auto& b_tensor_view = views.at(
I1);
649 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
667 const auto& d_tensor_view = views.at(
I2);
669 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
687 const auto& e_pad_view = [&]() {
688 const auto& e_tensor_view = views.at(
I3);
689 if constexpr(std::is_same_v<ELayout, tensor_layout::gemm::RowMajor>)
705 return make_tuple(a_pad_view, b_pad_view, ds_pad_view, e_pad_view);
708 template <
typename PadView>
712 const auto& a_pad_view = views.at(
I0);
713 const auto& b_pad_view = views.at(
I1);
714 const auto& ds_pad_view = views.at(
I2);
715 const auto& e_pad_view = views.at(
I3);
717 const auto& a_block_window = [&]() {
718 if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
734 const auto& b_block_window = [&]() {
735 if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
754 if constexpr(std::is_same_v<DiLayout, tensor_layout::gemm::RowMajor>)
776 return make_tuple(a_block_window, b_block_window, ds_block_window, e_block_window);
793 template <
bool UseDefaultScheduler = true>
796 const std::array<const void*, NumDTensor>& ds_ptr,
805 const auto& gemm_tensor_views_tuple =
806 MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
807 a_ptr, b_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
812 const index_t num_loop = __builtin_amdgcn_readfirstlane(
813 TilePartitioner::GetLoopNum(splitk_batch_offset.
splitted_k));
816 const auto& a_block_window = gemm_tile_windows.at(
I0);
817 const auto& b_block_window = gemm_tile_windows.at(
I1);
818 const auto& d_block_window = gemm_tile_windows.at(
I2);
821 a_block_window, b_block_window, num_loop, smem_ptr_0);
826 auto& c_block_window = gemm_tile_windows.at(
I3);
829 operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
830 c_block_window, c_block_tile, d_block_window, smem_ptr_0);
853 const std::array<const void*, NumDTensor>& ds_ptr,
855 void* __restrict__ smem_ptr_0,
856 void* __restrict__ smem_ptr_1,
863 const auto& gemm_tensor_views_tuple =
864 MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
865 a_ptr, b_ptr, ds_ptr, e_ptr, kargs, splitk_batch_offset);
870 const index_t num_loop = __builtin_amdgcn_readfirstlane(
871 TilePartitioner::GetLoopNum(splitk_batch_offset.
splitted_k));
874 const auto& a_block_window = gemm_tile_windows.at(
I0);
875 const auto& b_block_window = gemm_tile_windows.at(
I1);
876 const auto& d_block_window = gemm_tile_windows.at(
I2);
879 a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
882 auto& c_block_window = gemm_tile_windows.at(
I3);
885 operator()<decltype(c_block_window), decltype(c_block_tile), decltype(d_block_window)>(
886 c_block_window, c_block_tile, d_block_window, smem_ptr_0);
890 template <
bool U = !PersistentKernel,
typename = std::enable_if_t<U>>
893 const auto blockId = __builtin_amdgcn_readfirstlane(blockIdx.x);
894 const auto [iM, iN] =
TilePartitioner{kargs.
M, kargs.
N}.GetOutputTileIndex(blockId);
895 const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
896 const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
911 if constexpr(GemmPipeline::DoubleSmemBuffer ==
true)
915 EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
933 EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
936 constexpr
auto scheduler_type = (GemmPipeline::NumWaveGroups == 1);
937 RunGemm<scheduler_type>(a_ptr,
951 template <
bool U = PersistentKernel,
typename = std::enable_if_t<U>,
typename =
void>
954 const auto grid_size = __builtin_amdgcn_readfirstlane(
get_grid_size());
955 const auto num_tiles =
956 __builtin_amdgcn_readfirstlane(TilePartitioner::GridSize(kargs.
M, kargs.
N));
957 const auto num_work = __builtin_amdgcn_readfirstlane(num_tiles * kargs.
k_batch);
958 auto block_id = __builtin_amdgcn_readfirstlane(
get_block_id());
960 while(block_id < num_work)
963 const auto tile_idx = __builtin_amdgcn_readfirstlane(block_id % num_tiles);
964 const auto [iM, iN] =
TilePartitioner{kargs.
M, kargs.
N}.GetOutputTileIndex(tile_idx);
965 const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
966 const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
969 const auto k_batch = __builtin_amdgcn_readfirstlane(block_id / num_tiles);
980 if constexpr(GemmPipeline::DoubleSmemBuffer ==
true)
983 if constexpr(!(EpiloguePipeline::MemoryOperation ==
985 EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1002 if constexpr(!(EpiloguePipeline::MemoryOperation ==
1004 EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
1013 splitk_batch_offset,
1019 block_id += grid_size;
1020 if(block_id >= num_work)
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:255
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1672
CK_TILE_HOST void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:13
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:529
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1615
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:41
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:184
CK_TILE_DEVICE index_t get_warp_id()
Definition: arch.hpp:74
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
typename detail::detector< nonesuch, void, Op, Args... >::value_t is_detected
Definition: type_traits.hpp:67
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:406
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
CK_TILE_DEVICE index_t get_block_id()
Definition: arch.hpp:81
CK_TILE_DEVICE index_t get_grid_size()
Definition: arch.hpp:60
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
typename tuple_element< I, TTuple >::type tuple_element_t
Definition: tuple.hpp:208
The GEMM kernel host arguments.
Definition: gemm_kernel.hpp:30
index_t M
Definition: gemm_kernel.hpp:67
index_t K
Definition: gemm_kernel.hpp:69
index_t stride_E
Definition: gemm_kernel.hpp:75
const void * b_ptr
Definition: gemm_kernel.hpp:60
const std::array< index_t, NumDTensor > stride_Ds
Definition: gemm_kernel.hpp:72
index_t k_batch
Definition: gemm_kernel.hpp:79
index_t stride_A
Definition: gemm_kernel.hpp:70
const void * a_ptr
Definition: gemm_kernel.hpp:59
index_t N
Definition: gemm_kernel.hpp:68
index_t stride_B
Definition: gemm_kernel.hpp:71
index_t stride_C
Definition: gemm_kernel.hpp:76
CK_TILE_HOST GemmHostArgs(const void *a_ptr_, const void *b_ptr_, const std::array< const void *, NumDTensor > &ds_ptr_, void *e_ptr_, index_t k_batch_, index_t M_, index_t N_, index_t K_, index_t stride_A_, index_t stride_B_, const std::array< index_t, NumDTensor > &stride_Ds_, index_t stride_E_)
Definition: gemm_kernel.hpp:32
void * e_ptr
Definition: gemm_kernel.hpp:64
CK_TILE_HOST GemmHostArgs()=default
void * c_ptr
Definition: gemm_kernel.hpp:65
const std::array< const void *, NumDTensor > ds_ptr
Definition: gemm_kernel.hpp:61
Definition: gemm_kernel.hpp:251
index_t b_k_split_offset
Definition: gemm_kernel.hpp:287
index_t a_k_split_offset
Definition: gemm_kernel.hpp:286
__device__ SplitKBatchOffset(const KernelArgs &kargs, const std::size_t k_id=blockIdx.z)
Definition: gemm_kernel.hpp:252
index_t splitted_k
Definition: gemm_kernel.hpp:288
Definition: gemm_kernel.hpp:167
static constexpr bool value
Definition: gemm_kernel.hpp:171
decltype(T::UsePersistentKernel) has_persistent_type
Definition: gemm_kernel.hpp:169
The GEMM kernel device arguments.
Definition: gemm_kernel.hpp:85
const void * a_ptr
The A input tensor's pointer to device memory.
Definition: gemm_kernel.hpp:87
index_t stride_A
The distance between consecutive elements of non-contiguous dimension (in memory) of A tensor.
Definition: gemm_kernel.hpp:102
const void * b_ptr
The B input tensor's pointer to device memory.
Definition: gemm_kernel.hpp:89
index_t N
GEMM's N dimension size.
Definition: gemm_kernel.hpp:97
void * e_ptr
The E output tensor's pointer to device memory.
Definition: gemm_kernel.hpp:93
index_t k_batch
Definition: gemm_kernel.hpp:112
const std::array< const void *, NumDTensor > ds_ptr
The Ds input tensor's pointer to device memory.
Definition: gemm_kernel.hpp:91
index_t K
GEMM's K dimension size.
Definition: gemm_kernel.hpp:99
index_t stride_B
The distance between consecutive elements of non-contiguous dimension (in memory) of B tensor.
Definition: gemm_kernel.hpp:105
index_t M
GEMM's M dimension size.
Definition: gemm_kernel.hpp:95
std::array< index_t, NumDTensor > stride_Ds
The distance between consecutive elements of non-contiguous dimension (in memory) of Ds tensor.
Definition: gemm_kernel.hpp:108
index_t stride_E
The distance between consecutive elements of non-contiguous dimension (in memory) of E tensor.
Definition: gemm_kernel.hpp:111
The GEMM kernel template.
Definition: gemm_kernel.hpp:153
static CK_TILE_HOST bool IsSupportedArgument(const KernelArgs &kargs)
Definition: gemm_kernel.hpp:291
remove_cvref_t< typename EpiloguePipeline::ODataType > EDataType
Definition: gemm_kernel.hpp:183
static constexpr CK_TILE_HOST KernelArgs MakeKernelArgs(const GemmHostArgs< NumDTensor > &hostArgs)
Definition: gemm_kernel.hpp:228
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition: gemm_kernel.hpp:891
remove_cvref_t< typename GemmPipeline::ADataType > ADataType
Definition: gemm_kernel.hpp:180
static CK_TILE_HOST const std::string GetName()
Definition: gemm_kernel.hpp:196
remove_cvref_t< typename EpiloguePipeline::DsLayout > DsLayout
Definition: gemm_kernel.hpp:161
static constexpr auto I3
Definition: gemm_kernel.hpp:190
remove_cvref_t< typename GemmPipeline::CLayout > ELayout
Definition: gemm_kernel.hpp:160
static constexpr CK_TILE_HOST auto GridSize(index_t M, index_t N, index_t KBatch)
Definition: gemm_kernel.hpp:203
static CK_TILE_DEVICE void RunGemm(const ADataType *a_ptr, const BDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *smem_ptr_0, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: gemm_kernel.hpp:794
remove_cvref_t< typename GemmPipeline::ALayout > ALayout
Definition: gemm_kernel.hpp:157
static CK_TILE_DEVICE auto MakeGemmTensorViews(const ADataType *a_ptr, const BDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset)
Definition: gemm_kernel.hpp:484
static constexpr index_t KernelBlockSize
Definition: gemm_kernel.hpp:163
remove_cvref_t< typename GemmPipeline::BDataType > BDataType
Definition: gemm_kernel.hpp:181
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: gemm_kernel.hpp:710
static constexpr auto I0
Definition: gemm_kernel.hpp:187
static CK_TILE_DEVICE void RunGemm2LDS(const ADataType *a_ptr, const BDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, EDataType *e_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const KernelArgs &kargs, const SplitKBatchOffset &splitk_batch_offset, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: gemm_kernel.hpp:851
static CK_TILE_HOST auto MaxOccupancyGridSize(const stream_config &s) -> dim3
Get the maximum occupancy grid size for the persistent kernel on the current device.
Definition: gemm_kernel.hpp:214
static constexpr auto I1
Definition: gemm_kernel.hpp:188
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: gemm_kernel.hpp:245
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: gemm_kernel.hpp:155
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: gemm_kernel.hpp:162
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: gemm_kernel.hpp:627
GemmKernelArgs< DsLayout::size()> KernelArgs
Definition: gemm_kernel.hpp:194
remove_cvref_t< typename GemmPipeline::BLayout > BLayout
Definition: gemm_kernel.hpp:158
static constexpr index_t NumDTensor
Definition: gemm_kernel.hpp:185
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: gemm_kernel.hpp:154
CK_TILE_DEVICE void operator()(KernelArgs kargs) const
Definition: gemm_kernel.hpp:952
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: gemm_kernel.hpp:156
static constexpr bool PersistentKernel
Definition: gemm_kernel.hpp:178
static constexpr auto I2
Definition: gemm_kernel.hpp:189
static constexpr CK_TILE_HOST auto BlockSize()
Definition: gemm_kernel.hpp:225
Definition: integral_constant.hpp:13
Definition: type_traits.hpp:115
Definition: sequence.hpp:52
Definition: functional.hpp:43
Definition: stream_config.hpp:26
#define CK_TILE_ENV(name)
Definition: env.hpp:145