20           typename ComputeDataType,
 
   24           typename AMmaTileDesc,
 
   25           typename BMmaTileDesc,
 
   26           index_t ABlockTransferSrcScalarPerVector,
 
   27           index_t BBlockTransferSrcScalarPerVector,
 
   43           typename ComputeDataType,
 
   47           typename AMmaTileDesc,
 
   48           typename BMmaTileDesc,
 
   49           index_t ABlockTransferSrcScalarPerVector,
 
   50           index_t BBlockTransferSrcScalarPerVector,
 
   71                                        ABlockTransferSrcScalarPerVector,
 
   72                                        BBlockTransferSrcScalarPerVector,
 
   90                                         ABlockTransferSrcScalarPerVector,
 
   91                                         BBlockTransferSrcScalarPerVector,
 
  111                                                    ABlockTransferSrcScalarPerVector,
 
  112                                                    BBlockTransferSrcScalarPerVector,
 
  123     using Base::xdlops_gemm;
 
  125     using Base::CalculateCThreadOriginDataIndex;
 
  126     using Base::CalculateCThreadOriginDataIndex8D;
 
  127     using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  128     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  129     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  130     using Base::GetCThreadBuffer;
 
  131     using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  132     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  133     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  134     using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  135     using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  137     using Base::a_block_desc_m0_m1_m2_k;
 
  138     using Base::b_block_desc_n0_n1_n2_k;
 
  140     using Base::AMmaKStride;
 
  141     using Base::BMmaKStride;
 
  144         (4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
 
  147         (MPerBlock * 
sizeof(ADataType) + NPerBlock * 
sizeof(BDataType)) * KPerBlock);
 
  149         FullMemBandPrefetchStages >= 2
 
  150             ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8
 
  154     static constexpr 
index_t GlobalBufferNum = PrefetchStages;
 
  158         return num_loop > PrefetchStages;
 
  163         if(num_loop % PrefetchStages == 1)
 
  167         else if(num_loop % PrefetchStages == 2)
 
  171         else if(num_loop % PrefetchStages == 3)
 
  175         else if(num_loop % PrefetchStages == 4)
 
  179         else if(num_loop % PrefetchStages == 5)
 
  183         else if(num_loop % PrefetchStages == 6)
 
  187         else if(num_loop % PrefetchStages == 7)
 
  197     template <
bool HasMainLoop,
 
  201               typename ABlockTransfer,
 
  202               typename AGridBuffer,
 
  203               typename ABlockBuffer,
 
  204               typename ABlockTransferStep,
 
  207               typename BBlockTransfer,
 
  208               typename BGridBuffer,
 
  209               typename BBlockBuffer,
 
  210               typename BBlockTransferStep,
 
  211               typename CThreadBuffer>
 
  212     __device__ 
void Run(
const AGridDesc& a_grid_desc,
 
  213                         const ABlockDesc& a_block_desc,
 
  214                         ABlockTransfer& a_blockwise_copy,
 
  215                         const AGridBuffer& a_grid_buf,
 
  216                         ABlockBuffer& a_block_buf,
 
  217                         const ABlockTransferStep& a_block_copy_step,
 
  218                         const BGridDesc& b_grid_desc,
 
  219                         const BBlockDesc& b_block_desc,
 
  220                         BBlockTransfer& b_blockwise_copy,
 
  221                         const BGridBuffer& b_grid_buf,
 
  222                         BBlockBuffer& b_block_buf,
 
  223                         const BBlockTransferStep& b_block_copy_step,
 
  224                         CThreadBuffer& c_thread_buf,
 
  227         auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
 
  228             a_thread_desc_.GetElementSpaceSize());
 
  229         auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
 
  230             b_thread_desc_.GetElementSpaceSize());
 
  233         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
 
  234         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
 
  236         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  237         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  240         c_thread_buf.Clear();
 
  243         a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
 
  244         b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
 
  248             a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
 
  249             b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
 
  251             a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  252             b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  256         if constexpr(HasMainLoop)
 
  266                             a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  274                             b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  290                                     a_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  291                                         a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  293                                     b_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  294                                         b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  298                                 using mfma_input_type =
 
  300                                                          xdlops_gemm.K1PerXdlops>::type;
 
  303                                     c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  306                                     a_thread_vec.template AsType<mfma_input_type>(),
 
  307                                     b_thread_vec.template AsType<mfma_input_type>(),
 
  314                     a_blockwise_copy.RunWrite(
 
  315                         a_block_desc, a_block_buf, 
Number<(iprefetch + 1) % PrefetchStages>{});
 
  316                     b_blockwise_copy.RunWrite(
 
  317                         b_block_desc, b_block_buf, 
Number<(iprefetch + 1) % PrefetchStages>{});
 
  319                     a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
 
  320                     b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
 
  322                     a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  323                     b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  327             } 
