20 typename ComputeTypeA,
21 typename ComputeTypeB,
23 typename AWmmaTileDesc,
24 typename BWmmaTileDesc,
25 index_t ABlockTransferSrcScalarPerVector,
26 index_t BBlockTransferSrcScalarPerVector,
36 bool TransposeC =
false>
44 typename ComputeTypeA,
45 typename ComputeTypeB,
47 typename AWmmaTileDesc,
48 typename BWmmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
70 ABlockTransferSrcScalarPerVector,
71 BBlockTransferSrcScalarPerVector,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
132 using Base::wmma_gemm;
135 using Base::CalculateCThreadOriginDataIndex;
137 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
138 using Base::GetCThreadBuffer;
140 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
142 GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs;
144 using Base::a_block_desc_k0_m0_m1_m2_k1;
145 using Base::b_block_desc_k0_n0_n1_n2_k1;
147 using typename Base::Empty;
155 return num_loop > PrefetchStages;
160 if(BlockHasHotloop(num_loop))
285 template <
typename ABlockBuffer,
286 typename AThreadBuffer,
287 typename BBlockBuffer,
288 typename BThreadBuffer,
289 typename BScaleStruct>
290 __device__
inline void LocalLoad(ABlockBuffer& a_block_buf,
291 AThreadBuffer& a_thread_buf,
292 BBlockBuffer& b_block_buf,
293 BThreadBuffer& b_thread_buf,
294 BScaleStruct& b_scale_struct)
const
298 a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
306 if constexpr(ck::is_same_v<BScaleStruct, Empty>)
309 b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
320 b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
323 b_scale_struct.b_scale_thread_bufs(
324 I0)[
Number<n0 * BScaleStruct::num_scale_k_block +
325 k0 / BScaleStruct::num_scale_krepeat>{}],
334 template <
bool HasMainLoop,
338 typename ABlockTransfer,
339 typename AGridBuffer,
340 typename ABlockBuffer,
341 typename ABlockTransferStep,
344 typename BBlockTransfer,
345 typename BGridBuffer,
346 typename BBlockBuffer,
347 typename BBlockTransferStep,
348 typename CThreadBuffer,
349 typename BScaleStruct>
350 __device__
void Run(
const AGridDesc& a_grid_desc,
351 const ABlockDesc& a_block_desc,
352 ABlockTransfer& a_blockwise_copy,
353 const AGridBuffer& a_grid_buf,
354 ABlockBuffer& a_block_buf,
355 const ABlockTransferStep& a_block_copy_step,
356 const BGridDesc& b_grid_desc,
357 const BBlockDesc& b_block_desc,
358 BBlockTransfer& b_blockwise_copy,
359 const BGridBuffer& b_grid_buf,
360 BBlockBuffer& b_block_buf,
361 const BBlockTransferStep& b_block_copy_step,
362 CThreadBuffer& c_thread_buf,
364 BScaleStruct& b_scale_struct,
366 index_t num_loop_per_scale)
const
368 __builtin_amdgcn_sched_barrier(0);
370 constexpr
index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
372 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
373 a_thread_desc_.GetElementSpaceSize());
374 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
375 b_thread_desc_.GetElementSpaceSize());
378 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
379 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
381 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
382 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
384 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
387 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
388 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
393 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
394 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
396 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
397 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
401 c_thread_buf.Clear();
406 LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
408 __builtin_amdgcn_sched_barrier(0);
411 if constexpr(HasMainLoop)
418 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
419 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
421 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
422 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
424 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
425 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
427 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
433 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
434 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
436 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
437 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
438 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
439 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
448 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
449 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
450 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
451 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
461 using wmma_input_type_a =
462 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
463 using wmma_input_type_b =
464 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
467 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
470 a_thread_vec.template AsType<wmma_input_type_a>(),
471 b_thread_vec.template AsType<wmma_input_type_b>(),
480 LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
483 __builtin_amdgcn_sched_barrier(0);
486 }
while(i < (num_loop - 2));
494 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
495 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
499 b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
505 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
506 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
508 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
509 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
510 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
511 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
520 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
521 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
522 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
523 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
533 using wmma_input_type_a =
534 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
535 using wmma_input_type_b =
536 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
539 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
541 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
542 b_thread_vec.template AsType<wmma_input_type_b>(),
551 LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
554 __builtin_amdgcn_sched_barrier(0);
563 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
564 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
566 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
567 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
568 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
569 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
578 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
579 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
580 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
581 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
591 using wmma_input_type_a =
592 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
593 using wmma_input_type_b =
594 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
597 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
599 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
600 b_thread_vec.template AsType<wmma_input_type_b>(),
613 using Base::a_thread_copy_;
614 using Base::a_thread_desc_;
615 using Base::b_thread_copy_;
616 using Base::b_thread_desc_;
617 using Base::c_thread_desc_;
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:299
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:36
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC >::LocalLoad __device__ void LocalLoad(ABlockBuffer &a_block_buf, AThreadBuffer &a_thread_buf, BBlockBuffer &b_block_buf, BThreadBuffer &b_thread_buf, BScaleStruct &b_scale_struct) const
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:290
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:350
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:158
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:153
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC >::HotLoopScheduler static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:177
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:38
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10