/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm.hpp Source File
transform_conv_bwd_weight_to_gemm.hpp
Go to the documentation of this file.
1 
2 // SPDX-License-Identifier: MIT
3 // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
4 
5 #pragma once
6 
13 
14 namespace ck {
15 namespace tensor_operation {
16 
17 template <index_t NDimSpatial,
18  index_t MPerBlock,
19  index_t NPerBlock,
20  index_t GemmK1Number,
21  index_t K0PerBlock,
22  device::ConvolutionBackwardWeightSpecialization ConvBackwardWeightSpecialization>
24 {
25  static constexpr auto I0 = Number<0>{};
26  static constexpr auto I1 = Number<1>{};
27 
28  template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
29  constexpr static auto
31  const index_t Ho,
32  const index_t Wo,
33  const index_t K,
34  const std::array<index_t, NDimSpatial + 3>& output_strides)
35  {
36  const index_t WoStride = output_strides[4];
37  const auto KStride = Number<1>{};
38  return make_naive_tensor_descriptor(make_tuple(N * Ho * Wo, K),
39  make_tuple(WoStride, KStride));
40  }
41 
42  template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
43  constexpr static auto
45  const index_t Hi,
46  const index_t Wi,
47  const index_t C,
48  const std::array<index_t, NDimSpatial + 3>& input_strides)
49  {
50  const index_t NStride = input_strides[1];
51  const index_t HiStride = input_strides[3];
52  const index_t WiStride = input_strides[4];
53  const auto CStride = input_strides[2];
54  if constexpr(ConvBackwardWeightSpecialization ==
56  {
57  return make_naive_tensor_descriptor(make_tuple(N * Hi * Wi, C),
58  make_tuple(WiStride, CStride));
59  }
60  else
61  {
62  return make_naive_tensor_descriptor(make_tuple(N, Hi, Wi, C),
63  make_tuple(NStride, HiStride, WiStride, CStride));
64  }
65  }
66 
67  template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
68  constexpr static auto
70  const index_t Y,
71  const index_t X,
72  const index_t C,
73  const std::array<index_t, NDimSpatial + 3>& weights_strides)
74  {
75  const auto CStride = Number<1>{};
76  const auto KStride = weights_strides[1];
77  return make_naive_tensor_descriptor(make_tuple(K, Y * X * C), make_tuple(KStride, CStride));
78  }
79 
80  template <index_t NDim, typename enable_if<NDim == 3, bool>::type = false>
81  constexpr static auto
83  const index_t Do,
84  const index_t Ho,
85  const index_t Wo,
86  const index_t K,
87  const std::array<index_t, NDimSpatial + 3>& output_strides)
88  {
89  const index_t WoStride = output_strides[5];
90  const auto KStride = Number<1>{};
91  return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K),
92  make_tuple(WoStride, KStride));
93  }
94 
95  template <index_t NDim, typename enable_if<NDim == 3, bool>::type = false>
96  constexpr static auto
98  const index_t Di,
99  const index_t Hi,
100  const index_t Wi,
101  const index_t C,
102  const std::array<index_t, NDimSpatial + 3>& input_strides)
103  {
104  const index_t NStride = input_strides[1];
105  const index_t DiStride = input_strides[3];
106  const index_t HiStride = input_strides[4];
107  const index_t WiStride = input_strides[5];
108  const auto CStride = input_strides[2];
109  if constexpr(ConvBackwardWeightSpecialization ==
111  {
112  return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C),
113  make_tuple(WiStride, CStride));
114  }
115  else
116  {
118  make_tuple(N, Di, Hi, Wi, C),
119  make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
120  }
121  }
122 
123  template <index_t NDim, typename enable_if<NDim == 3, bool>::type = false>
124  constexpr static auto
126  const index_t Z,
127  const index_t Y,
128  const index_t X,
129  const index_t C,
130  const std::array<index_t, NDimSpatial + 3>& weights_strides)
131  {
132  const auto CStride = Number<1>{};
133  const auto KStride = weights_strides[1];
134  return make_naive_tensor_descriptor(make_tuple(K, Z * Y * X * C),
135  make_tuple(KStride, CStride));
136  }
137 
138  template <index_t NDim, typename enable_if<NDim == 1, bool>::type = false>
140  const index_t N,
141  const index_t K,
142  const index_t C,
143  const std::array<index_t, NDimSpatial>& input_spatial_lengths,
144  const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
145  const std::array<index_t, NDimSpatial>& output_spatial_lengths,
146  const std::array<index_t, NDimSpatial + 3>& /* input_strides */,
147  const std::array<index_t, NDimSpatial + 3>& /* weights_strides */,
148  const std::array<index_t, NDimSpatial + 3>& /* output_strides */,
149  const std::array<index_t, NDimSpatial>& conv_filter_strides,
150  const std::array<index_t, NDimSpatial>& conv_filter_dilations,
151  const std::array<index_t, NDimSpatial>& input_left_pads,
152  const std::array<index_t, NDimSpatial>& input_right_pads,
153  const index_t batch_k)
154  {
155  using namespace ck;
156 
157  const index_t Wi = input_spatial_lengths[0];
158  const index_t Wo = output_spatial_lengths[0];
159  const index_t X = filter_spatial_lengths[0];
160  const index_t ConvStrideW = conv_filter_strides[0];
161  const index_t ConvDilationW = conv_filter_dilations[0];
162  const index_t InLeftPadW = input_left_pads[0];
163  const index_t InRightPadW = input_right_pads[0];
164 
165  const index_t GemmKTotal = N * Wo;
166  const index_t GemmM = K;
167  const index_t GemmN = C * X;
168 
169  const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
170  const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
171 
172  const index_t GemmKBatch = batch_k;
173  const index_t GemmK0 =
174  math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
175  K0PerBlock;
176  const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
177 
178  if constexpr(ConvBackwardWeightSpecialization ==
180  {
181  // A: output tensor
182  const auto out_gemmktotal_gemmm_grid_desc =
184 
185  const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
186  out_gemmktotal_gemmm_grid_desc,
187  make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
191 
192  const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
193  out_gemmkpad_gemmm_grid_desc,
194  make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
198 
199  // B: input tensor
200  const auto in_gemmktotal_gemmn_grid_desc =
202 
203  const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
204  in_gemmktotal_gemmn_grid_desc,
205  make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
209 
210  const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
211  in_gemmkpad_gemmn_grid_desc,
212  make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
216 
217  // C: weight tensor
218  const auto wei_gemmm_gemmn_grid_desc =
220 
221  return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
222  in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
223  wei_gemmm_gemmn_grid_desc);
224  }
225  else
226  {
227  const auto out_gemmktotal_gemmm_grid_desc =
229  const auto in_n_wi_c_grid_desc =
231 
232  // A: output tensor
233  const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
234  out_gemmktotal_gemmm_grid_desc,
235  make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
239 
240  const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
241  out_gemmkpad_gemmm_grid_desc,
242  make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
246 
247  // B: input tensor
248  const auto in_n_wip_c_grid_desc = transform_tensor_descriptor(
249  in_n_wi_c_grid_desc,
251  make_pad_transform(Wi, InLeftPadW, InRightPadW),
255 
256  const auto in_n_x_wo_c_grid_desc = transform_tensor_descriptor(
257  in_n_wip_c_grid_desc,
258  make_tuple(
260  make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
264 
265  const auto in_gemmktotal_gemmn_grid_desc =
266  transform_tensor_descriptor(in_n_x_wo_c_grid_desc,
271 
272  const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
273  in_gemmktotal_gemmn_grid_desc,
274  make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
278 
279  const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
280  in_gemmkpad_gemmn_grid_desc,
281  make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
285 
286  // C: weight tensor
287  const auto wei_gemmm_gemmn_grid_desc =
289 
290  // Padd
291  const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
293  out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
296  make_right_pad_transform(GemmM, PadGemmM),
297  make_pass_through_transform(GemmK1Number)),
300 
301  const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
303  in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
306  make_right_pad_transform(GemmN, PadGemmN),
307  make_pass_through_transform(GemmK1Number)),
310 
311  const auto wei_gemmm_gemmn_pad_grid_desc =
312  transform_tensor_descriptor(wei_gemmm_gemmn_grid_desc,
313  make_tuple(make_right_pad_transform(GemmM, PadGemmM),
314  make_right_pad_transform(GemmN, PadGemmN)),
317 
318  return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
319  in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
320  wei_gemmm_gemmn_pad_grid_desc);
321  }
322  }
323 
324  template <index_t NDim, typename enable_if<NDim == 2, bool>::type = false>
326  const index_t N,
327  const index_t K,
328  const index_t C,
329  const std::array<index_t, NDimSpatial>& input_spatial_lengths,
330  const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
331  const std::array<index_t, NDimSpatial>& output_spatial_lengths,
332  const std::array<index_t, NDimSpatial + 3>& input_strides,
333  const std::array<index_t, NDimSpatial + 3>& weights_strides,
334  const std::array<index_t, NDimSpatial + 3>& output_strides,
335  const std::array<index_t, NDimSpatial>& conv_filter_strides,
336  const std::array<index_t, NDimSpatial>& conv_filter_dilations,
337  const std::array<index_t, NDimSpatial>& input_left_pads,
338  const std::array<index_t, NDimSpatial>& input_right_pads,
339  const index_t batch_k)
340  {
341  using namespace ck;
342 
343  const index_t Hi = input_spatial_lengths[0];
344  const index_t Wi = input_spatial_lengths[1];
345 
346  const index_t Ho = output_spatial_lengths[0];
347  const index_t Wo = output_spatial_lengths[1];
348 
349  const index_t Y = filter_spatial_lengths[0];
350  const index_t X = filter_spatial_lengths[1];
351 
352  const index_t ConvStrideH = conv_filter_strides[0];
353  const index_t ConvStrideW = conv_filter_strides[1];
354 
355  const index_t ConvDilationH = conv_filter_dilations[0];
356  const index_t ConvDilationW = conv_filter_dilations[1];
357 
358  const index_t InLeftPadH = input_left_pads[0];
359  const index_t InLeftPadW = input_left_pads[1];
360 
361  const index_t InRightPadH = input_right_pads[0];
362  const index_t InRightPadW = input_right_pads[1];
363 
364  const index_t GemmKTotal = N * Ho * Wo;
365  const index_t GemmM = K;
366  const index_t GemmN = C * X * Y;
367 
368  const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
369  const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
370 
371  const index_t GemmKBatch = batch_k;
372  const index_t GemmK0 =
373  math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
374  K0PerBlock;
375  const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
376 
377  const auto out_grid_desc = make_out_grid_desc<NDim>(N, Ho, Wo, K, output_strides);
378  const auto in_grid_desc = make_in_grid_desc<NDim>(N, Hi, Wi, C, input_strides);
379  const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Y, X, C, weights_strides);
380 
381  if constexpr(ConvBackwardWeightSpecialization ==
383  {
384  // A: output tensor
385  const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
386  out_grid_desc,
387  make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
391 
392  const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
393  out_gemmkpad_gemmm_grid_desc,
394  make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
398 
399  // B: input tensor
400  const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
401  in_grid_desc,
402  make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
406 
407  const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
408  in_gemmkpad_gemmn_grid_desc,
409  make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
413 
414  return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
415  in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
416  wei_grid_desc);
417  }
418  else
419  {
420  // A: output tensor
421  const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
422  out_grid_desc,
423  make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
427 
428  const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
429  out_gemmkpad_gemmm_grid_desc,
430  make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
434 
435  // B: input tensor
436  const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
437  in_grid_desc,
439  make_pad_transform(Hi, InLeftPadH, InRightPadH),
440  make_pad_transform(Wi, InLeftPadW, InRightPadW),
444 
445  const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
446  in_n_hip_wip_c_grid_desc,
447  make_tuple(
449  make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
450  make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
454 
455  const auto in_gemmktotal_gemmn_grid_desc =
456  transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
458  make_merge_transform(make_tuple(N, Ho, Wo))),
461 
462  const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
463  in_gemmktotal_gemmn_grid_desc,
464  make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
468 
469  const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
470  in_gemmkpad_gemmn_grid_desc,
471  make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
475 
476  // Padd
477  const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
479  out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
482  make_right_pad_transform(GemmM, PadGemmM),
483  make_pass_through_transform(GemmK1Number)),
486 
487  const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
489  in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
492  make_right_pad_transform(GemmN, PadGemmN),
493  make_pass_through_transform(GemmK1Number)),
496 
497  const auto wei_gemmm_gemmn_pad_grid_desc =
498  transform_tensor_descriptor(wei_grid_desc,
499  make_tuple(make_right_pad_transform(GemmM, PadGemmM),
500  make_right_pad_transform(GemmN, PadGemmN)),
503 
504  return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
505  in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
506  wei_gemmm_gemmn_pad_grid_desc);
507  }
508  }
509 
510  template <index_t NDim, typename enable_if<NDim == 3, bool>::type = false>
512  const index_t N,
513  const index_t K,
514  const index_t C,
515  const std::array<index_t, NDimSpatial>& input_spatial_lengths,
516  const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
517  const std::array<index_t, NDimSpatial>& output_spatial_lengths,
518  const std::array<index_t, NDimSpatial + 3>& input_strides,
519  const std::array<index_t, NDimSpatial + 3>& weights_strides,
520  const std::array<index_t, NDimSpatial + 3>& output_strides,
521  const std::array<index_t, NDimSpatial>& conv_filter_strides,
522  const std::array<index_t, NDimSpatial>& conv_filter_dilations,
523  const std::array<index_t, NDimSpatial>& input_left_pads,
524  const std::array<index_t, NDimSpatial>& input_right_pads,
525  const index_t batch_k)
526  {
527  using namespace ck;
528 
529  const index_t Di = input_spatial_lengths[0];
530  const index_t Hi = input_spatial_lengths[1];
531  const index_t Wi = input_spatial_lengths[2];
532 
533  const index_t Do = output_spatial_lengths[0];
534  const index_t Ho = output_spatial_lengths[1];
535  const index_t Wo = output_spatial_lengths[2];
536 
537  const index_t Z = filter_spatial_lengths[0];
538  const index_t Y = filter_spatial_lengths[1];
539  const index_t X = filter_spatial_lengths[2];
540 
541  const index_t ConvStrideD = conv_filter_strides[0];
542  const index_t ConvStrideH = conv_filter_strides[1];
543  const index_t ConvStrideW = conv_filter_strides[2];
544 
545  const index_t ConvDilationD = conv_filter_dilations[0];
546  const index_t ConvDilationH = conv_filter_dilations[1];
547  const index_t ConvDilationW = conv_filter_dilations[2];
548 
549  const index_t InLeftPadD = input_left_pads[0];
550  const index_t InLeftPadH = input_left_pads[1];
551  const index_t InLeftPadW = input_left_pads[2];
552 
553  const index_t InRightPadD = input_right_pads[0];
554  const index_t InRightPadH = input_right_pads[1];
555  const index_t InRightPadW = input_right_pads[2];
556 
557  const index_t GemmKTotal = N * Do * Ho * Wo;
558  const index_t GemmM = K;
559  const index_t GemmN = C * Z * X * Y;
560 
561  const auto PadGemmM = MPerBlock - GemmM % MPerBlock;
562  const auto PadGemmN = NPerBlock - GemmN % NPerBlock;
563 
564  const index_t GemmKBatch = batch_k;
565  const index_t GemmK0 =
566  math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock * GemmKBatch) *
567  K0PerBlock;
568  const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1Number;
569 
570  const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
571  const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
572  const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Z, Y, X, C, weights_strides);
573 
574  if constexpr(ConvBackwardWeightSpecialization ==
576  {
577  // A: output tensor
578  const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
579  out_grid_desc,
580  make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
584 
585  const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
586  out_gemmkpad_gemmm_grid_desc,
587  make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
591 
592  // B: input tensor
593  const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
594  in_grid_desc,
595  make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
599 
600  const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
601  in_gemmkpad_gemmn_grid_desc,
602  make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
606 
607  return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
608  in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
609  wei_grid_desc);
610  }
611  else
612  {
613  // A: output tensor
614  const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
615  out_grid_desc,
616  make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
620 
621  const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
622  out_gemmkpad_gemmm_grid_desc,
623  make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
627 
628  // B: input tensor
629  const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
630  in_grid_desc,
632  make_pad_transform(Di, InLeftPadD, InRightPadD),
633  make_pad_transform(Hi, InLeftPadH, InRightPadH),
634  make_pad_transform(Wi, InLeftPadW, InRightPadW),
636  make_tuple(
638  make_tuple(
640 
641  const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
642  in_n_dip_hip_wip_c_grid_desc,
643  make_tuple(
645  make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
646  make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
647  make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
649  make_tuple(
652  Sequence<1, 2>{},
653  Sequence<3, 4>{},
654  Sequence<5, 6>{},
655  Sequence<7>{}));
656 
657  const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor(
658  in_n_z_do_y_ho_x_wo_c_grid_desc,
660  make_merge_transform(make_tuple(N, Do, Ho, Wo))),
663 
664  const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
665  in_gemmktotal_gemmn_grid_desc,
666  make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
670 
671  const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
672  in_gemmkpad_gemmn_grid_desc,
673  make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1Number)),
677 
678  // Padd
679  const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
681  out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
684  make_right_pad_transform(GemmM, PadGemmM),
685  make_pass_through_transform(GemmK1Number)),
688 
689  const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
691  in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
694  make_right_pad_transform(GemmN, PadGemmN),
695  make_pass_through_transform(GemmK1Number)),
698 
699  const auto wei_gemmm_gemmn_pad_grid_desc =
700  transform_tensor_descriptor(wei_grid_desc,
701  make_tuple(make_right_pad_transform(GemmM, PadGemmM),
702  make_right_pad_transform(GemmN, PadGemmN)),
705 
706  return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
707  in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
708  wei_gemmm_gemmn_pad_grid_desc);
709  }
710  } // function end
711 };
712 
713 } // namespace tensor_operation
714 } // namespace ck
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
ConvolutionBackwardWeightSpecialization
Definition: convolution_backward_weight_specialization.hpp:11
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition: multi_index_transform_helper.hpp:48
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:19
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: transform_conv_bwd_weight_to_gemm.hpp:24
static constexpr auto I0
Definition: transform_conv_bwd_weight_to_gemm.hpp:25
constexpr static auto make_out_grid_desc(const index_t N, const index_t Do, const index_t Ho, const index_t Wo, const index_t K, const std::array< index_t, NDimSpatial+3 > &output_strides)
Definition: transform_conv_bwd_weight_to_gemm.hpp:82
constexpr static auto make_out_grid_desc(const index_t N, const index_t Ho, const index_t Wo, const index_t K, const std::array< index_t, NDimSpatial+3 > &output_strides)
Definition: transform_conv_bwd_weight_to_gemm.hpp:30
static constexpr auto I1
Definition: transform_conv_bwd_weight_to_gemm.hpp:26
constexpr static auto make_in_grid_desc(const index_t N, const index_t Di, const index_t Hi, const index_t Wi, const index_t C, const std::array< index_t, NDimSpatial+3 > &input_strides)
Definition: transform_conv_bwd_weight_to_gemm.hpp:97
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(const index_t N, const index_t K, const index_t C, const std::array< index_t, NDimSpatial > &input_spatial_lengths, const std::array< index_t, NDimSpatial > &filter_spatial_lengths, const std::array< index_t, NDimSpatial > &output_spatial_lengths, const std::array< index_t, NDimSpatial+3 > &input_strides, const std::array< index_t, NDimSpatial+3 > &weights_strides, const std::array< index_t, NDimSpatial+3 > &output_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const index_t batch_k)
Definition: transform_conv_bwd_weight_to_gemm.hpp:325
constexpr static auto make_wei_grid_desc(const index_t K, const index_t Y, const index_t X, const index_t C, const std::array< index_t, NDimSpatial+3 > &weights_strides)
Definition: transform_conv_bwd_weight_to_gemm.hpp:69
constexpr static auto make_in_grid_desc(const index_t N, const index_t Hi, const index_t Wi, const index_t C, const std::array< index_t, NDimSpatial+3 > &input_strides)
Definition: transform_conv_bwd_weight_to_gemm.hpp:44
constexpr static auto make_wei_grid_desc(const index_t K, const index_t Z, const index_t Y, const index_t X, const index_t C, const std::array< index_t, NDimSpatial+3 > &weights_strides)
Definition: transform_conv_bwd_weight_to_gemm.hpp:125
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(const index_t N, const index_t K, const index_t C, const std::array< index_t, NDimSpatial > &input_spatial_lengths, const std::array< index_t, NDimSpatial > &filter_spatial_lengths, const std::array< index_t, NDimSpatial > &output_spatial_lengths, const std::array< index_t, NDimSpatial+3 > &, const std::array< index_t, NDimSpatial+3 > &, const std::array< index_t, NDimSpatial+3 > &, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const index_t batch_k)
Definition: transform_conv_bwd_weight_to_gemm.hpp:139