/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/helper.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/helper.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/helper.hpp Source File
helper.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
11 #include <fstream>
12 #include <variant>
13 
14 // functions to return the corresponding structs based on generated template parameters
15 
21 // return the layout type: currently this is the only type supported in MIOpen
22 auto layout_type(std::string type)
23 {
24  if(type == "ck::tensor_layout::convolution::NHWGK")
25  {
27  }
28  throw std::runtime_error("Incorrect layout");
29 }
30 // return the right gemm spec based on the generated template parameters
32 {
33  if(type == "ck::tensor_operation::device::GemmSpecialization::Default")
34  {
36  }
37  if(type == "ck::tensor_operation::device::GemmSpecialization::MNKPadding")
38  {
40  }
41  throw std::runtime_error("Incorrect gemm spec: " + type);
42 }
43 
44 // return the type of convolution
46 {
47  if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Default")
48  {
50  }
51  if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0")
52  {
54  }
55  if(type ==
56  "ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0")
57  {
59  }
60  if(type == "ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC")
61  {
63  }
64  throw std::runtime_error("Incorrect conv spec: " + type);
65 }
66 
67 // Function to call on MatrixPadder via a wrapper struct
68 // NOTE: CK only uses MNKPadding for forward convolution
69 template <typename CDesc_MRaw_NRaw>
70 auto pad(ck::index_t mpb,
71  ck::index_t npb,
72  ck::index_t kpb,
74  CDesc_MRaw_NRaw conv)
75 {
77  {
83  a;
84  a.MPerTile_ = mpb;
85  a.NPerTile_ = npb;
86  a.KPerTile_ = kpb;
87  auto tmp = grid_desc(a, conv);
88  return tmp;
89  }
90  throw std::runtime_error("Incorrect template parameters, check gemm spec");
91 }
92 
93 // Functions to call on TransformConvFwdToGemm through wrapper: different functions based on num
94 // dims
95 // FIXME: add a way to properly pass in the layout
98  ck::Array<ck::index_t, 5> out_lengths,
99  ck::Array<ck::index_t, 5> out_strides)
100 {
101  ck::Array<ck::index_t, 5> dummy_dims;
102  ck::Array<ck::index_t, 2> dummy_spatial_dims;
103  if(num_dim == 2 &&
105  {
107  2,
109  conv_fwd{dummy_dims,
110  dummy_dims,
111  dummy_dims,
112  dummy_dims,
113  out_lengths,
114  out_strides,
115  dummy_spatial_dims,
116  dummy_spatial_dims,
117  dummy_spatial_dims,
118  dummy_spatial_dims};
119 
121  return res.transform_func(conv_fwd);
122  }
123  if(num_dim == 2 &&
125  {
127  2,
129  conv_fwd{dummy_dims,
130  dummy_dims,
131  dummy_dims,
132  dummy_dims,
133  out_lengths,
134  out_strides,
135  dummy_spatial_dims,
136  dummy_spatial_dims,
137  dummy_spatial_dims,
138  dummy_spatial_dims};
139 
141  return res.transform_func(conv_fwd);
142  }
143  if(num_dim == 2 &&
145  {
147  2,
149  conv_fwd{dummy_dims,
150  dummy_dims,
151  dummy_dims,
152  dummy_dims,
153  out_lengths,
154  out_strides,
155  dummy_spatial_dims,
156  dummy_spatial_dims,
157  dummy_spatial_dims,
158  dummy_spatial_dims};
159 
161  return res.transform_func(conv_fwd);
162  }
164  {
166  2,
168  conv_fwd{dummy_dims,
169  dummy_dims,
170  dummy_dims,
171  dummy_dims,
172  out_lengths,
173  out_strides,
174  dummy_spatial_dims,
175  dummy_spatial_dims,
176  dummy_spatial_dims,
177  dummy_spatial_dims};
178 
180  return res.transform_func(conv_fwd);
181  }
182  throw std::runtime_error("Incorrect conv spec");
183 }
184 
187  ck::Array<ck::index_t, 6> out_lengths,
188  ck::Array<ck::index_t, 6> out_strides)
189 {
190  ck::Array<ck::index_t, 6> dummy_dims;
191  ck::Array<ck::index_t, 3> dummy_spatial_dims;
192 
193  if(num_dim == 3 &&
195  {
197  3,
199  conv_fwd{dummy_dims,
200  dummy_dims,
201  dummy_dims,
202  dummy_dims,
203  out_lengths,
204  out_strides,
205  dummy_spatial_dims,
206  dummy_spatial_dims,
207  dummy_spatial_dims,
208  dummy_spatial_dims};
209 
211  return res.transform_func(conv_fwd);
212  }
213  if(num_dim == 3 &&
215  {
217  3,
219  conv_fwd{dummy_dims,
220  dummy_dims,
221  dummy_dims,
222  dummy_dims,
223  out_lengths,
224  out_strides,
225  dummy_spatial_dims,
226  dummy_spatial_dims,
227  dummy_spatial_dims,
228  dummy_spatial_dims};
229 
231  return res.transform_func(conv_fwd);
232  }
233  if(num_dim == 3 &&
235  {
237  3,
239  conv_fwd{dummy_dims,
240  dummy_dims,
241  dummy_dims,
242  dummy_dims,
243  out_lengths,
244  out_strides,
245  dummy_spatial_dims,
246  dummy_spatial_dims,
247  dummy_spatial_dims,
248  dummy_spatial_dims};
249 
251  return res.transform_func(conv_fwd);
252  }
254  {
256  3,
258  conv_fwd{dummy_dims,
259  dummy_dims,
260  dummy_dims,
261  dummy_dims,
262  out_lengths,
263  out_strides,
264  dummy_spatial_dims,
265  dummy_spatial_dims,
266  dummy_spatial_dims,
267  dummy_spatial_dims};
268 
270  return res.transform_func(conv_fwd);
271  }
272  throw std::runtime_error("Incorrect conv spec");
273 }
274 
277  ck::Array<ck::index_t, 4> out_lengths,
278  ck::Array<ck::index_t, 4> out_strides)
279 {
280  ck::Array<ck::index_t, 4> dummy_dims;
281  ck::Array<ck::index_t, 1> dummy_spatial_dims;
282 
283  if(num_dim == 1 &&
285  {
287  1,
289  conv_fwd{dummy_dims,
290  dummy_dims,
291  dummy_dims,
292  dummy_dims,
293  out_lengths,
294  out_strides,
295  dummy_spatial_dims,
296  dummy_spatial_dims,
297  dummy_spatial_dims,
298  dummy_spatial_dims};
299 
301  return res.transform_func(conv_fwd);
302  }
303  if(num_dim == 1 &&
305  {
307  1,
309  conv_fwd{dummy_dims,
310  dummy_dims,
311  dummy_dims,
312  dummy_dims,
313  out_lengths,
314  out_strides,
315  dummy_spatial_dims,
316  dummy_spatial_dims,
317  dummy_spatial_dims,
318  dummy_spatial_dims};
319 
321  return res.transform_func(conv_fwd);
322  }
323  if(num_dim == 1 &&
325  {
327  1,
329  conv_fwd{dummy_dims,
330  dummy_dims,
331  dummy_dims,
332  dummy_dims,
333  out_lengths,
334  out_strides,
335  dummy_spatial_dims,
336  dummy_spatial_dims,
337  dummy_spatial_dims,
338  dummy_spatial_dims};
339 
341  return res.transform_func(conv_fwd);
342  }
344  {
346  1,
348  conv_fwd{dummy_dims,
349  dummy_dims,
350  dummy_dims,
351  dummy_dims,
352  out_lengths,
353  out_strides,
354  dummy_spatial_dims,
355  dummy_spatial_dims,
356  dummy_spatial_dims,
357  dummy_spatial_dims};
358 
360  return res.transform_func(conv_fwd);
361  }
362  throw std::runtime_error("Incorrect dims or conv spec");
363 }
364 
365 template <typename CGridDesc_M_N>
366 auto block_2_etile(ck::index_t m_per_block, ck::index_t n_per_block, CGridDesc_M_N matrix_padder)
367 {
368  if(m_per_block == 32 && n_per_block == 64)
369  {
371  return b2e.CalculateGridSize(matrix_padder);
372  }
373  if(m_per_block == 32 && n_per_block == 128)
374  {
376  return b2e.CalculateGridSize(matrix_padder);
377  }
378  if(m_per_block == 64 && n_per_block == 32)
379  {
381  return b2e.CalculateGridSize(matrix_padder);
382  }
383  if(m_per_block == 64 && n_per_block == 64)
384  {
386  return b2e.CalculateGridSize(matrix_padder);
387  }
388  if(m_per_block == 64 && n_per_block == 128)
389  {
391  return b2e.CalculateGridSize(matrix_padder);
392  }
393  if(m_per_block == 128 && n_per_block == 32)
394  {
396  return b2e.CalculateGridSize(matrix_padder);
397  }
398  if(m_per_block == 128 && n_per_block == 64)
399  {
401  return b2e.CalculateGridSize(matrix_padder);
402  }
403  if(m_per_block == 128 && n_per_block == 128)
404  {
406  return b2e.CalculateGridSize(matrix_padder);
407  }
408  if(m_per_block == 128 && n_per_block == 256)
409  {
411  return b2e.CalculateGridSize(matrix_padder);
412  }
413  if(m_per_block == 256 && n_per_block == 128)
414  {
416  return b2e.CalculateGridSize(matrix_padder);
417  }
418  throw std::runtime_error("Incorrect template parameters");
419 }
420 
421 // wrapper functions by dims to get grid size - uses above 3 functions
422 // TODO: eventually remove the 1d/2d versions as CK will only support 3d convolutions
423 auto get_launch_params_1d(ck::host::Solution solution,
424  ck::Array<ck::index_t, 4> out_lengths,
425  ck::Array<ck::index_t, 4> out_strides)
426 {
427  auto num_dim = solution.GetTemplateParameter<ck::index_t>("NumDim");
428  auto m_per_block = solution.GetTemplateParameter<ck::index_t>("MPerBlock");
429  auto n_per_block = solution.GetTemplateParameter<ck::index_t>("NPerBlock");
430  auto k_per_block = solution.GetTemplateParameter<ck::index_t>("KPerBlock");
431  auto GemmType = solution.GetTemplateParameter<std::string>("GemmSpecialization");
432  auto ConvType = solution.GetTemplateParameter<std::string>("ConvSpecialization");
435  auto conv_to_gemm_transformer = transform_conv_1d(num_dim, ConvSpec, out_lengths, out_strides);
436  auto matrix_padder =
437  pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer);
438  auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder);
439  return b2e;
440 }
441 
442 auto get_launch_params(ck::host::Solution solution,
443  ck::Array<ck::index_t, 5> out_lengths,
444  ck::Array<ck::index_t, 5> out_strides)
445 {
446  auto num_dim = solution.GetTemplateParameter<ck::index_t>("NumDim");
447  auto m_per_block = solution.GetTemplateParameter<ck::index_t>("MPerBlock");
448  auto n_per_block = solution.GetTemplateParameter<ck::index_t>("NPerBlock");
449  auto k_per_block = solution.GetTemplateParameter<ck::index_t>("KPerBlock");
450  auto GemmType = solution.GetTemplateParameter<std::string>("GemmSpecialization");
451  auto ConvType = solution.GetTemplateParameter<std::string>("ConvSpecialization");
454  auto conv_to_gemm_transformer = transform_conv(num_dim, ConvSpec, out_lengths, out_strides);
455  auto matrix_padder =
456  pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer);
457  auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder);
458  return b2e;
459 }
460 
461 auto get_launch_params_3d(ck::host::Solution solution,
462  ck::Array<ck::index_t, 6> out_lengths,
463  ck::Array<ck::index_t, 6> out_strides)
464 {
465  auto num_dim = solution.GetTemplateParameter<ck::index_t>("NumDim");
466  auto m_per_block = solution.GetTemplateParameter<ck::index_t>("MPerBlock");
467  auto n_per_block = solution.GetTemplateParameter<ck::index_t>("NPerBlock");
468  auto k_per_block = solution.GetTemplateParameter<ck::index_t>("KPerBlock");
469  auto GemmType = solution.GetTemplateParameter<std::string>("GemmSpecialization");
470  auto ConvType = solution.GetTemplateParameter<std::string>("ConvSpecialization");
473  auto conv_to_gemm_transformer = transform_conv_3d(num_dim, ConvSpec, out_lengths, out_strides);
474  auto matrix_padder =
475  pad(m_per_block, n_per_block, k_per_block, GemmSpec, conv_to_gemm_transformer);
476  auto b2e = block_2_etile(m_per_block, n_per_block, matrix_padder);
477  return b2e;
478 }
auto transform_conv(ck::index_t num_dim, ck::tensor_operation::device::ConvolutionForwardSpecialization spec, ck::Array< ck::index_t, 5 > out_lengths, ck::Array< ck::index_t, 5 > out_strides)
Definition: helper.hpp:96
auto block_2_etile(ck::index_t m_per_block, ck::index_t n_per_block, CGridDesc_M_N matrix_padder)
Definition: helper.hpp:366
auto transform_conv_1d(ck::index_t num_dim, ck::tensor_operation::device::ConvolutionForwardSpecialization spec, ck::Array< ck::index_t, 4 > out_lengths, ck::Array< ck::index_t, 4 > out_strides)
Definition: helper.hpp:275
auto layout_type(std::string type)
Definition: helper.hpp:22
auto get_launch_params_3d(ck::host::Solution solution, ck::Array< ck::index_t, 6 > out_lengths, ck::Array< ck::index_t, 6 > out_strides)
Definition: helper.hpp:461
auto get_launch_params(ck::host::Solution solution, ck::Array< ck::index_t, 5 > out_lengths, ck::Array< ck::index_t, 5 > out_strides)
Definition: helper.hpp:442
ck::tensor_operation::device::GemmSpecialization gemm_type(std::string type)
Definition: helper.hpp:31
auto transform_conv_3d(ck::index_t num_dim, ck::tensor_operation::device::ConvolutionForwardSpecialization spec, ck::Array< ck::index_t, 6 > out_lengths, ck::Array< ck::index_t, 6 > out_strides)
Definition: helper.hpp:185
auto get_launch_params_1d(ck::host::Solution solution, ck::Array< ck::index_t, 4 > out_lengths, ck::Array< ck::index_t, 4 > out_strides)
Definition: helper.hpp:423
std::variant< ck::tensor_layout::convolution::GNWK, ck::tensor_layout::convolution::GNHWK, ck::tensor_layout::convolution::NHWGK, ck::tensor_layout::convolution::GNDHWK, ck::tensor_layout::convolution::NDHWGK > layouts
Definition: helper.hpp:20
auto pad(ck::index_t mpb, ck::index_t npb, ck::index_t kpb, ck::tensor_operation::device::GemmSpecialization gemm, CDesc_MRaw_NRaw conv)
Definition: helper.hpp:70
ck::tensor_operation::device::ConvolutionForwardSpecialization conv_type(std::string type)
Definition: helper.hpp:45
auto grid_desc(MatrixPadder< GemmSpec, MPerTileType, NPerTileType, KPerTileType > matrix_padder, CDesc_MRaw_NRaw conv_desc)
Definition: matrix_padder.hpp:190
GemmSpecialization
Definition: gemm_specialization.hpp:11
ConvolutionForwardSpecialization
Definition: convolution_forward_specialization.hpp:15
int32_t index_t
Definition: ck.hpp:289
Definition: array.hpp:14
Definition: block_to_ctile_map.hpp:260
Definition: tensor_layout.hpp:324
Definition: tensor_layout.hpp:319
Definition: tensor_layout.hpp:314
Definition: tensor_layout.hpp:341
Definition: tensor_layout.hpp:336
Definition: transform_conv_fwd_to_gemm.hpp:24
Definition: transform_conv_fwd_to_gemm.hpp:1559
Definition: matrix_padder.hpp:180