20           typename ComputeTypeA,
 
   21           typename ComputeTypeB,
 
   23           typename AWmmaTileDesc,
 
   24           typename BWmmaTileDesc,
 
   25           index_t ABlockTransferSrcScalarPerVector,
 
   26           index_t BBlockTransferSrcScalarPerVector,
 
   42           typename ComputeTypeA,
 
   43           typename ComputeTypeB,
 
   45           typename AWmmaTileDesc,
 
   46           typename BWmmaTileDesc,
 
   47           index_t ABlockTransferSrcScalarPerVector,
 
   48           index_t BBlockTransferSrcScalarPerVector,
 
   66                                         ABlockTransferSrcScalarPerVector,
 
   67                                         BBlockTransferSrcScalarPerVector,
 
   84                                          ABlockTransferSrcScalarPerVector,
 
   85                                          BBlockTransferSrcScalarPerVector,
 
  103                                                     ABlockTransferSrcScalarPerVector,
 
  104                                                     BBlockTransferSrcScalarPerVector,
 
  122     using Base::wmma_gemm;
 
  125     using Base::CalculateCThreadOriginDataIndex;
 
  127         GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
 
  128     using Base::GetCThreadBuffer;
 
  130         GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
 
  132     using Base::a_block_desc_k0_m0_m1_m2_k1;
 
  133     using Base::b_block_desc_k0_n0_n1_n2_k1;
 
  141         return num_loop > PrefetchStages;
 
  258     template <
bool HasMainLoop,
 
  262               typename ABlockTransfer,
 
  263               typename AGridBuffer,
 
  264               typename ABlockBuffer,
 
  265               typename ABlockTransferStep,
 
  268               typename BBlockTransfer,
 
  269               typename BGridBuffer,
 
  270               typename BBlockBuffer,
 
  271               typename BBlockTransferStep,
 
  272               typename CThreadBuffer>
 
  273     __device__ 
void Run(
const AGridDesc& a_grid_desc,
 
  274                         const ABlockDesc& a_block_desc,
 
  275                         ABlockTransfer& a_blockwise_copy,
 
  276                         const AGridBuffer& a_grid_buf,
 
  277                         ABlockBuffer& a_block_buf,
 
  278                         const ABlockTransferStep& a_block_copy_step,
 
  279                         const BGridDesc& b_grid_desc,
 
  280                         const BBlockDesc& b_block_desc,
 
  281                         BBlockTransfer& b_blockwise_copy,
 
  282                         const BGridBuffer& b_grid_buf,
 
  283                         BBlockBuffer& b_block_buf,
 
  284                         const BBlockTransferStep& b_block_copy_step,
 
  285                         CThreadBuffer& c_thread_buf,
 
  288         __builtin_amdgcn_sched_barrier(0);
 
  289         auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
 
  290             a_thread_desc_.GetElementSpaceSize());
 
  291         auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
 
  292             b_thread_desc_.GetElementSpaceSize());
 
  295         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  296         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  298         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  299         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  302         a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
 
  303         b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
 
  306         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  307         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  309         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  310         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  313         c_thread_buf.Clear();
 
  318             a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
 
  324             b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
 
  332         __builtin_amdgcn_sched_barrier(0);
 
  335         if constexpr(HasMainLoop)
 
  342                 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
 
  343                 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
 
  345                 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  346                 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  348                 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  349                 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  354                             vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
 
  355                             vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
 
  357                             static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
 
  358                                 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
 
  359                                     a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  367                             static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
 
  368                                 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
 
  369                                     b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  378                             using wmma_input_type_a =
 
  379                                 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
 
  380                             using wmma_input_type_b =
 
  381                                 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
 
  384                                 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
 
  386                             wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
 
  387                                           b_thread_vec.template AsType<wmma_input_type_b>(),
 
  397                         a_block_desc_k0_m0_m1_m2_k1,
 
  404                         b_block_desc_k0_n0_n1_n2_k1,
 
  413                 __builtin_amdgcn_sched_barrier(0);
 
  416             } 
while(i < (num_loop - 1));
 
  424                         vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
 
  425                         vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
 
  427                         static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
 
  428                             a_thread_vec.template AsType<ComputeTypeA>()(ik) =
 
  432                         static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
 
  433                             b_thread_vec.template AsType<ComputeTypeB>()(ik) =
 
  438                         using wmma_input_type_a =
 
  439                             typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
 
  440                         using wmma_input_type_b =
 
  441                             typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
 
  444                             c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
 
  446                         wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
 
  447                                       b_thread_vec.template AsType<wmma_input_type_b>(),
 
  459     using Base::a_thread_copy_;
 
  460     using Base::a_thread_desc_;
 
  461     using Base::b_thread_copy_;
 
  462     using Base::b_thread_desc_;
 
  463     using Base::c_thread_desc_;
 
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:300
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
 
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:35
 
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
 
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:139
 
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:144
 
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::HotLoopScheduler static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:150
 
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:273
 
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:36
 
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33
 
Definition: dtype_vector.hpp:10