20 namespace tensor_operation {
29 typename InElementwiseOperation,
30 typename WeiElementwiseOperation,
31 typename OutElementwiseOperation,
41 typename M1N1ThreadClusterM1Xs,
42 typename M1N1ThreadClusterN1Xs,
43 typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
44 typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
45 typename ABlockTransferThreadClusterArrangeOrder,
46 typename ABlockTransferSrcAccessOrder,
47 typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
48 typename ABlockTransferSrcVectorTensorContiguousDimOrder,
49 typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
50 typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
51 typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
52 typename BBlockTransferThreadClusterArrangeOrder,
53 typename BBlockTransferSrcAccessOrder,
54 typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
55 typename BBlockTransferSrcVectorTensorContiguousDimOrder,
56 typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
57 typename CThreadTransferSrcDstAccessOrder,
58 index_t CThreadTransferSrcDstVectorDim,
59 index_t CThreadTransferDstScalarPerVector>
63 ck::tuple_element_t<NDimSpatial - 1,
64 ck::Tuple<ck::tensor_layout::convolution::NWC,
65 ck::tensor_layout::convolution::NHWC,
66 ck::tensor_layout::convolution::NDHWC>>,
67 ck::tuple_element_t<NDimSpatial - 1,
68 ck::Tuple<ck::tensor_layout::convolution::KXC,
69 ck::tensor_layout::convolution::KYXC,
70 ck::tensor_layout::convolution::KZYXC>>,
71 ck::tuple_element_t<NDimSpatial - 1,
72 ck::Tuple<ck::tensor_layout::convolution::NWK,
73 ck::tensor_layout::convolution::NHWK,
74 ck::tensor_layout::convolution::NDHWK>>,
78 InElementwiseOperation,
79 WeiElementwiseOperation,
80 OutElementwiseOperation>
100 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type =
false>
105 std::vector<ck::index_t> input_spatial_lengths,
106 std::vector<ck::index_t> filter_spatial_lengths,
107 std::vector<ck::index_t> output_spatial_lengths,
108 std::vector<ck::index_t> conv_filter_strides,
109 std::vector<ck::index_t> conv_filter_dilations,
110 std::vector<ck::index_t> input_left_pads,
111 std::vector<ck::index_t> input_right_pads,
112 std::vector<ck::index_t> tildes)
118 const index_t Wi = input_spatial_lengths[0];
119 const index_t Wo = output_spatial_lengths[0];
120 const index_t X = filter_spatial_lengths[0];
121 const index_t InLeftPadW = input_left_pads[0];
122 const index_t InRightPadW = input_right_pads[0];
123 const index_t ConvStrideW = conv_filter_strides[0];
124 const index_t ConvDilationW = conv_filter_dilations[0];
126 const auto K0 = K / K1;
130 if constexpr(ConvBackwardDataSpecialization ==
142 const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
159 in_n_x_wo_c_grid_desc,
166 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
167 wei_gemmk0_gemmn_gemmk1_grid_desc,
168 in_gemmm_gemmn_grid_desc);
172 const auto out_n_wo_k_grid_desc =
174 const auto wei_k_x_c_grid_desc =
177 const auto GcdStrideDilationW =
math::gcd(ConvStrideW, ConvDilationW);
179 const auto XTilde = ConvStrideW / GcdStrideDilationW;
188 math::max(
I0, InLeftPadW - ConvDilationW * (XTilde -
I1)), ConvStrideW);
193 const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
200 out_n_wo_k_grid_desc,
208 out_n_wop_k_grid_desc,
218 out_n_xdot_wtilde_k_grid_desc,
227 out_n_xdotslice_wtildeslice_k0_k1_grid_desc,
245 wei_k_xdot_xtilde_c_grid_desc,
254 wei_k0_k1_xdotslice_c_grid_desc,
271 in_n_wip_c_grid_desc,
280 in_n_xtilde_wtilde_c_grid_desc,
289 in_n_wtildeslice_c_grid_desc,
295 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
296 wei_gemmk0_gemmn_gemmk1_grid_desc,
297 in_gemmm_gemmn_grid_desc);
301 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type =
false>
306 std::vector<ck::index_t> input_spatial_lengths,
307 std::vector<ck::index_t> filter_spatial_lengths,
308 std::vector<ck::index_t> output_spatial_lengths,
309 std::vector<ck::index_t> conv_filter_strides,
310 std::vector<ck::index_t> conv_filter_dilations,
311 std::vector<ck::index_t> input_left_pads,
312 std::vector<ck::index_t> input_right_pads,
313 std::vector<ck::index_t> tildes)
320 const index_t Hi = input_spatial_lengths[0];
321 const index_t Wi = input_spatial_lengths[1];
323 const index_t Ho = output_spatial_lengths[0];
324 const index_t Wo = output_spatial_lengths[1];
326 const index_t Y = filter_spatial_lengths[0];
327 const index_t X = filter_spatial_lengths[1];
329 const index_t InLeftPadH = input_left_pads[0];
330 const index_t InLeftPadW = input_left_pads[1];
332 const index_t InRightPadH = input_right_pads[0];
333 const index_t InRightPadW = input_right_pads[1];
335 const index_t ConvStrideH = conv_filter_strides[0];
336 const index_t ConvStrideW = conv_filter_strides[1];
338 const index_t ConvDilationH = conv_filter_dilations[0];
339 const index_t ConvDilationW = conv_filter_dilations[1];
341 const auto K0 = K / K1;
343 const auto out_n_ho_wo_k_grid_desc =
345 const auto wei_k_y_x_c_grid_desc =
347 const auto in_n_hi_wi_c_grid_desc =
350 if constexpr(ConvBackwardDataSpecialization ==
362 const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
371 in_n_hi_wi_c_grid_desc,
380 in_n_y_ho_x_wo_c_grid_desc,
388 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
389 wei_gemmk0_gemmn_gemmk1_grid_desc,
390 in_gemmm_gemmn_grid_desc);
394 const auto GcdStrideDilationH =
math::gcd(ConvStrideH, ConvDilationH);
395 const auto GcdStrideDilationW =
math::gcd(ConvStrideW, ConvDilationW);
397 const auto YTilde = ConvStrideH / GcdStrideDilationH;
398 const auto XTilde = ConvStrideW / GcdStrideDilationW;
410 math::max(
I0, InLeftPadH - ConvDilationH * (YTilde -
I1)), ConvStrideH);
412 math::max(
I0, InLeftPadW - ConvDilationW * (XTilde -
I1)), ConvStrideW);
419 const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
420 const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
428 out_n_ho_wo_k_grid_desc,
437 out_n_hop_wop_k_grid_desc,
448 const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
450 out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
471 out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
480 wei_k_y_x_c_grid_desc,
490 const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
512 wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
521 in_n_hi_wi_c_grid_desc,
530 in_n_hip_wip_c_grid_desc,
541 in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
562 in_n_htildeslice_wtildeslice_c_grid_desc,
568 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
569 wei_gemmk0_gemmn_gemmk1_grid_desc,
570 in_gemmm_gemmn_grid_desc);
575 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type =
false>
580 std::vector<ck::index_t> input_spatial_lengths,
581 std::vector<ck::index_t> filter_spatial_lengths,
582 std::vector<ck::index_t> output_spatial_lengths,
583 std::vector<ck::index_t> conv_filter_strides,
584 std::vector<ck::index_t> conv_filter_dilations,
585 std::vector<ck::index_t> input_left_pads,
586 std::vector<ck::index_t> input_right_pads,
587 std::vector<ck::index_t> tildes)
591 const index_t i_ztilde = tildes[0];
592 const index_t i_ytilde = tildes[1];
593 const index_t i_xtilde = tildes[2];
595 const index_t Di = input_spatial_lengths[0];
596 const index_t Hi = input_spatial_lengths[1];
597 const index_t Wi = input_spatial_lengths[2];
599 const index_t Do = output_spatial_lengths[0];
600 const index_t Ho = output_spatial_lengths[1];
601 const index_t Wo = output_spatial_lengths[2];
603 const index_t Z = filter_spatial_lengths[0];
604 const index_t Y = filter_spatial_lengths[1];
605 const index_t X = filter_spatial_lengths[2];
607 const index_t InLeftPadD = input_left_pads[0];
608 const index_t InLeftPadH = input_left_pads[1];
609 const index_t InLeftPadW = input_left_pads[2];
611 const index_t InRightPadD = input_right_pads[0];
612 const index_t InRightPadH = input_right_pads[1];
613 const index_t InRightPadW = input_right_pads[2];
615 const index_t ConvStrideD = conv_filter_strides[0];
616 const index_t ConvStrideH = conv_filter_strides[1];
617 const index_t ConvStrideW = conv_filter_strides[2];
619 const index_t ConvDilationD = conv_filter_dilations[0];
620 const index_t ConvDilationH = conv_filter_dilations[1];
621 const index_t ConvDilationW = conv_filter_dilations[2];
623 const auto K0 = K / K1;
625 const auto out_n_do_ho_wo_k_grid_desc =
627 const auto wei_k_z_y_x_c_grid_desc =
629 const auto in_n_di_hi_wi_c_grid_desc =
632 if constexpr(ConvBackwardDataSpecialization ==
644 const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
653 in_n_di_hi_wi_c_grid_desc,
668 in_n_z_do_y_ho_x_wo_c_grid_desc,
681 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
682 wei_gemmk0_gemmn_gemmk1_grid_desc,
683 in_gemmm_gemmn_grid_desc);
687 const auto GcdStrideDilationD =
math::gcd(ConvStrideD, ConvDilationD);
688 const auto GcdStrideDilationH =
math::gcd(ConvStrideH, ConvDilationH);
689 const auto GcdStrideDilationW =
math::gcd(ConvStrideW, ConvDilationW);
691 const auto ZTilde = ConvStrideD / GcdStrideDilationD;
692 const auto YTilde = ConvStrideH / GcdStrideDilationH;
693 const auto XTilde = ConvStrideW / GcdStrideDilationW;
708 math::max(
I0, InLeftPadD - ConvDilationD * (ZTilde -
I1)), ConvStrideD);
710 math::max(
I0, InLeftPadH - ConvDilationH * (YTilde -
I1)), ConvStrideH);
712 math::max(
I0, InLeftPadW - ConvDilationW * (XTilde -
I1)), ConvStrideW);
721 const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
722 const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
723 const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
732 out_n_do_ho_wo_k_grid_desc,
743 const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc =
745 out_n_dop_hop_wop_k_grid_desc,
764 out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
766 out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
793 out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
802 const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc =
804 wei_k_z_y_x_c_grid_desc,
822 const auto wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc =
850 wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc,
859 in_n_di_hi_wi_c_grid_desc,
870 const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
872 in_n_dip_hip_wip_c_grid_desc,
889 const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
891 in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
918 in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
925 return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
926 wei_gemmk0_gemmn_gemmk1_grid_desc,
927 in_gemmm_gemmn_grid_desc);
932 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type =
false>
935 return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(
936 1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0});
939 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type =
false>
942 return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(
943 1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0});
946 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type =
false>
949 return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1,
985 M1N1ThreadClusterM1Xs,
986 M1N1ThreadClusterN1Xs,
987 ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
988 ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
989 ABlockTransferThreadClusterArrangeOrder,
990 ABlockTransferSrcAccessOrder,
991 ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
992 ABlockTransferSrcVectorTensorContiguousDimOrder,
993 ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
994 BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
995 BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
996 BBlockTransferThreadClusterArrangeOrder,
997 BBlockTransferSrcAccessOrder,
998 BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
999 BBlockTransferSrcVectorTensorContiguousDimOrder,
1000 BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
1001 CThreadTransferSrcDstAccessOrder,
1002 CThreadTransferSrcDstVectorDim,
1003 CThreadTransferDstScalarPerVector>;
1017 const WeiDataType* p_wei_grid,
1018 const OutDataType* p_out_grid,
1022 std::vector<ck::index_t> input_spatial_lengths,
1023 std::vector<ck::index_t> filter_spatial_lengths,
1024 std::vector<ck::index_t> output_spatial_lengths,
1025 std::vector<ck::index_t> conv_filter_strides,
1026 std::vector<ck::index_t> conv_filter_dilations,
1027 std::vector<ck::index_t> input_left_pads,
1028 std::vector<ck::index_t> input_right_pads,
1029 InElementwiseOperation in_element_op,
1030 WeiElementwiseOperation wei_element_op,
1031 OutElementwiseOperation out_element_op)
1049 CreateABCDesc<NDimSpatial>();
1052 template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type =
false>
1057 const auto GcdStrideDilationW =
math::gcd(ConvStrideW, ConvDilationW);
1058 const auto XTilde = ConvStrideW / GcdStrideDilationW;
1062 for(
index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
1072 DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
1102 template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type =
false>
1111 const auto GcdStrideDilationH =
math::gcd(ConvStrideH, ConvDilationH);
1112 const auto GcdStrideDilationW =
math::gcd(ConvStrideW, ConvDilationW);
1114 const auto YTilde = ConvStrideH / GcdStrideDilationH;
1115 const auto XTilde = ConvStrideW / GcdStrideDilationW;
1119 for(
index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
1121 for(
index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
1126 if(YDotSlice * XDotSlice <= 0)
1132 DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
1143 {i_ytilde, i_xtilde});
1163 template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type =
false>
1174 const auto GcdStrideDilationD =
math::gcd(ConvStrideD, ConvDilationD);
1175 const auto GcdStrideDilationH =
math::gcd(ConvStrideH, ConvDilationH);
1176 const auto GcdStrideDilationW =
math::gcd(ConvStrideW, ConvDilationW);
1178 const auto ZTilde = ConvStrideD / GcdStrideDilationD;
1179 const auto YTilde = ConvStrideH / GcdStrideDilationH;
1180 const auto XTilde = ConvStrideW / GcdStrideDilationW;
1185 for(
index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
1187 for(
index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
1189 for(
index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
1195 if(ZDotSlice * YDotSlice * XDotSlice <= 0)
1201 DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
1212 {i_ztilde, i_ytilde, i_xtilde});
1277 std::cout <<
"arg.a_grid_desc_k0_m_k1_container_{"
1283 std::cout <<
"arg.b_grid_desc_k0_n_k1_container_{"
1289 std::cout <<
"arg.c_grid_desc_m_n_container_{ "
1294 std::cout <<
"arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_( "
1306 <<
" ) " << std::endl;
1313 throw std::runtime_error(
1314 "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
1321 auto has_double_tail_k_block_loop) {
1322 constexpr
bool has_main_loop = has_main_k_block_loop.value;
1323 constexpr
bool has_double_loop = has_double_tail_k_block_loop;
1353 const bool has_double_tail_k_block_loop =
1356 if(has_main_k_block_loop && has_double_tail_k_block_loop)
1360 else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
1365 else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
1382 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
1401 if constexpr(ConvBackwardDataSpecialization ==
1405 for(
int i = 0; i < NDimSpatial; i++)
1417 auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{};
1418 if(srcVectorLengths[
I1] != 1 || srcVectorLengths[
I2] != 1)
1422 if(K1 % srcVectorLengths[
I3] != 0 || K0PerBlock % srcVectorLengths[
I0] != 0)
1429 if(K % (srcVectorLengths[
I0] * srcVectorLengths[
I3]) != 0)
1437 auto srcLoadLenghts = BBlockTransferThreadSliceLengths_K0_N0_N1_K1{};
1438 auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{};
1439 if(srcVectorLengths[
I0] != 1 || srcVectorLengths[
I3] != 1)
1443 if(srcLoadLenghts[
I1] % srcVectorLengths[
I1] != 0 ||
1444 srcLoadLenghts[
I2] % srcVectorLengths[
I2] != 0)
1451 if(C % (srcVectorLengths[
I1] * srcVectorLengths[
I2]) != 0)
1457 if(!(arg.
Conv_C_ % CThreadTransferDstScalarPerVector == 0))
1459 std::cout <<
"Not surpport,because: arg.Conv_C_ % CThreadTransferDstScalarPerVector = "
1460 << arg.
Conv_C_ % CThreadTransferDstScalarPerVector << std::endl;
1483 const WeiDataType* p_wei_grid,
1484 const OutDataType* p_out_grid,
1488 std::vector<ck::index_t> input_spatial_lengths,
1489 std::vector<ck::index_t> filter_spatial_lengths,
1490 std::vector<ck::index_t> output_spatial_lengths,
1491 std::vector<ck::index_t> conv_filter_strides,
1492 std::vector<ck::index_t> conv_filter_dilations,
1493 std::vector<ck::index_t> input_left_pads,
1494 std::vector<ck::index_t> input_right_pads,
1495 InElementwiseOperation in_element_op,
1496 WeiElementwiseOperation wei_element_op,
1497 OutElementwiseOperation out_element_op)
1505 input_spatial_lengths,
1506 filter_spatial_lengths,
1507 output_spatial_lengths,
1508 conv_filter_strides,
1509 conv_filter_dilations,
1519 std::unique_ptr<BaseArgument>
1521 const void* p_wei_grid,
1522 const void* p_out_grid,
1526 std::vector<ck::index_t> input_spatial_lengths,
1527 std::vector<ck::index_t> filter_spatial_lengths,
1528 std::vector<ck::index_t> output_spatial_lengths,
1529 std::vector<ck::index_t> conv_filter_strides,
1530 std::vector<ck::index_t> conv_filter_dilations,
1531 std::vector<ck::index_t> input_left_pads,
1532 std::vector<ck::index_t> input_right_pads,
1533 InElementwiseOperation in_element_op,
1534 WeiElementwiseOperation wei_element_op,
1535 OutElementwiseOperation out_element_op)
override
1537 return std::make_unique<Argument>(
static_cast<InDataType*
>(p_in_grid),
1538 static_cast<const WeiDataType*
>(p_wei_grid),
1539 static_cast<const OutDataType*
>(p_out_grid),
1543 input_spatial_lengths,
1544 filter_spatial_lengths,
1545 output_spatial_lengths,
1546 conv_filter_strides,
1547 conv_filter_dilations,
1557 return std::make_unique<Invoker>(
Invoker{});
1562 auto str = std::stringstream();
1565 str <<
"DeviceConvNdBwdDataNwcKxcNwk_Dl"
1567 << BlockSize <<
", "
1568 << MPerBlock <<
", "
1569 << NPerBlock <<
", "
1570 << K0PerBlock <<
", "
1573 if constexpr(ConvBackwardDataSpecialization ==
1576 str<<
" Filter1x1Stride1Pad0";
#define CK_ENV(name)
Definition: env.hpp:128
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto integer_divide_floor(X x, Y y)
Definition: math.hpp:66
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
__host__ constexpr __device__ index_t gcd(index_t x, index_t y)
Definition: math.hpp:154
ConvolutionBackwardDataSpecialization
Definition: convolution_backward_data_specialization.hpp:11
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables... callables)
Definition: kernel_launch.hpp:72
typename remove_reference< T >::type remove_reference_t
Definition: type.hpp:292
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
std::string get_device_name()
Definition: device_prop.hpp:12
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
bool is_gfx12_supported()
Definition: device_prop.hpp:94
__host__ constexpr __device__ auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition: multi_index_transform_helper.hpp:48
__global__ void kernel_gemm_dl_v1r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map)
Definition: gridwise_gemm_dl_v1r3.hpp:33
bool is_gfx103_supported()
Definition: device_prop.hpp:81
__host__ constexpr __device__ auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition: multi_index_transform_helper.hpp:110
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:139
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:19
bool is_gfx11_supported()
Definition: device_prop.hpp:88
Definition: stream_config.hpp:10
Definition: gridwise_gemm_dl_v1r3.hpp:93
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:129
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:208
__host__ static constexpr __device__ auto MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1)
Definition: gridwise_gemm_dl_v1r3.hpp:188
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_v1r3.hpp:153
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:241
__host__ static constexpr __device__ bool CalculateHasDoubleTailKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_v1r3.hpp:160
__host__ static constexpr __device__ auto MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1)
Definition: gridwise_gemm_dl_v1r3.hpp:168
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_conv_bwd_data.hpp:25
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1015
const BDataType * p_b_grid_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1235
std::vector< CGridDesc_M0_M10_M11_N0_N10_N11 > c_grid_desc_m0_m10_m11_n0_n10_n11_container_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1243
index_t Conv_N_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1252
std::vector< BGridDesc_K0_N0_N1_K1 > b_grid_desc_k0_n0_n1_k1_container_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1242
std::vector< CGridDesc_M_N > c_grid_desc_m_n_container_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1239
std::vector< ck::index_t > filter_spatial_lengths_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1257
index_t Conv_C_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1254
InElementwiseOperation c_element_op_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1250
std::vector< ck::index_t > output_spatial_lengths_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1258
std::vector< ck::index_t > input_left_pads_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1261
std::vector< ck::index_t > conv_filter_strides_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1259
std::vector< ck::index_t > input_spatial_lengths_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1256
WeiElementwiseOperation b_element_op_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1249
index_t Conv_K_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1253
std::vector< BGridDesc_K0_N_K1 > b_grid_desc_k0_n_k1_container_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1238
CDataType * p_c_grid_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1236
std::vector< ck::index_t > input_right_pads_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1262
void CreateABCDesc()
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1053
std::vector< AGridDesc_K0_M_K1 > a_grid_desc_k0_m_k1_container_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1237
const ADataType * p_a_grid_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1234
std::vector< DefaultBlock2CTileMap > block_2_ctile_map_container_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1245
Argument(InDataType *p_in_grid, const WeiDataType *p_wei_grid, const OutDataType *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op)
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1016
OutElementwiseOperation a_element_op_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1248
std::vector< ck::index_t > conv_filter_dilations_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1260
std::vector< AGridDesc_K0_M0_M1_K1 > a_grid_desc_k0_m0_m1_k1_container_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1241
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1267
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1379
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1270
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:81
InDataType CDataType
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:86
InDataType ABDataType
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:89
static constexpr auto I3
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:94
static constexpr bool IsValidCompilationParameter()
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1386
decltype(GetABCGridDesc< NDimSpatial >()) ABCGridDescs
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:962
static constexpr auto I7
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:98
static constexpr auto I5
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:96
std::string GetTypeString() const override
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1560
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})) BGridDesc_K0_N0_N1_K1
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1008
static bool IsSupportedArgument(const Argument &arg)
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1392
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, std::vector< ck::index_t > tildes)
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:102
static auto MakeArgument(InDataType *p_in_grid, const WeiDataType *p_wei_grid, const OutDataType *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op)
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1482
remove_cvref_t< decltype(ABCGridDescs{}[I1])> BGridDesc_K0_N_K1
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:965
static constexpr auto I2
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:93
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1555
static auto MakeInvoker()
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1517
static constexpr auto I4
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:95
std::unique_ptr< BaseArgument > MakeArgumentPointer(void *p_in_grid, const void *p_wei_grid, const void *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) override
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1520
OutDataType ADataType
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:84
remove_cvref_t< decltype(ABCGridDescs{}[I2])> CGridDesc_M_N
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:966
remove_cvref_t< decltype(ABCGridDescs{}[I0])> AGridDesc_K0_M_K1
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:964
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1477
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})) CGridDesc_M0_M10_M11_N0_N10_N11
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1010
GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector > GridwiseGemm
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1003
static constexpr auto I6
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:97
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{})) DefaultBlock2CTileMap
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1012
static constexpr auto I0
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:91
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})) AGridDesc_K0_M0_M1_K1
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1006
static constexpr auto I1
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:92
WeiDataType BDataType
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:85
static auto GetABCGridDesc()
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:933