while(i < (num_loop - PrefetchStages));
 
  332         auto LoopTailFunc = [&](
auto tail_num) {
 
  337                         a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  345                         b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  361                                 a_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  362                                     a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  364                                 b_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  365                                     b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  369                             using mfma_input_type =
 
  371                                                      xdlops_gemm.K1PerXdlops>::type;
 
  374                                 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  377                                 a_thread_vec.template AsType<mfma_input_type>(),
 
  378                                 b_thread_vec.template AsType<mfma_input_type>(),
 
  385                 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch);
 
  386                 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch);
 
  392                     a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  400                     b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  416                             a_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  417                                 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  419                             b_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  420                                 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  424                         using mfma_input_type =
 
  425                             typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
 
  428                             c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  430                         xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
  431                                         b_thread_vec.template AsType<mfma_input_type>(),
 
  443                     a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  451                     b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  467                             a_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  468                                 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  470                             b_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  471                                 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  475                         using mfma_input_type =
 
  476                             typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
 
  479                             c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  481                         xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
  482                                         b_thread_vec.template AsType<mfma_input_type>(),
 
  519     using Base::a_thread_copy_;
 
  520     using Base::a_thread_desc_;
 
  521     using Base::b_thread_copy_;
 
  522     using Base::b_thread_desc_;
 
  523     using Base::c_thread_desc_;
 
  529           typename ComputeDataType,
 
  530           typename AccDataType,
 
  533           typename AMmaTileDesc,
 
  534           typename BMmaTileDesc,
 
  535           index_t ABlockTransferSrcScalarPerVector,
 
  536           index_t BBlockTransferSrcScalarPerVector,
 
  557                                        ABlockTransferSrcScalarPerVector,
 
  558                                        BBlockTransferSrcScalarPerVector,
 
  576                                         ABlockTransferSrcScalarPerVector,
 
  577                                         BBlockTransferSrcScalarPerVector,
 
  597                                                    ABlockTransferSrcScalarPerVector,
 
  598                                                    BBlockTransferSrcScalarPerVector,
 
  611     using Base::KPerThread;
 
  612     using Base::xdlops_gemm;
 
  614     using Base::CalculateCThreadOriginDataIndex;
 
  615     using Base::CalculateCThreadOriginDataIndex8D;
 
  616     using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  617     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  618     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  619     using Base::GetCThreadBuffer;
 
  620     using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  621     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  622     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  623     using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  624     using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  626     using Base::a_block_desc_m0_m1_m2_k;
 
  627     using Base::b_block_desc_n0_n1_n2_k;
 
  631     static constexpr 
index_t KRepeat        = KPerThread / KPerInnerLoop;
 
  634         (4 * WarpSize / BlockSize) >= 1 ? 4 * WarpSize / BlockSize : 1;
 
  637         (MPerBlock * 
sizeof(ADataType) + NPerBlock * 
sizeof(BDataType)) * KPerBlock);
 
  639         FullMemBandPrefetchStages >= 2
 
  640             ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8
 
  644     static constexpr 
index_t GlobalBufferNum = PrefetchStages;
 
  648         return num_loop > PrefetchStages;
 
  653         if(num_loop % PrefetchStages == 1)
 
  657         else if(num_loop % PrefetchStages == 2)
 
  661         else if(num_loop % PrefetchStages == 3)
 
  665         else if(num_loop % PrefetchStages == 4)
 
  669         else if(num_loop % PrefetchStages == 5)
 
  673         else if(num_loop % PrefetchStages == 6)
 
  677         else if(num_loop % PrefetchStages == 7)
 
  687     template <
bool HasMainLoop,
 
  691               typename ABlockTransfer,
 
  692               typename AGridBuffer,
 
  693               typename ABlockBuffer,
 
  694               typename ABlockTransferStep,
 
  697               typename BBlockTransfer,
 
  698               typename BGridBuffer,
 
  699               typename BBlockBuffer,
 
  700               typename BBlockTransferStep,
 
  701               typename CThreadBuffer>
 
  702     __device__ 
