23 template <
typename Gr
idwiseGemm>
25 #if CK_USE_LAUNCH_BOUNDS
29 const typename GridwiseGemm::FloatAB* p_b_grid,
30 typename GridwiseGemm::FloatC* p_c_grid,
38 typename GridwiseGemm::Block2CTileMap block_mapping)
40 #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
42 constexpr
index_t shared_size = GridwiseGemm::GetSharedMemoryNumberOfByte();
44 __shared__ uint8_t p_shared[shared_size];
46 GridwiseGemm::Run(p_a_grid,
57 static_cast<void*
>(p_shared));
74 typename Block2CTileMap_,
81 typename AElementwiseOperation,
82 typename BElementwiseOperation,
83 typename CElementwiseOperation,
92 typename ABlockTransferThreadClusterLengths_K0_M_K1,
93 typename ABlockTransferThreadClusterArrangeOrder,
94 typename ABlockTransferSrcAccessOrder,
95 index_t ABlockTransferSrcVectorDim,
96 index_t ABlockTransferSrcScalarPerVector,
97 index_t ABlockTransferDstScalarPerVector_K1,
98 bool AThreadTransferSrcResetCoordinateAfterRun,
100 typename BBlockTransferThreadClusterLengths_K0_N_K1,
101 typename BBlockTransferThreadClusterArrangeOrder,
102 typename BBlockTransferSrcAccessOrder,
103 index_t BBlockTransferSrcVectorDim,
104 index_t BBlockTransferSrcScalarPerVector,
105 index_t BBlockTransferDstScalarPerVector_K1,
106 bool BThreadTransferSrcResetCoordinateAfterRun,
108 index_t CShuffleMRepeatPerShuffle,
109 index_t CShuffleNRepeatPerShuffle,
110 index_t CBlockTransferScalarPerVector_NWaveNPerXDL,
111 typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock>
125 static constexpr
auto M01 = 1;
126 static constexpr
auto N01 = 1;
161 uint32_t num_sk_blocks_)
183 <<
"SC:" <<
StrideC << std::endl;
196 __host__ __device__
static auto
201 const auto a_grid_desc_m_k = [&]() {
225 __host__ __device__
static auto
230 const auto b_grid_desc_k_n = [&]() {
254 __host__ __device__
static auto
257 const auto c_grid_desc_m_n = [&]() {
293 constexpr
auto max_lds_align =
K1;
299 constexpr
auto a_block_space_size_aligned =
302 constexpr
auto b_block_space_size_aligned =
305 constexpr
auto c_block_size =
308 return math::max((a_block_space_size_aligned + b_block_space_size_aligned) *
317 if(karg.
K % ABlockTransferSrcScalarPerVector != 0)
322 if(karg.
M % ABlockTransferSrcScalarPerVector != 0)
328 if(karg.
N % BBlockTransferSrcScalarPerVector != 0)
333 if(karg.
K % BBlockTransferSrcScalarPerVector != 0)
339 if(karg.
N % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
344 if(karg.
M % CBlockTransferScalarPerVector_NWaveNPerXDL != 0)
353 const bool has_main_k0_block_loop = K0 > K0PerBlock;
355 return has_main_k0_block_loop;
358 template <
typename CGr
idDesc>
359 __host__ __device__
static constexpr
auto
362 const auto M = c_m_n_grid_desc.GetLength(
I0);
363 const auto N = c_m_n_grid_desc.GetLength(
I1);
365 const auto MBlock = M / MPerBlock;
366 const auto NBlock = N / NPerBlock;
377 template <
typename CGr
idDesc>
382 c_m_n_grid_desc, 8, KBatch);
385 __host__ __device__
static constexpr
auto
388 constexpr
index_t MWave = MPerBlock / (MRepeat * MPerXDL);
389 constexpr
index_t NWave = NPerBlock / (NRepeat * NPerXDL);
398 __host__ __device__
static constexpr
auto
401 constexpr
index_t MWave = MPerBlock / (MRepeat * MPerXDL);
402 constexpr
index_t NWave = NPerBlock / (NRepeat * NPerXDL);
407 Number<NRepeat / CShuffleNRepeatPerShuffle>{},
415 constexpr
auto NPerBlockPow2 = math::next_power_of_two<NPerBlock>();
416 constexpr
auto NPerBlockReduction =
417 NPerBlockPow2 / CBlockTransferScalarPerVector_NWaveNPerXDL;
418 constexpr
auto MPerBlockReduction =
419 (BlockSize + NPerBlockReduction - 1) / NPerBlockReduction;
425 const auto c_partial_acc_block_m_n = [&]() {
437 return c_partial_acc_block_m_n;
453 void* __restrict__ p_shared_block)
458 uint32_t pad_m = (m + MPerBlock - 1) / MPerBlock * MPerBlock;
459 uint32_t pad_n = (n + NPerBlock - 1) / NPerBlock * NPerBlock;
461 uint32_t stride_a = StrideA;
462 uint32_t stride_b = StrideB;
463 uint32_t stride_c = StrideC;
469 const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
471 const AElementwiseOperation a_element_op = AElementwiseOperation{};
472 const BElementwiseOperation b_element_op = BElementwiseOperation{};
473 const CElementwiseOperation c_element_op = CElementwiseOperation{};
475 const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
476 p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
477 const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
478 p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
479 auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
480 p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
483 constexpr
auto max_lds_align =
K1;
491 auto blockwise_gemm =
496 decltype(a_block_desc_k0_m_k1),
497 decltype(b_block_desc_k0_n_k1),
504 auto c_thread_buf = blockwise_gemm.GetCThreadBuffer();
507 constexpr
auto a_block_space_size =
511 FloatAB* p_b_block =
static_cast<FloatAB*
>(p_shared_block) + a_block_space_size;
513 constexpr
auto a_block_slice_copy_step =
make_multi_index(K0PerBlock, 0, 0);
514 constexpr
auto b_block_slice_copy_step =
make_multi_index(K0PerBlock, 0, 0);
516 auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
517 p_a_block, a_block_desc_k0_m_k1.GetElementSpaceSize());
518 auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
519 p_b_block, b_block_desc_k0_n_k1.GetElementSpaceSize());
524 uint32_t block_idx = block_mapping.get_block_idx();
525 bool is_sk_block = block_idx < block_mapping.sk_num_blocks;
526 bool is_dp_block = block_idx >= block_mapping.dp_start_block_idx &&
527 block_idx < block_mapping.reduction_start_block_idx;
528 bool is_reduction_block = block_idx >= block_mapping.reduction_start_block_idx;
529 bool is_padding_block = block_idx >= block_mapping.sk_num_blocks &&
530 block_idx < block_mapping.dp_start_block_idx;
531 uint32_t iter_start, iter_end;
532 block_mapping.get_block_itr(block_idx, iter_start, iter_end);
533 uint32_t total_iter_length = iter_end - iter_start;
538 uint32_t* p_semaphore =
539 reinterpret_cast<uint32_t*
>(
reinterpret_cast<char*
>(p_workspace) +
540 block_mapping.get_workspace_size_for_acc(
sizeof(
FloatAcc)));
544 if(is_reduction_block)
549 const auto reduce_thread_cluster_idx =
551 const auto thread_m_cluster_id = reduce_thread_cluster_idx[
I0];
552 const auto thread_n_cluster_id = reduce_thread_cluster_idx[
I1];
554 constexpr
auto MReduceIters =
558 cluster_length_reduce.At(
I1) *
569 0, cluster_length_reduce.At(
I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
570 constexpr
auto partial_acc_load_step_n_reverse =
572 -1 * cluster_length_reduce.At(
I1).value * (NReduceIters - 1) *
573 CBlockTransferScalarPerVector_NWaveNPerXDL);
574 constexpr
auto partial_acc_load_step_m =
581 cluster_length_reduce.At(
I1) * CBlockTransferScalarPerVector_NWaveNPerXDL);
582 constexpr
auto partial_acc_store_step_n_reverse =
586 -1 * cluster_length_reduce.At(
I1).value * (NReduceIters - 1) *
587 CBlockTransferScalarPerVector_NWaveNPerXDL);
588 constexpr
auto partial_acc_store_step_m =
593 CBlockTransferScalarPerVector_NWaveNPerXDL,
598 CBlockTransferScalarPerVector_NWaveNPerXDL,
603 auto reduction_idx = blockIdx.x - block_mapping.reduction_start_block_idx;
604 auto spatial_idx = block_mapping.tile_to_spatial(reduction_idx, m, n);
608 uint32_t tile_acc_offset_start =
609 block_mapping.get_acc_buffer_offset_from_tile(reduction_idx);
610 uint32_t tile_acc_offset_end =
611 block_mapping.get_acc_buffer_offset_from_tile(reduction_idx + 1);
616 decltype(c_partial_acc_block_m_n),
617 decltype(acc_thread_buf_load_desc),
621 CBlockTransferScalarPerVector_NWaveNPerXDL,
624 >{c_partial_acc_block_m_n,
626 thread_n_cluster_id *
627 CBlockTransferScalarPerVector_NWaveNPerXDL)};
632 decltype(acc_thread_buf_store_desc),
633 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
634 CElementwiseOperation,
638 CBlockTransferScalarPerVector_NWaveNPerXDL,
642 >{c_grid_desc_mblock_mperblock_nblock_nperblock,
645 __builtin_amdgcn_readfirstlane(spatial_idx[
I1]),
646 thread_n_cluster_id *
647 CBlockTransferScalarPerVector_NWaveNPerXDL),
648 CElementwiseOperation{}};
651 wg_barrier.
wait_eq(reduction_idx, tile_acc_offset_end - tile_acc_offset_start);
654 if(threadIdx.x == 0) {
655 printf(
"bid:%d, rid:%d, os:%d,%d, spatial:%d,%d\n",
static_cast<int>(blockIdx.x),
656 reduction_idx, __builtin_amdgcn_readfirstlane(tile_acc_offset_start), __builtin_amdgcn_readfirstlane(tile_acc_offset_end),
657 __builtin_amdgcn_readfirstlane(spatial_idx[
I0]),
658 __builtin_amdgcn_readfirstlane(spatial_idx[
I1]));
662 using Accumulation = ck::detail::
665 for(
int i_m = 0; i_m < MReduceIters; i_m++)
669 for(
auto i = tile_acc_offset_start; i < tile_acc_offset_end; i++)
671 auto c_partial_acc_buf =
674 reinterpret_cast<FloatAcc*
>(p_workspace) +
675 i * c_partial_acc_block_m_n.GetElementSpaceSize(),
676 c_partial_acc_block_m_n.GetElementSpaceSize());
678 acc_load.Run(c_partial_acc_block_m_n,
680 acc_thread_buf_load_desc,
686 constexpr
auto offset =
687 acc_thread_buf_load_desc.CalculateOffset(
694 if(thread_n_cluster_id * CBlockTransferScalarPerVector_NWaveNPerXDL <
697 acc_store.Run(acc_thread_buf_store_desc,
700 c_grid_desc_mblock_mperblock_nblock_nperblock,
703 if constexpr(NReduceIters != 1)
705 if constexpr(i_n_reduce != (NReduceIters - 1))
707 acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
708 partial_acc_load_step_n);
709 acc_store.MoveDstSliceWindow(
710 c_grid_desc_mblock_mperblock_nblock_nperblock,
711 partial_acc_store_step_n);
715 acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
716 partial_acc_load_step_n_reverse);
717 acc_store.MoveDstSliceWindow(
718 c_grid_desc_mblock_mperblock_nblock_nperblock,
719 partial_acc_store_step_n_reverse);
724 acc_load.MoveSrcSliceWindow(c_partial_acc_block_m_n,
725 partial_acc_load_step_m);
726 acc_store.MoveDstSliceWindow(c_grid_desc_mblock_mperblock_nblock_nperblock,
727 partial_acc_store_step_m);
735 uint32_t block_acc_offset =
736 (block_mapping.get_acc_buffer_offset_from_block(block_idx + 1) - 1) * MPerBlock *
741 uint32_t current_iter_length = __builtin_amdgcn_readfirstlane(
742 block_mapping.get_current_iter_length(iter_start, iter_end, total_iter_length));
743 uint32_t tile_idx, iter_offset;
744 block_mapping.get_tile_idx_with_offset(iter_end - 1, tile_idx, iter_offset);
745 iter_offset = __builtin_amdgcn_readfirstlane(iter_offset - current_iter_length + 1);
746 auto spatial_idx = block_mapping.tile_to_spatial(tile_idx, m, n);
748 const index_t m_block_data_idx_on_grid =
749 __builtin_amdgcn_readfirstlane(spatial_idx[
I0] * MPerBlock);
751 const index_t n_block_data_idx_on_grid =
752 __builtin_amdgcn_readfirstlane(spatial_idx[
I1] * NPerBlock);
754 const index_t k0_block_data_idx_on_grid =
755 __builtin_amdgcn_readfirstlane(iter_offset * K0PerBlock);
758 auto a_blockwise_copy =
760 AElementwiseOperation,
764 ABlockTransferThreadClusterLengths_K0_M_K1,
765 ABlockTransferThreadClusterArrangeOrder,
768 decltype(a_k0_m_k1_grid_desc),
769 decltype(a_block_desc_k0_m_k1),
770 ABlockTransferSrcAccessOrder,
772 ABlockTransferSrcVectorDim,
774 ABlockTransferSrcScalarPerVector,
775 ABlockTransferDstScalarPerVector_K1,
778 AThreadTransferSrcResetCoordinateAfterRun,
783 a_block_desc_k0_m_k1,
788 auto b_blockwise_copy =
790 BElementwiseOperation,
794 BBlockTransferThreadClusterLengths_K0_N_K1,
795 BBlockTransferThreadClusterArrangeOrder,
798 decltype(b_k0_n_k1_grid_desc),
799 decltype(b_block_desc_k0_n_k1),
800 BBlockTransferSrcAccessOrder,
802 BBlockTransferSrcVectorDim,
804 BBlockTransferSrcScalarPerVector,
805 BBlockTransferDstScalarPerVector_K1,
808 BThreadTransferSrcResetCoordinateAfterRun,
813 b_block_desc_k0_n_k1,
817 const index_t num_k_block_main_loop = current_iter_length;
819 gridwise_gemm_pipeline.Run(a_k0_m_k1_grid_desc,
820 a_block_desc_k0_m_k1,
824 a_block_slice_copy_step,
826 b_block_desc_k0_n_k1,
830 b_block_slice_copy_step,
833 num_k_block_main_loop);
837 constexpr
index_t MWave = MPerBlock / (MRepeat * MPerXDL);
838 constexpr
index_t NWave = NPerBlock / (NRepeat * NPerXDL);
840 constexpr
auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc =
841 blockwise_gemm.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
843 constexpr
auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc =
844 blockwise_gemm.GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2();
846 constexpr
auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I0);
847 constexpr
auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I1);
848 constexpr
auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I2);
849 constexpr
auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I3);
850 constexpr
auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I4);
851 constexpr
auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I5);
852 constexpr
auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I6);
853 constexpr
auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(
I7);
855 constexpr
auto c_block_desc_mblock_mpershuffle_nblock_npershuffle =
858 constexpr
auto c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle =
861 auto c_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
863 c_block_desc_mblock_mpershuffle_nblock_npershuffle.GetElementSpaceSize());
865 auto c_partial_acc_buf =
866 make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
867 reinterpret_cast<FloatAcc*
>(p_workspace) + block_acc_offset,
868 c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle
869 .GetElementSpaceSize());
872 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
893 const auto c_thread_mtx_on_block =
894 blockwise_gemm.CalculateCThreadOriginDataIndex(
I0,
I0,
I0,
I0);
896 const index_t m_thread_data_on_block = c_thread_mtx_on_block[
I0];
897 const index_t n_thread_data_on_block = c_thread_mtx_on_block[
I1];
899 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor =
905 const auto m_thread_data_on_block_idx =
906 m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex(
909 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor =
915 const auto n_thread_data_on_block_idx =
916 n_thread_data_on_block_to_n0_n1_n2_adaptor.CalculateBottomIndex(
923 decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc),
924 decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2),
927 CShuffleNRepeatPerShuffle,
939 true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
942 m_thread_data_on_block_idx[
I1],
943 n_thread_data_on_block_idx[
I1],
944 m_thread_data_on_block_idx[
I2],
945 m_thread_data_on_block_idx[
I3],
946 m_thread_data_on_block_idx[
I4],
947 n_thread_data_on_block_idx[
I2]),
953 CElementwiseOperation,
956 CShuffleMRepeatPerShuffle * MWave * MPerXDL,
958 CShuffleNRepeatPerShuffle * NWave * NPerXDL>,
959 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
963 decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
964 decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
967 CBlockTransferScalarPerVector_NWaveNPerXDL,
970 {c_block_desc_mblock_mpershuffle_nblock_npershuffle,
972 c_grid_desc_mblock_mperblock_nblock_nperblock,
975 __builtin_amdgcn_readfirstlane(spatial_idx[
I1]),
982 CElementwiseOperation,
985 CShuffleMRepeatPerShuffle * MWave * MPerXDL,
987 CShuffleNRepeatPerShuffle * NWave * NPerXDL>,
988 CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
992 decltype(c_block_desc_mblock_mpershuffle_nblock_npershuffle),
993 decltype(c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle),
996 CBlockTransferScalarPerVector_NWaveNPerXDL,
1001 {c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1003 c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
1007 constexpr
auto mxdlperwave_forward_step =
1009 constexpr
auto nxdlperwave_forward_step =
1011 constexpr
auto nxdlperwave_backward_step =
1015 constexpr
auto mxdlperwave = mxdlperwave_iter;
1018 constexpr
bool nxdlperwave_forward_sweep =
1019 (mxdlperwave % (2 * CShuffleMRepeatPerShuffle) == 0);
1021 constexpr
index_t nxdlperwave_value =
1022 nxdlperwave_forward_sweep
1024 : (NRepeat - nxdlperwave_iter - CShuffleNRepeatPerShuffle);
1032 c_thread_copy_vgpr_to_lds.Run(
1033 c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc,
1036 c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2,
1042 c_block_copy_lds_to_global.SetSrcSliceOrigin(
1043 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1048 c_block_copy_lds_to_global.template
Run<decltype(c_block_buf),
1049 decltype(c_grid_buf),
1051 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1053 c_grid_desc_mblock_mperblock_nblock_nperblock,
1055 else if(is_sk_block)
1057 if constexpr(Block2CTileMap::ReductionStrategy ==
1061 c_block_copy_lds_to_partial_acc.SetSrcSliceOrigin(
1062 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1065 c_block_copy_lds_to_partial_acc.SetDstSliceOrigin(
1066 c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
1067 make_tuple(mxdlperwave.value, 0, nxdlperwave.value, 0));
1069 c_block_copy_lds_to_partial_acc
1070 .template
Run<decltype(c_block_buf),
1071 decltype(c_partial_acc_buf),
1073 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1075 c_block_desc_mshuffle_mpershuffle_nshuffle_npershuffle,
1078 else if constexpr(Block2CTileMap::ReductionStrategy ==
1081 c_block_copy_lds_to_global
1082 .template
Run<decltype(c_block_buf),
1083 decltype(c_grid_buf),
1085 c_block_desc_mblock_mpershuffle_nblock_npershuffle,
1087 c_grid_desc_mblock_mperblock_nblock_nperblock,
1093 if constexpr(nxdlperwave_forward_sweep &&
1094 (nxdlperwave < NRepeat - CShuffleNRepeatPerShuffle))
1096 c_block_copy_lds_to_global.MoveDstSliceWindow(
1097 c_grid_desc_mblock_mperblock_nblock_nperblock,
1098 nxdlperwave_forward_step);
1100 else if constexpr((!nxdlperwave_forward_sweep) && (nxdlperwave > 0))
1102 c_block_copy_lds_to_global.MoveDstSliceWindow(
1103 c_grid_desc_mblock_mperblock_nblock_nperblock,
1104 nxdlperwave_backward_step);
1109 if constexpr(mxdlperwave < MRepeat - CShuffleMRepeatPerShuffle)
1111 c_block_copy_lds_to_global.MoveDstSliceWindow(
1112 c_grid_desc_mblock_mperblock_nblock_nperblock,
1113 mxdlperwave_forward_step);
1117 if constexpr(Block2CTileMap::ReductionStrategy ==
1124 wg_barrier.
inc(tile_idx);
1130 iter_end -= current_iter_length;
1131 if(iter_end <= iter_start)
1136 block_acc_offset -= MPerBlock * NPerBlock;
1143 template <
typename Layout>
1146 static std::string
Get() {
return ""; }
1152 static std::string
Get() {
return "R"; }
1158 static std::string
Get() {
return "C"; }
1163 auto str = std::stringstream();
1166 str <<
"GemmXdlStreamK_"
1167 << std::string(ALayout::name)[0]
1168 << std::string(BLayout::name)[0]
1169 << std::string(CLayout::name)[0]
1171 <<
"B" << BlockSize <<
"_"
1172 <<
"Vec" << ABlockTransferSrcScalarPerVector <<
"x"
1173 << BBlockTransferSrcScalarPerVector <<
"x"
1174 << CBlockTransferScalarPerVector_NWaveNPerXDL <<
"_"
1177 << K0PerBlock <<
"x"
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:34
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
@ Atomic
Definition: block_to_ctile_map.hpp:1009
@ Reduction
Definition: block_to_ctile_map.hpp:1010
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__global__ void kernel_gemm_xdlops_streamk(const typename GridwiseGemm::FloatAB *p_a_grid, const typename GridwiseGemm::FloatAB *p_b_grid, typename GridwiseGemm::FloatC *p_c_grid, void *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, typename GridwiseGemm::Block2CTileMap block_mapping)
Definition: gridwise_gemm_xdlops_streamk.hpp:28
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:16
__host__ constexpr __device__ auto make_dynamic_buffer(T *p, ElementSpaceSize element_space_size)
Definition: dynamic_buffer.hpp:448
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
Definition: block_to_ctile_map.hpp:539
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
Definition: gridwise_gemm_xdlops_streamk.hpp:138
index_t K
Definition: gridwise_gemm_xdlops_streamk.hpp:144
const FloatAB * p_b_grid
Definition: gridwise_gemm_xdlops_streamk.hpp:140
void Print() const
Definition: gridwise_gemm_xdlops_streamk.hpp:175
index_t M
Definition: gridwise_gemm_xdlops_streamk.hpp:142
Argument(const FloatAB *p_a_grid_, const FloatAB *p_b_grid_, FloatC *p_c_grid_, index_t M_, index_t N_, index_t K_, index_t StrideA_, index_t StrideB_, index_t StrideC_, uint32_t num_cu, uint32_t occupancy, uint32_t num_sk_blocks_)
Definition: gridwise_gemm_xdlops_streamk.hpp:150
FloatC * p_c_grid
Definition: gridwise_gemm_xdlops_streamk.hpp:141
const FloatAB * p_a_grid
Definition: gridwise_gemm_xdlops_streamk.hpp:139
index_t StrideC
Definition: gridwise_gemm_xdlops_streamk.hpp:147
index_t StrideB
Definition: gridwise_gemm_xdlops_streamk.hpp:146
index_t StrideA
Definition: gridwise_gemm_xdlops_streamk.hpp:145
index_t N
Definition: gridwise_gemm_xdlops_streamk.hpp:143
Block2CTileMap block_mapping
Definition: gridwise_gemm_xdlops_streamk.hpp:148
static std::string Get()
Definition: gridwise_gemm_xdlops_streamk.hpp:1158
static std::string Get()
Definition: gridwise_gemm_xdlops_streamk.hpp:1152
Definition: gridwise_gemm_xdlops_streamk.hpp:1145
static std::string Get()
Definition: gridwise_gemm_xdlops_streamk.hpp:1146
Definition: gridwise_gemm_xdlops_streamk.hpp:113
static constexpr auto I5
Definition: gridwise_gemm_xdlops_streamk.hpp:119
static __device__ void Run(const FloatAB *p_a_grid, const FloatAB *p_b_grid, FloatC *p_c_grid, void *p_workspace, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, Block2CTileMap block_mapping, void *__restrict__ p_shared_block)
Definition: gridwise_gemm_xdlops_streamk.hpp:442
__host__ static constexpr __device__ auto MakeCGridDesc_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc &c_m_n_grid_desc)
Definition: gridwise_gemm_xdlops_streamk.hpp:360
__host__ static __device__ auto MakeAGridDescriptor_K0_M_K1(index_t M, index_t MPad, index_t K, index_t KPad, index_t StrideA)
Definition: gridwise_gemm_xdlops_streamk.hpp:197
__host__ static __device__ auto CalculateK0(index_t KPad)
Definition: gridwise_gemm_xdlops_streamk.hpp:194
static constexpr auto I0
Definition: gridwise_gemm_xdlops_streamk.hpp:114
Block2CTileMap_ Block2CTileMap
Definition: gridwise_gemm_xdlops_streamk.hpp:133
__host__ static constexpr __device__ auto GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1()
Definition: gridwise_gemm_xdlops_streamk.hpp:283
__host__ static __device__ auto CalculateGridSize(const Argument &karg)
Definition: gridwise_gemm_xdlops_streamk.hpp:187
FloatAcc FloatCShuffle
Definition: gridwise_gemm_xdlops_streamk.hpp:131
__host__ static constexpr __device__ auto GetClusterLengthReduction()
Definition: gridwise_gemm_xdlops_streamk.hpp:411
__host__ static constexpr __device__ bool CalculateHasMainK0BlockLoop(index_t K0)
Definition: gridwise_gemm_xdlops_streamk.hpp:351
static constexpr auto N01
Definition: gridwise_gemm_xdlops_streamk.hpp:126
static constexpr auto I6
Definition: gridwise_gemm_xdlops_streamk.hpp:120
__host__ static constexpr __device__ bool CheckValidity(const Argument &karg)
Definition: gridwise_gemm_xdlops_streamk.hpp:313
static constexpr auto M01
Definition: gridwise_gemm_xdlops_streamk.hpp:125
__host__ static constexpr __device__ auto GetCBlockDescriptor_MBlock_MPerShuffle_NBlock_NPerShuffle()
Definition: gridwise_gemm_xdlops_streamk.hpp:386
__host__ static __device__ auto MakeCGridDescriptor_M_N(index_t M, index_t MPad, index_t N, index_t NPad, index_t StrideC)
Definition: gridwise_gemm_xdlops_streamk.hpp:255
static std::string GetTypeString()
Definition: gridwise_gemm_xdlops_streamk.hpp:1161
__host__ static constexpr __device__ auto GetCBlockDescriptor_MShuffleRepeat_MPerShuffle_NShuffleRepeat_NPerShuffle()
Definition: gridwise_gemm_xdlops_streamk.hpp:399
__host__ static constexpr __device__ auto GetPartialAccBlockDescriptor()
Definition: gridwise_gemm_xdlops_streamk.hpp:423
static constexpr auto I2
Definition: gridwise_gemm_xdlops_streamk.hpp:116
static constexpr auto I1
Definition: gridwise_gemm_xdlops_streamk.hpp:115
FloatAB_ FloatAB
Definition: gridwise_gemm_xdlops_streamk.hpp:134
static constexpr auto K1
Definition: gridwise_gemm_xdlops_streamk.hpp:124
static constexpr auto KPerBlock
Definition: gridwise_gemm_xdlops_streamk.hpp:127
__host__ static constexpr __device__ auto MakeCBlockClusterAdaptor(const CGridDesc &c_m_n_grid_desc, index_t, index_t, index_t KBatch)
Definition: gridwise_gemm_xdlops_streamk.hpp:378
FloatAcc_ FloatAcc
Definition: gridwise_gemm_xdlops_streamk.hpp:130
remove_cvref_t< decltype(MakeCGridDescriptor_M_N(1, 1, 1, 1, 1))> CGridDesc_M_N
Definition: gridwise_gemm_xdlops_streamk.hpp:440
static constexpr auto I3
Definition: gridwise_gemm_xdlops_streamk.hpp:117
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_gemm_xdlops_streamk.hpp:129
static constexpr auto I7
Definition: gridwise_gemm_xdlops_streamk.hpp:121
__host__ static constexpr __device__ index_t GetSharedMemoryNumberOfByte()
Definition: gridwise_gemm_xdlops_streamk.hpp:291
static constexpr auto I4
Definition: gridwise_gemm_xdlops_streamk.hpp:118
FloatC_ FloatC
Definition: gridwise_gemm_xdlops_streamk.hpp:135
__host__ static __device__ auto MakeBGridDescriptor_K0_N_K1(index_t K, index_t KPad, index_t N, index_t NPad, index_t StrideB)
Definition: gridwise_gemm_xdlops_streamk.hpp:226
__host__ static constexpr __device__ auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
Definition: gridwise_gemm_xdlops_streamk.hpp:275
Definition: gridwise_gemm_pipeline_v3.hpp:11
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
__host__ __device__ void Clear()
Definition: static_buffer.hpp:63
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v6r1r2.hpp:33
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: threadwise_tensor_slice_transfer.hpp:214
Definition: integral_constant.hpp:10
Definition: reduction_operator.hpp:37
Definition: functional2.hpp:31
Definition: tensor_layout.hpp:21
Definition: tensor_layout.hpp:16
Definition: device_base.hpp:50
Definition: unary_element_wise_operation.hpp:241
Definition: workgroup_barrier.hpp:7
__device__ void inc(uint32_t offset)
Definition: workgroup_barrier.hpp:62
__device__ void wait_eq(uint32_t offset, uint32_t value)
Definition: workgroup_barrier.hpp:29