20 typename ComputeTypeA,
21 typename ComputeTypeB,
23 typename AWmmaTileDesc,
24 typename BWmmaTileDesc,
25 index_t ABlockTransferSrcScalarPerVector,
26 index_t BBlockTransferSrcScalarPerVector,
42 typename ComputeTypeA,
43 typename ComputeTypeB,
45 typename AWmmaTileDesc,
46 typename BWmmaTileDesc,
47 index_t ABlockTransferSrcScalarPerVector,
48 index_t BBlockTransferSrcScalarPerVector,
66 ABlockTransferSrcScalarPerVector,
67 BBlockTransferSrcScalarPerVector,
84 ABlockTransferSrcScalarPerVector,
85 BBlockTransferSrcScalarPerVector,
104 ABlockTransferSrcScalarPerVector,
105 BBlockTransferSrcScalarPerVector,
123 using Base::wmma_gemm;
125 using Base::CalculateCThreadOriginDataIndex;
127 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
128 using Base::GetCThreadBuffer;
130 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
132 using Base::a_block_desc_k0_m0_m1_m2_k1;
133 using Base::b_block_desc_k0_n0_n1_n2_k1;
147 template <
bool HasMainLoop,
151 typename ABlockTransfer,
152 typename AGridBuffer,
153 typename ABlockBuffer,
154 typename ABlockTransferStep,
157 typename BBlockTransfer,
158 typename BGridBuffer,
159 typename BBlockBuffer,
160 typename BBlockTransferStep,
161 typename CThreadBuffer>
162 __device__
void Run(
const AGridDesc& a_grid_desc,
163 const ABlockDesc& a_block_desc,
164 ABlockTransfer& a_blockwise_copy,
165 const AGridBuffer& a_grid_buf,
166 ABlockBuffer& a_block_buf,
167 const ABlockTransferStep& a_block_copy_step,
168 const BGridDesc& b_grid_desc,
169 const BBlockDesc& b_block_desc,
170 BBlockTransfer& b_blockwise_copy,
171 const BGridBuffer& b_grid_buf,
172 BBlockBuffer& b_block_buf,
173 const BBlockTransferStep& b_block_copy_step,
174 CThreadBuffer& c_thread_buf,
177 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
178 a_thread_desc_.GetElementSpaceSize());
179 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
180 b_thread_desc_.GetElementSpaceSize());
183 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
184 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
186 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
187 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
190 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
191 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
194 c_thread_buf.Clear();
196 auto blockwise_gemm_func = [&]() {
199 a_block_desc_k0_m0_m1_m2_k1,
206 b_block_desc_k0_n0_n1_n2_k1,
215 vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
216 vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
218 static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
219 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
223 static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
224 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
229 using wmma_input_type_a =
230 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
231 using wmma_input_type_b =
232 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
235 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
237 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
238 b_thread_vec.template AsType<wmma_input_type_b>(),
246 if constexpr(HasMainLoop)
251 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
252 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
254 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
255 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
258 blockwise_gemm_func();
261 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
262 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
265 }
while(i < (num_loop - 1));
272 blockwise_gemm_func();
277 using Base::a_thread_copy_;
278 using Base::a_thread_desc_;
279 using Base::b_thread_copy_;
280 using Base::b_thread_desc_;
281 using Base::c_thread_desc_;
287 typename ComputeTypeA,
288 typename ComputeTypeB,
289 typename AccDataType,
290 typename AWmmaTileDesc,
291 typename BWmmaTileDesc,
292 index_t ABlockTransferSrcScalarPerVector,
293 index_t BBlockTransferSrcScalarPerVector,
311 ABlockTransferSrcScalarPerVector,
312 BBlockTransferSrcScalarPerVector,
329 ABlockTransferSrcScalarPerVector,
330 BBlockTransferSrcScalarPerVector,
349 ABlockTransferSrcScalarPerVector,
350 BBlockTransferSrcScalarPerVector,
369 using Base::wmma_gemm;
371 using Base::CalculateCThreadOriginDataIndex;
373 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
374 using Base::GetCThreadBuffer;
376 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
378 using Base::a_block_desc_k0_m0_m1_m2_k1;
379 using Base::b_block_desc_k0_n0_n1_n2_k1;
396 template <
bool HasMainLoop,
400 typename ABlockTransfer,
401 typename AGridBuffer,
402 typename ABlockBuffer,
403 typename ABlockTransferStep,
406 typename BBlockTransfer,
407 typename BGridBuffer,
408 typename BBlockBuffer,
409 typename BBlockTransferStep,
410 typename CThreadBuffer>
411 __device__
void Run(
const AGridDesc& a_grid_desc,
412 const ABlockDesc& a_block_desc,
413 ABlockTransfer& a_blockwise_copy,
414 const AGridBuffer& a_grid_buf,
415 ABlockBuffer& a_block_buf,
416 const ABlockTransferStep& a_block_copy_step,
417 const BGridDesc& b_grid_desc,
418 const BBlockDesc& b_block_desc,
419 BBlockTransfer& b_blockwise_copy,
420 const BGridBuffer& b_grid_buf,
421 BBlockBuffer& b_block_buf,
422 const BBlockTransferStep& b_block_copy_step,
423 CThreadBuffer& c_thread_buf,
426 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
427 a_thread_desc_.GetElementSpaceSize());
428 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
429 b_thread_desc_.GetElementSpaceSize());
432 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
433 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
435 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
436 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
439 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
440 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
443 c_thread_buf.Clear();
445 auto blockwise_gemm_func = [&]() {
449 a_block_desc_k0_m0_m1_m2_k1,
461 b_block_desc_k0_n0_n1_n2_k1,
474 __builtin_amdgcn_sched_barrier(0);
481 if constexpr(k0_offset != 0 || KRepeat == 1)
483 __builtin_amdgcn_s_barrier();
484 __builtin_amdgcn_sched_barrier(0);
489 vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
490 vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
492 static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
493 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
494 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
502 static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
503 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
504 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
513 using wmma_input_type_a =
514 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
515 using wmma_input_type_b =
516 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
519 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
527 if constexpr(k0_offset + k0_inner == KRepeat - 1 && m0 == MRepeat - 1 &&
530 __builtin_amdgcn_sched_barrier(0);
532 __builtin_amdgcn_sched_barrier(0);
534 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
535 b_thread_vec.template AsType<wmma_input_type_b>(),
537 if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
539 __builtin_amdgcn_sched_barrier(0);
540 __builtin_amdgcn_s_setprio(1);
541 __builtin_amdgcn_sched_barrier(0);
546 __builtin_amdgcn_sched_barrier(0);
547 __builtin_amdgcn_s_setprio(0);
548 __builtin_amdgcn_sched_barrier(0);
553 if constexpr(HasMainLoop)
558 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
559 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
561 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
562 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
565 blockwise_gemm_func();
567 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
568 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
571 }
while(i < (num_loop - 1));
578 blockwise_gemm_func();
583 static constexpr
auto a_thread_desc_ =
586 Number<KRepeatPerCluster>{},
592 Number<KPack / A_KRow * MRepeat>{},
597 static constexpr
auto b_thread_desc_ =
600 Number<KRepeatPerCluster>{},
606 Number<KPack / B_KRow * NRepeat>{},
614 decltype(a_block_desc_k0_m0_m1_m2_k1),
615 decltype(a_thread_desc_),
616 Sequence<KPack / A_K1 / A_KRow, MRepeat, 1, 1, 1, A_K1>,
625 decltype(b_block_desc_k0_n0_n1_n2_k1),
626 decltype(b_thread_desc_),
627 Sequence<KPack / B_K1 / B_KRow, NRepeat, 1, 1, 1, B_K1>,
633 AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
634 BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
635 using Base::c_thread_desc_;
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:207
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:300
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:35
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:411
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::BlockHasHotloop static bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:388
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Interwave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::BlockLoopTailNum static TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:390
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:162
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::BlockHasHotloop static bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:139
ck::BlockwiseGemmWmmaops_pipeline_v1< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::BlockLoopTailNum static TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:141
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:36
Definition: sequence.hpp:43
ck::ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeTypeA, decltype(a_block_desc_k0_m0_m1_m2_k1), decltype(a_thread_desc_), Sequence< KPack/A_K1/A_KRow, MRepeat, 1, 1, 1, A_K1 >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, A_K1, A_K1 >
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10