/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_weight_kernel.hpp Source File
grouped_convolution_backward_weight_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
16 
17 namespace ck_tile {
18 
20 template <typename GroupedConvTraitsType_>
22 {
23 
25  TransformConvBwdWeightToGemm<GroupedConvTraitsType_::NDimSpatial,
26  GroupedConvTraitsType_::ConvSpecialization,
27  GroupedConvTraitsType_::VectorSizeA,
28  GroupedConvTraitsType_::VectorSizeB,
29  GroupedConvTraitsType_::VectorSizeC>;
30  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
31 
32  template <
33  typename InLay = typename GroupedConvTraitsType_::InLayout,
34  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
35  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
36  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
37  std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
38  std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
39  bool>::type = false>
41  {
42  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
43  static_cast<index_t>(args.N_),
44  static_cast<index_t>(args.C_),
45  static_cast<index_t>(args.input_spatial_lengths_[0])};
46  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
47  static_cast<index_t>(args.K_),
48  static_cast<index_t>(args.C_),
49  static_cast<index_t>(args.filter_spatial_lengths_[0])};
50  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
51  static_cast<index_t>(args.N_),
52  static_cast<index_t>(args.K_),
53  static_cast<index_t>(args.output_spatial_lengths_[0])};
54 
55  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
56  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0])};
57  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
58  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
59 
60  k_batch = args.k_batch;
61 
62  in_ptr = args.in_ptr;
63  wei_ptr = args.wei_ptr;
64  for(index_t d = 0; d < NumDTensor; d++)
65  {
66  ds_ptr[d] = args.ds_ptr[d];
67  }
68  out_ptr = args.out_ptr;
69 
70  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
77 
78  // tuple
79  auto grid_descs =
80  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
81  GroupedConvTraitsType_::NDimSpatial>();
82 
83  a_grid_desc_k_m = grid_descs.at(number<0>{});
84  b_grid_desc_k_n = grid_descs.at(number<1>{});
85  c_grid_desc_m_n = grid_descs.at(number<2>{});
86 
87  group_stride_a = args.K_; // A: Out NWGK
88  group_stride_b = args.C_; // B: In NWGC
89  group_stride_c = args.K_ * args.C_ * // C: Wei GKXC
90  std::accumulate(args.filter_spatial_lengths_.begin(),
91  args.filter_spatial_lengths_.end(),
92  1,
93  std::multiplies<index_t>());
94 
95  GemmM = a_grid_desc_k_m.get_length(number<1>{});
96  GemmN = b_grid_desc_k_n.get_length(number<1>{});
97  GemmK = a_grid_desc_k_m.get_length(number<0>{});
98  GemmBatch = args.G_;
99  }
100 
101  template <
102  typename InLay = typename GroupedConvTraitsType_::InLayout,
103  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
104  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
105  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
106  std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
107  std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
108  bool>::type = false>
110  {
111  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
112  static_cast<index_t>(args.N_),
113  static_cast<index_t>(args.C_),
114  static_cast<index_t>(args.input_spatial_lengths_[0]),
115  static_cast<index_t>(args.input_spatial_lengths_[1])};
116  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
117  static_cast<index_t>(args.K_),
118  static_cast<index_t>(args.C_),
119  static_cast<index_t>(args.filter_spatial_lengths_[0]),
120  static_cast<index_t>(args.filter_spatial_lengths_[1])};
121  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
122  static_cast<index_t>(args.N_),
123  static_cast<index_t>(args.K_),
124  static_cast<index_t>(args.output_spatial_lengths_[0]),
125  static_cast<index_t>(args.output_spatial_lengths_[1])};
126 
127  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
128  static_cast<index_t>(args.conv_filter_strides_[1])};
129  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
130  static_cast<index_t>(args.conv_filter_dilations_[1])};
131  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
132  static_cast<index_t>(args.input_left_pads_[1])};
133  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
134  static_cast<index_t>(args.input_right_pads_[1])};
135 
136  k_batch = args.k_batch;
137 
138  in_ptr = args.in_ptr;
139  wei_ptr = args.wei_ptr;
140  for(index_t d = 0; d < NumDTensor; d++)
141  {
142  ds_ptr[d] = args.ds_ptr[d];
143  }
144  out_ptr = args.out_ptr;
145 
146  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
153 
154  // tuple
155  auto grid_descs =
156  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
157  GroupedConvTraitsType_::NDimSpatial>();
158 
159  a_grid_desc_k_m = grid_descs.at(number<0>{});
160  b_grid_desc_k_n = grid_descs.at(number<1>{});
161  c_grid_desc_m_n = grid_descs.at(number<2>{});
162 
163  group_stride_a = args.K_; // A: Out NHWGK
164  group_stride_b = args.C_; // B: In NHWGC
165  group_stride_c = args.K_ * args.C_ * // C: Wei GKYXC
166  std::accumulate(args.filter_spatial_lengths_.begin(),
167  args.filter_spatial_lengths_.end(),
168  1,
169  std::multiplies<index_t>());
170 
171  GemmM = a_grid_desc_k_m.get_length(number<1>{});
172  GemmN = b_grid_desc_k_n.get_length(number<1>{});
173  GemmK = a_grid_desc_k_m.get_length(number<0>{});
174  GemmBatch = args.G_;
175  }
176 
177  template <
178  typename InLay = typename GroupedConvTraitsType_::InLayout,
179  typename WeiLay = typename GroupedConvTraitsType_::WeiLayout,
180  typename OutLay = typename GroupedConvTraitsType_::OutLayout,
181  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
182  std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
183  std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
184  bool>::type = false>
186  {
187  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
188  static_cast<index_t>(args.N_),
189  static_cast<index_t>(args.C_),
190  static_cast<index_t>(args.input_spatial_lengths_[0]),
191  static_cast<index_t>(args.input_spatial_lengths_[1]),
192  static_cast<index_t>(args.input_spatial_lengths_[2])};
193  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
194  static_cast<index_t>(args.K_),
195  static_cast<index_t>(args.C_),
196  static_cast<index_t>(args.filter_spatial_lengths_[0]),
197  static_cast<index_t>(args.filter_spatial_lengths_[1]),
198  static_cast<index_t>(args.filter_spatial_lengths_[2])};
199  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
200  static_cast<index_t>(args.N_),
201  static_cast<index_t>(args.K_),
202  static_cast<index_t>(args.output_spatial_lengths_[0]),
203  static_cast<index_t>(args.output_spatial_lengths_[1]),
204  static_cast<index_t>(args.output_spatial_lengths_[2])};
205 
206  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
207  static_cast<index_t>(args.conv_filter_strides_[1]),
208  static_cast<index_t>(args.conv_filter_strides_[2])};
209  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
210  static_cast<index_t>(args.conv_filter_dilations_[1]),
211  static_cast<index_t>(args.conv_filter_dilations_[2])};
212  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
213  static_cast<index_t>(args.input_left_pads_[1]),
214  static_cast<index_t>(args.input_left_pads_[2])};
215  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
216  static_cast<index_t>(args.input_right_pads_[1]),
217  static_cast<index_t>(args.input_right_pads_[2])};
218 
219  k_batch = args.k_batch;
220 
221  in_ptr = args.in_ptr;
222  wei_ptr = args.wei_ptr;
223  for(index_t d = 0; d < NumDTensor; d++)
224  {
225  ds_ptr[d] = args.ds_ptr[d];
226  }
227  out_ptr = args.out_ptr;
228 
229  ConvToGemmTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
236 
237  // tuple
238  auto grid_descs =
239  conv_to_gemm_transformer.template MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<
240  GroupedConvTraitsType_::NDimSpatial>();
241 
242  a_grid_desc_k_m = grid_descs.at(number<0>{});
243  b_grid_desc_k_n = grid_descs.at(number<1>{});
244  c_grid_desc_m_n = grid_descs.at(number<2>{});
245 
246  group_stride_a = args.K_; // A: Out NDHWGK
247  group_stride_b = args.C_; // B: In NDHWGC
248  group_stride_c = args.K_ * args.C_ * // C: wEI GKZYXC
249  std::accumulate(args.filter_spatial_lengths_.begin(),
250  args.filter_spatial_lengths_.end(),
251  1,
252  std::multiplies<index_t>());
253 
254  GemmM = a_grid_desc_k_m.get_length(number<1>{});
255  GemmN = b_grid_desc_k_n.get_length(number<1>{});
256  GemmK = a_grid_desc_k_m.get_length(number<0>{});
257  GemmBatch = args.G_;
258  }
259 
262 
266 
267  static constexpr index_t NonSpatialDims = 3;
271 
276 
282 
283  const void* out_ptr;
284  const void* in_ptr;
285  std::array<const void*, NumDTensor> ds_ptr;
286  void* wei_ptr;
287 
291 
295 };
296 
335 template <typename GroupedConvTraitsType_,
336  typename TilePartitioner_,
337  typename GemmPipeline_,
338  typename EpiloguePipeline_>
340 {
341  static constexpr index_t NDimSpatial = GroupedConvTraitsType_::NDimSpatial_;
343  GroupedConvTraitsType_::ConvSpecialization;
350 
355 
357  static constexpr index_t NumDTensor = GroupedConvTraitsType_::NumDTensor;
358 
359  static constexpr index_t kBlockSize = GemmPipeline::BlockSize;
360 
365 
368 
369  // TODO: Enable this
370  static constexpr bool IsSplitKSupported = true;
371 
372  static constexpr auto I0 = number<0>();
373  static constexpr auto I1 = number<1>();
374  static constexpr auto I2 = number<2>();
375  static constexpr auto I3 = number<3>();
376 
377  static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
378  "Not supported!");
379  static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
380  static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
381  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
382 
383  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
384  {
385  // clang-format off
386  return concat('_', "grouped_convolution_backward_weight", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName());
387  // clang-format on
388  }
389 
390  CK_TILE_HOST static constexpr auto
392  {
393  return dim3(
394  TilePartitioner::GridSize(kargs.GemmM, kargs.GemmN), kargs.GemmBatch, kargs.k_batch);
395  }
396 
397  CK_TILE_HOST static constexpr auto BlockSize()
398  {
399  return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
400  }
401 
404  {
406  }
407 
409  {
410  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
411  }
412 
414  {
416  const std::size_t k_id = blockIdx.z)
417  {
418  constexpr auto K1 = TilePartitioner::BlockGemmShape::WarpTile::at(number<2>{});
419  const index_t K_t = amd_wave_read_first_lane(kargs.k_batch * K1);
420  const index_t KRead = amd_wave_read_first_lane((kargs.GemmK + K_t - 1) / K_t * K1);
421 
424 
425  if(k_id < static_cast<uint32_t>(kargs.k_batch - 1))
426  {
428  }
429  else
430  {
431  splitted_k = amd_wave_read_first_lane(kargs.GemmK - KRead * (kargs.k_batch - 1));
432  }
433  }
434 
438  };
439 
441  const stream_config& s)
442  {
443  return [&]() {
444  if(kargs.k_batch > 1)
445  hipGetErrorString(hipMemsetAsync(kargs.wei_ptr,
446  0,
447  kargs.GemmBatch * kargs.GemmM * kargs.GemmN *
448  sizeof(WeiDataType),
449  s.stream_id_));
450  };
451  }
452 
453  CK_TILE_HOST static bool
455  {
456  if constexpr((GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
459  {
460  if(kargs.k_batch != 1)
461  {
462  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
463  {
464  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
465  }
466  return false;
467  }
468  }
469 
470  const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
471  const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
472 
473  // check ConvSpecialization
475  {
476  // check if it's 1x1, stride=1 conv
477  for(index_t i = 0; i < NDimSpatial; ++i)
478  {
479  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
480  const index_t ConvStride = kargs.conv_filter_strides[i];
481  const index_t LeftPad = kargs.input_left_pads[i];
482  const index_t RightPad = kargs.input_right_pads[i];
483 
484  if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
485  {
486  return false;
487  }
488  }
489  }
491  {
492  // check if it's 1x1 conv
493  for(index_t i = 0; i < NDimSpatial; ++i)
494  {
495  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
496  const index_t LeftPad = kargs.input_left_pads[i];
497  const index_t RightPad = kargs.input_right_pads[i];
498 
499  if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
500  {
501  return false;
502  }
503  }
504  }
506  {
507  if(ConvC != 1)
508  {
509  return false;
510  }
511  for(index_t i = 0; i < NDimSpatial; ++i)
512  {
513  const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
514 
515  if(filter_spatial_dim != I3)
516  {
517  return false;
518  }
519  }
520  }
521 
522  namespace ctc = tensor_layout::convolution;
523 
524  if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
525  std::is_same_v<InLayout, ctc::NDHWGC>)
526  {
527  // Check access per C
528  if(ConvC % GroupedConvTraitsType_::VectorSizeB != 0)
529  {
530  CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
531  return false;
532  }
533  }
534  else
535  {
536  CK_TILE_ERROR("Not supported input layout!");
537  return false;
538  }
539 
540  if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
541  std::is_same_v<WeiLayout, ctc::GKYXC> ||
542  std::is_same_v<WeiLayout, ctc::GKZYXC>)
543  {
544  if(ConvC % GroupedConvTraitsType_::VectorSizeC != 0)
545  {
546  CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
547  return false;
548  }
549  }
550  else
551  {
552  CK_TILE_ERROR("Not supported weight layout!");
553  return false;
554  }
555 
556  if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
557  std::is_same_v<OutLayout, ctc::NHWGK> ||
558  std::is_same_v<OutLayout, ctc::NDHWGK>)
559  {
560  if(ConvK % GroupedConvTraitsType_::VectorSizeA != 0)
561  {
562  CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
563  return false;
564  }
565  }
566  else
567  {
568  CK_TILE_ERROR("Not supported output layout!");
569  return false;
570  }
571 
572  return true;
573  }
574 
575  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
576  CK_TILE_DEVICE static auto
578  const InDataType* b_ptr,
579  const std::array<const void*, NumDTensor>& ds_ptr,
580  WeiDataType* c_ptr,
582  {
583  static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
584  static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
585  const auto& a_tensor_view = [&]() {
586  return make_tensor_view<address_space_enum::global>(a_ptr,
587  kargs.a_grid_desc_k_m); // A: out
588  }();
589 
590  const auto& b_tensor_view = [&]() {
591  return make_tensor_view<address_space_enum::global>(b_ptr,
592  kargs.b_grid_desc_k_n); // B: in
593  }();
594 
595  const auto& c_tensor_view = [&]() {
596  return make_tensor_view<address_space_enum::global, DstInMemOp>(c_ptr,
597  kargs.c_grid_desc_m_n);
598  }();
599 
600  const auto& ds_tensor_view = generate_tuple(
601  [&](auto i) {
602  static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
603  "Not supported!");
604  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
605  "Not supported!");
606  static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, WeiDataType>,
607  "Not supported!");
608 
609  return make_tensor_view<address_space_enum::global>(
610  static_cast<WeiDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
611  },
613 
614  return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
615  }
616 
617  template <typename TensorView>
618  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views, const index_t k_batch)
619  {
620  const auto& a_pad_view = [&]() {
621  const auto& a_tensor_view = views.at(I0);
622  return pad_tensor_view(a_tensor_view,
626  }();
627 
628  const auto& b_pad_view = [&]() {
629  const auto& b_tensor_view = views.at(I1);
630  return pad_tensor_view(b_tensor_view,
634  }();
635 
636  const auto& ds_tensor_view = views.at(I2);
637  const auto& ds_pad_view = generate_tuple(
638  [&](auto i) {
639  return pad_tensor_view(ds_tensor_view[i],
643  },
645 
646  const auto& c_pad_view = [&]() {
647  const auto& c_tensor_view = views.at(I3);
648  return pad_tensor_view(c_tensor_view,
652  }();
653 
654  return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
655  }
656 
657  template <typename PadView>
658  CK_TILE_DEVICE static auto MakeGemmTileWindows(const PadView& views,
659  const index_t i_m,
660  const index_t i_n,
661  const index_t i_k)
662  {
663  const auto& a_pad_view = views.at(I0);
664  const auto& b_pad_view = views.at(I1);
665  const auto& ds_pad_view = views.at(I2);
666  const auto& c_pad_view = views.at(I3);
667 
668  const auto& a_block_window = [&]() {
669  return make_tile_window(a_pad_view,
672  {i_k, i_m});
673  }();
674 
675  const auto& b_block_window = [&]() {
676  return make_tile_window(b_pad_view,
679  {i_k, i_n});
680  }();
681 
682  const auto ds_block_window = generate_tuple(
683  [&](auto i) {
684  return make_tile_window(ds_pad_view[i],
687  {i_m, i_n});
688  },
690 
691  auto c_block_window = make_tile_window(
692  c_pad_view,
694  {i_m, i_n});
695 
696  return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
697  }
698 
711  CK_TILE_DEVICE static void RunGemm(const OutDataType* a_ptr,
712  const InDataType* b_ptr,
713  const std::array<const void*, NumDTensor>& ds_ptr,
714  WeiDataType* c_ptr,
715  void* smem_ptr_0,
717  const index_t num_loop,
718  const index_t block_idx_m,
719  const index_t block_idx_n,
720  const index_t block_idx_k)
721  {
722  // Create Gemm tensor views, pad views and tile windows
723  const auto& gemm_tensor_views_tuple =
724  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
725  a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
726 
727  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch);
728  auto gemm_tile_windows =
729  MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
730 
731  // Run GEMM cooperatively by whole workgroup.
732  const auto& a_block_window = gemm_tile_windows.at(I0);
733  const auto& b_block_window = gemm_tile_windows.at(I1);
734  const auto& d_block_window = gemm_tile_windows.at(I2);
735 
736  const auto& c_block_tile = GemmPipeline{}.template operator()(
737  a_block_window, b_block_window, num_loop, smem_ptr_0);
738 
739  // Run Epilogue Pipeline
740  auto& c_block_window = gemm_tile_windows.at(I3);
741 
742  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
743  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
744  }
745 
761  CK_TILE_DEVICE static void RunGemm2LDS(const OutDataType* a_ptr,
762  const InDataType* b_ptr,
763  const std::array<const void*, NumDTensor>& ds_ptr,
764  WeiDataType* c_ptr,
765  void* __restrict__ smem_ptr_0,
766  void* __restrict__ smem_ptr_1,
768  const index_t num_loop,
769  const index_t block_idx_m,
770  const index_t block_idx_n,
771  const index_t block_idx_k)
772  {
773  // Create Gemm tensor views, pad views and tile windows
774  const auto& gemm_tensor_views_tuple =
775  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
776  a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
777  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple, kargs.k_batch);
778  auto gemm_tile_windows =
779  MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n, block_idx_k);
780 
781  // Run GEMM cooperatively by whole workgroup.
782  const auto& a_block_window = gemm_tile_windows.at(I0);
783  const auto& b_block_window = gemm_tile_windows.at(I1);
784  const auto& d_block_window = gemm_tile_windows.at(I2);
785 
786  const auto& c_block_tile = GemmPipeline{}.template operator()(
787  a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
788 
789  // Run Epilogue Pipeline
790  auto& c_block_window = gemm_tile_windows.at(I3);
791 
792  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
793  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
794  }
795 
797  {
798  const auto blockIdX = amd_wave_read_first_lane(blockIdx.x);
799  const auto [iM, iN] =
800  TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
801  const index_t i_m = amd_wave_read_first_lane(iM * TilePartitioner::MPerBlock);
802  const index_t i_n = amd_wave_read_first_lane(iN * TilePartitioner::NPerBlock);
803 
804  const auto blockIdZ = amd_wave_read_first_lane(blockIdx.z);
805  const index_t num_loop = amd_wave_read_first_lane(
806  ck_tile::integer_divide_ceil(kargs.GemmK, kargs.k_batch * TilePartitioner::KPerBlock));
807  const index_t i_k =
808  amd_wave_read_first_lane(blockIdZ * num_loop * TilePartitioner::KPerBlock);
809 
810  const auto blockIdY = amd_wave_read_first_lane(blockIdx.y);
811  const auto group_offset_a = amd_wave_read_first_lane(kargs.group_stride_a * blockIdY);
812  const auto group_offset_b = amd_wave_read_first_lane(kargs.group_stride_b * blockIdY);
813  const auto group_offset_c = amd_wave_read_first_lane(kargs.group_stride_c * blockIdY);
814 
815  // options
816  // conv_bwd_weight = Out * In = Weight
817  const OutDataType* a_ptr = static_cast<const OutDataType*>(kargs.out_ptr) + group_offset_a;
818  const InDataType* b_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_b;
819  WeiDataType* c_ptr = static_cast<WeiDataType*>(kargs.wei_ptr) + group_offset_c;
820 
821  // allocate LDS
822  __shared__ char smem_ptr_0[GetSmemSize()];
823 
824  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
825  {
826  __shared__ char smem_ptr_1[GetSmemSize()];
827  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
828  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
830  {
831  RunGemm2LDS(a_ptr,
832  b_ptr,
833  kargs.ds_ptr,
834  c_ptr,
835  smem_ptr_0,
836  smem_ptr_1,
837  kargs,
838  num_loop,
839  i_m,
840  i_n,
841  i_k);
842  }
843  }
844  else
845  {
846  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
847  GroupedConvTraitsType_::VectorSizeC % 2 != 0 &&
849  {
850  RunGemm(
851  a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, num_loop, i_m, i_n, i_k);
852  }
853  }
854  }
855 };
856 
857 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
Definition: cluster_descriptor.hpp:13
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
__device__ uint32_t amd_wave_read_first_lane(uint16_t v)
Definition: amd_buffer_addressing.hpp:33
ConvolutionSpecialization
Definition: convolution_specialization.hpp:11
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:530
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:43
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:75
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:429
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:360
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr bool is_same_v
Definition: type.hpp:283
__device__ X atomic_add(X *p_dst, const X &x)
unsigned int uint32_t
Definition: stdint.h:126
The Grouped Convolution kernel device arguments.
Definition: grouped_convolution_backward_weight_kernel.hpp:22
long_index_t group_stride_a
Definition: grouped_convolution_backward_weight_kernel.hpp:292
remove_cvref_t< decltype(ConvToGemmTransformer{}.MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N())> ABCGridDescs
Definition: grouped_convolution_backward_weight_kernel.hpp:261
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_strides
Definition: grouped_convolution_backward_weight_kernel.hpp:272
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > wei_g_k_c_xs_lengths
Definition: grouped_convolution_backward_weight_kernel.hpp:269
void * wei_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:286
long_index_t group_stride_b
Definition: grouped_convolution_backward_weight_kernel.hpp:293
CGridDescMN c_grid_desc_m_n
Definition: grouped_convolution_backward_weight_kernel.hpp:290
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > in_g_n_c_wis_lengths
Definition: grouped_convolution_backward_weight_kernel.hpp:268
array< index_t, GroupedConvTraitsType_::NDimSpatial > conv_filter_dilations
Definition: grouped_convolution_backward_weight_kernel.hpp:273
AGridDescKM a_grid_desc_k_m
Definition: grouped_convolution_backward_weight_kernel.hpp:288
BGridDescKN b_grid_desc_k_n
Definition: grouped_convolution_backward_weight_kernel.hpp:289
index_t GemmN
Definition: grouped_convolution_backward_weight_kernel.hpp:279
index_t GemmBatch
Definition: grouped_convolution_backward_weight_kernel.hpp:281
array< index_t, NonSpatialDims+GroupedConvTraitsType_::NDimSpatial > out_g_n_k_wos_lengths
Definition: grouped_convolution_backward_weight_kernel.hpp:270
CK_TILE_HOST GroupedConvBwdWeightKernelArgs(const GroupedConvBwdWeightHostArgs &args)
Definition: grouped_convolution_backward_weight_kernel.hpp:40
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_left_pads
Definition: grouped_convolution_backward_weight_kernel.hpp:274
remove_cvref_t< decltype(ABCGridDescs{}[number< 1 >{}])> BGridDescKN
Definition: grouped_convolution_backward_weight_kernel.hpp:264
std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:285
index_t GemmM
Definition: grouped_convolution_backward_weight_kernel.hpp:278
remove_cvref_t< decltype(ABCGridDescs{}[number< 2 >{}])> CGridDescMN
Definition: grouped_convolution_backward_weight_kernel.hpp:265
array< index_t, GroupedConvTraitsType_::NDimSpatial > input_right_pads
Definition: grouped_convolution_backward_weight_kernel.hpp:275
index_t GemmK
Definition: grouped_convolution_backward_weight_kernel.hpp:280
const void * in_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:284
index_t k_batch
Definition: grouped_convolution_backward_weight_kernel.hpp:277
static constexpr index_t NonSpatialDims
Definition: grouped_convolution_backward_weight_kernel.hpp:267
const void * out_ptr
Definition: grouped_convolution_backward_weight_kernel.hpp:283
remove_cvref_t< decltype(ABCGridDescs{}[number< 0 >{}])> AGridDescKM
Definition: grouped_convolution_backward_weight_kernel.hpp:263
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_weight_kernel.hpp:30
long_index_t group_stride_c
Definition: grouped_convolution_backward_weight_kernel.hpp:294
The Grouped Conv kernel host arguments.
Definition: grouped_convolution_utils.hpp:19
index_t k_batch
Definition: grouped_convolution_utils.hpp:40
InPtr in_ptr
Definition: grouped_convolution_utils.hpp:36
WeiPtr wei_ptr
Definition: grouped_convolution_utils.hpp:37
OutPtr out_ptr
Definition: grouped_convolution_utils.hpp:39
const std::vector< const void * > ds_ptr
Definition: grouped_convolution_utils.hpp:38
Definition: grouped_convolution_backward_weight_kernel.hpp:414
index_t b_k_split_offset
Definition: grouped_convolution_backward_weight_kernel.hpp:436
index_t splitted_k
Definition: grouped_convolution_backward_weight_kernel.hpp:437
__device__ SplitKBatchOffset(const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const std::size_t k_id=blockIdx.z)
Definition: grouped_convolution_backward_weight_kernel.hpp:415
index_t a_k_split_offset
Definition: grouped_convolution_backward_weight_kernel.hpp:435
The Grouped Convolution Backward Weight kernel template.
Definition: grouped_convolution_backward_weight_kernel.hpp:340
remove_cvref_t< typename EpiloguePipeline::DsLayout > GemmDsLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:356
static constexpr index_t kBlockSize
Definition: grouped_convolution_backward_weight_kernel.hpp:359
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views, const index_t k_batch)
Definition: grouped_convolution_backward_weight_kernel.hpp:618
static constexpr CK_TILE_HOST auto BlockSize()
Definition: grouped_convolution_backward_weight_kernel.hpp:397
remove_cvref_t< typename GroupedConvTraitsType_::OutLayout > OutLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:353
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_backward_weight_kernel.hpp:344
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_convolution_backward_weight_kernel.hpp:345
GroupedConvBwdWeightKernelArgs< GroupedConvTraitsType_ > GroupedConvBwdWeightKernelArgsSpecialized
Definition: grouped_convolution_backward_weight_kernel.hpp:367
static CK_TILE_HOST const std::string GetName()
Definition: grouped_convolution_backward_weight_kernel.hpp:383
static CK_TILE_DEVICE void RunGemm2LDS(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t block_idx_k)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_weight_kernel.hpp:761
remove_cvref_t< typename GemmPipeline::CLayout > GemmCLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:349
static constexpr auto I2
Definition: grouped_convolution_backward_weight_kernel.hpp:374
static constexpr CK_TILE_HOST GroupedConvBwdWeightKernelArgsSpecialized MakeKernelArgs(const GroupedConvBwdWeightHostArgs &hostArgs)
Definition: grouped_convolution_backward_weight_kernel.hpp:403
remove_cvref_t< typename GemmPipeline::ALayout > GemmALayout
Definition: grouped_convolution_backward_weight_kernel.hpp:347
static CK_TILE_HOST bool IsSupportedArgument(const GroupedConvBwdWeightKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_weight_kernel.hpp:454
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_convolution_backward_weight_kernel.hpp:346
static constexpr ConvolutionSpecialization ConvSpecialization
Definition: grouped_convolution_backward_weight_kernel.hpp:342
remove_cvref_t< typename GroupedConvTraitsType_::WeiLayout > WeiLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:352
static constexpr bool IsSplitKSupported
Definition: grouped_convolution_backward_weight_kernel.hpp:370
static CK_TILE_DEVICE void RunGemm(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, void *smem_ptr_0, const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const index_t num_loop, const index_t block_idx_m, const index_t block_idx_n, const index_t block_idx_k)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_backward_weight_kernel.hpp:711
static constexpr index_t NDimSpatial
Definition: grouped_convolution_backward_weight_kernel.hpp:341
remove_cvref_t< typename GroupedConvTraitsType_::DsLayout > DsLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:354
remove_cvref_t< typename GroupedConvTraitsType_::InLayout > InLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:351
static CK_TILE_HOST auto Preprocess(const GroupedConvBwdWeightKernelArgsSpecialized &kargs, const stream_config &s)
Definition: grouped_convolution_backward_weight_kernel.hpp:440
remove_cvref_t< typename EpiloguePipeline::ODataType > WeiDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:364
static constexpr auto I3
Definition: grouped_convolution_backward_weight_kernel.hpp:375
static constexpr auto I0
Definition: grouped_convolution_backward_weight_kernel.hpp:372
CK_TILE_DEVICE void operator()(GroupedConvBwdWeightKernelArgsSpecialized kargs) const
Definition: grouped_convolution_backward_weight_kernel.hpp:796
static constexpr auto I1
Definition: grouped_convolution_backward_weight_kernel.hpp:373
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:363
static constexpr index_t NumDTensor
Definition: grouped_convolution_backward_weight_kernel.hpp:357
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: grouped_convolution_backward_weight_kernel.hpp:408
remove_cvref_t< typename GemmPipeline::ADataType > OutDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:361
static CK_TILE_DEVICE auto MakeGemmTensorViews(const OutDataType *a_ptr, const InDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, WeiDataType *c_ptr, const GroupedConvBwdWeightKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_weight_kernel.hpp:577
remove_cvref_t< typename GemmPipeline::BDataType > InDataType
Definition: grouped_convolution_backward_weight_kernel.hpp:362
remove_cvref_t< typename GemmPipeline::BLayout > GemmBLayout
Definition: grouped_convolution_backward_weight_kernel.hpp:348
static constexpr CK_TILE_HOST auto GridSize(const GroupedConvBwdWeightKernelArgsSpecialized &kargs)
Definition: grouped_convolution_backward_weight_kernel.hpp:391
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n, const index_t i_k)
Definition: grouped_convolution_backward_weight_kernel.hpp:658
Definition: transform_conv_bwd_weight_to_gemm.hpp:22
CK_TILE_HOST auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N() const
Definition: transform_conv_bwd_weight_to_gemm.hpp:548
Definition: integral_constant.hpp:13
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition: convolution_parameter.hpp:130
ck_tile::long_index_t K_
Definition: convolution_parameter.hpp:126
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition: convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > input_right_pads_
Definition: convolution_parameter.hpp:137
ck_tile::long_index_t G_
Definition: convolution_parameter.hpp:124
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition: convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition: convolution_parameter.hpp:129
ck_tile::long_index_t C_
Definition: convolution_parameter.hpp:127
ck_tile::long_index_t N_
Definition: convolution_parameter.hpp:125
std::vector< ck_tile::long_index_t > input_left_pads_
Definition: convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition: convolution_parameter.hpp:134
Definition: type_traits.hpp:115
Definition: sequence.hpp:49
Definition: stream_config.hpp:30
hipStream_t stream_id_
Definition: stream_config.hpp:31
#define CK_TILE_ENV(name)
Definition: env.hpp:145