/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp Source File
device_column_to_image_impl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
13 
15 
21 #include "ck/host_utility/io.hpp"
22 
23 namespace ck {
24 namespace tensor_operation {
25 namespace device {
26 
27 // Column to Image:
28 // input : gemm form [G, N * Do * Ho * Wo, Z * Y * X * C]
29 // output : input image [G, N, Di, Hi, Wi, C]
30 // input : gemm form [N * Do * Ho * Wo, G, Z * Y * X * C]
31 // output : input image [N, Di, Hi, Wi, G, C]
32 template <index_t NDimSpatial,
33  typename ImageLayout,
34  typename InputDataType,
35  typename OutputDataType,
36  index_t BlockSize,
37  index_t MPerBlock,
38  index_t KPerBlock,
39  typename ThreadClusterLengths,
40  index_t ScalarPerVector,
41  typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
43  : public DeviceConvTensorRearrange<NDimSpatial,
44  ImageLayout,
45  InputDataType,
46  OutputDataType,
47  conv_tensor_rearrange_op::ColumnToImage>
48 {
49  static constexpr bool is_NSpatialGC =
50  std::is_same_v<ImageLayout, tensor_layout::convolution::NWGC> ||
51  std::is_same_v<ImageLayout, tensor_layout::convolution::NHWGC> ||
52  std::is_same_v<ImageLayout, tensor_layout::convolution::NDHWGC>;
53  static constexpr bool is_GNSpatialC =
54  std::is_same_v<ImageLayout, tensor_layout::convolution::GNWC> ||
55  std::is_same_v<ImageLayout, tensor_layout::convolution::GNHWC> ||
56  std::is_same_v<ImageLayout, tensor_layout::convolution::GNDHWC>;
57 
58  static constexpr auto I0 = Number<0>{};
59  static constexpr auto I1 = Number<1>{};
60  static constexpr auto I2 = Number<2>{};
61 
62  static constexpr auto ZIdx = Number<I0>{};
63  static constexpr auto YIdx = NDimSpatial == 1 ? I0 : Number<NDimSpatial - I2>{};
64  static constexpr auto XIdx = Number<NDimSpatial - I1>{};
65 
66  static constexpr auto spatial_offset = Number<3>{};
67 
70  static constexpr auto matrix_padder =
72  MPerBlock, 0 /* NPerBlock*/, KPerBlock};
73 
74  // Calculate number of independent filters for given conv params
75  static index_t GetNumberOfIndependentFilters(const index_t input_spatial_len,
76  const index_t left_pad,
77  const index_t right_pad,
78  const index_t filter_len,
79  const index_t filter_stride,
80  const index_t filter_dilation,
81  const index_t image_offset)
82  {
83  const index_t x_eff = (filter_len - 1) * filter_dilation + 1;
84  const index_t next_filter_padded =
85  math::integer_divide_ceil(x_eff, filter_stride) * filter_stride;
86  // If filter_stride >= x_eff then each filter is independent
87  const index_t independent_filter_stride =
88  filter_stride >= x_eff ? filter_stride : next_filter_padded;
89  const index_t w_eff = input_spatial_len - image_offset + left_pad + right_pad - x_eff;
90  // There are no independent filters
91  if(w_eff < 0)
92  return 0;
93  const index_t independent_kernels_num = w_eff / independent_filter_stride + 1;
94  return independent_kernels_num;
95  }
96 
97  // Make column form descriptor
98  static auto
100  const ck::index_t C,
101  const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
102  const std::array<index_t, NDimSpatial>& output_spatial_lengths,
103  const std::array<index_t, NDimSpatial>& conv_filter_strides,
104  const std::array<index_t, 3>& gemm_g_m_k_strides,
105  const std::array<index_t, NDimSpatial>& independent_filters,
106  const std::array<index_t, NDimSpatial>& effs)
107  {
108  const index_t DoHoWo = ck::accumulate_n<index_t>(
109  output_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
110  const index_t CZYX =
111  C * ck::accumulate_n<index_t>(
112  filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
113 
114  const index_t NStride = DoHoWo * gemm_g_m_k_strides[I1] * gemm_g_m_k_strides[I2];
115  // Calculate the appropriate stride for each set of independent filters
116  // in each dimension
117  const index_t WStride = math::integer_divide_ceil(effs[XIdx], conv_filter_strides[XIdx]) *
118  gemm_g_m_k_strides[I1];
119  const index_t HStride = math::integer_divide_ceil(effs[YIdx], conv_filter_strides[YIdx]) *
120  output_spatial_lengths[XIdx] * gemm_g_m_k_strides[I1];
121  const index_t DStride = math::integer_divide_ceil(effs[ZIdx], conv_filter_strides[ZIdx]) *
122  output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx] *
123  gemm_g_m_k_strides[I1];
124  // Create descriptor for independent filters in each dimension and
125  // then merge them into column form
126  if constexpr(NDimSpatial == 1)
127  {
128  const auto desc_gemm_form =
129  make_naive_tensor_descriptor(make_tuple(N, independent_filters[XIdx], CZYX),
130  make_tuple(NStride, WStride, gemm_g_m_k_strides[I2]));
131  const auto desc_gemm_form_merged_filters = transform_tensor_descriptor(
132  desc_gemm_form,
133  make_tuple(make_merge_transform(make_tuple(N, independent_filters[XIdx])),
137  const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_gemm_form_merged_filters);
138  return desc_m_k;
139  }
140  else if constexpr(NDimSpatial == 2)
141  {
142  const auto desc_gemm_form = make_naive_tensor_descriptor(
143  make_tuple(N, independent_filters[YIdx], independent_filters[XIdx], CZYX),
144  make_tuple(NStride, HStride, WStride, gemm_g_m_k_strides[I2]));
145  const auto desc_gemm_form_merged_filters = transform_tensor_descriptor(
146  desc_gemm_form,
148  make_tuple(N, independent_filters[YIdx], independent_filters[XIdx])),
152  const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_gemm_form_merged_filters);
153  return desc_m_k;
154  }
155  else if constexpr(NDimSpatial == 3)
156  {
157  const auto desc_gemm_form = make_naive_tensor_descriptor(
158  make_tuple(N,
159  independent_filters[ZIdx],
160  independent_filters[YIdx],
161  independent_filters[XIdx],
162  CZYX),
163  make_tuple(NStride, DStride, HStride, WStride, gemm_g_m_k_strides[I2]));
164  const auto desc_gemm_form_merged_filters = transform_tensor_descriptor(
165  desc_gemm_form,
167  independent_filters[ZIdx],
168  independent_filters[YIdx],
169  independent_filters[XIdx])),
173  const auto desc_m_k = matrix_padder.PadADescriptor_M_K(desc_gemm_form_merged_filters);
174  return desc_m_k;
175  }
176  }
177 
178  // Use MakeADescriptor_M_K from grouped convolution forward
179  static auto
181  const ck::index_t C,
182  const std::array<index_t, NDimSpatial>& input_spatial_lengths,
183  const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
184  const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
185  const std::array<index_t, NDimSpatial>& conv_filter_strides,
186  const std::array<index_t, NDimSpatial>& conv_filter_dilations,
187  const std::array<index_t, NDimSpatial>& input_left_pads,
188  const std::array<index_t, NDimSpatial>& input_right_pads,
189  const std::array<index_t, NDimSpatial>& image_offsets,
190  const std::array<index_t, NDimSpatial>& independent_filters,
191  const std::array<index_t, NDimSpatial>& effs)
192  {
193  std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{1};
194  std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{1};
195  std::array<index_t, NDimSpatial + 3> c_g_n_k_wos_lengths{1};
196 
197  auto copy = [](const auto& x, auto& y, index_t dst_offset) {
198  std::copy(x.begin(), x.end(), y.begin() + dst_offset);
199  };
200 
201  copy(input_spatial_lengths, a_g_n_c_wis_lengths, spatial_offset);
202  copy(filter_spatial_lengths, b_g_k_c_xs_lengths, spatial_offset);
203  // Calculate descriptor only for independent filters
204  copy(independent_filters, c_g_n_k_wos_lengths, spatial_offset);
205 
206  // fill only significant values (C and N)
207  a_g_n_c_wis_lengths[I1] = N;
208  a_g_n_c_wis_lengths[I2] = C;
209  b_g_k_c_xs_lengths[I2] = C;
210  c_g_n_k_wos_lengths[I1] = N;
211 
212  // Modify pads to apply offsets
213  std::array<index_t, NDimSpatial> input_left_pads_with_offset;
214  for(index_t i = 0; i < NDimSpatial; i++)
215  {
216  input_left_pads_with_offset[i] = math::max(0, input_left_pads[i] - image_offsets[i]);
217  }
218  // Modify input spatial lengths to apply offsets
219  for(index_t i = 0; i < NDimSpatial; i++)
220  {
221  a_g_n_c_wis_lengths[i + spatial_offset] -=
222  math::max(0, image_offsets[i] - input_left_pads[i]);
223  }
224 
225  // Strides to next independent filters
226  std::array<index_t, NDimSpatial> independent_filter_strides;
227  for(index_t i = 0; i < NDimSpatial; i++)
228  {
229  index_t independent_filter_stride =
230  math::integer_divide_ceil(effs[i], conv_filter_strides[i]) * conv_filter_strides[i];
231  // If conv stride is greater than whole filter size, use conv stride
232  independent_filter_strides[i] = conv_filter_strides[i] >= effs[i]
233  ? conv_filter_strides[i]
234  : independent_filter_stride;
235  }
236 
237  ConvToGemmFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
238  image_g_n_c_wis_strides,
239  b_g_k_c_xs_lengths,
240  {}, // not needed for A Descriptor
241  c_g_n_k_wos_lengths,
242  {}, // not needed for A Descriptor
243  // conv_filter_strides,
244  independent_filter_strides,
245  conv_filter_dilations,
246  input_left_pads_with_offset,
247  input_right_pads};
248 
249  // Calculate image form descriptor for the modified convolution problem
250  const auto in_gemmmraw_gemmkraw_desc =
251  conv_to_gemm_transformer.template MakeADescriptor_M_K<ImageLayout>();
252 
253  const auto in_gemmm_gemmk_desc =
254  matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
255  return in_gemmm_gemmk_desc;
256  }
257 
259  remove_cvref_t<decltype(MakeInputDescriptor_M_K(1, 1, {}, {}, {}, {}, {}, {}))>;
261  1, 1, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))>;
262 
265  InputGridDesc{}))>;
266 
268  InputDataType,
270  OutputDataType,
271  BlockSize,
272  MPerBlock,
273  KPerBlock,
274  ThreadClusterLengths,
275  ScalarPerVector,
278  ComputePtrOffsetOfStridedBatch<>>;
279 
280  struct Argument : public BaseArgument
281  {
282  Argument(const void* p_in, // input image
283  void* p_out, // output image
284  const ck::index_t G,
285  const ck::index_t N,
286  const ck::index_t C,
287  const std::array<index_t, NDimSpatial>& input_spatial_lengths,
288  const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
289  const std::array<index_t, NDimSpatial>& output_spatial_lengths,
290  const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
291  const std::array<index_t, 3>& gemm_g_m_k_strides,
292  const std::array<index_t, NDimSpatial>& conv_filter_strides,
293  const std::array<index_t, NDimSpatial>& conv_filter_dilations,
294  const std::array<index_t, NDimSpatial>& input_left_pads,
295  const std::array<index_t, NDimSpatial>& input_right_pads)
296  : G_(G),
297  C_(C),
298  X_(filter_spatial_lengths[NDimSpatial - I1]),
299  p_in_{static_cast<const InputDataType*>(p_in)},
300  p_out_{static_cast<OutputDataType*>(p_out)},
301  image_g_n_c_wis_strides_{image_g_n_c_wis_strides},
302  conv_filter_strides_{conv_filter_strides},
303  conv_filter_dilations_{conv_filter_dilations},
304  input_left_pads_{input_left_pads},
305  input_right_pads_{input_right_pads}
306  {
307  compute_ptr_offset_of_batch_.BatchStrideA_ = gemm_g_m_k_strides[I0];
308  compute_ptr_offset_of_batch_.BatchStrideC_ = image_g_n_c_wis_strides[I0];
309 
310  const index_t x_eff =
311  (filter_spatial_lengths[XIdx] - 1) * conv_filter_dilations[XIdx] + 1;
312  const index_t y_eff =
313  NDimSpatial < 2
314  ? I1
315  : (filter_spatial_lengths[YIdx] - 1) * conv_filter_dilations[YIdx] + 1;
316  const index_t z_eff =
317  NDimSpatial < 3
318  ? I1
319  : (filter_spatial_lengths[ZIdx] - 1) * conv_filter_dilations[ZIdx] + 1;
320 
321  // Iterate over sets of independent filters
322  for(int z_img_offset = 0; z_img_offset < z_eff;
323  z_img_offset += conv_filter_strides[ZIdx])
324  {
325  for(int y_img_offset = 0; y_img_offset < y_eff;
326  y_img_offset += conv_filter_strides[YIdx])
327  {
328  for(int x_img_offset = 0; x_img_offset < x_eff;
329  x_img_offset += conv_filter_strides[XIdx])
330  {
331 
332  std::array<index_t, NDimSpatial> image_offsets;
333  std::array<index_t, NDimSpatial> effs;
334  // Calculate the starting offset for a given set of
335  // independent filters
336  if constexpr(NDimSpatial == 1)
337  {
338  image_offsets = {x_img_offset};
339  effs = {x_eff};
340  }
341  if constexpr(NDimSpatial == 2)
342  {
343  image_offsets = {y_img_offset, x_img_offset};
344  effs = {y_eff, x_eff};
345  }
346  else if constexpr(NDimSpatial == 3)
347  {
348  image_offsets = {z_img_offset, y_img_offset, x_img_offset};
349  effs = {z_eff, y_eff, x_eff};
350  }
351 
352  std::array<index_t, NDimSpatial> independent_filters;
353  for(index_t i = 0; i < NDimSpatial; i++)
354  {
355  independent_filters[i] =
356  GetNumberOfIndependentFilters(input_spatial_lengths[i],
357  input_left_pads[i],
358  input_right_pads[i],
359  filter_spatial_lengths[i],
360  conv_filter_strides[i],
361  conv_filter_dilations[i],
362  image_offsets[i]);
363  }
364  const index_t independent_filters_acum = ck::accumulate_n<index_t>(
365  independent_filters.begin(), NDimSpatial, 1, std::multiplies<>());
366  if(independent_filters_acum <= 0)
367  continue;
368 
369  const auto in_grid_desc_m_k =
371  C,
372  filter_spatial_lengths,
373  output_spatial_lengths,
374  conv_filter_strides,
375  gemm_g_m_k_strides,
376  independent_filters,
377  effs);
378  const auto out_grid_desc_m_k =
380  C,
381  input_spatial_lengths,
382  filter_spatial_lengths,
383  image_g_n_c_wis_strides,
384  conv_filter_strides,
385  conv_filter_dilations,
386  input_left_pads,
387  input_right_pads,
388  image_offsets,
389  independent_filters,
390  effs);
391  in_grid_desc_m_k_container_.push_back(in_grid_desc_m_k);
392  out_grid_desc_m_k_container_.push_back(out_grid_desc_m_k);
393 
394  const index_t x_idx = x_img_offset / conv_filter_strides[XIdx];
395  const index_t y_idx = y_img_offset / conv_filter_strides[YIdx];
396  const index_t z_idx = z_img_offset / conv_filter_strides[ZIdx];
397 
398  const index_t x_offset_with_pad =
399  math::max(0, x_img_offset - input_left_pads[XIdx]);
400  const index_t y_offset_with_pad =
401  math::max(0, y_img_offset - input_left_pads[YIdx]);
402  const index_t z_offset_with_pad =
403  math::max(0, z_img_offset - input_left_pads[ZIdx]);
404 
405  // Memory offsets to next set of independent filters,
406  // move to independent filters in each dimension
407  const index_t in_offset =
408  (x_idx + y_idx * output_spatial_lengths[XIdx] +
409  z_idx * output_spatial_lengths[YIdx] * output_spatial_lengths[XIdx]) *
410  gemm_g_m_k_strides[I1];
411  // Move to independent filters in appropriate dimensions
412  const index_t out_offset =
413  x_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + XIdx] +
414  y_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + YIdx] +
415  z_offset_with_pad * image_g_n_c_wis_strides[spatial_offset + ZIdx];
416 
417  const InputDataType* p_in_with_offset =
418  static_cast<const InputDataType*>(p_in) + in_offset;
419  OutputDataType* p_out_with_offset =
420  static_cast<OutputDataType*>(p_out) + out_offset;
421  p_in_container_.push_back(p_in_with_offset);
422  p_out_container_.push_back(p_out_with_offset);
423  }
424  }
425  }
426  }
427 
428  void Print() const
429  {
430  for(std::size_t i = 0; i < in_grid_desc_m_k_container_.size(); i++)
431  {
432  std::cout << in_grid_desc_m_k_container_[i] << std::endl;
433  std::cout << out_grid_desc_m_k_container_[i] << std::endl;
434  }
435  }
436 
440 
441  const InputDataType* p_in_;
442  OutputDataType* p_out_;
443 
444  const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides_;
445  const std::array<index_t, NDimSpatial>& conv_filter_strides_;
446  const std::array<index_t, NDimSpatial>& conv_filter_dilations_;
447  const std::array<index_t, NDimSpatial>& input_left_pads_;
448  const std::array<index_t, NDimSpatial>& input_right_pads_;
449 
450  std::vector<InputGridDesc> in_grid_desc_m_k_container_;
451  std::vector<OutputGridDesc> out_grid_desc_m_k_container_;
452 
453  std::vector<const InputDataType*> p_in_container_;
454  std::vector<OutputDataType*> p_out_container_;
455 
456  ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_;
457  };
458 
459  struct Invoker : public BaseInvoker
460  {
461  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
462  {
463  if(stream_config.log_level_ > 0)
464  {
465  arg.Print();
466  }
467 
468  float elapsed_time = 0.f;
469  const auto kernel = kernel_tensor_rearrange<InputGridDesc,
470  InputDataType,
472  OutputDataType,
474  ComputePtrOffsetOfStridedBatch<>,
476 
477  // Execute each set of independent filters
478  for(std::size_t i = 0; i < arg.in_grid_desc_m_k_container_.size(); i++)
479  {
480  const auto block_2_tile_map =
483  const index_t grid_size =
484  block_2_tile_map.CalculateGridSize(arg.in_grid_desc_m_k_container_[i]) * arg.G_;
485  elapsed_time += launch_and_time_kernel(stream_config,
486  kernel,
487  dim3(grid_size),
488  dim3(BlockSize),
489  0,
491  arg.p_in_container_[i],
493  arg.p_out_container_[i],
494  arg.G_,
495  block_2_tile_map,
497  }
498  return elapsed_time;
499  }
500 
501  float Run(const BaseArgument* p_arg,
502  const StreamConfig& stream_config = StreamConfig{}) override
503  {
504  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
505  }
506  };
507 
508  bool IsSupportedArgument(const Argument& arg)
509  {
510  using namespace tensor_layout::convolution;
511  if constexpr(!(is_NSpatialGC || is_GNSpatialC))
512  {
513  return false;
514  }
515 
516  const auto w_pad_left = arg.input_left_pads_[NDimSpatial - I1];
517  const auto w_pad_right = arg.input_right_pads_[NDimSpatial - I1];
518  const auto dilation_x = arg.conv_filter_dilations_[NDimSpatial - I1];
519  const auto stride_x = arg.conv_filter_strides_[NDimSpatial - I1];
520  bool is_w_packed = arg.image_g_n_c_wis_strides_[NDimSpatial + I2] == arg.C_;
521  bool is_c_packed = arg.image_g_n_c_wis_strides_[I2] == 1;
522 
523  // check vector acces with c not packed
524  if(!is_c_packed && ScalarPerVector != 1)
525  return false;
526  // check vector access of filter window row (only C if C is not packed)
527  if(!is_w_packed && arg.C_ % ScalarPerVector != 0)
528  return false;
529  // check vector access of filter window row (X * C)
530  if(arg.X_ * arg.C_ % ScalarPerVector != 0)
531  return false;
532  // check vector access of pads (w_pad_left/w_pad_right * C)
533  if(w_pad_left * arg.C_ % ScalarPerVector != 0 ||
534  w_pad_right * arg.C_ % ScalarPerVector != 0)
535  return false;
536  // check vector access of with stride and pad
537  if((w_pad_left != 0 || w_pad_right != 0) && stride_x > 1 && arg.C_ % ScalarPerVector != 0)
538  return false;
539  // check vector access of with dilation
540  if(dilation_x > 1 && arg.C_ % ScalarPerVector != 0)
541  return false;
542 
543  bool valid = true;
544  for(std::size_t i = 0; i < arg.in_grid_desc_m_k_container_.size(); i++)
545  {
548  }
549  return valid;
550  }
551 
552  bool IsSupportedArgument(const BaseArgument* p_arg) override
553  {
554  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
555  }
556 
557  static auto MakeArgument(const void* p_in, // input image
558  void* p_out, // output image
559  const ck::index_t G,
560  const ck::index_t N,
561  const ck::index_t C,
562  const std::array<index_t, NDimSpatial>& input_spatial_lengths,
563  const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
564  const std::array<index_t, NDimSpatial>& output_spatial_lengths,
565  const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
566  const std::array<index_t, 3>& gemm_g_m_k_strides,
567  const std::array<index_t, NDimSpatial>& conv_filter_strides,
568  const std::array<index_t, NDimSpatial>& conv_filter_dilations,
569  const std::array<index_t, NDimSpatial>& input_left_pads,
570  const std::array<index_t, NDimSpatial>& input_right_pads)
571  {
572  return Argument{static_cast<const InputDataType*>(p_in),
573  static_cast<OutputDataType*>(p_out),
574  G,
575  N,
576  C,
577  input_spatial_lengths,
578  filter_spatial_lengths,
579  output_spatial_lengths,
580  image_g_n_c_wis_strides,
581  gemm_g_m_k_strides,
582  conv_filter_strides,
583  conv_filter_dilations,
584  input_left_pads,
585  input_right_pads};
586  }
587 
588  static auto MakeInvoker() { return Invoker{}; }
589 
590  std::unique_ptr<BaseArgument>
591  MakeArgumentPointer(const void* p_in, // input image
592  void* p_out, // output image
593  const ck::index_t G,
594  const ck::index_t N,
595  const ck::index_t C,
596  const std::array<index_t, NDimSpatial>& input_spatial_lengths,
597  const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
598  const std::array<index_t, NDimSpatial>& output_spatial_lengths,
599  const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
600  const std::array<index_t, 3>& gemm_g_m_k_strides,
601  const std::array<index_t, NDimSpatial>& conv_filter_strides,
602  const std::array<index_t, NDimSpatial>& conv_filter_dilations,
603  const std::array<index_t, NDimSpatial>& input_left_pads,
604  const std::array<index_t, NDimSpatial>& input_right_pads) override
605  {
606  return std::make_unique<Argument>(static_cast<const InputDataType*>(p_in),
607  static_cast<OutputDataType*>(p_out),
608  G,
609  N,
610  C,
611  input_spatial_lengths,
612  filter_spatial_lengths,
613  output_spatial_lengths,
614  image_g_n_c_wis_strides,
615  gemm_g_m_k_strides,
616  conv_filter_strides,
617  conv_filter_dilations,
618  input_left_pads,
619  input_right_pads);
620  }
621 
622  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
623  {
624  return std::make_unique<Invoker>(Invoker{});
625  }
626 
627  std::string GetTypeString() const override
628  {
629  auto str = std::stringstream();
630 
631  // clang-format off
632  str << "DeviceColumnToImage"
633  << "<"
634  << BlockSize << ", "
635  << MPerBlock << ", "
636  << KPerBlock << ", "
637  << ScalarPerVector
638  << ">";
639  // clang-format on
640 
641  return str.str();
642  }
643 };
644 
645 } // namespace device
646 } // namespace tensor_operation
647 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
auto copy(InputRange &&range, OutputIterator iter) -> decltype(std::copy(std::begin(std::forward< InputRange >(range)), std::end(std::forward< InputRange >(range)), iter))
Definition: algorithm.hpp:14
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__global__ void kernel_tensor_rearrange(const InputGridDesc in_grid_desc, const InputDataType *__restrict__ p_in_global, const OutputGridDesc out_grid_desc, OutputDataType *__restrict__ p_out_global, const index_t batch_count, const Block2ETileMap block_2_tile_map, const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
Definition: gridwise_tensor_rearrange.hpp:30
Definition: stream_config.hpp:10
Definition: block_to_ctile_map.hpp:260
Definition: gridwise_tensor_rearrange.hpp:72
static constexpr __host__ bool CheckValidity(const InputGridDesc &in_grid_desc, const OutputGridDesc &out_grid_desc)
Definition: gridwise_tensor_rearrange.hpp:138
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: transform_conv_fwd_to_gemm.hpp:24
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_column_to_image_impl.hpp:281
std::vector< InputGridDesc > in_grid_desc_m_k_container_
Definition: device_column_to_image_impl.hpp:450
Argument(const void *p_in, void *p_out, const ck::index_t G, const ck::index_t N, const ck::index_t C, const std::array< index_t, NDimSpatial > &input_spatial_lengths, const std::array< index_t, NDimSpatial > &filter_spatial_lengths, const std::array< index_t, NDimSpatial > &output_spatial_lengths, const std::array< index_t, NDimSpatial+3 > &image_g_n_c_wis_strides, const std::array< index_t, 3 > &gemm_g_m_k_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads)
Definition: device_column_to_image_impl.hpp:282
std::vector< OutputGridDesc > out_grid_desc_m_k_container_
Definition: device_column_to_image_impl.hpp:451
std::vector< OutputDataType * > p_out_container_
Definition: device_column_to_image_impl.hpp:454
const std::array< index_t, NDimSpatial+3 > & image_g_n_c_wis_strides_
Definition: device_column_to_image_impl.hpp:444
const std::array< index_t, NDimSpatial > & conv_filter_dilations_
Definition: device_column_to_image_impl.hpp:446
const ck::index_t X_
Definition: device_column_to_image_impl.hpp:439
OutputDataType * p_out_
Definition: device_column_to_image_impl.hpp:442
const std::array< index_t, NDimSpatial > & input_right_pads_
Definition: device_column_to_image_impl.hpp:448
const ck::index_t G_
Definition: device_column_to_image_impl.hpp:437
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_
Definition: device_column_to_image_impl.hpp:456
void Print() const
Definition: device_column_to_image_impl.hpp:428
const InputDataType * p_in_
Definition: device_column_to_image_impl.hpp:441
const ck::index_t C_
Definition: device_column_to_image_impl.hpp:438
std::vector< const InputDataType * > p_in_container_
Definition: device_column_to_image_impl.hpp:453
const std::array< index_t, NDimSpatial > & input_left_pads_
Definition: device_column_to_image_impl.hpp:447
const std::array< index_t, NDimSpatial > & conv_filter_strides_
Definition: device_column_to_image_impl.hpp:445
Definition: device_column_to_image_impl.hpp:460
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_column_to_image_impl.hpp:501
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_column_to_image_impl.hpp:461
Definition: device_column_to_image_impl.hpp:48
remove_cvref_t< decltype(BlockToCTileMap_M00_N0_M01Adapt< MPerBlock, KPerBlock, InputGridDesc >(InputGridDesc{}))> Block2ETileMap
Definition: device_column_to_image_impl.hpp:265
std::string GetTypeString() const override
Definition: device_column_to_image_impl.hpp:627
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_column_to_image_impl.hpp:622
bool IsSupportedArgument(const Argument &arg)
Definition: device_column_to_image_impl.hpp:508
static auto MakeOutDescriptor_M_K(const ck::index_t N, const ck::index_t C, const std::array< index_t, NDimSpatial > &input_spatial_lengths, const std::array< index_t, NDimSpatial > &filter_spatial_lengths, const std::array< index_t, NDimSpatial+3 > &image_g_n_c_wis_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads, const std::array< index_t, NDimSpatial > &image_offsets, const std::array< index_t, NDimSpatial > &independent_filters, const std::array< index_t, NDimSpatial > &effs)
Definition: device_column_to_image_impl.hpp:180
GridwiseTensorRearrange< InputGridDesc, InputDataType, OutputGridDesc, OutputDataType, BlockSize, MPerBlock, KPerBlock, ThreadClusterLengths, ScalarPerVector, InMemoryDataOperationEnum::Add, Block2ETileMap, ComputePtrOffsetOfStridedBatch<> > GridwiseTensorRearrangeKernel
Definition: device_column_to_image_impl.hpp:278
static constexpr auto matrix_padder
Definition: device_column_to_image_impl.hpp:70
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in, void *p_out, const ck::index_t G, const ck::index_t N, const ck::index_t C, const std::array< index_t, NDimSpatial > &input_spatial_lengths, const std::array< index_t, NDimSpatial > &filter_spatial_lengths, const std::array< index_t, NDimSpatial > &output_spatial_lengths, const std::array< index_t, NDimSpatial+3 > &image_g_n_c_wis_strides, const std::array< index_t, 3 > &gemm_g_m_k_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads) override
Make argument pointer for image to column.
Definition: device_column_to_image_impl.hpp:591
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_column_to_image_impl.hpp:552
static constexpr auto ZIdx
Definition: device_column_to_image_impl.hpp:62
static auto MakeInputDescriptor_M_K(const ck::index_t N, const ck::index_t C, const std::array< index_t, NDimSpatial > &filter_spatial_lengths, const std::array< index_t, NDimSpatial > &output_spatial_lengths, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, 3 > &gemm_g_m_k_strides, const std::array< index_t, NDimSpatial > &independent_filters, const std::array< index_t, NDimSpatial > &effs)
Definition: device_column_to_image_impl.hpp:99
remove_cvref_t< decltype(MakeOutDescriptor_M_K(1, 1, {}, {}, {}, {}, {}, {}, {}, {}, {}, {}))> OutputGridDesc
Definition: device_column_to_image_impl.hpp:261
static constexpr bool is_GNSpatialC
Definition: device_column_to_image_impl.hpp:53
static index_t GetNumberOfIndependentFilters(const index_t input_spatial_len, const index_t left_pad, const index_t right_pad, const index_t filter_len, const index_t filter_stride, const index_t filter_dilation, const index_t image_offset)
Definition: device_column_to_image_impl.hpp:75
static constexpr auto I1
Definition: device_column_to_image_impl.hpp:59
static constexpr auto I0
Definition: device_column_to_image_impl.hpp:58
static constexpr auto XIdx
Definition: device_column_to_image_impl.hpp:64
static constexpr auto I2
Definition: device_column_to_image_impl.hpp:60
remove_cvref_t< decltype(MakeInputDescriptor_M_K(1, 1, {}, {}, {}, {}, {}, {}))> InputGridDesc
Definition: device_column_to_image_impl.hpp:259
static constexpr auto spatial_offset
Definition: device_column_to_image_impl.hpp:66
static auto MakeInvoker()
Definition: device_column_to_image_impl.hpp:588
static auto MakeArgument(const void *p_in, void *p_out, const ck::index_t G, const ck::index_t N, const ck::index_t C, const std::array< index_t, NDimSpatial > &input_spatial_lengths, const std::array< index_t, NDimSpatial > &filter_spatial_lengths, const std::array< index_t, NDimSpatial > &output_spatial_lengths, const std::array< index_t, NDimSpatial+3 > &image_g_n_c_wis_strides, const std::array< index_t, 3 > &gemm_g_m_k_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads)
Definition: device_column_to_image_impl.hpp:557
static constexpr auto YIdx
Definition: device_column_to_image_impl.hpp:63
static constexpr bool is_NSpatialGC
Definition: device_column_to_image_impl.hpp:49
Convolution Tensor Rearrange.
Definition: device_conv_tensor_rearrange.hpp:36
Definition: matrix_padder.hpp:180