20 typename ComputeTypeA,
21 typename ComputeTypeB,
23 typename AWmmaTileDesc,
24 typename BWmmaTileDesc,
25 index_t ABlockTransferSrcScalarPerVector,
26 index_t BBlockTransferSrcScalarPerVector,
35 bool TransposeC =
false>
43 typename ComputeTypeA,
44 typename ComputeTypeB,
46 typename AWmmaTileDesc,
47 typename BWmmaTileDesc,
48 index_t ABlockTransferSrcScalarPerVector,
49 index_t BBlockTransferSrcScalarPerVector,
68 ABlockTransferSrcScalarPerVector,
69 BBlockTransferSrcScalarPerVector,
87 ABlockTransferSrcScalarPerVector,
88 BBlockTransferSrcScalarPerVector,
107 ABlockTransferSrcScalarPerVector,
108 BBlockTransferSrcScalarPerVector,
127 using Base::wmma_gemm;
130 using Base::CalculateCThreadOriginDataIndex;
132 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
133 using Base::GetCThreadBuffer;
135 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
137 GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs;
139 using Base::a_block_desc_k0_m0_m1_m2_k1;
140 using Base::b_block_desc_k0_n0_n1_n2_k1;
142 using typename Base::Empty;
150 return num_loop > PrefetchStages;
155 if(BlockHasHotloop(num_loop))
280 template <
typename ABlockBuffer,
281 typename AThreadBuffer,
282 typename BBlockBuffer,
283 typename BThreadBuffer,
284 typename BScaleStruct>
285 __device__
inline void LocalLoad(ABlockBuffer& a_block_buf,
286 AThreadBuffer& a_thread_buf,
287 BBlockBuffer& b_block_buf,
288 BThreadBuffer& b_thread_buf,
289 BScaleStruct& b_scale_struct)
const
294 a_block_desc_k0_m0_m1_m2_k1,
302 if constexpr(ck::is_same_v<BScaleStruct, Empty>)
306 b_block_desc_k0_n0_n1_n2_k1,
318 b_block_desc_k0_n0_n1_n2_k1,
321 b_scale_struct.b_scale_thread_bufs(
322 I0)[
Number<n0 * BScaleStruct::num_scale_k_block +
323 k0 / BScaleStruct::num_scale_krepeat>{}],
332 template <
bool HasMainLoop,
336 typename ABlockTransfer,
337 typename AGridBuffer,
338 typename ABlockBuffer,
339 typename ABlockTransferStep,
342 typename BBlockTransfer,
343 typename BGridBuffer,
344 typename BBlockBuffer,
345 typename BBlockTransferStep,
346 typename CThreadBuffer,
347 typename BScaleStruct>
348 __device__
void Run(
const AGridDesc& a_grid_desc,
349 const ABlockDesc& a_block_desc,
350 ABlockTransfer& a_blockwise_copy,
351 const AGridBuffer& a_grid_buf,
352 ABlockBuffer& a_block_buf,
353 const ABlockTransferStep& a_block_copy_step,
354 const BGridDesc& b_grid_desc,
355 const BBlockDesc& b_block_desc,
356 BBlockTransfer& b_blockwise_copy,
357 const BGridBuffer& b_grid_buf,
358 BBlockBuffer& b_block_buf,
359 const BBlockTransferStep& b_block_copy_step,
360 CThreadBuffer& c_thread_buf,
362 BScaleStruct& b_scale_struct,
364 index_t num_loop_per_scale)
const
366 __builtin_amdgcn_sched_barrier(0);
367 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
368 a_thread_desc_.GetElementSpaceSize());
369 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
370 b_thread_desc_.GetElementSpaceSize());
373 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
374 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
376 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
377 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
379 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
382 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
383 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
388 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
389 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
391 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
392 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
396 c_thread_buf.Clear();
401 LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
403 __builtin_amdgcn_sched_barrier(0);
406 if constexpr(HasMainLoop)
413 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
414 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
416 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
417 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
419 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
420 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
422 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
427 vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
428 vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
430 static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
431 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
432 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
440 static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
441 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
442 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
451 using wmma_input_type_a =
452 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
453 using wmma_input_type_b =
454 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
457 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
459 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
460 b_thread_vec.template AsType<wmma_input_type_b>(),
468 LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
471 __builtin_amdgcn_sched_barrier(0);
474 }
while(i < (num_loop - 2));
482 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
483 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
487 b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
492 vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
493 vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
495 static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
496 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
500 static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
501 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
506 using wmma_input_type_a =
507 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
508 using wmma_input_type_b =
509 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
512 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
514 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
515 b_thread_vec.template AsType<wmma_input_type_b>(),
523 LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
526 __builtin_amdgcn_sched_barrier(0);
534 vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
535 vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
537 static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
538 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
542 static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
543 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
548 using wmma_input_type_a =
549 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
550 using wmma_input_type_b =
551 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
554 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
556 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
557 b_thread_vec.template AsType<wmma_input_type_b>(),
569 using Base::a_thread_copy_;
570 using Base::a_thread_desc_;
571 using Base::b_thread_copy_;
572 using Base::b_thread_desc_;
573 using Base::c_thread_desc_;
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:35
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, TransposeC >::LocalLoad __device__ void LocalLoad(ABlockBuffer &a_block_buf, AThreadBuffer &a_thread_buf, BBlockBuffer &b_block_buf, BThreadBuffer &b_thread_buf, BScaleStruct &b_scale_struct) const
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:285
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, TransposeC >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:153
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, TransposeC >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:148
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, TransposeC >::HotLoopScheduler static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:172
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, TransposeC >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:348
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:37
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10