/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp Source File
grouped_convolution_backward_weight_kernel.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
16 
17 #ifdef CK_EXPERIMENTAL_BUILDER
18 #include "ck_tile/builder/reflect/instance_traits_tile_grouped_convolution_backward_weight.hpp"
19 #endif
20 
21 namespace ck_tile {
22 
24 template <typename GroupedConvTraitsType_>
26 {
27 
29  TransformConvBwdWeightToGemm<GroupedConvTraitsType_::NDimSpatial,
30  GroupedConvTraitsType_::ConvSpecialization,
31  GroupedConvTraitsType_::VectorSizeA,
32  GroupedConvTraitsType_::VectorSizeB,
33  GroupedConvTraitsType_::VectorSizeC,
34  GroupedConvTraitsType_::NumGroupsToMerge>;
35  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
36 
37  template <
38  typename InLay = typename GroupedConvTraitsType_::InLayout,
39  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
40  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
41  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
42  std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
43  std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
44  bool>::type = false>
46  {
47  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
48  static_cast<index_t>(args.N_),
49  static_cast<index_t>(args.C_),
50  static_cast<index_t>(args.input_spatial_lengths_[0])};
51  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
52  static_cast<index_t>(args.K_),
53  static_cast<index_t>(args.C_),
54  static_cast<index_t>(args.filter_spatial_lengths_[0])};
55  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
56  static_cast<index_t>(args.N_),
57  static_cast<index_t>(args.K_),
58  static_cast<index_t>(args.output_spatial_lengths_[0])};
59 
60  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
61  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0])};
62  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
63  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
64 
65  k_batch = args.k_batch;
66 
67  in_ptr = args.in_ptr;
68  wei_ptr = args.wei_ptr;
69  for(index_t d = 0; d < NumDTensor; d++)
70  {
71  ds_ptr[d] = args.ds_ptr[d];
72  }
73  out_ptr = args.out_ptr;
74 
75  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
82 
83  // tuple
84  auto grid_descs =
85  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
86  GroupedConvTraitsType_::NDimSpatial>();
87 
88  a_grid_desc_k_m = grid_descs.at(number<0>{});
89  b_grid_desc_k_n = grid_descs.at(number<1>{});
90  c_grid_desc_m_n = grid_descs.at(number<2>{});
91 
92  NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
93  group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NWGK
94  group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NWGC
95  group_stride_c = args.K_ * args.C_ // C: Wei GKXC
97  std::accumulate(args.filter_spatial_lengths_.begin(),
98  args.filter_spatial_lengths_.end(),
99  1,
100  std::multiplies<index_t>());
101 
102  GemmM = a_grid_desc_k_m.get_length(number<1>{});
103  GemmN = b_grid_desc_k_n.get_length(number<1>{});
104  GemmK = a_grid_desc_k_m.get_length(number<0>{});
106 
107  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
108  {
109  std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
110  << ", GemmBatch: " << GemmBatch
111  << ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
112  }
113  }
114 
115  template <
116  typename InLay = typename GroupedConvTraitsType_::InLayout,
117  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
118  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
119  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
120  std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
121  std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
122  bool>::type = false>
124  {
125  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
126  static_cast<index_t>(args.N_),
127  static_cast<index_t>(args.C_),
128  static_cast<index_t>(args.input_spatial_lengths_[0]),
129  static_cast<index_t>(args.input_spatial_lengths_[1])};
130  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
131  static_cast<index_t>(args.K_),
132  static_cast<index_t>(args.C_),
133  static_cast<index_t>(args.filter_spatial_lengths_[0]),
134  static_cast<index_t>(args.filter_spatial_lengths_[1])};
135  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
136  static_cast<index_t>(args.N_),
137  static_cast<index_t>(args.K_),
138  static_cast<index_t>(args.output_spatial_lengths_[0]),
139  static_cast<index_t>(args.output_spatial_lengths_[1])};
140 
141  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
142  static_cast<index_t>(args.conv_filter_strides_[1])};
143  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
144  static_cast<index_t>(args.conv_filter_dilations_[1])};
145  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
146  static_cast<index_t>(args.input_left_pads_[1])};
147  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
148  static_cast<index_t>(args.input_right_pads_[1])};
149 
150  k_batch = args.k_batch;
151 
152  in_ptr = args.in_ptr;
153  wei_ptr = args.wei_ptr;
154  for(index_t d = 0; d < NumDTensor; d++)
155  {
156  ds_ptr[d] = args.ds_ptr[d];
157  }
158  out_ptr = args.out_ptr;
159 
160  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
167 
168  // tuple
169  auto grid_descs =
170  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
171  GroupedConvTraitsType_::NDimSpatial>();
172 
173  a_grid_desc_k_m = grid_descs.at(number<0>{});
174  b_grid_desc_k_n = grid_descs.at(number<1>{});
175  c_grid_desc_m_n = grid_descs.at(number<2>{});
176 
177  NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
178  group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NHWGK
179  group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NHWGC
180  group_stride_c = args.K_ * args.C_ // C: Wei GKYXC
182  std::accumulate(args.filter_spatial_lengths_.begin(),
183  args.filter_spatial_lengths_.end(),
184  1,
185  std::multiplies<index_t>());
186 
187  GemmM = a_grid_desc_k_m.get_length(number<1>{});
188  GemmN = b_grid_desc_k_n.get_length(number<1>{});
189  GemmK = a_grid_desc_k_m.get_length(number<0>{});
191 
192  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
193  {
194  std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
195  << ", GemmBatch: " << GemmBatch
196  << ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
197  }
198  }
199 
200  template <
201  typename InLay = typename GroupedConvTraitsType_::InLayout,
202  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
203  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
204  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
205  std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
206  std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
207  bool>::type = false>
209  {
210  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
211  static_cast<index_t>(args.N_),
212  static_cast<index_t>(args.C_),
213  static_cast<index_t>(args.input_spatial_lengths_[0]),
214  static_cast<index_t>(args.input_spatial_lengths_[1]),
215  static_cast<index_t>(args.input_spatial_lengths_[2])};
216  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
217  static_cast<index_t>(args.K_),
218  static_cast<index_t>(args.C_),
219  static_cast<index_t>(args.filter_spatial_lengths_[0]),
220  static_cast<index_t>(args.filter_spatial_lengths_[1]),
221  static_cast<index_t>(args.filter_spatial_lengths_[2])};
222  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
223  static_cast<index_t>(args.N_),
224  static_cast<index_t>(args.K_),
225  static_cast<index_t>(args.output_spatial_lengths_[0]),
226  static_cast<index_t>(args.output_spatial_lengths_[1]),
227  static_cast<index_t>(args.output_spatial_lengths_[2])};
228 
229  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
230  static_cast<index_t>(args.conv_filter_strides_[1]),
231  static_cast<index_t>(args.conv_filter_strides_[2])};
232  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
233  static_cast<index_t>(args.conv_filter_dilations_[1]),
234  static_cast<index_t>(args.conv_filter_dilations_[2])};
235  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
236  static_cast<index_t>(args.input_left_pads_[1]),
237  static_cast<index_t>(args.input_left_pads_[2])};
238  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
239  static_cast<index_t>(args.input_right_pads_[1]),
240  static_cast<index_t>(args.input_right_pads_[2])};
241 
242  k_batch = args.k_batch;
243 
244  in_ptr = args.in_ptr;
245  wei_ptr = args.wei_ptr;
246  for(index_t d = 0; d < NumDTensor; d++)
247  {
248  ds_ptr[d] = args.ds_ptr[d];
249  }
250  out_ptr = args.out_ptr;
251 
252  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
259 
260  // tuple
261  auto grid_descs =
262  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
263  GroupedConvTraitsType_::NDimSpatial>();
264 
265  a_grid_desc_k_m = grid_descs.at(number<0>{});
266  b_grid_desc_k_n = grid_descs.at(number<1>{});
267  c_grid_desc_m_n = grid_descs.at(number<2>{});
268 
269  NumGroupsPerBatch = GroupedConvTraitsType_::NumGroupsToMerge;
270  group_stride_a = args.K_ * NumGroupsPerBatch; // A: Out NDHWGK
271  group_stride_b = args.C_ * NumGroupsPerBatch; // B: In NDHWGC
272  group_stride_c = args.K_ * args.C_ // C: Wei GKZYXC
274  std::accumulate(args.filter_spatial_lengths_.begin(),
275  args.filter_spatial_lengths_.end(),
276  1,
277  std::multiplies<index_t>());
278 
279  GemmM = a_grid_desc_k_m.get_length(number<1>{});
280  GemmN = b_grid_desc_k_n.get_length(number<1>{});
281  GemmK = a_grid_desc_k_m.get_length(number<0>{});
283 
284  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
285  {
286  std::cout << "GemmM: " << GemmM << ", GemmN: " << GemmN << ", GemmK: " << GemmK
287  << ", GemmBatch: " << GemmBatch
288  << ", NumGroupsPerBatch: " << NumGroupsPerBatch << std::endl;
289  }
290  }
291 
294 
298 
299  static constexpr index_t NonSpatialDims = 3;
303 
308 
315 
316  const void* out_ptr;
317  const void* in_ptr;
318  std::array<const void*, NumDTensor> ds_ptr;
319  void* wei_ptr;
320 
324 
328 };
329 
367 template <typename GroupedConvTraitsType_,
368  typename TilePartitioner_,
369  typename GemmPipeline_,
370  typename EpiloguePipeline_>
372 {
373  static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial;
375  GroupedConvTraitsType_::ConvSpecialization;
382 
387 
389  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
390 
391  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
392 
397 
400 
401  // TODO: Enable this
402  static constexpr bool IsSplitKSupported = true;
403 
404  static constexpr auto I0 = number<0>();
405  static constexpr auto I1 = number<1>();
406  static constexpr auto I2 = number<2>();
407  static constexpr auto I3 = number<3>();
408 
409  static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
410  "Not supported!");
411  static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
412  static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
413  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
414  static_assert(GroupedConvTraitsType_::ExplicitGemm == false ||
415  GroupedConvTraitsType_::NumGroupsToMerge == 1,
416  "Not supported!");
417 
418  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
419  {
420  constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
421  // clang-format off
422  if (NumGroupsToMerge > 1) {
423  return concat('_', "grouped_convolution_backward_weight",
424  gemm_prec_str<InDataType, WeiDataType>(),
425  "gemm",
426  GemmPipeline::GetName(),
427  "epilogue",
428  EpiloguePipeline::GetName());
429  } else {
430  return concat('_', "grouped_convolution_backward_weight",
431  gemm_prec_str<InDataType, WeiDataType>(),
432  "gemm",
433  GemmPipeline::GetName(),
434  "epilogue",
435  EpiloguePipeline::GetName(), "merge", NumGroupsToMerge);
436  }
437  // clang-format on
438  }
439 
440 #ifdef CK_EXPERIMENTAL_BUILDER
441  CK_TILE_HOST std::string GetInstanceString() const
442  {
443  static_assert(ck_tile::reflect::HasInstanceTraits<GroupedConvolutionBackwardWeightKernel>,
444  "Specialization of instance_traits not found. Please check that a "
445  "specialization exists in file "
446  "ck_tile/builder/reflect/"
447  "instance_traits_tile_grouped_convolution_backward_weight.hpp "
448  "for the given template parameters.");
449  return ck_tile::reflect::instance_string<GroupedConvolutionBackwardWeightKernel>();
450  }
451 #endif
452 
453  CK_TILE_HOST static constexpr auto
455  {
456  return dim3(
457  TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
458  }
459 
460  CK_TILE_HOST static constexpr auto BlockSize()
461  {
462  return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
463  }
464 
467  {
468  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
469  {
470  std::cout << "MPerBlock: " << number<TilePartitioner::MPerBlock>{} << std::endl;
471  std::cout << "NPerBlock: " << number<TilePartitioner::NPerBlock>{} << std::endl;
472  std::cout << "KPerBlock: " << number<TilePartitioner::KPerBlock>{} << std::endl;
473  }
475  }
476 
478  {
479  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
480  }
481 
483  {
485  const std::size_t k_id = blockIdx.z)
486  {
487  constexpr auto K1 = GemmPipeline::BlockGemmShape::WarpTile::at(number<2>{});
488  const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
489  const index_t KRead = amd_wave_read_first_lane((kargs.GemmK + K_t - 1) / K_t * K1);
490 
493 
494  if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
495  {
497  }
498  else
499  {
500  splitted_k = amd_wave_read_first_lane(kargs.GemmK - KRead * (kargs.k_batch - 1));
501  }
502  }
503 
507  };
508 
509  CK_TILE_HOST static bool
511  {
512  if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
515  {
516  if(kargs.k_batch != 1)
517  {
518  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
519  {
520  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
521  }
522  return false;
523  }
524  }
525 
526  const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
527  const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
528 
529  // check ConvSpecialization
531  {
532  // check if it's 1x1, stride=1 conv
533  for(index_t i = 0; i < NDimSpatial; ++i)
534  {
535  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
536  const index_t ConvStride = kargs.conv_filter_strides[i];
537  const index_t LeftPad = kargs.input_left_pads[i];
538  const index_t RightPad = kargs.input_right_pads[i];
539 
540  if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
541  {
542  return false;
543  }
544  }
545  }
547  {
548  // check if it's 1x1 conv
549  for(index_t i = 0; i < NDimSpatial; ++i)
550  {
551  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
552  const index_t LeftPad = kargs.input_left_pads[i];
553  const index_t RightPad = kargs.input_right_pads[i];
554 
555  if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
556  {
557  return false;
558  }
559  }
560  }
562  {
563  if(ConvC != 1)
564  {
565  return false;
566  }
567  for(index_t i = 0; i < NDimSpatial; ++i)
568  {
569  const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
570 
571  if(filter_spatial_dim != I3)
572  {
573  return false;
574  }
575  }
576  }
577 
578  if constexpr(GroupedConvTraitsType_::ExplicitGemm &&
580  {
582  "Explicit Gemm is supported only for Filter1x1Stride1Pad0 specialization!");
583  return false;
584  }
585 
586  namespace ctc = tensor_layout::convolution;
587 
588  if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
589  std::is_same_v<InLayout, ctc::NDHWGC>)
590  {
591  // Check access per C
592  if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
593  {
594  CK_TILE_ERROR("Conv C is not a multiple of vector load size for "
595  "input image!");
596  return false;
597  }
598  }
599  else
600  {
601  CK_TILE_ERROR("Not supported input layout!");
602  return false;
603  }
604 
605  if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
606  std::is_same_v<WeiLayout, ctc::GKYXC> ||
607  std::is_same_v<WeiLayout, ctc::GKZYXC>)
608  {
609  if(ConvC % GroupedConvTraitsType_::VectorSizeC != 0)
610  {
611  CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
612  return false;
613  }
614  }
615  else
616  {
617  CK_TILE_ERROR("Not supported weight layout!");
618  return false;
619  }
620 
621  if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
622  std::is_same_v<OutLayout, ctc::NHWGK> ||
623  std::is_same_v<OutLayout, ctc::NDHWGK>)
624  {
625  if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
626  {
627  CK_TILE_ERROR("Conv K is not a multiple of vector store size "
628  "for output image!");
629  return false;
630  }
631  }
632  else
633  {
634  CK_TILE_ERROR("Not supported output layout!");
635  return false;
636  }
637 
638  if constexpr(GroupedConvTraitsType_::NumGroupsToMerge > 1)
639  {
640  const index_t ConvG = kargs.wei_g_k_c_xs_lengths[number<0>{}];
641  if(ConvG % GroupedConvTraitsType_::NumGroupsToMerge != 0)
642  {
643  CK_TILE_ERROR("ConvG must be a multiple of NumGroupsToMerge!");
644  return false;
645  }
646  }
647 
648  return true;
649  }
650 
651  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
652  CK_TILE_DEVICE static auto
654  const InDataType* b_ptr,
655  const std::array<const void*, NumDTensor>& ds_ptr,
656  WeiDataType* c_ptr,
658  {
659  static_assert(!GemmPipeline::BlockGemmShape::PermuteA, "Not implemented!");
660  static_assert(!GemmPipeline::BlockGemmShape::PermuteB, "Not implemented!");
661  const auto& a_tensor_view = [&]() {
662  return make_tensor_view<address_space_enum::global>(a_ptr,
663  kargs.a_grid_desc_k_m); // A: out
664  }();
665 
666  const auto& b_tensor_view = [&]() {
667  return make_tensor_view<address_space_enum::global>(b_ptr,
668  kargs.b_grid_desc_k_n); // B: in
669  }();
670 
671  const auto& c_tensor_view = [&]() {
672  return make_tensor_view<address_space_enum::global, DstInMemOp>(c_ptr,
673  kargs.c_grid_desc_m_n);
674  }();
675 
676  const auto& ds_tensor_view = generate_tuple(
677  [&](auto i) {
678  static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
679  "Not supported!");
680  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
681  "Not supported!");
682  static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, WeiDataType>,
683  "Not supported!");
684 
685  return make_tensor_view<address_space_enum::global>(
686  static_cast<WeiDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
687  },
689 
690  return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
691  }
692 
693  template <typename TensorView>
694  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views, const index_t k_batch)
695  {
696  const auto& a_pad_view = [&]() {
697  const auto& a_tensor_view = views.at(I0);
698  return pad_tensor_view(a_tensor_view,
702  }();
703 
704  const auto& b_pad_view = [&]() {
705  const auto& b_tensor_view = views.at(I1);
706  return pad_tensor_view(b_tensor_view,
710  }();
711 
712  const auto& ds_tensor_view = views.at(I2);
713  const auto& ds_pad_view = generate_tuple(
714  [&](auto i) {
715  return pad_tensor_view(ds_tensor_view[i],
719  },
721 
722  const auto& c_pad_view = [&]() {
723  const auto& c_tensor_view = views.at(I3);
724  return pad_tensor_view(c_tensor_view,
728  }();
729 
730  return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
731  }
732 
743  template <typename PadView>
744  CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
745  const index_t i_m,
746  const index_t i_n,
747  const index_t i_k)
748  {
749  const auto& a_pad_view = views.at(I0);
750  const auto& b_pad_view = views.at(I1);
751  const auto& ds_pad_view = views.at(I2);
752  const auto& c_pad_view = views.at(I3);
753 
754  const auto& a_block_window = [&]() {
755  return make_tile_window(a_pad_view,
758  {i_k, i_m});
759  }();
760 
761  const auto& b_block_window = [&]() {
762  return make_tile_window(b_pad_view,
765  {i_k, i_n});
766  }();
767 
768  const auto ds_block_window = generate_tuple(
769  [&](auto i) {
770  return make_tile_window(ds_pad_view[i],
773  {i_m, i_n});
774  },
776 
777  auto c_block_window = make_tile_window(
778  c_pad_view,
780  {i_m, i_n});
781 
782  return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
783  }
784 
797  CK_TILE_DEVICE static void RunGemm(const OutDataType* a_ptr,
798  const InDataType* b_ptr,
799  const std::array<const void*, NumDTensor>& ds_ptr,
800  WeiDataType* c_ptr,
801  void* smem_ptr_0,
803  const index_t num_loop,
804  const index_t block_idx_m,
805  const index_t block_idx_n,
806  const index_t block_idx_k)
807  {
808  // Create Gemm tensor views, pad views and tile windows
809  const auto& gemm_tensor_views_tuple =
810  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
811  a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
812 
813  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch);
814  auto gemm_tile_windows =
815  MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
816 
817  // Run GEMM cooperatively by whole workgroup.
818  const auto& a_block_window = gemm_tile_windows.at(I0);
819  const auto& b_block_window = gemm_tile_windows.at(I1);
820  const auto& d_block_window = gemm_tile_windows.at(I2);
821 
822  const auto& c_block_tile = GemmPipeline{}.template operator()(
823  a_block_window, b_block_window, num_loop, smem_ptr_0);
824 
825  // Run Epilogue Pipeline
826  auto& c_block_window = gemm_tile_windows.at(I3);
827 
828  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
829  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
830  }
831 
847  CK_TILE_DEVICE static void RunGemm2LDS(const OutDataType* a_ptr,
848  const InDataType* b_ptr,
849  const std::array<const void*, NumDTensor>& ds_ptr,
850  WeiDataType* c_ptr,
851  void* __restrict__ smem_ptr_0,
852  void* __restrict__ smem_ptr_1,
854  const index_t num_loop,
855  const index_t block_idx_m,
856  const index_t block_idx_n,
857  const index_t block_idx_k)
858  {
859  // Create Gemm tensor views, pad views and tile windows
860  const auto& gemm_tensor_views_tuple =
861  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
862  a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
863  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch);
864  auto gemm_tile_windows =
865  MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
866 
867  // Run GEMM cooperatively by whole workgroup.
868  const auto& a_block_window = gemm_tile_windows.at(I0);
869  const auto& b_block_window = gemm_tile_windows.at(I1);
870  const auto& d_block_window = gemm_tile_windows.at(I2);
871 
872  const auto& c_block_tile = GemmPipeline{}.template operator()(
873  a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
874 
875  // Run Epilogue Pipeline
876  auto& c_block_window = gemm_tile_windows.at(I3);
877 
878  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
879  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
880  }
881 
883  {
884  static_assert(NumDTensor == 0, "Not supported!");
885  using ExplicitBatchedGemmKernel =
887  const auto batched_gemm_kargs = typename ExplicitBatchedGemmKernel::BatchedGemmKernelArgs{
888  {{kargs.out_ptr},
889  {kargs.in_ptr},
890  {},
891  kargs.wei_ptr,
892  kargs.GemmM,
893  kargs.GemmN,
894  kargs.GemmK,
895  {kargs.GemmM * kargs.GemmBatch},
896  {kargs.GemmN * kargs.GemmBatch},
897  {},
898  kargs.GemmN,
899  kargs.k_batch},
900  kargs.GemmM,
901  kargs.GemmN,
902  kargs.GemmM * kargs.GemmN,
903  kargs.GemmBatch};
904  ExplicitBatchedGemmKernel{}(batched_gemm_kargs);
905  }
906 
908  {
909  if constexpr(GroupedConvTraitsType_::ExplicitGemm)
910  {
911  CallExplicitGemm(kargs);
912  }
913  else
914  {
915  const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
916  const auto [iM, iN] =
917  TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
918  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
919  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
920 
921  const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
923  kargs.GemmK, kargs.k_batch * TilePartitioner::KPerBlock));
924  const index_t i_k =
925  amd_wave_read_first_lane(blockIdZ * num_loop * TilePartitioner::KPerBlock);
926 
927  const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
928  const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
929  const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
930  const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
931 
932  // options
933  // conv_bwd_weight = Out * In = Weight
934  const OutDataType* a_ptr =
935  static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
936  const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
937  WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
938 
939  __shared__ char smem_ptr_0[GetSmemSize()];
940 
941  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
942  {
943  __shared__ char smem_ptr_1[GetSmemSize()];
944  if constexpr(!(EpiloguePipeline::MemoryOperation ==
946  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
948  {
949  RunGemm2LDS(a_ptr,
950  b_ptr,
951  kargs.ds_ptr,
952  c_ptr,
953  smem_ptr_0,
954  smem_ptr_1,
955  kargs,
956  num_loop,
957  i_m,
958  i_n,
959  i_k);
960  }
961  }
962  else
963  {
964  if constexpr(!(EpiloguePipeline::MemoryOperation ==
966  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
968  {
969  RunGemm(a_ptr,
970  b_ptr,
971  kargs.ds_ptr,
972  c_ptr,
973  smem_ptr_0,
974  kargs,
975  num_loop,
976  i_m,
977  i_n,
978  i_k);
979  }
980  }
981  }
982  }
983 };
984 
985 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
Definition: cluster_descriptor.hpp:13
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:35
ConvolutionSpecialization
Definition: convolution_specialization.hpp:11
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:545
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr bool is_same_v
Definition: type.hpp:283
__device__ X atomic_add(X *p_dst, const X &x)
unsigned int uint32_t
Definition: stdint.h:126
Definition: batched_gemm_kernel.hpp:62
The Grouped Convolution kernel device arguments.
Definition: grouped_convolution_backward_weight_kernel.hpp:26
long_index_t group_stride_a
Definition: grouped_convolution_backward_weight_kernel.hpp:325
remove_cvref_t< decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N())> ABCGridDescs
Definition: grouped_convolution_backward_weight_kernel.hpp:293
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_strides
Definition: grouped_convolution_backward_weight_kernel.hpp:304
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > wei_g_k_c_xs_lengths
Definition: grouped_convolution_backward_weight_kernel.hpp:301
void * wei_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:319
long_index_t group_stride_b
Definition: grouped_convolution_backward_weight_kernel.hpp:326
CGridDescMN c_grid_desc_m_n
Definition: grouped_convolution_backward_weight_kernel.hpp:323
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > in_g_n_c_wis_lengths
Definition: grouped_convolution_backward_weight_kernel.hpp:300
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_dilations
Definition: grouped_convolution_backward_weight_kernel.hpp:305
AGridDescKM a_grid_desc_k_m
Definition: grouped_convolution_backward_weight_kernel.hpp:321
BGridDescKN b_grid_desc_k_n
Definition: grouped_convolution_backward_weight_kernel.hpp:322
index_t GemmN
Definition: grouped_convolution_backward_weight_kernel.hpp:311
index_t GemmBatch
Definition: grouped_convolution_backward_weight_kernel.hpp:313
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > out_g_n_k_wos_lengths
Definition: grouped_convolution_backward_weight_kernel.hpp:302
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs &args)
Definition: grouped_convolution_backward_weight_kernel.hpp:45
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_left_pads
Definition: grouped_convolution_backward_weight_kernel.hpp:306
remove_cvref_t< decltype(ABCGridDescs{}[number< 1 >{}])> BGridDescKN
Definition: grouped_convolution_backward_weight_kernel.hpp:296
std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:318
index_t GemmM
Definition: grouped_convolution_backward_weight_kernel.hpp:310
index_t NumGroupsPerBatch
Definition: grouped_convolution_backward_weight_kernel.hpp:314
remove_cvref_t< decltype(ABCGridDescs{}[number< 2 >{}])> CGridDescMN
Definition: grouped_convolution_backward_weight_kernel.hpp:297
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_right_pads
Definition: grouped_convolution_backward_weight_kernel.hpp:307
index_t GemmK
Definition: grouped_convolution_backward_weight_kernel.hpp:312
const void * in_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:317
index_t k_batch
Definition: grouped_convolution_backward_weight_kernel.hpp:309
static constexpr index_t NonSpatialDims
Definition: grouped_convolution_backward_weight_kernel.hpp:299
const void * out_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:316
remove_cvref_t< decltype(ABCGridDescs{}[number< 0 >{}])> AGridDescKM
Definition: grouped_convolution_backward_weight_kernel.hpp:295
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_weight_kernel.hpp:35
long_index_t group_stride_c
Definition: grouped_convolution_backward_weight_kernel.hpp:327
The Grouped Conv kernel host arguments.
Definition: grouped_convolution_utils.hpp:20
InPtr in_ptr
Definition: grouped_convolution_utils.hpp:39
OutPtr out_ptr
Definition: grouped_convolution_utils.hpp:42
WeiPtr wei_ptr
Definition: grouped_convolution_utils.hpp:40
index_t k_batch
Definition: grouped_convolution_utils.hpp:43
const std::vector< const void * > ds_ptr
Definition: grouped_convolution_utils.hpp:41
Definition: grouped_convolution_backward_weight_kernel.hpp:483
index_t b_k_split_offset
Definition: grouped_convolution_backward_weight_kernel.hpp:505
index_t splitted_k
Definition: grouped_convolution_backward_weight_kernel.hpp:506
__device__ SplitKBatchOffset(const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const std::size_t k_id=blockIdx.z)
Definition: grouped_convolution_backward_weight_kernel.hpp:484
index_t a_k_split_offset
Definition: grouped_convolution_backward_weight_kernel.hpp:504
The Grouped Convolution Backward Weight kernel template.
Definition: grouped_convolution_backward_weight_kernel.hpp:372
remove_cvref_t< typename EpiloguePipeline::DsLayout > GemmDsLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:388
static constexpr index_t kBlockSize
Definition: grouped_convolution_backward_weight_kernel.hpp:391
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views, const index_t k_batch)
Definition: grouped_convolution_backward_weight_kernel.hpp:694
static constexpr CK_TILE_HOST auto BlockSize()
Definition: grouped_convolution_backward_weight_kernel.hpp:460
remove_cvref_t< typename GroupedConvTraitsType_::OutLayout > OutLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:385
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_backward_weight_kernel.hpp:376
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_convolution_backward_weight_kernel.hpp:377
GroupedConvBwdWeightKernelArgs< GroupedConvTraitsType_ > GroupedConvBwdWeightKernelArgsSpecialized
Definition: grouped_convolution_backward_weight_kernel.hpp:399
static CK_TILE_HOST const std::string GetName()
Definition: grouped_convolution_backward_weight_kernel.hpp:418
static CK_TILE_DEVICE void RunGemm2LDS(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t block_idx_k)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_weight_kernel.hpp:847
remove_cvref_t< typename GemmPipeline::CLayout > GemmCLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:381
static constexpr auto I2
Definition: grouped_convolution_backward_weight_kernel.hpp:406
static constexpr CK_TILE_HOST GroupedConvBwdWeightKernelArgsSpecialized MakeKernelArgs(const GroupedConvBwdWeightHostArgs &hostArgs)
Definition: grouped_convolution_backward_weight_kernel.hpp:466
CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized &kargs) const
Definition: grouped_convolution_backward_weight_kernel.hpp:907
remove_cvref_t< typename GemmPipeline::ALayout > GemmALayout
Definition: grouped_convolution_backward_weight_kernel.hpp:379
static CK_TILE_HOST bool IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_weight_kernel.hpp:510
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_convolution_backward_weight_kernel.hpp:378
CK_TILE_DEVICE void CallExplicitGemm(GroupedConvBwdWeightKernelArgsSpecialized &kargs) const
Definition: grouped_convolution_backward_weight_kernel.hpp:882
static constexpr ConvolutionSpecialization ConvSpecialization
Definition: grouped_convolution_backward_weight_kernel.hpp:374
remove_cvref_t< typename GroupedConvTraitsType_::WeiLayout > WeiLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:384
static constexpr bool IsSplitKSupported
Definition: grouped_convolution_backward_weight_kernel.hpp:402
static CK_TILE_DEVICE void RunGemm(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *smem_ptr_0, const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t block_idx_k)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_weight_kernel.hpp:797
static constexpr index_t NDimSpatial
Definition: grouped_convolution_backward_weight_kernel.hpp:373
remove_cvref_t< typename GroupedConvTraitsType_::DsLayout > DsLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:386
remove_cvref_t< typename GroupedConvTraitsType_::InLayout > InLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:383
remove_cvref_t< typename EpiloguePipeline::ODataType > WeiDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:396
static constexpr auto I3
Definition: grouped_convolution_backward_weight_kernel.hpp:407
static constexpr auto I0
Definition: grouped_convolution_backward_weight_kernel.hpp:404
static constexpr auto I1
Definition: grouped_convolution_backward_weight_kernel.hpp:405
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:395
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_weight_kernel.hpp:389
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: grouped_convolution_backward_weight_kernel.hpp:477
remove_cvref_t< typename GemmPipeline::ADataType > OutDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:393
static CK_TILE_DEVICE auto MakeGemmTensorViews(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, const GroupedConvBwdWeightKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_weight_kernel.hpp:653
remove_cvref_t< typename GemmPipeline::BDataType > InDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:394
remove_cvref_t< typename GemmPipeline::BLayout > GemmBLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:380
static constexpr CK_TILE_HOST auto GridSize(const GroupedConvBwdWeightKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_weight_kernel.hpp:454
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n, const index_t i_k)
Create views to the data that each workgroup will process.
Definition: grouped_convolution_backward_weight_kernel.hpp:744
Definition: transform_conv_bwd_weight_to_gemm.hpp:21
CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N() const
Definition: transform_conv_bwd_weight_to_gemm.hpp:818
Definition: integral_constant.hpp:13
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition: convolution_parameter.hpp:130
ck_tile::long_index_t K_
Definition: convolution_parameter.hpp:126
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition: convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > input_right_pads_
Definition: convolution_parameter.hpp:137
ck_tile::long_index_t G_
Definition: convolution_parameter.hpp:124
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition: convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition: convolution_parameter.hpp:129
ck_tile::long_index_t C_
Definition: convolution_parameter.hpp:127
ck_tile::long_index_t N_
Definition: convolution_parameter.hpp:125
std::vector< ck_tile::long_index_t > input_left_pads_
Definition: convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition: convolution_parameter.hpp:134
Definition: type_traits.hpp:115
Definition: sequence.hpp:49
#define CK_TILE_ENV(name)
Definition: env.hpp:145