20           typename ComputeDataType,
 
   24           typename AMmaTileDesc,
 
   25           typename BMmaTileDesc,
 
   26           index_t ABlockTransferSrcScalarPerVector,
 
   27           index_t BBlockTransferSrcScalarPerVector,
 
   43           typename ComputeDataType,
 
   47           typename AMmaTileDesc,
 
   48           typename BMmaTileDesc,
 
   49           index_t ABlockTransferSrcScalarPerVector,
 
   50           index_t BBlockTransferSrcScalarPerVector,
 
   71                                        ABlockTransferSrcScalarPerVector,
 
   72                                        BBlockTransferSrcScalarPerVector,
 
   90                                         ABlockTransferSrcScalarPerVector,
 
   91                                         BBlockTransferSrcScalarPerVector,
 
  111                                                    ABlockTransferSrcScalarPerVector,
 
  112                                                    BBlockTransferSrcScalarPerVector,
 
  126     using Base::xdlops_gemm;
 
  129     using Base::CalculateCThreadOriginDataIndex;
 
  130     using Base::CalculateCThreadOriginDataIndex8D;
 
  131     using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  132     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  133     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  134     using Base::GetCThreadBuffer;
 
  135     using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  136     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  137     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  138     using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  139     using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  141     using Base::a_block_desc_m0_m1_m2_k;
 
  142     using Base::b_block_desc_n0_n1_n2_k;
 
  144     using Base::AMmaKStride;
 
  145     using Base::BMmaKStride;
 
  154         return num_loop > PrefetchStages;
 
  159         if(num_loop % HotloopUnroll == 1)
 
  174         constexpr 
auto num_ds_read_inst_a =
 
  175             HotLoopInstList::A_LDS_Read_Width * 
sizeof(ADataType) == 16
 
  176                 ? HotLoopInstList::A_LDS_Read_Inst_Num
 
  177                 : HotLoopInstList::A_LDS_Read_Inst_Num / 2;
 
  178         constexpr 
auto num_ds_read_inst_b =
 
  179             HotLoopInstList::B_LDS_Read_Width * 
sizeof(BDataType) == 16
 
  180                 ? HotLoopInstList::B_LDS_Read_Inst_Num
 
  181                 : HotLoopInstList::B_LDS_Read_Inst_Num / 2;
 
  183         constexpr 
auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
 
  184         constexpr 
auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
 
  186         constexpr 
auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
 
  187         constexpr 
auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
 
  189         constexpr 
auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
 
  191         constexpr 
auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
 
  192         constexpr 
auto ds_read_a_issue_cycle =
 
  193             HotLoopInstList::A_LDS_Read_Width * 
sizeof(ADataType) == 16 ? 8 : 4;
 
  194         constexpr 
auto ds_read_b_issue_cycle =
 
  195             HotLoopInstList::B_LDS_Read_Width * 
sizeof(BDataType) == 16 ? 8 : 4;
 
  196         constexpr 
