20           typename ComputeDataType,
 
   24           typename AMmaTileDesc,
 
   25           typename BMmaTileDesc,
 
   26           index_t ABlockTransferSrcScalarPerVector,
 
   27           index_t BBlockTransferSrcScalarPerVector,
 
   43           typename ComputeDataType,
 
   47           typename AMmaTileDesc,
 
   48           typename BMmaTileDesc,
 
   49           index_t ABlockTransferSrcScalarPerVector,
 
   50           index_t BBlockTransferSrcScalarPerVector,
 
   71                                        ABlockTransferSrcScalarPerVector,
 
   72                                        BBlockTransferSrcScalarPerVector,
 
   90                                         ABlockTransferSrcScalarPerVector,
 
   91                                         BBlockTransferSrcScalarPerVector,
 
  111                                                    ABlockTransferSrcScalarPerVector,
 
  112                                                    BBlockTransferSrcScalarPerVector,
 
  123     using Base::xdlops_gemm;
 
  125     using Base::CalculateCThreadOriginDataIndex;
 
  126     using Base::CalculateCThreadOriginDataIndex8D;
 
  127     using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  128     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  129     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  130     using Base::GetCThreadBuffer;
 
  131     using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  132     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  133     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  134     using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  135     using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  137     using Base::a_block_desc_m0_m1_m2_k;
 
  138     using Base::b_block_desc_n0_n1_n2_k;
 
  140     using Base::AMmaKStride;
 
  141     using Base::BMmaKStride;
 
  149         return num_loop > PrefetchStages;
 
  158     template <
bool HasMainLoop,
 
  162               typename ABlockTransfer,
 
  163               typename AGridBuffer,
 
  164               typename ABlockBuffer,
 
  165               typename ABlockTransferStep,
 
  168               typename BBlockTransfer,
 
  169               typename BGridBuffer,
 
  170               typename BBlockBuffer,
 
  171               typename BBlockTransferStep,
 
  172               typename CThreadBuffer>
 
  173     __device__ 
void Run(
const AGridDesc& a_grid_desc,
 
  174                         const ABlockDesc& a_block_desc,
 
  175                         ABlockTransfer& a_blockwise_copy,
 
  176                         const AGridBuffer& a_grid_buf,
 
  177                         ABlockBuffer& a_block_buf,
 
  178                         const ABlockTransferStep& a_block_copy_step,
 
  179                         const BGridDesc& b_grid_desc,
 
  180                         const BBlockDesc& b_block_desc,
 
  181                         BBlockTransfer& b_blockwise_copy,
 
  182                         const BGridBuffer& b_grid_buf,
 
  183                         BBlockBuffer& b_block_buf,
 
  184                         const BBlockTransferStep& b_block_copy_step,
 
  185                         CThreadBuffer& c_thread_buf,
 
  188         auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
 
  189             a_thread_desc_.GetElementSpaceSize());
 
  190         auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
 
  191             b_thread_desc_.GetElementSpaceSize());
 
  194         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  195         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  197         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  198         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  201         a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
 
  202         b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
 
  205         c_thread_buf.Clear();
 
  208         if constexpr(HasMainLoop)
 
  214                 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  215                 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  217                 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  218                 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  223                         a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  230                             b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  247                                 a_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  248                                     a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  250                                 b_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  251                                     b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  255                             using mfma_input_type =
 
  257                                                      xdlops_gemm.K1PerXdlops>::type;
 
  260                                 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  263                                 a_thread_vec.template AsType<mfma_input_type>(),
 
  264                                 b_thread_vec.template AsType<mfma_input_type>(),
 
  271                 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
 
  272                 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
 
  275             } 
while(i < (num_loop - 1));
 
  284                     a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  291                         b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  308                             a_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  309                                 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  311                             b_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  312                                 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  316                         using mfma_input_type =
 
  317                             typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
 
  320                             c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  322                         xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
  323                                         b_thread_vec.template AsType<mfma_input_type>(),
 
  332     using Base::a_thread_copy_;
 
  333     using Base::a_thread_desc_;
 
  334     using Base::b_thread_copy_;
 
  335     using Base::b_thread_desc_;
 
  336     using Base::c_thread_desc_;
 
  342           typename ComputeDataType,
 
  343           typename AccDataType,
 
  346           typename AMmaTileDesc,
 
  347           typename BMmaTileDesc,
 
  348           index_t ABlockTransferSrcScalarPerVector,
 
  349           index_t BBlockTransferSrcScalarPerVector,
 
  370                                        ABlockTransferSrcScalarPerVector,
 
  371                                        BBlockTransferSrcScalarPerVector,
 
  389                                         ABlockTransferSrcScalarPerVector,
 
  390                                         BBlockTransferSrcScalarPerVector,
 
  410                                                    ABlockTransferSrcScalarPerVector,
 
  411                                                    BBlockTransferSrcScalarPerVector,
 
  424     using Base::KPerThread;
 
  425     using Base::xdlops_gemm;
 
  427     using Base::CalculateCThreadOriginDataIndex;
 
  428     using Base::CalculateCThreadOriginDataIndex8D;
 
  429     using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  430     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  431     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  432     using Base::GetCThreadBuffer;
 
  433     using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  434     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  435     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  436     using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  437     using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  439     using Base::a_block_desc_m0_m1_m2_k;
 
  440     using Base::b_block_desc_n0_n1_n2_k;
 
  444     static constexpr 
index_t KRepeat         = KPerThread / KPerInnerLoop;
 
  450         return num_loop > PrefetchStages;
 
  459     template <
