20           typename ComputeTypeA,
 
   21           typename ComputeTypeB,
 
   23           typename AWmmaTileDesc,
 
   24           typename BWmmaTileDesc,
 
   25           index_t ABlockTransferSrcScalarPerVector,
 
   26           index_t BBlockTransferSrcScalarPerVector,
 
   42           typename ComputeTypeA,
 
   43           typename ComputeTypeB,
 
   45           typename AWmmaTileDesc,
 
   46           typename BWmmaTileDesc,
 
   47           index_t ABlockTransferSrcScalarPerVector,
 
   48           index_t BBlockTransferSrcScalarPerVector,
 
   66                                         ABlockTransferSrcScalarPerVector,
 
   67                                         BBlockTransferSrcScalarPerVector,
 
   84                                          ABlockTransferSrcScalarPerVector,
 
   85                                          BBlockTransferSrcScalarPerVector,
 
  104                                                     ABlockTransferSrcScalarPerVector,
 
  105                                                     BBlockTransferSrcScalarPerVector,
 
  123     using Base::wmma_gemm;
 
  125     using Base::CalculateCThreadOriginDataIndex;
 
  127         GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
 
  128     using Base::GetCThreadBuffer;
 
  130         GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
 
  132     using Base::a_block_desc_k0_m0_m1_m2_k1;
 
  133     using Base::b_block_desc_k0_n0_n1_n2_k1;
 
  149     template <
bool HasMainLoop,
 
  153               typename ABlockTransfer,
 
  154               typename AGridBuffer,
 
  155               typename ABlockBuffer,
 
  156               typename ABlockTransferStep,
 
  159               typename BBlockTransfer,
 
  160               typename BGridBuffer,
 
  161               typename BBlockBuffer,
 
  162               typename BBlockTransferStep,
 
  163               typename CThreadBuffer,
 
  164               typename BScaleStruct>
 
  165     __device__ 
void Run(
const AGridDesc& a_grid_desc,
 
  166                         const ABlockDesc& a_block_desc,
 
  167                         ABlockTransfer& a_blockwise_copy,
 
  168                         const AGridBuffer& a_grid_buf,
 
  169                         ABlockBuffer& a_block_buf,
 
  170                         const ABlockTransferStep& a_block_copy_step,
 
  171                         const BGridDesc& b_grid_desc,
 
  172                         const BBlockDesc& b_block_desc,
 
  173                         BBlockTransfer& b_blockwise_copy,
 
  174                         const BGridBuffer& b_grid_buf,
 
  175                         BBlockBuffer& b_block_buf,
 
  176                         const BBlockTransferStep& b_block_copy_step,
 
  177                         CThreadBuffer& c_thread_buf,
 
  179                         BScaleStruct& b_scale_struct,
 
  181                         index_t num_loop_per_scale)
 const 
  183         auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
 
  184             a_thread_desc_.GetElementSpaceSize());
 
  185         auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
 
  186             b_thread_desc_.GetElementSpaceSize());
 
  189         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  190         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  192         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  193         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  195         b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
 
  198         a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
 
  199         b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
 
  202         c_thread_buf.Clear();
 
  204         auto blockwise_gemm_func = [&]() {
 
  208                         a_block_desc_k0_m0_m1_m2_k1,
 
  219                             b_block_desc_k0_n0_n1_n2_k1,
 
  231                             b_block_desc_k0_n0_n1_n2_k1,
 
  234                             b_scale_struct.b_scale_thread_bufs(
 
  235                                 I0)[
Number<n0 * BScaleStruct::num_scale_k_block +
 
  236                                            k0 / BScaleStruct::num_scale_krepeat>{}],
 
  245                         vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
 
  246                         vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
 
  248                         static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
 
  249                             a_thread_vec.template AsType<ComputeTypeA>()(ik) =
 
  253                         static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
 
  254                             b_thread_vec.template AsType<ComputeTypeB>()(ik) =
 
  259                         using wmma_input_type_a =
 
  260                             typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
 
  261                         using wmma_input_type_b =
 
  262                             typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
 
  265                             c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
 
  267                         wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
 
  268                                       b_thread_vec.template AsType<wmma_input_type_b>(),
 
  276         if constexpr(HasMainLoop)
 
  281                 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  282                 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  284                 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  285                 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  288                 blockwise_gemm_func();
 
  291                 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
 
  292                 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
 
  293                 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
 
  296             } 
while(i < (num_loop - 1));
 
  303             blockwise_gemm_func();
 
  308     using Base::a_thread_copy_;
 
  309     using Base::a_thread_desc_;
 
  310     using Base::b_thread_copy_;
 
  311     using Base::b_thread_desc_;
 
  312     using Base::c_thread_desc_;
 
  318           typename ComputeTypeA,
 
  319           typename ComputeTypeB,
 
  320           typename AccDataType,
 
  321           typename AWmmaTileDesc,
 
  322           typename BWmmaTileDesc,
 
  323           index_t ABlockTransferSrcScalarPerVector,
 
  324           index_t BBlockTransferSrcScalarPerVector,
 
  342                                         ABlockTransferSrcScalarPerVector,
 
  343                                         BBlockTransferSrcScalarPerVector,
 
  360                                          ABlockTransferSrcScalarPerVector,
 
  361                                          BBlockTransferSrcScalarPerVector,
 
  380                                                     ABlockTransferSrcScalarPerVector,
 
  381                                                     BBlockTransferSrcScalarPerVector,
 
  400     using Base::wmma_gemm;
 
  402     using Base::CalculateCThreadOriginDataIndex;
 
  404         GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
 
  405     using Base::GetCThreadBuffer;
 
  407         GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
 
  409     using Base::a_block_desc_k0_m0_m1_m2_k1;
 
  410     using Base::b_block_desc_k0_n0_n1_n2_k1;
 
  429     template <
