/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_image_to_column_impl.hpp Source File
device_image_to_column_impl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
20 #include "ck/host_utility/io.hpp"
21 
22 namespace ck {
23 namespace tensor_operation {
24 namespace device {
25 
26 // Image to column:
27 // input : input image [G, N, Di, Hi, Wi, C]
28 // output : gemm form [G * N * Do * Ho * Wo, Z * Y * X * C]
29 // input : input image [N, Di, Hi, Wi, G, C]
30 // output : gemm form [N * Do * Ho * Wo * G, Z * Y * X * C]
31 template <index_t NDimSpatial,
32  typename ImageLayout,
33  typename InputDataType,
34  typename OutputDataType,
35  index_t BlockSize,
36  index_t MPerBlock,
37  index_t KPerBlock,
38  typename ThreadClusterLengths,
39  index_t ScalarPerVector,
40  typename std::enable_if<NDimSpatial >= 1 && NDimSpatial <= 3, bool>::type = false>
42  : public DeviceConvTensorRearrange<NDimSpatial,
43  ImageLayout,
44  InputDataType,
45  OutputDataType,
46  conv_tensor_rearrange_op::ImageToColumn>
47 {
48  static constexpr bool is_NSpatialGC =
49  std::is_same_v<ImageLayout, tensor_layout::convolution::NWGC> ||
50  std::is_same_v<ImageLayout, tensor_layout::convolution::NHWGC> ||
51  std::is_same_v<ImageLayout, tensor_layout::convolution::NDHWGC>;
52  static constexpr bool is_GNSpatialC =
53  std::is_same_v<ImageLayout, tensor_layout::convolution::GNWC> ||
54  std::is_same_v<ImageLayout, tensor_layout::convolution::GNHWC> ||
55  std::is_same_v<ImageLayout, tensor_layout::convolution::GNDHWC>;
56 
57  static constexpr auto I0 = Number<0>{};
58  static constexpr auto I1 = Number<1>{};
59  static constexpr auto I2 = Number<2>{};
60 
63 
64  static constexpr auto matrix_padder =
66  MPerBlock, 0 /* NPerBlock*/, KPerBlock};
67 
68  // Use MakeADescriptor_M_K from grouped convolution forward
69  static auto
71  const ck::index_t C,
72  const std::array<index_t, NDimSpatial>& input_spatial_lengths,
73  const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
74  const std::array<index_t, NDimSpatial>& output_spatial_lengths,
75  const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
76  const std::array<index_t, NDimSpatial>& conv_filter_strides,
77  const std::array<index_t, NDimSpatial>& conv_filter_dilations,
78  const std::array<index_t, NDimSpatial>& input_left_pads,
79  const std::array<index_t, NDimSpatial>& input_right_pads)
80  {
81  std::array<index_t, NDimSpatial + 3> a_g_n_c_wis_lengths{1};
82  std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths{1};
83  std::array<index_t, NDimSpatial + 3> c_g_n_k_wos_lengths{1};
84 
85  auto copy = [](const auto& x, auto& y, index_t dst_offset) {
86  std::copy(x.begin(), x.end(), y.begin() + dst_offset);
87  };
88 
89  constexpr index_t spatial_offset = 3;
90 
91  copy(input_spatial_lengths, a_g_n_c_wis_lengths, spatial_offset);
92  copy(filter_spatial_lengths, b_g_k_c_xs_lengths, spatial_offset);
93  copy(output_spatial_lengths, c_g_n_k_wos_lengths, spatial_offset);
94 
95  // fill only significant values (C and N)
96  a_g_n_c_wis_lengths[I1] = N;
97  a_g_n_c_wis_lengths[I2] = C;
98  b_g_k_c_xs_lengths[I2] = C;
99  c_g_n_k_wos_lengths[I1] = N;
100 
101  ConvToGemmFwdTransformer conv_to_gemm_transformer{a_g_n_c_wis_lengths,
102  image_g_n_c_wis_strides,
103  b_g_k_c_xs_lengths,
104  {}, // not needed for A Descriptor
105  c_g_n_k_wos_lengths,
106  {}, // not needed for A Descriptor
107  conv_filter_strides,
108  conv_filter_dilations,
109  input_left_pads,
110  input_right_pads};
111 
112  const auto in_gemmmraw_gemmkraw_desc =
113  conv_to_gemm_transformer.template MakeADescriptor_M_K<ImageLayout>();
114 
115  const auto in_gemmm_gemmk_desc =
116  matrix_padder.PadADescriptor_M_K(in_gemmmraw_gemmkraw_desc);
117  return in_gemmm_gemmk_desc;
118  }
119 
120  static auto
122  const ck::index_t C,
123  const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
124  const std::array<index_t, NDimSpatial>& output_spatial_lengths,
125  const std::array<index_t, 3>& gemm_g_m_k_strides)
126  {
127  const index_t NDoHoWo =
128  N * ck::accumulate_n<index_t>(
129  output_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
130  const index_t CZYX =
131  C * ck::accumulate_n<index_t>(
132  filter_spatial_lengths.begin(), NDimSpatial, 1, std::multiplies<>());
133 
134  const auto desc_mraw_kraw = make_naive_tensor_descriptor(
135  make_tuple(NDoHoWo, CZYX), make_tuple(gemm_g_m_k_strides[I1], gemm_g_m_k_strides[I2]));
136  return matrix_padder.PadADescriptor_M_K(desc_mraw_kraw);
137  }
138 
140  remove_cvref_t<decltype(MakeInputDescriptor_M_K(1, 1, {}, {}, {}, {}, {}, {}, {}, {}))>;
141  using OutputGridDesc = remove_cvref_t<decltype(MakeOutDescriptor_M_K(1, 1, {}, {}, {}))>;
142 
145  OutputGridDesc{}))>;
146 
148  InputDataType,
150  OutputDataType,
151  BlockSize,
152  MPerBlock,
153  KPerBlock,
154  ThreadClusterLengths,
155  ScalarPerVector,
158  ComputePtrOffsetOfStridedBatch<>>;
159 
160  struct Argument : public BaseArgument
161  {
162  Argument(const void* p_in, // input image
163  void* p_out, // gemm form
164  const ck::index_t G,
165  const ck::index_t N,
166  const ck::index_t C,
167  const std::array<index_t, NDimSpatial>& input_spatial_lengths,
168  const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
169  const std::array<index_t, NDimSpatial>& output_spatial_lengths,
170  const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
171  const std::array<index_t, 3>& gemm_g_m_k_strides,
172  const std::array<index_t, NDimSpatial>& conv_filter_strides,
173  const std::array<index_t, NDimSpatial>& conv_filter_dilations,
174  const std::array<index_t, NDimSpatial>& input_left_pads,
175  const std::array<index_t, NDimSpatial>& input_right_pads)
176  : G_(G),
177  C_(C),
178  X_(filter_spatial_lengths[NDimSpatial - I1]),
179  p_in_{static_cast<const InputDataType*>(p_in)},
180  p_out_{static_cast<OutputDataType*>(p_out)},
181  image_g_n_c_wis_strides_{image_g_n_c_wis_strides},
182  conv_filter_strides_{conv_filter_strides},
183  conv_filter_dilations_{conv_filter_dilations},
184  input_left_pads_{input_left_pads},
185  input_right_pads_{input_right_pads}
186  {
187 
189  C,
190  input_spatial_lengths,
191  filter_spatial_lengths,
192  output_spatial_lengths,
193  image_g_n_c_wis_strides,
194  conv_filter_strides,
195  conv_filter_dilations,
196  input_left_pads,
197  input_right_pads);
198 
200  N, C, filter_spatial_lengths, output_spatial_lengths, gemm_g_m_k_strides);
201 
202  compute_ptr_offset_of_batch_.BatchStrideA_ = image_g_n_c_wis_strides[I0];
203  compute_ptr_offset_of_batch_.BatchStrideC_ = gemm_g_m_k_strides[I0];
204  }
205 
206  void Print() const
207  {
208  std::cout << in_grid_desc_m_k_ << std::endl;
209  std::cout << out_grid_desc_m_k_ << std::endl;
210  }
211 
215 
216  const InputDataType* p_in_;
217  OutputDataType* p_out_;
218 
219  const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides_;
220  const std::array<index_t, NDimSpatial>& conv_filter_strides_;
221  const std::array<index_t, NDimSpatial>& conv_filter_dilations_;
222  const std::array<index_t, NDimSpatial>& input_left_pads_;
223  const std::array<index_t, NDimSpatial>& input_right_pads_;
224 
227 
228  ComputePtrOffsetOfStridedBatch<> compute_ptr_offset_of_batch_;
229  };
230 
231  struct Invoker : public BaseInvoker
232  {
233  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
234  {
235  if(stream_config.log_level_ > 0)
236  {
237  arg.Print();
238  }
239 
240  const auto block_2_tile_map =
242  arg.out_grid_desc_m_k_);
243  const index_t grid_size =
244  block_2_tile_map.CalculateGridSize(arg.out_grid_desc_m_k_) * arg.G_;
245  const auto kernel = kernel_tensor_rearrange<InputGridDesc,
246  InputDataType,
248  OutputDataType,
250  ComputePtrOffsetOfStridedBatch<>,
252 
253  float elapsed_time = launch_and_time_kernel(stream_config,
254  kernel,
255  dim3(grid_size),
256  dim3(BlockSize),
257  0,
258  arg.in_grid_desc_m_k_,
259  arg.p_in_,
260  arg.out_grid_desc_m_k_,
261  arg.p_out_,
262  arg.G_,
263  block_2_tile_map,
265  return elapsed_time;
266  }
267 
268  float Run(const BaseArgument* p_arg,
269  const StreamConfig& stream_config = StreamConfig{}) override
270  {
271  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
272  }
273  };
274 
275  bool IsSupportedArgument(const Argument& arg)
276  {
277  if constexpr(!(is_NSpatialGC || is_GNSpatialC))
278  {
279  return false;
280  }
281 
282  const auto w_pad_left = arg.input_left_pads_[NDimSpatial - I1];
283  const auto w_pad_right = arg.input_right_pads_[NDimSpatial - I1];
284  const auto dilation_x = arg.conv_filter_dilations_[NDimSpatial - I1];
285  const auto stride_x = arg.conv_filter_strides_[NDimSpatial - I1];
286  bool is_w_packed = arg.image_g_n_c_wis_strides_[NDimSpatial + I2] == arg.C_;
287  bool is_c_packed = arg.image_g_n_c_wis_strides_[I2] == 1;
288 
289  // check vector acces with c not packed
290  if(!is_c_packed && ScalarPerVector != 1)
291  return false;
292  // check vector access of filter window row (only C if C is not packed)
293  if(!is_w_packed && arg.C_ % ScalarPerVector != 0)
294  return false;
295  // check vector access of filter window row (X * C)
296  if(arg.X_ * arg.C_ % ScalarPerVector != 0)
297  return false;
298  // check vector access of pads (w_pad_left/w_pad_right * C)
299  if(w_pad_left * arg.C_ % ScalarPerVector != 0 ||
300  w_pad_right * arg.C_ % ScalarPerVector != 0)
301  return false;
302  // check vector access of with stride and pad
303  if((w_pad_left != 0 || w_pad_right != 0) && stride_x > 1 && arg.C_ % ScalarPerVector != 0)
304  return false;
305  // check vector access of with dilation
306  if(dilation_x > 1 && arg.C_ % ScalarPerVector != 0)
307  return false;
308 
310  arg.out_grid_desc_m_k_);
311  }
312 
313  bool IsSupportedArgument(const BaseArgument* p_arg) override
314  {
315  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
316  }
317 
318  static auto MakeArgument(const void* p_in, // input image
319  void* p_out, // gemm form
320  const ck::index_t G,
321  const ck::index_t N,
322  const ck::index_t C,
323  const std::array<index_t, NDimSpatial>& input_spatial_lengths,
324  const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
325  const std::array<index_t, NDimSpatial>& output_spatial_lengths,
326  const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
327  const std::array<index_t, 3>& gemm_g_m_k_strides,
328  const std::array<index_t, NDimSpatial>& conv_filter_strides,
329  const std::array<index_t, NDimSpatial>& conv_filter_dilations,
330  const std::array<index_t, NDimSpatial>& input_left_pads,
331  const std::array<index_t, NDimSpatial>& input_right_pads)
332  {
333  return Argument{static_cast<const InputDataType*>(p_in),
334  static_cast<OutputDataType*>(p_out),
335  G,
336  N,
337  C,
338  input_spatial_lengths,
339  filter_spatial_lengths,
340  output_spatial_lengths,
341  image_g_n_c_wis_strides,
342  gemm_g_m_k_strides,
343  conv_filter_strides,
344  conv_filter_dilations,
345  input_left_pads,
346  input_right_pads};
347  }
348 
349  static auto MakeInvoker() { return Invoker{}; }
350 
351  std::unique_ptr<BaseArgument>
352  MakeArgumentPointer(const void* p_in, // input image
353  void* p_out, // gemm form
354  const ck::index_t G,
355  const ck::index_t N,
356  const ck::index_t C,
357  const std::array<index_t, NDimSpatial>& input_spatial_lengths,
358  const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
359  const std::array<index_t, NDimSpatial>& output_spatial_lengths,
360  const std::array<index_t, NDimSpatial + 3>& image_g_n_c_wis_strides,
361  const std::array<index_t, 3>& gemm_g_m_k_strides,
362  const std::array<index_t, NDimSpatial>& conv_filter_strides,
363  const std::array<index_t, NDimSpatial>& conv_filter_dilations,
364  const std::array<index_t, NDimSpatial>& input_left_pads,
365  const std::array<index_t, NDimSpatial>& input_right_pads) override
366  {
367  return std::make_unique<Argument>(static_cast<const InputDataType*>(p_in),
368  static_cast<OutputDataType*>(p_out),
369  G,
370  N,
371  C,
372  input_spatial_lengths,
373  filter_spatial_lengths,
374  output_spatial_lengths,
375  image_g_n_c_wis_strides,
376  gemm_g_m_k_strides,
377  conv_filter_strides,
378  conv_filter_dilations,
379  input_left_pads,
380  input_right_pads);
381  }
382 
383  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
384  {
385  return std::make_unique<Invoker>(Invoker{});
386  }
387 
388  std::string GetTypeString() const override
389  {
390  auto str = std::stringstream();
391 
392  // clang-format off
393  str << "DeviceImageToColumn"
394  << "<"
395  << BlockSize << ", "
396  << MPerBlock << ", "
397  << KPerBlock << ", "
398  << ScalarPerVector
399  << ">";
400  // clang-format on
401 
402  return str.str();
403  }
404 };
405 
406 } // namespace device
407 } // namespace tensor_operation
408 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
auto copy(InputRange &&range, OutputIterator iter) -> decltype(std::copy(std::begin(std::forward< InputRange >(range)), std::end(std::forward< InputRange >(range)), iter))
Definition: algorithm.hpp:14
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
int32_t index_t
Definition: ck.hpp:289
__global__ void kernel_tensor_rearrange(const InputGridDesc in_grid_desc, const InputDataType *__restrict__ p_in_global, const OutputGridDesc out_grid_desc, OutputDataType *__restrict__ p_out_global, const index_t batch_count, const Block2ETileMap block_2_tile_map, const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
Definition: gridwise_tensor_rearrange.hpp:30
Definition: stream_config.hpp:10
Definition: block_to_ctile_map.hpp:260
Definition: gridwise_tensor_rearrange.hpp:72
static constexpr __host__ bool CheckValidity(const InputGridDesc &in_grid_desc, const OutputGridDesc &out_grid_desc)
Definition: gridwise_tensor_rearrange.hpp:138
Definition: integral_constant.hpp:10
Definition: transform_conv_fwd_to_gemm.hpp:24
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Convolution Tensor Rearrange.
Definition: device_conv_tensor_rearrange.hpp:36
Definition: device_image_to_column_impl.hpp:161
InputGridDesc in_grid_desc_m_k_
Definition: device_image_to_column_impl.hpp:225
const std::array< index_t, NDimSpatial+3 > & image_g_n_c_wis_strides_
Definition: device_image_to_column_impl.hpp:219
const std::array< index_t, NDimSpatial > & input_right_pads_
Definition: device_image_to_column_impl.hpp:223
const std::array< index_t, NDimSpatial > & conv_filter_strides_
Definition: device_image_to_column_impl.hpp:220
const InputDataType * p_in_
Definition: device_image_to_column_impl.hpp:216
const std::array< index_t, NDimSpatial > & conv_filter_dilations_
Definition: device_image_to_column_impl.hpp:221
const std::array< index_t, NDimSpatial > & input_left_pads_
Definition: device_image_to_column_impl.hpp:222
const ck::index_t C_
Definition: device_image_to_column_impl.hpp:213
Argument(const void *p_in, void *p_out, const ck::index_t G, const ck::index_t N, const ck::index_t C, const std::array< index_t, NDimSpatial > &input_spatial_lengths, const std::array< index_t, NDimSpatial > &filter_spatial_lengths, const std::array< index_t, NDimSpatial > &output_spatial_lengths, const std::array< index_t, NDimSpatial+3 > &image_g_n_c_wis_strides, const std::array< index_t, 3 > &gemm_g_m_k_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads)
Definition: device_image_to_column_impl.hpp:162
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_
Definition: device_image_to_column_impl.hpp:228
OutputDataType * p_out_
Definition: device_image_to_column_impl.hpp:217
const ck::index_t X_
Definition: device_image_to_column_impl.hpp:214
OutputGridDesc out_grid_desc_m_k_
Definition: device_image_to_column_impl.hpp:226
void Print() const
Definition: device_image_to_column_impl.hpp:206
const ck::index_t G_
Definition: device_image_to_column_impl.hpp:212
Definition: device_image_to_column_impl.hpp:232
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_image_to_column_impl.hpp:233
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_image_to_column_impl.hpp:268
Definition: device_image_to_column_impl.hpp:47
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_in, void *p_out, const ck::index_t G, const ck::index_t N, const ck::index_t C, const std::array< index_t, NDimSpatial > &input_spatial_lengths, const std::array< index_t, NDimSpatial > &filter_spatial_lengths, const std::array< index_t, NDimSpatial > &output_spatial_lengths, const std::array< index_t, NDimSpatial+3 > &image_g_n_c_wis_strides, const std::array< index_t, 3 > &gemm_g_m_k_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads) override
Make argument pointer for image to column.
Definition: device_image_to_column_impl.hpp:352
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_image_to_column_impl.hpp:383
bool IsSupportedArgument(const Argument &arg)
Definition: device_image_to_column_impl.hpp:275
static auto MakeInvoker()
Definition: device_image_to_column_impl.hpp:349
static constexpr auto I0
Definition: device_image_to_column_impl.hpp:57
std::string GetTypeString() const override
Definition: device_image_to_column_impl.hpp:388
static auto MakeArgument(const void *p_in, void *p_out, const ck::index_t G, const ck::index_t N, const ck::index_t C, const std::array< index_t, NDimSpatial > &input_spatial_lengths, const std::array< index_t, NDimSpatial > &filter_spatial_lengths, const std::array< index_t, NDimSpatial > &output_spatial_lengths, const std::array< index_t, NDimSpatial+3 > &image_g_n_c_wis_strides, const std::array< index_t, 3 > &gemm_g_m_k_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads)
Definition: device_image_to_column_impl.hpp:318
static auto MakeOutDescriptor_M_K(const ck::index_t N, const ck::index_t C, const std::array< index_t, NDimSpatial > &filter_spatial_lengths, const std::array< index_t, NDimSpatial > &output_spatial_lengths, const std::array< index_t, 3 > &gemm_g_m_k_strides)
Definition: device_image_to_column_impl.hpp:121
remove_cvref_t< decltype(MakeInputDescriptor_M_K(1, 1, {}, {}, {}, {}, {}, {}, {}, {}))> InputGridDesc
Definition: device_image_to_column_impl.hpp:140
static auto MakeInputDescriptor_M_K(const ck::index_t N, const ck::index_t C, const std::array< index_t, NDimSpatial > &input_spatial_lengths, const std::array< index_t, NDimSpatial > &filter_spatial_lengths, const std::array< index_t, NDimSpatial > &output_spatial_lengths, const std::array< index_t, NDimSpatial+3 > &image_g_n_c_wis_strides, const std::array< index_t, NDimSpatial > &conv_filter_strides, const std::array< index_t, NDimSpatial > &conv_filter_dilations, const std::array< index_t, NDimSpatial > &input_left_pads, const std::array< index_t, NDimSpatial > &input_right_pads)
Definition: device_image_to_column_impl.hpp:70
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_image_to_column_impl.hpp:313
remove_cvref_t< decltype(MakeOutDescriptor_M_K(1, 1, {}, {}, {}))> OutputGridDesc
Definition: device_image_to_column_impl.hpp:141
remove_cvref_t< decltype(BlockToCTileMap_M00_N0_M01Adapt< MPerBlock, KPerBlock, OutputGridDesc >(OutputGridDesc{}))> Block2ETileMap
Definition: device_image_to_column_impl.hpp:145
static constexpr auto matrix_padder
Definition: device_image_to_column_impl.hpp:64
GridwiseTensorRearrange< InputGridDesc, InputDataType, OutputGridDesc, OutputDataType, BlockSize, MPerBlock, KPerBlock, ThreadClusterLengths, ScalarPerVector, InMemoryDataOperationEnum::Set, Block2ETileMap, ComputePtrOffsetOfStridedBatch<> > GridwiseTensorRearrangeKernel
Definition: device_image_to_column_impl.hpp:158
static constexpr bool is_GNSpatialC
Definition: device_image_to_column_impl.hpp:52
static constexpr auto I2
Definition: device_image_to_column_impl.hpp:59
static constexpr auto I1
Definition: device_image_to_column_impl.hpp:58
static constexpr bool is_NSpatialGC
Definition: device_image_to_column_impl.hpp:48
Definition: matrix_padder.hpp:180