20 typename ComputeTypeA,
21 typename ComputeTypeB,
23 typename AWmmaTileDesc,
24 typename BWmmaTileDesc,
25 index_t ABlockTransferSrcScalarPerVector,
26 index_t BBlockTransferSrcScalarPerVector,
42 typename ComputeTypeA,
43 typename ComputeTypeB,
45 typename AWmmaTileDesc,
46 typename BWmmaTileDesc,
47 index_t ABlockTransferSrcScalarPerVector,
48 index_t BBlockTransferSrcScalarPerVector,
66 ABlockTransferSrcScalarPerVector,
67 BBlockTransferSrcScalarPerVector,
84 ABlockTransferSrcScalarPerVector,
85 BBlockTransferSrcScalarPerVector,
103 ABlockTransferSrcScalarPerVector,
104 BBlockTransferSrcScalarPerVector,
122 using Base::wmma_gemm;
125 using Base::CalculateCThreadOriginDataIndex;
127 GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
128 using Base::GetCThreadBuffer;
130 GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
132 using Base::a_block_desc_k0_m0_m1_m2_k1;
133 using Base::b_block_desc_k0_n0_n1_n2_k1;
141 return num_loop > PrefetchStages;
258 template <
bool HasMainLoop,
262 typename ABlockTransfer,
263 typename AGridBuffer,
264 typename ABlockBuffer,
265 typename ABlockTransferStep,
268 typename BBlockTransfer,
269 typename BGridBuffer,
270 typename BBlockBuffer,
271 typename BBlockTransferStep,
272 typename CThreadBuffer>
273 __device__
void Run(
const AGridDesc& a_grid_desc,
274 const ABlockDesc& a_block_desc,
275 ABlockTransfer& a_blockwise_copy,
276 const AGridBuffer& a_grid_buf,
277 ABlockBuffer& a_block_buf,
278 const ABlockTransferStep& a_block_copy_step,
279 const BGridDesc& b_grid_desc,
280 const BBlockDesc& b_block_desc,
281 BBlockTransfer& b_blockwise_copy,
282 const BGridBuffer& b_grid_buf,
283 BBlockBuffer& b_block_buf,
284 const BBlockTransferStep& b_block_copy_step,
285 CThreadBuffer& c_thread_buf,
288 __builtin_amdgcn_sched_barrier(0);
289 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
290 a_thread_desc_.GetElementSpaceSize());
291 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
292 b_thread_desc_.GetElementSpaceSize());
295 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
296 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
298 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
299 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
302 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
303 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
306 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
307 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
309 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
310 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
313 c_thread_buf.Clear();
318 a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
324 b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
332 __builtin_amdgcn_sched_barrier(0);
335 if constexpr(HasMainLoop)
342 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
343 b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
345 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
346 b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
348 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
349 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
354 vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
355 vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
357 static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
358 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
359 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
367 static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
368 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
369 b_thread_buf[
Number<b_thread_desc_.CalculateOffset(
378 using wmma_input_type_a =
379 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
380 using wmma_input_type_b =
381 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
384 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
386 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
387 b_thread_vec.template AsType<wmma_input_type_b>(),
397 a_block_desc_k0_m0_m1_m2_k1,
404 b_block_desc_k0_n0_n1_n2_k1,
413 __builtin_amdgcn_sched_barrier(0);
416 }
while(i < (num_loop - 1));
424 vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
425 vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
427 static_for<0, KPack / A_KRow, 1>{}([&](
auto ik) {
428 a_thread_vec.template AsType<ComputeTypeA>()(ik) =
432 static_for<0, KPack / B_KRow, 1>{}([&](
auto ik) {
433 b_thread_vec.template AsType<ComputeTypeB>()(ik) =
438 using wmma_input_type_a =
439 typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
440 using wmma_input_type_b =
441 typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
444 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, I0));
446 wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
447 b_thread_vec.template AsType<wmma_input_type_b>(),
459 using Base::a_thread_copy_;
460 using Base::a_thread_desc_;
461 using Base::b_thread_copy_;
462 using Base::b_thread_desc_;
463 using Base::c_thread_desc_;
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:300
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:35
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:139
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:144
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::HotLoopScheduler static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:150
ck::BlockwiseGemmWmmaops_pipeline_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeTypeA, ComputeTypeB, AccDataType, AWmmaTileDesc, BWmmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerWmma, NPerWmma, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:273
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:36
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10