20 template <
typename GridwiseGemm,
 
   23           typename AGridDesc_K0_M_K1,
 
   24           typename BGridDesc_K0_N_K1,
 
   25           typename CGridDesc_M_N,
 
   26           bool HasMainKBlockLoop>
 
   28 #if CK_USE_LAUNCH_BOUNDS 
   31 #if CK_USE_WAVES_PER_EU 
   32         __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
 
   35                                 const FloatAB* __restrict__ p_b_grid,
 
   36                                 FloatC* __restrict__ p_c_grid,
 
   37                                 const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1,
 
   38                                 const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1,
 
   39                                 const CGridDesc_M_N c_grid_desc_m_n)
 
   41 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ 
   43     __shared__ 
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
 
   45     GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid,
 
   56     ignore                = a_grid_desc_k0_m_k1;
 
   57     ignore                = b_grid_desc_k0_n_k1;
 
   62 template <
typename Gr
idwiseGemm, 
bool HasMainKBlockLoop>
 
   64 #if CK_USE_LAUNCH_BOUNDS 
   67 #if CK_USE_WAVES_PER_EU 
   68         __attribute__((amdgpu_waves_per_eu(CK_MIN_WAVES_PER_EU, CK_MAX_WAVES_PER_EU)))
 
   72 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ 
   74     __shared__ 
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
 
   76     const auto a_grid_desc_k0_m_k1 =
 
   78             karg.M, karg.MPadded, karg.K, karg.K0, karg.StrideA));
 
   79     const auto b_grid_desc_k0_n_k1 =
 
   81             karg.K, karg.N, karg.NPadded, karg.K0, karg.StrideB));
 
   83         karg.M, karg.MPadded, karg.N, karg.NPadded, karg.StrideC));
 
   85     GridwiseGemm::template Run<HasMainKBlockLoop>(karg.p_a_grid,
 
  102           typename AElementwiseOperation,
 
  103           typename BElementwiseOperation,
 
  104           typename CElementwiseOperation,
 
  113           typename ABlockTransferThreadClusterLengths_K0_M_K1,
 
  114           typename ABlockTransferThreadClusterArrangeOrder,
 
  115           typename ABlockTransferSrcAccessOrder,
 
  116           index_t ABlockTransferSrcVectorDim,
 
  117           index_t ABlockTransferSrcScalarPerVector,
 
  118           index_t ABlockTransferDstScalarPerVector_K1,
 
  119           bool AThreadTransferSrcResetCoordinateAfterRun,
 
  120           bool ABlockLdsExtraM,
 
  121           typename BBlockTransferThreadClusterLengths_K0_N_K1,
 
  122           typename BBlockTransferThreadClusterArrangeOrder,
 
  123           typename BBlockTransferSrcAccessOrder,
 
  124           index_t BBlockTransferSrcVectorDim,
 
  125           index_t BBlockTransferSrcScalarPerVector,
 
  126           index_t BBlockTransferDstScalarPerVector_K1,
 
  127           bool BThreadTransferSrcResetCoordinateAfterRun,
 
  128           bool BBlockLdsExtraN,
 
  129           typename CThreadTransferSrcDstAccessOrder,
 
  130           index_t CThreadTransferSrcDstVectorDim,
 
  131           index_t CThreadTransferDstScalarPerVector,
 
  132           index_t NumGemmKPrefetchStage = 1,
 
  156     template <
typename CGr
idDesc_M_N>
 
  159         return std::make_tuple(Block2CTileMap::CalculateGridSize(c_grid_desc_m_n), 1, 1);
 
  203             std::cout << 
"problem {" 
  212                       << 
"K0:" << 
K0 << 
"}" << std::endl;
 
  230                           const FloatAB* p_b_grid_,
 
  238             : 
Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_},
 
  251         decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage, LoopSched>())>;
 
  257 #if CK_GFX90A_DENORM_WORKAROUND 
  265         constexpr 
auto max_lds_align = 
K1;
 
  268         constexpr 
