20 typename ComputeTypeA,
21 typename ComputeTypeB,
23 typename AWmmaTileDesc,
24 typename BWmmaTileDesc,
25 index_t ABlockTransferSrcScalarPerVector,
26 index_t BBlockTransferSrcScalarPerVector,
36 bool TransposeC =
false>
44 typename ComputeTypeA,
45 typename ComputeTypeB,
47 typename AWmmaTileDesc,
48 typename BWmmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
70 ABlockTransferSrcScalarPerVector,
71 BBlockTransferSrcScalarPerVector,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
134 using Base::wmma_gemm;
136 using Base::CalculateCThreadOriginDataIndex;
138 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
139 using Base::GetCThreadBuffer;
141 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
143 using Base::a_block_desc_k0_m0_m1_m2_k1;
144 using Base::b_block_desc_k0_n0_n1_n2_k1;
146 using typename Base::Empty;
160 template <
bool HasMainLoop,
164 typename ABlockTransfer,
165 typename AGridBuffer,
166 typename ABlockBuffer,
167 typename ABlockTransferStep,
170 typename BBlockTransfer,
171 typename BGridBuffer,
172 typename BBlockBuffer,
173 typename BBlockTransferStep,
174 typename CThreadBuffer,
175 typename BScaleStruct>
176 __device__
void Run(
const AGridDesc& a_grid_desc,
177 const ABlockDesc& a_block_desc,
178 ABlockTransfer& a_blockwise_copy,
179 const AGridBuffer& a_grid_buf,
180 ABlockBuffer& a_block_buf,
181 const ABlockTransferStep& a_block_copy_step,
182 const BGridDesc& b_grid_desc,
183 const BBlockDesc& b_block_desc,
184 BBlockTransfer& b_blockwise_copy,
185 const BGridBuffer& b_grid_buf,
186 BBlockBuffer& b_block_buf,
187 const BBlockTransferStep& b_block_copy_step,
188 CThreadBuffer& c_thread_buf,
190 BScaleStruct& b_scale_struct,
192 index_t num_loop_per_scale)
const
194 constexpr
index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
196 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
197 a_thread_desc_.GetElementSpaceSize());
198 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
199 b_thread_desc_.GetElementSpaceSize());
202 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
203 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
205 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
206 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
208 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
211 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
212 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
215 c_thread_buf.Clear();
217 auto blockwise_gemm_func = [&]() {
220 a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
226 if constexpr(m0 == I0)
231 b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
243 b_block_desc_k0_n0_n1_n2_k1,
246 b_scale_struct.b_scale_thread_bufs(
247 I0)[
Number<n0 * BScaleStruct::num_scale_k_block +
248 k0 / BScaleStruct::num_scale_krepeat>{}],
258 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
259 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
261 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
262 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
263 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
264 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
273 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
274 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
275 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
276 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
286 using wmma_input_type_a =
287 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
288 using wmma_input_type_b =
289 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
292 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
294 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
295 b_thread_vec.template AsType<wmma_input_type_b>(),
304 if constexpr(HasMainLoop)
309 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
310 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
312 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
313 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
316 blockwise_gemm_func();
319 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
324 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
325 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
327 constexpr
index_t num_ds_write_inst =
328 HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num;
330 constexpr
index_t num_buffer_load_inst = HotLoopInstList::A_Buffer_Load_Inst_Num +
331 HotLoopInstList::B_Buffer_Load_Inst_Num;
333 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
337 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
338 if constexpr(m0 == I0)
341 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
346 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
352 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
356 }
while(i < (num_loop - 1));
363 blockwise_gemm_func();
379 decltype(a_block_desc_k0_m0_m1_m2_k1),
380 decltype(a_thread_desc_),
381 Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
390 decltype(b_block_desc_k0_n0_n1_n2_k1),
391 decltype(b_thread_desc_),
392 Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
398 AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
399 BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
400 using Base::c_thread_desc_;
406 typename ComputeTypeA,
407 typename ComputeTypeB,
408 typename AccDataType,
409 typename AWmmaTileDesc,
410 typename BWmmaTileDesc,
411 index_t ABlockTransferSrcScalarPerVector,
412 index_t BBlockTransferSrcScalarPerVector,
432 ABlockTransferSrcScalarPerVector,
433 BBlockTransferSrcScalarPerVector,
452 ABlockTransferSrcScalarPerVector,
453 BBlockTransferSrcScalarPerVector,
473 ABlockTransferSrcScalarPerVector,
474 BBlockTransferSrcScalarPerVector,
495 using Base::wmma_gemm;
497 using Base::CalculateCThreadOriginDataIndex;
499 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
500 using Base::GetCThreadBuffer;
502 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
504 using Base::a_block_desc_k0_m0_m1_m2_k1;
505 using Base::b_block_desc_k0_n0_n1_n2_k1;
507 using typename Base::Empty;
524 template <
bool HasMainLoop,
528 typename ABlockTransfer,
529 typename AGridBuffer,
530 typename ABlockBuffer,
531 typename ABlockTransferStep,
534 typename BBlockTransfer,
535 typename BGridBuffer,
536 typename BBlockBuffer,
537 typename BBlockTransferStep,
538 typename CThreadBuffer,
539 typename BScaleStruct>
540 __device__
void Run(
const AGridDesc& a_grid_desc,
541 const ABlockDesc& a_block_desc,
542 ABlockTransfer& a_blockwise_copy,
543 const AGridBuffer& a_grid_buf,
544 ABlockBuffer& a_block_buf,
545 const ABlockTransferStep& a_block_copy_step,
546 const BGridDesc& b_grid_desc,
547 const BBlockDesc& b_block_desc,
548 BBlockTransfer& b_blockwise_copy,
549 const BGridBuffer& b_grid_buf,
550 BBlockBuffer& b_block_buf,
551 const BBlockTransferStep& b_block_copy_step,
552 CThreadBuffer& c_thread_buf,
554 BScaleStruct& b_scale_struct,
556 index_t num_loop_per_scale)
const
558 constexpr
index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
560 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
561 a_thread_desc_.GetElementSpaceSize());
562 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
563 b_thread_desc_.GetElementSpaceSize());
566 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
567 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
569 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
570 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
572 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
575 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
576 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
579 c_thread_buf.Clear();
581 auto blockwise_gemm_func = [&]() {
585 a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
586 make_tuple(I0, m0, k0_offset + k0_inner, I0, I0, I0, I0),
596 b_block_desc_k0_n0_n1_n2_k1,
597 make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
608 b_block_desc_k0_n0_n1_n2_k1,
609 make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
611 b_scale_struct.b_scale_thread_bufs(I0)[
Number<
612 n0 * BScaleStruct::num_scale_k_block +
613 (k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
621 __builtin_amdgcn_sched_barrier(0);
628 if constexpr(k0_offset != 0 || KRepeat == 1)
630 __builtin_amdgcn_s_barrier();
631 __builtin_amdgcn_sched_barrier(0);
637 vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
638 vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
640 static_for<0, KPack / A_KRow / KInner, 1>{}([&](
auto ik) {
641 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
642 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
643 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
652 static_for<0, KPack / B_KRow / KInner, 1>{}([&](
auto ik) {
653 constexpr
index_t kk = ik + k_inner * KPerWaveBlock;
654 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
655 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
665 using wmma_input_type_a =
666 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
667 using wmma_input_type_b =
668 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
671 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
679 if constexpr(k0_offset + k0_inner == KRepeat - 1 &&
680 m0 == MRepeat - 1 && n0 == NRepeat - 1)
682 __builtin_amdgcn_sched_barrier(0);
684 __builtin_amdgcn_sched_barrier(0);
687 a_thread_vec.template AsType<wmma_input_type_a>(),
688 b_thread_vec.template AsType<wmma_input_type_b>(),
690 if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
692 __builtin_amdgcn_sched_barrier(0);
693 __builtin_amdgcn_s_setprio(1);
694 __builtin_amdgcn_sched_barrier(0);
700 __builtin_amdgcn_sched_barrier(0);
701 __builtin_amdgcn_s_setprio(0);
702 __builtin_amdgcn_sched_barrier(0);
707 if constexpr(HasMainLoop)
712 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
713 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
715 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
716 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
719 blockwise_gemm_func();
721 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
726 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
727 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
730 }
while(i < (num_loop - 1));
737 blockwise_gemm_func();
742 static constexpr
auto a_thread_desc_ =
745 Number<KRepeatPerCluster>{},
752 Number<KPack / A_KRow * MRepeat>{},
758 static constexpr
auto b_thread_desc_ =
761 Number<KRepeatPerCluster>{},
768 Number<KPack / B_KRow * NRepeat>{},
777 decltype(a_block_desc_k0_m0_m1_m2_k1),
778 decltype(a_thread_desc_),
779 Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
788 decltype(b_block_desc_k0_n0_n1_n2_k1),
789 decltype(b_thread_desc_),
790 Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
796 AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
797 BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
798 using Base::c_thread_desc_;
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:209
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:299
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:36
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC >::BlockLoopTailNum static TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:518
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC >::BlockHasHotloop static bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:516
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:540
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC >::BlockLoopTailNum static TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:154
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:176
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, KInner, TransposeC >::BlockHasHotloop static bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:152
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:38
Definition: sequence.hpp:43
ck::ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeTypeA, decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_thread_desc_), Sequence< KPack/A_K1/A_KRow, 1, 1, 1, 1, 1, A_K1 >, Sequence< 0, 1, 2, 3, 4, 5, 6 >, 6, A_K1, A_K1 >
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10