20 typename ComputeDataType,
24 typename AMmaTileDesc,
25 typename BMmaTileDesc,
26 index_t ABlockTransferSrcScalarPerVector,
27 index_t BBlockTransferSrcScalarPerVector,
43 typename ComputeDataType,
47 typename AMmaTileDesc,
48 typename BMmaTileDesc,
49 index_t ABlockTransferSrcScalarPerVector,
50 index_t BBlockTransferSrcScalarPerVector,
71 ABlockTransferSrcScalarPerVector,
72 BBlockTransferSrcScalarPerVector,
90 ABlockTransferSrcScalarPerVector,
91 BBlockTransferSrcScalarPerVector,
111 ABlockTransferSrcScalarPerVector,
112 BBlockTransferSrcScalarPerVector,
128 using Base::xdlops_gemm;
131 using Base::a_block_desc_m0_m1_m2_k;
132 using Base::CalculateCThreadOriginDataIndex;
133 using Base::CalculateCThreadOriginDataIndex8D;
134 using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
135 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
136 using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
137 using Base::GetCThreadBuffer;
138 using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
139 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
140 using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
141 using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
142 using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
144 using Base::AMmaKStride;
145 using Base::BMmaKStride;
152 static constexpr
index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1;
154 template <
typename TileDesc_M0_M1_M2_K>
160 constexpr
index_t K2 = KPack / KGroup;
161 constexpr
index_t K1 = 64 / NPerXDL;
162 constexpr
index_t K0 = KRepeat * KGroup;
165 TileDesc_M0_M1_M2_K{},
175 static constexpr
auto a_block_desc_m0_m1_m2_k0_k1_k2 =
176 MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
180 return num_loop > PrefetchStages;
192 constexpr
auto num_ds_read_inst_a =
193 HotLoopInstList::A_LDS_Read_Width *
sizeof(ADataType) == 16
194 ? HotLoopInstList::A_LDS_Read_Inst_Num
195 : HotLoopInstList::A_LDS_Read_Inst_Num / 2;
197 constexpr
auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
199 constexpr
auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
200 constexpr
auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * 2;
202 static_assert(num_buffer_load_inst_a == num_ds_write_inst_a);
204 constexpr
auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * 2;
205 constexpr
auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
207 constexpr
auto ds_read_a_issue_cycle =
208 HotLoopInstList::A_LDS_Read_Width *
sizeof(ADataType) == 16 ? 8 : 4;
209 constexpr
auto ds_read_a_mfma_rate =
215 constexpr
auto num_total_stages = MRepeat;
219 constexpr
auto num_mfma_perstage = num_mfma_inst / num_total_stages;
220 constexpr
auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
222 constexpr
auto num_ds_read_a_mfma_perstage =
225 constexpr
auto num_ds_read_a_prefetch_stages = 2;
228 (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2));
230 (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2));
232 constexpr
auto buffer_load_stages_more =
233 (num_buffer_load_inst_a + num_buffer_load_inst_b) -
235 (num_total_stages - 2)) *
236 ((num_total_stages - 2));
238 constexpr
auto buffer_load_b_stages =
239 buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b
240 ? num_buffer_load_inst_b / buffer_load_perstage_more
241 : (buffer_load_stages_more +
242 (num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) /
243 buffer_load_perstage_less);
245 constexpr
auto buffer_load_a_stages =
246 num_total_stages - num_ds_read_a_prefetch_stages - buffer_load_b_stages;
248 constexpr
auto buffer_load_issue_point_b = 0;
249 constexpr
auto buffer_load_issue_point_interval_more =
250 num_mfma_perstage / buffer_load_perstage_more;
251 constexpr
auto buffer_load_issue_point_interval_less =
252 num_mfma_perstage / buffer_load_perstage_less;
253 constexpr
auto ds_write_issue_point = 0;
254 constexpr
auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0;
259 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
261 if constexpr(((i < buffer_load_stages_more) &&
262 (imfma % buffer_load_issue_point_interval_more ==
263 buffer_load_issue_point_b)) ||
264 ((i >= buffer_load_stages_more) &&
265 (imfma % buffer_load_issue_point_interval_less ==
266 buffer_load_issue_point_b)))
268 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
271 if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
273 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0);
281 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
282 if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
283 (imfma % buffer_load_issue_point_interval_more ==
284 ds_write_issue_point)) ||
285 (((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
286 (imfma % buffer_load_issue_point_interval_less ==
287 ds_write_issue_point)))
289 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
291 if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
292 (imfma % buffer_load_issue_point_interval_more ==
293 buffer_load_issue_point_a)) ||
294 (((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
295 (imfma % buffer_load_issue_point_interval_less ==
296 buffer_load_issue_point_a)))
298 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
300 if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
302 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0);
311 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
312 if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
314 __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0);
320 template <
typename Stage>
323 constexpr
auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
324 constexpr
auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
325 constexpr
auto num_buffer_load_inst_b =
326 MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num * 2;
328 constexpr
auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num * 2;
330 constexpr
auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat;
331 constexpr
auto staged_num_mfma = num_mfma / MRepeat;
333 constexpr
auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
335 if constexpr(stage.value == 0)
337 constexpr
auto staged_num_buffer_load_b_per_ds_read_a =
338 num_buffer_load_inst_b / staged_num_ds_read_inst_a;
339 constexpr
auto staged_num_mfma_per_buffer_load_b =
340 staged_num_mfma / num_buffer_load_inst_b;
347 __builtin_amdgcn_sched_group_barrier(
348 0x008, staged_num_mfma_per_buffer_load_b, 0);
349 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
352 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
353 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
354 __builtin_amdgcn_sched_group_barrier(
355 0x008, staged_num_mfma_per_buffer_load_b - 1, 0);
356 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0);
359 __builtin_amdgcn_sched_barrier(0);
361 else if constexpr(stage.value == 1)
363 constexpr
auto staged_num_mfma_per_ds_write_a =
366 constexpr
auto stage_more_mfma =
367 staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a;
371 if constexpr(i_inst.value < stage_more_mfma)
373 if(i_inst.value < staged_num_ds_read_inst_a)
375 __builtin_amdgcn_sched_group_barrier(
376 0x008, staged_num_mfma_per_ds_write_a - 1, 0);
377 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
378 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
379 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
383 __builtin_amdgcn_sched_group_barrier(
384 0x008, staged_num_mfma_per_ds_write_a, 0);
385 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
390 if(i_inst.value < staged_num_ds_read_inst_a)
392 __builtin_amdgcn_sched_group_barrier(
393 0x008, staged_num_mfma_per_ds_write_a - 2, 0);
394 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
395 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0);
396 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
400 __builtin_amdgcn_sched_group_barrier(
401 0x008, staged_num_mfma_per_ds_write_a - 1, 0);
402 __builtin_amdgcn_sched_group_barrier(0x200, 1, 0);
406 __builtin_amdgcn_sched_barrier(0);
413 __builtin_amdgcn_sched_group_barrier(
414 0x008, staged_num_mfma_per_ds_read_a, 0);
415 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
418 __builtin_amdgcn_sched_barrier(0);
424 constexpr
auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
426 constexpr
auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num * 2;
428 constexpr
auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat;
429 constexpr
auto staged_num_mfma = num_mfma / MRepeat;
431 constexpr
auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
436 __builtin_amdgcn_sched_group_barrier(0x008, staged_num_mfma_per_ds_read_a, 0);
437 __builtin_amdgcn_sched_group_barrier(0x100, 1, 0);
440 __builtin_amdgcn_sched_barrier(0);
443 template <
bool HasMainLoop,
447 typename ABlockTransfer,
448 typename AGridBuffer,
449 typename ABlockBuffer,
450 typename ABlockTransferStep,
452 typename BBlockTransfer,
453 typename BGridBuffer,
454 typename BBlockBuffer,
455 typename BBlockTransferStep,
456 typename CThreadBuffer>
457 __device__
void Run(
const AGridDesc& a_grid_desc,
458 const ABlockDesc& a_block_desc,
459 ABlockTransfer& a_blockwise_copy,
460 const AGridBuffer& a_grid_buf,
461 ABlockBuffer& a_block_buf,
462 const ABlockTransferStep& a_block_copy_step,
463 const BGridDesc& b_grid_desc,
464 BBlockTransfer& b_blockwise_copy,
465 BBlockTransfer& b_blockwise_copy_up,
466 const BGridBuffer& b_grid_buf,
467 const BGridBuffer& b_grid_buf_up,
468 BBlockBuffer& b_block_buf,
469 const BBlockTransferStep& b_block_copy_step,
470 CThreadBuffer& c_thread_buf,
471 CThreadBuffer& c_thread_buf_up,
475 __builtin_amdgcn_sched_barrier(0);
476 auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
477 a_thread_desc_.GetElementSpaceSize());
478 auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
479 b_thread_desc_.GetElementSpaceSize());
483 constexpr
auto b_block_origin_idx =
make_tuple(I0, I0, I0, I0);
486 b_blockwise_copy.Run(b_grid_desc,
488 b_block_desc_n0_n1_k0_k1,
492 b_blockwise_copy_up.Run(b_grid_desc,
494 b_block_desc_n0_n1_k0_k1,
496 b_thread_bufs_up(I0));
497 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
498 b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
500 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
501 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
502 __builtin_amdgcn_sched_barrier(0);
505 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
508 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
509 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
516 a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
527 c_thread_buf.Clear();
528 c_thread_buf_up.Clear();
530 __builtin_amdgcn_sched_barrier(0);
533 if constexpr(HasMainLoop)
538 auto LoopFunc = [&](
auto mfma_reg_buf,
auto local_read_buf) {
539 b_blockwise_copy.Run(b_grid_desc,
541 b_block_desc_n0_n1_k0_k1,
543 b_thread_bufs(local_read_buf));
544 b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
545 b_blockwise_copy_up.Run(b_grid_desc,
547 b_block_desc_n0_n1_k0_k1,
549 b_thread_bufs_up(local_read_buf));
550 b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
552 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf));
553 a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
554 a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
563 a_thread_vec.template AsType<ComputeDataType>()(ik) =
564 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
565 make_tuple((m0 + HotloopLocalBufSwitch * mfma_reg_buf) %
572 b_thread_vec.template AsType<ComputeDataType>()(ik) =
573 b_thread_bufs[mfma_reg_buf]
574 [
Number<b_thread_desc_.CalculateOffset(
577 b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
578 b_thread_bufs_up[mfma_reg_buf]
579 [
Number<b_thread_desc_.CalculateOffset(
583 using mfma_input_type =
585 xdlops_gemm.K1PerXdlops>::type;
588 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
591 a_thread_vec.template AsType<mfma_input_type>(),
592 b_thread_vec.template AsType<mfma_input_type>(),
596 a_thread_vec.template AsType<mfma_input_type>(),
597 b_thread_vec_up.template AsType<mfma_input_type>(),
602 if constexpr(m0.value == MRepeat - 2)
609 a_block_desc_m0_m1_m2_k0_k1_k2,
616 a_block_buf.At(local_read_buf),
619 Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
630 else if constexpr(m0.value == (MRepeat - 1))
635 a_block_desc_m0_m1_m2_k0_k1_k2,
642 a_block_buf.At(local_read_buf),
645 Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
661 a_block_desc_m0_m1_m2_k0_k1_k2,
668 a_block_buf.At(mfma_reg_buf),
671 Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
690 }
while(i < (num_loop - 2));
695 b_blockwise_copy.Run(b_grid_desc,
697 b_block_desc_n0_n1_k0_k1,
701 b_blockwise_copy_up.Run(b_grid_desc,
703 b_block_desc_n0_n1_k0_k1,
705 b_thread_bufs_up(I1));
706 a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
715 a_thread_vec.template AsType<ComputeDataType>()(ik) =
716 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
718 b_thread_vec.template AsType<ComputeDataType>()(ik) =
719 b_thread_bufs[I0][
Number<b_thread_desc_.CalculateOffset(
722 b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
723 b_thread_bufs_up[I0][
Number<b_thread_desc_.CalculateOffset(
727 using mfma_input_type =
728 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
731 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
733 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
734 b_thread_vec.template AsType<mfma_input_type>(),
737 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
738 b_thread_vec_up.template AsType<mfma_input_type>(),
742 if constexpr(m0.value == (MRepeat - 2))
749 a_block_desc_m0_m1_m2_k0_k1_k2,
764 else if constexpr(m0.value == MRepeat - 1)
769 a_block_desc_m0_m1_m2_k0_k1_k2,
789 a_block_desc_m0_m1_m2_k0_k1_k2,
816 a_thread_vec.template AsType<ComputeDataType>()(ik) =
818 (m0 + HotloopLocalBufSwitch) % 2, I0, I0, k0, I0, ik))>{}];
819 b_thread_vec.template AsType<ComputeDataType>()(ik) =
820 b_thread_bufs[I1][
Number<b_thread_desc_.CalculateOffset(
822 b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
823 b_thread_bufs_up[I1][
Number<b_thread_desc_.CalculateOffset(
827 using mfma_input_type =
828 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
831 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
833 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
834 b_thread_vec.template AsType<mfma_input_type>(),
837 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
838 b_thread_vec_up.template AsType<mfma_input_type>(),
843 if constexpr(m0.value < (MRepeat - 2))
848 a_block_desc_m0_m1_m2_k0_k1_k2,
879 a_thread_vec.template AsType<ComputeDataType>()(ik) =
880 a_thread_buf[
Number<a_thread_desc_.CalculateOffset(
882 b_thread_vec.template AsType<ComputeDataType>()(ik) =
883 b_thread_bufs[I0][
Number<b_thread_desc_.CalculateOffset(
885 b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
886 b_thread_bufs_up[I0][
Number<b_thread_desc_.CalculateOffset(
890 using mfma_input_type =
891 typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
894 c_thread_desc_.CalculateOffset(
make_tuple(m0, n0, 0));
896 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
897 b_thread_vec.template AsType<mfma_input_type>(),
899 xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
900 b_thread_vec_up.template AsType<mfma_input_type>(),
905 if constexpr(m0.value < (MRepeat - 2))
910 a_block_desc_m0_m1_m2_k0_k1_k2,
934 decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
935 decltype(a_thread_desc_),
936 Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
942 AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()};
949 using Base::c_thread_desc_;
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto integer_divide_floor(X x, Y y)
Definition: math.hpp:66
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:300
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
ck::BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::Run __device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, BBlockTransfer &b_blockwise_copy, BBlockTransfer &b_blockwise_copy_up, const BGridBuffer &b_grid_buf, const BGridBuffer &b_grid_buf_up, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, CThreadBuffer &c_thread_buf_up, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp:457
ck::BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::EpilogueScheduler_2 static constexpr __device__ auto EpilogueScheduler_2()
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp:422
ck::BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::HotLoopScheduler static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp:188
ck::BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockLoopTailNum __host__ static constexpr __device__ TailNumber BlockLoopTailNum(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp:183
ck::BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::EpilogueScheduler_1 static constexpr __device__ auto EpilogueScheduler_1(Stage stage)
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp:321
ck::BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::b_block_desc_n0_n1_k0_k1 static constexpr BTileDesc b_block_desc_n0_n1_k0_k1
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp:947
ck::BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::MakeAGemmMmaTileDescriptor __host__ static constexpr __device__ auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K &)
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp:155
ck::BlockwiseGemmXdlops_pipeline_bpreshuffle_gufusion_v3< BlockGemmPipelineScheduler::Intrawave, BlockSize, ADataType, BDataType, ComputeDataType, AccDataType, ATileDesc, BTileDesc, AMmaTileDesc, BMmaTileDesc, ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXDL, NPerXDL, MRepeat, NRepeat, KPack >::BlockHasHotloop __host__ static constexpr __device__ bool BlockHasHotloop(index_t num_loop)
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp:178
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp:37
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
Definition: sequence.hpp:43
ck::ThreadwiseTensorSliceTransfer_v4< ADataType, ComputeDataType, decltype(a_block_desc_m0_m1_m2_k0_k1_k2), decltype(a_thread_desc_), Sequence< 1, 1, 1, 1, 1, KPack/KGroup >, Sequence< 0, 1, 2, 3, 4, 5 >, 5, A_K1, A_K1 >
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10