auto a_block_desc_k0_m_k1 = [&]() {
 
  269             if constexpr(ABlockLdsExtraM)
 
  282         return a_block_desc_k0_m_k1;
 
  287         constexpr 
auto max_lds_align = 
K1;
 
  290         constexpr 
auto b_block_desc_k0_n_k1 = [&]() {
 
  291             if constexpr(BBlockLdsExtraN)
 
  304         return b_block_desc_k0_n_k1;
 
  314         constexpr 
auto max_lds_align = 
K1;
 
  316         constexpr 
auto a_block_space_size_aligned =
 
  319         constexpr 
auto b_block_space_size_aligned =
 
  322         return (a_block_space_size_aligned + b_block_space_size_aligned) * 
sizeof(FloatAB);
 
  325     template <
typename AGr
idDesc_K0_M_K1, 
typename BGr
idDesc_K0_N_K1, 
typename CGr
idDesc_M_N>
 
  326     __host__ __device__ 
static constexpr 
bool 
  328                   const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
 
  329                   const CGridDesc_M_N& c_grid_desc_m_n)
 
  332                       "wrong! K1 need to be known at compile-time");
 
  334         static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) &&
 
  335                           (NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
 
  336                       "Invalid tuning param!");
 
  338         const auto M  = a_grid_desc_k0_m_k1.GetLength(
I1);
 
  339         const auto N  = b_grid_desc_k0_n_k1.GetLength(
I1);
 
  340         const auto K0 = a_grid_desc_k0_m_k1.GetLength(
I0);
 
  342         if(!(M == c_grid_desc_m_n.GetLength(
I0) && N == c_grid_desc_m_n.GetLength(
I1) &&
 
  343              K0 == b_grid_desc_k0_n_k1.GetLength(
I0) && 
K1 == a_grid_desc_k0_m_k1.GetLength(
I2) &&
 
  344              K1 == b_grid_desc_k0_n_k1.GetLength(
I2)))
 
  347         if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0))
 
  351         const auto num_k_loop = K0 / K0PerBlock;
 
  353         if(!GridwiseGemmPipe::IsSupported(num_k_loop))
 
  365                       "wrong! K1 need to be known at compile-time");
 
  367         static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) &&
 
  368                           (NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
 
  369                       "Invalid tuning param!");
 
  373         if(!GridwiseGemmPipe::IsSupported(num_k_loop))
 
  386         return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
 
  389     template <
typename CGr
idDesc>
 
  390     __host__ __device__ 
static constexpr 
auto 
  393         constexpr 
auto max_lds_align = 
K1;
 
  396         constexpr 
auto a_block_desc_k0_m_k1 = [&]() {
 
  397             if constexpr(ABlockLdsExtraM)
 
  411         constexpr 
auto b_block_desc_k0_n_k1 = [&]() {
 
  412             if constexpr(BBlockLdsExtraN)
 
  425         using BlockwiseGemm =
 
  430                                                                 decltype(a_block_desc_k0_m_k1),
 
  431                                                                 decltype(b_block_desc_k0_n_k1),
 
  438         return BlockwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n);
 
  444     template <
bool HasMainKBlockLoop,
 
  445               typename AGridDesc_K0_M_K1,
 
  446               typename BGridDesc_K0_N_K1,
 
  447               typename CGridDesc_M_N>
 
  448     __device__ 
static void Run(
const FloatAB* p_a_grid,
 
  449                                const FloatAB* p_b_grid,
 
  451                                void* __restrict__ p_shared,
 
  452                                const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
 
  453                                const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
 
  454                                const CGridDesc_M_N& c_grid_desc_m_n)
 
  456         const auto c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
 
  459         const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  460             p_a_grid, a_grid_desc_k0_m_k1.GetElementSpaceSize());
 
  461         const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  462             p_b_grid, b_grid_desc_k0_n_k1.GetElementSpaceSize());
 
  463         auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  464             p_c_grid, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetElementSpaceSize());
 
  466         const AElementwiseOperation a_element_op{};
 
  467         const BElementwiseOperation b_element_op{};
 
  468         const CElementwiseOperation c_element_op{};
 
  470         const auto block_2_ctile_map =
 
  474         const auto block_work_idx =
 
  477         if(!block_2_ctile_map.ValidCTileIndex(
 
  479                make_tuple(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I0),
 
  480                           c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I1))))
 
  486         const index_t m_block_data_idx_on_grid =
 
  487             __builtin_amdgcn_readfirstlane(block_work_idx[
I0] * MPerBlock);
 
  489         const index_t n_block_data_idx_on_grid =
 
  490             __builtin_amdgcn_readfirstlane(block_work_idx[
I1] * NPerBlock);
 
  493         constexpr 
