/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp Source File
grouped_convolution_backward_data_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
16 
17 #ifdef CK_EXPERIMENTAL_BUILDER
18 #include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_data.hpp"
19 #endif
20 
21 namespace ck_tile {
22 
24 template <typename GroupedConvTraitsType_, typename TilePartitioner_>
26 {
28 
30  TransformConvBwdDataToGemm<GroupedConvTraitsType_::NDimSpatial,
31  GroupedConvTraitsType_::ConvSpecialization,
32  GroupedConvTraitsType_::VectorSizeA,
33  GroupedConvTraitsType_::VectorSizeB,
34  GroupedConvTraitsType_::VectorSizeC,
35  true>; // Split N enabled
36  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
37 
38  static constexpr auto I0 = number<0>();
39  static constexpr auto I1 = number<1>();
40 
41  template <
42  typename InLay = typename GroupedConvTraitsType_::InLayout,
43  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
44  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
45  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
46  std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
47  std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
48  bool>::type = false>
50  {
51  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
52  static_cast<index_t>(args.N_),
53  static_cast<index_t>(args.C_),
54  static_cast<index_t>(args.input_spatial_lengths_[0])};
55  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
56  static_cast<index_t>(args.K_),
57  static_cast<index_t>(args.C_),
58  static_cast<index_t>(args.filter_spatial_lengths_[0])};
59  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
60  static_cast<index_t>(args.N_),
61  static_cast<index_t>(args.K_),
62  static_cast<index_t>(args.output_spatial_lengths_[0])};
63 
64  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
65  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0])};
66  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
67  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
68 
69  k_batch = args.k_batch;
70 
71  in_ptr = args.in_ptr;
72  wei_ptr = args.wei_ptr;
73  for(index_t d = 0; d < NumDTensor; d++)
74  {
75  ds_ptr[d] = args.ds_ptr[d];
76  }
77  out_ptr = args.out_ptr;
78 
79  const index_t X = wei_g_k_c_xs_lengths[3];
80  const index_t ConvStrideW = conv_filter_strides[0];
81  const index_t ConvDilationW = conv_filter_dilations[0];
82  const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW);
83  const auto XTilde = ConvStrideW / GcdStrideDilationW;
84 
85  for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
86  {
87  const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde);
88 
89  if(XDotSlice <= 0)
90  {
91  continue;
92  }
93 
95  {
96  gemm_count++;
97  // Avoid array segfault
98  continue;
99  }
100 
101  tildes = {i_xtilde};
102 
103  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
110  tildes};
111 
112  auto grid_descs =
113  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
114  GroupedConvTraitsType_::NDimSpatial>(1);
115 
116  a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{});
117  b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{});
118  c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{});
119 
120  const index_t grid_size_grp =
121  TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0),
122  c_grid_descs_m_n[gemm_count].get_length(I1));
123 
125  block_ends[gemm_count] = grid_size_ + grid_size_grp;
126 
127  grid_size_ += grid_size_grp;
128 
129  // Get the actual split N from transformer
130  n_per_split = conv_to_gemm_transformer.GetN();
131  original_n = conv_to_gemm_transformer.GetOriginalN();
133 
134  ++gemm_count;
135  }
136  group_stride_a = args.K_; // A: Out NWGK
137  group_stride_b = args.K_ * args.C_ *
138  std::accumulate(args.filter_spatial_lengths_.begin(),
139  args.filter_spatial_lengths_.end(),
140  1,
141  std::multiplies<index_t>()); // B: Wei GKXC
142  group_stride_c = args.C_; // C: In NWGC
143 
144  input_batch_stride = args.C_ * args.G_ * args.input_spatial_lengths_[0];
145  output_batch_stride = args.K_ * args.G_ * args.output_spatial_lengths_[0];
146 
147  GemmBatch = args.G_;
148  }
149 
150  template <
151  typename InLay = typename GroupedConvTraitsType_::InLayout,
152  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
153  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
154  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
155  std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
156  std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
157  bool>::type = false>
159  {
160  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
161  static_cast<index_t>(args.N_),
162  static_cast<index_t>(args.C_),
163  static_cast<index_t>(args.input_spatial_lengths_[0]),
164  static_cast<index_t>(args.input_spatial_lengths_[1])};
165  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
166  static_cast<index_t>(args.K_),
167  static_cast<index_t>(args.C_),
168  static_cast<index_t>(args.filter_spatial_lengths_[0]),
169  static_cast<index_t>(args.filter_spatial_lengths_[1])};
170  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
171  static_cast<index_t>(args.N_),
172  static_cast<index_t>(args.K_),
173  static_cast<index_t>(args.output_spatial_lengths_[0]),
174  static_cast<index_t>(args.output_spatial_lengths_[1])};
175 
176  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
177  static_cast<index_t>(args.conv_filter_strides_[1])};
178  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
179  static_cast<index_t>(args.conv_filter_dilations_[1])};
180  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
181  static_cast<index_t>(args.input_left_pads_[1])};
182  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
183  static_cast<index_t>(args.input_right_pads_[1])};
184 
185  k_batch = args.k_batch;
186 
187  in_ptr = args.in_ptr;
188  wei_ptr = args.wei_ptr;
189  for(index_t d = 0; d < NumDTensor; d++)
190  {
191  ds_ptr[d] = args.ds_ptr[d];
192  }
193  out_ptr = args.out_ptr;
194 
195  const index_t Y = wei_g_k_c_xs_lengths[3];
196  const index_t X = wei_g_k_c_xs_lengths[4];
197  const index_t ConvStrideH = conv_filter_strides[0];
198  const index_t ConvStrideW = conv_filter_strides[1];
199  const index_t ConvDilationH = conv_filter_dilations[0];
200  const index_t ConvDilationW = conv_filter_dilations[1];
201  const auto GcdStrideDilationH = gcd(ConvStrideH, ConvDilationH);
202  const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW);
203  const auto YTilde = ConvStrideH / GcdStrideDilationH;
204  const auto XTilde = ConvStrideW / GcdStrideDilationW;
205 
206  for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
207  {
208  for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
209  {
210  const auto YDotSlice = integer_divide_ceil(Y - i_ytilde, YTilde);
211  const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde);
212 
213  if(XDotSlice * YDotSlice <= 0)
214  {
215  continue;
216  }
217 
219  {
220  gemm_count++;
221  // Avoid array segfault
222  continue;
223  }
224 
225  tildes = {i_ytilde, i_xtilde};
226 
227  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
234  tildes};
235 
236  auto grid_descs = conv_to_gemm_transformer
237  .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
238  GroupedConvTraitsType_::NDimSpatial>(1);
239 
240  a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{});
241  b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{});
242  c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{});
243 
244  const index_t grid_size_grp =
245  TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0),
246  c_grid_descs_m_n[gemm_count].get_length(I1));
247 
249  block_ends[gemm_count] = grid_size_ + grid_size_grp;
250 
251  grid_size_ += grid_size_grp;
252 
253  // Get the actual split N from transformer
254  n_per_split = conv_to_gemm_transformer.GetN();
255  original_n = conv_to_gemm_transformer.GetOriginalN();
257 
258  ++gemm_count;
259  }
260  }
261  group_stride_a = args.K_; // A: Out NWGK
262  group_stride_b = args.K_ * args.C_ *
263  std::accumulate(args.filter_spatial_lengths_.begin(),
264  args.filter_spatial_lengths_.end(),
265  1,
266  std::multiplies<index_t>()); // B: Wei GKXC
267  group_stride_c = args.C_; // C: In NWGC
268 
270  args.C_ * args.G_ * args.input_spatial_lengths_[0] * args.input_spatial_lengths_[1];
272  args.K_ * args.G_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
273 
274  GemmBatch = args.G_;
275  }
276 
277  template <
278  typename InLay = typename GroupedConvTraitsType_::InLayout,
279  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
280  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
281  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
282  std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
283  std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
284  bool>::type = false>
286  {
287  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
288  static_cast<index_t>(args.N_),
289  static_cast<index_t>(args.C_),
290  static_cast<index_t>(args.input_spatial_lengths_[0]),
291  static_cast<index_t>(args.input_spatial_lengths_[1]),
292  static_cast<index_t>(args.input_spatial_lengths_[2])};
293  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
294  static_cast<index_t>(args.K_),
295  static_cast<index_t>(args.C_),
296  static_cast<index_t>(args.filter_spatial_lengths_[0]),
297  static_cast<index_t>(args.filter_spatial_lengths_[1]),
298  static_cast<index_t>(args.filter_spatial_lengths_[2])};
299  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
300  static_cast<index_t>(args.N_),
301  static_cast<index_t>(args.K_),
302  static_cast<index_t>(args.output_spatial_lengths_[0]),
303  static_cast<index_t>(args.output_spatial_lengths_[1]),
304  static_cast<index_t>(args.output_spatial_lengths_[2])};
305 
306  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
307  static_cast<index_t>(args.conv_filter_strides_[1]),
308  static_cast<index_t>(args.conv_filter_strides_[2])};
309  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
310  static_cast<index_t>(args.conv_filter_dilations_[1]),
311  static_cast<index_t>(args.conv_filter_dilations_[2])};
312  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
313  static_cast<index_t>(args.input_left_pads_[1]),
314  static_cast<index_t>(args.input_left_pads_[2])};
315  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
316  static_cast<index_t>(args.input_right_pads_[1]),
317  static_cast<index_t>(args.input_right_pads_[2])};
318 
319  k_batch = args.k_batch;
320 
321  in_ptr = args.in_ptr;
322  wei_ptr = args.wei_ptr;
323  for(index_t d = 0; d < NumDTensor; d++)
324  {
325  ds_ptr[d] = args.ds_ptr[d];
326  }
327  out_ptr = args.out_ptr;
328 
329  const index_t Z = wei_g_k_c_xs_lengths[3];
330  const index_t Y = wei_g_k_c_xs_lengths[4];
331  const index_t X = wei_g_k_c_xs_lengths[5];
332  const index_t ConvStrideD = conv_filter_strides[0];
333  const index_t ConvStrideH = conv_filter_strides[1];
334  const index_t ConvStrideW = conv_filter_strides[2];
335  const index_t ConvDilationD = conv_filter_dilations[0];
336  const index_t ConvDilationH = conv_filter_dilations[1];
337  const index_t ConvDilationW = conv_filter_dilations[2];
338  const auto GcdStrideDilationD = gcd(ConvStrideD, ConvDilationD);
339  const auto GcdStrideDilationH = gcd(ConvStrideH, ConvDilationH);
340  const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW);
341  const auto ZTilde = ConvStrideD / GcdStrideDilationD;
342  const auto YTilde = ConvStrideH / GcdStrideDilationH;
343  const auto XTilde = ConvStrideW / GcdStrideDilationW;
344 
345  for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
346  {
347  for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
348  {
349  for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
350  {
351  const auto ZDotSlice = integer_divide_ceil(Z - i_ztilde, ZTilde);
352  const auto YDotSlice = integer_divide_ceil(Y - i_ytilde, YTilde);
353  const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde);
354 
355  if(ZDotSlice * XDotSlice * YDotSlice <= 0)
356  {
357  continue;
358  }
359 
361  {
362  gemm_count++;
363  // Avoid array segfault
364  continue;
365  }
366 
367  tildes = {i_ztilde, i_ytilde, i_xtilde};
368 
369  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
376  tildes};
377 
378  auto grid_descs = conv_to_gemm_transformer
379  .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
380  GroupedConvTraitsType_::NDimSpatial>(1);
381 
382  a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{});
383  b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{});
384  c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{});
385 
386  const index_t grid_size_grp =
387  TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0),
388  c_grid_descs_m_n[gemm_count].get_length(I1));
389 
391  block_ends[gemm_count] = grid_size_ + grid_size_grp;
392 
393  grid_size_ += grid_size_grp;
394 
395  // Get the actual split N from transformer
396  n_per_split = conv_to_gemm_transformer.GetN();
397  original_n = conv_to_gemm_transformer.GetOriginalN();
399 
400  ++gemm_count;
401  }
402  }
403  }
404 
405  group_stride_a = args.K_; // A: Out NWGK
406  group_stride_b = args.K_ * args.C_ *
407  std::accumulate(args.filter_spatial_lengths_.begin(),
408  args.filter_spatial_lengths_.end(),
409  1,
410  std::multiplies<index_t>()); // B: Wei GKXC
411  group_stride_c = args.C_; // C: In NWGC
412 
413  input_batch_stride = args.C_ * args.G_ * args.input_spatial_lengths_[0] *
415  output_batch_stride = args.K_ * args.G_ * args.output_spatial_lengths_[0] *
417 
418  GemmBatch = args.G_; // C: In NWGC
419  }
420 
421  static constexpr index_t MaxGroupedGemmGroupsNum = 128;
422 
425 
429 
430  static constexpr index_t NonSpatialDims = 3;
434 
440 
445 
446  const void* out_ptr;
447  void* in_ptr;
448  std::array<const void*, NumDTensor> ds_ptr;
449  const void* wei_ptr;
450 
454 
457 
461 
462  // Split-N support fields - initialize to safe defaults
463  index_t n_splits = 1; // Number of batch splits (e.g., 2 for 128→64×2)
464  index_t n_per_split = 1; // Batches per split (N_ from transformer)
465  index_t original_n = 1; // Original batch size before splitting
466  index_t input_batch_stride = 0; // Stride to next batch in input tensor
467  index_t output_batch_stride = 0; // Stride to next batch in output tensor
468 };
469 
508 template <typename GroupedConvTraitsType_,
509  typename TilePartitioner_,
510  typename GemmPipeline_,
511  typename EpiloguePipeline_>
513 {
514  static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
516  GroupedConvTraitsType_::ConvSpecialization;
523 
528 
530  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
531 
532  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
533 
537 
539 
542  static constexpr index_t MaxGroupedGemmGroupsNum =
544 
545  // TODO: Enable this
546  static constexpr bool IsSplitKSupported = false;
547 
548  static constexpr auto I0 = number<0>();
549  static constexpr auto I1 = number<1>();
550  static constexpr auto I2 = number<2>();
551  static constexpr auto I3 = number<3>();
552 
553  static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
554  "Not supported!");
555  static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
556  static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
557  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
558  "Not supported C GEMM layout!");
559  static_assert(GroupedConvTraitsType_::ExplicitGemm == false, "Not supported yet!");
560 
561  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
562  {
563  // clang-format off
564  return concat('_', "grouped_convolution_backward_data",
565  gemm_prec_str<InDataType, WeiDataType>(),
566  "gemm",
567  GemmPipeline::GetName(),
568  "epilogue",
569  EpiloguePipeline::GetName());
570  // clang-format on
571  }
572 
573 #ifdef CK_EXPERIMENTAL_BUILDER
574  CK_TILE_HOST std::string GetInstanceString() const
575  {
576  static_assert(ck_tile::reflect::HasInstanceTraits<GroupedConvolutionBackwardDataKernel>,
577  "Specialization of instance_traits not found. Please check that a "
578  "specialization exists in file "
579  "ck_tile/builder/reflect/"
580  "instance_traits_tile_grouped_convolution_backward_data.hpp "
581  "for the given template parameters.");
582  return ck_tile::reflect::instance_string<GroupedConvolutionBackwardDataKernel>();
583  }
584 #endif
585 
587  {
588  // enable batched grouped gemm
589  return dim3(kargs.grid_size_, kargs.GemmBatch, kargs.n_splits * kargs.k_batch);
590  }
591 
592  CK_TILE_HOST static constexpr auto BlockSize()
593  {
594  return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
595  }
596 
599  {
601  }
602 
604  {
605  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
606  }
607 
608  CK_TILE_HOST static bool
610  {
611  if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
614  {
615  if(kargs.k_batch != 1)
616  {
617  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
618  {
619  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
620  }
621  return false;
622  }
623  }
624 
626  {
627  return false;
628  }
629 
630  const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
631  const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
632 
633  // check ConvSpecialization
635  {
636  // check if it's 1x1, stride=1 conv
637  for(index_t i = 0; i < NDimSpatial; ++i)
638  {
639  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
640  const index_t ConvStride = kargs.conv_filter_strides[i];
641  const index_t LeftPad = kargs.input_left_pads[i];
642  const index_t RightPad = kargs.input_right_pads[i];
643 
644  if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
645  {
646  return false;
647  }
648  }
649  }
651  {
652  // check if it's 1x1 conv
653  for(index_t i = 0; i < NDimSpatial; ++i)
654  {
655  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
656  const index_t LeftPad = kargs.input_left_pads[i];
657  const index_t RightPad = kargs.input_right_pads[i];
658 
659  if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
660  {
661  return false;
662  }
663  }
664  }
666  {
667  if(ConvC != 1)
668  {
669  return false;
670  }
671  for(index_t i = 0; i < NDimSpatial; ++i)
672  {
673  const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
674 
675  if(filter_spatial_dim != I3)
676  {
677  return false;
678  }
679  }
680  }
681 
682  namespace ctc = tensor_layout::convolution;
683 
684  if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
685  std::is_same_v<InLayout, ctc::NDHWGC>)
686  {
687  // Check access per C
688  if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
689  {
690  CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
691  return false;
692  }
693  }
694  else
695  {
696  CK_TILE_ERROR("Not supported input layout!");
697  return false;
698  }
699 
700  // FIXME: layout
701  if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
702  std::is_same_v<WeiLayout, ctc::GKYXC> ||
703  std::is_same_v<WeiLayout, ctc::GKZYXC>)
704  {
705  if(ConvC % GroupedConvTraitsType_::VectorSizeC != 0)
706  {
707  CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
708  return false;
709  }
710  }
711  else
712  {
713  CK_TILE_ERROR("Not supported weight layout!");
714  return false;
715  }
716 
717  if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
718  std::is_same_v<OutLayout, ctc::NHWGK> ||
719  std::is_same_v<OutLayout, ctc::NDHWGK>)
720  {
721  if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
722  {
723  CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
724  return false;
725  }
726  }
727  else
728  {
729  CK_TILE_ERROR("Not supported output layout!");
730  return false;
731  }
732 
733  return true;
734  }
735 
736  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
737  CK_TILE_DEVICE static auto
739  const InDataType* b_ptr,
740  const std::array<const void*, NumDTensor>& ds_ptr,
741  WeiDataType* c_ptr,
743  const index_t group_id)
744  {
745  static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
746  static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!");
747  const auto& a_tensor_view = [&]() {
748  return make_tensor_view<address_space_enum::global>(
749  a_ptr,
750  kargs.a_grid_descs_m_k[group_id]); // A: out
751  }();
752 
753  const auto& b_tensor_view = [&]() {
754  return make_tensor_view<address_space_enum::global>(
755  b_ptr,
756  kargs.b_grid_descs_n_k[group_id]); // B: weight
757  }();
758 
759  const auto& c_tensor_view = [&]() {
760  return make_tensor_view<address_space_enum::global>(c_ptr,
761  kargs.c_grid_descs_m_n[group_id]);
762  }();
763 
764  const auto& ds_tensor_view = generate_tuple(
765  [&](auto i) {
766  static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
767  "Not supported!");
768  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
769  "Not supported!");
770  static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, OutDataType>,
771  "Not supported!");
772 
773  return make_tensor_view<address_space_enum::global>(
774  static_cast<OutDataType*>(ds_ptr[i]), kargs.c_grid_descs_m_n[group_id]);
775  },
777 
778  return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
779  }
780 
781  template <typename TensorView>
782  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
783  {
784  const auto& a_pad_view = [&]() {
785  const auto& a_tensor_view = views.at(I0);
786  return pad_tensor_view(a_tensor_view,
790  }();
791 
792  const auto& b_pad_view = [&]() {
793  const auto& b_tensor_view = views.at(I1);
794  return pad_tensor_view(b_tensor_view,
798  }();
799 
800  const auto& ds_tensor_view = views.at(I2);
801  const auto& ds_pad_view = generate_tuple(
802  [&](auto i) {
803  return pad_tensor_view(ds_tensor_view[i],
807  },
809 
810  const auto& c_pad_view = [&]() {
811  const auto& c_tensor_view = views.at(I3);
812  return pad_tensor_view(c_tensor_view,
816  }();
817 
818  return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
819  }
820 
821  template <typename PadView>
822  CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
823  const index_t i_m,
824  const index_t i_n,
825  const index_t i_k = 0)
826  {
827  const auto& a_pad_view = views.at(I0);
828  const auto& b_pad_view = views.at(I1);
829  const auto& ds_pad_view = views.at(I2);
830  const auto& c_pad_view = views.at(I3);
831 
832  const auto& a_block_window = [&]() {
833  return make_tile_window(a_pad_view,
836  {i_m, i_k});
837  }();
838 
839  const auto& b_block_window = [&]() {
840  return make_tile_window(b_pad_view,
843  {i_k, i_n});
844  }();
845 
846  const auto ds_block_window = generate_tuple(
847  [&](auto i) {
848  return make_tile_window(ds_pad_view[i],
851  {i_m, i_n});
852  },
854 
855  auto c_block_window = make_tile_window(
856  c_pad_view,
858  {i_m, i_n});
859 
860  return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
861  }
862 
875  CK_TILE_DEVICE static void RunGemm(const OutDataType* a_ptr,
876  const InDataType* b_ptr,
877  const std::array<const void*, NumDTensor>& ds_ptr,
878  WeiDataType* c_ptr,
879  void* smem_ptr_0,
881  const index_t block_idx_m,
882  const index_t block_idx_n,
883  const index_t group_id)
884  {
885  // Create Gemm tensor views, pad views and tile windows
886  const auto& gemm_tensor_views_tuple =
887  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
888  a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
889 
890  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
891  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
892 
893  const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(
894  gemm_pad_views.at(I0).get_tensor_descriptor().get_length(I1)));
895 
896  // Run GEMM cooperatively by whole workgroup.
897  const auto& a_block_window = gemm_tile_windows.at(I0);
898  const auto& b_block_window = gemm_tile_windows.at(I1);
899  const auto& d_block_window = gemm_tile_windows.at(I2);
900 
901  const auto& c_block_tile = GemmPipeline{}.template operator()(
902  a_block_window, b_block_window, num_loop, smem_ptr_0);
903 
904  // Run Epilogue Pipeline
905  auto& c_block_window = gemm_tile_windows.at(I3);
906 
907  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
908  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
909  }
910 
926  CK_TILE_DEVICE static void RunGemm2LDS(const OutDataType* a_ptr,
927  const InDataType* b_ptr,
928  const std::array<const void*, NumDTensor>& ds_ptr,
929  WeiDataType* c_ptr,
930  void* __restrict__ smem_ptr_0,
931  void* __restrict__ smem_ptr_1,
933  const index_t block_idx_m,
934  const index_t block_idx_n,
935  const index_t group_id)
936  {
937  // Create Gemm tensor views, pad views and tile windows
938  const auto& gemm_tensor_views_tuple =
939  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
940  a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
941  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
942  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
943 
944  const index_t num_loop = amd_wave_read_first_lane(
945  TilePartitioner::GetLoopNum(gemm_tile_windows.at(I0).get_length(I1)));
946 
947  // Run GEMM cooperatively by whole workgroup.
948  const auto& a_block_window = gemm_tile_windows.at(I0);
949  const auto& b_block_window = gemm_tile_windows.at(I1);
950  const auto& d_block_window = gemm_tile_windows.at(I2);
951 
952  const auto& c_block_tile = GemmPipeline{}.template operator()(
953  a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
954 
955  // Run Epilogue Pipeline
956  auto& c_block_window = gemm_tile_windows.at(I3);
957 
958  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
959  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
960  }
961 
963  index_t block_id) const
964  {
965  index_t left = 0;
966  index_t right = kargs.gemm_count;
967  index_t group_id = index_t((left + right) >> 1);
968 
969  while((!(block_id >= kargs.block_starts[group_id] &&
970  block_id < kargs.block_ends[group_id])) &&
971  left <= right)
972  {
973  if(block_id < kargs.block_starts[group_id])
974  {
975  right = group_id;
976  }
977  else
978  {
979  left = group_id;
980  }
981  group_id = index_t((left + right) >> 1);
982  }
983 
984  return group_id;
985  }
986 
988  {
989  const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
990  const index_t group_id = FindGroupId(kargs, blockIdX);
991 
993  kargs.block_starts[group_id],
994  kargs.c_grid_descs_m_n[group_id].get_length(I0),
995  kargs.c_grid_descs_m_n[group_id].get_length(I1));
996 
997  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
998  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
999 
1000  const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
1001  const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
1002  const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
1003  const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
1004 
1005  const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
1006 
1007  // SplitN
1008  const index_t split_n_idx = __builtin_amdgcn_readfirstlane(blockIdZ / kargs.k_batch);
1009  const index_t split_n_offset =
1010  __builtin_amdgcn_readfirstlane(split_n_idx * kargs.n_per_split);
1011 
1012  const long_index_t output_batch_offset =
1013  static_cast<long_index_t>(split_n_offset) *
1014  static_cast<long_index_t>(kargs.output_batch_stride);
1015  const long_index_t input_batch_offset = static_cast<long_index_t>(split_n_offset) *
1016  static_cast<long_index_t>(kargs.input_batch_stride);
1017 
1018  // SplitK
1019  // TODO: Implement SplitK support
1020  // const index_t split_k_idx =
1021  // __builtin_amdgcn_readfirstlane(blockIdZ - split_n_idx * kargs.k_batch);
1022 
1023  // options
1024  // conv_bwd_data = Out * Weight = In
1025  const OutDataType* a_ptr =
1026  static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a + output_batch_offset;
1027  const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) + group_offset_b;
1028  InDataType* c_ptr =
1029  static_cast<InDataType*>(kargs.in_ptr) + group_offset_c + input_batch_offset;
1030 
1031  // allocate LDS
1032  __shared__ char smem_ptr_0[GetSmemSize()];
1033 
1034  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
1035  {
1036  __shared__ char smem_ptr_1[GetSmemSize()];
1037  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
1038  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
1040  {
1041  RunGemm2LDS(a_ptr,
1042  b_ptr,
1043  kargs.ds_ptr,
1044  c_ptr,
1045  smem_ptr_0,
1046  smem_ptr_1,
1047  kargs,
1048  i_m,
1049  i_n,
1050  group_id);
1051  }
1052  }
1053  else
1054  {
1055  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
1056  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
1058  {
1059  RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n, group_id);
1060  }
1061  }
1062  }
1063 };
1064 
1065 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
ConvolutionSpecialization
Definition: convolution_specialization.hpp:11
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE index_t gcd(index_t x, index_t y)
Definition: math.hpp:268
int64_t long_index_t
Definition: integer.hpp:11
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr bool is_same_v
Definition: type.hpp:283
__device__ X atomic_add(X *p_dst, const X &x)
The Grouped Convolution kernel device arguments.
Definition: grouped_convolution_backward_data_kernel.hpp:26
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > wei_g_k_c_xs_lengths
Definition: grouped_convolution_backward_data_kernel.hpp:432
static constexpr auto I1
Definition: grouped_convolution_backward_data_kernel.hpp:39
CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs &args)
Definition: grouped_convolution_backward_data_kernel.hpp:49
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_dilations
Definition: grouped_convolution_backward_data_kernel.hpp:436
std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:448
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_strides
Definition: grouped_convolution_backward_data_kernel.hpp:435
array< index_t, MaxGroupedGemmGroupsNum > block_starts
Definition: grouped_convolution_backward_data_kernel.hpp:455
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_left_pads
Definition: grouped_convolution_backward_data_kernel.hpp:437
long_index_t group_stride_b
Definition: grouped_convolution_backward_data_kernel.hpp:459
long_index_t group_stride_c
Definition: grouped_convolution_backward_data_kernel.hpp:460
array< index_t, MaxGroupedGemmGroupsNum > block_ends
Definition: grouped_convolution_backward_data_kernel.hpp:456
const void * out_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:446
remove_cvref_t< decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(1))> ABCGridDescs
Definition: grouped_convolution_backward_data_kernel.hpp:424
remove_cvref_t< decltype(ABCGridDescs{}[number< 1 >{}])> BGridDescNK
Definition: grouped_convolution_backward_data_kernel.hpp:427
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_backward_data_kernel.hpp:27
array< index_t, GroupedConvTraitsType_::NDimSpatial > tildes
Definition: grouped_convolution_backward_data_kernel.hpp:439
remove_cvref_t< decltype(ABCGridDescs{}[number< 0 >{}])> AGridDescMK
Definition: grouped_convolution_backward_data_kernel.hpp:426
const void * wei_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:449
index_t n_per_split
Definition: grouped_convolution_backward_data_kernel.hpp:464
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > out_g_n_k_wos_lengths
Definition: grouped_convolution_backward_data_kernel.hpp:433
long_index_t group_stride_a
Definition: grouped_convolution_backward_data_kernel.hpp:458
index_t GemmBatch
Definition: grouped_convolution_backward_data_kernel.hpp:442
void * in_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:447
index_t n_splits
Definition: grouped_convolution_backward_data_kernel.hpp:463
index_t gemm_count
Definition: grouped_convolution_backward_data_kernel.hpp:444
array< CGridDescMN, MaxGroupedGemmGroupsNum > c_grid_descs_m_n
Definition: grouped_convolution_backward_data_kernel.hpp:453
index_t original_n
Definition: grouped_convolution_backward_data_kernel.hpp:465
index_t grid_size_
Definition: grouped_convolution_backward_data_kernel.hpp:443
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_right_pads
Definition: grouped_convolution_backward_data_kernel.hpp:438
array< BGridDescNK, MaxGroupedGemmGroupsNum > b_grid_descs_n_k
Definition: grouped_convolution_backward_data_kernel.hpp:452
index_t k_batch
Definition: grouped_convolution_backward_data_kernel.hpp:441
static constexpr auto I0
Definition: grouped_convolution_backward_data_kernel.hpp:38
static constexpr index_t MaxGroupedGemmGroupsNum
Definition: grouped_convolution_backward_data_kernel.hpp:421
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > in_g_n_c_wis_lengths
Definition: grouped_convolution_backward_data_kernel.hpp:431
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_data_kernel.hpp:36
index_t output_batch_stride
Definition: grouped_convolution_backward_data_kernel.hpp:467
index_t input_batch_stride
Definition: grouped_convolution_backward_data_kernel.hpp:466
array< AGridDescMK, MaxGroupedGemmGroupsNum > a_grid_descs_m_k
Definition: grouped_convolution_backward_data_kernel.hpp:451
remove_cvref_t< decltype(ABCGridDescs{}[number< 2 >{}])> CGridDescMN
Definition: grouped_convolution_backward_data_kernel.hpp:428
static constexpr index_t NonSpatialDims
Definition: grouped_convolution_backward_data_kernel.hpp:430
The Grouped Conv kernel host arguments.
Definition: grouped_convolution_utils.hpp:20
InPtr in_ptr
Definition: grouped_convolution_utils.hpp:39
OutPtr out_ptr
Definition: grouped_convolution_utils.hpp:42
WeiPtr wei_ptr
Definition: grouped_convolution_utils.hpp:40
index_t k_batch
Definition: grouped_convolution_utils.hpp:43
const std::vector< const void * > ds_ptr
Definition: grouped_convolution_utils.hpp:41
The Grouped Convolution Backward Data kernel template.
Definition: grouped_convolution_backward_data_kernel.hpp:513
static constexpr index_t NDimSpatial
Definition: grouped_convolution_backward_data_kernel.hpp:514
static constexpr CK_TILE_HOST auto BlockSize()
Definition: grouped_convolution_backward_data_kernel.hpp:592
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_convolution_backward_data_kernel.hpp:518
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n, const index_t i_k=0)
Definition: grouped_convolution_backward_data_kernel.hpp:822
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: grouped_convolution_backward_data_kernel.hpp:782
GroupedConvBwdDataKernelArgs< GroupedConvTraitsType_, TilePartitioner > GroupedConvBwdDataKernelArgsSpecialized
Definition: grouped_convolution_backward_data_kernel.hpp:541
remove_cvref_t< typename GemmPipeline::ADataType > InDataType
Definition: grouped_convolution_backward_data_kernel.hpp:534
static constexpr index_t MaxGroupedGemmGroupsNum
Definition: grouped_convolution_backward_data_kernel.hpp:542
static constexpr auto I1
Definition: grouped_convolution_backward_data_kernel.hpp:549
CK_TILE_DEVICE void operator()(GroupedConvBwdDataKernelArgsSpecialized &kargs) const
Definition: grouped_convolution_backward_data_kernel.hpp:987
static constexpr auto I3
Definition: grouped_convolution_backward_data_kernel.hpp:551
remove_cvref_t< typename GroupedConvTraitsType_::OutLayout > OutLayout
Definition: grouped_convolution_backward_data_kernel.hpp:526
static constexpr ConvolutionSpecialization ConvSpecialization
Definition: grouped_convolution_backward_data_kernel.hpp:515
static constexpr CK_TILE_HOST GroupedConvBwdDataKernelArgsSpecialized MakeKernelArgs(const GroupedConvBwdDataHostArgs &hostArgs)
Definition: grouped_convolution_backward_data_kernel.hpp:598
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_data_kernel.hpp:530
remove_cvref_t< typename GemmPipeline::BDataType > WeiDataType
Definition: grouped_convolution_backward_data_kernel.hpp:535
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_convolution_backward_data_kernel.hpp:519
remove_cvref_t< typename EpiloguePipeline::ODataType > OutDataType
Definition: grouped_convolution_backward_data_kernel.hpp:538
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_backward_data_kernel.hpp:517
remove_cvref_t< typename GroupedConvTraitsType_::WeiLayout > WeiLayout
Definition: grouped_convolution_backward_data_kernel.hpp:525
static constexpr index_t kBlockSize
Definition: grouped_convolution_backward_data_kernel.hpp:532
static CK_TILE_HOST bool IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_data_kernel.hpp:609
remove_cvref_t< typename GemmPipeline::BLayout > GemmBLayout
Definition: grouped_convolution_backward_data_kernel.hpp:521
remove_cvref_t< typename GroupedConvTraitsType_::DsLayout > DsLayout
Definition: grouped_convolution_backward_data_kernel.hpp:527
static constexpr auto I2
Definition: grouped_convolution_backward_data_kernel.hpp:550
static CK_TILE_DEVICE auto MakeGemmTensorViews(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t group_id)
Definition: grouped_convolution_backward_data_kernel.hpp:738
static CK_TILE_HOST auto GridSize(const GroupedConvBwdDataKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_data_kernel.hpp:586
remove_cvref_t< typename GemmPipeline::ALayout > GemmALayout
Definition: grouped_convolution_backward_data_kernel.hpp:520
remove_cvref_t< typename EpiloguePipeline::DsLayout > GemmDsLayout
Definition: grouped_convolution_backward_data_kernel.hpp:529
CK_TILE_DEVICE index_t FindGroupId(const GroupedConvBwdDataKernelArgsSpecialized &kargs, index_t block_id) const
Definition: grouped_convolution_backward_data_kernel.hpp:962
static CK_TILE_DEVICE void RunGemm(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *smem_ptr_0, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t block_idx_m, const index_t block_idx_n, const index_t group_id)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_data_kernel.hpp:875
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: grouped_convolution_backward_data_kernel.hpp:603
static CK_TILE_DEVICE void RunGemm2LDS(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t block_idx_m, const index_t block_idx_n, const index_t group_id)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_data_kernel.hpp:926
static constexpr bool IsSplitKSupported
Definition: grouped_convolution_backward_data_kernel.hpp:546
remove_cvref_t< typename GroupedConvTraitsType_::InLayout > InLayout
Definition: grouped_convolution_backward_data_kernel.hpp:524
remove_cvref_t< typename GemmPipeline::CLayout > GemmCLayout
Definition: grouped_convolution_backward_data_kernel.hpp:522
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: grouped_convolution_backward_data_kernel.hpp:536
static CK_TILE_HOST const std::string GetName()
Definition: grouped_convolution_backward_data_kernel.hpp:561
static constexpr auto I0
Definition: grouped_convolution_backward_data_kernel.hpp:548
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from 1D raw-indexes.
Definition: gemm_tile_partitioner.hpp:192
Definition: transform_conv_bwd_data_to_gemm.hpp:21
CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N([[maybe_unused]] const index_t GemmKBatch) const
Definition: transform_conv_bwd_data_to_gemm.hpp:659
constexpr CK_TILE_HOST_DEVICE auto & at(index_t i)
Definition: array.hpp:110
Definition: integral_constant.hpp:13
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition: convolution_parameter.hpp:130
ck_tile::long_index_t K_
Definition: convolution_parameter.hpp:126
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition: convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > input_right_pads_
Definition: convolution_parameter.hpp:137
ck_tile::long_index_t G_
Definition: convolution_parameter.hpp:124
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition: convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition: convolution_parameter.hpp:129
ck_tile::long_index_t C_
Definition: convolution_parameter.hpp:127
ck_tile::long_index_t N_
Definition: convolution_parameter.hpp:125
std::vector< ck_tile::long_index_t > input_left_pads_
Definition: convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition: convolution_parameter.hpp:134
Definition: type_traits.hpp:115
Definition: sequence.hpp:49
#define CK_TILE_ENV(name)
Definition: env.hpp:145