20           typename ComputeDataType,
 
   24           typename AMmaTileDesc,
 
   25           typename BMmaTileDesc,
 
   26           index_t ABlockTransferSrcScalarPerVector,
 
   27           index_t BBlockTransferSrcScalarPerVector,
 
   36 struct BlockwiseGemmXdlops_pipeline_v4
 
   43           typename ComputeDataType,
 
   47           typename AMmaTileDesc,
 
   48           typename BMmaTileDesc,
 
   49           index_t ABlockTransferSrcScalarPerVector,
 
   50           index_t BBlockTransferSrcScalarPerVector,
 
   71                                        ABlockTransferSrcScalarPerVector,
 
   72                                        BBlockTransferSrcScalarPerVector,
 
   90                                         ABlockTransferSrcScalarPerVector,
 
   91                                         BBlockTransferSrcScalarPerVector,
 
  111                                                    ABlockTransferSrcScalarPerVector,
 
  112                                                    BBlockTransferSrcScalarPerVector,
 
  124     using Base::xdlops_gemm;
 
  127     using Base::CalculateCThreadOriginDataIndex;
 
  128     using Base::CalculateCThreadOriginDataIndex8D;
 
  129     using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  130     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  131     using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  132     using Base::GetCThreadBuffer;
 
  133     using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  134     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  135     using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
 
  136     using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
 
  137     using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
 
  139     using Base::a_block_desc_m0_m1_m2_k;
 
  140     using Base::b_block_desc_n0_n1_n2_k;
 
  142     using Base::AMmaKStride;
 
  143     using Base::BMmaKStride;
 
  152         return num_loop > PrefetchStages;
 
  157         if(num_loop % HotloopUnroll == 1)
 
  171         constexpr 
auto num_ds_read_inst_a =
 
  175         constexpr 
auto num_ds_read_inst_b =
 
  181         constexpr 
auto num_dswrite_per_issue_a =
 
  183         constexpr 
auto num_dsread_per_issue_a = num_ds_read_inst_a / num_issue_a;
 
  186         constexpr 
auto num_dswrite_per_issue_b =
 
  188         constexpr 
auto num_dsread_per_issue_b = num_ds_read_inst_b / num_issue_b;
 
  190         constexpr 
auto num_mfma_per_issue =
 
  197                 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); 
 
  198                 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  203                 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); 
 
  204                 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  207             __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); 
 
  208             __builtin_amdgcn_sched_group_barrier(0x008,
 
  209                                                  num_mfma_per_issue - num_dsread_per_issue_a -
 
  210                                                      num_dswrite_per_issue_a,
 
  218                 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); 
 
  219                 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  224                 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); 
 
  225                 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); 
 
  228             __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); 
 
  229             __builtin_amdgcn_sched_group_barrier(0x008,
 
  230                                                  num_mfma_per_issue - num_dsread_per_issue_a -
 
  231                                                      num_dswrite_per_issue_b,
 
  234         __builtin_amdgcn_sched_barrier(0);
 
  237     template <
bool HasMainLoop,
 
  241               typename ABlockTransfer,
 
  242               typename AGridBuffer,
 
  243               typename ABlockBuffer,
 
  244               typename ABlockTransferStep,
 
  247               typename BBlockTransfer,
 
  248               typename BGridBuffer,
 
  249               typename BBlockBuffer,
 
  250               typename BBlockTransferStep,
 
  251               typename CThreadBuffer>
 
  252     __device__ 
void Run(
const AGridDesc& a_grid_desc,
 
  253                         const ABlockDesc& a_block_desc,
 
  254                         ABlockTransfer& a_blockwise_copy,
 
  255                         const AGridBuffer& a_grid_buf,
 
  256                         ABlockBuffer& a_block_buf,
 
  257                         const ABlockTransferStep& a_block_copy_step,
 
  258                         const BGridDesc& b_grid_desc,
 
  259                         const BBlockDesc& b_block_desc,
 
  260                         BBlockTransfer& b_blockwise_copy,
 
  261                         const BGridBuffer& b_grid_buf,
 
  262                         BBlockBuffer& b_block_buf,
 
  263                         const BBlockTransferStep& b_block_copy_step,
 
  264                         CThreadBuffer& c_thread_buf,
 
  267         auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
 
  269         auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
 
  276         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  277         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  279         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  280         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  283         a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(
I0));
 
  284         b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(
I0));
 
  308         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  309         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  311         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  312         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  315         a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(
I1));
 
  316         b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(
I1));
 
  319         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  320         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  322         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  323         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  326         c_thread_buf.Clear();
 
  329         if constexpr(HasMainLoop)
 
  335                 auto LoopFunc = [&](
auto lds_read_buf,
 
  336                                     auto lds_read_reg_buf,
 
  345                                                a_block_buf.At(lds_read_buf),
 
  348                                                a_thread_bufs(lds_read_reg_buf));
 
  353                                                b_block_buf.At(lds_read_buf),
 
  356                                                b_thread_bufs(lds_read_reg_buf));
 
  360                     a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
 
  361                     b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
 
  363                     a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  364                     b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  366                     a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  367                     b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  376                                     a_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  377                                         a_thread_bufs[mfma_reg_buf]
 
  380                                     b_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  381                                         b_thread_bufs[mfma_reg_buf]
 
  386                                 using mfma_input_type =
 
  394                                     a_thread_vec.template AsType<mfma_input_type>(),
 
  395                                     b_thread_vec.template AsType<mfma_input_type>(),
 
  408             } 
