20 template <
typename ABDataType,
 
   21           typename FloatGemmAcc,
 
   22           typename EDataTypeShuffle,
 
   24           typename AElementwiseOperation,
 
   25           typename BElementwiseOperation,
 
   26           typename EElementwiseOperation,
 
   28           typename AGridDesc_M_K,
 
   29           typename BGridDesc_N_K,
 
   30           typename EGridDesc_M_N,
 
   32           index_t TileLoadThreadGroupSize,
 
   33           index_t TileMathThreadGroupSize,
 
   43           typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
   44           typename ABlockTransferThreadClusterArrangeOrder,
 
   45           typename ABlockTransferSrcAccessOrder,
 
   46           index_t ABlockTransferSrcVectorDim,
 
   47           index_t ABlockTransferSrcScalarPerVector,
 
   48           index_t ABlockTransferDstScalarPerVector_AK1,
 
   49           bool AThreadTransferSrcResetCoordinateAfterRun,
 
   51           typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
   52           typename BBlockTransferThreadClusterArrangeOrder,
 
   53           typename BBlockTransferSrcAccessOrder,
 
   54           index_t BBlockTransferSrcVectorDim,
 
   55           index_t BBlockTransferSrcScalarPerVector,
 
   56           index_t BBlockTransferDstScalarPerVector_BK1,
 
   57           bool BThreadTransferSrcResetCoordinateAfterRun,
 
   59           index_t CShuffleMXdlPerWavePerShuffle,
 
   60           index_t CShuffleNXdlPerWavePerShuffle,
 
   61           typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
   62           index_t CShuffleBlockTransferScalarPerVector_NPerBlock>
 
  131     __host__ __device__ 
static constexpr 
auto 
  134         constexpr 
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
 
  135         constexpr 
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
 
  137         constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
  144         return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
 
  157             a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
 
  160             b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
 
  163         constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
  166         constexpr 
auto c_block_size =
 
  167             c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
 
  169         return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
 
  171                          c_block_size * 
sizeof(EDataTypeShuffle));
 
  175     template <
typename Block2ETileMap>
 
  176     __host__ __device__ 
static constexpr 
bool 
  178                   const BGridDesc_N_K& b_grid_desc_n_k,
 
  179                   const EGridDesc_M_N& e_grid_desc_m_n,
 
  180                   const Block2ETileMap& )
 
  182         static_assert((MPerBlock % (MPerXdl * MXdlPerWave) == 0) &&
 
  183                           (NPerBlock % (NXdlPerWave * NPerXdl)) == 0,
 
  184                       "Invalid tuning param!");
 
  186         const auto M = a_grid_desc_m_k.GetLength(
I0);
 
  187         const auto N = b_grid_desc_n_k.GetLength(
I0);
 
  188         const auto K = a_grid_desc_m_k.GetLength(
I1);
 
  191         if(!(M == e_grid_desc_m_n.GetLength(
I0) && N == e_grid_desc_m_n.GetLength(
I1) &&
 
  192              K == b_grid_desc_n_k.GetLength(
I1)))
 
  198         if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K % KPerBlock == 0))
 
  204         const auto num_k_loop = K / KPerBlock;
 
  206         if(!GridwiseGemmMath::IsSupported(num_k_loop))
 
  216         if(!(a_grid_desc_m_k.GetElementSpaceSize() * 
sizeof(ABDataType) <= TwoGB &&
 
  217              b_grid_desc_n_k.GetElementSpaceSize() * 
sizeof(ABDataType) <= TwoGB &&
 
  218              e_grid_desc_m_n.GetElementSpaceSize() * 
sizeof(EDataType) <= TwoGB))
 
  228         const index_t num_loop = K / KPerBlock;
 
  230         return GridwiseGemmMath::CalculateHasMainLoop(num_loop);
 
  234     __host__ __device__ 
static constexpr 
auto 
  237         const auto M = e_grid_desc_m_n.GetLength(
I0);
 
  238         const auto N = e_grid_desc_m_n.GetLength(
I1);
 
  243         const auto M0 = M / M1;
 
  244         const auto N0 = N / N1;
 
  246         constexpr 
auto M01 = 
I1;
 
  247         constexpr 
auto N01 = 
I1;
 
  249         const auto m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor =
 
  256         const auto cblockid_to_m00_m01_n00_n01_block_cluster_adaptor =
 
  262         const auto cblockid_to_m0_n0_block_cluster_adaptor =
 
  264                                   cblockid_to_m00_m01_n00_n01_block_cluster_adaptor);
 
  266         return cblockid_to_m0_n0_block_cluster_adaptor;
 
  269     __host__ __device__ 
