23 template <
typename Gr
idwiseGemm>
 
   25 #if CK_USE_LAUNCH_BOUNDS 
   29                                    const typename GridwiseGemm::FloatAB* p_b_grid,
 
   30                                    typename GridwiseGemm::FloatC* p_c_grid,
 
   38                                    typename GridwiseGemm::Block2CTileMap block_mapping)
 
   40 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \ 
   42     constexpr 
index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
 
   44     __shared__ uint8_t p_shared[shared_size];
 
   46     GridwiseGemm::Run(p_a_grid,
 
   57                       static_cast<void*
>(p_shared));
 
   74           typename Block2CTileMap_,
 
   81           typename AElementwiseOperation,
 
   82           typename BElementwiseOperation,
 
   83           typename CElementwiseOperation,
 
   92           typename ABlockTransferThreadClusterLengths_K0_M_K1,
 
   93           typename ABlockTransferThreadClusterArrangeOrder,
 
   94           typename ABlockTransferSrcAccessOrder,
 
   95           index_t ABlockTransferSrcVectorDim,
 
   96           index_t ABlockTransferSrcScalarPerVector,
 
   97           index_t ABlockTransferDstScalarPerVector_K1,
 
   98           bool AThreadTransferSrcResetCoordinateAfterRun,
 
  100           typename BBlockTransferThreadClusterLengths_K0_N_K1,
 
  101           typename BBlockTransferThreadClusterArrangeOrder,
 
  102           typename BBlockTransferSrcAccessOrder,
 
  103           index_t BBlockTransferSrcVectorDim,
 
  104           index_t BBlockTransferSrcScalarPerVector,
 
  105           index_t BBlockTransferDstScalarPerVector_K1,
 
  106           bool BThreadTransferSrcResetCoordinateAfterRun,
 
  108           index_t CShuffleMRepeatPerShuffle,
 
  109           index_t CShuffleNRepeatPerShuffle,
 
  110           index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
 
  111           typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>
 
  125     static constexpr 
auto M01       = 1;
 
  126     static constexpr 
auto N01       = 1;
 
  161                  uint32_t num_sk_blocks_)
 
  183                       << 
"SC:" << 
StrideC << std::endl;
 
  196     __host__ __device__ 
