20           typename ComputeDataType,
 
   24           typename AMmaTileDesc,
 
   25           typename BMmaTileDesc,
 
   26           index_t ABlockTransferSrcScalarPerVector,
 
   27           index_t BBlockTransferSrcScalarPerVector,
 
   43           typename ComputeDataType,
 
   47           typename AMmaTileDesc,
 
   48           typename BMmaTileDesc,
 
   49           index_t ABlockTransferSrcScalarPerVector,
 
   50           index_t BBlockTransferSrcScalarPerVector,
 
   71                                        ABlockTransferSrcScalarPerVector,
 
   72                                        BBlockTransferSrcScalarPerVector,
 
   90                                         ABlockTransferSrcScalarPerVector,
 
   91                                         BBlockTransferSrcScalarPerVector,
 
  111                                                    ABlockTransferSrcScalarPerVector,
 
  112                                                    BBlockTransferSrcScalarPerVector,
 
  124     using Base::xdlops_gemm;
 
  127     using Base::CalculateCThreadOriginDataIndex;
 
  128     using Base::CalculateCThreadOriginDataIndex8D;
 
  129     using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  130     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  131     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  132     using Base::GetCThreadBuffer;
 
  133     using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  134     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  135     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  136     using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  137     using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  139     using Base::a_block_desc_m0_m1_m2_k;
 
  140     using Base::b_block_desc_n0_n1_n2_k;
 
  142     using Base::AMmaKStride;
 
  143     using Base::BMmaKStride;
 
  151         return num_loop > PrefetchStages;
 
  164         constexpr 
auto num_ds_read_inst_a =
 
  165             HotLoopInstList::A_LDS_Read_Width * 
sizeof(ADataType) == 16
 
  166                 ? HotLoopInstList::A_LDS_Read_Inst_Num
 
  167                 : HotLoopInstList::A_LDS_Read_Inst_Num / 2;
 
  168         constexpr 
auto num_ds_read_inst_b =
 
  169             HotLoopInstList::B_LDS_Read_Width * 
sizeof(BDataType) == 16
 
  170                 ? HotLoopInstList::B_LDS_Read_Inst_Num
 
  171                 : HotLoopInstList::B_LDS_Read_Inst_Num / 2;
 
  173         constexpr 
auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
 
  174         constexpr 
auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
 
  176         constexpr 
auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
 
  177         constexpr 
auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
 
  179         constexpr 
auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
 
  180         constexpr 
auto mfma_cycle    = HotLoopInstList::C_MFMA_Inst_Cycle;
 
  182         constexpr 
auto ds_read_a_issue_cycle =
 
  183             HotLoopInstList::A_LDS_Read_Width * 
sizeof(ADataType) == 16 ? 8 : 4;
 
  184         constexpr 
auto ds_read_b_issue_cycle =
 
  185             HotLoopInstList::B_LDS_Read_Width * 
sizeof(BDataType) == 16 ? 8 : 4;
 
  186         constexpr 
auto ds_read_a_mfma_rate =
 
  187             (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
 
  188         constexpr 
auto ds_read_b_mfma_rate =
 
  189             (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
 
  191         constexpr 
auto num_dsread_a_mfma =
 
  192             (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
 
  193         constexpr 
auto num_dsread_b_mfma =
 
  194             (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
 
  202         constexpr 
auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
 
  203         constexpr 
auto num_mfma_per_issue =
 
  204             num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
 
  205         constexpr 
auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
 
  206         constexpr 
auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
 
  212                 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); 
 
  213                 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  215             __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); 
 
  216             __builtin_amdgcn_sched_group_barrier(
 
  217                 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); 
 
  223                 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); 
 
  224                 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  226             __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); 
 
  227             __builtin_amdgcn_sched_group_barrier(
 
  228                 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); 
 
  233             if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
 
  236                 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); 
 
  240                 __builtin_amdgcn_sched_group_barrier(0x100,
 
  241                                                      num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
 
  245             __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  249             if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
 
  252                 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); 
 
  256                 __builtin_amdgcn_sched_group_barrier(0x100,
 
  257                                                      num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
 
  261             __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  265     template <
bool HasMainLoop,
 
  269               typename ABlockTransfer,
 
  270               typename AGridBuffer,
 
  271               typename ABlockBuffer,
 
  272               typename ABlockTransferStep,
 
  275               typename BBlockTransfer,
 
  276               typename BGridBuffer,
 
  277               typename BBlockBuffer,
 
  278               typename BBlockTransferStep,
 
  279               typename CThreadBuffer>
 
  280     __device__ 
void Run(
const AGridDesc& a_grid_desc,
 
  281                         const ABlockDesc& a_block_desc,
 
  282                         ABlockTransfer& a_blockwise_copy,
 
  283                         const AGridBuffer& a_grid_buf,
 
  284                         ABlockBuffer& a_block_buf,
 
  285                         const ABlockTransferStep& a_block_copy_step,
 
  286                         const BGridDesc& b_grid_desc,
 
  287                         const BBlockDesc& b_block_desc,
 
  288                         BBlockTransfer& b_blockwise_copy,
 
  289                         const BGridBuffer& b_grid_buf,
 
  290                         BBlockBuffer& b_block_buf,
 
  291                         const BBlockTransferStep& b_block_copy_step,
 
  292                         CThreadBuffer& c_thread_buf,
 
  295         __builtin_amdgcn_sched_barrier(0);
 
  296         auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
 
  297             a_thread_desc_.GetElementSpaceSize());
 
  298         auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
 
  299             b_thread_desc_.GetElementSpaceSize());
 
  302         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  303         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  305         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  306         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  309         a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
 
  310         b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
 
  313         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  314         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  316         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  317         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  320         c_thread_buf.Clear();
 
  326                 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  334                 b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  343         __builtin_amdgcn_sched_barrier(0);
 
  346         if constexpr(HasMainLoop)
 
  353                 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
 
  354                 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
 
  356                 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  357                 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  359                 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  360                 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  369                                 a_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  370                                     a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  372                                 b_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  373                                     b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  377                             using mfma_input_type =
 
  379                                                      xdlops_gemm.K1PerXdlops>::type;
 
  382                                 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  385                                 a_thread_vec.template AsType<mfma_input_type>(),
 
  386                                 b_thread_vec.template AsType<mfma_input_type>(),
 
  396                         a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  404                         b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  414                 __builtin_amdgcn_sched_barrier(0);
 
  417             } 
while(i < (num_loop - 1));
 
  429                             a_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  430                                 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  432                             b_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  433                                 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  437                         using mfma_input_type =
 
  438                             typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
 
  441                             c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  443                         xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
  444                                         b_thread_vec.template AsType<mfma_input_type>(),
 
  456     using Base::a_thread_copy_;
 
  457     using Base::a_thread_desc_;
 
  458     using Base::b_thread_copy_;
 
  459     using Base::b_thread_desc_;
 
  460     using Base::c_thread_desc_;
 
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:300
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
 
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
 
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
 
ck::BlockwiseGemmXdlops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v3.hpp:154
 
ck::BlockwiseGemmXdlops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v3.hpp:280
 
ck::BlockwiseGemmXdlops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::HotLoopScheduler static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops_v3.hpp:160
 
ck::BlockwiseGemmXdlops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v3.hpp:149
 
Definition: blockwise_gemm_pipeline_xdlops_v3.hpp:37
 
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33
 
Definition: dtype_vector.hpp:10