static constexpr 
index_t 
  272         const auto M = e_grid_desc_m_n.GetLength(
I0);
 
  273         const auto N = e_grid_desc_m_n.GetLength(
I1);
 
  275         const index_t grid_size = (M / MPerBlock) * (N / NPerBlock);
 
  281     __host__ __device__ 
static constexpr 
auto 
  284         const auto M = a_grid_desc_m_k.GetLength(
I0);
 
  285         const auto K = a_grid_desc_m_k.GetLength(
I1);
 
  287         const auto AK0 = K / 
AK1;
 
  297     __host__ __device__ 
static constexpr 
auto 
  300         const auto N = b_grid_desc_n_k.GetLength(
I0);
 
  301         const auto K = b_grid_desc_n_k.GetLength(
I1);
 
  303         const auto BK0 = K / 
BK1;
 
  313     template <
typename EGr
idDescriptor_M_N>
 
  315         const EGridDescriptor_M_N& e_grid_desc_m_n)
 
  317         const auto M = e_grid_desc_m_n.GetLength(
I0);
 
  318         const auto N = e_grid_desc_m_n.GetLength(
I1);
 
  320         const auto MBlock = M / MPerBlock;
 
  321         const auto NBlock = N / NPerBlock;
 
  330         return e_grid_desc_mblock_mperblock_nblock_nperblock;
 
  340     template <
bool HasMainKBlockLoop,
 
  341               typename AGridDesc_AK0_M_AK1,
 
  342               typename BGridDesc_BK0_N_BK1,
 
  343               typename Block2ETileMap>
 
  344     __device__ 
static void Run(
const ABDataType* __restrict__ p_a_grid,
 
  345                                const ABDataType* __restrict__ p_b_grid,
 
  346                                EDataType* __restrict__ p_e_grid,
 
  347                                void* __restrict__ p_shared,
 
  348                                const AElementwiseOperation& a_element_op,
 
  349                                const BElementwiseOperation& b_element_op,
 
  350                                const EElementwiseOperation& e_element_op,
 
  351                                const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
 
  352                                const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
 
  354                                    e_grid_desc_mblock_mperblock_nblock_nperblock,
 
  355                                const Block2ETileMap& block_2_etile_map)
 
  371             a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
 
  373         auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  374             static_cast<ABDataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
 
  376         auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  377             static_cast<ABDataType*
