40         MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth);
 
   42         NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth);
 
   45         MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth);
 
   47         NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth);
 
   50         WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth);
 
   52         WaveNumM * MPerBlock * KPerBlock / (BlockSize * BLDSReadWidth);
 
   55         MPerBlock * NPerBlock * KPerBlock / (BlockSize / 
WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
 
   59         printf(
" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n",
 
   69         printf(
" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: " 
   70                "%d, %d\n C MFMA inst: %d\n",
 
   87     typename AMmaTileDesc,
 
   88     typename BMmaTileDesc,
 
   97     bool TransposeC = 
false,
 
   99         KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops,
 
  101         KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops>
 
  144                   "Wrong KPack setting; try increasing KPerThread or decreasing KPack");
 
  164         return threadid_to_wave_idx_adaptor.CalculateBottomIndex(
make_multi_index(thread_id));
 
  171         const auto waveId_m = wave_idx[
I0];
 
  173         const auto xdlops_a_idx = 
xdlops_gemm.CalculateAThreadOriginDataIndex();
 
  175         return make_tuple(0, waveId_m, xdlops_a_idx[
I1], KPack * xdlops_a_idx[
I0]);
 
  182         const auto waveId_n = wave_idx[
I1];
 
  184         const auto xdlops_b_idx = 
xdlops_gemm.CalculateBThreadOriginDataIndex();
 
  186         return make_tuple(0, waveId_n, xdlops_b_idx[
I1], KPack * xdlops_b_idx[
I0]);
 
  189     template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
 
  190     __device__ 
static auto 
  195         const auto waveId_m = wave_idx[
I0];
 
  196         const auto waveId_n = wave_idx[
I1];
 
  198         const auto blk_idx = 
xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
 
  210         const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
 
  212         const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
 
  218     template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
 
  219     __device__ 
static auto 
  224         const auto waveId_m = wave_idx[
I0];
 
  225         const auto waveId_n = wave_idx[
I1];
 
  227         const auto blk_idx = 
xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
 
  230             m0, n0, waveId_m, waveId_n, blk_idx[
I0], blk_idx[
I1], blk_idx[
I2], blk_idx[
I3]);
 
  240         static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
 
  241                       "wrong! Desc should be known at compile-time");
 
  244                       "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
 
  246         static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
 
  255         constexpr 
auto c_m0_m1_m2_n_tblk_lens = 
xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
 
  257         constexpr 
auto M0 = c_m0_m1_m2_n_tblk_lens[
I0];
 
  258         constexpr 
auto M1 = c_m0_m1_m2_n_tblk_lens[
I1];
 
  259         constexpr 
auto M2 = c_m0_m1_m2_n_tblk_lens[
I2];
 
  260         constexpr 
auto N  = c_m0_m1_m2_n_tblk_lens[
I3];
 
  269         constexpr 
auto c_m0_m1_m2_n_tblk_lens = 
xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
 
  271         constexpr 
auto M0 = c_m0_m1_m2_n_tblk_lens[
I0];
 
  272         constexpr 
auto M1 = c_m0_m1_m2_n_tblk_lens[
I1];
 
  273         constexpr 
auto M2 = c_m0_m1_m2_n_tblk_lens[
I2];
 
  274         constexpr 
auto N  = c_m0_m1_m2_n_tblk_lens[
I3];
 
  282         constexpr 
auto c_m0_m1_m2_n_tblk_lens = 
xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
 
  284         constexpr 
auto M0 = c_m0_m1_m2_n_tblk_lens[
I0];
 
  285         constexpr 
auto M1 = c_m0_m1_m2_n_tblk_lens[
I1];
 
  286         constexpr 
auto M2 = c_m0_m1_m2_n_tblk_lens[
I2];
 
  287         constexpr 
auto N  = c_m0_m1_m2_n_tblk_lens[
I3];
 
  296         constexpr 
auto c_block_desc_m0_n0_m1_n1_m2_n2 =
 
  304         return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
 
  310         constexpr 
auto c_block_desc_m0_n0_m1_n1_m2_n2 =
 
  318         return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
 
  323         constexpr 
auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
 
  332         return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
 
  333             c_block_desc_g_m0_n0_m1_n1_m2_n2);
 
  336     template <
typename CGr
idDesc_M_N>
 
  337     __host__ __device__ 