bool HasMainLoop,
 
  463               typename ABlockTransfer,
 
  464               typename AGridBuffer,
 
  465               typename ABlockBuffer,
 
  466               typename ABlockTransferStep,
 
  469               typename BBlockTransfer,
 
  470               typename BGridBuffer,
 
  471               typename BBlockBuffer,
 
  472               typename BBlockTransferStep,
 
  473               typename CThreadBuffer>
 
  474     __device__ 
void Run(
const AGridDesc& a_grid_desc,
 
  475                         const ABlockDesc& a_block_desc,
 
  476                         ABlockTransfer& a_blockwise_copy,
 
  477                         const AGridBuffer& a_grid_buf,
 
  478                         ABlockBuffer& a_block_buf,
 
  479                         const ABlockTransferStep& a_block_copy_step,
 
  480                         const BGridDesc& b_grid_desc,
 
  481                         const BBlockDesc& b_block_desc,
 
  482                         BBlockTransfer& b_blockwise_copy,
 
  483                         const BGridBuffer& b_grid_buf,
 
  484                         BBlockBuffer& b_block_buf,
 
  485                         const BBlockTransferStep& b_block_copy_step,
 
  486                         CThreadBuffer& c_thread_buf,
 
  489         auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
 
  490             a_thread_desc_.GetElementSpaceSize());
 
  491         auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
 
  492             b_thread_desc_.GetElementSpaceSize());
 
  495         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  496         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  498         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  499         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  502         a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
 
  503         b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
 
  506         c_thread_buf.Clear();
 
  509         if constexpr(HasMainLoop)
 
  515                 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  516                 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  518                 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  519                 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  524                         a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  531                             b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  539                     __builtin_amdgcn_sched_barrier(0);
 
  546                     if constexpr(k0.value != 0 || KRepeat == 1)
 
  548                         __builtin_amdgcn_s_barrier();
 
  549                         __builtin_amdgcn_sched_barrier(0);
 
  558                                     a_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  559                                         a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  561                                     b_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  562                                         b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  566                                 using mfma_input_type =
 
  568                                                          xdlops_gemm.K1PerXdlops>::type;
 
  571                                     c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  578                                 if constexpr(k0.value == KRepeat - 1 &&
 
  579                                              k_.value == KPerInnerLoop - KPack &&
 
  580                                              m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
 
  582                                     __builtin_amdgcn_sched_barrier(0);
 
  584                                     __builtin_amdgcn_sched_barrier(0);
 
  587                                     a_thread_vec.template AsType<mfma_input_type>(),
 
  588                                     b_thread_vec.template AsType<mfma_input_type>(),
 
  590                                 if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
 
  592                                     __builtin_amdgcn_sched_barrier(0);
 
  593                                     __builtin_amdgcn_s_setprio(1);
 
  594                                     __builtin_amdgcn_sched_barrier(0);
 
  599                     __builtin_amdgcn_sched_barrier(0);
 
  600                     __builtin_amdgcn_s_setprio(0);
 
  601                     __builtin_amdgcn_sched_barrier(0);
 
  605                 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
 
  606                 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
 
  609             } 
while(i < (num_loop - 1));
 
  618                     a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  625                         b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  634                 __builtin_amdgcn_sched_barrier(0);
 
  635                 if constexpr(k0.value != 0 || KRepeat == 1)
 
  637                     __builtin_amdgcn_s_barrier();
 
  638                     __builtin_amdgcn_sched_barrier(0);
 
  647                                 a_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  648                                     a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  650                                 b_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  651                                     b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  655                             using mfma_input_type =
 
  657                                                      xdlops_gemm.K1PerXdlops>::type;
 
  660                                 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  662                             if constexpr(k0.value == KRepeat - 1 &&
 
  663                                          k_.value == KPerInnerLoop - KPack &&
 
  664                                          m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
 
  666                                 __builtin_amdgcn_sched_barrier(0);
 
  668                                 __builtin_amdgcn_sched_barrier(0);
 
  671                                 a_thread_vec.template AsType<mfma_input_type>(),
 
  672                                 b_thread_vec.template AsType<mfma_input_type>(),
 
  674                             if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
 
  676                                 __builtin_amdgcn_sched_barrier(0);
 
  677                                 __builtin_amdgcn_s_setprio(1);
 
  678                                 __builtin_amdgcn_sched_barrier(0);
 
  683                 __builtin_amdgcn_sched_barrier(0);
 
  684                 __builtin_amdgcn_s_setprio(0);
 
  685                 __builtin_amdgcn_sched_barrier(0);
 
  695                    Number<KRepeat * MRepeat * KPerInnerLoop>{},
 
  696                    Number<MRepeat * KPerInnerLoop>{},
 
  702                    Number<KRepeat * NRepeat * KPerInnerLoop>{},
 
  703                    Number<NRepeat * KPerInnerLoop>{},
 
  708                                                          decltype(a_block_desc_m0_m1_m2_k),
 
  709                                                          decltype(a_thread_desc_),
 
  718                                                          decltype(b_block_desc_n0_n1_n2_k),
 
  719                                                          decltype(b_thread_desc_),
 
  726     AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
 
  727     BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
 
  728     using Base::c_thread_desc_;
 
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:207
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
 
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:300
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
 
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
 
ck::BlockwiseGemmXdlops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:453
 
ck::BlockwiseGemmXdlops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:448
 
ck::BlockwiseGemmXdlops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:474
 
ck::BlockwiseGemmXdlops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:173
 
ck::BlockwiseGemmXdlops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:152
 
ck::BlockwiseGemmXdlops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:147
 
Definition: blockwise_gemm_pipeline_xdlops_v1.hpp:37
 
Definition: sequence.hpp:43
 
ck::ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataType, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 >  
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33
 
Definition: dtype_vector.hpp:10