bool HasMainLoop,
 
  433               typename ABlockTransfer,
 
  434               typename AGridBuffer,
 
  435               typename ABlockBuffer,
 
  436               typename ABlockTransferStep,
 
  439               typename BBlockTransfer,
 
  440               typename BGridBuffer,
 
  441               typename BBlockBuffer,
 
  442               typename BBlockTransferStep,
 
  443               typename CThreadBuffer,
 
  444               typename BScaleStruct>
 
  445     __device__ 
void Run(
const AGridDesc& a_grid_desc,
 
  446                         const ABlockDesc& a_block_desc,
 
  447                         ABlockTransfer& a_blockwise_copy,
 
  448                         const AGridBuffer& a_grid_buf,
 
  449                         ABlockBuffer& a_block_buf,
 
  450                         const ABlockTransferStep& a_block_copy_step,
 
  451                         const BGridDesc& b_grid_desc,
 
  452                         const BBlockDesc& b_block_desc,
 
  453                         BBlockTransfer& b_blockwise_copy,
 
  454                         const BGridBuffer& b_grid_buf,
 
  455                         BBlockBuffer& b_block_buf,
 
  456                         const BBlockTransferStep& b_block_copy_step,
 
  457                         CThreadBuffer& c_thread_buf,
 
  459                         BScaleStruct& b_scale_struct,
 
  461                         index_t num_loop_per_scale)
 const 
  463         auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
 
  464             a_thread_desc_.GetElementSpaceSize());
 
  465         auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
 
  466             b_thread_desc_.GetElementSpaceSize());
 
  469         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  470         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  472         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  473         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  475         b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
 
  478         a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
 
  479         b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
 
  482         c_thread_buf.Clear();
 
  484         auto blockwise_gemm_func = [&]() {
 
  489                             a_block_desc_k0_m0_m1_m2_k1,
 
  505                                 b_block_desc_k0_n0_n1_n2_k1,
 
  522                                 b_block_desc_k0_n0_n1_n2_k1,
 
  530                                 b_scale_struct.b_scale_thread_bufs(I0)[
Number<
 
  531                                     n0 * BScaleStruct::num_scale_k_block +
 
  532                                     (k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
 
  540                 __builtin_amdgcn_sched_barrier(0);
 
  547                 if constexpr(k0_offset != 0 || KRepeat == 1)
 
  549                     __builtin_amdgcn_s_barrier();
 
  550                     __builtin_amdgcn_sched_barrier(0);
 
  555                             vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
 
  556                             vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
 
  558                             static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
 
  559                                 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
 
  560                                     a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  568                             static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
 
  569                                 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
 
  570                                     b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  579                             using wmma_input_type_a =
 
  580                                 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
 
  581                             using wmma_input_type_b =
 
  582                                 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
 
  585                                 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
 
  593                             if constexpr(k0_offset + k0_inner == KRepeat - 1 && m0 == MRepeat - 1 &&
 
  596                                 __builtin_amdgcn_sched_barrier(0);
 
  598                                 __builtin_amdgcn_sched_barrier(0);
 
  600                             wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
 
  601                                           b_thread_vec.template AsType<wmma_input_type_b>(),
 
  603                             if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
 
  605                                 __builtin_amdgcn_sched_barrier(0);
 
  606                                 __builtin_amdgcn_s_setprio(1);
 
  607                                 __builtin_amdgcn_sched_barrier(0);
 
  612                 __builtin_amdgcn_sched_barrier(0);
 
  613                 __builtin_amdgcn_s_setprio(0);
 
  614                 __builtin_amdgcn_sched_barrier(0);
 
  619         if constexpr(HasMainLoop)
 
  624                 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  625                 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  627                 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  628                 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  631                 blockwise_gemm_func();
 
  633                 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
 
  634                 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
 
  635                 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
 
  638             } 
while(i < (num_loop - 1));
 
  645             blockwise_gemm_func();
 
  650     static constexpr 
auto a_thread_desc_ =
 
  653                                                 Number<KRepeatPerCluster>{},
 
  659                                                 Number<KPack / A_KRow * MRepeat>{},
 
  664     static constexpr 
auto b_thread_desc_ =
 
  667                                                 Number<KRepeatPerCluster>{},
 
  673                                                 Number<KPack / B_KRow * NRepeat>{},
 
  681                                          decltype(a_block_desc_k0_m0_m1_m2_k1),
 
  682                                          decltype(a_thread_desc_),
 
  683                                          Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
 
  692                                          decltype(b_block_desc_k0_n0_n1_n2_k1),
 
  693                                          decltype(b_thread_desc_),
 
  694                                          Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
 
  700     AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
 
  701     BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
 
  702     using Base::c_thread_desc_;
 
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:207
 
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
 
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:297
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
 
integral_constant< index_t, N > Number
Definition: number.hpp:12
 
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:95
 
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:35
 
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::BlockHasHotloop static bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:421
 
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::BlockLoopTailNum static TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:423
 
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:445
 
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:165
 
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::BlockHasHotloop static bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:141
 
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::BlockLoopTailNum static TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:143
 
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:36
 
Definition: sequence.hpp:43
 
ck::ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeTypeA, decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_thread_desc_), Sequence< KPack/A_K1/A_KRow, 1, 1, 1, 1, A_K1 >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, A_K1, A_K1 >  
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33
 
Definition: dtype_vector.hpp:10