void Run(
const AGridDesc& a_grid_desc,
 
  703                         const ABlockDesc& a_block_desc,
 
  704                         ABlockTransfer& a_blockwise_copy,
 
  705                         const AGridBuffer& a_grid_buf,
 
  706                         ABlockBuffer& a_block_buf,
 
  707                         const ABlockTransferStep& a_block_copy_step,
 
  708                         const BGridDesc& b_grid_desc,
 
  709                         const BBlockDesc& b_block_desc,
 
  710                         BBlockTransfer& b_blockwise_copy,
 
  711                         const BGridBuffer& b_grid_buf,
 
  712                         BBlockBuffer& b_block_buf,
 
  713                         const BBlockTransferStep& b_block_copy_step,
 
  714                         CThreadBuffer& c_thread_buf,
 
  717         auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
 
  718             a_thread_desc_.GetElementSpaceSize());
 
  719         auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
 
  720             b_thread_desc_.GetElementSpaceSize());
 
  723         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
 
  724         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
 
  726         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  727         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  730         c_thread_buf.Clear();
 
  733         a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
 
  734         b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
 
  738             a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
 
  739             b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
 
  741             a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  742             b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  746         if constexpr(HasMainLoop)
 
  756                             a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  764                             b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  771                         __builtin_amdgcn_sched_barrier(0);
 
  779                         if constexpr(k0.value != 0 || KRepeat == 1)
 
  781                             __builtin_amdgcn_s_barrier();
 
  782                             __builtin_amdgcn_sched_barrier(0);
 
  791                                         a_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  792                                             a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  794                                         b_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  795                                             b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  799                                     using mfma_input_type =
 
  801                                                              xdlops_gemm.K1PerXdlops>::type;
 
  804                                         c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  812                                     if constexpr(k0.value == KRepeat - 1 &&
 
  813                                                  k_.value == KPerInnerLoop - KPack &&
 
  814                                                  m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
 
  816                                         __builtin_amdgcn_sched_barrier(0);
 
  818                                         __builtin_amdgcn_sched_barrier(0);
 
  821                                         a_thread_vec.template AsType<mfma_input_type>(),
 
  822                                         b_thread_vec.template AsType<mfma_input_type>(),
 
  824                                     if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
 
  826                                         __builtin_amdgcn_sched_barrier(0);
 
  827                                         __builtin_amdgcn_s_setprio(1);
 
  828                                         __builtin_amdgcn_sched_barrier(0);
 
  833                         __builtin_amdgcn_sched_barrier(0);
 
  834                         __builtin_amdgcn_s_setprio(0);
 
  835                         __builtin_amdgcn_sched_barrier(0);
 
  839                     a_blockwise_copy.RunWrite(
 
  840                         a_block_desc, a_block_buf, 
Number<(iprefetch + 1) % PrefetchStages>{});
 
  841                     b_blockwise_copy.RunWrite(
 
  842                         b_block_desc, b_block_buf, 
Number<(iprefetch + 1) % PrefetchStages>{});
 
  844                     a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
 
  845                     b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
 
  847                     a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  848                     b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  851             } 
