20 typename ComputeTypeA,
21 typename ComputeTypeB,
23 typename AWmmaTileDesc,
24 typename BWmmaTileDesc,
25 index_t ABlockTransferSrcScalarPerVector,
26 index_t BBlockTransferSrcScalarPerVector,
35 bool TransposeC =
false>
43 typename ComputeTypeA,
44 typename ComputeTypeB,
46 typename AWmmaTileDesc,
47 typename BWmmaTileDesc,
48 index_t ABlockTransferSrcScalarPerVector,
49 index_t BBlockTransferSrcScalarPerVector,
68 ABlockTransferSrcScalarPerVector,
69 BBlockTransferSrcScalarPerVector,
87 ABlockTransferSrcScalarPerVector,
88 BBlockTransferSrcScalarPerVector,
107 ABlockTransferSrcScalarPerVector,
108 BBlockTransferSrcScalarPerVector,
127 using Base::wmma_gemm;
129 using Base::CalculateCThreadOriginDataIndex;
131 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
132 using Base::GetCThreadBuffer;
134 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
136 using Base::a_block_desc_k0_m0_m1_m2_k1;
137 using Base::b_block_desc_k0_n0_n1_n2_k1;
139 using typename Base::Empty;
153 template <
bool HasMainLoop,
157 typename ABlockTransfer,
158 typename AGridBuffer,
159 typename ABlockBuffer,
160 typename ABlockTransferStep,
163 typename BBlockTransfer,
164 typename BGridBuffer,
165 typename BBlockBuffer,
166 typename BBlockTransferStep,
167 typename CThreadBuffer,
168 typename BScaleStruct>
169 __device__
void Run(
const AGridDesc& a_grid_desc,
170 const ABlockDesc& a_block_desc,
171 ABlockTransfer& a_blockwise_copy,
172 const AGridBuffer& a_grid_buf,
173 ABlockBuffer& a_block_buf,
174 const ABlockTransferStep& a_block_copy_step,
175 const BGridDesc& b_grid_desc,
176 const BBlockDesc& b_block_desc,
177 BBlockTransfer& b_blockwise_copy,
178 const BGridBuffer& b_grid_buf,
179 BBlockBuffer& b_block_buf,
180 const BBlockTransferStep& b_block_copy_step,
181 CThreadBuffer& c_thread_buf,
183 BScaleStruct& b_scale_struct,
185 index_t num_loop_per_scale)
const
187 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
188 a_thread_desc_.GetElementSpaceSize());
189 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
190 b_thread_desc_.GetElementSpaceSize());
193 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
194 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
196 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
197 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
199 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
202 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
203 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
206 c_thread_buf.Clear();
208 auto blockwise_gemm_func = [&]() {
212 a_block_desc_k0_m0_m1_m2_k1,
223 b_block_desc_k0_n0_n1_n2_k1,
235 b_block_desc_k0_n0_n1_n2_k1,
238 b_scale_struct.b_scale_thread_bufs(
239 I0)[
Number<n0 * BScaleStruct::num_scale_k_block +
240 k0 / BScaleStruct::num_scale_krepeat>{}],
249 vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
250 vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
252 static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
253 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
257 static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
258 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
263 using wmma_input_type_a =
264 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
265 using wmma_input_type_b =
266 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
269 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
271 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
272 b_thread_vec.template AsType<wmma_input_type_b>(),
280 if constexpr(HasMainLoop)
285 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
286 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
288 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
289 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
292 blockwise_gemm_func();
295 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
296 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
297 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
300 }
while(i < (num_loop - 1));
307 blockwise_gemm_func();
312 using Base::a_thread_copy_;
313 using Base::a_thread_desc_;
314 using Base::b_thread_copy_;
315 using Base::b_thread_desc_;
316 using Base::c_thread_desc_;
322 typename ComputeTypeA,
323 typename ComputeTypeB,
324 typename AccDataType,
325 typename AWmmaTileDesc,
326 typename BWmmaTileDesc,
327 index_t ABlockTransferSrcScalarPerVector,
328 index_t BBlockTransferSrcScalarPerVector,
347 ABlockTransferSrcScalarPerVector,
348 BBlockTransferSrcScalarPerVector,
366 ABlockTransferSrcScalarPerVector,
367 BBlockTransferSrcScalarPerVector,
386 ABlockTransferSrcScalarPerVector,
387 BBlockTransferSrcScalarPerVector,
407 using Base::wmma_gemm;
409 using Base::CalculateCThreadOriginDataIndex;
411 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
412 using Base::GetCThreadBuffer;
414 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
416 using Base::a_block_desc_k0_m0_m1_m2_k1;
417 using Base::b_block_desc_k0_n0_n1_n2_k1;
419 using typename Base::Empty;
436 template <
bool HasMainLoop,
440 typename ABlockTransfer,
441 typename AGridBuffer,
442 typename ABlockBuffer,
443 typename ABlockTransferStep,
446 typename BBlockTransfer,
447 typename BGridBuffer,
448 typename BBlockBuffer,
449 typename BBlockTransferStep,
450 typename CThreadBuffer,
451 typename BScaleStruct>
452 __device__
void Run(
const AGridDesc& a_grid_desc,
453 const ABlockDesc& a_block_desc,
454 ABlockTransfer& a_blockwise_copy,
455 const AGridBuffer& a_grid_buf,
456 ABlockBuffer& a_block_buf,
457 const ABlockTransferStep& a_block_copy_step,
458 const BGridDesc& b_grid_desc,
459 const BBlockDesc& b_block_desc,
460 BBlockTransfer& b_blockwise_copy,
461 const BGridBuffer& b_grid_buf,
462 BBlockBuffer& b_block_buf,
463 const BBlockTransferStep& b_block_copy_step,
464 CThreadBuffer& c_thread_buf,
466 BScaleStruct& b_scale_struct,
468 index_t num_loop_per_scale)
const
470 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
471 a_thread_desc_.GetElementSpaceSize());
472 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
473 b_thread_desc_.GetElementSpaceSize());
476 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
477 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
479 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
480 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
482 b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
485 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
486 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
489 c_thread_buf.Clear();
491 auto blockwise_gemm_func = [&]() {
496 a_block_desc_k0_m0_m1_m2_k1,
512 b_block_desc_k0_n0_n1_n2_k1,
529 b_block_desc_k0_n0_n1_n2_k1,
537 b_scale_struct.b_scale_thread_bufs(I0)[
Number<
538 n0 * BScaleStruct::num_scale_k_block +
539 (k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
547 __builtin_amdgcn_sched_barrier(0);
554 if constexpr(k0_offset != 0 || KRepeat == 1)
556 __builtin_amdgcn_s_barrier();
557 __builtin_amdgcn_sched_barrier(0);
562 vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
563 vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
565 static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
566 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
567 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
575 static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
576 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
577 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
586 using wmma_input_type_a =
587 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
588 using wmma_input_type_b =
589 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
592 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
600 if constexpr(k0_offset + k0_inner == KRepeat - 1 && m0 == MRepeat - 1 &&
603 __builtin_amdgcn_sched_barrier(0);
605 __builtin_amdgcn_sched_barrier(0);
607 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
608 b_thread_vec.template AsType<wmma_input_type_b>(),
610 if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
612 __builtin_amdgcn_sched_barrier(0);
613 __builtin_amdgcn_s_setprio(1);
614 __builtin_amdgcn_sched_barrier(0);
619 __builtin_amdgcn_sched_barrier(0);
620 __builtin_amdgcn_s_setprio(0);
621 __builtin_amdgcn_sched_barrier(0);
626 if constexpr(HasMainLoop)
631 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
632 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
634 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
635 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
638 blockwise_gemm_func();
640 b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
641 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
642 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
645 }
while(i < (num_loop - 1));
652 blockwise_gemm_func();
657 static constexpr
auto a_thread_desc_ =
660 Number<KRepeatPerCluster>{},
666 Number<KPack / A_KRow * MRepeat>{},
671 static constexpr
auto b_thread_desc_ =
674 Number<KRepeatPerCluster>{},
680 Number<KPack / B_KRow * NRepeat>{},
688 decltype(a_block_desc_k0_m0_m1_m2_k1),
689 decltype(a_thread_desc_),
690 Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
699 decltype(b_block_desc_k0_n0_n1_n2_k1),
700 decltype(b_thread_desc_),
701 Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
707 AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
708 BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
709 using Base::c_thread_desc_;
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:208
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:35
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, TransposeC >::BlockHasHotloop static bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:428
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, TransposeC >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:452
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, TransposeC >::BlockLoopTailNum static TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:430
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, TransposeC >::BlockHasHotloop static bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:145
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, TransposeC >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:169
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack, TransposeC >::BlockLoopTailNum static TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:147
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:37
Definition: sequence.hpp:43
ck::ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeTypeA, decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_thread_desc_), Sequence< KPack/A_K1/A_KRow, 1, 1, 1, 1, A_K1 >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, A_K1, A_K1 >
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10