/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/grouped_convolution/kernel/grouped_convolution_forward_kernel.hpp Source File
grouped_convolution_forward_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <string>
8 
9 #include "ck_tile/core.hpp"
10 #include "ck_tile/ops/common.hpp"
11 #include "ck_tile/host/concat.hpp"
16 
17 namespace ck_tile {
18 
20 template <typename GroupedConvTraitsType>
22 {
23 
25  TransformConvFwdToGemm<GroupedConvTraitsType::NDimSpatial,
26  GroupedConvTraitsType::ConvSpecialization>;
27  static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor;
28 
29  template <
30  typename InLay = typename GroupedConvTraitsType::InLayout,
31  typename WeiLay = typename GroupedConvTraitsType::WeiLayout,
32  typename OutLay = typename GroupedConvTraitsType::OutLayout,
33  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NWGC> &&
34  std::is_same_v<WeiLay, tensor_layout::convolution::GKXC> &&
35  std::is_same_v<OutLay, tensor_layout::convolution::NWGK>,
36  bool>::type = false>
38  {
39  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
40  static_cast<index_t>(args.N_),
41  static_cast<index_t>(args.C_),
42  static_cast<index_t>(args.input_spatial_lengths_[0])};
43  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
44  static_cast<index_t>(args.K_),
45  static_cast<index_t>(args.C_),
46  static_cast<index_t>(args.filter_spatial_lengths_[0])};
47  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
48  static_cast<index_t>(args.N_),
49  static_cast<index_t>(args.K_),
50  static_cast<index_t>(args.output_spatial_lengths_[0])};
51 
52  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0])};
53  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0])};
54  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0])};
55  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0])};
56 
57  k_batch = args.k_batch;
58 
59  GemmM = args.N_ * args.output_spatial_lengths_[0];
60  GemmN = args.K_;
61  GemmK = args.C_ * args.filter_spatial_lengths_[0];
62 
63  in_ptr = args.in_ptr;
64  wei_ptr = args.wei_ptr;
65  for(index_t d = 0; d < NumDTensor; d++)
66  {
67  ds_ptr[d] = args.ds_ptr[d];
68  }
69  out_ptr = args.out_ptr;
70 
71  ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
78 
80  conv_to_gemm_transformer
81  .template MakeADescriptor_M_K<typename GroupedConvTraitsType::InLayout>();
83  conv_to_gemm_transformer
84  .template MakeBDescriptor_N_K<typename GroupedConvTraitsType::WeiLayout>();
86  conv_to_gemm_transformer
87  .template MakeCDescriptor_M_N<typename GroupedConvTraitsType::OutLayout>();
88 
89  group_stride_a = args.C_;
90  group_stride_b = args.K_ * args.C_ *
91  std::accumulate(args.filter_spatial_lengths_.begin(),
92  args.filter_spatial_lengths_.end(),
93  1,
94  std::multiplies<index_t>());
95  group_stride_c = args.K_;
96  }
97 
98  template <
99  typename InLay = typename GroupedConvTraitsType::InLayout,
100  typename WeiLay = typename GroupedConvTraitsType::WeiLayout,
101  typename OutLay = typename GroupedConvTraitsType::OutLayout,
102  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NHWGC> &&
103  std::is_same_v<WeiLay, tensor_layout::convolution::GKYXC> &&
104  std::is_same_v<OutLay, tensor_layout::convolution::NHWGK>,
105  bool>::type = false>
107  {
108  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
109  static_cast<index_t>(args.N_),
110  static_cast<index_t>(args.C_),
111  static_cast<index_t>(args.input_spatial_lengths_[0]),
112  static_cast<index_t>(args.input_spatial_lengths_[1])};
113  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
114  static_cast<index_t>(args.K_),
115  static_cast<index_t>(args.C_),
116  static_cast<index_t>(args.filter_spatial_lengths_[0]),
117  static_cast<index_t>(args.filter_spatial_lengths_[1])};
118  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
119  static_cast<index_t>(args.N_),
120  static_cast<index_t>(args.K_),
121  static_cast<index_t>(args.output_spatial_lengths_[0]),
122  static_cast<index_t>(args.output_spatial_lengths_[1])};
123 
124  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
125  static_cast<index_t>(args.conv_filter_strides_[1])};
126  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
127  static_cast<index_t>(args.conv_filter_dilations_[1])};
128  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
129  static_cast<index_t>(args.input_left_pads_[1])};
130  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
131  static_cast<index_t>(args.input_right_pads_[1])};
132 
133  k_batch = args.k_batch;
134 
135  GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1];
136  GemmN = args.K_;
137  GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1];
138 
139  in_ptr = args.in_ptr;
140  wei_ptr = args.wei_ptr;
141  for(index_t d = 0; d < NumDTensor; d++)
142  {
143  ds_ptr[d] = args.ds_ptr[d];
144  }
145  out_ptr = args.out_ptr;
146 
147  ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
154 
156  conv_to_gemm_transformer
157  .template MakeADescriptor_M_K<typename GroupedConvTraitsType::InLayout>();
159  conv_to_gemm_transformer
160  .template MakeBDescriptor_N_K<typename GroupedConvTraitsType::WeiLayout>();
162  conv_to_gemm_transformer
163  .template MakeCDescriptor_M_N<typename GroupedConvTraitsType::OutLayout>();
164 
165  group_stride_a = args.C_;
166  group_stride_b = args.K_ * args.C_ *
167  std::accumulate(args.filter_spatial_lengths_.begin(),
168  args.filter_spatial_lengths_.end(),
169  1,
170  std::multiplies<index_t>());
171  group_stride_c = args.K_;
172  }
173 
174  template <
175  typename InLay = typename GroupedConvTraitsType::InLayout,
176  typename WeiLay = typename GroupedConvTraitsType::WeiLayout,
177  typename OutLay = typename GroupedConvTraitsType::OutLayout,
178  typename std::enable_if<std::is_same_v<InLay, tensor_layout::convolution::NDHWGC> &&
179  std::is_same_v<WeiLay, tensor_layout::convolution::GKZYXC> &&
180  std::is_same_v<OutLay, tensor_layout::convolution::NDHWGK>,
181  bool>::type = false>
183  {
184  in_g_n_c_wis_lengths = {static_cast<index_t>(args.G_),
185  static_cast<index_t>(args.N_),
186  static_cast<index_t>(args.C_),
187  static_cast<index_t>(args.input_spatial_lengths_[0]),
188  static_cast<index_t>(args.input_spatial_lengths_[1]),
189  static_cast<index_t>(args.input_spatial_lengths_[2])};
190  wei_g_k_c_xs_lengths = {static_cast<index_t>(args.G_),
191  static_cast<index_t>(args.K_),
192  static_cast<index_t>(args.C_),
193  static_cast<index_t>(args.filter_spatial_lengths_[0]),
194  static_cast<index_t>(args.filter_spatial_lengths_[1]),
195  static_cast<index_t>(args.filter_spatial_lengths_[2])};
196  out_g_n_k_wos_lengths = {static_cast<index_t>(args.G_),
197  static_cast<index_t>(args.N_),
198  static_cast<index_t>(args.K_),
199  static_cast<index_t>(args.output_spatial_lengths_[0]),
200  static_cast<index_t>(args.output_spatial_lengths_[1]),
201  static_cast<index_t>(args.output_spatial_lengths_[2])};
202 
203  conv_filter_strides = {static_cast<index_t>(args.conv_filter_strides_[0]),
204  static_cast<index_t>(args.conv_filter_strides_[1]),
205  static_cast<index_t>(args.conv_filter_strides_[2])};
206  conv_filter_dilations = {static_cast<index_t>(args.conv_filter_dilations_[0]),
207  static_cast<index_t>(args.conv_filter_dilations_[1]),
208  static_cast<index_t>(args.conv_filter_dilations_[2])};
209  input_left_pads = {static_cast<index_t>(args.input_left_pads_[0]),
210  static_cast<index_t>(args.input_left_pads_[1]),
211  static_cast<index_t>(args.input_left_pads_[2])};
212  input_right_pads = {static_cast<index_t>(args.input_right_pads_[0]),
213  static_cast<index_t>(args.input_right_pads_[1]),
214  static_cast<index_t>(args.input_right_pads_[2])};
215 
216  k_batch = args.k_batch;
217 
218  GemmM = args.N_ * args.output_spatial_lengths_[0] * args.output_spatial_lengths_[1] *
219  args.output_spatial_lengths_[2];
220  GemmN = args.K_;
221  GemmK = args.C_ * args.filter_spatial_lengths_[0] * args.filter_spatial_lengths_[1] *
222  args.filter_spatial_lengths_[2];
223 
224  in_ptr = args.in_ptr;
225  wei_ptr = args.wei_ptr;
226  for(index_t d = 0; d < NumDTensor; d++)
227  {
228  ds_ptr[d] = args.ds_ptr[d];
229  }
230  out_ptr = args.out_ptr;
231 
232  ConvToGemmFwdTransformer conv_to_gemm_transformer{in_g_n_c_wis_lengths,
239 
241  conv_to_gemm_transformer
242  .template MakeADescriptor_M_K<typename GroupedConvTraitsType::InLayout>();
244  conv_to_gemm_transformer
245  .template MakeBDescriptor_N_K<typename GroupedConvTraitsType::WeiLayout>();
247  conv_to_gemm_transformer
248  .template MakeCDescriptor_M_N<typename GroupedConvTraitsType::OutLayout>();
249 
250  group_stride_a = args.C_;
251  group_stride_b = args.K_ * args.C_ *
252  std::accumulate(args.filter_spatial_lengths_.begin(),
253  args.filter_spatial_lengths_.end(),
254  1,
255  std::multiplies<index_t>());
256  group_stride_c = args.K_;
257  }
258 
259  using AGridDescMK = remove_cvref_t<decltype(
261  .template MakeADescriptor_M_K<typename GroupedConvTraitsType::InLayout>())>;
262  using BGridDescNK = remove_cvref_t<decltype(
264  .template MakeBDescriptor_N_K<typename GroupedConvTraitsType::WeiLayout>())>;
265  using CGridDescMN = remove_cvref_t<decltype(
267  .template MakeCDescriptor_M_N<typename GroupedConvTraitsType::OutLayout>())>;
268 
269  static constexpr index_t NonSpatialDims = 3;
273 
278 
283 
284  const void* in_ptr;
285  const void* wei_ptr;
286  std::array<const void*, NumDTensor> ds_ptr;
287  void* out_ptr;
288 
292 
296 };
297 
336 template <typename GroupedConvTraitsType,
337  typename TilePartitioner_,
338  typename GemmPipeline_,
339  typename EpiloguePipeline_>
341 {
342  static constexpr index_t NDimSpatial = GroupedConvTraitsType::NDimSpatial;
344  GroupedConvTraitsType::ConvSpecialization;
351 
356 
358 
359  static constexpr index_t NumDTensor = GroupedConvTraitsType::NumDTensor;
360 
361  static constexpr index_t KernelBlockSize = GemmPipeline::BlockSize;
362 
366  // Below type is actually accumulation data type - the output of block GEMM.
368 
370 
371  // TODO: Enable this
372  static constexpr bool IsSplitKSupported = false;
373 
374  static constexpr auto I0 = number<0>();
375  static constexpr auto I1 = number<1>();
376  static constexpr auto I2 = number<2>();
377  static constexpr auto I3 = number<3>();
378 
379  static_assert(GemmPipeline::kPadM && GemmPipeline::kPadN && GemmPipeline::kPadK,
380  "Not supported!");
381  static_assert(std::is_same_v<GemmALayout, tensor_layout::gemm::RowMajor>, "Not supported!");
382  static_assert(std::is_same_v<GemmBLayout, tensor_layout::gemm::ColumnMajor>, "Not supported!");
383  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>, "Not supported!");
384 
385  [[nodiscard]] CK_TILE_HOST static const std::string GetName()
386  {
387  // clang-format off
388  return concat('_', "grouped_convolution_forward", gemm_prec_str<InDataType, WeiDataType>, GemmPipeline::GetName());
389  // clang-format on
390  }
391 
392  CK_TILE_HOST static constexpr auto GridSize(const GroupedConvHostArgs& args)
393  {
394  const index_t GemmM = args.N_ * std::accumulate(args.output_spatial_lengths_.begin(),
395  args.output_spatial_lengths_.end(),
396  1,
397  std::multiplies<index_t>());
398  const index_t GemmN = args.K_;
399  return dim3(TilePartitioner::GridSize(GemmM, GemmN), args.G_, args.k_batch);
400  }
401 
402  CK_TILE_HOST static constexpr auto BlockSize() { return dim3(KernelBlockSize); }
403 
406  {
407  return GroupedConvFwdKernelArgsSpecialized(hostArgs);
408  }
409 
411  {
412  return max(GemmPipeline::GetSmemSize(), EpiloguePipeline::GetSmemSize());
413  }
414 
416  {
417  if constexpr((EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
420  {
421  if(kargs.k_batch != 1)
422  {
423  if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
424  {
425  CK_TILE_ERROR("Conditions not met for Kbatch >1 !");
426  }
427  return false;
428  }
429  }
430 
431  const index_t ConvK = kargs.wei_g_k_c_xs_lengths[number<1>{}];
432  const index_t ConvC = kargs.wei_g_k_c_xs_lengths[number<2>{}];
433 
434  // check ConvolutionSpecialization
436  {
437  // check if it's 1x1, stride=1 conv
438  for(index_t i = 0; i < NDimSpatial; ++i)
439  {
440  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
441  const index_t ConvStride = kargs.conv_filter_strides[i];
442  const index_t LeftPad = kargs.input_left_pads[i];
443  const index_t RightPad = kargs.input_right_pads[i];
444 
445  if(!(SpatialDim == 1 && ConvStride == 1 && LeftPad == 0 && RightPad == 0))
446  {
447  return false;
448  }
449  }
450  }
452  {
453  // check if it's 1x1 conv
454  for(index_t i = 0; i < NDimSpatial; ++i)
455  {
456  const index_t SpatialDim = kargs.wei_g_k_c_xs_lengths[i + 3];
457  const index_t LeftPad = kargs.input_left_pads[i];
458  const index_t RightPad = kargs.input_right_pads[i];
459 
460  if(!(SpatialDim == 1 && LeftPad == 0 && RightPad == 0))
461  {
462  return false;
463  }
464  }
465  }
467  {
468  if(ConvC != 1)
469  {
470  return false;
471  }
472  for(index_t i = 0; i < NDimSpatial; ++i)
473  {
474  const index_t filter_spatial_dim = kargs.wei_g_k_c_xs_lengths[i + I3];
475 
476  if(filter_spatial_dim != I3)
477  {
478  return false;
479  }
480  }
481  }
482 
483  namespace ctc = tensor_layout::convolution;
484 
485  if constexpr(std::is_same_v<InLayout, ctc::NWGC> || std::is_same_v<InLayout, ctc::NHWGC> ||
486  std::is_same_v<InLayout, ctc::NDHWGC>)
487  {
488  // Check access per C
489  if(ConvC % GemmPipeline::GetVectorSizeA() != 0)
490  {
491  CK_TILE_ERROR("Conv C is not a multiple of vector load size for input image!");
492  return false;
493  }
494  }
495  else
496  {
497  CK_TILE_ERROR("Not supported input layout!");
498  return false;
499  }
500 
501  // check vector access of B
502  // FIXME: layout
503  if constexpr(std::is_same_v<WeiLayout, ctc::GKXC> ||
504  std::is_same_v<WeiLayout, ctc::GKYXC> ||
505  std::is_same_v<WeiLayout, ctc::GKZYXC>)
506  {
507  if(ConvC % GemmPipeline::GetVectorSizeB() != 0)
508  {
509  CK_TILE_ERROR("Conv C is not a multiple of vector load size for weight!");
510  return false;
511  }
512  }
513  else
514  {
515  CK_TILE_ERROR("Not supported weight layout!");
516  return false;
517  }
518 
519  // check vector access of E
520  if constexpr(std::is_same_v<OutLayout, ctc::NWGK> ||
521  std::is_same_v<OutLayout, ctc::NHWGK> ||
522  std::is_same_v<OutLayout, ctc::NDHWGK>)
523  {
524  if(ConvK % EpiloguePipeline::GetVectorSizeC() != 0)
525  {
526  CK_TILE_ERROR("Conv K is not a multiple of vector store size for output image!");
527  return false;
528  }
529  }
530  else
531  {
532  CK_TILE_ERROR("Not supported output layout!");
533  return false;
534  }
535 
536  return true;
537  }
538 
539  template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
540  CK_TILE_DEVICE static auto
542  const WeiDataType* b_ptr,
543  const std::array<const void*, NumDTensor>& ds_ptr,
544  OutDataType* c_ptr,
546  {
547  static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
548  static_assert(!TilePartitioner::BlockGemmShape::PermuteB, "Not implemented!");
549  const auto& a_tensor_view = [&]() {
550  return make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_m_k);
551  }();
552 
553  const auto& b_tensor_view = [&]() {
554  return make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_n_k);
555  }();
556 
557  // TODO: enable vector write for C in ColMajor
558  const auto& c_tensor_view = [&]() {
559  return make_tensor_view<address_space_enum::global>(c_ptr, kargs.c_grid_desc_m_n);
560  }();
561 
562  const auto& ds_tensor_view = generate_tuple(
563  [&](auto i) {
564  static_assert(std::is_same_v<std::tuple_element_t<i, DsLayout>, OutLayout>,
565  "Not supported!");
566  static_assert(std::is_same_v<GemmCLayout, tensor_layout::gemm::RowMajor>,
567  "Not supported!");
568  static_assert(std::is_same_v<std::tuple_element_t<i, DsDataType>, OutDataType>,
569  "Not supported!");
570 
571  return make_tensor_view<address_space_enum::global>(
572  static_cast<OutDataType*>(ds_ptr[i]), kargs.c_grid_desc_m_n);
573  },
575 
576  return make_tuple(a_tensor_view, b_tensor_view, ds_tensor_view, c_tensor_view);
577  }
578 
579  template <typename TensorView>
580  CK_TILE_DEVICE static auto MakeGemmPadViews(const TensorView& views)
581  {
582  const auto& a_pad_view = [&]() {
583  const auto& a_tensor_view = views.at(I0);
584  return pad_tensor_view(a_tensor_view,
588  }();
589 
590  const auto& b_pad_view = [&]() {
591  const auto& b_tensor_view = views.at(I1);
592  return pad_tensor_view(b_tensor_view,
596  }();
597 
598  const auto& ds_tensor_view = views.at(I2);
599  const auto& ds_pad_view = generate_tuple(
600  [&](auto i) {
601  return pad_tensor_view(ds_tensor_view[i],
605  },
607 
608  const auto& c_pad_view = [&]() {
609  const auto& c_tensor_view = views.at(I3);
610  return pad_tensor_view(c_tensor_view,
614  }();
615 
616  return make_tuple(a_pad_view, b_pad_view, ds_pad_view, c_pad_view);
617  }
618 
619  template <typename PadView>
620  CK_TILE_DEVICE static auto
621  MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
622  {
623  const auto& a_pad_view = views.at(I0);
624  const auto& b_pad_view = views.at(I1);
625  const auto& ds_pad_view = views.at(I2);
626  const auto& c_pad_view = views.at(I3);
627 
628  const auto& a_block_window = [&]() {
629  return make_tile_window(a_pad_view,
632  {i_m, 0});
633  }();
634 
635  const auto& b_block_window = [&]() {
636  return make_tile_window(b_pad_view,
639  {i_n, 0});
640  }();
641 
642  const auto ds_block_window = generate_tuple(
643  [&](auto i) {
644  return make_tile_window(ds_pad_view[i],
647  {i_m, i_n});
648  },
650 
651  auto c_block_window = make_tile_window(
652  c_pad_view,
654  {i_m, i_n});
655 
656  return make_tuple(a_block_window, b_block_window, ds_block_window, c_block_window);
657  }
658 
671  CK_TILE_DEVICE static void RunGemm(const InDataType* a_ptr,
672  const WeiDataType* b_ptr,
673  const std::array<const void*, NumDTensor>& ds_ptr,
674  OutDataType* c_ptr,
675  void* smem_ptr_0,
677  const index_t block_idx_m,
678  const index_t block_idx_n)
679  {
680  // Create Gemm tensor views, pad views and tile windows
681  const auto& gemm_tensor_views_tuple =
682  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
683  a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
684 
685  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
686  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
687 
688  const index_t num_loop =
689  __builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.GemmK));
690 
691  // Run GEMM cooperatively by whole workgroup.
692  const auto& a_block_window = gemm_tile_windows.at(I0);
693  const auto& b_block_window = gemm_tile_windows.at(I1);
694  const auto& d_block_window = gemm_tile_windows.at(I2);
695 
696  const auto& c_block_tile = GemmPipeline{}.template operator()(
697  a_block_window, b_block_window, num_loop, smem_ptr_0);
698 
699  // Run Epilogue Pipeline
700  auto& c_block_window = gemm_tile_windows.at(I3);
701 
702  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
703  c_block_window, c_block_tile, d_block_window, smem_ptr_0);
704  }
705 
721  CK_TILE_DEVICE static void RunGemm2LDS(const InDataType* a_ptr,
722  const WeiDataType* b_ptr,
723  const std::array<const void*, NumDTensor>& ds_ptr,
724  OutDataType* c_ptr,
725  void* __restrict__ smem_ptr_0,
726  void* __restrict__ smem_ptr_1,
728  const index_t block_idx_m,
729  const index_t block_idx_n)
730  {
731  // Create Gemm tensor views, pad views and tile windows
732  const auto& gemm_tensor_views_tuple =
733  MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
734  a_ptr, b_ptr, ds_ptr, c_ptr, kargs);
735  const auto& gemm_pad_views = MakeGemmPadViews(gemm_tensor_views_tuple);
736  auto gemm_tile_windows = MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
737 
738  const index_t num_loop =
739  __builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.GemmK));
740 
741  // Run GEMM cooperatively by whole workgroup.
742  const auto& a_block_window = gemm_tile_windows.at(I0);
743  const auto& b_block_window = gemm_tile_windows.at(I1);
744  const auto& d_block_window = gemm_tile_windows.at(I2);
745 
746  const auto& c_block_tile = GemmPipeline{}.template operator()(
747  a_block_window, b_block_window, num_loop, smem_ptr_0, smem_ptr_1);
748 
749  // Run Epilogue Pipeline
750  auto& c_block_window = gemm_tile_windows.at(I3);
751 
752  EpiloguePipeline{}.template operator()<decltype(c_block_window), decltype(c_block_tile)>(
753  c_block_window, c_block_tile, d_block_window, smem_ptr_0, smem_ptr_1);
754  }
755 
757  {
758  const auto blockIdX = __builtin_amdgcn_readfirstlane(blockIdx.x);
759  const auto [iM, iN] =
760  TilePartitioner{kargs.GemmM, kargs.GemmN}.GetOutputTileIndex(blockIdX);
761  const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
762  const index_t i_n = __builtin_amdgcn_readfirstlane(iN * TilePartitioner::NPerBlock);
763 
764  const auto blockIdY = __builtin_amdgcn_readfirstlane(blockIdx.y);
765  const auto group_offset_a = __builtin_amdgcn_readfirstlane(kargs.group_stride_a * blockIdY);
766  const auto group_offset_b = __builtin_amdgcn_readfirstlane(kargs.group_stride_b * blockIdY);
767  const auto group_offset_c = __builtin_amdgcn_readfirstlane(kargs.group_stride_c * blockIdY);
768 
769  // options
770  const InDataType* a_ptr = static_cast<const InDataType*>(kargs.in_ptr) + group_offset_a;
771  const WeiDataType* b_ptr = static_cast<const WeiDataType*>(kargs.wei_ptr) + group_offset_b;
772  OutDataType* c_ptr = static_cast<OutDataType*>(kargs.out_ptr) + group_offset_c;
773 
774  // allocate LDS
775  __shared__ char smem_ptr_0[GetSmemSize()];
776 
777  if constexpr(GemmPipeline::DoubleSmemBuffer == true)
778  {
779  __shared__ char smem_ptr_1[GetSmemSize()];
780  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
781  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
783  {
784  RunGemm2LDS(
785  a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, smem_ptr_1, kargs, i_m, i_n);
786  }
787  }
788  else
789  {
790  if constexpr(!(EpiloguePipeline::MemoryOperation == memory_operation_enum::atomic_add &&
791  EpiloguePipeline::GetVectorSizeC() % 2 != 0 &&
793  {
794  RunGemm(a_ptr, b_ptr, kargs.ds_ptr, c_ptr, smem_ptr_0, kargs, i_m, i_n);
795  }
796  }
797  }
798 };
799 
800 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
bool EnvIsEnabled(EnvVar)
Definition: env.hpp:156
void CK_TILE_ERROR(Args &&... args) noexcept
Definition: env.hpp:12
ConvolutionSpecialization
Definition: convolution_specialization.hpp:11
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:529
auto concat(const Ts &... xs) -> std::enable_if_t<!AllConvertibleToStringView< Ts... >, std::string >
Definition: concat.hpp:41
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
int64_t long_index_t
Definition: integer.hpp:11
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
constexpr CK_TILE_HOST_DEVICE auto generate_tuple(F &&f, number< N >)
Definition: tuple.hpp:406
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
constexpr bool is_same_v
Definition: type.hpp:283
The Grouped Convolution kernel device arguments.
Definition: grouped_convolution_forward_kernel.hpp:22
index_t k_batch
Definition: grouped_convolution_forward_kernel.hpp:279
array< index_t, NonSpatialDims+GroupedConvTraitsType::NDimSpatial > out_g_n_k_wos_lengths
Definition: grouped_convolution_forward_kernel.hpp:272
CK_TILE_HOST GroupedConvFwdKernelArgs(const GroupedConvHostArgs &args)
Definition: grouped_convolution_forward_kernel.hpp:37
array< index_t, GroupedConvTraitsType::NDimSpatial > input_left_pads
Definition: grouped_convolution_forward_kernel.hpp:276
array< index_t, NonSpatialDims+GroupedConvTraitsType::NDimSpatial > in_g_n_c_wis_lengths
Definition: grouped_convolution_forward_kernel.hpp:270
remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeADescriptor_M_K< typename GroupedConvTraitsType::InLayout >())> AGridDescMK
Definition: grouped_convolution_forward_kernel.hpp:261
const void * wei_ptr
Definition: grouped_convolution_forward_kernel.hpp:285
long_index_t group_stride_c
Definition: grouped_convolution_forward_kernel.hpp:295
index_t GemmN
Definition: grouped_convolution_forward_kernel.hpp:281
index_t GemmM
Definition: grouped_convolution_forward_kernel.hpp:280
long_index_t group_stride_b
Definition: grouped_convolution_forward_kernel.hpp:294
CGridDescMN c_grid_desc_m_n
Definition: grouped_convolution_forward_kernel.hpp:291
index_t GemmK
Definition: grouped_convolution_forward_kernel.hpp:282
array< index_t, GroupedConvTraitsType::NDimSpatial > input_right_pads
Definition: grouped_convolution_forward_kernel.hpp:277
AGridDescMK a_grid_desc_m_k
Definition: grouped_convolution_forward_kernel.hpp:289
remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeBDescriptor_N_K< typename GroupedConvTraitsType::WeiLayout >())> BGridDescNK
Definition: grouped_convolution_forward_kernel.hpp:264
array< index_t, NonSpatialDims+GroupedConvTraitsType::NDimSpatial > wei_g_k_c_xs_lengths
Definition: grouped_convolution_forward_kernel.hpp:271
static constexpr index_t NumDTensor
Definition: grouped_convolution_forward_kernel.hpp:27
BGridDescNK b_grid_desc_n_k
Definition: grouped_convolution_forward_kernel.hpp:290
array< index_t, GroupedConvTraitsType::NDimSpatial > conv_filter_dilations
Definition: grouped_convolution_forward_kernel.hpp:275
array< index_t, GroupedConvTraitsType::NDimSpatial > conv_filter_strides
Definition: grouped_convolution_forward_kernel.hpp:274
void * out_ptr
Definition: grouped_convolution_forward_kernel.hpp:287
static constexpr index_t NonSpatialDims
Definition: grouped_convolution_forward_kernel.hpp:269
long_index_t group_stride_a
Definition: grouped_convolution_forward_kernel.hpp:293
remove_cvref_t< decltype(ConvToGemmFwdTransformer{} .template MakeCDescriptor_M_N< typename GroupedConvTraitsType::OutLayout >())> CGridDescMN
Definition: grouped_convolution_forward_kernel.hpp:267
std::array< const void *, NumDTensor > ds_ptr
Definition: grouped_convolution_forward_kernel.hpp:286
const void * in_ptr
Definition: grouped_convolution_forward_kernel.hpp:284
The Grouped Conv kernel host arguments.
Definition: grouped_convolution_utils.hpp:18
const void * wei_ptr
Definition: grouped_convolution_utils.hpp:36
index_t k_batch
Definition: grouped_convolution_utils.hpp:39
const void * in_ptr
Definition: grouped_convolution_utils.hpp:35
void * out_ptr
Definition: grouped_convolution_utils.hpp:38
const std::vector< const void * > ds_ptr
Definition: grouped_convolution_utils.hpp:37
The Grouped Convolution Forward kernel template.
Definition: grouped_convolution_forward_kernel.hpp:341
static CK_TILE_DEVICE auto MakeGemmPadViews(const TensorView &views)
Definition: grouped_convolution_forward_kernel.hpp:580
remove_cvref_t< GemmPipeline_ > GemmPipeline
Definition: grouped_convolution_forward_kernel.hpp:346
remove_cvref_t< EpiloguePipeline_ > EpiloguePipeline
Definition: grouped_convolution_forward_kernel.hpp:347
static constexpr CK_TILE_HOST auto GridSize(const GroupedConvHostArgs &args)
Definition: grouped_convolution_forward_kernel.hpp:392
static constexpr auto I2
Definition: grouped_convolution_forward_kernel.hpp:376
static constexpr ConvolutionSpecialization ConvSpecialization
Definition: grouped_convolution_forward_kernel.hpp:343
remove_cvref_t< typename EpiloguePipeline::DsDataType > DsDataType
Definition: grouped_convolution_forward_kernel.hpp:365
remove_cvref_t< typename GemmPipeline::BDataType > WeiDataType
Definition: grouped_convolution_forward_kernel.hpp:364
remove_cvref_t< typename GemmPipeline::ADataType > InDataType
Definition: grouped_convolution_forward_kernel.hpp:363
static CK_TILE_DEVICE void RunGemm2LDS(const InDataType *a_ptr, const WeiDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, OutDataType *c_ptr, void *__restrict__ smem_ptr_0, void *__restrict__ smem_ptr_1, const GroupedConvFwdKernelArgsSpecialized &kargs, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_forward_kernel.hpp:721
remove_cvref_t< typename GroupedConvTraitsType::DsLayout > DsLayout
Definition: grouped_convolution_forward_kernel.hpp:355
static CK_TILE_HOST bool IsSupportedArgument(const GroupedConvFwdKernelArgsSpecialized &kargs)
Definition: grouped_convolution_forward_kernel.hpp:415
static constexpr auto I1
Definition: grouped_convolution_forward_kernel.hpp:375
static CK_TILE_DEVICE void RunGemm(const InDataType *a_ptr, const WeiDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, OutDataType *c_ptr, void *smem_ptr_0, const GroupedConvFwdKernelArgsSpecialized &kargs, const index_t block_idx_m, const index_t block_idx_n)
Runs single GEMM problem cooperatively by whole workgroup.
Definition: grouped_convolution_forward_kernel.hpp:671
CK_TILE_DEVICE void operator()(GroupedConvFwdKernelArgsSpecialized kargs) const
Definition: grouped_convolution_forward_kernel.hpp:756
static constexpr bool IsSplitKSupported
Definition: grouped_convolution_forward_kernel.hpp:372
static constexpr index_t NDimSpatial
Definition: grouped_convolution_forward_kernel.hpp:342
static constexpr index_t NumDTensor
Definition: grouped_convolution_forward_kernel.hpp:359
remove_cvref_t< typename GemmPipeline::CLayout > GemmCLayout
Definition: grouped_convolution_forward_kernel.hpp:350
remove_cvref_t< typename GroupedConvTraitsType::OutLayout > OutLayout
Definition: grouped_convolution_forward_kernel.hpp:354
static CK_TILE_HOST const std::string GetName()
Definition: grouped_convolution_forward_kernel.hpp:385
remove_cvref_t< typename EpiloguePipeline::DsLayout > GemmDsLayout
Definition: grouped_convolution_forward_kernel.hpp:357
static constexpr index_t KernelBlockSize
Definition: grouped_convolution_forward_kernel.hpp:361
static constexpr auto I0
Definition: grouped_convolution_forward_kernel.hpp:374
remove_cvref_t< TilePartitioner_ > TilePartitioner
Definition: grouped_convolution_forward_kernel.hpp:345
remove_cvref_t< typename GemmPipeline::ALayout > GemmALayout
Definition: grouped_convolution_forward_kernel.hpp:348
static constexpr CK_TILE_HOST GroupedConvFwdKernelArgsSpecialized MakeKernelArgs(const GroupedConvHostArgs &hostArgs)
Definition: grouped_convolution_forward_kernel.hpp:405
remove_cvref_t< typename GroupedConvTraitsType::WeiLayout > WeiLayout
Definition: grouped_convolution_forward_kernel.hpp:353
static CK_TILE_DEVICE auto MakeGemmTileWindows(const PadView &views, const index_t i_m, const index_t i_n)
Definition: grouped_convolution_forward_kernel.hpp:621
static constexpr CK_TILE_HOST auto BlockSize()
Definition: grouped_convolution_forward_kernel.hpp:402
static constexpr auto I3
Definition: grouped_convolution_forward_kernel.hpp:377
GroupedConvFwdKernelArgs< GroupedConvTraitsType > GroupedConvFwdKernelArgsSpecialized
Definition: grouped_convolution_forward_kernel.hpp:369
remove_cvref_t< typename GemmPipeline::BLayout > GemmBLayout
Definition: grouped_convolution_forward_kernel.hpp:349
remove_cvref_t< typename EpiloguePipeline::ODataType > OutDataType
Definition: grouped_convolution_forward_kernel.hpp:367
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: grouped_convolution_forward_kernel.hpp:410
remove_cvref_t< typename GroupedConvTraitsType::InLayout > InLayout
Definition: grouped_convolution_forward_kernel.hpp:352
static CK_TILE_DEVICE auto MakeGemmTensorViews(const InDataType *a_ptr, const WeiDataType *b_ptr, const std::array< const void *, NumDTensor > &ds_ptr, OutDataType *c_ptr, const GroupedConvFwdKernelArgsSpecialized &kargs)
Definition: grouped_convolution_forward_kernel.hpp:541
Definition: transform_conv_fwd_to_gemm.hpp:19
Definition: integral_constant.hpp:13
std::vector< ck_tile::long_index_t > input_spatial_lengths_
Definition: convolution_parameter.hpp:130
ck_tile::long_index_t K_
Definition: convolution_parameter.hpp:126
std::vector< ck_tile::long_index_t > output_spatial_lengths_
Definition: convolution_parameter.hpp:131
std::vector< ck_tile::long_index_t > input_right_pads_
Definition: convolution_parameter.hpp:137
ck_tile::long_index_t G_
Definition: convolution_parameter.hpp:124
std::vector< ck_tile::long_index_t > conv_filter_strides_
Definition: convolution_parameter.hpp:133
std::vector< ck_tile::long_index_t > filter_spatial_lengths_
Definition: convolution_parameter.hpp:129
ck_tile::long_index_t C_
Definition: convolution_parameter.hpp:127
ck_tile::long_index_t N_
Definition: convolution_parameter.hpp:125
std::vector< ck_tile::long_index_t > input_left_pads_
Definition: convolution_parameter.hpp:136
std::vector< ck_tile::long_index_t > conv_filter_dilations_
Definition: convolution_parameter.hpp:134
Definition: type_traits.hpp:115
Definition: sequence.hpp:52
#define CK_TILE_ENV(name)
Definition: env.hpp:145