20 typename ComputeTypeA,
21 typename ComputeTypeB,
23 typename AWmmaTileDesc,
24 typename BWmmaTileDesc,
25 index_t ABlockTransferSrcScalarPerVector,
26 index_t BBlockTransferSrcScalarPerVector,
35 bool TransposeC =
false>
43 typename ComputeTypeA,
44 typename ComputeTypeB,
46 typename AWmmaTileDesc,
47 typename BWmmaTileDesc,
48 index_t ABlockTransferSrcScalarPerVector,
49 index_t BBlockTransferSrcScalarPerVector,
68 ABlockTransferSrcScalarPerVector,
69 BBlockTransferSrcScalarPerVector,
87 ABlockTransferSrcScalarPerVector,
88 BBlockTransferSrcScalarPerVector,
107 ABlockTransferSrcScalarPerVector,
108 BBlockTransferSrcScalarPerVector,
120 using Base::WaveSize;
130 using Base::wmma_gemm;
132 using Base::CalculateCThreadOriginDataIndex;
134 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
135 using Base::GetCThreadBuffer;
137 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
139 using Base::a_block_desc_k0_m0_m1_m2_k1;
140 using Base::b_block_desc_k0_n0_n1_n2_k1;
142 using typename Base::Empty;
156 template <
bool HasMainLoop,
160 typename ABlockTransfer,
161 typename AGridBuffer,
162 typename ABlockBuffer,
163 typename ABlockTransferStep,
166 typename BBlockTransfer,
167 typename BGridBuffer,
168 typename BBlockBuffer,
169 typename BBlockTransferStep,
170 typename CThreadBuffer,
171 typename BScaleStruct>
172 __device__
void Run(
const AGridDesc& a_grid_desc,
173 const ABlockDesc& a_block_desc,
174 ABlockTransfer& a_blockwise_copy,
175 const AGridBuffer& a_grid_buf,
176 ABlockBuffer& a_block_buf,
177 const ABlockTransferStep& a_block_copy_step,
178 const BGridDesc& b_grid_desc,
179 const BBlockDesc& b_block_desc,
180 BBlockTransfer& b_blockwise_copy,
181 const BGridBuffer& b_grid_buf,
182 BBlockBuffer& b_block_buf,
183 const BBlockTransferStep& b_block_copy_step,
184 CThreadBuffer& c_thread_buf,
186 BScaleStruct& b_scale_struct,
188 index_t num_loop_per_scale)
const
190 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
191 a_thread_desc_.GetElementSpaceSize());
192 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
193 b_thread_desc_.GetElementSpaceSize());
196 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
197 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
199 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
200 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
202 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
205 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
206 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
209 c_thread_buf.Clear();
211 auto blockwise_gemm_func = [&]() {
215 a_block_desc_k0_m0_m1_m2_k1,
222 if constexpr(m0 == I0)
228 b_block_desc_k0_n0_n1_n2_k1,
241 b_block_desc_k0_n0_n1_n2_k1,
245 b_scale_struct.b_scale_thread_bufs(
246 I0)[
Number<n0 * BScaleStruct::num_scale_k_block +
247 k0 / BScaleStruct::num_scale_krepeat>{}],
256 vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
257 vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
259 static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
260 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
264 static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
265 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
270 using wmma_input_type_a =
271 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
272 using wmma_input_type_b =
273 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
276 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
278 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
279 b_thread_vec.template AsType<wmma_input_type_b>(),
287 if constexpr(HasMainLoop)
292 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
293 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
295 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
296 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
299 blockwise_gemm_func();
302 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
307 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
308 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
310 constexpr
index_t num_ds_write_inst =
311 HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num;
313 constexpr
index_t num_buffer_load_inst = HotLoopInstList::A_Buffer_Load_Inst_Num +
314 HotLoopInstList::B_Buffer_Load_Inst_Num;
316 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
320 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
321 if constexpr(m0 == I0)
324 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
328 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
333 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
337 }
while(i < (num_loop - 1));
344 blockwise_gemm_func();
360 decltype(a_block_desc_k0_m0_m1_m2_k1),
361 decltype(a_thread_desc_),
362 Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
371 decltype(b_block_desc_k0_n0_n1_n2_k1),
372 decltype(b_thread_desc_),
373 Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
379 AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
380 BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
381 using Base::c_thread_desc_;
387 typename ComputeTypeA,
388 typename ComputeTypeB,
389 typename AccDataType,
390 typename AWmmaTileDesc,
391 typename BWmmaTileDesc,
392 index_t ABlockTransferSrcScalarPerVector,
393 index_t BBlockTransferSrcScalarPerVector,
412 ABlockTransferSrcScalarPerVector,
413 BBlockTransferSrcScalarPerVector,
431 ABlockTransferSrcScalarPerVector,
432 BBlockTransferSrcScalarPerVector,
451 ABlockTransferSrcScalarPerVector,
452 BBlockTransferSrcScalarPerVector,
472 using Base::wmma_gemm;
474 using Base::CalculateCThreadOriginDataIndex;
476 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
477 using Base::GetCThreadBuffer;
479 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
481 using Base::a_block_desc_k0_m0_m1_m2_k1;
482 using Base::b_block_desc_k0_n0_n1_n2_k1;
484 using typename Base::Empty;
501 template <
bool HasMainLoop,
505 typename ABlockTransfer,
506 typename AGridBuffer,
507 typename ABlockBuffer,
508 typename ABlockTransferStep,
511 typename BBlockTransfer,
512 typename BGridBuffer,
513 typename BBlockBuffer,
514 typename BBlockTransferStep,
515 typename CThreadBuffer,
516 typename BScaleStruct>
517 __device__
void Run(
const AGridDesc& a_grid_desc,
518 const ABlockDesc& a_block_desc,
519 ABlockTransfer& a_blockwise_copy,
520 const AGridBuffer& a_grid_buf,
521 ABlockBuffer& a_block_buf,
522 const ABlockTransferStep& a_block_copy_step,
523 const BGridDesc& b_grid_desc,
524 const BBlockDesc& b_block_desc,
525 BBlockTransfer& b_blockwise_copy,
526 const BGridBuffer& b_grid_buf,
527 BBlockBuffer& b_block_buf,
528 const BBlockTransferStep& b_block_copy_step,
529 CThreadBuffer& c_thread_buf,
531 BScaleStruct& b_scale_struct,
533 index_t num_loop_per_scale)
const
535 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
536 a_thread_desc_.GetElementSpaceSize());
537 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
538 b_thread_desc_.GetElementSpaceSize());
541 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
542 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
544 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
545 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
547 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
550 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
551 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
554 c_thread_buf.Clear();
556 auto blockwise_gemm_func = [&]() {
561 a_block_desc_k0_m0_m1_m2_k1,
577 b_block_desc_k0_n0_n1_n2_k1,
594 b_block_desc_k0_n0_n1_n2_k1,
602 b_scale_struct.b_scale_thread_bufs(I0)[
Number<
603 n0 * BScaleStruct::num_scale_k_block +
604 (k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
612 __builtin_amdgcn_sched_barrier(0);
619 if constexpr(k0_offset != 0 || KRepeat == 1)
621 __builtin_amdgcn_s_barrier();
622 __builtin_amdgcn_sched_barrier(0);
627 vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
628 vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
630 static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
631 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
632 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
640 static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
641 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
642 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
651 using wmma_input_type_a =
652 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
653 using wmma_input_type_b =
654 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
657 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
665 if constexpr(k0_offset + k0_inner == KRepeat - 1 && m0 == MRepeat - 1 &&
668 __builtin_amdgcn_sched_barrier(0);
670 __builtin_amdgcn_sched_barrier(0);
672 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
673 b_thread_vec.template AsType<wmma_input_type_b>(),
675 if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
677 __builtin_amdgcn_sched_barrier(0);
678 __builtin_amdgcn_s_setprio(1);
679 __builtin_amdgcn_sched_barrier(0);
684 __builtin_amdgcn_sched_barrier(0);
685 __builtin_amdgcn_s_setprio(0);
686 __builtin_amdgcn_sched_barrier(0);
691 if constexpr(HasMainLoop)
696 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
697 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
699 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
700 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
703 blockwise_gemm_func();
705 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
710 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
711 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
714 }
while(i < (num_loop - 1));
721 blockwise_gemm_func();
726 static constexpr
auto a_thread_desc_ =
729 Number<KRepeatPerCluster>{},
735 Number<KPack / A_KRow * MRepeat>{},
740 static constexpr
auto b_thread_desc_ =
743 Number<KRepeatPerCluster>{},
749 Number<KPack / B_KRow * NRepeat>{},
757 decltype(a_block_desc_k0_m0_m1_m2_k1),
758 decltype(a_thread_desc_),
759 Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
768 decltype(b_block_desc_k0_n0_n1_n2_k1),
769 decltype(b_thread_desc_),
770 Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
776 AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
777 BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
778 using Base::c_thread_desc_;
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:209
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:299
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:35
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, TransposeC >::BlockHasHotloop static bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:493
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, TransposeC >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:517
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, TransposeC >::BlockLoopTailNum static TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:495
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, TransposeC >::BlockHasHotloop static bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:148
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, TransposeC >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:172
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, TransposeC >::BlockLoopTailNum static TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:150
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:37
Definition: sequence.hpp:43
ck::ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeTypeA, decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_thread_desc_), Sequence< KPack/A_K1/A_KRow, 1, 1, 1, 1, A_K1 >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, A_K1, A_K1 >
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10