auto ds_read_a_mfma_rate =
 
  197             (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
 
  198         constexpr 
auto ds_read_b_mfma_rate =
 
  199             (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
 
  201         constexpr 
auto num_dsread_stage1_a = num_ds_read_inst_a / KRepeat * (KRepeat - 1);
 
  202         constexpr 
auto num_dsread_stage1_b = num_ds_read_inst_b / KRepeat * (KRepeat - 1);
 
  203         constexpr 
auto num_dsread_stage3_a = num_ds_read_inst_a / KRepeat;
 
  204         constexpr 
auto num_dsread_stage3_b = num_ds_read_inst_b / KRepeat;
 
  206         constexpr 
auto num_dsread_stage1_a_mfma =
 
  207             (num_dsread_stage1_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
 
  208         constexpr 
auto num_dsread_stage1_b_mfma =
 
  209             (num_dsread_stage1_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
 
  210         constexpr 
auto num_dsread_stage3_a_mfma =
 
  211             (num_dsread_stage3_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
 
  212         constexpr 
auto num_dsread_stage3_b_mfma =
 
  213             (num_dsread_stage3_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
 
  215         constexpr 
auto num_mfma_stage2 = num_mfma_inst - num_ds_read_inst_a / ds_read_a_mfma_rate -
 
  216                                          num_ds_read_inst_b / ds_read_b_mfma_rate;
 
  217         constexpr 
auto num_mfma_per_issue =
 
  218             num_mfma_stage2 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
 
  219         constexpr 
auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
 
  220         constexpr 
auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
 
  225             if constexpr((num_dsread_stage1_a - (i + 1) * ds_read_a_mfma_rate) >=
 
  228                 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); 
 
  232                 __builtin_amdgcn_sched_group_barrier(
 
  234                     num_dsread_stage1_a - (num_dsread_stage1_a_mfma - 1) * ds_read_a_mfma_rate,
 
  237             __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  241             if constexpr((num_dsread_stage1_b - (i + 1) * ds_read_b_mfma_rate) >=
 
  244                 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); 
 
  248                 __builtin_amdgcn_sched_group_barrier(
 
  250                     num_dsread_stage1_b - (num_dsread_stage1_b_mfma - 1) * ds_read_b_mfma_rate,
 
  253             __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  261                 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); 
 
  262                 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  264             __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); 
 
  265             __builtin_amdgcn_sched_group_barrier(
 
  266                 0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); 
 
  272                 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); 
 
  273                 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  275             __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); 
 
  276             __builtin_amdgcn_sched_group_barrier(
 
  277                 0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); 
 
  283             if constexpr((num_dsread_stage3_a - (i + 1) * ds_read_a_mfma_rate) >=
 
  286                 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); 
 
  290                 __builtin_amdgcn_sched_group_barrier(
 
  292                     num_dsread_stage3_a - (num_dsread_stage3_a_mfma - 1) * ds_read_a_mfma_rate,
 
  295             __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  299             if constexpr((num_dsread_stage3_b - (i + 1) * ds_read_b_mfma_rate) >=
 
  302                 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); 
 
  306                 __builtin_amdgcn_sched_group_barrier(
 
  308                     num_dsread_stage3_b - (num_dsread_stage3_b_mfma - 1) * ds_read_b_mfma_rate,
 
  311             __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  316         __builtin_amdgcn_sched_barrier(0);
 
  319     template <
bool HasMainLoop,
 
  323               typename ABlockTransfer,
 
  324               typename AGridBuffer,
 
  325               typename ABlockBuffer,
 
  326               typename ABlockTransferStep,
 
  329               typename BBlockTransfer,
 
  330               typename BGridBuffer,
 
  331               typename BBlockBuffer,
 
  332               typename BBlockTransferStep,
 
  333               typename CThreadBuffer>
 
  334     __device__ 
void Run(
const AGridDesc& a_grid_desc,
 
  335                         const ABlockDesc& a_block_desc,
 
  336                         ABlockTransfer& a_blockwise_copy,
 
  337                         const AGridBuffer& a_grid_buf,
 
  338                         ABlockBuffer& a_block_buf,
 
  339                         const ABlockTransferStep& a_block_copy_step,
 
  340                         const BGridDesc& b_grid_desc,
 
  341                         const BBlockDesc& b_block_desc,
 
  342                         BBlockTransfer& b_blockwise_copy,
 
  343                         const BGridBuffer& b_grid_buf,
 
  344                         BBlockBuffer& b_block_buf,
 
  345                         const BBlockTransferStep& b_block_copy_step,
 
  346                         CThreadBuffer& c_thread_buf,
 
  349         auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
 
  350             a_thread_desc_.GetElementSpaceSize());
 
  351         auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
 
  352             b_thread_desc_.GetElementSpaceSize());
 
  355         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
 
  356         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
 
  358         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  359         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  362         a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
 
  363         b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
 
  366         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
 
  367         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
 
  369         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  370         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  373         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
 
  374         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
 
  376         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  377         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  380         c_thread_buf.Clear();
 
  385             a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
 
  393             b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
 
  402         if constexpr(HasMainLoop)
 
  407                 auto LoopFunc = [&](
auto vmem_buf) {
 
  412                         if constexpr(k0 == (KRepeat - 1))
 
  416                             a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, vmem_buf);
 
  417                             b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, vmem_buf);
 
  419                             a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, vmem_buf);
 
  420                             b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, vmem_buf);
 
  422                             a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  423                             b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  430                                     a_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  431                                         a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  435                                     b_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  436                                         b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  440                                 using mfma_input_type =
 
  442                                                          xdlops_gemm.K1PerXdlops>::type;
 
  445                                     c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  448                                     a_thread_vec.template AsType<mfma_input_type>(),
 
  449                                     b_thread_vec.template AsType<mfma_input_type>(),
 
  454                                 a_block_desc_m0_m1_m2_k,
 
  464                                 b_block_desc_n0_n1_n2_k,
 
  480             } 
while(i < (num_loop - PrefetchStages));
 
  483         auto ReadWriteCompFunc = [&](
auto vmem_buf) {
 
  488                 if constexpr(k0 == (KRepeat - 1))
 
  492                     a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, vmem_buf);
 
  493                     b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, vmem_buf);
 
  500                             a_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  501                                 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  505                             b_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  506                                 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  510                         using mfma_input_type =
 
  511                             typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
 
  514                             c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  516                         xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
  517                                         b_thread_vec.template AsType<mfma_input_type>(),
 
  521                         a_block_desc_m0_m1_m2_k,
 
  531                         b_block_desc_n0_n1_n2_k,
 
  542         auto ReadCompFunc = [&]() {
 
  546             static_for<0, KRepeat - 1, 1>{}([&](
auto k0) {
 
  550                             a_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  551                                 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  555                             b_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  556                                 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  560                         using mfma_input_type =
 
  561                             typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
 
  564                             c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  566                         xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
  567                                         b_thread_vec.template AsType<mfma_input_type>(),
 
  572                         a_block_desc_m0_m1_m2_k,
 
  582                         b_block_desc_n0_n1_n2_k,
 
  594                         a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_buf
 
  598                         b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_buf
 
  602                     using mfma_input_type =
 
  603                         typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
 
  606                         c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
 
  608                     xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
  609                                     b_thread_vec.template AsType<mfma_input_type>(),
 
  619             ReadWriteCompFunc(I0);
 
  620             ReadWriteCompFunc(I1);
 
  625             ReadWriteCompFunc(I0);
 
  632     static constexpr 
auto a_thread_desc_ =
 
  636     static constexpr 
auto b_thread_desc_ =
 
  641                                                          decltype(a_block_desc_m0_m1_m2_k),
 
  642                                                          decltype(a_thread_desc_),
 
  651                                                          decltype(b_block_desc_n0_n1_n2_k),
 
  652                                                          decltype(b_thread_desc_),
 
  659     AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
 
  660     BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
 
  661     using Base::c_thread_desc_;
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:297
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
 
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
 
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
 
ck::BlockwiseGemmXdlops_pipeline_v5< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::HotLoopScheduler static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops_v5.hpp:169
 
ck::BlockwiseGemmXdlops_pipeline_v5< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v5.hpp:334
 
ck::BlockwiseGemmXdlops_pipeline_v5< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum static constexpr __host__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v5.hpp:157
 
ck::BlockwiseGemmXdlops_pipeline_v5< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop static constexpr __host__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v5.hpp:152
 
Definition: blockwise_gemm_pipeline_xdlops_v5.hpp:37
 
Definition: sequence.hpp:43
 
ck::ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataType, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPack >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 >  
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33
 
Definition: dtype_vector.hpp:10