20           typename ComputeTypeA,
 
   21           typename ComputeTypeB,
 
   23           typename AWmmaTileDesc,
 
   24           typename BWmmaTileDesc,
 
   25           index_t ABlockTransferSrcScalarPerVector,
 
   26           index_t BBlockTransferSrcScalarPerVector,
 
   42           typename ComputeTypeA,
 
   43           typename ComputeTypeB,
 
   45           typename AWmmaTileDesc,
 
   46           typename BWmmaTileDesc,
 
   47           index_t ABlockTransferSrcScalarPerVector,
 
   48           index_t BBlockTransferSrcScalarPerVector,
 
   66                                         ABlockTransferSrcScalarPerVector,
 
   67                                         BBlockTransferSrcScalarPerVector,
 
   84                                          ABlockTransferSrcScalarPerVector,
 
   85                                          BBlockTransferSrcScalarPerVector,
 
  103                                                     ABlockTransferSrcScalarPerVector,
 
  104                                                     BBlockTransferSrcScalarPerVector,
 
  122     using Base::wmma_gemm;
 
  125     using Base::CalculateCThreadOriginDataIndex;
 
  127         GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
 
  128     using Base::GetCThreadBuffer;
 
  130         GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
 
  132     using Base::a_block_desc_k0_m0_m1_m2_k1;
 
  133     using Base::b_block_desc_k0_n0_n1_n2_k1;
 
  143         return num_loop > PrefetchStages;
 
  260     template <
typename ABlockBuffer,
 
  261               typename AThreadBuffer,
 
  262               typename BBlockBuffer,
 
  263               typename BThreadBuffer,
 
  264               typename BScaleStruct>
 
  265     __device__ 
inline void LocalLoad(ABlockBuffer& a_block_buf,
 
  266                                      AThreadBuffer& a_thread_buf,
 
  267                                      BBlockBuffer& b_block_buf,
 
  268                                      BThreadBuffer& b_thread_buf,
 
  269                                      BScaleStruct& b_scale_struct)
 const 
  274                     a_block_desc_k0_m0_m1_m2_k1,
 
  282             if constexpr(ck::is_same_v<BScaleStruct, Empty>)
 
  286                         b_block_desc_k0_n0_n1_n2_k1,
 
  298                         b_block_desc_k0_n0_n1_n2_k1,
 
  301                         b_scale_struct.b_scale_thread_bufs(
 
  302                             I0)[
Number<n0 * BScaleStruct::num_scale_k_block +
 
  303                                        k0 / BScaleStruct::num_scale_krepeat>{}],
 
  312     template <
bool HasMainLoop,
 
  316               typename ABlockTransfer,
 
  317               typename AGridBuffer,
 
  318               typename ABlockBuffer,
 
  319               typename ABlockTransferStep,
 
  322               typename BBlockTransfer,
 
  323               typename BGridBuffer,
 
  324               typename BBlockBuffer,
 
  325               typename BBlockTransferStep,
 
  326               typename CThreadBuffer,
 
  327               typename BScaleStruct>
 
  328     __device__ 
void Run(
const AGridDesc& a_grid_desc,
 
  329                         const ABlockDesc& a_block_desc,
 
  330                         ABlockTransfer& a_blockwise_copy,
 
  331                         const AGridBuffer& a_grid_buf,
 
  332                         ABlockBuffer& a_block_buf,
 
  333                         const ABlockTransferStep& a_block_copy_step,
 
  334                         const BGridDesc& b_grid_desc,
 
  335                         const BBlockDesc& b_block_desc,
 
  336                         BBlockTransfer& b_blockwise_copy,
 
  337                         const BGridBuffer& b_grid_buf,
 
  338                         BBlockBuffer& b_block_buf,
 
  339                         const BBlockTransferStep& b_block_copy_step,
 
  340                         CThreadBuffer& c_thread_buf,
 
  342                         BScaleStruct& b_scale_struct,
 
  344                         index_t num_loop_per_scale)
 const 
  346         __builtin_amdgcn_sched_barrier(0);
 
  347         auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
 
  348             a_thread_desc_.GetElementSpaceSize());
 
  349         auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
 
  350             b_thread_desc_.GetElementSpaceSize());
 
  353         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  354         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  356         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  357         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  359         b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
 
  362         a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
 
  363         b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
 
  366         a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  367         b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  369         a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  370         b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  373         c_thread_buf.Clear();
 
  378         LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
 
  380         __builtin_amdgcn_sched_barrier(0);
 
  383         if constexpr(HasMainLoop)
 
  390                 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
 
  391                 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
 
  393                 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
 
  394                 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
 
  396                 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
 
  397                 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
 
  399                 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
 
  404                             vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
 
  405                             vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
 
  407                             static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
 
  408                                 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
 
  409                                     a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
 
  417                             static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
 
  418                                 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
 
  419                                     b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
 
  428                             using wmma_input_type_a =
 
  429                                 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
 
  430                             using wmma_input_type_b =
 
  431                                 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
 
  434                                 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
 
  436                             wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
 
  437                                           b_thread_vec.template AsType<wmma_input_type_b>(),
 
  445                 LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
 
  448                 __builtin_amdgcn_sched_barrier(0);
 
  451             } 
while(i < (num_loop - 1));
 
  459                         vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
 
  460                         vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
 
  462                         static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
 
  463                             a_thread_vec.template AsType<ComputeTypeA>()(ik) =
 
  467                         static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
 
  468                             b_thread_vec.template AsType<ComputeTypeB>()(ik) =
 
  473                         using wmma_input_type_a =
 
  474                             typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
 
  475                         using wmma_input_type_b =
 
  476                             typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
 
  479                             c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
 
  481                         wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
 
  482                                       b_thread_vec.template AsType<wmma_input_type_b>(),
 
  494     using Base::a_thread_copy_;
 
  495     using Base::a_thread_desc_;
 
  496     using Base::b_thread_copy_;
 
  497     using Base::b_thread_desc_;
 
  498     using Base::c_thread_desc_;
 
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
 
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
 
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:297
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
 
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:95
 
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:35
 
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
 
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:141
 
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:328
 
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:146
 
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::LocalLoad __device__ void LocalLoad(ABlockBuffer &a_block_buf, AThreadBuffer &a_thread_buf, BBlockBuffer &b_block_buf, BThreadBuffer &b_thread_buf, BScaleStruct &b_scale_struct) const
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:265
 
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::HotLoopScheduler static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:152
 
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:36
 
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33
 
Definition: dtype_vector.hpp:10