static constexpr 
auto 
  340         const auto M = c_grid_desc_m_n.GetLength(
I0);
 
  341         const auto N = c_grid_desc_m_n.GetLength(
I1);
 
  350         return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
 
  353     template <
typename CGr
idDesc_G_M_N>
 
  354     __host__ __device__ 
static constexpr 
auto 
  357         const auto G = c_grid_desc_g_m_n.GetLength(
I0);
 
  358         const auto M = c_grid_desc_g_m_n.GetLength(
I1);
 
  359         const auto N = c_grid_desc_g_m_n.GetLength(
I2);
 
  369         return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
 
  370             c_grid_desc_g_m0_n0_m1_n1_m2_n2);
 
  376         constexpr 
auto num_ds_read_inst =
 
  378         constexpr 
auto num_ds_write_inst =
 
  381         constexpr 
auto num_buffer_load_inst =
 
  386         constexpr 
auto num_issue = num_buffer_load_inst;
 
  390             __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  391             __builtin_amdgcn_sched_group_barrier(
 
  392                 0x100, num_ds_read_inst / num_buffer_load_inst, 0); 
 
  393             __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);      
 
  394             __builtin_amdgcn_sched_group_barrier(
 
  395                 0x200, num_ds_write_inst / num_buffer_load_inst, 0); 
 
  396             __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);       
 
  397             __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);       
 
  398             __builtin_amdgcn_sched_group_barrier(
 
  399                 0x008, num_mfma_inst / num_buffer_load_inst - 3, 0); 
 
  403     template <index_t stage>
 
  409     __device__ constexpr 
auto TailScheduler<1>()
 
  412         constexpr 
auto num_ds_read_inst =
 
  414         constexpr 
auto num_ds_write_inst =
 
  419         constexpr 
auto num_issue = num_ds_write_inst;
 
  423             __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  424             __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); 
 
  425             __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  426             __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); 
 
  427             __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  428             __builtin_amdgcn_sched_group_barrier(
 
  429                 0x100, num_ds_read_inst / num_ds_write_inst - 1, 0); 
 
  430             __builtin_amdgcn_sched_group_barrier(
 
  431                 0x008, num_mfma_inst / num_ds_write_inst - 3, 0); 
 
  436     __device__ constexpr 
auto TailScheduler<2>()
 
  439         constexpr 
auto num_ds_read_inst =
 
  443         constexpr 
auto num_issue = num_ds_read_inst;
 
  447             __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); 
 
  448             __builtin_amdgcn_sched_group_barrier(
 
  449                 0x008, num_mfma_inst / num_ds_read_inst, 0); 
 
  456     template <
bool HasMainLoop,
 
  460               typename ABlockTransfer,
 
  461               typename AGridBuffer,
 
  462               typename ABlockBuffer,
 
  463               typename ABlockTransferStep,
 
  466               typename BBlockTransfer,
 
  467               typename BGridBuffer,
 
  468               typename BBlockBuffer,
 
  469               typename BBlockTransferStep,
 
  470               typename CThreadBuffer>
 
  471     __device__ 