auto max_lds_align = 
K1;
 
  502         auto a_blockwise_copy =
 
  504                                                 AElementwiseOperation,
 
  508                                                 ABlockTransferThreadClusterLengths_K0_M_K1,
 
  509                                                 ABlockTransferThreadClusterArrangeOrder,
 
  512                                                 decltype(a_grid_desc_k0_m_k1),
 
  513                                                 decltype(a_block_desc_k0_m_k1),
 
  514                                                 ABlockTransferSrcAccessOrder,
 
  516                                                 ABlockTransferSrcVectorDim,
 
  518                                                 ABlockTransferSrcScalarPerVector,
 
  519                                                 ABlockTransferDstScalarPerVector_K1,
 
  522                                                 AThreadTransferSrcResetCoordinateAfterRun,
 
  524                                                 NumGemmKPrefetchStage>(
 
  528                 a_block_desc_k0_m_k1,
 
  533         auto b_blockwise_copy =
 
  535                                                 BElementwiseOperation,
 
  539                                                 BBlockTransferThreadClusterLengths_K0_N_K1,
 
  540                                                 BBlockTransferThreadClusterArrangeOrder,
 
  543                                                 decltype(b_grid_desc_k0_n_k1),
 
  544                                                 decltype(b_block_desc_k0_n_k1),
 
  545                                                 BBlockTransferSrcAccessOrder,
 
  547                                                 BBlockTransferSrcVectorDim,
 
  549                                                 BBlockTransferSrcScalarPerVector,
 
  550                                                 BBlockTransferDstScalarPerVector_K1,
 
  553                                                 BThreadTransferSrcResetCoordinateAfterRun,
 
  555                                                 NumGemmKPrefetchStage>(
 
  559                 b_block_desc_k0_n_k1,
 
  575             decltype(a_block_desc_k0_m_k1),
 
  576             decltype(b_block_desc_k0_n_k1),
 
  584         auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
 
  587         constexpr 
auto a_block_space_size_aligned =
 
  590         auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  591             static_cast<FloatABAdjusted*
>(p_shared), a_block_desc_k0_m_k1.GetElementSpaceSize());
 
  593         auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  595             b_block_desc_k0_n_k1.GetElementSpaceSize());
 
  597         constexpr 
auto a_block_slice_copy_step = 
make_multi_index(K0PerBlock, 0, 0);
 
  598         constexpr 
