20           typename ComputeTypeA,
 
   21           typename ComputeTypeB,
 
   23           typename AWmmaTileDesc,
 
   24           typename BWmmaTileDesc,
 
   25           index_t ABlockTransferSrcScalarPerVector,
 
   26           index_t BBlockTransferSrcScalarPerVector,
 
   42           typename ComputeTypeA,
 
   43           typename ComputeTypeB,
 
   45           typename AWmmaTileDesc,
 
   46           typename BWmmaTileDesc,
 
   47           index_t ABlockTransferSrcScalarPerVector,
 
   48           index_t BBlockTransferSrcScalarPerVector,
 
   66                                         ABlockTransferSrcScalarPerVector,
 
   67                                         BBlockTransferSrcScalarPerVector,
 
   84                                          ABlockTransferSrcScalarPerVector,
 
   85                                          BBlockTransferSrcScalarPerVector,
 
  104                                                     ABlockTransferSrcScalarPerVector,
 
  105                                                     BBlockTransferSrcScalarPerVector,
 
  123     using Base::wmma_gemm;
 
  125     using Base::CalculateCThreadOriginDataIndex;
 
  127         GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
 
  128     using Base::GetCThreadBuffer;
 
  130         GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
 
  132     using Base::a_block_desc_k0_m0_m1_m2_k1;
 
  133     using Base::b_block_desc_k0_n0_n1_n2_k1;
 
  147     template <
bool HasMainLoop,
 
  151               typename ABlockTransfer,
 
  152               typename AGridBuffer,
 
  153               typename ABlockBuffer,
 
  154               typename ABlockTransferStep,
 
  157               typename BBlockTransfer,
 
  158               typename BGridBuffer,
 
  159               typename BBlockBuffer,
 
  160               typename BBlockTransferStep,
 
  161               typename CThreadBuffer>
 
  162     __device__ 
void Run(
const AGridDesc& a_grid_desc,
 
  163                         const ABlockDesc& a_block_desc,
 
  164                         ABlockTransfer& a_blockwise_copy,
 
  165                         const AGridBuffer& a_grid_buf,
 
  166                         ABlockBuffer& a_block_buf,
 
  167                         const ABlockTransferStep& a_block_copy_step,
 
  168                         const BGridDesc& b_grid_desc,
 
  169                         const BBlockDesc& b_block_desc,
 
  170                         BBlockTransfer& b_blockwise_copy,
 
  171                         const BGridBuffer& b_grid_buf,
 
  172                         BBlockBuffer& b_block_buf,
 
  173                         const BBlockTransferStep& b_block_copy_step,
 
  174                         CThreadBuffer& c_thread_buf,
 
  177         auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
 
  178             a_thread_desc_.GetElementSpaceSize());
 
  179         auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
 
  180             b_thread_desc_.GetElementSpaceSize());
 
  183         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  184         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  186         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  187         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  190         a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
 
  191         b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
 
  194         c_thread_buf.Clear();
 
  196         auto blockwise_gemm_func = [&]() {
 
  199                     a_block_desc_k0_m0_m1_m2_k1,
 
  206                     b_block_desc_k0_n0_n1_n2_k1,
 
  215                         vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
 
  216                         vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
 
  218                         static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
 
  219                             a_thread_vec.template AsType<ComputeTypeA>()(ik) =
 
  223                         static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
 
  224                             b_thread_vec.template AsType<ComputeTypeB>()(ik) =
 
  229                         using wmma_input_type_a =
 
  230                             typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
 
  231                         using wmma_input_type_b =
 
  232                             typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
 
  235                             c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
 
  237                         wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
 
  238                                       b_thread_vec.template AsType<wmma_input_type_b>(),
 
  246         if constexpr(HasMainLoop)
 
  251                 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  252                 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  254                 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  255                 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  258                 blockwise_gemm_func();
 
  261                 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
 
  262                 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
 
  265             } 
while(i < (num_loop - 1));
 
  272             blockwise_gemm_func();
 
  277     using Base::a_thread_copy_;
 
  278     using Base::a_thread_desc_;
 
  279     using Base::b_thread_copy_;
 
  280     using Base::b_thread_desc_;
 
  281     using Base::c_thread_desc_;
 
  287           typename ComputeTypeA,
 
  288           typename ComputeTypeB,
 
  289           typename AccDataType,
 
  290           typename AWmmaTileDesc,
 
  291           typename BWmmaTileDesc,
 
  292           index_t ABlockTransferSrcScalarPerVector,
 
  293           index_t BBlockTransferSrcScalarPerVector,
 
  311                                         ABlockTransferSrcScalarPerVector,
 
  312                                         BBlockTransferSrcScalarPerVector,
 
  329                                          ABlockTransferSrcScalarPerVector,
 
  330                                          BBlockTransferSrcScalarPerVector,
 
  349                                                     ABlockTransferSrcScalarPerVector,
 
  350                                                     BBlockTransferSrcScalarPerVector,
 
  369     using Base::wmma_gemm;
 
  371     using Base::CalculateCThreadOriginDataIndex;
 
  373         GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
 
  374     using Base::GetCThreadBuffer;
 
  376         GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
 
  378     using Base::a_block_desc_k0_m0_m1_m2_k1;
 
  379     using Base::b_block_desc_k0_n0_n1_n2_k1;
 
  396     template <