while(i < (num_loop - PrefetchStages));
 
  411         auto ReadWriteCompFunc = [&](
auto lds_read_buf,
 
  412                                      auto lds_read_reg_buf,
 
  421                                        a_block_buf.At(lds_read_buf),
 
  424                                        a_thread_bufs(lds_read_reg_buf));
 
  429                                        b_block_buf.At(lds_read_buf),
 
  432                                        b_thread_bufs(lds_read_reg_buf));
 
  436             a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf));
 
  437             b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf));
 
  446                             a_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  449                             b_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  454                         using mfma_input_type =
 
  460                         xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
  461                                         b_thread_vec.template AsType<mfma_input_type>(),
 
  470         auto ReadCompFunc = [&](
auto lds_read_buf, 
auto lds_read_reg_buf, 
auto mfma_reg_buf) {
 
  477                                        a_block_buf.At(lds_read_buf),
 
  480                                        a_thread_bufs(lds_read_reg_buf));
 
  485                                        b_block_buf.At(lds_read_buf),
 
  488                                        b_thread_bufs(lds_read_reg_buf));
 
  499                             a_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  502                             b_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  507                         using mfma_input_type =
 
  513                         xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
  514                                         b_thread_vec.template AsType<mfma_input_type>(),
 
  523         auto CompFunc = [&](
auto mfma_reg_buf) {
 
  531                             a_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  534                             b_thread_vec.template AsType<ComputeDataType>()(ik) =
 
  539                         using mfma_input_type =
 
  545                         xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
 
  546                                         b_thread_vec.template AsType<mfma_input_type>(),
 
  567     using Base::a_thread_copy_;
 
  568     using Base::a_thread_desc_;
 
  569     using Base::b_thread_copy_;
 
  570     using Base::b_thread_desc_;
 
  571     using Base::c_thread_desc_;
 
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
 
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:300
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
 
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
 
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
 
static constexpr index_t B_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:46
 
static constexpr index_t A_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:49
 
static constexpr index_t A_LDS_Read_Width
Definition: blkgemmpipe_scheduler.hpp:82
 
static constexpr index_t B_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:51
 
static constexpr index_t A_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:44
 
static constexpr index_t C_MFMA_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:54
 
static constexpr index_t A_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:39
 
static constexpr index_t B_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:41
 
static constexpr index_t B_LDS_Read_Width
Definition: blkgemmpipe_scheduler.hpp:83
 
ck::BlockwiseGemmXdlops_pipeline_v4< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:150
 
ck::BlockwiseGemmXdlops_pipeline_v4< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::HotLoopScheduler static constexpr __device__ void HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:167
 
ck::BlockwiseGemmXdlops_pipeline_v4< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:155
 
ck::BlockwiseGemmXdlops_pipeline_v4< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:252
 
Definition: blockwise_gemm_pipeline_xdlops.hpp:103
 
static constexpr auto I1
Definition: blockwise_gemm_pipeline_xdlops.hpp:105
 
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:961
 
static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops.hpp:373
 
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:967
 
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops.hpp:991
 
static constexpr auto I0
Definition: blockwise_gemm_pipeline_xdlops.hpp:104
 
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_pipeline_xdlops.hpp:453
 
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops.hpp:990
 
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_pipeline_xdlops.hpp:454
 
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:955
 
static constexpr auto xdlops_gemm
Definition: blockwise_gemm_pipeline_xdlops.hpp:118
 
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1297
 
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33
 
Definition: dtype_vector.hpp:10