21 template <
typename GridwiseGemm,
 
   24           typename AGridDesc_AK0_M_AK1,
 
   25           typename BGridDesc_BK0_N_BK1,
 
   26           typename CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
 
   27           typename AElementwiseOperation,
 
   28           typename BElementwiseOperation,
 
   29           typename CElementwiseOperation,
 
   30           typename Block2CTileMap,
 
   31           bool HasMainK0BlockLoop>
 
   33 #if CK_USE_LAUNCH_BOUNDS 
   37             const FloatAB* __restrict__ p_a_grid,
 
   38             const FloatAB* __restrict__ p_b_grid,
 
   39             FloatC* __restrict__ p_c_grid,
 
   40             const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1,
 
   41             const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1,
 
   42             const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
 
   43                 c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
 
   44             const AElementwiseOperation a_element_op,
 
   45             const BElementwiseOperation b_element_op,
 
   46             const CElementwiseOperation c_element_op,
 
   47             const Block2CTileMap block_2_ctile_map)
 
   49 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ 
   51     __shared__ 
char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
 
   53     GridwiseGemm::template Run<HasMainK0BlockLoop>(
 
   58         a_grid_desc_ak0_m_ak1,
 
   59         b_grid_desc_bk0_n_bk1,
 
   60         c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
 
   69     ignore = a_grid_desc_ak0_m_ak1;
 
   70     ignore = b_grid_desc_bk0_n_bk1;
 
   71     ignore = c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
 
   75     ignore = block_2_ctile_map;
 
   83     typename FloatCShuffle,
 
   86     typename AGridDesc_AK0_M_AK1,
 
   87     typename BGridDesc_BK0_N_BK1,
 
   88     typename CGridDesc_M_N,
 
   89     typename AElementwiseOperation,
 
   90     typename BElementwiseOperation,
 
   91     typename CElementwiseOperation,
 
  101     typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
  102     typename ABlockTransferThreadClusterArrangeOrder,
 
  103     typename ABlockTransferSrcAccessOrder,
 
  104     index_t ABlockTransferSrcVectorDim,
 
  105     index_t ABlockTransferSrcScalarPerVector,
 
  106     index_t ABlockTransferDstScalarPerVector_K1,
 
  107     bool AThreadTransferSrcResetCoordinateAfterRun,
 
  108     bool ABlockLdsExtraM,
 
  109     typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
  110     typename BBlockTransferThreadClusterArrangeOrder,
 
  111     typename BBlockTransferSrcAccessOrder,
 
  112     index_t BBlockTransferSrcVectorDim,
 
  113     index_t BBlockTransferSrcScalarPerVector,
 
  114     index_t BBlockTransferDstScalarPerVector_K1,
 
  115     bool BThreadTransferSrcResetCoordinateAfterRun,
 
  116     bool BBlockLdsExtraN,
 
  117     index_t CShuffleMXdlPerWavePerShuffle,
 
  118     index_t CShuffleNXdlPerWavePerShuffle,
 
  119     typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
 
  120     index_t CBlockTransferScalarPerVector_NWaveNPerXdl,
 
  121     index_t NumGemmKPrefetchStage = 1,
 
  135     static constexpr 
auto AK0 = 
Number<KPerBlock / AK1Value>{};
 
  136     static constexpr 
auto BK0 = 
Number<KPerBlock / BK1Value>{};
 
  143         decltype(GridwiseGemmPipeline_Selector<PipelineVer, NumGemmKPrefetchStage>())>;
 
  147         constexpr 
auto max_lds_align = 
AK1;
 
  150         constexpr 
auto a_block_desc_ak0_m_ak1 = [&]() {
 
  151             if constexpr(ABlockLdsExtraM)
 
  164         return a_block_desc_ak0_m_ak1;
 
  169         constexpr 
auto max_lds_align = 
BK1;
 
  172         constexpr 
auto b_block_desc_bk0_n_bk1 = [&]() {
 
  173             if constexpr(BBlockLdsExtraN)
 
  186         return b_block_desc_bk0_n_bk1;
 
  189     __host__ __device__ 
static constexpr 
auto 
  192         constexpr 
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
 
  193         constexpr 
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
 
  196             c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
 
  205         return c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
 
  215         constexpr 
auto a_block_space_size_aligned =
 
  218         constexpr 
auto b_block_space_size_aligned =
 
  222         constexpr 
auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
 
  225         constexpr 
auto c_block_size =
 
  226             c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
 
  227                 .GetElementSpaceSize();
 
  229         return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
 
  231                          c_block_size * 
sizeof(FloatCShuffle));
 
  235     template <
typename Block2CTileMap>
 
  236     __host__ __device__ 
static constexpr 
bool 
  238                   const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
 
  239                   const CGridDesc_M_N& c_grid_desc_m_n,
 
  240                   const Block2CTileMap& block_2_ctile_map)
 
  246         static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
 
  247                           (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
 
  248                       "Invalid tuning param!");
 
  250         const auto M = a_grid_desc_ak0_m_ak1.GetLength(
I1);
 
  251         const auto N = b_grid_desc_bk0_n_bk1.GetLength(
I1);
 
  252         const auto K = a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2);
 
  254         if(!(M == c_grid_desc_m_n.GetLength(
I0) && N == c_grid_desc_m_n.GetLength(
I1)))
 
  257         if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
 
  261         const auto num_k_loop = K / KPerBlock;
 
  263         if(!GridwiseGemmPipe::IsSupported(num_k_loop))
 
  268         if(!block_2_ctile_map.CheckValidity(c_grid_desc_m_n))
 
  279         const index_t num_loop = K / KPerBlock;
 
  281         return GridwiseGemmPipe::CalculateHasMainLoop(num_loop);
 
  284     __host__ __device__ 