bool HasMainLoop,
 
  400               typename ABlockTransfer,
 
  401               typename AGridBuffer,
 
  402               typename ABlockBuffer,
 
  403               typename ABlockTransferStep,
 
  406               typename BBlockTransfer,
 
  407               typename BGridBuffer,
 
  408               typename BBlockBuffer,
 
  409               typename BBlockTransferStep,
 
  410               typename CThreadBuffer>
 
  411     __device__ 
void Run(
const AGridDesc& a_grid_desc,
 
  412                         const ABlockDesc& a_block_desc,
 
  413                         ABlockTransfer& a_blockwise_copy,
 
  414                         const AGridBuffer& a_grid_buf,
 
  415                         ABlockBuffer& a_block_buf,
 
  416                         const ABlockTransferStep& a_block_copy_step,
 
  417                         const BGridDesc& b_grid_desc,
 
  418                         const BBlockDesc& b_block_desc,
 
  419                         BBlockTransfer& b_blockwise_copy,
 
  420                         const BGridBuffer& b_grid_buf,
 
  421                         BBlockBuffer& b_block_buf,
 
  422                         const BBlockTransferStep& b_block_copy_step,
 
  423                         CThreadBuffer& c_thread_buf,
 
  426         auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
 
  427             a_thread_desc_.GetElementSpaceSize());
 
  428         auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
 
  429             b_thread_desc_.GetElementSpaceSize());
 
  432         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  433         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  435         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  436         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  439         a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
 
  440         b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
 
  443         c_thread_buf.Clear();
 
  445         auto blockwise_gemm_func = [&]() {
 
  449                         a_block_desc_k0_m0_m1_m2_k1,
 
  461                         b_block_desc_k0_n0_n1_n2_k1,
 
  474                 __builtin_amdgcn_sched_barrier(0);
 
  481                 if constexpr(k0_offset != 0 || KRepeat == 1)
 
  483                     __builtin_amdgcn_s_barrier();
 
  484                     __builtin_amdgcn_sched_barrier(0);
 
  489                             vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
 
  490                             vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
 
  492                             static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
 
  493                                 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
 
  494                                     a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  502                             static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
 
  503                                 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
 
  504                                     b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  513                             using wmma_input_type_a =
 
  514                                 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
 
  515                             using wmma_input_type_b =
 
  516                                 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
 
  519                                 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
 
  527                             if constexpr(k0_offset + k0_inner == KRepeat - 1 && m0 == MRepeat - 1 &&
 
  530                                 __builtin_amdgcn_sched_barrier(0);
 
  532                                 __builtin_amdgcn_sched_barrier(0);
 
  534                             wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
 
  535                                           b_thread_vec.template AsType<wmma_input_type_b>(),
 
  537                             if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
 
  539                                 __builtin_amdgcn_sched_barrier(0);
 
  540                                 __builtin_amdgcn_s_setprio(1);
 
  541                                 __builtin_amdgcn_sched_barrier(0);
 
  546                 __builtin_amdgcn_sched_barrier(0);
 
  547                 __builtin_amdgcn_s_setprio(0);
 
  548                 __builtin_amdgcn_sched_barrier(0);
 
  553         if constexpr(HasMainLoop)
 
  558                 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  559                 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  561                 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  562                 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  565                 blockwise_gemm_func();
 
  567                 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
 
  568                 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
 
  571             } 
while(i < (num_loop - 1));
 
  578             blockwise_gemm_func();
 
  583     static constexpr 
auto a_thread_desc_ =
 
  586                                                 Number<KRepeatPerCluster>{},
 
  592                                                 Number<KPack / A_KRow * MRepeat>{},
 
  597     static constexpr 
auto b_thread_desc_ =
 
  600                                                 Number<KRepeatPerCluster>{},
 
  606                                                 Number<KPack / B_KRow * NRepeat>{},
 
  614                                          decltype(a_block_desc_k0_m0_m1_m2_k1),
 
  615                                          decltype(a_thread_desc_),
 
  616                                          Sequence<KPack / A_K1 / A_KRow, MRepeat, 1, 1, 1, A_K1>,
 
  625                                          decltype(b_block_desc_k0_n0_n1_n2_k1),
 
  626                                          decltype(b_thread_desc_),
 
  627                                          Sequence<KPack / B_K1 / B_KRow, NRepeat, 1, 1, 1, B_K1>,
 
  633     AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
 
  634     BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
 
  635     using Base::c_thread_desc_;
 
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:207
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
 
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:300
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
 
integral_constant< index_t, N > Number
Definition: number.hpp:12
 
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:35
 
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:411
 
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::BlockHasHotloop static bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:388
 
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::BlockLoopTailNum static TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:390
 
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:162
 
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::BlockHasHotloop static bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:139
 
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::BlockLoopTailNum static TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:141
 
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:36
 
Definition: sequence.hpp:43
 
ck::ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeTypeA, decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_thread_desc_), Sequence< KPack/A_K1/A_KRow, MRepeat, 1, 1, 1, A_K1 >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, A_K1, A_K1 >  
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33
 
Definition: dtype_vector.hpp:10