void Run(
const AGridDesc& a_grid_desc,
 
  472                         const ABlockDesc& a_block_desc,
 
  473                         ABlockTransfer& a_blockwise_copy,
 
  474                         const AGridBuffer& a_grid_buf,
 
  475                         ABlockBuffer& a_block_buf,
 
  476                         const ABlockTransferStep& a_block_copy_step,
 
  477                         const BGridDesc& b_grid_desc,
 
  478                         const BBlockDesc& b_block_desc,
 
  479                         BBlockTransfer& b_blockwise_copy,
 
  480                         const BGridBuffer& b_grid_buf,
 
  481                         BBlockBuffer& b_block_buf,
 
  482                         const BBlockTransferStep& b_block_copy_step,
 
  483                         CThreadBuffer& c_thread_buf,
 
  486         __builtin_amdgcn_sched_barrier(0);
 
  487         auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
 
  489         auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
 
  502         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  503         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  505         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  506         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  508         a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(
I0));
 
  509         b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(
I0));
 
  533         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  534         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  536         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  537         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  539         a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(
I1));
 
  540         b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(
I1));
 
  543         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  544         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  546         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  547         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  550         c_thread_buf.Clear();
 
  553         if constexpr(HasMainLoop)
 
  571                                            a_block_buf.At(PongP1{}),
 
  574                                            a_thread_bufs(PongP1{}));
 
  578                                                b_block_buf.At(PongP1{}),
 
  581                                                b_thread_bufs(PongP1{}));
 
  586                 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{}));
 
  587                 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{}));
 
  589                 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  590                 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  592                 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  593                 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  595                 static_for<0, KRepeat, 1>{}([&](
auto k0) {
 
  596                     static_for<0, MRepeat, 1>{}([&](
auto m0) {
 
  597                         static_for<0, NRepeat, 1>{}([&](
auto n0) {
 
  598                             vector_type<FloatAB, KPack> a_thread_vec;
 
  599                             vector_type<FloatAB, KPack> b_thread_vec;
 
  601                             static_for<0, KPack, 1>{}([&](
auto ik) {
 
  602                                 a_thread_vec.template AsType<FloatAB>()(ik) =
 
  603                                     a_thread_bufs[PingP1{}][
Number<a_thread_desc_.CalculateOffset(
 
  605                                 b_thread_vec.template AsType<FloatAB>()(ik) =
 
  606                                     b_thread_bufs[PingP1{}][
Number<b_thread_desc_.CalculateOffset(
 
  610                             using mfma_input_type =
 
  611                                 typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
 
  614                                 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  617                                 a_thread_vec.template AsType<mfma_input_type>(),
 
  618                                 b_thread_vec.template AsType<mfma_input_type>(),
 
  619                                 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
 
  625                 __builtin_amdgcn_sched_barrier(0);
 
  628                 using PingP2 = Number<1>;
 
  629                 using PongP2 = Number<0>;
 
  635                 static_for<0, KRepeat, 1>{}([&](
auto k) {
 
  636                     static_for<0, MRepeat, 1>{}([&](
auto m0) {
 
  637                         a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  638                                            make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
 
  639                                            a_block_buf.At(PongP2{}),
 
  642                                            a_thread_bufs(PongP2{}));
 
  643                         static_for<0, NRepeat, 1>{}([&](
auto n0) {
 
  644                             b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  645                                                make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
 
  646                                                b_block_buf.At(PongP2{}),
 
  649                                                b_thread_bufs(PongP2{}));
 
  654                 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP2{}));
 
  655                 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP2{}));
 
  657                 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  658                 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  660                 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  661                 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  663                 static_for<0, KRepeat, 1>{}([&](
auto k0) {
 
  664                     static_for<0, MRepeat, 1>{}([&](
auto m0) {
 
  665                         static_for<0, NRepeat, 1>{}([&](
auto n0) {
 
  666                             vector_type<FloatAB, KPack> a_thread_vec;
 
  667                             vector_type<FloatAB, KPack> b_thread_vec;
 
  669                             static_for<0, KPack, 1>{}([&](
auto ik) {
 
  670                                 a_thread_vec.template AsType<FloatAB>()(ik) =
 
  671                                     a_thread_bufs[PingP2{}][
Number<a_thread_desc_.CalculateOffset(
 
  673                                 b_thread_vec.template AsType<FloatAB>()(ik) =
 
  674                                     b_thread_bufs[PingP2{}][
Number<b_thread_desc_.CalculateOffset(
 
  678                             using mfma_input_type =
 
  679                                 typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
 
  682                                 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  685                                 a_thread_vec.template AsType<mfma_input_type>(),
 
  686                                 b_thread_vec.template AsType<mfma_input_type>(),
 
  687                                 c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
 
  693                 __builtin_amdgcn_sched_barrier(0);
 
  696             } 
while(i < (num_loop - 3));
 
  700         if constexpr(TailNum == 3)
 
  702             using PingP1 = Number<0>;
 
  703             using PongP1 = Number<1>;
 
  709             static_for<0, KRepeat, 1>{}([&](
auto k) {
 
  710                 static_for<0, MRepeat, 1>{}([&](
auto m0) {
 
  711                     a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  712                                        make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
 
  713                                        a_block_buf.At(PongP1{}),
 
  716                                        a_thread_bufs(PongP1{}));
 
  717                     static_for<0, NRepeat, 1>{}([&](
auto n0) {
 
  718                         b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  719                                            make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
 
  720                                            b_block_buf.At(PongP1{}),
 
  723                                            b_thread_bufs(PongP1{}));
 
  728             a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{}));
 
  729             b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{}));
 
  731             static_for<0, KRepeat, 1>{}([&](
auto k0) {
 
  732                 static_for<0, MRepeat, 1>{}([&](
auto m0) {
 
  733                     static_for<0, NRepeat, 1>{}([&](
auto n0) {
 
  734                         vector_type<FloatAB, KPack> a_thread_vec;
 
  735                         vector_type<FloatAB, KPack> b_thread_vec;
 
  737                         static_for<0, KPack, 1>{}([&](
auto ik) {
 
  738                             a_thread_vec.template AsType<FloatAB>()(ik) =
 
  739                                 a_thread_bufs[PingP1{}][
Number<a_thread_desc_.CalculateOffset(
 
  741                             b_thread_vec.template AsType<FloatAB>()(ik) =
 
  742                                 b_thread_bufs[PingP1{}][
Number<b_thread_desc_.CalculateOffset(
 
  746                         using mfma_input_type =
 
  747                             typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
 
  750                             c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  752                         xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
  753                                         b_thread_vec.template AsType<mfma_input_type>(),
 
  754                                         c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
 
  760             __builtin_amdgcn_sched_barrier(0);
 
  763             using PingP2 = Number<1>;
 
  764             using PongP2 = Number<0>;
 
  770             static_for<0, KRepeat, 1>{}([&](
auto k) {
 
  771                 static_for<0, MRepeat, 1>{}([&](
auto m0) {
 
  772                     a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  773                                        make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
 
  774                                        a_block_buf.At(PongP2{}),
 
  777                                        a_thread_bufs(PongP2{}));
 
  778                     static_for<0, NRepeat, 1>{}([&](
auto n0) {
 
  779                         b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  780                                            make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
 
  781                                            b_block_buf.At(PongP2{}),
 
  784                                            b_thread_bufs(PongP2{}));
 
  789             static_for<0, KRepeat, 1>{}([&](
auto k0) {
 
  790                 static_for<0, MRepeat, 1>{}([&](
auto m0) {
 
  791                     static_for<0, NRepeat, 1>{}([&](
auto n0) {
 
  792                         vector_type<FloatAB, KPack> a_thread_vec;
 
  793                         vector_type<FloatAB, KPack> b_thread_vec;
 
  795                         static_for<0, KPack, 1>{}([&](
auto ik) {
 
  796                             a_thread_vec.template AsType<FloatAB>()(ik) =
 
  797                                 a_thread_bufs[PingP2{}][
Number<a_thread_desc_.CalculateOffset(
 
  799                             b_thread_vec.template AsType<FloatAB>()(ik) =
 
  800                                 b_thread_bufs[PingP2{}][
Number<b_thread_desc_.CalculateOffset(
 
  804                         using mfma_input_type =
 
  805                             typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
 
  808                             c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  810                         xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
  811                                         b_thread_vec.template AsType<mfma_input_type>(),
 
  812                                         c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
 
  818             __builtin_amdgcn_sched_barrier(0);
 
  820             static_for<0, KRepeat, 1>{}([&](
auto k) {
 
  821                 static_for<0, MRepeat, 1>{}([&](
auto m0) {
 
  822                     static_for<0, NRepeat, 1>{}([&](
auto n0) {
 
  823                         vector_type<FloatAB, KPack> a_thread_vec;
 
  824                         vector_type<FloatAB, KPack> b_thread_vec;
 
  826                         static_for<0, KPack, 1>{}([&](
auto ik) {
 
  827                             a_thread_vec.template AsType<FloatAB>()(ik) =
 
  828                                 a_thread_bufs[PongP2{}][
Number<a_thread_desc_.CalculateOffset(
 
  830                             b_thread_vec.template AsType<FloatAB>()(ik) =
 
  831                                 b_thread_bufs[PongP2{}][
Number<b_thread_desc_.CalculateOffset(
 
  835                         using mfma_input_type =
 
  836                             typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
 
  839                             c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  841                         xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
  842                                         b_thread_vec.template AsType<mfma_input_type>(),
 
  843                                         c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
 
  849             __builtin_amdgcn_sched_group_barrier(0x008, 64, 0); 
 
  850             __builtin_amdgcn_sched_barrier(0);
 
  852         else if constexpr(TailNum == 2)
 
  854             using PingP1 = Number<0>;
 
  855             using PongP1 = Number<1>;
 
  861             static_for<0, KRepeat, 1>{}([&](
auto k) {
 
  862                 static_for<0, MRepeat, 1>{}([&](
auto m0) {
 
  863                     a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  864                                        make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
 
  865                                        a_block_buf.At(PongP1{}),
 
  868                                        a_thread_bufs(PongP1{}));
 
  869                     static_for<0, NRepeat, 1>{}([&](
auto n0) {
 
  870                         b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  871                                            make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
 
  872                                            b_block_buf.At(PongP1{}),
 
  875                                            b_thread_bufs(PongP1{}));
 
  880             static_for<0, KRepeat, 1>{}([&](
auto k0) {
 
  881                 static_for<0, MRepeat, 1>{}([&](
auto m0) {
 
  882                     static_for<0, NRepeat, 1>{}([&](
auto n0) {
 
  883                         vector_type<FloatAB, KPack> a_thread_vec;
 
  884                         vector_type<FloatAB, KPack> b_thread_vec;
 
  886                         static_for<0, KPack, 1>{}([&](
auto ik) {
 
  887                             a_thread_vec.template AsType<FloatAB>()(ik) =
 
  888                                 a_thread_bufs[PingP1{}][
Number<a_thread_desc_.CalculateOffset(
 
  890                             b_thread_vec.template AsType<FloatAB>()(ik) =
 
  891                                 b_thread_bufs[PingP1{}][
Number<b_thread_desc_.CalculateOffset(
 
  895                         using mfma_input_type =
 
  896                             typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
 
  899                             c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  901                         xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
  902                                         b_thread_vec.template AsType<mfma_input_type>(),
 
  903                                         c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
 
  909             __builtin_amdgcn_sched_barrier(0);
 
  912             using PingP2 = Number<1>;
 
  917             static_for<0, KRepeat, 1>{}([&](
auto k0) {
 
  918                 static_for<0, MRepeat, 1>{}([&](
auto m0) {
 
  919                     static_for<0, NRepeat, 1>{}([&](
auto n0) {
 
  920                         vector_type<FloatAB, KPack> a_thread_vec;
 
  921                         vector_type<FloatAB, KPack> b_thread_vec;
 
  923                         static_for<0, KPack, 1>{}([&](
auto ik) {
 
  924                             a_thread_vec.template AsType<FloatAB>()(ik) =
 
  925                                 a_thread_bufs[PingP2{}][
Number<a_thread_desc_.CalculateOffset(
 
  927                             b_thread_vec.template AsType<FloatAB>()(ik) =
 
  928                                 b_thread_bufs[PingP2{}][
Number<b_thread_desc_.CalculateOffset(
 
  932                         using mfma_input_type =
 
  933                             typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
 
  936                             c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  938                         xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
  939                                         b_thread_vec.template AsType<mfma_input_type>(),
 
  940                                         c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
 
  946             __builtin_amdgcn_sched_group_barrier(0x008, 64, 0); 
 
  947             __builtin_amdgcn_sched_barrier(0);
 
  958             Number<KPack>{}, Number<KRepeat * MRepeat * KPack>{}, Number<MRepeat * KPack>{}, I1));
 
  964             Number<KPack>{}, Number<KRepeat * NRepeat * KPack>{}, Number<NRepeat * KPack>{}, I1));
 
  972                                                          decltype(a_block_desc_m0_m1_m2_k),
 
  973                                                          decltype(a_thread_desc_),
 
  982                                                          decltype(b_block_desc_n0_n1_n2_k),
 
  983                                                          decltype(b_thread_desc_),
 
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
 
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
 
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
 
int32_t index_t
Definition: ck.hpp:300
 
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
 
integral_constant< index_t, N > Number
Definition: number.hpp:12
 
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
 
static constexpr index_t B_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:46
 
static constexpr index_t A_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:49
 
static constexpr index_t B_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:51
 
static constexpr index_t A_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:44
 
static constexpr index_t C_MFMA_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:54
 
static constexpr index_t A_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:39
 
static constexpr index_t WaveSize
Definition: blockwise_gemm_pipeline_xdlops.hpp:35
 
static constexpr index_t B_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:41
 
static constexpr auto Print()
Definition: blockwise_gemm_pipeline_xdlops.hpp:57
 
static constexpr index_t WaveNumN
Definition: blockwise_gemm_pipeline_xdlops.hpp:37
 
static constexpr index_t WaveNumM
Definition: blockwise_gemm_pipeline_xdlops.hpp:36
 
Definition: blockwise_gemm_pipeline_xdlops.hpp:103
 
static constexpr auto I1
Definition: blockwise_gemm_pipeline_xdlops.hpp:105
 
__host__ static constexpr __device__ auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition: blockwise_gemm_pipeline_xdlops.hpp:355
 
static constexpr index_t MWaves
Definition: blockwise_gemm_pipeline_xdlops.hpp:124
 
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition: blockwise_gemm_pipeline_xdlops.hpp:253
 
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: blockwise_gemm_pipeline_xdlops.hpp:338
 
static constexpr index_t A_K1
Definition: blockwise_gemm_pipeline_xdlops.hpp:115
 
static constexpr index_t A_K0
Definition: blockwise_gemm_pipeline_xdlops.hpp:113
 
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:961
 
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops.hpp:308
 
static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops.hpp:373
 
static constexpr index_t WaveSize
Definition: blockwise_gemm_pipeline_xdlops.hpp:111
 
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops.hpp:991
 
decltype(CalculateAThreadOriginDataIndex()) Tuple4
Definition: blockwise_gemm_pipeline_xdlops.hpp:233
 
static constexpr auto I0
Definition: blockwise_gemm_pipeline_xdlops.hpp:104
 
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_pipeline_xdlops.hpp:453
 
static __device__ auto CalculateBThreadOriginDataIndex()
Definition: blockwise_gemm_pipeline_xdlops.hpp:178
 
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops.hpp:990
 
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops.hpp:267
 
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: blockwise_gemm_pipeline_xdlops.hpp:109
 
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_pipeline_xdlops.hpp:454
 
static constexpr index_t KRepeat
Definition: blockwise_gemm_pipeline_xdlops.hpp:122
 
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_pipeline_xdlops.hpp:153
 
static constexpr auto I3
Definition: blockwise_gemm_pipeline_xdlops.hpp:107
 
static constexpr index_t B_K1
Definition: blockwise_gemm_pipeline_xdlops.hpp:116
 
static constexpr auto I2
Definition: blockwise_gemm_pipeline_xdlops.hpp:106
 
static constexpr index_t B_K0
Definition: blockwise_gemm_pipeline_xdlops.hpp:114
 
__host__ static constexpr __device__ auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops.hpp:321
 
static __device__ auto CalculateAThreadOriginDataIndex()
Definition: blockwise_gemm_pipeline_xdlops.hpp:167
 
static constexpr index_t KPerThread
Definition: blockwise_gemm_pipeline_xdlops.hpp:121
 
static __device__ auto GetWaveIdx()
Definition: blockwise_gemm_pipeline_xdlops.hpp:155
 
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:955
 
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition: blockwise_gemm_pipeline_xdlops.hpp:294
 
static constexpr index_t NWaves
Definition: blockwise_gemm_pipeline_xdlops.hpp:125
 
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, FloatAcc, MRepeat *NRepeat, xdlops_gemm.GetRegSizePerXdlops(), true > c_thread_buf_
Definition: blockwise_gemm_pipeline_xdlops.hpp:144
 
__host__ static constexpr __device__ auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops.hpp:280
 
static constexpr auto xdlops_gemm
Definition: blockwise_gemm_pipeline_xdlops.hpp:118
 
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_pipeline_xdlops.hpp:220
 
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops.hpp:471
 
__host__ __device__ BlockwiseGemmXdlops_pipeline_v4(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Definition: blockwise_gemm_pipeline_xdlops.hpp:236
 
static constexpr __device__ auto TailScheduler()
Definition: blockwise_gemm_pipeline_xdlops.hpp:404
 
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_pipeline_xdlops.hpp:191
 
Definition: sequence.hpp:43
 
Definition: static_buffer.hpp:75
 
static __device__ index_t GetThreadId()
Definition: thread_group.hpp:19
 
static constexpr __device__ index_t GetNumOfThread()
Definition: thread_group.hpp:15
 
ck::ThreadwiseTensorSliceTransfer_v4< FloatAB, FloatAB, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 >  
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1297
 
Definition: xdlops_gemm.hpp:1399
 
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33