>(p_shared) + a_block_space_size_aligned,
 
  378             b_block_desc_bk0_n_bk1.GetElementSpaceSize());
 
  383         const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
 
  384             (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
 
  388         const auto block_work_idx =
 
  392         const index_t m_block_data_idx_on_grid =
 
  393             __builtin_amdgcn_readfirstlane(block_work_idx[
I0] * MPerBlock);
 
  395         const index_t n_block_data_idx_on_grid =
 
  396             __builtin_amdgcn_readfirstlane(block_work_idx[
I1] * NPerBlock);
 
  402             const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  403                 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
 
  404             const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  405                 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
 
  408             auto a_blockwise_copy =
 
  410                                                     AElementwiseOperation,
 
  414                                                     ABlockTransferThreadClusterLengths_AK0_M_AK1,
 
  415                                                     ABlockTransferThreadClusterArrangeOrder,
 
  418                                                     decltype(a_grid_desc_ak0_m_ak1),
 
  419                                                     decltype(a_block_desc_ak0_m_ak1),
 
  420                                                     ABlockTransferSrcAccessOrder,
 
  422                                                     ABlockTransferSrcVectorDim,
 
  424                                                     ABlockTransferSrcScalarPerVector,
 
  425                                                     ABlockTransferDstScalarPerVector_AK1,
 
  428                                                     AThreadTransferSrcResetCoordinateAfterRun,
 
  430                                                     NumGemmKPrefetchStage>(
 
  431                     a_grid_desc_ak0_m_ak1,
 
  434                     a_block_desc_ak0_m_ak1,
 
  439             auto b_blockwise_copy =
 
  441                                                     BElementwiseOperation,
 
  445                                                     BBlockTransferThreadClusterLengths_BK0_N_BK1,
 
  446                                                     BBlockTransferThreadClusterArrangeOrder,
 
  449                                                     decltype(b_grid_desc_bk0_n_bk1),
 
  450                                                     decltype(b_block_desc_bk0_n_bk1),
 
  451                                                     BBlockTransferSrcAccessOrder,
 
  453                                                     BBlockTransferSrcVectorDim,
 
  455                                                     BBlockTransferSrcScalarPerVector,
 
  456                                                     BBlockTransferDstScalarPerVector_BK1,
 
  459                                                     BThreadTransferSrcResetCoordinateAfterRun,
 
  461                                                     NumGemmKPrefetchStage>(
 
  462                     b_grid_desc_bk0_n_bk1,
 
  465                     b_block_desc_bk0_n_bk1,
 
  469             GridwiseGemmLoad::template RunLoadWavePipeline<HasMainKBlockLoop>(
 
  470                 a_grid_desc_ak0_m_ak1,
 
  471                 a_block_desc_ak0_m_ak1,
 
  475                 a_block_slice_copy_step,
 
  476                 b_grid_desc_bk0_n_bk1,
 
  477                 b_block_desc_bk0_n_bk1,
 
  481                 b_block_slice_copy_step,
 
  482                 num_k_block_main_loop);
 
  491             constexpr 
bool is_single_rate_mfma =
 
  499             constexpr 
auto is_scale_mfma = 
false;
 
  507                                        is_scale_mfma>::selected_mfma.k_per_blk);
 
  510                 TileMathThreadGroupSize,
 
  514                 decltype(a_block_desc_ak0_m_ak1),
 
  515                 decltype(b_block_desc_bk0_n_bk1),
 
  523             auto c_grid_buf   = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  524                 p_e_grid, e_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
  528             GridwiseGemmMath::template RunMathWavePipeline<HasMainKBlockLoop>(
 
  529                 a_block_buf, b_block_buf, blockwise_gemm, c_thread_buf, num_k_block_main_loop);
 
  541                 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
 
  542                                   NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
 
  545                 constexpr 
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
 
  546                 constexpr 
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
 
  549                 constexpr 
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
 
  550                     blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
  554                 constexpr 
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
 
  555                     blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
  557                 constexpr 
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
 
  558                 constexpr 
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
 
  559                 constexpr 
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
 
  560                 constexpr 
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
 
  561                 constexpr 
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
 
  562                 constexpr 
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
 
  563                 constexpr 
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
 
  564                 constexpr 
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
 
  566                 constexpr 
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
 
  569                 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  570                     static_cast<EDataTypeShuffle*
>(p_shared),
 
  571                     c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
  574                     c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
  596                 const auto c_thread_mtx_on_block =
 
  597                     blockwise_gemm.CalculateCThreadOriginDataIndex(
I0, 
I0, 
I0, 
I0);
 
  599                 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
 
  600                 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
 
  602                 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
 
  608                 const auto m_thread_data_on_block_idx =
 
  609                     m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
 
  612                 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
 
  618                 const auto n_thread_data_on_block_idx =
 
  619                     n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
 
  626                     decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
  627                     decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
  629                     Sequence<CShuffleMXdlPerWavePerShuffle,
 
  630                              CShuffleNXdlPerWavePerShuffle,
 
  642                     true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
  645                                            m_thread_data_on_block_idx[
I1],
 
  646                                            n_thread_data_on_block_idx[
I1],
 
  647                                            m_thread_data_on_block_idx[
I2],
 
  648                                            m_thread_data_on_block_idx[
I3],
 
  649                                            m_thread_data_on_block_idx[
I4],
 
  650                                            n_thread_data_on_block_idx[
I2]),
 
  656                     EElementwiseOperation,            
 
  657                     CGlobalMemoryDataOperation,       
 
  659                              CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
  661                              CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>, 
 
  662                     CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
  666                     decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
 
  667                     decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
 
  670                     CShuffleBlockTransferScalarPerVector_NPerBlock, 
 
  673                     {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
  675                      e_grid_desc_mblock_mperblock_nblock_nperblock,
 
  680                 constexpr 
auto sfc_c_vgpr =
 
  683                                       Sequence<CShuffleMXdlPerWavePerShuffle,
 
  684                                                CShuffleNXdlPerWavePerShuffle,
 
  693                 constexpr 
auto sfc_c_global =
 
  697                                                CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
 
  699                                                CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
 
  701                 constexpr 
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
 
  703                 static_assert(num_access == sfc_c_global.GetNumOfAccess(), 
"wrong!");
 
  732                     c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
  733                                                   sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
 
  735                                                   c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
  736                                                   c_shuffle_block_buf);
 
  741                     c_shuffle_block_copy_lds_to_global.Run(
 
  742                         c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
 
  744                         e_grid_desc_mblock_mperblock_nblock_nperblock,
 
  747                     if constexpr(access_id < num_access - 1)
 
  749                         constexpr 
auto c_global_step = sfc_c_global.GetForwardStep(access_id);
 
  752                         c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
 
  753                             e_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
 
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
 
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
 
InMemoryDataOperationEnum
Definition: ck.hpp:278
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
 
int64_t long_index_t
Definition: ck.hpp:301
 
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
 
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
 
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
 
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
 
int32_t index_t
Definition: ck.hpp:300
 
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
 
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
 
__host__ constexpr __device__ auto chain_tensor_adaptors(const TensorAdaptor0 &adaptor0, const TensorAdaptor1 &adaptor1)
Definition: tensor_adaptor.hpp:245
 
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
 
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_smfmac_xdlops.hpp:79
 
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:82
 
static __device__ index_t GetThreadId()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:90
 
static constexpr __device__ bool IsBelong()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:85
 
static constexpr __device__ index_t GetNumOfThread()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:83
 
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:97
 
static __device__ index_t GetThreadId()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:105
 
static constexpr __device__ index_t GetNumOfThread()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:98
 
static constexpr __device__ bool IsBelong()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:100
 
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:64
 
__host__ static constexpr __device__ auto MakeDefaultBlock2ETileMap(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:235
 
__host__ static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:132
 
static constexpr auto I3
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:69
 
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:147
 
static constexpr auto AK0PerBlock
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:78
 
__host__ static constexpr __device__ auto MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const EGridDescriptor_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:314
 
__host__ static constexpr __device__ auto MakeDefaultBGridDescriptor_BK0_N_BK1(const BGridDesc_N_K &b_grid_desc_n_k)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:298
 
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_M_K &a_grid_desc_m_k, const BGridDesc_N_K &b_grid_desc_n_k, const EGridDesc_M_N &e_grid_desc_m_n, const Block2ETileMap &)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:177
 
static constexpr auto BK0PerBlock
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:79
 
static constexpr auto I7
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:73
 
__host__ static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:115
 
static constexpr auto I0
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:66
 
__host__ static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:123
 
ThisThreadBlock< TileMathThreadGroupSize > CShuffleBlockTransferThreadGroup
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:108
 
remove_cvref_t< decltype(MakeDefaultBlock2ETileMap(EGridDesc_M_N{}))> DefaultBlock2ETileMap
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:338
 
static __device__ void Run(const ABDataType *__restrict__ p_a_grid, const ABDataType *__restrict__ p_b_grid, EDataType *__restrict__ p_e_grid, void *__restrict__ p_shared, const AElementwiseOperation &a_element_op, const BElementwiseOperation &b_element_op, const EElementwiseOperation &e_element_op, const AGridDesc_AK0_M_AK1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_BK1 &b_grid_desc_bk0_n_bk1, const EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock &e_grid_desc_mblock_mperblock_nblock_nperblock, const Block2ETileMap &block_2_etile_map)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:344
 
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:226
 
static constexpr auto AK1
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:76
 
static constexpr auto I2
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:68
 
__host__ static constexpr __device__ index_t CalculateGridSize(const EGridDesc_M_N &e_grid_desc_m_n)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:270
 
remove_cvref_t< decltype(MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(EGridDesc_M_N{}))> EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:335
 
static constexpr auto I5
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:71
 
static constexpr auto I6
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:72
 
static constexpr auto I4
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:70
 
static constexpr auto BK1
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:77
 
static constexpr auto I1
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:67
 
__host__ static constexpr __device__ auto MakeDefaultAGridDescriptor_AK0_M_AK1(const AGridDesc_M_K &a_grid_desc_m_k)
Definition: gridwise_gemm_xdl_waveletmodel_cshuffle.hpp:282
 
Definition: gridwise_gemm_waveletmodel.hpp:11
 
Definition: gridwise_gemm_waveletmodel.hpp:103
 
Definition: xdlops_gemm.hpp:942
 
Definition: sequence.hpp:43
 
Definition: tensor_space_filling_curve.hpp:20
 
Definition: thread_group.hpp:12
 
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
 
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
 
Definition: threadwise_tensor_slice_transfer.hpp:39
 
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33
 
Definition: unary_element_wise_operation.hpp:308