/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp Source File
device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
18 
19 namespace ck {
20 namespace tensor_operation {
21 namespace device {
22 
23 // out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
24 template <ck::index_t NDimSpatial,
25  typename InDataType,
26  typename WeiDataType,
27  typename OutDataType,
28  typename AccDataType,
29  typename InElementwiseOperation,
30  typename WeiElementwiseOperation,
31  typename OutElementwiseOperation,
32  ConvolutionBackwardDataSpecialization ConvBackwardDataSpecialization,
33  ck::index_t BlockSize,
34  ck::index_t MPerBlock,
35  ck::index_t NPerBlock,
36  ck::index_t K0PerBlock,
37  ck::index_t K1,
38  index_t M1PerThread,
39  index_t N1PerThread,
40  index_t KPerThread,
41  typename M1N1ThreadClusterM1Xs,
42  typename M1N1ThreadClusterN1Xs,
43  typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
44  typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
45  typename ABlockTransferThreadClusterArrangeOrder,
46  typename ABlockTransferSrcAccessOrder,
47  typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
48  typename ABlockTransferSrcVectorTensorContiguousDimOrder,
49  typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
50  typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
51  typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
52  typename BBlockTransferThreadClusterArrangeOrder,
53  typename BBlockTransferSrcAccessOrder,
54  typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
55  typename BBlockTransferSrcVectorTensorContiguousDimOrder,
56  typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
57  typename CThreadTransferSrcDstAccessOrder,
58  index_t CThreadTransferSrcDstVectorDim,
59  index_t CThreadTransferDstScalarPerVector>
61  : public DeviceConvBwdData<
62  NDimSpatial,
63  ck::tuple_element_t<NDimSpatial - 1,
64  ck::Tuple<ck::tensor_layout::convolution::NWC,
65  ck::tensor_layout::convolution::NHWC,
66  ck::tensor_layout::convolution::NDHWC>>,
67  ck::tuple_element_t<NDimSpatial - 1,
68  ck::Tuple<ck::tensor_layout::convolution::KXC,
69  ck::tensor_layout::convolution::KYXC,
70  ck::tensor_layout::convolution::KZYXC>>,
71  ck::tuple_element_t<NDimSpatial - 1,
72  ck::Tuple<ck::tensor_layout::convolution::NWK,
73  ck::tensor_layout::convolution::NHWK,
74  ck::tensor_layout::convolution::NDHWK>>,
75  InDataType,
76  WeiDataType,
77  OutDataType,
78  InElementwiseOperation,
79  WeiElementwiseOperation,
80  OutElementwiseOperation>
81 {
83 
84  using ADataType = OutDataType;
85  using BDataType = WeiDataType;
86  using CDataType = InDataType;
87 
88  // TODO make A/B datatype different
89  using ABDataType = InDataType;
90 
91  static constexpr auto I0 = Number<0>{};
92  static constexpr auto I1 = Number<1>{};
93  static constexpr auto I2 = Number<2>{};
94  static constexpr auto I3 = Number<3>{};
95  static constexpr auto I4 = Number<4>{};
96  static constexpr auto I5 = Number<5>{};
97  static constexpr auto I6 = Number<6>{};
98  static constexpr auto I7 = Number<7>{};
99 
100  template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
101  static auto
103  ck::index_t K,
104  ck::index_t C,
105  std::vector<ck::index_t> input_spatial_lengths,
106  std::vector<ck::index_t> filter_spatial_lengths,
107  std::vector<ck::index_t> output_spatial_lengths,
108  std::vector<ck::index_t> conv_filter_strides,
109  std::vector<ck::index_t> conv_filter_dilations,
110  std::vector<ck::index_t> input_left_pads,
111  std::vector<ck::index_t> input_right_pads,
112  std::vector<ck::index_t> tildes)
113  {
114  using namespace ck;
115 
116  index_t i_xtilde = tildes[0];
117 
118  const index_t Wi = input_spatial_lengths[0];
119  const index_t Wo = output_spatial_lengths[0];
120  const index_t X = filter_spatial_lengths[0];
121  const index_t InLeftPadW = input_left_pads[0];
122  const index_t InRightPadW = input_right_pads[0];
123  const index_t ConvStrideW = conv_filter_strides[0];
124  const index_t ConvDilationW = conv_filter_dilations[0];
125 
126  const auto K0 = K / K1;
127 
128  const auto in_n_wi_c_grid_desc = make_naive_tensor_descriptor_packed(make_tuple(N, Wi, C));
129 
130  if constexpr(ConvBackwardDataSpecialization ==
132  {
133  // A: output tensor
134  const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
140 
141  // B: weight tensor
142  const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
148 
149  // C: input tensor
150  const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
151  in_n_wi_c_grid_desc,
153  make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
157 
158  const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
159  in_n_x_wo_c_grid_desc,
165 
166  return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
167  wei_gemmk0_gemmn_gemmk1_grid_desc,
168  in_gemmm_gemmn_grid_desc);
169  }
170  else
171  {
172  const auto out_n_wo_k_grid_desc =
174  const auto wei_k_x_c_grid_desc =
176 
177  const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
178 
179  const auto XTilde = ConvStrideW / GcdStrideDilationW;
180 
181  const auto XDot = math::integer_divide_ceil(X, XTilde);
182 
183  const auto WTilde =
184  Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
185 
186  // only work on HTilde and WTilde that contribute to non-padding area of input tensor
187  const auto IWTildeSliceBegin = math::integer_divide_floor(
188  math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
189 
190  const auto IWTildeSliceEnd = math::min(
191  WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
192 
193  const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
194 
195  // GemmK is different for each GEMM
196  const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
197 
198  // A: output tensor
199  const auto out_n_wop_k_grid_desc = transform_tensor_descriptor(
200  out_n_wo_k_grid_desc,
202  make_pad_transform(Wo, I0, I0),
206 
207  const auto out_n_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
208  out_n_wop_k_grid_desc,
209  make_tuple(
211  make_embed_transform(make_tuple(XDot, WTilde),
212  make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
216 
217  const auto out_n_xdotslice_wtildeslice_k0_k1_grid_desc = transform_tensor_descriptor(
218  out_n_xdot_wtilde_k_grid_desc,
220  make_slice_transform(XDot, I0, XDotSlice),
221  make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
225 
226  const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
227  out_n_xdotslice_wtildeslice_k0_k1_grid_desc,
229  make_merge_transform(make_tuple(N, WTildeSlice)),
233 
234  // B weight tensor
235  const auto wei_k_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
236  wei_k_x_c_grid_desc,
238  make_embed_transform(make_tuple(XDot, XTilde),
239  make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
243 
244  const auto wei_k0_k1_xdotslice_c_grid_desc = transform_tensor_descriptor(
245  wei_k_xdot_xtilde_c_grid_desc,
247  make_slice_transform(XDot, I0, XDotSlice),
248  make_freeze_transform(i_xtilde),
252 
253  const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
254  wei_k0_k1_xdotslice_c_grid_desc,
260 
261  // C: input tensor
262  const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
263  in_n_wi_c_grid_desc,
265  make_pad_transform(Wi, InLeftPadW, InRightPadW),
269 
270  const auto in_n_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
271  in_n_wip_c_grid_desc,
273  make_embed_transform(make_tuple(XTilde, WTilde),
274  make_tuple(ConvDilationW, ConvStrideW)),
278 
279  const auto in_n_wtildeslice_c_grid_desc = transform_tensor_descriptor(
280  in_n_xtilde_wtilde_c_grid_desc,
282  make_freeze_transform(i_xtilde),
283  make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
287 
288  const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
289  in_n_wtildeslice_c_grid_desc,
290  make_tuple(make_merge_transform(make_tuple(N, WTildeSlice)),
294 
295  return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
296  wei_gemmk0_gemmn_gemmk1_grid_desc,
297  in_gemmm_gemmn_grid_desc);
298  }
299 
300  } // function end
301  template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
302  static auto
304  ck::index_t K,
305  ck::index_t C,
306  std::vector<ck::index_t> input_spatial_lengths,
307  std::vector<ck::index_t> filter_spatial_lengths,
308  std::vector<ck::index_t> output_spatial_lengths,
309  std::vector<ck::index_t> conv_filter_strides,
310  std::vector<ck::index_t> conv_filter_dilations,
311  std::vector<ck::index_t> input_left_pads,
312  std::vector<ck::index_t> input_right_pads,
313  std::vector<ck::index_t> tildes)
314  {
315  using namespace ck;
316 
317  index_t i_ytilde = tildes[0];
318  index_t i_xtilde = tildes[1];
319 
320  const index_t Hi = input_spatial_lengths[0];
321  const index_t Wi = input_spatial_lengths[1];
322 
323  const index_t Ho = output_spatial_lengths[0];
324  const index_t Wo = output_spatial_lengths[1];
325 
326  const index_t Y = filter_spatial_lengths[0];
327  const index_t X = filter_spatial_lengths[1];
328 
329  const index_t InLeftPadH = input_left_pads[0];
330  const index_t InLeftPadW = input_left_pads[1];
331 
332  const index_t InRightPadH = input_right_pads[0];
333  const index_t InRightPadW = input_right_pads[1];
334 
335  const index_t ConvStrideH = conv_filter_strides[0];
336  const index_t ConvStrideW = conv_filter_strides[1];
337 
338  const index_t ConvDilationH = conv_filter_dilations[0];
339  const index_t ConvDilationW = conv_filter_dilations[1];
340 
341  const auto K0 = K / K1;
342 
343  const auto out_n_ho_wo_k_grid_desc =
345  const auto wei_k_y_x_c_grid_desc =
347  const auto in_n_hi_wi_c_grid_desc =
349 
350  if constexpr(ConvBackwardDataSpecialization ==
352  {
353  // A: output tensor
354  const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
360 
361  // B: weight tensor
362  const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
368 
369  // C: input tensor
370  const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
371  in_n_hi_wi_c_grid_desc,
373  make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
374  make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
378 
379  const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
380  in_n_y_ho_x_wo_c_grid_desc,
383  make_merge_transform(make_tuple(N, Ho, Wo)),
387 
388  return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
389  wei_gemmk0_gemmn_gemmk1_grid_desc,
390  in_gemmm_gemmn_grid_desc);
391  }
392  else
393  {
394  const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
395  const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
396 
397  const auto YTilde = ConvStrideH / GcdStrideDilationH;
398  const auto XTilde = ConvStrideW / GcdStrideDilationW;
399 
400  const auto YDot = math::integer_divide_ceil(Y, YTilde);
401  const auto XDot = math::integer_divide_ceil(X, XTilde);
402 
403  const auto HTilde =
404  Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
405  const auto WTilde =
406  Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
407 
408  // only work on HTilde and WTilde that contribute to non-padding area of input tensor
409  const auto IHTildeSliceBegin = math::integer_divide_floor(
410  math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
411  const auto IWTildeSliceBegin = math::integer_divide_floor(
412  math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
413 
414  const auto IHTildeSliceEnd = math::min(
415  HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
416  const auto IWTildeSliceEnd = math::min(
417  WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
418 
419  const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
420  const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
421 
422  // GemmK is different for each GEMM
423  const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
424  const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
425 
426  // A: output tensor
427  const auto out_n_hop_wop_k_grid_desc = transform_tensor_descriptor(
428  out_n_ho_wo_k_grid_desc,
430  make_pad_transform(Ho, I0, I0),
431  make_pad_transform(Wo, I0, I0),
435 
436  const auto out_n_ydot_htilde_xdot_wtilde_k_grid_desc = transform_tensor_descriptor(
437  out_n_hop_wop_k_grid_desc,
438  make_tuple(
440  make_embed_transform(make_tuple(YDot, HTilde),
441  make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
442  make_embed_transform(make_tuple(XDot, WTilde),
443  make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
447 
448  const auto out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
450  out_n_ydot_htilde_xdot_wtilde_k_grid_desc,
452  make_slice_transform(YDot, I0, YDotSlice),
453  make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
454  make_slice_transform(XDot, I0, XDotSlice),
455  make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
458  Sequence<1>{},
459  Sequence<2>{},
460  Sequence<3>{},
461  Sequence<4>{},
462  Sequence<5>{}),
464  Sequence<1>{},
465  Sequence<2>{},
466  Sequence<3>{},
467  Sequence<4>{},
468  Sequence<5, 6>{}));
469 
470  const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
471  out_n_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
472  make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
473  make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
477 
478  // B weight tensor
479  const auto wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc = transform_tensor_descriptor(
480  wei_k_y_x_c_grid_desc,
482  make_embed_transform(make_tuple(YDot, YTilde),
483  make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
484  make_embed_transform(make_tuple(XDot, XTilde),
485  make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
489 
490  const auto wei_k0_k1_ydotslice_xdotslice_c_grid_desc =
491  transform_tensor_descriptor(wei_k_ydot_ytilde_xdot_xtilde_c_grid_desc,
493  make_slice_transform(YDot, I0, YDotSlice),
494  make_slice_transform(XDot, I0, XDotSlice),
495  make_freeze_transform(i_ytilde),
496  make_freeze_transform(i_xtilde),
499  Sequence<1>{},
500  Sequence<3>{},
501  Sequence<2>{},
502  Sequence<4>{},
503  Sequence<5>{}),
505  Sequence<2>{},
506  Sequence<3>{},
507  Sequence<>{},
508  Sequence<>{},
509  Sequence<4>{}));
510 
511  const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
512  wei_k0_k1_ydotslice_xdotslice_c_grid_desc,
513  make_tuple(make_merge_transform(make_tuple(YDotSlice, XDotSlice, K0)),
518 
519  // C: input tensor
520  const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
521  in_n_hi_wi_c_grid_desc,
523  make_pad_transform(Hi, InLeftPadH, InRightPadH),
524  make_pad_transform(Wi, InLeftPadW, InRightPadW),
528 
529  const auto in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc = transform_tensor_descriptor(
530  in_n_hip_wip_c_grid_desc,
532  make_embed_transform(make_tuple(YTilde, HTilde),
533  make_tuple(ConvDilationH, ConvStrideH)),
534  make_embed_transform(make_tuple(XTilde, WTilde),
535  make_tuple(ConvDilationW, ConvStrideW)),
539 
540  const auto in_n_htildeslice_wtildeslice_c_grid_desc = transform_tensor_descriptor(
541  in_n_ytilde_htilde_xtilde_wtilde_c_grid_desc,
543  make_freeze_transform(i_ytilde),
544  make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
545  make_freeze_transform(i_xtilde),
546  make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
549  Sequence<1>{},
550  Sequence<2>{},
551  Sequence<3>{},
552  Sequence<4>{},
553  Sequence<5>{}),
555  Sequence<>{},
556  Sequence<1>{},
557  Sequence<>{},
558  Sequence<2>{},
559  Sequence<3>{}));
560 
561  const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
562  in_n_htildeslice_wtildeslice_c_grid_desc,
563  make_tuple(make_merge_transform(make_tuple(N, HTildeSlice, WTildeSlice)),
567 
568  return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
569  wei_gemmk0_gemmn_gemmk1_grid_desc,
570  in_gemmm_gemmn_grid_desc);
571  }
572 
573  } // function end
574 
575  template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
576  static auto
578  ck::index_t K,
579  ck::index_t C,
580  std::vector<ck::index_t> input_spatial_lengths,
581  std::vector<ck::index_t> filter_spatial_lengths,
582  std::vector<ck::index_t> output_spatial_lengths,
583  std::vector<ck::index_t> conv_filter_strides,
584  std::vector<ck::index_t> conv_filter_dilations,
585  std::vector<ck::index_t> input_left_pads,
586  std::vector<ck::index_t> input_right_pads,
587  std::vector<ck::index_t> tildes)
588  {
589  using namespace ck;
590 
591  const index_t i_ztilde = tildes[0];
592  const index_t i_ytilde = tildes[1];
593  const index_t i_xtilde = tildes[2];
594 
595  const index_t Di = input_spatial_lengths[0];
596  const index_t Hi = input_spatial_lengths[1];
597  const index_t Wi = input_spatial_lengths[2];
598 
599  const index_t Do = output_spatial_lengths[0];
600  const index_t Ho = output_spatial_lengths[1];
601  const index_t Wo = output_spatial_lengths[2];
602 
603  const index_t Z = filter_spatial_lengths[0];
604  const index_t Y = filter_spatial_lengths[1];
605  const index_t X = filter_spatial_lengths[2];
606 
607  const index_t InLeftPadD = input_left_pads[0];
608  const index_t InLeftPadH = input_left_pads[1];
609  const index_t InLeftPadW = input_left_pads[2];
610 
611  const index_t InRightPadD = input_right_pads[0];
612  const index_t InRightPadH = input_right_pads[1];
613  const index_t InRightPadW = input_right_pads[2];
614 
615  const index_t ConvStrideD = conv_filter_strides[0];
616  const index_t ConvStrideH = conv_filter_strides[1];
617  const index_t ConvStrideW = conv_filter_strides[2];
618 
619  const index_t ConvDilationD = conv_filter_dilations[0];
620  const index_t ConvDilationH = conv_filter_dilations[1];
621  const index_t ConvDilationW = conv_filter_dilations[2];
622 
623  const auto K0 = K / K1;
624 
625  const auto out_n_do_ho_wo_k_grid_desc =
627  const auto wei_k_z_y_x_c_grid_desc =
629  const auto in_n_di_hi_wi_c_grid_desc =
631 
632  if constexpr(ConvBackwardDataSpecialization ==
634  {
635  // A: output tensor
636  const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
637  make_naive_tensor_descriptor_packed(make_tuple(N * Do * Ho * Wo, K)),
638  make_tuple(make_pass_through_transform(N * Do * Ho * Wo),
642 
643  // B: weight tensor
644  const auto wei_gemmk0_gemmn_gemmk1_grid_desc =
650 
651  // C: input tensor
652  const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
653  in_n_di_hi_wi_c_grid_desc,
655  make_embed_transform(make_tuple(I1, Do), make_tuple(I1, ConvStrideD)),
656  make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)),
657  make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)),
659  make_tuple(
662  Sequence<1, 2>{},
663  Sequence<3, 4>{},
664  Sequence<5, 6>{},
665  Sequence<7>{}));
666 
667  const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
668  in_n_z_do_y_ho_x_wo_c_grid_desc,
672  make_merge_transform(make_tuple(N, Do, Ho, Wo)),
675  Sequence<3>{},
676  Sequence<5>{},
678  Sequence<7>{}),
680 
681  return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
682  wei_gemmk0_gemmn_gemmk1_grid_desc,
683  in_gemmm_gemmn_grid_desc);
684  }
685  else
686  {
687  const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
688  const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
689  const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
690 
691  const auto ZTilde = ConvStrideD / GcdStrideDilationD;
692  const auto YTilde = ConvStrideH / GcdStrideDilationH;
693  const auto XTilde = ConvStrideW / GcdStrideDilationW;
694 
695  const auto ZDot = math::integer_divide_ceil(Z, ZTilde);
696  const auto YDot = math::integer_divide_ceil(Y, YTilde);
697  const auto XDot = math::integer_divide_ceil(X, XTilde);
698 
699  const auto DTilde =
700  Do + math::integer_divide_ceil(ConvDilationD * (Z - I1), ConvStrideD);
701  const auto HTilde =
702  Ho + math::integer_divide_ceil(ConvDilationH * (Y - I1), ConvStrideH);
703  const auto WTilde =
704  Wo + math::integer_divide_ceil(ConvDilationW * (X - I1), ConvStrideW);
705 
706  // only work on HTilde and WTilde that contribute to non-padding area of input tensor
707  const auto IDTildeSliceBegin = math::integer_divide_floor(
708  math::max(I0, InLeftPadD - ConvDilationD * (ZTilde - I1)), ConvStrideD);
709  const auto IHTildeSliceBegin = math::integer_divide_floor(
710  math::max(I0, InLeftPadH - ConvDilationH * (YTilde - I1)), ConvStrideH);
711  const auto IWTildeSliceBegin = math::integer_divide_floor(
712  math::max(I0, InLeftPadW - ConvDilationW * (XTilde - I1)), ConvStrideW);
713 
714  const auto IDTildeSliceEnd = math::min(
715  DTilde, math::integer_divide_ceil(InLeftPadD + Di - I1, ConvStrideD) + I1);
716  const auto IHTildeSliceEnd = math::min(
717  HTilde, math::integer_divide_ceil(InLeftPadH + Hi - I1, ConvStrideH) + I1);
718  const auto IWTildeSliceEnd = math::min(
719  WTilde, math::integer_divide_ceil(InLeftPadW + Wi - I1, ConvStrideW) + I1);
720 
721  const auto DTildeSlice = IDTildeSliceEnd - IDTildeSliceBegin;
722  const auto HTildeSlice = IHTildeSliceEnd - IHTildeSliceBegin;
723  const auto WTildeSlice = IWTildeSliceEnd - IWTildeSliceBegin;
724 
725  // GemmK is different for each GEMM
726  const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
727  const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
728  const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
729 
730  // A: output tensor
731  const auto out_n_dop_hop_wop_k_grid_desc = transform_tensor_descriptor(
732  out_n_do_ho_wo_k_grid_desc,
734  make_pad_transform(Do, I0, I0),
735  make_pad_transform(Ho, I0, I0),
736  make_pad_transform(Wo, I0, I0),
738  make_tuple(
740  make_tuple(
742 
743  const auto out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc =
745  out_n_dop_hop_wop_k_grid_desc,
746  make_tuple(
748  make_embed_transform(make_tuple(ZDot, DTilde),
749  make_tuple(-ConvDilationD / GcdStrideDilationD, I1)),
750  make_embed_transform(make_tuple(YDot, HTilde),
751  make_tuple(-ConvDilationH / GcdStrideDilationH, I1)),
752  make_embed_transform(make_tuple(XDot, WTilde),
753  make_tuple(-ConvDilationW / GcdStrideDilationW, I1)),
755  make_tuple(
758  Sequence<1, 2>{},
759  Sequence<3, 4>{},
760  Sequence<5, 6>{},
761  Sequence<7>{}));
762 
763  const auto
764  out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc =
766  out_n_zdot_dtilde_ydot_htilde_xdot_wtilde_k_grid_desc,
768  make_slice_transform(ZDot, I0, ZDotSlice),
769  make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
770  make_slice_transform(YDot, I0, YDotSlice),
771  make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
772  make_slice_transform(XDot, I0, XDotSlice),
773  make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
776  Sequence<1>{},
777  Sequence<2>{},
778  Sequence<3>{},
779  Sequence<4>{},
780  Sequence<5>{},
781  Sequence<6>{},
782  Sequence<7>{}),
784  Sequence<1>{},
785  Sequence<2>{},
786  Sequence<3>{},
787  Sequence<4>{},
788  Sequence<5>{},
789  Sequence<6>{},
790  Sequence<7, 8>{}));
791 
792  const auto out_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
793  out_n_zdotslice_dtildeslice_ydotslice_htildeslice_xdotslice_wtildeslice_k0_k1_grid_desc,
794  make_tuple(
795  make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)),
796  make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
800 
801  // B weight tensor
802  const auto wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc =
804  wei_k_z_y_x_c_grid_desc,
805  make_tuple(
807  make_embed_transform(make_tuple(ZDot, ZTilde),
808  make_tuple(ConvStrideD / GcdStrideDilationD, I1)),
809  make_embed_transform(make_tuple(YDot, YTilde),
810  make_tuple(ConvStrideH / GcdStrideDilationH, I1)),
811  make_embed_transform(make_tuple(XDot, XTilde),
812  make_tuple(ConvStrideW / GcdStrideDilationW, I1)),
814  make_tuple(
817  Sequence<1, 2>{},
818  Sequence<3, 4>{},
819  Sequence<5, 6>{},
820  Sequence<7>{}));
821 
822  const auto wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc =
823  transform_tensor_descriptor(wei_k_zdot_ztilde_ydot_ytilde_xdot_xtilde_c_grid_desc,
825  make_slice_transform(ZDot, I0, ZDotSlice),
826  make_slice_transform(YDot, I0, YDotSlice),
827  make_slice_transform(XDot, I0, XDotSlice),
828  make_freeze_transform(i_ztilde),
829  make_freeze_transform(i_ytilde),
830  make_freeze_transform(i_xtilde),
833  Sequence<1>{},
834  Sequence<3>{},
835  Sequence<5>{},
836  Sequence<2>{},
837  Sequence<4>{},
838  Sequence<6>{},
839  Sequence<7>{}),
841  Sequence<2>{},
842  Sequence<3>{},
843  Sequence<4>{},
844  Sequence<>{},
845  Sequence<>{},
846  Sequence<>{},
847  Sequence<5>{}));
848 
849  const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
850  wei_k0_k1_zdotslice_ydotslice_xdotslice_c_grid_desc,
851  make_tuple(make_merge_transform(make_tuple(ZDotSlice, YDotSlice, XDotSlice, K0)),
856 
857  // C: input tensor
858  const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
859  in_n_di_hi_wi_c_grid_desc,
861  make_pad_transform(Di, InLeftPadD, InRightPadD),
862  make_pad_transform(Hi, InLeftPadH, InRightPadH),
863  make_pad_transform(Wi, InLeftPadW, InRightPadW),
865  make_tuple(
867  make_tuple(
869 
870  const auto in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc =
872  in_n_dip_hip_wip_c_grid_desc,
874  make_embed_transform(make_tuple(ZTilde, DTilde),
875  make_tuple(ConvDilationD, ConvStrideD)),
876  make_embed_transform(make_tuple(YTilde, HTilde),
877  make_tuple(ConvDilationH, ConvStrideH)),
878  make_embed_transform(make_tuple(XTilde, WTilde),
879  make_tuple(ConvDilationW, ConvStrideW)),
881  make_tuple(
884  Sequence<1, 2>{},
885  Sequence<3, 4>{},
886  Sequence<5, 6>{},
887  Sequence<7>{}));
888 
889  const auto in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc =
891  in_n_ztilde_dtilde_ytilde_htilde_xtilde_wtilde_c_grid_desc,
893  make_freeze_transform(i_ztilde),
894  make_slice_transform(DTilde, IDTildeSliceBegin, DTildeSlice),
895  make_freeze_transform(i_ytilde),
896  make_slice_transform(HTilde, IHTildeSliceBegin, HTildeSlice),
897  make_freeze_transform(i_xtilde),
898  make_slice_transform(WTilde, IWTildeSliceBegin, WTildeSlice),
901  Sequence<1>{},
902  Sequence<2>{},
903  Sequence<3>{},
904  Sequence<4>{},
905  Sequence<5>{},
906  Sequence<6>{},
907  Sequence<7>{}),
909  Sequence<>{},
910  Sequence<1>{},
911  Sequence<>{},
912  Sequence<2>{},
913  Sequence<>{},
914  Sequence<3>{},
915  Sequence<4>{}));
916 
917  const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor(
918  in_n_dtildeslice_htildeslice_wtildeslice_c_grid_desc,
919  make_tuple(
920  make_merge_transform(make_tuple(N, DTildeSlice, HTildeSlice, WTildeSlice)),
924 
925  return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc,
926  wei_gemmk0_gemmn_gemmk1_grid_desc,
927  in_gemmm_gemmn_grid_desc);
928  }
929 
930  } // function end
931 
932  template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
933  static auto GetABCGridDesc()
934  {
935  return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<1>(
936  1, 1, 1, {1}, {1}, {1}, {1}, {1}, {1}, {1}, {0});
937  }
938 
939  template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
940  static auto GetABCGridDesc()
941  {
942  return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<2>(
943  1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {0, 0});
944  }
945 
946  template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
947  static auto GetABCGridDesc()
948  {
949  return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(1,
950  1,
951  1,
952  {1, 1, 1},
953  {1, 1, 1},
954  {1, 1, 1},
955  {1, 1, 1},
956  {1, 1, 1},
957  {1, 1, 1},
958  {1, 1, 1},
959  {0, 0, 0});
960  }
961 
962  using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
963 
967 
968  // GridwiseGemm
969  using GridwiseGemm =
971  ADataType,
972  AccDataType,
973  CDataType,
978  MPerBlock,
979  NPerBlock,
980  K0PerBlock,
981  K1,
982  M1PerThread,
983  N1PerThread,
984  KPerThread,
985  M1N1ThreadClusterM1Xs,
986  M1N1ThreadClusterN1Xs,
987  ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
988  ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
989  ABlockTransferThreadClusterArrangeOrder,
990  ABlockTransferSrcAccessOrder,
991  ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
992  ABlockTransferSrcVectorTensorContiguousDimOrder,
993  ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
994  BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
995  BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
996  BBlockTransferThreadClusterArrangeOrder,
997  BBlockTransferSrcAccessOrder,
998  BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
999  BBlockTransferSrcVectorTensorContiguousDimOrder,
1000  BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
1001  CThreadTransferSrcDstAccessOrder,
1002  CThreadTransferSrcDstVectorDim,
1003  CThreadTransferDstScalarPerVector>;
1004 
1013  // Argument
1014  struct Argument : public BaseArgument
1015  {
1016  Argument(InDataType* p_in_grid,
1017  const WeiDataType* p_wei_grid,
1018  const OutDataType* p_out_grid,
1019  ck::index_t N,
1020  ck::index_t K,
1021  ck::index_t C,
1022  std::vector<ck::index_t> input_spatial_lengths,
1023  std::vector<ck::index_t> filter_spatial_lengths,
1024  std::vector<ck::index_t> output_spatial_lengths,
1025  std::vector<ck::index_t> conv_filter_strides,
1026  std::vector<ck::index_t> conv_filter_dilations,
1027  std::vector<ck::index_t> input_left_pads,
1028  std::vector<ck::index_t> input_right_pads,
1029  InElementwiseOperation in_element_op,
1030  WeiElementwiseOperation wei_element_op,
1031  OutElementwiseOperation out_element_op)
1032  : p_a_grid_{p_out_grid},
1033  p_b_grid_{p_wei_grid},
1034  p_c_grid_{p_in_grid},
1035  a_element_op_{out_element_op},
1036  b_element_op_{wei_element_op},
1037  c_element_op_{in_element_op},
1038  Conv_N_{N},
1039  Conv_K_{K},
1040  Conv_C_{C},
1041  input_spatial_lengths_{input_spatial_lengths},
1042  filter_spatial_lengths_{filter_spatial_lengths},
1043  output_spatial_lengths_{output_spatial_lengths},
1044  conv_filter_strides_{conv_filter_strides},
1045  conv_filter_dilations_{conv_filter_dilations},
1046  input_left_pads_{input_left_pads},
1047  input_right_pads_{input_right_pads}
1048  {
1049  CreateABCDesc<NDimSpatial>();
1050  }
1051 
1052  template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
1054  {
1055  const index_t ConvStrideW = conv_filter_strides_[0];
1056  const index_t ConvDilationW = conv_filter_dilations_[0];
1057  const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
1058  const auto XTilde = ConvStrideW / GcdStrideDilationW;
1059 
1060  const index_t X = filter_spatial_lengths_[0];
1061 
1062  for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
1063  {
1064  // check slice is valid
1065  const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
1066  if(XDotSlice <= 0)
1067  {
1068  continue;
1069  }
1070 
1071  const auto descs =
1072  DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
1073  Conv_N_,
1074  Conv_K_,
1075  Conv_C_,
1083  {i_xtilde});
1084  a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
1085  b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
1086  c_grid_desc_m_n_container_.push_back(descs[I2]);
1087 
1088  if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2]))
1089  {
1096 
1097  block_2_ctile_map_container_.push_back(
1099  }
1100  }
1101  }
1102  template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
1104  {
1105  const index_t ConvStrideH = conv_filter_strides_[0];
1106  const index_t ConvStrideW = conv_filter_strides_[1];
1107 
1108  const index_t ConvDilationH = conv_filter_dilations_[0];
1109  const index_t ConvDilationW = conv_filter_dilations_[1];
1110 
1111  const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
1112  const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
1113 
1114  const auto YTilde = ConvStrideH / GcdStrideDilationH;
1115  const auto XTilde = ConvStrideW / GcdStrideDilationW;
1116 
1117  const index_t Y = filter_spatial_lengths_[0];
1118  const index_t X = filter_spatial_lengths_[1];
1119  for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
1120  {
1121  for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
1122  {
1123  // check slice is valid
1124  const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
1125  const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
1126  if(YDotSlice * XDotSlice <= 0)
1127  {
1128  continue;
1129  }
1130 
1131  const auto descs =
1132  DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
1133  Conv_N_,
1134  Conv_K_,
1135  Conv_C_,
1143  {i_ytilde, i_xtilde});
1144  a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
1145  b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
1146  c_grid_desc_m_n_container_.push_back(descs[I2]);
1147 
1148  if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2]))
1149  {
1156 
1157  block_2_ctile_map_container_.push_back(
1159  }
1160  }
1161  }
1162  }
1163  template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
1165  {
1166  const index_t ConvStrideD = conv_filter_strides_[0];
1167  const index_t ConvStrideH = conv_filter_strides_[1];
1168  const index_t ConvStrideW = conv_filter_strides_[2];
1169 
1170  const index_t ConvDilationD = conv_filter_dilations_[0];
1171  const index_t ConvDilationH = conv_filter_dilations_[1];
1172  const index_t ConvDilationW = conv_filter_dilations_[2];
1173 
1174  const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
1175  const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
1176  const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
1177 
1178  const auto ZTilde = ConvStrideD / GcdStrideDilationD;
1179  const auto YTilde = ConvStrideH / GcdStrideDilationH;
1180  const auto XTilde = ConvStrideW / GcdStrideDilationW;
1181 
1182  const index_t Z = filter_spatial_lengths_[0];
1183  const index_t Y = filter_spatial_lengths_[1];
1184  const index_t X = filter_spatial_lengths_[2];
1185  for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
1186  {
1187  for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
1188  {
1189  for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
1190  {
1191  // check slice is valid
1192  const auto ZDotSlice = math::integer_divide_ceil(Z - i_ztilde, ZTilde);
1193  const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
1194  const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
1195  if(ZDotSlice * YDotSlice * XDotSlice <= 0)
1196  {
1197  continue;
1198  }
1199 
1200  const auto descs =
1201  DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
1202  Conv_N_,
1203  Conv_K_,
1204  Conv_C_,
1212  {i_ztilde, i_ytilde, i_xtilde});
1213  a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
1214  b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
1215  c_grid_desc_m_n_container_.push_back(descs[I2]);
1216 
1217  if(GridwiseGemm::CheckValidity(descs[I0], descs[I1], descs[I2]))
1218  {
1225 
1226  block_2_ctile_map_container_.push_back(
1228  }
1229  }
1230  }
1231  }
1232  }
1233 
1237  std::vector<AGridDesc_K0_M_K1> a_grid_desc_k0_m_k1_container_;
1238  std::vector<BGridDesc_K0_N_K1> b_grid_desc_k0_n_k1_container_;
1239  std::vector<CGridDesc_M_N> c_grid_desc_m_n_container_;
1240 
1241  std::vector<AGridDesc_K0_M0_M1_K1> a_grid_desc_k0_m0_m1_k1_container_;
1242  std::vector<BGridDesc_K0_N0_N1_K1> b_grid_desc_k0_n0_n1_k1_container_;
1243  std::vector<CGridDesc_M0_M10_M11_N0_N10_N11> c_grid_desc_m0_m10_m11_n0_n10_n11_container_;
1244 
1245  std::vector<DefaultBlock2CTileMap> block_2_ctile_map_container_;
1246 
1247  // element-wise op
1248  OutElementwiseOperation a_element_op_;
1249  WeiElementwiseOperation b_element_op_;
1250  InElementwiseOperation c_element_op_;
1251  // for checking IsSupportedArgument()
1255 
1256  std::vector<ck::index_t> input_spatial_lengths_;
1257  std::vector<ck::index_t> filter_spatial_lengths_;
1258  std::vector<ck::index_t> output_spatial_lengths_;
1259  std::vector<ck::index_t> conv_filter_strides_;
1260  std::vector<ck::index_t> conv_filter_dilations_;
1261  std::vector<ck::index_t> input_left_pads_;
1262  std::vector<ck::index_t> input_right_pads_;
1263  };
1264 
1265  // Invoker
1266  struct Invoker : public BaseInvoker
1267  {
1269 
1270  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
1271  {
1272  float ave_time = 0;
1273  for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
1274  {
1275  if(ck::EnvIsEnabled(CK_ENV(CK_LOGGING)))
1276  {
1277  std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
1278  << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
1279  << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", "
1280  << arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}"
1281  << std::endl;
1282 
1283  std::cout << "arg.b_grid_desc_k0_n_k1_container_{"
1284  << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", "
1285  << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", "
1286  << arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}"
1287  << std::endl;
1288 
1289  std::cout << "arg.c_grid_desc_m_n_container_{ "
1290  << arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", "
1291  << arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}"
1292  << std::endl;
1293 
1294  std::cout << "arg.c_grid_desc_m0_m10_m11_n0_n10_n11_container_( "
1296  << ", "
1298  << ", "
1300  << ", "
1302  << ", "
1304  << ", "
1306  << " ) " << std::endl;
1307  }
1308 
1312  {
1313  throw std::runtime_error(
1314  "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
1315  }
1316 
1317  const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize(
1319 
1320  auto launch_kernel = [&](auto has_main_k_block_loop,
1321  auto has_double_tail_k_block_loop) {
1322  constexpr bool has_main_loop = has_main_k_block_loop.value;
1323  constexpr bool has_double_loop = has_double_tail_k_block_loop;
1324 
1325  const auto kernel = kernel_gemm_dl_v1r3<
1326  GridwiseGemm,
1327  ADataType, // TODO: distiguish A/B datatype
1328  CDataType,
1333  has_main_loop,
1334  has_double_loop>;
1335 
1336  ave_time +=
1337  launch_and_time_kernel(stream_config,
1338  kernel,
1339  dim3(grid_size),
1340  dim3(BlockSize),
1341  0,
1342  arg.p_a_grid_,
1343  arg.p_b_grid_,
1344  arg.p_c_grid_,
1349  };
1350 
1351  const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_container_[i].GetLength(I0);
1352  const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
1353  const bool has_double_tail_k_block_loop =
1355 
1356  if(has_main_k_block_loop && has_double_tail_k_block_loop)
1357  {
1359  }
1360  else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
1361  {
1364  }
1365  else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
1366  {
1369  }
1370  else
1371  {
1374  }
1375  }
1376  return ave_time;
1377  }
1378 
1379  float Run(const BaseArgument* p_arg,
1380  const StreamConfig& stream_config = StreamConfig{}) override
1381  {
1382  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
1383  }
1384  };
1385 
1386  static constexpr bool IsValidCompilationParameter()
1387  {
1388  // TODO: properly implement this check
1389  return true;
1390  }
1391 
1392  static bool IsSupportedArgument(const Argument& arg)
1393  {
1394  // check device
1395  if(!(ck::get_device_name() == "gfx906" || ck::is_gfx103_supported() ||
1397  {
1398  return false;
1399  }
1400 
1401  if constexpr(ConvBackwardDataSpecialization ==
1403  {
1404  // check if it's 1x1, stride=1 pad = 0 conv
1405  for(int i = 0; i < NDimSpatial; i++)
1406  {
1407  if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
1408  arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
1409  {
1410  return false;
1411  }
1412  }
1413  }
1414 
1415  // matrix A
1416  {
1417  auto srcVectorLengths = ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1{};
1418  if(srcVectorLengths[I1] != 1 || srcVectorLengths[I2] != 1)
1419  {
1420  return false;
1421  }
1422  if(K1 % srcVectorLengths[I3] != 0 || K0PerBlock % srcVectorLengths[I0] != 0)
1423  {
1424  return false;
1425  }
1426 
1427  const index_t K = arg.Conv_K_;
1428 
1429  if(K % (srcVectorLengths[I0] * srcVectorLengths[I3]) != 0)
1430  {
1431  return false;
1432  }
1433  }
1434 
1435  // matrix B
1436  {
1437  auto srcLoadLenghts = BBlockTransferThreadSliceLengths_K0_N0_N1_K1{};
1438  auto srcVectorLengths = BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1{};
1439  if(srcVectorLengths[I0] != 1 || srcVectorLengths[I3] != 1)
1440  {
1441  return false;
1442  }
1443  if(srcLoadLenghts[I1] % srcVectorLengths[I1] != 0 ||
1444  srcLoadLenghts[I2] % srcVectorLengths[I2] != 0)
1445  {
1446  return false;
1447  }
1448 
1449  const index_t C = arg.Conv_K_;
1450 
1451  if(C % (srcVectorLengths[I1] * srcVectorLengths[I2]) != 0)
1452  {
1453  return false;
1454  }
1455  }
1456  // vector store C matrix into global memory
1457  if(!(arg.Conv_C_ % CThreadTransferDstScalarPerVector == 0))
1458  {
1459  std::cout << "Not surpport,because: arg.Conv_C_ % CThreadTransferDstScalarPerVector = "
1460  << arg.Conv_C_ % CThreadTransferDstScalarPerVector << std::endl;
1461  return false;
1462  }
1463 
1464  // Gridwise GEMM size
1465  for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
1466  {
1470  {
1471  return false;
1472  }
1473  }
1474  return true;
1475  }
1476 
1477  bool IsSupportedArgument(const BaseArgument* p_arg) override
1478  {
1479  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
1480  }
1481 
1482  static auto MakeArgument(InDataType* p_in_grid,
1483  const WeiDataType* p_wei_grid,
1484  const OutDataType* p_out_grid,
1485  ck::index_t N,
1486  ck::index_t K,
1487  ck::index_t C,
1488  std::vector<ck::index_t> input_spatial_lengths,
1489  std::vector<ck::index_t> filter_spatial_lengths,
1490  std::vector<ck::index_t> output_spatial_lengths,
1491  std::vector<ck::index_t> conv_filter_strides,
1492  std::vector<ck::index_t> conv_filter_dilations,
1493  std::vector<ck::index_t> input_left_pads,
1494  std::vector<ck::index_t> input_right_pads,
1495  InElementwiseOperation in_element_op,
1496  WeiElementwiseOperation wei_element_op,
1497  OutElementwiseOperation out_element_op)
1498  {
1499  return Argument{p_in_grid,
1500  p_wei_grid,
1501  p_out_grid,
1502  N,
1503  K,
1504  C,
1505  input_spatial_lengths,
1506  filter_spatial_lengths,
1507  output_spatial_lengths,
1508  conv_filter_strides,
1509  conv_filter_dilations,
1510  input_left_pads,
1511  input_right_pads,
1512  in_element_op,
1513  wei_element_op,
1514  out_element_op};
1515  }
1516 
1517  static auto MakeInvoker() { return Invoker{}; }
1518 
1519  std::unique_ptr<BaseArgument>
1520  MakeArgumentPointer(void* p_in_grid,
1521  const void* p_wei_grid,
1522  const void* p_out_grid,
1523  ck::index_t N,
1524  ck::index_t K,
1525  ck::index_t C,
1526  std::vector<ck::index_t> input_spatial_lengths,
1527  std::vector<ck::index_t> filter_spatial_lengths,
1528  std::vector<ck::index_t> output_spatial_lengths,
1529  std::vector<ck::index_t> conv_filter_strides,
1530  std::vector<ck::index_t> conv_filter_dilations,
1531  std::vector<ck::index_t> input_left_pads,
1532  std::vector<ck::index_t> input_right_pads,
1533  InElementwiseOperation in_element_op,
1534  WeiElementwiseOperation wei_element_op,
1535  OutElementwiseOperation out_element_op) override
1536  {
1537  return std::make_unique<Argument>(static_cast<InDataType*>(p_in_grid),
1538  static_cast<const WeiDataType*>(p_wei_grid),
1539  static_cast<const OutDataType*>(p_out_grid),
1540  N,
1541  K,
1542  C,
1543  input_spatial_lengths,
1544  filter_spatial_lengths,
1545  output_spatial_lengths,
1546  conv_filter_strides,
1547  conv_filter_dilations,
1548  input_left_pads,
1549  input_right_pads,
1550  in_element_op,
1551  wei_element_op,
1552  out_element_op);
1553  }
1554 
1555  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
1556  {
1557  return std::make_unique<Invoker>(Invoker{});
1558  }
1559 
1560  std::string GetTypeString() const override
1561  {
1562  auto str = std::stringstream();
1563 
1564  // clang-format off
1565  str << "DeviceConvNdBwdDataNwcKxcNwk_Dl"
1566  << "<"
1567  << BlockSize << ", "
1568  << MPerBlock << ", "
1569  << NPerBlock << ", "
1570  << K0PerBlock << ", "
1571  << K1
1572  << ">";
1573  if constexpr(ConvBackwardDataSpecialization ==
1575 
1576  str<< " Filter1x1Stride1Pad0";
1577  }
1578 
1579 
1580  return str.str();
1581  }
1582 };
1583 
1584 } // namespace device
1585 } // namespace tensor_operation
1586 } // namespace ck
#define CK_ENV(name)
Definition: env.hpp:128
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto integer_divide_floor(X x, Y y)
Definition: math.hpp:66
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
__host__ constexpr __device__ index_t gcd(index_t x, index_t y)
Definition: math.hpp:154
ConvolutionBackwardDataSpecialization
Definition: convolution_backward_data_specialization.hpp:11
CK_TILE_HOST float launch_kernel(const stream_config &s, Callables... callables)
Definition: kernel_launch.hpp:72
Definition: ck.hpp:264
typename remove_reference< T >::type remove_reference_t
Definition: type.hpp:292
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
std::string get_device_name()
Definition: device_prop.hpp:12
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:98
bool is_gfx12_supported()
Definition: device_prop.hpp:94
__host__ constexpr __device__ auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition: multi_index_transform_helper.hpp:48
__global__ void kernel_gemm_dl_v1r3(const FloatAB *__restrict__ p_a_grid, const FloatAB *__restrict__ p_b_grid, FloatC *__restrict__ p_c_grid, const AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1, const BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1, const CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11, const Block2CTileMap block_2_ctile_map)
Definition: gridwise_gemm_dl_v1r3.hpp:33
bool is_gfx103_supported()
Definition: device_prop.hpp:81
__host__ constexpr __device__ auto make_slice_transform(const LowLength &low_length, const SliceBegin &slice_begin, const SliceEnd &slice_end)
Definition: multi_index_transform_helper.hpp:110
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:139
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:19
bool is_gfx11_supported()
Definition: device_prop.hpp:88
Definition: stream_config.hpp:10
Definition: gridwise_gemm_dl_v1r3.hpp:93
__host__ static constexpr __device__ bool CheckValidity(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1, const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1, const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:129
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:208
__host__ static constexpr __device__ auto MakeBGridDescriptor_K0_N0_N1_K1(const BGridDesc_K0_N_K1 &b_grid_desc_k0_n_k1)
Definition: gridwise_gemm_dl_v1r3.hpp:188
__host__ static constexpr __device__ bool CalculateHasMainKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_v1r3.hpp:153
__host__ static constexpr __device__ auto MakeDefaultBlock2CTileMap(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: gridwise_gemm_dl_v1r3.hpp:241
__host__ static constexpr __device__ bool CalculateHasDoubleTailKBlockLoop(index_t K0)
Definition: gridwise_gemm_dl_v1r3.hpp:160
__host__ static constexpr __device__ auto MakeAGridDescriptor_K0_M0_M1_K1(const AGridDesc_K0_M_K1 &a_grid_desc_k0_m_k1)
Definition: gridwise_gemm_dl_v1r3.hpp:168
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_conv_bwd_data.hpp:25
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1015
const BDataType * p_b_grid_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1235
std::vector< CGridDesc_M0_M10_M11_N0_N10_N11 > c_grid_desc_m0_m10_m11_n0_n10_n11_container_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1243
index_t Conv_N_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1252
std::vector< BGridDesc_K0_N0_N1_K1 > b_grid_desc_k0_n0_n1_k1_container_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1242
std::vector< CGridDesc_M_N > c_grid_desc_m_n_container_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1239
std::vector< ck::index_t > filter_spatial_lengths_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1257
index_t Conv_C_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1254
InElementwiseOperation c_element_op_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1250
std::vector< ck::index_t > output_spatial_lengths_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1258
std::vector< ck::index_t > input_left_pads_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1261
std::vector< ck::index_t > conv_filter_strides_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1259
std::vector< ck::index_t > input_spatial_lengths_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1256
WeiElementwiseOperation b_element_op_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1249
index_t Conv_K_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1253
std::vector< BGridDesc_K0_N_K1 > b_grid_desc_k0_n_k1_container_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1238
CDataType * p_c_grid_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1236
std::vector< ck::index_t > input_right_pads_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1262
void CreateABCDesc()
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1053
std::vector< AGridDesc_K0_M_K1 > a_grid_desc_k0_m_k1_container_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1237
const ADataType * p_a_grid_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1234
std::vector< DefaultBlock2CTileMap > block_2_ctile_map_container_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1245
Argument(InDataType *p_in_grid, const WeiDataType *p_wei_grid, const OutDataType *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op)
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1016
OutElementwiseOperation a_element_op_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1248
std::vector< ck::index_t > conv_filter_dilations_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1260
std::vector< AGridDesc_K0_M0_M1_K1 > a_grid_desc_k0_m0_m1_k1_container_
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1241
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1267
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1379
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1270
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:81
InDataType CDataType
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:86
InDataType ABDataType
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:89
static constexpr auto I3
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:94
static constexpr bool IsValidCompilationParameter()
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1386
decltype(GetABCGridDesc< NDimSpatial >()) ABCGridDescs
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:962
static constexpr auto I7
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:98
static constexpr auto I5
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:96
std::string GetTypeString() const override
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1560
decltype(GridwiseGemm::MakeBGridDescriptor_K0_N0_N1_K1(BGridDesc_K0_N_K1{})) BGridDesc_K0_N0_N1_K1
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1008
static bool IsSupportedArgument(const Argument &arg)
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1392
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, std::vector< ck::index_t > tildes)
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:102
static auto MakeArgument(InDataType *p_in_grid, const WeiDataType *p_wei_grid, const OutDataType *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op)
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1482
remove_cvref_t< decltype(ABCGridDescs{}[I1])> BGridDesc_K0_N_K1
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:965
static constexpr auto I2
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:93
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1555
static auto MakeInvoker()
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1517
static constexpr auto I4
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:95
std::unique_ptr< BaseArgument > MakeArgumentPointer(void *p_in_grid, const void *p_wei_grid, const void *p_out_grid, ck::index_t N, ck::index_t K, ck::index_t C, std::vector< ck::index_t > input_spatial_lengths, std::vector< ck::index_t > filter_spatial_lengths, std::vector< ck::index_t > output_spatial_lengths, std::vector< ck::index_t > conv_filter_strides, std::vector< ck::index_t > conv_filter_dilations, std::vector< ck::index_t > input_left_pads, std::vector< ck::index_t > input_right_pads, InElementwiseOperation in_element_op, WeiElementwiseOperation wei_element_op, OutElementwiseOperation out_element_op) override
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1520
OutDataType ADataType
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:84
remove_cvref_t< decltype(ABCGridDescs{}[I2])> CGridDesc_M_N
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:966
remove_cvref_t< decltype(ABCGridDescs{}[I0])> AGridDesc_K0_M_K1
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:964
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1477
decltype(GridwiseGemm::MakeCGridDescriptor_M0_M10_M11_N0_N10_N11(CGridDesc_M_N{})) CGridDesc_M0_M10_M11_N0_N10_N11
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1010
GridwiseGemmDl_km_kn_mn_v1r3< BlockSize, ADataType, AccDataType, CDataType, InMemoryDataOperationEnum::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, K1, M1PerThread, N1PerThread, KPerThread, M1N1ThreadClusterM1Xs, M1N1ThreadClusterN1Xs, ABlockTransferThreadSliceLengths_K0_M0_M1_K1, ABlockTransferThreadClusterLengths_K0_M0_M1_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, ABlockTransferSrcVectorTensorContiguousDimOrder, ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, BBlockTransferThreadSliceLengths_K0_N0_N1_K1, BBlockTransferThreadClusterLengths_K0_N0_N1_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, BBlockTransferSrcVectorTensorContiguousDimOrder, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector > GridwiseGemm
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1003
static constexpr auto I6
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:97
decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{})) DefaultBlock2CTileMap
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1012
static constexpr auto I0
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:91
decltype(GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(AGridDesc_K0_M_K1{})) AGridDesc_K0_M0_M1_K1
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:1006
static constexpr auto I1
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:92
WeiDataType BDataType
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:85
static auto GetABCGridDesc()
Definition: device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp:933