while(i < (num_loop - PrefetchStages));
 
  856         auto LoopTailFunc = [&](
auto tail_num) {
 
  861                         a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  869                         b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  877                     __builtin_amdgcn_sched_barrier(0);
 
  878                     if constexpr(k0.value != 0 || KRepeat == 1)
 
  880                         __builtin_amdgcn_s_barrier();
 
  881                         __builtin_amdgcn_sched_barrier(0);
 
  890                                     a_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  891                                         a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  893                                     b_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  894                                         b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  898                                 using mfma_input_type =
 
  900                                                          xdlops_gemm.K1PerXdlops>::type;
 
  903                                     c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  905                                 if constexpr(k0.value == KRepeat - 1 &&
 
  906                                              k_.value == KPerInnerLoop - KPack &&
 
  907                                              m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
 
  909                                     __builtin_amdgcn_sched_barrier(0);
 
  911                                     __builtin_amdgcn_sched_barrier(0);
 
  914                                     a_thread_vec.template AsType<mfma_input_type>(),
 
  915                                     b_thread_vec.template AsType<mfma_input_type>(),
 
  917                                 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
 
  919                                     __builtin_amdgcn_sched_barrier(0);
 
  920                                     __builtin_amdgcn_s_setprio(1);
 
  921                                     __builtin_amdgcn_sched_barrier(0);
 
  926                     __builtin_amdgcn_sched_barrier(0);
 
  927                     __builtin_amdgcn_s_setprio(0);
 
  928                     __builtin_amdgcn_sched_barrier(0);
 
  931                 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch);
 
  932                 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch);
 
  937                     a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  945                     b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  953                 __builtin_amdgcn_sched_barrier(0);
 
  954                 if constexpr(k0.value != 0 || KRepeat == 1)
 
  956                     __builtin_amdgcn_s_barrier();
 
  957                     __builtin_amdgcn_sched_barrier(0);
 
  966                                 a_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  967                                     a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  969                                 b_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  970                                     b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  974                             using mfma_input_type =
 
  976                                                      xdlops_gemm.K1PerXdlops>::type;
 
  979                                 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  981                             if constexpr(k0.value == KRepeat - 1 &&
 
  982                                          k_.value == KPerInnerLoop - KPack &&
 
  983                                          m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
 
  985                                 __builtin_amdgcn_sched_barrier(0);
 
  987                                 __builtin_amdgcn_sched_barrier(0);
 
  990                                 a_thread_vec.template AsType<mfma_input_type>(),
 
  991                                 b_thread_vec.template AsType<mfma_input_type>(),
 
  993                             if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
 
  995                                 __builtin_amdgcn_sched_barrier(0);
 
  996                                 __builtin_amdgcn_s_setprio(1);
 
  997                                 __builtin_amdgcn_sched_barrier(0);
 
 1002                 __builtin_amdgcn_sched_barrier(0);
 
 1003                 __builtin_amdgcn_s_setprio(0);
 
 1004                 __builtin_amdgcn_sched_barrier(0);
 
 1013                     a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
 1021                     b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
 1029                 __builtin_amdgcn_sched_barrier(0);
 
 1030                 if constexpr(k0.value != 0 || KRepeat == 1)
 
 1032                     __builtin_amdgcn_s_barrier();
 
 1033                     __builtin_amdgcn_sched_barrier(0);
 
 1042                                 a_thread_vec.template AsType<ComputeDataType>()(ik) =
 
 1043                                     a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
 1045                                 b_thread_vec.template AsType<ComputeDataType>()(ik) =
 
 1046                                     b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
 1050                             using mfma_input_type =
 
 1052                                                      xdlops_gemm.K1PerXdlops>::type;
 
 1055                                 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
 1057                             if constexpr(k0.value == KRepeat - 1 &&
 
 1058                                          k_.value == KPerInnerLoop - KPack &&
 
 1059                                          m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
 
 1061                                 __builtin_amdgcn_sched_barrier(0);
 
 1063                                 __builtin_amdgcn_sched_barrier(0);
 
 1066                                 a_thread_vec.template AsType<mfma_input_type>(),
 
 1067                                 b_thread_vec.template AsType<mfma_input_type>(),
 
 1069                             if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
 
 1071                                 __builtin_amdgcn_sched_barrier(0);
 
 1072                                 __builtin_amdgcn_s_setprio(1);
 
 1073                                 __builtin_amdgcn_sched_barrier(0);
 
 1078                 __builtin_amdgcn_sched_barrier(0);
 
 1079                 __builtin_amdgcn_s_setprio(0);
 
 1080                 __builtin_amdgcn_sched_barrier(0);
 
 1118                    Number<KRepeat * MRepeat * KPerInnerLoop>{},
 
 1119                    Number<MRepeat * KPerInnerLoop>{},
 
 1125                    Number<KRepeat * NRepeat * KPerInnerLoop>{},
 
 1126                    Number<NRepeat * KPerInnerLoop>{},
 
 1131                                                          decltype(a_block_desc_m0_m1_m2_k),
 
 1132                                                          decltype(a_thread_desc_),
 
 1141                                                          decltype(b_block_desc_n0_n1_n2_k),
 
 1142                                                          decltype(b_thread_desc_),
 
 1151     using Base::c_thread_desc_;
 
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:207
 
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
 
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
 
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:297
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
 
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
 
ck::BlockwiseGemmXdlops_pipeline_v2< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:702
 
ck::BlockwiseGemmXdlops_pipeline_v2< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:651
 
ck::BlockwiseGemmXdlops_pipeline_v2< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:646
 
ck::BlockwiseGemmXdlops_pipeline_v2< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:212
 
ck::BlockwiseGemmXdlops_pipeline_v2< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:161
 
ck::BlockwiseGemmXdlops_pipeline_v2< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:156
 
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:37
 
Definition: sequence.hpp:43
 
ck::ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataType, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 >  
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33
 
Definition: dtype_vector.hpp:10