auto b_block_slice_copy_step = 
make_multi_index(K0PerBlock, 0, 0);
 
  601         const auto K0                       = a_grid_desc_k0_m_k1.GetLength(
I0);
 
  602         const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(K0 / K0PerBlock);
 
  604         GridwiseGemmPipe::template Run<HasMainKBlockLoop>(a_grid_desc_k0_m_k1,
 
  605                                                           a_block_desc_k0_m_k1,
 
  609                                                           a_block_slice_copy_step,
 
  611                                                           b_block_desc_k0_n_k1,
 
  615                                                           b_block_slice_copy_step,
 
  618                                                           num_k_block_main_loop);
 
  622             constexpr 
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
 
  623                 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
  625             constexpr 
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
 
  626                 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
  628             constexpr 
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I0);
 
  629             constexpr 
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I1);
 
  630             constexpr 
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I2);
 
  631             constexpr 
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I3);
 
  632             constexpr 
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I4);
 
  633             constexpr 
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I5);
 
  634             constexpr 
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I6);
 
  635             constexpr 
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2.GetLength(
I7);
 
  639             const auto c_thread_mtx_on_block =
 
  640                 blockwise_gemm.CalculateCThreadOriginDataIndex(
I0, 
I0, 
I0, 
I0);
 
  642             const index_t m_thread_data_on_grid =
 
  643                 m_block_data_idx_on_grid + c_thread_mtx_on_block[
I0];
 
  645             const index_t n_thread_data_on_grid =
 
  646                 n_block_data_idx_on_grid + c_thread_mtx_on_block[
I1];
 
  648             const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor =
 
  654             const auto m_thread_data_on_grid_idx =
 
  655                 m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
 
  663             const auto n_thread_data_on_grid_idx =
 
  664                 n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex(
 
  670                                                    decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
  671                                                    decltype(c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
  672                                                    CElementwiseOperation,
 
  674                                                    CThreadTransferSrcDstAccessOrder,
 
  675                                                    CThreadTransferSrcDstVectorDim,
 
  676                                                    CThreadTransferDstScalarPerVector,
 
  677                                                    CGlobalMemoryDataOperation,
 
  680                     c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
  682                                      n_thread_data_on_grid_idx[
I0],
 
  683                                      m_thread_data_on_grid_idx[
I1],
 
  684                                      n_thread_data_on_grid_idx[
I1],
 
  685                                      m_thread_data_on_grid_idx[
I2],
 
  686                                      m_thread_data_on_grid_idx[
I3],
 
  687                                      m_thread_data_on_grid_idx[
I4],
 
  688                                      n_thread_data_on_grid_idx[
I2]),
 
  691             c_thread_copy.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
  694                               c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
  708           typename AElementwiseOperation,
 
  709           typename BElementwiseOperation,
 
  710           typename CElementwiseOperation,
 
  720           typename ABlockTransferThreadClusterLengths_K0_M_K1,
 
  721           typename ABlockTransferThreadClusterArrangeOrder,
 
  722           typename ABlockTransferSrcAccessOrder,
 
  723           index_t ABlockTransferSrcVectorDim,
 
  724           index_t ABlockTransferSrcScalarPerVector,
 
  725           index_t ABlockTransferDstScalarPerVector_K1,
 
  726           bool AThreadTransferSrcResetCoordinateAfterRun,
 
  727           bool ABlockLdsExtraM,
 
  728           typename BBlockTransferThreadClusterLengths_K0_N_K1,
 
  729           typename BBlockTransferThreadClusterArrangeOrder,
 
  730           typename BBlockTransferSrcAccessOrder,
 
  731           index_t BBlockTransferSrcVectorDim,
 
  732           index_t BBlockTransferSrcScalarPerVector,
 
  733           index_t BBlockTransferDstScalarPerVector_K1,
 
  734           bool BThreadTransferSrcResetCoordinateAfterRun,
 
  735           bool BBlockLdsExtraN,
 
  736           typename CThreadTransferSrcDstAccessOrder,
 
  737           index_t CThreadTransferSrcDstVectorDim,
 
  738           index_t CThreadTransferDstScalarPerVector,
 
  739           index_t NumGemmKPrefetchStage = 1,
 
  747                                               CGlobalMemoryDataOperation,
 
  748                                               AElementwiseOperation,
 
  749                                               BElementwiseOperation,
 
  750                                               CElementwiseOperation,
 
  759                                               ABlockTransferThreadClusterLengths_K0_M_K1,
 
  760                                               ABlockTransferThreadClusterArrangeOrder,
 
  761                                               ABlockTransferSrcAccessOrder,
 
  762                                               ABlockTransferSrcVectorDim,
 
  763                                               ABlockTransferSrcScalarPerVector,
 
  764                                               ABlockTransferDstScalarPerVector_K1,
 
  765                                               AThreadTransferSrcResetCoordinateAfterRun,
 
  767                                               BBlockTransferThreadClusterLengths_K0_N_K1,
 
  768                                               BBlockTransferThreadClusterArrangeOrder,
 
  769                                               BBlockTransferSrcAccessOrder,
 
  770                                               BBlockTransferSrcVectorDim,
 
  771                                               BBlockTransferSrcScalarPerVector,
 
  772                                               BBlockTransferDstScalarPerVector_K1,
 
  773                                               BThreadTransferSrcResetCoordinateAfterRun,
 
  775                                               CThreadTransferSrcDstAccessOrder,
 
  776                                               CThreadTransferSrcDstVectorDim,
 
  777                                               CThreadTransferDstScalarPerVector,
 
  778                                               NumGemmKPrefetchStage,
 
  787                                                 CGlobalMemoryDataOperation,
 
  788                                                 AElementwiseOperation,
 
  789                                                 BElementwiseOperation,
 
  790                                                 CElementwiseOperation,
 
  799                                                 ABlockTransferThreadClusterLengths_K0_M_K1,
 
  800                                                 ABlockTransferThreadClusterArrangeOrder,
 
  801                                                 ABlockTransferSrcAccessOrder,
 
  802                                                 ABlockTransferSrcVectorDim,
 
  803                                                 ABlockTransferSrcScalarPerVector,
 
  804                                                 ABlockTransferDstScalarPerVector_K1,
 
  805                                                 AThreadTransferSrcResetCoordinateAfterRun,
 
  807                                                 BBlockTransferThreadClusterLengths_K0_N_K1,
 
  808                                                 BBlockTransferThreadClusterArrangeOrder,
 
  809                                                 BBlockTransferSrcAccessOrder,
 
  810                                                 BBlockTransferSrcVectorDim,
 
  811                                                 BBlockTransferSrcScalarPerVector,
 
  812                                                 BBlockTransferDstScalarPerVector_K1,
 
  813                                                 BThreadTransferSrcResetCoordinateAfterRun,
 
  815                                                 CThreadTransferSrcDstAccessOrder,
 
  816                                                 CThreadTransferSrcDstVectorDim,
 
  817                                                 CThreadTransferDstScalarPerVector,
 
  818                                                 NumGemmKPrefetchStage,
 
  829     __device__ 
static auto 
  832         const auto a_grid_desc_m_k = [&]() {
 
  846             const auto KPad  = K0Pad * K1Value;
 
  881     __device__ 
static auto 
  884         const auto b_grid_desc_k_n = [&]() {
 
  898             const auto KPad  = K0Pad * K1Value;
 
  934     __device__ 
static auto 
  937         const auto c_grid_desc_m_n = [&]() {
 
  971                       "wrong! K1 need to be known at compile-time");
 
  973         static_assert((MPerBlock % (MPerXDL * MXdlPerWave) == 0) &&
 
  974                           (NPerBlock % (NXdlPerWave * NPerXDL)) == 0,
 
  975                       "Invalid tuning param!");
 
  982             if(!(problem.M % MPerBlock == 0))
 
  993             if(!(problem.N % NPerBlock == 0))
 
 1004             if(!(problem.K0 % K0PerBlock == 0))
 
 1012             if(problem.K % ABlockTransferSrcScalarPerVector != 0)
 
 1019             if(problem.M % ABlockTransferSrcScalarPerVector != 0)
 
 1027             if(problem.N % BBlockTransferSrcScalarPerVector != 0)
 
 1034             if(problem.K % BBlockTransferSrcScalarPerVector != 0)
 
 1043         if(!GridwiseGemmPipe::IsSupported(num_k_loop))
 
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
 
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
 
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
 
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
 
GemmSpecialization
Definition: gemm_specialization.hpp:11
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition: blockwise_gemm_xdlops.hpp:606
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
 
InMemoryDataOperationEnum
Definition: ck.hpp:278
 
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
 
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
 
ushort bhalf_t
Definition: data_type.hpp:29
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition: amd_wave_read_first_lane.hpp:100
 
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
 
typename conditional< predicate, X, Y >::type conditional_t
Definition: functional.hpp:115
 
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
 
__global__ void kernel_gemm_xdlops_v2r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1, const CGridDesc_M_N c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v2r3.hpp:34
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
 
LoopScheduler
Definition: loop_scheduler.hpp:15
 
int32_t index_t
Definition: ck.hpp:300
 
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
 
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
 
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
 
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:295
 
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
 
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
 
Definition: gridwise_gemm_xdlops_v2r3.hpp:228
 
const FloatAB * p_a_grid
Definition: gridwise_gemm_xdlops_v2r3.hpp:245
 
FloatC * p_c_grid
Definition: gridwise_gemm_xdlops_v2r3.hpp:247
 
const FloatAB * p_b_grid
Definition: gridwise_gemm_xdlops_v2r3.hpp:246
 
__host__ Argument(const FloatAB *p_a_grid_, const FloatAB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdlops_v2r3.hpp:229
 
Definition: gridwise_gemm_xdlops_v2r3.hpp:182
 
index_t NPadded
Definition: gridwise_gemm_xdlops_v2r3.hpp:222
 
index_t StrideC
Definition: gridwise_gemm_xdlops_v2r3.hpp:220
 
index_t M
Definition: gridwise_gemm_xdlops_v2r3.hpp:215
 
index_t StrideA
Definition: gridwise_gemm_xdlops_v2r3.hpp:218
 
index_t N
Definition: gridwise_gemm_xdlops_v2r3.hpp:216
 
index_t K
Definition: gridwise_gemm_xdlops_v2r3.hpp:217
 
index_t StrideB
Definition: gridwise_gemm_xdlops_v2r3.hpp:219
 
index_t K0
Definition: gridwise_gemm_xdlops_v2r3.hpp:223
 
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_)
Definition: gridwise_gemm_xdlops_v2r3.hpp:183
 
__host__ void Print() const
Definition: gridwise_gemm_xdlops_v2r3.hpp:201
 
index_t MPadded
Definition: gridwise_gemm_xdlops_v2r3.hpp:221
 
Definition: gridwise_gemm_xdlops_v2r3.hpp:781
 
static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdlops_v2r3.hpp:935
 
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdlops_v2r3.hpp:968
 
static constexpr auto I1
Definition: gridwise_gemm_xdlops_v2r3.hpp:138
 
static constexpr auto K1
Definition: gridwise_gemm_xdlops_v2r3.hpp:147
 
static __device__ auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t K0, index_t StrideA)
Definition: gridwise_gemm_xdlops_v2r3.hpp:830
 
static __device__ auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t N, index_t NPad, index_t K0, index_t StrideB)
Definition: gridwise_gemm_xdlops_v2r3.hpp:882
 
Definition: gridwise_gemm_xdlops_v2r3.hpp:136
 
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_v2r3.hpp:149
 
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdlops_v2r3.hpp:168
 
static __device__ void Run(const FloatAB *p_a_grid, const FloatAB *p_b_grid, FloatC *p_c_grid, void *__restrict__ p_shared, const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v2r3.hpp:448
 
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v2r3.hpp:327
 
__host__ static constexpr __device__ auto GetBBlockDescriptor_K0PerBlock_NPerBlock_K1()
Definition: gridwise_gemm_xdlops_v2r3.hpp:285
 
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdlops_v2r3.hpp:382
 
static constexpr auto I7
Definition: gridwise_gemm_xdlops_v2r3.hpp:144
 
static constexpr auto I2
Definition: gridwise_gemm_xdlops_v2r3.hpp:139
 
static __host__ auto CalculateGridSize(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v2r3.hpp:157
 
static constexpr auto I5
Definition: gridwise_gemm_xdlops_v2r3.hpp:142
 
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdlops_v2r3.hpp:173
 
FloatAB FloatABAdjusted
Definition: gridwise_gemm_xdlops_v2r3.hpp:260
 
static constexpr auto I0
Definition: gridwise_gemm_xdlops_v2r3.hpp:137
 
static constexpr auto I4
Definition: gridwise_gemm_xdlops_v2r3.hpp:141
 
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_xdlops_v2r3.hpp:163
 
static constexpr auto I1
Definition: gridwise_gemm_xdlops_v2r3.hpp:138
 
static constexpr auto K1
Definition: gridwise_gemm_xdlops_v2r3.hpp:147
 
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdlops_v2r3.hpp:362
 
static constexpr auto I6
Definition: gridwise_gemm_xdlops_v2r3.hpp:143
 
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v2r3.hpp:391
 
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage, LoopSched >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdlops_v2r3.hpp:251
 
__host__ static constexpr __device__ auto GetABlockDescriptor_K0PerBlock_MPerBlock_K1()
Definition: gridwise_gemm_xdlops_v2r3.hpp:263
 
static __host__ auto CalculateK0(index_t K)
Definition: gridwise_gemm_xdlops_v2r3.hpp:178
 
static constexpr auto I3
Definition: gridwise_gemm_xdlops_v2r3.hpp:140
 
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_v2r3.hpp:307
 
static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_xdlops_v2r3.hpp:151
 
Definition: sequence.hpp:43
 
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
 
Definition: threadwise_tensor_slice_transfer.hpp:39
 
Definition: integral_constant.hpp:20
 
Definition: is_known_at_compile_time.hpp:14
 
Definition: device_base.hpp:51
 
Definition: unary_element_wise_operation.hpp:308