static auto 
  201         const auto a_grid_desc_m_k = [&]() {
 
  225     __host__ __device__ 
static auto 
  230         const auto b_grid_desc_k_n = [&]() {
 
  254     __host__ __device__ 
static auto 
  257         const auto c_grid_desc_m_n = [&]() {
 
  293         constexpr 
auto max_lds_align = 
K1;
 
  299         constexpr 
auto a_block_space_size_aligned =
 
  302         constexpr 
auto b_block_space_size_aligned =
 
  305         constexpr 
auto c_block_size =
 
  308         return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
 
  317             if(karg.
K % ABlockTransferSrcScalarPerVector != 0)
 
  322             if(karg.
M % ABlockTransferSrcScalarPerVector != 0)
 
  328             if(karg.
N % BBlockTransferSrcScalarPerVector != 0)
 
  333             if(karg.
K % BBlockTransferSrcScalarPerVector != 0)
 
  339             if(karg.
N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
 
  344             if(karg.
M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
 
  353         const bool has_main_k0_block_loop = K0 > K0PerBlock;
 
  355         return has_main_k0_block_loop;
 
  358     template <
typename CGr
idDesc>
 
  359     __host__ __device__ 
static constexpr 
auto 
  362         const auto M = c_m_n_grid_desc.GetLength(
I0);
 
  363         const auto N = c_m_n_grid_desc.GetLength(
I1);
 
  365         const auto MBlock = M / MPerBlock;
 
  366         const auto NBlock = N / NPerBlock;
 
  377     template <
typename CGr
idDesc>
 
  382             c_m_n_grid_desc, 8, KBatch);
 
  385     __host__ __device__ 
static constexpr 
auto 
  388         constexpr 
index_t MWave = MPerBlock / (MRepeat * MPerXDL);
 
  389         constexpr 
index_t NWave = NPerBlock / (NRepeat * NPerXDL);
 
  398     __host__ __device__ 
static constexpr 
auto 
  401         constexpr 
index_t MWave = MPerBlock / (MRepeat * MPerXDL);
 
  402         constexpr 
index_t NWave = NPerBlock / (NRepeat * NPerXDL);
 
  407                        Number<NRepeat / CShuffleNRepeatPerShuffle>{},
 
  415         constexpr 
auto NPerBlockPow2 = math::next_power_of_two<NPerBlock>();
 
  416         constexpr 
auto NPerBlockReduction =
 
  417             NPerBlockPow2 / CBlockTransferScalarPerVector_NWaveNPerXDL;
 
  418         constexpr 
auto MPerBlockReduction =
 
  419             (BlockSize + NPerBlockReduction - 1) / NPerBlockReduction;
 
  425         const auto c_partial_acc_block_m_n = [&]() {
 
  437         return c_partial_acc_block_m_n;
 
  453                                void* __restrict__ p_shared_block)
 
  458         uint32_t pad_m    = (m + MPerBlock - 1) / MPerBlock * MPerBlock;
 
  459         uint32_t pad_n    = (n + NPerBlock - 1) / NPerBlock * NPerBlock;
 
  461         uint32_t stride_a = StrideA;
 
  462         uint32_t stride_b = StrideB;
 
  463         uint32_t stride_c = StrideC;
 
  469         const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
 
  471         const AElementwiseOperation a_element_op = AElementwiseOperation{};
 
  472         const BElementwiseOperation b_element_op = BElementwiseOperation{};
 
  473         const CElementwiseOperation c_element_op = CElementwiseOperation{};
 
  475         const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  476             p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
 
  477         const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  478             p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
 
  479         auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  480             p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
 
  483         constexpr 
auto max_lds_align = 
K1;
 
  491         auto blockwise_gemm =
 
  496                                                                 decltype(a_block_desc_k0_m_k1),
 
  497                                                                 decltype(b_block_desc_k0_n_k1),
 
  504         auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
 
  507         constexpr 
auto a_block_space_size =
 
  511         FloatAB* p_b_block = 
static_cast<FloatAB*
>(p_shared_block) + a_block_space_size;
 
  513         constexpr 
auto a_block_slice_copy_step = 
make_multi_index(K0PerBlock, 0, 0);
 
  514         constexpr 
auto b_block_slice_copy_step = 
make_multi_index(K0PerBlock, 0, 0);
 
  516         auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  517             p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize());
 
  518         auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  519             p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize());
 
  524         uint32_t block_idx = block_mapping.get_block_idx();
 
  525         bool is_sk_block   = block_idx < block_mapping.sk_num_blocks;
 
  526         bool is_dp_block   = block_idx >= block_mapping.dp_start_block_idx &&
 
  527                            block_idx < block_mapping.reduction_start_block_idx;
 
  528         bool is_reduction_block = block_idx >= block_mapping.reduction_start_block_idx;
 
  529         bool is_padding_block   = block_idx >= block_mapping.sk_num_blocks &&
 
  530                                 block_idx < block_mapping.dp_start_block_idx;
 
  531         uint32_t iter_start, iter_end;
 
  532         block_mapping.get_block_itr(block_idx, iter_start, iter_end);
 
  533         uint32_t total_iter_length = iter_end - iter_start;
 
  538         uint32_t* p_semaphore =
 
  539             reinterpret_cast<uint32_t*
>(
reinterpret_cast<char*
>(p_workspace) +
 
  540                                         block_mapping.get_workspace_size_for_acc(
sizeof(
FloatAcc)));
 
  544             if(is_reduction_block)
 
  549                 const auto reduce_thread_cluster_idx =
 
  551                 const auto thread_m_cluster_id = reduce_thread_cluster_idx[
I0];
 
  552                 const auto thread_n_cluster_id = reduce_thread_cluster_idx[
I1];
 
  554                 constexpr 
auto MReduceIters =
 
  558                     cluster_length_reduce.At(
I1) *
 
  569                     0, cluster_length_reduce.At(
I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
 
  570                 constexpr 
auto partial_acc_load_step_n_reverse =
 
  572                                      -1 * cluster_length_reduce.At(
I1).value * (NReduceIters - 1) *
 
  573                                          CBlockTransferScalarPerVector_NWaveNPerXDL);
 
  574                 constexpr 
auto partial_acc_load_step_m =
 
  581                     cluster_length_reduce.At(
I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
 
  582                 constexpr 
auto partial_acc_store_step_n_reverse =
 
  586                                      -1 * cluster_length_reduce.At(
I1).value * (NReduceIters - 1) *
 
  587                                          CBlockTransferScalarPerVector_NWaveNPerXDL);
 
  588                 constexpr 
auto partial_acc_store_step_m =
 
  593                              CBlockTransferScalarPerVector_NWaveNPerXDL,
 
  598                              CBlockTransferScalarPerVector_NWaveNPerXDL,
 
  603                 auto reduction_idx = blockIdx.x - block_mapping.reduction_start_block_idx;
 
  604                 auto spatial_idx   = block_mapping.tile_to_spatial(reduction_idx, m, n);
 
  608                 uint32_t tile_acc_offset_start =
 
  609                     block_mapping.get_acc_buffer_offset_from_tile(reduction_idx);
 
  610                 uint32_t tile_acc_offset_end =
 
  611                     block_mapping.get_acc_buffer_offset_from_tile(reduction_idx + 1);
 
  616                     decltype(c_partial_acc_block_m_n),                       
 
  617                     decltype(acc_thread_buf_load_desc),                      
 
  621                     CBlockTransferScalarPerVector_NWaveNPerXDL,              
 
  624                     >{c_partial_acc_block_m_n,
 
  626                                        thread_n_cluster_id *
 
  627                                            CBlockTransferScalarPerVector_NWaveNPerXDL)};
 
  632                     decltype(acc_thread_buf_store_desc),                     
 
  633                     decltype(c_grid_desc_mblock_mperblock_nblock_nperblock), 
 
  634                     CElementwiseOperation, 
 
  638                     CBlockTransferScalarPerVector_NWaveNPerXDL, 
 
  642                     >{c_grid_desc_mblock_mperblock_nblock_nperblock,
 
  645                                        __builtin_amdgcn_readfirstlane(spatial_idx[
I1]),
 
  646                                        thread_n_cluster_id *
 
  647                                            CBlockTransferScalarPerVector_NWaveNPerXDL),
 
  648                       CElementwiseOperation{}};
 
  651                 wg_barrier.
wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start);
 
  654                 if(threadIdx.x == 0) {
 
  655                     printf(
"bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n", 
static_cast<int>(blockIdx.x),
 
  656                         reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
 
  657                         __builtin_amdgcn_readfirstlane(spatial_idx[
I0]),
 
  658                         __builtin_amdgcn_readfirstlane(spatial_idx[
I1]));
 
  662                 using Accumulation = ck::detail::
 
  665                 for(
int i_m = 0; i_m < MReduceIters; i_m++)
 
  669                         for(
auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
 
  671                             auto c_partial_acc_buf =
 
  674                                     reinterpret_cast<FloatAcc*
>(p_workspace) +
 
  675                                         i * c_partial_acc_block_m_n.GetElementSpaceSize(),
 
  676                                     c_partial_acc_block_m_n.GetElementSpaceSize());
 
  678                             acc_load.Run(c_partial_acc_block_m_n,
 
  680                                          acc_thread_buf_load_desc,
 
  686                                     constexpr 
auto offset =
 
  687                                         acc_thread_buf_load_desc.CalculateOffset(
 
  694                         if(thread_n_cluster_id * CBlockTransferScalarPerVector_NWaveNPerXDL <
 
  697                             acc_store.Run(acc_thread_buf_store_desc,
 
  700                                           c_grid_desc_mblock_mperblock_nblock_nperblock,
 
  703                         if constexpr(NReduceIters != 1)
 
  705                             if constexpr(i_n_reduce != (NReduceIters - 1))
 
  707                                 acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
 
  708                                                             partial_acc_load_step_n);
 
  709                                 acc_store.MoveDstSliceWindow(
 
  710                                     c_grid_desc_mblock_mperblock_nblock_nperblock,
 
  711                                     partial_acc_store_step_n);
 
  715                                 acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
 
  716                                                             partial_acc_load_step_n_reverse);
 
  717                                 acc_store.MoveDstSliceWindow(
 
  718                                     c_grid_desc_mblock_mperblock_nblock_nperblock,
 
  719                                     partial_acc_store_step_n_reverse);
 
  724                         acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
 
  725                                                     partial_acc_load_step_m);
 
  726                         acc_store.MoveDstSliceWindow(c_grid_desc_mblock_mperblock_nblock_nperblock,
 
  727                                                      partial_acc_store_step_m);
 
  735         uint32_t block_acc_offset =
 
  736             (block_mapping.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * MPerBlock *
 
  741             uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
 
  742                 block_mapping.get_current_iter_length(iter_start, iter_end, total_iter_length));
 
  743             uint32_t tile_idx, iter_offset;
 
  744             block_mapping.get_tile_idx_with_offset(iter_end - 1, tile_idx, iter_offset);
 
  745             iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
 
  746             auto spatial_idx = block_mapping.tile_to_spatial(tile_idx, m, n);
 
  748             const index_t m_block_data_idx_on_grid =
 
  749                 __builtin_amdgcn_readfirstlane(spatial_idx[
I0] * MPerBlock);
 
  751             const index_t n_block_data_idx_on_grid =
 
  752                 __builtin_amdgcn_readfirstlane(spatial_idx[
I1] * NPerBlock);
 
  754             const index_t k0_block_data_idx_on_grid =
 
  755                 __builtin_amdgcn_readfirstlane(iter_offset * K0PerBlock);
 
  758             auto a_blockwise_copy =
 
  760                                                     AElementwiseOperation,
 
  764                                                     ABlockTransferThreadClusterLengths_K0_M_K1,
 
  765                                                     ABlockTransferThreadClusterArrangeOrder,
 
  768                                                     decltype(a_k0_m_k1_grid_desc),
 
  769                                                     decltype(a_block_desc_k0_m_k1),
 
  770                                                     ABlockTransferSrcAccessOrder,
 
  772                                                     ABlockTransferSrcVectorDim,
 
  774                                                     ABlockTransferSrcScalarPerVector,
 
  775                                                     ABlockTransferDstScalarPerVector_K1,
 
  778                                                     AThreadTransferSrcResetCoordinateAfterRun,
 
  783                     a_block_desc_k0_m_k1,
 
  788             auto b_blockwise_copy =
 
  790                                                     BElementwiseOperation,
 
  794                                                     BBlockTransferThreadClusterLengths_K0_N_K1,
 
  795                                                     BBlockTransferThreadClusterArrangeOrder,
 
  798                                                     decltype(b_k0_n_k1_grid_desc),
 
  799                                                     decltype(b_block_desc_k0_n_k1),
 
  800                                                     BBlockTransferSrcAccessOrder,
 
  802                                                     BBlockTransferSrcVectorDim,
 
  804                                                     BBlockTransferSrcScalarPerVector,
 
  805                                                     BBlockTransferDstScalarPerVector_K1,
 
  808                                                     BThreadTransferSrcResetCoordinateAfterRun,
 
  813                     b_block_desc_k0_n_k1,
 
  817             const index_t num_k_block_main_loop = current_iter_length;
 
  819             gridwise_gemm_pipeline.Run(a_k0_m_k1_grid_desc,
 
  820                                        a_block_desc_k0_m_k1,
 
  824                                        a_block_slice_copy_step,
 
  826                                        b_block_desc_k0_n_k1,
 
  830                                        b_block_slice_copy_step,
 
  833                                        num_k_block_main_loop);
 
  837                 constexpr 
index_t MWave = MPerBlock / (MRepeat * MPerXDL);
 
  838                 constexpr 
index_t NWave = NPerBlock / (NRepeat * NPerXDL);
 
  840                 constexpr 
auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
 
  841                     blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
  843                 constexpr 
auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
 
  844                     blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
 
  846                 constexpr 
auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I0);
 
  847                 constexpr 
auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I1);
 
  848                 constexpr 
auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I2);
 
  849                 constexpr 
auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I3);
 
  850                 constexpr 
auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I4);
 
  851                 constexpr 
auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I5);
 
  852                 constexpr 
auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I6);
 
  853                 constexpr 
auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I7);
 
  855                 constexpr 
auto c_block_desc_mblock_mpershuffle_nblock_npershuffle =
 
  858                 constexpr 
auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
 
  861                 auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
 
  863                     c_block_desc_mblock_mpershuffle_nblock_npershuffle.GetElementSpaceSize());
 
  865                 auto c_partial_acc_buf =
 
  866                     make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
 
  867                         reinterpret_cast<FloatAcc*
>(p_workspace) + block_acc_offset,
 
  868                         c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
 
  869                             .GetElementSpaceSize());
 
  872                     c_block_desc_mblock_mpershuffle_nblock_npershuffle,
 
  893                 const auto c_thread_mtx_on_block =
 
  894                     blockwise_gemm.CalculateCThreadOriginDataIndex(
I0, 
I0, 
I0, 
I0);
 
  896                 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
 
  897                 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
 
  899                 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
 
  905                 const auto m_thread_data_on_block_idx =
 
  906                     m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
 
  909                 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
 
  915                 const auto n_thread_data_on_block_idx =
 
  916                     n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
 
  923                     decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
 
  924                     decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
 
  927                              CShuffleNRepeatPerShuffle,
 
  939                     true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
  942                                            m_thread_data_on_block_idx[
I1],
 
  943                                            n_thread_data_on_block_idx[
I1],
 
  944                                            m_thread_data_on_block_idx[
I2],
 
  945                                            m_thread_data_on_block_idx[
I3],
 
  946                                            m_thread_data_on_block_idx[
I4],
 
  947                                            n_thread_data_on_block_idx[
I2]),
 
  953                     CElementwiseOperation, 
 
  956                              CShuffleMRepeatPerShuffle * MWave * MPerXDL,
 
  958                              CShuffleNRepeatPerShuffle * NWave * NPerXDL>, 
 
  959                     CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
  963                     decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
 
  964                     decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
 
  967                     CBlockTransferScalarPerVector_NWaveNPerXDL, 
 
  970                     {c_block_desc_mblock_mpershuffle_nblock_npershuffle,
 
  972                      c_grid_desc_mblock_mperblock_nblock_nperblock,
 
  975                                       __builtin_amdgcn_readfirstlane(spatial_idx[
I1]),
 
  982                     CElementwiseOperation, 
 
  985                              CShuffleMRepeatPerShuffle * MWave * MPerXDL,
 
  987                              CShuffleNRepeatPerShuffle * NWave * NPerXDL>, 
 
  988                     CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
 
  992                     decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
 
  993                     decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
 
  996                     CBlockTransferScalarPerVector_NWaveNPerXDL, 
 
 1001                     {c_block_desc_mblock_mpershuffle_nblock_npershuffle,
 
 1003                      c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
 
 1007                 constexpr 
auto mxdlperwave_forward_step =
 
 1009                 constexpr 
auto nxdlperwave_forward_step =
 
 1011                 constexpr 
auto nxdlperwave_backward_step =
 
 1015                     constexpr 
auto mxdlperwave = mxdlperwave_iter;
 
 1018                         constexpr 
bool nxdlperwave_forward_sweep =
 
 1019                             (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
 
 1021                         constexpr 
index_t nxdlperwave_value =
 
 1022                             nxdlperwave_forward_sweep
 
 1024                                 : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
 
 1032                         c_thread_copy_vgpr_to_lds.Run(
 
 1033                             c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
 
 1036                             c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
 
 1042                         c_block_copy_lds_to_global.SetSrcSliceOrigin(
 
 1043                             c_block_desc_mblock_mpershuffle_nblock_npershuffle,
 
 1048                             c_block_copy_lds_to_global.template 
Run<decltype(c_block_buf),
 
 1049                                                                     decltype(c_grid_buf),
 
 1051                                 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
 
 1053                                 c_grid_desc_mblock_mperblock_nblock_nperblock,
 
 1055                         else if(is_sk_block)
 
 1057                             if constexpr(Block2CTileMap::ReductionStrategy ==
 
 1061                                 c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
 
 1062                                     c_block_desc_mblock_mpershuffle_nblock_npershuffle,
 
 1065                                 c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
 
 1066                                     c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
 
 1067                                     make_tuple(mxdlperwave.value, 0, nxdlperwave.value, 0));
 
 1069                                 c_block_copy_lds_to_partial_acc
 
 1070                                     .template 
Run<decltype(c_block_buf),
 
 1071                                                   decltype(c_partial_acc_buf),
 
 1073                                         c_block_desc_mblock_mpershuffle_nblock_npershuffle,
 
 1075                                         c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
 
 1078                             else if constexpr(Block2CTileMap::ReductionStrategy ==
 
 1081                                 c_block_copy_lds_to_global
 
 1082                                     .template 
Run<decltype(c_block_buf),
 
 1083                                                   decltype(c_grid_buf),
 
 1085                                         c_block_desc_mblock_mpershuffle_nblock_npershuffle,
 
 1087                                         c_grid_desc_mblock_mperblock_nblock_nperblock,
 
 1093                         if constexpr(nxdlperwave_forward_sweep &&
 
 1094                                      (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
 
 1096                             c_block_copy_lds_to_global.MoveDstSliceWindow(
 
 1097                                 c_grid_desc_mblock_mperblock_nblock_nperblock,
 
 1098                                 nxdlperwave_forward_step);
 
 1100                         else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
 
 1102                             c_block_copy_lds_to_global.MoveDstSliceWindow(
 
 1103                                 c_grid_desc_mblock_mperblock_nblock_nperblock,
 
 1104                                 nxdlperwave_backward_step);
 
 1109                     if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
 
 1111                         c_block_copy_lds_to_global.MoveDstSliceWindow(
 
 1112                             c_grid_desc_mblock_mperblock_nblock_nperblock,
 
 1113                             mxdlperwave_forward_step);
 
 1117                 if constexpr(Block2CTileMap::ReductionStrategy ==
 
 1124                         wg_barrier.
inc(tile_idx);
 
 1130             iter_end -= current_iter_length;
 
 1131             if(iter_end <= iter_start)
 
 1136                 block_acc_offset -= MPerBlock * NPerBlock;
 
 1143     template <
typename Layout>
 
 1146         static std::string 
Get() { 
return ""; }
 
 1152         static std::string 
Get() { 
return "R"; }
 
 1158         static std::string 
Get() { 
return "C"; }
 
 1163         auto str = std::stringstream();
 
 1166         str << 
"GemmXdlStreamK_" 
 1167             << std::string(ALayout::name)[0]
 
 1168             << std::string(BLayout::name)[0]
 
 1169             << std::string(CLayout::name)[0]
 
 1171             << 
"B" << BlockSize << 
"_" 
 1172             << 
"Vec" << ABlockTransferSrcScalarPerVector << 
"x" 
 1173             << BBlockTransferSrcScalarPerVector << 
"x" 
 1174             << CBlockTransferScalarPerVector_NWaveNPerXDL << 
"_" 
 1177             << K0PerBlock << 
"x" 
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
 
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
 
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
 
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
@ Atomic
Definition: block_to_ctile_map.hpp:1011
 
@ Reduction
Definition: block_to_ctile_map.hpp:1012
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
 
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
 
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
__global__ void kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB *p_a_grid, const typename GridwiseGemm::FloatAB *p_b_grid, typename GridwiseGemm::FloatC *p_c_grid, void *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, typename GridwiseGemm::Block2CTileMap block_mapping)
Definition: gridwise_gemm_xdlops_streamk.hpp:28
 
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
 
int32_t index_t
Definition: ck.hpp:300
 
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
 
__host__ constexpr __device__ auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition: dynamic_buffer.hpp:461
 
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
 
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
 
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
 
Definition: block_to_ctile_map.hpp:540
 
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
 
Definition: gridwise_gemm_xdlops_streamk.hpp:138
 
index_t K
Definition: gridwise_gemm_xdlops_streamk.hpp:144
 
const FloatAB * p_b_grid
Definition: gridwise_gemm_xdlops_streamk.hpp:140
 
void Print() const
Definition: gridwise_gemm_xdlops_streamk.hpp:175
 
index_t M
Definition: gridwise_gemm_xdlops_streamk.hpp:142
 
Argument(const FloatAB *p_a_grid_, const FloatAB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, uint32_t num_cu, uint32_t occupancy, uint32_t num_sk_blocks_)
Definition: gridwise_gemm_xdlops_streamk.hpp:150
 
FloatC * p_c_grid
Definition: gridwise_gemm_xdlops_streamk.hpp:141
 
const FloatAB * p_a_grid
Definition: gridwise_gemm_xdlops_streamk.hpp:139
 
index_t StrideC
Definition: gridwise_gemm_xdlops_streamk.hpp:147
 
index_t StrideB
Definition: gridwise_gemm_xdlops_streamk.hpp:146
 
index_t StrideA
Definition: gridwise_gemm_xdlops_streamk.hpp:145
 
index_t N
Definition: gridwise_gemm_xdlops_streamk.hpp:143
 
Block2CTileMap block_mapping
Definition: gridwise_gemm_xdlops_streamk.hpp:148
 
static std::string Get()
Definition: gridwise_gemm_xdlops_streamk.hpp:1158
 
static std::string Get()
Definition: gridwise_gemm_xdlops_streamk.hpp:1152
 
Definition: gridwise_gemm_xdlops_streamk.hpp:1145
 
static std::string Get()
Definition: gridwise_gemm_xdlops_streamk.hpp:1146
 
Definition: gridwise_gemm_xdlops_streamk.hpp:113
 
static constexpr auto I5
Definition: gridwise_gemm_xdlops_streamk.hpp:119
 
static __device__ void Run(const FloatAB *p_a_grid, const FloatAB *p_b_grid, FloatC *p_c_grid, void *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, Block2CTileMap block_mapping, void *__restrict__ p_shared_block)
Definition: gridwise_gemm_xdlops_streamk.hpp:442
 
__host__ static constexpr __device__ auto MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_m_n_grid_desc)
Definition: gridwise_gemm_xdlops_streamk.hpp:360
 
__host__ static __device__ auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA)
Definition: gridwise_gemm_xdlops_streamk.hpp:197
 
__host__ static __device__ auto CalculateK0(index_t KPad)
Definition: gridwise_gemm_xdlops_streamk.hpp:194
 
static constexpr auto I0
Definition: gridwise_gemm_xdlops_streamk.hpp:114
 
Block2CTileMap_ Block2CTileMap
Definition: gridwise_gemm_xdlops_streamk.hpp:133
 
__host__ static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdlops_streamk.hpp:283
 
__host__ static __device__ auto CalculateGridSize(const Argument &karg)
Definition: gridwise_gemm_xdlops_streamk.hpp:187
 
FloatAcc FloatCShuffle
Definition: gridwise_gemm_xdlops_streamk.hpp:131
 
__host__ static constexpr __device__ auto GetClusterLengthReduction()
Definition: gridwise_gemm_xdlops_streamk.hpp:411
 
__host__ static constexpr __device__ bool CalculateHasMainK0BlockLoop(index_t K0)
Definition: gridwise_gemm_xdlops_streamk.hpp:351
 
static constexpr auto N01
Definition: gridwise_gemm_xdlops_streamk.hpp:126
 
static constexpr auto I6
Definition: gridwise_gemm_xdlops_streamk.hpp:120
 
__host__ static constexpr __device__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdlops_streamk.hpp:313
 
static constexpr auto M01
Definition: gridwise_gemm_xdlops_streamk.hpp:125
 
__host__ static constexpr __device__ auto GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle()
Definition: gridwise_gemm_xdlops_streamk.hpp:386
 
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdlops_streamk.hpp:255
 
static std::string GetTypeString()
Definition: gridwise_gemm_xdlops_streamk.hpp:1161
 
__host__ static constexpr __device__ auto GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle()
Definition: gridwise_gemm_xdlops_streamk.hpp:399
 
__host__ static constexpr __device__ auto GetPartialAccBlockDescriptor()
Definition: gridwise_gemm_xdlops_streamk.hpp:423
 
static constexpr auto I2
Definition: gridwise_gemm_xdlops_streamk.hpp:116
 
static constexpr auto I1
Definition: gridwise_gemm_xdlops_streamk.hpp:115
 
FloatAB_ FloatAB
Definition: gridwise_gemm_xdlops_streamk.hpp:134
 
static constexpr auto K1
Definition: gridwise_gemm_xdlops_streamk.hpp:124
 
static constexpr auto KPerBlock
Definition: gridwise_gemm_xdlops_streamk.hpp:127
 
__host__ static constexpr __device__ auto MakeCBlockClusterAdaptor(const CGridDesc &c_m_n_grid_desc, index_t, index_t, index_t KBatch)
Definition: gridwise_gemm_xdlops_streamk.hpp:378
 
FloatAcc_ FloatAcc
Definition: gridwise_gemm_xdlops_streamk.hpp:130
 
remove_cvref_t< decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))> CGridDesc_M_N
Definition: gridwise_gemm_xdlops_streamk.hpp:440
 
static constexpr auto I3
Definition: gridwise_gemm_xdlops_streamk.hpp:117
 
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_streamk.hpp:129
 
static constexpr auto I7
Definition: gridwise_gemm_xdlops_streamk.hpp:121
 
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_streamk.hpp:291
 
static constexpr auto I4
Definition: gridwise_gemm_xdlops_streamk.hpp:118
 
FloatC_ FloatC
Definition: gridwise_gemm_xdlops_streamk.hpp:135
 
__host__ static __device__ auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB)
Definition: gridwise_gemm_xdlops_streamk.hpp:226
 
__host__ static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdlops_streamk.hpp:275
 
Definition: gridwise_gemm_pipeline_v3.hpp:11
 
Definition: sequence.hpp:43
 
Definition: static_buffer.hpp:16
 
__host__ __device__ void Clear()
Definition: static_buffer.hpp:63
 
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
 
Definition: thread_group_tensor_slice_transfer_v6r1r2.hpp:33
 
Definition: threadwise_tensor_slice_transfer.hpp:39
 
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
 
Definition: integral_constant.hpp:20
 
Definition: reduction_operator.hpp:37
 
Definition: functional2.hpp:33
 
Definition: tensor_layout.hpp:21
 
Definition: tensor_layout.hpp:16
 
Definition: device_base.hpp:51
 
Definition: unary_element_wise_operation.hpp:308
 
Definition: workgroup_barrier.hpp:7
 
__device__ void inc(uint32_t offset)
Definition: workgroup_barrier.hpp:62
 
__device__ void wait_eq(uint32_t offset, uint32_t value)
Definition: workgroup_barrier.hpp:29