19 template <
typename ALayout,
25 typename CShuffleDataType,
27 typename AElementwiseOperation,
28 typename BElementwiseOperation,
29 typename CElementwiseOperation,
41 typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
42 typename ABlockTransferThreadClusterArrangeOrder,
43 typename ABlockTransferSrcAccessOrder,
44 index_t ABlockTransferSrcVectorDim,
45 index_t ABlockTransferSrcScalarPerVector,
46 index_t ABlockTransferDstScalarPerVector_AK1,
47 bool AThreadTransferSrcResetCoordinateAfterRun,
49 typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
50 typename BBlockTransferThreadClusterArrangeOrder,
51 typename BBlockTransferSrcAccessOrder,
52 index_t BBlockTransferSrcVectorDim,
53 index_t BBlockTransferSrcScalarPerVector,
54 index_t BBlockTransferDstScalarPerVector_BK1,
55 bool BThreadTransferSrcResetCoordinateAfterRun,
57 index_t CShuffleMXdlPerWavePerShuffle,
58 index_t CShuffleNXdlPerWavePerShuffle,
59 typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
60 index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
63 typename ComputeTypeA = CDataType,
64 typename ComputeTypeB = ComputeTypeA>
110 auto K_t = K_Batch * KPerBlock;
111 return (K + K_t - 1) / K_t * (KPerBlock / AK1Value);
116 auto K_t = K_Batch * KPerBlock;
117 return (K + K_t - 1) / K_t * (KPerBlock / BK1Value);
122 auto K_t = K_Batch * KPerBlock;
123 return (K + K_t - 1) / K_t * KPerBlock;
129 auto K_t = K_Batch * KReadVec;
130 return (K + K_t - 1) / K_t * KReadVec;
143 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl,
typename TileDesc_K0_MN_K1>
158 template <
typename ABlockDesc_AK0_M_AK1>
159 __host__ __device__
static constexpr
auto
162 constexpr
index_t MWaves = MPerBlock / (MXdlPerWave * MPerXdl);
164 return MakeGemmMmaTileDescriptor<MXdlPerWave, MWaves, MPerXdl>(ABlockDesc_AK0_M_AK1{});
167 template <
typename BBlockDesc_BK0_N_BK1>
168 __host__ __device__
static constexpr
auto
171 constexpr
index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl);
173 return MakeGemmMmaTileDescriptor<NXdlPerWave, NWaves, NPerXdl>(BBlockDesc_BK0_N_BK1{});
205 std::cout <<
"problem {"
214 <<
"KRead:" <<
KRead <<
", "
216 <<
"AK0:" <<
AK0 <<
", "
217 <<
"BK0:" <<
BK0 <<
", "
218 <<
"MBlock: " <<
MBlock <<
", "
219 <<
"NBlock: " <<
NBlock <<
"}" << std::endl;
243 const BDataType* p_b_grid_,
244 CDataType* p_c_grid_,
252 :
Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, k_batch_},
267 if constexpr(ABlockLdsExtraM)
277 constexpr
auto MLdsLayer = 32 * 4 / KPerBlock /
sizeof(ADataType) < 1
279 : 32 * 4 / KPerBlock /
sizeof(ADataType);
294 a_lds_block_desc_permuted,
302 a_lds_block_desc_ak0_mldslayer_m_ak1,
310 return a_lds_block_desc_ak0_m_ak1;
317 constexpr
auto M0 = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I1);
318 constexpr
auto M1 = MPerBlock / M0;
320 constexpr
auto KThreadWrite = ABlockTransferThreadClusterLengths_AK0_M_AK1{}.At(
I0);
321 constexpr
auto K0PerThreadWrite =
AK0Number / KThreadWrite;
322 constexpr
auto KThreadRead = 64 / MPerXdl;
323 constexpr
auto K0PerThreadRead =
AK0Number / KThreadRead;
325 constexpr
auto kfold = (
AK1Number * M0 *
sizeof(ADataType) > 128)
327 : 128 / (
AK1Number * M0 *
sizeof(ADataType));
328 constexpr
auto KThreadReadPerm =
329 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
330 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
334 constexpr
auto mpair = (
AK1Number * MPerXdl *
sizeof(ADataType) > 128)
336 : ((128 / (
AK1Number * MPerXdl *
sizeof(ADataType))) > M0
338 : 128 / (
AK1Number * MPerXdl *
sizeof(ADataType)));
344 Number<kfold * M0 / mpair>{},
363 a_lds_block_desc_permuted,
385 a_lds_block_desc_unmerged,
388 Number<KThreadWrite / kfold / KThreadReadPerm>{},
397 return a_lds_block_desc_ak0_m_ak1;
404 if constexpr(BBlockLdsExtraN)
413 constexpr
auto NLdsLayer = 32 * 4 / KPerBlock /
sizeof(BDataType) < 1
415 : 32 * 4 / KPerBlock /
sizeof(BDataType);
431 b_lds_block_desc_permuted,
439 b_lds_block_desc_bk0_nldslayer_n_bk1,
447 return b_lds_block_desc_bk0_n_bk1;
451 constexpr
auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I1);
452 constexpr
auto N1 = NPerBlock / N0;
454 constexpr
auto KThreadWrite = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(
I0);
455 constexpr
auto K0PerThreadWrite =
BK0Number / KThreadWrite;
456 constexpr
auto KThreadRead = 64 / NPerXdl;
457 constexpr
auto K0PerThreadRead =
BK0Number / KThreadRead;
459 constexpr
auto kfold = (
BK1Number * N0 *
sizeof(BDataType) > 128)
461 : 128 / (
BK1Number * N0 *
sizeof(BDataType));
462 constexpr
auto KThreadReadPerm =
463 (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
464 ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
468 constexpr
auto npair = (
BK1Number * NPerXdl *
sizeof(BDataType) > 128)
470 : ((128 / (
BK1Number * NPerXdl *
sizeof(BDataType))) > N0
472 : 128 / (
BK1Number * NPerXdl *
sizeof(BDataType)));
478 Number<kfold * N0 / npair>{},
497 b_lds_block_desc_permuted,
519 b_lds_block_desc_unmerged,
522 Number<KThreadWrite / kfold / KThreadReadPerm>{},
531 return b_lds_block_desc_bk0_n_bk1;
537 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
538 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
540 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
547 return c_shuffle_block_desc_mblock_mperblock_nblock_nperblock;
565 ABlockTransferSrcScalarPerVector,
566 BBlockTransferSrcScalarPerVector,
586 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
589 b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
592 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
595 constexpr
auto c_block_size =
596 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
598 return math::max((a_block_space_size_aligned *
sizeof(ADataType) +
599 b_block_space_size_aligned *
sizeof(BDataType)),
600 c_block_size *
sizeof(CShuffleDataType));
605 const index_t num_loop = K / KPerBlock;
607 return BlockwiseGemmPipe::BlockHasHotloop(num_loop);
612 const index_t num_loop = K / KPerBlock;
614 return BlockwiseGemmPipe::BlockLoopTailNum(num_loop);
617 template <
typename CGr
idDesc>
628 return c_grid_desc_mblock_mperblock_nblock_nperblock;
635 template <
typename AGridDesc_AK0_M_K1,
636 typename BGridDesc_BK0_N_K1,
637 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
638 bool HasMainKBlockLoop,
641 __device__
static void Run(
const ADataType* p_a_grid,
642 const BDataType* p_b_grid,
646 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
647 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
648 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
649 c_grid_desc_mblock_mperblock_nblock_nperblock,
652 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
653 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
654 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
655 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
656 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
657 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
659 const AElementwiseOperation a_element_op{};
660 const BElementwiseOperation b_element_op{};
661 const CElementwiseOperation c_element_op{};
666 const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(
669 if(!block_2_ctile_map.ValidCTileIndex(
671 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
672 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
677 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
678 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
681 const index_t m_block_data_idx_on_grid =
682 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
684 const index_t n_block_data_idx_on_grid =
685 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
697 auto a_blockwise_copy =
699 AElementwiseOperation,
703 ABlockTransferThreadClusterLengths_AK0_M_AK1,
704 ABlockTransferThreadClusterArrangeOrder,
707 decltype(a_grid_desc_ak0_m_ak1),
708 decltype(a_block_desc_ak0_m_ak1),
709 ABlockTransferSrcAccessOrder,
711 ABlockTransferSrcVectorDim,
713 ABlockTransferSrcScalarPerVector,
714 ABlockTransferDstScalarPerVector_AK1,
717 AThreadTransferSrcResetCoordinateAfterRun,
719 BlockwiseGemmPipe::GlobalBufferNum>(
720 a_grid_desc_ak0_m_ak1,
723 a_block_desc_ak0_m_ak1,
728 auto b_blockwise_copy =
730 BElementwiseOperation,
734 BBlockTransferThreadClusterLengths_BK0_N_BK1,
735 BBlockTransferThreadClusterArrangeOrder,
738 decltype(b_grid_desc_bk0_n_bk1),
739 decltype(b_block_desc_bk0_n_bk1),
740 BBlockTransferSrcAccessOrder,
742 BBlockTransferSrcVectorDim,
744 BBlockTransferSrcScalarPerVector,
745 BBlockTransferDstScalarPerVector_BK1,
748 BThreadTransferSrcResetCoordinateAfterRun,
750 BlockwiseGemmPipe::GlobalBufferNum>(
751 b_grid_desc_bk0_n_bk1,
754 b_block_desc_bk0_n_bk1,
760 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
763 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
764 static_cast<ADataType*
>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
766 auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
767 static_cast<BDataType*
>(p_shared) +
768 a_block_space_size_aligned *
sizeof(ADataType) /
sizeof(BDataType),
769 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
775 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
777 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
779 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
780 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
781 (KPerBlock * problem.
KBatch));
783 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
784 a_block_desc_ak0_m_ak1,
788 a_block_slice_copy_step,
789 b_grid_desc_bk0_n_bk1,
790 b_block_desc_bk0_n_bk1,
794 b_block_slice_copy_step,
796 num_k_block_main_loop);
800 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
801 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
804 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
805 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
808 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
809 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
813 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
814 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
816 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
817 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
818 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
819 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
820 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
821 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
822 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
823 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
825 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
828 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
829 static_cast<CShuffleDataType*
>(p_shared),
830 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
833 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
853 const auto c_thread_mtx_on_block =
854 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
856 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
857 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
859 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
865 const auto m_thread_data_on_block_idx =
866 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
869 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
875 const auto n_thread_data_on_block_idx =
876 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
880 auto c_thread_copy_vgpr_to_lds =
883 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
884 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
886 Sequence<CShuffleMXdlPerWavePerShuffle,
887 CShuffleNXdlPerWavePerShuffle,
900 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
903 m_thread_data_on_block_idx[
I1],
904 n_thread_data_on_block_idx[
I1],
905 m_thread_data_on_block_idx[
I2],
906 m_thread_data_on_block_idx[
I3],
907 m_thread_data_on_block_idx[
I4],
908 n_thread_data_on_block_idx[
I2]),
914 CElementwiseOperation,
915 CGlobalMemoryDataOperation,
917 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
919 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
920 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
924 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
925 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
928 CShuffleBlockTransferScalarPerVector_NPerBlock,
931 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
933 c_grid_desc_mblock_mperblock_nblock_nperblock,
938 constexpr
auto sfc_c_vgpr =
941 Sequence<CShuffleMXdlPerWavePerShuffle,
942 CShuffleNXdlPerWavePerShuffle,
951 constexpr
auto sfc_c_global =
955 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
957 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
959 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
961 static_assert(num_access == sfc_c_global.GetNumOfAccess(),
"wrong!");
968 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
969 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
971 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
972 c_shuffle_block_buf);
978 c_shuffle_block_copy_lds_to_global.Run(
979 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
981 c_grid_desc_mblock_mperblock_nblock_nperblock,
984 if constexpr(access_id < num_access - 1)
986 constexpr
auto c_global_step = sfc_c_global.GetForwardStep(access_id);
989 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
990 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
996 template <
typename AGridDesc_AK0_M_K1,
997 typename BGridDesc_BK0_N_K1,
998 typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
999 bool HasMainKBlockLoop,
1002 __device__
static void Run_2Lds(
const ADataType* p_a_grid,
1003 const BDataType* p_b_grid,
1004 CDataType* p_c_grid,
1008 const AGridDesc_AK0_M_K1& a_grid_desc_ak0_m_ak1,
1009 const BGridDesc_BK0_N_K1& b_grid_desc_bk0_n_bk1,
1010 const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock&
1011 c_grid_desc_mblock_mperblock_nblock_nperblock,
1014 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1015 p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize());
1016 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1017 p_b_grid, b_grid_desc_bk0_n_bk1.GetElementSpaceSize());
1018 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
1019 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1021 const AElementwiseOperation a_element_op{};
1022 const BElementwiseOperation b_element_op{};
1023 const CElementwiseOperation c_element_op{};
1028 const auto block_work_idx = block_2_ctile_map.CalculateBottomIndex(
1031 if(!block_2_ctile_map.ValidCTileIndex(
1033 make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I0),
1034 c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(
I2))))
1039 const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I0]);
1040 const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[
I1]);
1043 const index_t m_block_data_idx_on_grid =
1044 __builtin_amdgcn_readfirstlane(block_m_id * MPerBlock);
1046 const index_t n_block_data_idx_on_grid =
1047 __builtin_amdgcn_readfirstlane(block_n_id * NPerBlock);
1059 auto a_blockwise_copy =
1061 AElementwiseOperation,
1065 ABlockTransferThreadClusterLengths_AK0_M_AK1,
1066 ABlockTransferThreadClusterArrangeOrder,
1069 decltype(a_grid_desc_ak0_m_ak1),
1070 decltype(a_block_desc_ak0_m_ak1),
1071 ABlockTransferSrcAccessOrder,
1073 ABlockTransferSrcVectorDim,
1075 ABlockTransferSrcScalarPerVector,
1076 ABlockTransferDstScalarPerVector_AK1,
1079 AThreadTransferSrcResetCoordinateAfterRun,
1081 BlockwiseGemmPipe::GlobalBufferNum>(
1082 a_grid_desc_ak0_m_ak1,
1085 a_block_desc_ak0_m_ak1,
1090 auto b_blockwise_copy =
1092 BElementwiseOperation,
1096 BBlockTransferThreadClusterLengths_BK0_N_BK1,
1097 BBlockTransferThreadClusterArrangeOrder,
1100 decltype(b_grid_desc_bk0_n_bk1),
1101 decltype(b_block_desc_bk0_n_bk1),
1102 BBlockTransferSrcAccessOrder,
1104 BBlockTransferSrcVectorDim,
1106 BBlockTransferSrcScalarPerVector,
1107 BBlockTransferDstScalarPerVector_BK1,
1110 BThreadTransferSrcResetCoordinateAfterRun,
1112 BlockwiseGemmPipe::GlobalBufferNum>(
1113 b_grid_desc_bk0_n_bk1,
1116 b_block_desc_bk0_n_bk1,
1122 a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
1124 auto a_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1125 static_cast<ADataType*
>(p_shared_0), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1127 auto b_block_buf_ping = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1128 static_cast<BDataType*
>(p_shared_0) +
1129 a_block_space_size_aligned *
sizeof(ADataType) /
sizeof(BDataType),
1130 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1132 auto a_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1133 static_cast<ADataType*
>(p_shared_1), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
1135 auto b_block_buf_pong = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1136 static_cast<BDataType*
>(p_shared_1) +
1137 a_block_space_size_aligned *
sizeof(ADataType) /
sizeof(BDataType),
1138 b_block_desc_bk0_n_bk1.GetElementSpaceSize());
1140 auto a_block_bufs =
make_tuple(a_block_buf_ping, a_block_buf_pong);
1141 auto b_block_bufs =
make_tuple(b_block_buf_ping, b_block_buf_pong);
1147 static_assert(std::is_default_constructible_v<BlockwiseGemmPipe>);
1149 auto c_thread_buf = blockwise_gemm_pipeline.GetCThreadBuffer();
1151 const index_t num_k_block_main_loop = __builtin_amdgcn_readfirstlane(
1152 (a_grid_desc_ak0_m_ak1.GetLength(
I0) * a_grid_desc_ak0_m_ak1.GetLength(
I2)) /
1153 (KPerBlock * problem.
KBatch));
1155 blockwise_gemm_pipeline.template Run<HasMainKBlockLoop, TailNum>(a_grid_desc_ak0_m_ak1,
1156 a_block_desc_ak0_m_ak1,
1160 a_block_slice_copy_step,
1161 b_grid_desc_bk0_n_bk1,
1162 b_block_desc_bk0_n_bk1,
1166 b_block_slice_copy_step,
1168 num_k_block_main_loop);
1172 static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 &&
1173 NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0,
1176 constexpr
index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl);
1177 constexpr
index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl);
1180 constexpr
auto c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2 =
1181 blockwise_gemm_pipeline.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1185 constexpr
auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp =
1186 blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
1188 constexpr
auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I0);
1189 constexpr
auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I1);
1190 constexpr
auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I2);
1191 constexpr
auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I3);
1192 constexpr
auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I4);
1193 constexpr
auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I5);
1194 constexpr
auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I6);
1195 constexpr
auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(
I7);
1197 constexpr
auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock =
1200 auto c_shuffle_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
1201 static_cast<CShuffleDataType*
>(p_shared_0),
1202 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
1205 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1225 const auto c_thread_mtx_on_block =
1226 blockwise_gemm_pipeline.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
1228 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
1229 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
1231 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
1237 const auto m_thread_data_on_block_idx =
1238 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
1241 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
1247 const auto n_thread_data_on_block_idx =
1248 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
1252 auto c_thread_copy_vgpr_to_lds =
1255 decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1256 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
1258 Sequence<CShuffleMXdlPerWavePerShuffle,
1259 CShuffleNXdlPerWavePerShuffle,
1272 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1275 m_thread_data_on_block_idx[
I1],
1276 n_thread_data_on_block_idx[
I1],
1277 m_thread_data_on_block_idx[
I2],
1278 m_thread_data_on_block_idx[
I3],
1279 m_thread_data_on_block_idx[
I4],
1280 n_thread_data_on_block_idx[
I2]),
1286 CElementwiseOperation,
1287 CGlobalMemoryDataOperation,
1289 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1291 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>,
1292 CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
1296 decltype(c_shuffle_block_desc_mblock_mperblock_nblock_nperblock),
1297 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
1300 CShuffleBlockTransferScalarPerVector_NPerBlock,
1303 {c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1305 c_grid_desc_mblock_mperblock_nblock_nperblock,
1310 constexpr
auto sfc_c_vgpr =
1313 Sequence<CShuffleMXdlPerWavePerShuffle,
1314 CShuffleNXdlPerWavePerShuffle,
1323 constexpr
auto sfc_c_global =
1327 CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl,
1329 CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
1331 constexpr
index_t num_access = sfc_c_vgpr.GetNumOfAccess();
1333 static_assert(num_access == sfc_c_global.GetNumOfAccess(),
"wrong!");
1340 c_thread_copy_vgpr_to_lds.Run(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1341 sfc_c_vgpr.GetIndexTupleOfNumber(access_id),
1343 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1344 c_shuffle_block_buf);
1350 c_shuffle_block_copy_lds_to_global.Run(
1351 c_shuffle_block_desc_mblock_mperblock_nblock_nperblock,
1352 c_shuffle_block_buf,
1353 c_grid_desc_mblock_mperblock_nblock_nperblock,
1356 if constexpr(access_id < num_access - 1)
1358 constexpr
auto c_global_step = sfc_c_global.GetForwardStep(access_id);
1361 c_shuffle_block_copy_lds_to_global.MoveDstSliceWindow(
1362 c_grid_desc_mblock_mperblock_nblock_nperblock, c_global_step);
Y __host__ constexpr __device__ auto lcm(X x, Y y)
Definition: math.hpp:198
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
GemmSpecialization
Definition: gemm_specialization.hpp:11
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
InMemoryDataOperationEnum
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
BlockGemmPipelineVersion
Definition: blockwise_gemm_pipeline_xdlops_ab_scale_selector.hpp:13
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
TailNumber
Definition: blkgemmpipe_scheduler.hpp:18
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
constexpr auto BlockGemmPipeline_Selector()
Definition: blockwise_gemm_pipeline_xdlops_b_scale_selector.hpp:44
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:132
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: block_to_ctile_map.hpp:270
__host__ static constexpr __device__ index_t CalculateGridSize(index_t M, index_t N)
Definition: block_to_ctile_map.hpp:281
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:241
__host__ Argument(const ADataType *p_a_grid_, const BDataType *p_b_grid_, CDataType *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t k_batch_)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:242
const BDataType * p_b_grid
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:260
const ADataType * p_a_grid
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:259
CDataType * p_c_grid
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:261
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:177
index_t StrideB
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:226
index_t M
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:222
index_t KRead
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:231
index_t BK0
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:234
index_t NPadded
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:230
index_t MBlock
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:235
index_t KPadded
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:232
__host__ void Print() const
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:203
index_t NBlock
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:236
index_t MPadded
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:229
index_t AK0
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:233
index_t N
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:223
index_t KBatch
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:228
index_t StrideA
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:225
__host__ Problem(index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, index_t KBatch_)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:178
index_t K
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:224
index_t StrideC
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:227
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:66
static __host__ auto CalculateBK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:114
static constexpr auto I5
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:72
static __device__ void Run_2Lds(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared_0, void *p_shared_1, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const index_t k_id=0)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:1002
static constexpr auto AK1Number
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:79
static constexpr auto BK1Number
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:80
static constexpr __host__ TailNumber CalculateKBlockLoopTailNum(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:610
static constexpr auto I3
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:70
static constexpr auto I1
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:68
static constexpr auto I4
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:71
static __host__ auto CalculateKPadded(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:103
static __host__ auto CalculateKPadded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:120
static __host__ auto CalculateAK0Padded(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:108
static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:576
static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:401
static constexpr auto AK0Number
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:77
static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:264
__host__ static constexpr __device__ auto MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_grid_desc_m_n, index_t MBlock, index_t NBlock)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:618
static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:603
static __device__ void Run(const ADataType *p_a_grid, const BDataType *p_b_grid, CDataType *p_c_grid, void *p_shared, const Problem &problem, const AGridDesc_AK0_M_K1 &a_grid_desc_ak0_m_ak1, const BGridDesc_BK0_N_K1 &b_grid_desc_bk0_n_bk1, const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock &c_grid_desc_mblock_mperblock_nblock_nperblock, const index_t k_id=0)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:641
static constexpr auto I7
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:74
static constexpr index_t KPack
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:82
static __host__ auto CalculateNPadded(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:98
static constexpr auto I0
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:67
__host__ static constexpr __device__ auto MakeAMmaTileDescriptor_M0_M1_M2_K(const ABlockDesc_AK0_M_AK1 &)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:160
static constexpr __device__ auto GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock()
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:535
__host__ static constexpr __device__ auto MakeBMmaTileDescriptor_N0_N1_N2_K(const BBlockDesc_BK0_N_BK1 &)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:169
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:86
static __host__ auto CalculateGridSize(index_t M, index_t N, index_t KBatch, index_t Batch)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:88
static constexpr auto BK0Number
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:78
static constexpr auto I2
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:69
static constexpr auto I6
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:73
static __host__ auto CalculateKRead(index_t K, index_t K_Batch=1)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:126
__host__ static constexpr __device__ auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1 &)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:144
static __host__ auto CalculateMBlock(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:133
remove_cvref_t< decltype(BlockGemmPipeline_Selector< BlkGemmPipelineVer, BlkGemmPipeSched, BlockSize, ADataType, BDataType, ComputeTypeA, AccDataType, decltype(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()), decltype(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()), decltype(MakeAMmaTileDescriptor_M0_M1_M2_K(GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1())), decltype(MakeBMmaTileDescriptor_N0_N1_N2_K(GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1())), ABlockTransferSrcScalarPerVector, BBlockTransferSrcScalarPerVector, MPerBlock, NPerBlock, KPerBlock, MPerXdl, NPerXdl, MXdlPerWave, NXdlPerWave, KPack >())> BlockwiseGemmPipe
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:574
static __host__ auto CalculateNBlock(index_t N)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:138
static __host__ auto CalculateMPadded(index_t M)
Definition: gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp:93
Definition: xdlops_gemm.hpp:886
Definition: sequence.hpp:43
Definition: tensor_space_filling_curve.hpp:20
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1.hpp:34
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
Definition: device_base.hpp:50
Definition: unary_element_wise_operation.hpp:241