/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp Source File
grouped_convolution_backward_data_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
16 
17 namespace ck_tile {
18 
20 template <typename GroupedConvTraitsType_, typename TilePartitioner_>
22 {
24 
26  TransformConvBwdDataToGemm<GroupedConvTraitsType_::NDimSpatial,
27  GroupedConvTraitsType_::ConvSpecialization,
28  GroupedConvTraitsType_::VectorSizeA,
29  GroupedConvTraitsType_::VectorSizeB,
30  GroupedConvTraitsType_::VectorSizeC>;
31  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
32 
33  static constexpr auto I0 = number<0>();
34  static constexpr auto I1 = number<1>();
35 
36  template <
37  typename InLay = typename GroupedConvTraitsType_::InLayout,
38  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
39  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
40  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
41  std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
42  std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
43  bool>::type = false>
45  {
46  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
47  static_cast<index_t>(args.N_),
48  static_cast<index_t>(args.C_),
49  static_cast<index_t>(args.input_spatial_lengths_[0])};
50  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
51  static_cast<index_t>(args.K_),
52  static_cast<index_t>(args.C_),
53  static_cast<index_t>(args.filter_spatial_lengths_[0])};
54  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
55  static_cast<index_t>(args.N_),
56  static_cast<index_t>(args.K_),
57  static_cast<index_t>(args.output_spatial_lengths_[0])};
58 
59  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
60  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0])};
61  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
62  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
63 
64  k_batch = args.k_batch;
65 
66  in_ptr = args.in_ptr;
67  wei_ptr = args.wei_ptr;
68  for(index_t d = 0; d < NumDTensor; d++)
69  {
70  ds_ptr[d] = args.ds_ptr[d];
71  }
72  out_ptr = args.out_ptr;
73 
74  const index_t X = wei_g_k_c_xs_lengths[3];
75  const index_t ConvStrideW = conv_filter_strides[0];
76  const index_t ConvDilationW = conv_filter_dilations[0];
77  const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW);
78  const auto XTilde = ConvStrideW / GcdStrideDilationW;
79 
80  for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
81  {
82  const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde);
83 
84  if(XDotSlice <= 0)
85  {
86  continue;
87  }
88 
90  {
91  gemm_count++;
92  // Avoid array segfault
93  continue;
94  }
95 
96  tildes = {i_xtilde};
97 
98  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
105  tildes};
106 
107  auto grid_descs =
108  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
109  GroupedConvTraitsType_::NDimSpatial>(1);
110 
111  a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{});
112  b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{});
113  c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{});
114 
115  const index_t grid_size_grp =
116  TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0),
117  c_grid_descs_m_n[gemm_count].get_length(I1));
118 
120  block_ends[gemm_count] = grid_size_ + grid_size_grp;
121 
122  grid_size_ += grid_size_grp;
123 
124  ++gemm_count;
125  }
126  group_stride_a = args.K_; // A: Out NWGK
127  group_stride_b = args.K_ * args.C_ *
128  std::accumulate(args.filter_spatial_lengths_.begin(),
129  args.filter_spatial_lengths_.end(),
130  1,
131  std::multiplies<index_t>()); // B: Wei GKXC
132  group_stride_c = args.C_; // C: In NWGC
133 
134  GemmBatch = args.G_;
135  }
136 
137  template <
138  typename InLay = typename GroupedConvTraitsType_::InLayout,
139  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
140  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
141  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
142  std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
143  std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
144  bool>::type = false>
146  {
147  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
148  static_cast<index_t>(args.N_),
149  static_cast<index_t>(args.C_),
150  static_cast<index_t>(args.input_spatial_lengths_[0]),
151  static_cast<index_t>(args.input_spatial_lengths_[1])};
152  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
153  static_cast<index_t>(args.K_),
154  static_cast<index_t>(args.C_),
155  static_cast<index_t>(args.filter_spatial_lengths_[0]),
156  static_cast<index_t>(args.filter_spatial_lengths_[1])};
157  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
158  static_cast<index_t>(args.N_),
159  static_cast<index_t>(args.K_),
160  static_cast<index_t>(args.output_spatial_lengths_[0]),
161  static_cast<index_t>(args.output_spatial_lengths_[1])};
162 
163  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
164  static_cast<index_t>(args.conv_filter_strides_[1])};
165  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
166  static_cast<index_t>(args.conv_filter_dilations_[1])};
167  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
168  static_cast<index_t>(args.input_left_pads_[1])};
169  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
170  static_cast<index_t>(args.input_right_pads_[1])};
171 
172  k_batch = args.k_batch;
173 
174  in_ptr = args.in_ptr;
175  wei_ptr = args.wei_ptr;
176  for(index_t d = 0; d < NumDTensor; d++)
177  {
178  ds_ptr[d] = args.ds_ptr[d];
179  }
180  out_ptr = args.out_ptr;
181 
182  const index_t Y = wei_g_k_c_xs_lengths[3];
183  const index_t X = wei_g_k_c_xs_lengths[4];
184  const index_t ConvStrideH = conv_filter_strides[0];
185  const index_t ConvStrideW = conv_filter_strides[1];
186  const index_t ConvDilationH = conv_filter_dilations[0];
187  const index_t ConvDilationW = conv_filter_dilations[1];
188  const auto GcdStrideDilationH = gcd(ConvStrideH, ConvDilationH);
189  const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW);
190  const auto YTilde = ConvStrideH / GcdStrideDilationH;
191  const auto XTilde = ConvStrideW / GcdStrideDilationW;
192 
193  for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
194  {
195  for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
196  {
197  const auto YDotSlice = integer_divide_ceil(Y - i_ytilde, YTilde);
198  const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde);
199 
200  if(XDotSlice * YDotSlice <= 0)
201  {
202  continue;
203  }
204 
206  {
207  gemm_count++;
208  // Avoid array segfault
209  continue;
210  }
211 
212  tildes = {i_ytilde, i_xtilde};
213 
214  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
221  tildes};
222 
223  auto grid_descs = conv_to_gemm_transformer
224  .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
225  GroupedConvTraitsType_::NDimSpatial>(1);
226 
227  a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{});
228  b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{});
229  c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{});
230 
231  const index_t grid_size_grp =
232  TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0),
233  c_grid_descs_m_n[gemm_count].get_length(I1));
234 
236  block_ends[gemm_count] = grid_size_ + grid_size_grp;
237 
238  grid_size_ += grid_size_grp;
239 
240  ++gemm_count;
241  }
242  }
243  group_stride_a = args.K_; // A: Out NWGK
244  group_stride_b = args.K_ * args.C_ *
245  std::accumulate(args.filter_spatial_lengths_.begin(),
246  args.filter_spatial_lengths_.end(),
247  1,
248  std::multiplies<index_t>()); // B: Wei GKXC
249  group_stride_c = args.C_; // C: In NWGC
250 
251  GemmBatch = args.G_;
252  }
253 
254  template <
255  typename InLay = typename GroupedConvTraitsType_::InLayout,
256  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
257  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
258  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
259  std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
260  std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
261  bool>::type = false>
263  {
264  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
265  static_cast<index_t>(args.N_),
266  static_cast<index_t>(args.C_),
267  static_cast<index_t>(args.input_spatial_lengths_[0]),
268  static_cast<index_t>(args.input_spatial_lengths_[1]),
269  static_cast<index_t>(args.input_spatial_lengths_[2])};
270  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
271  static_cast<index_t>(args.K_),
272  static_cast<index_t>(args.C_),
273  static_cast<index_t>(args.filter_spatial_lengths_[0]),
274  static_cast<index_t>(args.filter_spatial_lengths_[1]),
275  static_cast<index_t>(args.filter_spatial_lengths_[2])};
276  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
277  static_cast<index_t>(args.N_),
278  static_cast<index_t>(args.K_),
279  static_cast<index_t>(args.output_spatial_lengths_[0]),
280  static_cast<index_t>(args.output_spatial_lengths_[1]),
281  static_cast<index_t>(args.output_spatial_lengths_[2])};
282 
283  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
284  static_cast<index_t>(args.conv_filter_strides_[1]),
285  static_cast<index_t>(args.conv_filter_strides_[2])};
286  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
287  static_cast<index_t>(args.conv_filter_dilations_[1]),
288  static_cast<index_t>(args.conv_filter_dilations_[2])};
289  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
290  static_cast<index_t>(args.input_left_pads_[1]),
291  static_cast<index_t>(args.input_left_pads_[2])};
292  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
293  static_cast<index_t>(args.input_right_pads_[1]),
294  static_cast<index_t>(args.input_right_pads_[2])};
295 
296  k_batch = args.k_batch;
297 
298  in_ptr = args.in_ptr;
299  wei_ptr = args.wei_ptr;
300  for(index_t d = 0; d < NumDTensor; d++)
301  {
302  ds_ptr[d] = args.ds_ptr[d];
303  }
304  out_ptr = args.out_ptr;
305 
306  const index_t Z = wei_g_k_c_xs_lengths[3];
307  const index_t Y = wei_g_k_c_xs_lengths[4];
308  const index_t X = wei_g_k_c_xs_lengths[5];
309  const index_t ConvStrideD = conv_filter_strides[0];
310  const index_t ConvStrideH = conv_filter_strides[1];
311  const index_t ConvStrideW = conv_filter_strides[2];
312  const index_t ConvDilationD = conv_filter_dilations[0];
313  const index_t ConvDilationH = conv_filter_dilations[1];
314  const index_t ConvDilationW = conv_filter_dilations[2];
315  const auto GcdStrideDilationD = gcd(ConvStrideD, ConvDilationD);
316  const auto GcdStrideDilationH = gcd(ConvStrideH, ConvDilationH);
317  const auto GcdStrideDilationW = gcd(ConvStrideW, ConvDilationW);
318  const auto ZTilde = ConvStrideD / GcdStrideDilationD;
319  const auto YTilde = ConvStrideH / GcdStrideDilationH;
320  const auto XTilde = ConvStrideW / GcdStrideDilationW;
321 
322  for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
323  {
324  for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
325  {
326  for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
327  {
328  const auto ZDotSlice = integer_divide_ceil(Z - i_ztilde, ZTilde);
329  const auto YDotSlice = integer_divide_ceil(Y - i_ytilde, YTilde);
330  const auto XDotSlice = integer_divide_ceil(X - i_xtilde, XTilde);
331 
332  if(ZDotSlice * XDotSlice * YDotSlice <= 0)
333  {
334  continue;
335  }
336 
338  {
339  gemm_count++;
340  // Avoid array segfault
341  continue;
342  }
343 
344  tildes = {i_ztilde, i_ytilde, i_xtilde};
345 
346  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
353  tildes};
354 
355  auto grid_descs = conv_to_gemm_transformer
356  .template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
357  GroupedConvTraitsType_::NDimSpatial>(1);
358 
359  a_grid_descs_m_k[gemm_count] = grid_descs.at(number<0>{});
360  b_grid_descs_n_k[gemm_count] = grid_descs.at(number<1>{});
361  c_grid_descs_m_n[gemm_count] = grid_descs.at(number<2>{});
362 
363  const index_t grid_size_grp =
364  TilePartitioner::GridSize(c_grid_descs_m_n[gemm_count].get_length(I0),
365  c_grid_descs_m_n[gemm_count].get_length(I1));
366 
368  block_ends[gemm_count] = grid_size_ + grid_size_grp;
369 
370  grid_size_ += grid_size_grp;
371 
372  ++gemm_count;
373  }
374  }
375  }
376 
377  group_stride_a = args.K_; // A: Out NWGK
378  group_stride_b = args.K_ * args.C_ *
379  std::accumulate(args.filter_spatial_lengths_.begin(),
380  args.filter_spatial_lengths_.end(),
381  1,
382  std::multiplies<index_t>()); // B: Wei GKXC
383  group_stride_c = args.C_; // C: In NWGC
384 
385  GemmBatch = args.G_; // C: In NWGC
386  }
387 
388  static constexpr index_t MaxGroupedGemmGroupsNum = 128;
389 
392 
396 
397  static constexpr index_t NonSpatialDims = 3;
401 
407 
412 
413  const void* out_ptr;
414  void* in_ptr;
415  std::array<const void*, NumDTensor> ds_ptr;
416  const void* wei_ptr;
417 
421 
424 
428 };
429 
468 template <typename GroupedConvTraitsType_,
469  typename TilePartitioner_,
470  typename GemmPipeline_,
471  typename EpiloguePipeline_>
473 {
474  static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_;
476  GroupedConvTraitsType_::ConvSpecialization;
483 
488 
490  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
491 
492  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
493 
497 
499 
502  static constexpr index_t MaxGroupedGemmGroupsNum =
504 
505  // TODO: Enable this
506  static constexpr bool IsSplitKSupported = false;
507 
508  static constexpr auto I0 = number<0>();
509  static constexpr auto I1 = number<1>();
510  static constexpr auto I2 = number<2>();
511  static constexpr auto I3 = number<3>();
512 
513  static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
514  "Not supported!");
515  static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
516  static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
517  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
518  "Not supported C GEMM layout!");
519 
520  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
521  {
522  // clang-format off
523  return concat('_', "grouped_convolution_backward_data", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName());
524  // clang-format on
525  }
526 
528  {
529  // enable batched grouped gemm
530  return dim3(kargs.grid_size_, kargs.GemmBatch, kargs.k_batch);
531  }
532 
533  CK_TILE_HOST static constexpr auto BlockSize()
534  {
535  return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
536  }
537 
540  {
542  }
543 
545  {
546  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
547  }
548 
549  CK_TILE_HOST static bool
551  {
552  if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
555  {
556  if(kargs.k_batch != 1)
557  {
558  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
559  {
560  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
561  }
562  return false;
563  }
564  }
565 
567  {
568  return false;
569  }
570 
571  const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
572  const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
573 
574  // check ConvSpecialization
576  {
577  // check if it's 1x1, stride=1 conv
578  for(index_t i = 0; i < NDimSpatial; ++i)
579  {
580  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
581  const index_t ConvStride = kargs.conv_filter_strides[i];
582  const index_t LeftPad = kargs.input_left_pads[i];
583  const index_t RightPad = kargs.input_right_pads[i];
584 
585  if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
586  {
587  return false;
588  }
589  }
590  }
592  {
593  // check if it's 1x1 conv
594  for(index_t i = 0; i < NDimSpatial; ++i)
595  {
596  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
597  const index_t LeftPad = kargs.input_left_pads[i];
598  const index_t RightPad = kargs.input_right_pads[i];
599 
600  if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
601  {
602  return false;
603  }
604  }
605  }
607  {
608  if(ConvC != 1)
609  {
610  return false;
611  }
612  for(index_t i = 0; i < NDimSpatial; ++i)
613  {
614  const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
615 
616  if(filter_spatial_dim != I3)
617  {
618  return false;
619  }
620  }
621  }
622 
623  namespace ctc = tensor_layout::convolution;
624 
625  if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
626  std::is_same_v<InLayout, ctc::NDHWGC>)
627  {
628  // Check access per C
629  if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
630  {
631  CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
632  return false;
633  }
634  }
635  else
636  {
637  CK_TILE_ERROR("Not supported input layout!");
638  return false;
639  }
640 
641  // FIXME: layout
642  if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
643  std::is_same_v<WeiLayout, ctc::GKYXC> ||
644  std::is_same_v<WeiLayout, ctc::GKZYXC>)
645  {
646  if(ConvC % GroupedConvTraitsType_::VectorSizeC != 0)
647  {
648  CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
649  return false;
650  }
651  }
652  else
653  {
654  CK_TILE_ERROR("Not supported weight layout!");
655  return false;
656  }
657 
658  if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
659  std::is_same_v<OutLayout, ctc::NHWGK> ||
660  std::is_same_v<OutLayout, ctc::NDHWGK>)
661  {
662  if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
663  {
664  CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
665  return false;
666  }
667  }
668  else
669  {
670  CK_TILE_ERROR("Not supported output layout!");
671  return false;
672  }
673 
674  return true;
675  }
676 
677  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
678  CK_TILE_DEVICE static auto
680  const InDataType* b_ptr,
681  const std::array<const void*, NumDTensor>& ds_ptr,
682  WeiDataType* c_ptr,
684  const index_t group_id)
685  {
686  static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
687  static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
688  const auto& a_tensor_view = [&]() {
689  return make_tensor_view<address_space_enum::global>(
690  a_ptr,
691  kargs.a_grid_descs_m_k[group_id]); // A: out
692  }();
693 
694  const auto& b_tensor_view = [&]() {
695  return make_tensor_view<address_space_enum::global>(
696  b_ptr,
697  kargs.b_grid_descs_n_k[group_id]); // B: weight
698  }();
699 
700  const auto& c_tensor_view = [&]() {
701  return make_tensor_view<address_space_enum::global>(c_ptr,
702  kargs.c_grid_descs_m_n[group_id]);
703  }();
704 
705  const auto& ds_tensor_view = generate_tuple(
706  [&](auto i) {
707  static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
708  "Not supported!");
709  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
710  "Not supported!");
711  static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, OutDataType>,
712  "Not supported!");
713 
714  return make_tensor_view<address_space_enum::global>(
715  static_cast<OutDataType*>(ds_ptr[i]), kargs.c_grid_descs_m_n[group_id]);
716  },
718 
719  return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
720  }
721 
722  template <typename TensorView>
723  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
724  {
725  const auto& a_pad_view = [&]() {
726  const auto& a_tensor_view = views.at(I0);
727  return pad_tensor_view(a_tensor_view,
731  }();
732 
733  const auto& b_pad_view = [&]() {
734  const auto& b_tensor_view = views.at(I1);
735  return pad_tensor_view(b_tensor_view,
739  }();
740 
741  const auto& ds_tensor_view = views.at(I2);
742  const auto& ds_pad_view = generate_tuple(
743  [&](auto i) {
744  return pad_tensor_view(ds_tensor_view[i],
748  },
750 
751  const auto& c_pad_view = [&]() {
752  const auto& c_tensor_view = views.at(I3);
753  return pad_tensor_view(c_tensor_view,
757  }();
758 
759  return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
760  }
761 
762  template <typename PadView>
763  CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
764  const index_t i_m,
765  const index_t i_n,
766  const index_t i_k = 0)
767  {
768  const auto& a_pad_view = views.at(I0);
769  const auto& b_pad_view = views.at(I1);
770  const auto& ds_pad_view = views.at(I2);
771  const auto& c_pad_view = views.at(I3);
772 
773  const auto& a_block_window = [&]() {
774  return make_tile_window(a_pad_view,
777  {i_m, i_k});
778  }();
779 
780  const auto& b_block_window = [&]() {
781  return make_tile_window(b_pad_view,
784  {i_k, i_n});
785  }();
786 
787  const auto ds_block_window = generate_tuple(
788  [&](auto i) {
789  return make_tile_window(ds_pad_view[i],
792  {i_m, i_n});
793  },
795 
796  auto c_block_window = make_tile_window(
797  c_pad_view,
799  {i_m, i_n});
800 
801  return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
802  }
803 
816  CK_TILE_DEVICE static void RunGemm(const OutDataType* a_ptr,
817  const InDataType* b_ptr,
818  const std::array<const void*, NumDTensor>& ds_ptr,
819  WeiDataType* c_ptr,
820  void* smem_ptr_0,
822  const index_t block_idx_m,
823  const index_t block_idx_n,
824  const index_t group_id)
825  {
826  // Create Gemm tensor views, pad views and tile windows
827  const auto& gemm_tensor_views_tuple =
828  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
829  a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
830 
831  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
832  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
833 
834  const index_t num_loop = amd_wave_read_first_lane(TilePartitioner::GetLoopNum(
835  gemm_pad_views.at(I0).get_tensor_descriptor().get_length(I1)));
836 
837  // Run GEMM cooperatively by whole workgroup.
838  const auto& a_block_window = gemm_tile_windows.at(I0);
839  const auto& b_block_window = gemm_tile_windows.at(I1);
840  const auto& d_block_window = gemm_tile_windows.at(I2);
841 
842  const auto& c_block_tile = GemmPipeline{}.template operator()(
843  a_block_window, b_block_window, num_loop, smem_ptr_0);
844 
845  // Run Epilogue Pipeline
846  auto& c_block_window = gemm_tile_windows.at(I3);
847 
848  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
849  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
850  }
851 
867  CK_TILE_DEVICE static void RunGemm2LDS(const OutDataType* a_ptr,
868  const InDataType* b_ptr,
869  const std::array<const void*, NumDTensor>& ds_ptr,
870  WeiDataType* c_ptr,
871  void* __restrict__ smem_ptr_0,
872  void* __restrict__ smem_ptr_1,
874  const index_t block_idx_m,
875  const index_t block_idx_n,
876  const index_t group_id)
877  {
878  // Create Gemm tensor views, pad views and tile windows
879  const auto& gemm_tensor_views_tuple =
880  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
881  a_ptr, b_ptr, ds_ptr, c_ptr, kargs, group_id);
882  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
883  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
884 
885  const index_t num_loop = amd_wave_read_first_lane(
886  TilePartitioner::GetLoopNum(gemm_tile_windows.at(I0).get_length(I1)));
887 
888  // Run GEMM cooperatively by whole workgroup.
889  const auto& a_block_window = gemm_tile_windows.at(I0);
890  const auto& b_block_window = gemm_tile_windows.at(I1);
891  const auto& d_block_window = gemm_tile_windows.at(I2);
892 
893  const auto& c_block_tile = GemmPipeline{}.template operator()(
894  a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
895 
896  // Run Epilogue Pipeline
897  auto& c_block_window = gemm_tile_windows.at(I3);
898 
899  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
900  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
901  }
902 
904  index_t block_id) const
905  {
906  index_t left = 0;
907  index_t right = kargs.gemm_count;
908  index_t group_id = index_t((left + right) >> 1);
909 
910  while((!(block_id >= kargs.block_starts[group_id] &&
911  block_id < kargs.block_ends[group_id])) &&
912  left <= right)
913  {
914  if(block_id < kargs.block_starts[group_id])
915  {
916  right = group_id;
917  }
918  else
919  {
920  left = group_id;
921  }
922  group_id = index_t((left + right) >> 1);
923  }
924 
925  return group_id;
926  }
927 
929  {
930  const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
931  const index_t group_id = FindGroupId(kargs, blockIdX);
932 
934  kargs.block_starts[group_id],
935  kargs.c_grid_descs_m_n[group_id].get_length(I0),
936  kargs.c_grid_descs_m_n[group_id].get_length(I1));
937 
938  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
939  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
940 
941  const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
942  const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
943  const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
944  const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
945 
946  // options
947  // conv_bwd_data = Out * Weight = In
948  const OutDataType* a_ptr = static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
949  const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) + group_offset_b;
950  InDataType* c_ptr = static_cast<InDataType*>(kargs.in_ptr) + group_offset_c;
951 
952  // allocate LDS
953  __shared__ char smem_ptr_0[GetSmemSize()];
954 
955  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
956  {
957  __shared__ char smem_ptr_1[GetSmemSize()];
958  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
959  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
961  {
962  RunGemm2LDS(a_ptr,
963  b_ptr,
964  kargs.ds_ptr,
965  c_ptr,
966  smem_ptr_0,
967  smem_ptr_1,
968  kargs,
969  i_m,
970  i_n,
971  group_id);
972  }
973  }
974  else
975  {
976  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
977  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
979  {
980  RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n, group_id);
981  }
982  }
983  }
984 };
985 
986 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:33
ConvolutionSpecialization
Definition: convolution_specialization.hpp:11
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE index_t gcd(index_t x, index_t y)
Definition: math.hpp:268
int64_t long_index_t
Definition: integer.hpp:11
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr bool is_same_v
Definition: type.hpp:283
__device__ X atomic_add(X *p_dst, const X &x)
The Grouped Convolution kernel device arguments.
Definition: grouped_convolution_backward_data_kernel.hpp:22
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > wei_g_k_c_xs_lengths
Definition: grouped_convolution_backward_data_kernel.hpp:399
static constexpr auto I1
Definition: grouped_convolution_backward_data_kernel.hpp:34
CK_TILE_HOST GroupedConvBwdDataKernelArgs(const GroupedConvBwdDataHostArgs &args)
Definition: grouped_convolution_backward_data_kernel.hpp:44
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_dilations
Definition: grouped_convolution_backward_data_kernel.hpp:403
std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:415
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_strides
Definition: grouped_convolution_backward_data_kernel.hpp:402
array< index_t, MaxGroupedGemmGroupsNum > block_starts
Definition: grouped_convolution_backward_data_kernel.hpp:422
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_left_pads
Definition: grouped_convolution_backward_data_kernel.hpp:404
long_index_t group_stride_b
Definition: grouped_convolution_backward_data_kernel.hpp:426
long_index_t group_stride_c
Definition: grouped_convolution_backward_data_kernel.hpp:427
array< index_t, MaxGroupedGemmGroupsNum > block_ends
Definition: grouped_convolution_backward_data_kernel.hpp:423
const void * out_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:413
remove_cvref_t< decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(1))> ABCGridDescs
Definition: grouped_convolution_backward_data_kernel.hpp:391
remove_cvref_t< decltype(ABCGridDescs{}[number< 1 >{}])> BGridDescNK
Definition: grouped_convolution_backward_data_kernel.hpp:394
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_backward_data_kernel.hpp:23
array< index_t, GroupedConvTraitsType_::NDimSpatial > tildes
Definition: grouped_convolution_backward_data_kernel.hpp:406
remove_cvref_t< decltype(ABCGridDescs{}[number< 0 >{}])> AGridDescMK
Definition: grouped_convolution_backward_data_kernel.hpp:393
const void * wei_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:416
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > out_g_n_k_wos_lengths
Definition: grouped_convolution_backward_data_kernel.hpp:400
long_index_t group_stride_a
Definition: grouped_convolution_backward_data_kernel.hpp:425
index_t GemmBatch
Definition: grouped_convolution_backward_data_kernel.hpp:409
void * in_ptr
Definition: grouped_convolution_backward_data_kernel.hpp:414
index_t gemm_count
Definition: grouped_convolution_backward_data_kernel.hpp:411
array< CGridDescMN, MaxGroupedGemmGroupsNum > c_grid_descs_m_n
Definition: grouped_convolution_backward_data_kernel.hpp:420
index_t grid_size_
Definition: grouped_convolution_backward_data_kernel.hpp:410
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_right_pads
Definition: grouped_convolution_backward_data_kernel.hpp:405
array< BGridDescNK, MaxGroupedGemmGroupsNum > b_grid_descs_n_k
Definition: grouped_convolution_backward_data_kernel.hpp:419
index_t k_batch
Definition: grouped_convolution_backward_data_kernel.hpp:408
static constexpr auto I0
Definition: grouped_convolution_backward_data_kernel.hpp:33
static constexpr index_t MaxGroupedGemmGroupsNum
Definition: grouped_convolution_backward_data_kernel.hpp:388
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > in_g_n_c_wis_lengths
Definition: grouped_convolution_backward_data_kernel.hpp:398
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_data_kernel.hpp:31
array< AGridDescMK, MaxGroupedGemmGroupsNum > a_grid_descs_m_k
Definition: grouped_convolution_backward_data_kernel.hpp:418
remove_cvref_t< decltype(ABCGridDescs{}[number< 2 >{}])> CGridDescMN
Definition: grouped_convolution_backward_data_kernel.hpp:395
static constexpr index_t NonSpatialDims
Definition: grouped_convolution_backward_data_kernel.hpp:397
The Grouped Conv kernel host arguments.
Definition: grouped_convolution_utils.hpp:19
index_t k_batch
Definition: grouped_convolution_utils.hpp:40
InPtr in_ptr
Definition: grouped_convolution_utils.hpp:36
WeiPtr wei_ptr
Definition: grouped_convolution_utils.hpp:37
OutPtr out_ptr
Definition: grouped_convolution_utils.hpp:39
const std::vector< const void * > ds_ptr
Definition: grouped_convolution_utils.hpp:38
The Grouped Convolution Backward Data kernel template.
Definition: grouped_convolution_backward_data_kernel.hpp:473
static constexpr index_t NDimSpatial
Definition: grouped_convolution_backward_data_kernel.hpp:474
static constexpr CK_TILE_HOST auto BlockSize()
Definition: grouped_convolution_backward_data_kernel.hpp:533
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_convolution_backward_data_kernel.hpp:478
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n, const index_t i_k=0)
Definition: grouped_convolution_backward_data_kernel.hpp:763
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: grouped_convolution_backward_data_kernel.hpp:723
GroupedConvBwdDataKernelArgs< GroupedConvTraitsType_, TilePartitioner > GroupedConvBwdDataKernelArgsSpecialized
Definition: grouped_convolution_backward_data_kernel.hpp:501
remove_cvref_t< typename GemmPipeline::ADataType > InDataType
Definition: grouped_convolution_backward_data_kernel.hpp:494
static constexpr index_t MaxGroupedGemmGroupsNum
Definition: grouped_convolution_backward_data_kernel.hpp:502
static constexpr auto I1
Definition: grouped_convolution_backward_data_kernel.hpp:509
static constexpr auto I3
Definition: grouped_convolution_backward_data_kernel.hpp:511
remove_cvref_t< typename GroupedConvTraitsType_::OutLayout > OutLayout
Definition: grouped_convolution_backward_data_kernel.hpp:486
static constexpr ConvolutionSpecialization ConvSpecialization
Definition: grouped_convolution_backward_data_kernel.hpp:475
static constexpr CK_TILE_HOST GroupedConvBwdDataKernelArgsSpecialized MakeKernelArgs(const GroupedConvBwdDataHostArgs &hostArgs)
Definition: grouped_convolution_backward_data_kernel.hpp:539
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_data_kernel.hpp:490
remove_cvref_t< typename GemmPipeline::BDataType > WeiDataType
Definition: grouped_convolution_backward_data_kernel.hpp:495
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_convolution_backward_data_kernel.hpp:479
remove_cvref_t< typename EpiloguePipeline::ODataType > OutDataType
Definition: grouped_convolution_backward_data_kernel.hpp:498
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_backward_data_kernel.hpp:477
remove_cvref_t< typename GroupedConvTraitsType_::WeiLayout > WeiLayout
Definition: grouped_convolution_backward_data_kernel.hpp:485
static constexpr index_t kBlockSize
Definition: grouped_convolution_backward_data_kernel.hpp:492
static CK_TILE_HOST bool IsSupportedArgument(const GroupedConvBwdDataKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_data_kernel.hpp:550
remove_cvref_t< typename GemmPipeline::BLayout > GemmBLayout
Definition: grouped_convolution_backward_data_kernel.hpp:481
remove_cvref_t< typename GroupedConvTraitsType_::DsLayout > DsLayout
Definition: grouped_convolution_backward_data_kernel.hpp:487
static constexpr auto I2
Definition: grouped_convolution_backward_data_kernel.hpp:510
static CK_TILE_DEVICE auto MakeGemmTensorViews(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t group_id)
Definition: grouped_convolution_backward_data_kernel.hpp:679
static CK_TILE_HOST auto GridSize(const GroupedConvBwdDataKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_data_kernel.hpp:527
remove_cvref_t< typename GemmPipeline::ALayout > GemmALayout
Definition: grouped_convolution_backward_data_kernel.hpp:480
remove_cvref_t< typename EpiloguePipeline::DsLayout > GemmDsLayout
Definition: grouped_convolution_backward_data_kernel.hpp:489
CK_TILE_DEVICE index_t FindGroupId(const GroupedConvBwdDataKernelArgsSpecialized &kargs, index_t block_id) const
Definition: grouped_convolution_backward_data_kernel.hpp:903
static CK_TILE_DEVICE void RunGemm(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *smem_ptr_0, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t block_idx_m, const index_t block_idx_n, const index_t group_id)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_data_kernel.hpp:816
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: grouped_convolution_backward_data_kernel.hpp:544
static CK_TILE_DEVICE void RunGemm2LDS(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const GroupedConvBwdDataKernelArgsSpecialized &kargs, const index_t block_idx_m, const index_t block_idx_n, const index_t group_id)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_data_kernel.hpp:867
CK_TILE_DEVICE void operator()(GroupedConvBwdDataKernelArgsSpecialized kargs) const
Definition: grouped_convolution_backward_data_kernel.hpp:928
static constexpr bool IsSplitKSupported
Definition: grouped_convolution_backward_data_kernel.hpp:506
remove_cvref_t< typename GroupedConvTraitsType_::InLayout > InLayout
Definition: grouped_convolution_backward_data_kernel.hpp:484
remove_cvref_t< typename GemmPipeline::CLayout > GemmCLayout
Definition: grouped_convolution_backward_data_kernel.hpp:482
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: grouped_convolution_backward_data_kernel.hpp:496
static CK_TILE_HOST const std::string GetName()
Definition: grouped_convolution_backward_data_kernel.hpp:520
static constexpr auto I0
Definition: grouped_convolution_backward_data_kernel.hpp:508
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from 1D raw-indexes.
Definition: gemm_tile_partitioner.hpp:192
Definition: transform_conv_bwd_data_to_gemm.hpp:22
CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N([[maybe_unused]] const index_t GemmKBatch) const
Definition: transform_conv_bwd_data_to_gemm.hpp:569
constexpr CK_TILE_HOST_DEVICE auto & at(index_t i)
Definition: array.hpp:110
Definition: integral_constant.hpp:13
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition: convolution_parameter.hpp:130
ck_tile::long_index_t K_
Definition: convolution_parameter.hpp:126
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition: convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > input_right_pads_
Definition: convolution_parameter.hpp:137
ck_tile::long_index_t G_
Definition: convolution_parameter.hpp:124
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition: convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition: convolution_parameter.hpp:129
ck_tile::long_index_t C_
Definition: convolution_parameter.hpp:127
ck_tile::long_index_t N_
Definition: convolution_parameter.hpp:125
std::vector< ck_tile::long_index_t > input_left_pads_
Definition: convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition: convolution_parameter.hpp:134
Definition: type_traits.hpp:115
Definition: sequence.hpp:49
#define CK_TILE_ENV(name)
Definition: env.hpp:145