20 typename ComputeTypeA,
21 typename ComputeTypeB,
23 typename AWmmaTileDesc,
24 typename BWmmaTileDesc,
25 index_t ABlockTransferSrcScalarPerVector,
26 index_t BBlockTransferSrcScalarPerVector,
36 bool TransposeC =
false,
37 bool BSkipLDS =
false>
45 typename ComputeTypeA,
46 typename ComputeTypeB,
48 typename AWmmaTileDesc,
49 typename BWmmaTileDesc,
50 index_t ABlockTransferSrcScalarPerVector,
51 index_t BBlockTransferSrcScalarPerVector,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
92 ABlockTransferSrcScalarPerVector,
93 BBlockTransferSrcScalarPerVector,
113 ABlockTransferSrcScalarPerVector,
114 BBlockTransferSrcScalarPerVector,
134 using Base::wmma_gemm;
137 using Base::CalculateCThreadOriginDataIndex;
139 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
140 using Base::GetCThreadBuffer;
142 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
144 GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs;
146 using Base::a_block_desc_k0_m0_m1_m2_k1;
147 using Base::b_block_desc_k0_n0_n1_n2_k1;
149 using typename Base::Empty;
157 return num_loop > PrefetchStages;
162 if(BlockHasHotloop(num_loop))
287 template <
typename ABlockBuffer,
288 typename AThreadBuffer,
289 typename BBlockBuffer,
290 typename BThreadBuffer,
291 typename BScaleStruct>
292 __device__
inline void LocalLoad(ABlockBuffer& a_block_buf,
293 AThreadBuffer& a_thread_buf,
294 BBlockBuffer& b_block_buf,
295 BThreadBuffer& b_thread_buf,
296 BScaleStruct& b_scale_struct)
const
300 a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
308 if constexpr(ck::is_same_v<BScaleStruct, Empty>)
311 b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
322 b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
325 b_scale_struct.b_scale_thread_bufs(
326 I0)[
Number<n0 * BScaleStruct::num_scale_k_block +
327 k0 / BScaleStruct::num_scale_krepeat>{}],
336 template <
bool HasMainLoop,
340 typename ABlockTransfer,
341 typename AGridBuffer,
342 typename ABlockBuffer,
343 typename ABlockTransferStep,
346 typename BBlockTransfer,
347 typename BGridBuffer,
348 typename BBlockBuffer,
349 typename BBlockTransferStep,
350 typename CThreadBuffer,
351 typename BScaleStruct>
352 __device__
void Run(
const AGridDesc& a_grid_desc,
353 const ABlockDesc& a_block_desc,
354 ABlockTransfer& a_blockwise_copy,
355 const AGridBuffer& a_grid_buf,
356 ABlockBuffer& a_block_buf,
357 const ABlockTransferStep& a_block_copy_step,
358 const BGridDesc& b_grid_desc,
359 const BBlockDesc& b_block_desc,
360 BBlockTransfer& b_blockwise_copy,
361 const BGridBuffer& b_grid_buf,
362 BBlockBuffer& b_block_buf,
363 const BBlockTransferStep& b_block_copy_step,
364 CThreadBuffer& c_thread_buf,
366 BScaleStruct& b_scale_struct,
368 index_t num_loop_per_scale)
const
370 __builtin_amdgcn_sched_barrier(0);
372 constexpr
index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
374 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
375 a_thread_desc_.GetElementSpaceSize());
376 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
377 b_thread_desc_.GetElementSpaceSize());
380 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
381 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
383 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
384 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
386 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
389 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
390 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
395 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
396 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
398 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
399 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
403 c_thread_buf.Clear();
408 LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
410 __builtin_amdgcn_sched_barrier(0);
413 if constexpr(HasMainLoop)
420 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
421 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
423 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
424 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
426 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
427 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
429 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
435 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
436 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
438 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
439 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
440 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
441 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
450 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
451 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
452 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
453 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
463 using wmma_input_type_a =
464 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
465 using wmma_input_type_b =
466 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
469 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
472 a_thread_vec.template AsType<wmma_input_type_a>(),
473 b_thread_vec.template AsType<wmma_input_type_b>(),
482 LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
485 __builtin_amdgcn_sched_barrier(0);
488 }
while(i < (num_loop - 2));
496 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
497 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
501 b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
507 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
508 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
510 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
511 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
512 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
513 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
522 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
523 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
524 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
525 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
535 using wmma_input_type_a =
536 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
537 using wmma_input_type_b =
538 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
541 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
543 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
544 b_thread_vec.template AsType<wmma_input_type_b>(),
553 LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
556 __builtin_amdgcn_sched_barrier(0);
565 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
566 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
568 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
569 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
570 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
571 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
580 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
581 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
582 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
583 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
593 using wmma_input_type_a =
594 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
595 using wmma_input_type_b =
596 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
599 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
601 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
602 b_thread_vec.template AsType<wmma_input_type_b>(),
615 using Base::a_thread_copy_;
616 using Base::a_thread_desc_;
617 using Base::b_thread_copy_;
618 using Base::b_thread_desc_;
619 using Base::c_thread_desc_;
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:301
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:36
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:352
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:155
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::LocalLoad __device__ void LocalLoad(ABlockBuffer &a_block_buf, AThreadBuffer &a_thread_buf, BBlockBuffer &b_block_buf, BThreadBuffer &b_thread_buf, BScaleStruct &b_scale_struct) const
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:292
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:160
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC, false >::HotLoopScheduler static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:179
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:39
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:11