static constexpr 
auto 
  286         const CGridDesc_M_N& c_grid_desc_m_n)
 
  288         const auto M = c_grid_desc_m_n.GetLength(
I0);
 
  289         const auto N = c_grid_desc_m_n.GetLength(
I1);
 
  291         const auto MBlock = M / MPerBlock;
 
  292         const auto NBlock = N / NPerBlock;
 
  294         constexpr 
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
 
  295         constexpr 
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
 
  297         const auto c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
 
  307         return c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl;
 
  325     template <
bool HasMainK0BlockLoop, 
typename Block2CTileMap>
 
  326     __device__ 
static void 
  327     Run(
const FloatAB* __restrict__ p_a_grid,
 
  328         const FloatAB* __restrict__ p_b_grid,
 
  329         FloatC* __restrict__ p_c_grid,
 
  330         void* __restrict__ p_shared,
 
  331         const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
 
  332         const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
 
  334             c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
 
  335         const AElementwiseOperation& a_element_op,
 
  336         const BElementwiseOperation& b_element_op,
 
  337         const CElementwiseOperation& c_element_op,
 
  338         const Block2CTileMap& block_2_ctile_map)
 
  340         const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  341             p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
 
  342         const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  343             p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
 
  344         auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  346             c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
 
  347                 .GetElementSpaceSize());
 
  350         const auto block_work_idx =
 
  353         if(!block_2_ctile_map.ValidCTileIndex(
 
  356                    c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
 
  358                    c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
 
  365         const index_t m_block_data_idx_on_grid =
 
  366             __builtin_amdgcn_readfirstlane(block_work_idx[
I0] * MPerBlock);
 
  368         const index_t n_block_data_idx_on_grid =
 
  369             __builtin_amdgcn_readfirstlane(block_work_idx[
I1] * NPerBlock);
 
  381         auto a_blockwise_copy =
 
  383                                                 AElementwiseOperation,
 
  387                                                 ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
  388                                                 ABlockTransferThreadClusterArrangeOrder,
 
  391                                                 decltype(a_grid_desc_ak0_m_ak1),
 
  392                                                 decltype(a_block_desc_ak0_m_ak1),
 
  393                                                 ABlockTransferSrcAccessOrder,
 
  395                                                 ABlockTransferSrcVectorDim,
 
  397                                                 ABlockTransferSrcScalarPerVector,
 
  398                                                 ABlockTransferDstScalarPerVector_K1,
 
  401                                                 AThreadTransferSrcResetCoordinateAfterRun,
 
  403                                                 NumGemmKPrefetchStage>(
 
  404                 a_grid_desc_ak0_m_ak1,
 
  407                 a_block_desc_ak0_m_ak1,
 
  412         auto b_blockwise_copy =
 
  414                                                 BElementwiseOperation,
 
  418                                                 BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
  419                                                 BBlockTransferThreadClusterArrangeOrder,
 
  422                                                 decltype(b_grid_desc_bk0_n_bk1),
 
  423                                                 decltype(b_block_desc_bk0_n_bk1),
 
  424                                                 BBlockTransferSrcAccessOrder,
 
  426                                                 BBlockTransferSrcVectorDim,
 
  428                                                 BBlockTransferSrcScalarPerVector,
 
  429                                                 BBlockTransferDstScalarPerVector_K1,
 
  432                                                 BThreadTransferSrcResetCoordinateAfterRun,
 
  434                                                 NumGemmKPrefetchStage>(
 
  435                 b_grid_desc_bk0_n_bk1,
 
  438                 b_block_desc_bk0_n_bk1,
 
  450         constexpr 
bool is_single_rate_mfma =
 
  458         constexpr 
auto is_scale_mfma = 
false;
 
  462                 selected_mfma.k_per_blk);
 
  464         auto blockwise_gemm =
 
  469                                                                 decltype(a_block_desc_ak0_m_ak1),
 
  470                                                                 decltype(b_block_desc_bk0_n_bk1),
 
  477         auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
 
  481             a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
 
  483         auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  484             static_cast<FloatAB*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
 
  486         auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  487             static_cast<FloatAB*
>(p_shared) + a_block_space_size_aligned,
 
  488             b_block_desc_bk0_n_bk1.GetElementSpaceSize());
 
  494         const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
 
  495             (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
 
  498         GridwiseGemmPipe::template Run<HasMainK0BlockLoop>(a_grid_desc_ak0_m_ak1,
 
  499                                                            a_block_desc_ak0_m_ak1,
 
  503                                                            a_block_slice_copy_step,
 
  504                                                            b_grid_desc_bk0_n_bk1,
 
  505                                                            b_block_desc_bk0_n_bk1,
 
  509                                                            b_block_slice_copy_step,
 
  512                                                            num_k_block_main_loop);
 
  516             static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
 
  517                               NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
 
  520             constexpr 
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
 
  521             constexpr 
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
 
  524             constexpr 
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
 
  525                 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
  529             constexpr 
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
 
  530                 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
  532             constexpr 
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
 
  533             constexpr 
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
 
  534             constexpr 
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
 
  535             constexpr 
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
 
  536             constexpr 
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
 
  537             constexpr 
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
 
  538             constexpr 
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
 
  539             constexpr 
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
 
  541             constexpr 
auto c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl =
 
  544             auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  545                 static_cast<FloatCShuffle*
>(p_shared),
 
  546                 c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl
 
  547                     .GetElementSpaceSize());
 
  550                 c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
 
  577             const auto c_thread_mtx_on_block =
 
  578                 blockwise_gemm.CalculateCThreadOriginDataIndex(
I0, 
I0, 
I0, 
I0);
 
  580             const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
 
  581             const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
 
  583             const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
 
  589             const auto m_thread_data_on_block_idx =
 
  590                 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
 
  593             const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
 
  599             const auto n_thread_data_on_block_idx =
 
  600                 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
 
  604             auto c_thread_copy_vgpr_to_lds =
 
  607                                                    decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
  608                                                    decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
  610                                                    Sequence<CShuffleMXdlPerWavePerShuffle,
 
  611                                                             CShuffleNXdlPerWavePerShuffle,
 
  624                     c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
  627                                      m_thread_data_on_block_idx[
I1],
 
  628                                      n_thread_data_on_block_idx[
I1],
 
  629                                      m_thread_data_on_block_idx[
I2],
 
  630                                      m_thread_data_on_block_idx[
I3],
 
  631                                      m_thread_data_on_block_idx[
I4],
 
  632                                      n_thread_data_on_block_idx[
I2]),
 
  638                 CElementwiseOperation,      
 
  639                 CGlobalMemoryDataOperation, 
 
  641                          CShuffleMXdlPerWavePerShuffle,
 
  644                          CShuffleNXdlPerWavePerShuffle,
 
  646                 CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
 
  650                 decltype(c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
 
  651                 decltype(c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl),
 
  654                 CBlockTransferScalarPerVector_NWaveNPerXdl, 
 
  657                 {c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
 
  659                  c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
 
  663             constexpr 
auto mxdlperwave_forward_step =
 
  665             constexpr 
auto nxdlperwave_forward_step =
 
  667             constexpr 
auto nxdlperwave_backward_step =
 
  671                 constexpr 
auto mxdlperwave = mxdlperwave_iter;
 
  675                            CShuffleNXdlPerWavePerShuffle>{}([&](
auto nxdlperwave_iter) {
 
  676                     constexpr 
bool nxdlperwave_forward_sweep =
 
  677                         (mxdlperwave % (2 * CShuffleMXdlPerWavePerShuffle) == 0);
 
  679                     constexpr 
index_t nxdlperwave_value =
 
  680                         nxdlperwave_forward_sweep
 
  682                             : (NXdlPerWave - nxdlperwave_iter - CShuffleNXdlPerWavePerShuffle);
 
  690                     c_thread_copy_vgpr_to_lds.Run(
 
  691                         c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
  694                         c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
  695                         c_shuffle_block_buf);
 
  701                     c_block_copy_lds_to_global.Run(
 
  702                         c_block_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
 
  704                         c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
 
  708                     if constexpr(nxdlperwave_forward_sweep &&
 
  709                                  (nxdlperwave < NXdlPerWave - CShuffleNXdlPerWavePerShuffle))
 
  711                         c_block_copy_lds_to_global.MoveDstSliceWindow(
 
  712                             c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
 
  713                             nxdlperwave_forward_step);
 
  715                     else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
 
  717                         c_block_copy_lds_to_global.MoveDstSliceWindow(
 
  718                             c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
 
  719                             nxdlperwave_backward_step);
 
  724                 if constexpr(mxdlperwave < MXdlPerWave - CShuffleMXdlPerWavePerShuffle)
 
  726                     c_block_copy_lds_to_global.MoveDstSliceWindow(
 
  727                         c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl,
 
  728                         mxdlperwave_forward_step);
 
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
 
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
 
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
 
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
 
InMemoryDataOperationEnum
Definition: ck.hpp:278
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_aligned(const Tuple< Lengths... > &lengths, Align align)
Definition: tensor_descriptor_helper.hpp:132
 
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
 
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
 
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
 
int32_t index_t
Definition: ck.hpp:300
 
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
 
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:18
 
__global__ void kernel_gemm_xdlops_v3r1(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const AElementwiseOperation a_element_op, const BElementwiseOperation b_element_op, const CElementwiseOperation c_element_op, const Block2CTileMap block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r1.hpp:36
 
Definition: block_to_ctile_map.hpp:260
 
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
 
Definition: gridwise_gemm_xdlops_v3r1.hpp:124
 
static __device__ void Run(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, void *__restrict__ p_shared, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl &c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const CElementwiseOperation &c_element_op, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r1.hpp:327
 
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdlops_v3r1.hpp:277
 
__host__ static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdlops_v3r1.hpp:167
 
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_v3r1.hpp:140
 
static constexpr auto I6
Definition: gridwise_gemm_xdlops_v3r1.hpp:131
 
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n, index_t, index_t)
Definition: gridwise_gemm_xdlops_v3r1.hpp:311
 
static constexpr auto I5
Definition: gridwise_gemm_xdlops_v3r1.hpp:130
 
static constexpr auto I7
Definition: gridwise_gemm_xdlops_v3r1.hpp:132
 
static constexpr auto I4
Definition: gridwise_gemm_xdlops_v3r1.hpp:129
 
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const CGridDesc_M_N &c_grid_desc_m_n, const Block2CTileMap &block_2_ctile_map)
Definition: gridwise_gemm_xdlops_v3r1.hpp:237
 
remove_cvref_t< decltype(GridwiseGemmPipeline_Selector< PipelineVer, NumGemmKPrefetchStage >())> GridwiseGemmPipe
Definition: gridwise_gemm_xdlops_v3r1.hpp:143
 
static constexpr auto I0
Definition: gridwise_gemm_xdlops_v3r1.hpp:125
 
remove_cvref_t< decltype(MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(CGridDesc_M_N{}))> CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
Definition: gridwise_gemm_xdlops_v3r1.hpp:320
 
static constexpr auto AK0
Definition: gridwise_gemm_xdlops_v3r1.hpp:135
 
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_xdlops_v3r1.hpp:285
 
remove_cvref_t< decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1))> DefaultBlock2CTileMap
Definition: gridwise_gemm_xdlops_v3r1.hpp:323
 
__host__ static constexpr __device__ auto GetCBlockDescriptor_MBlock_NXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl()
Definition: gridwise_gemm_xdlops_v3r1.hpp:190
 
static constexpr auto BK0
Definition: gridwise_gemm_xdlops_v3r1.hpp:136
 
__host__ static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdlops_v3r1.hpp:145
 
static constexpr auto BK1
Definition: gridwise_gemm_xdlops_v3r1.hpp:138
 
static constexpr auto I1
Definition: gridwise_gemm_xdlops_v3r1.hpp:126
 
static constexpr auto I3
Definition: gridwise_gemm_xdlops_v3r1.hpp:128
 
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_v3r1.hpp:208
 
static constexpr auto AK1
Definition: gridwise_gemm_xdlops_v3r1.hpp:137
 
static constexpr auto I2
Definition: gridwise_gemm_xdlops_v3r1.hpp:127
 
Definition: xdlops_gemm.hpp:942
 
Definition: sequence.hpp:43
 
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
 
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
 
Definition: threadwise_tensor_slice_transfer.hpp:39
 
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33
 
Definition: unary_element_wise_operation.hpp:308