/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp Source File
image_to_column_kernel.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include "ck_tile/core.hpp"
7 #include "ck_tile/ops/common.hpp"
8 
9 namespace ck_tile {
10 
11 template <typename Problem_>
13 {
14  static constexpr auto I0 = number<0>{};
15  static constexpr auto I1 = number<1>{};
16  static constexpr auto I2 = number<2>{};
17  static constexpr auto I3 = number<3>{};
18  static constexpr auto I4 = number<4>{};
19 
21 
24 
25  static constexpr index_t NDimSpatial = Problem::NDimSpatial;
26 
27  static constexpr index_t AligmentIn = Problem::AligmentIn;
28  static constexpr index_t AligmentOut = Problem::AligmentOut;
29 
30  static_assert(NDimSpatial == 2, "Not supported.");
31 
32  static constexpr index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
33  static constexpr index_t kKPerBlock = Problem::BlockShape::kKPerBlock;
34 
35  struct Kargs
36  {
37  const void* p_in;
38  void* p_out;
39 
40  const long_index_t G;
41  const long_index_t N;
42  const long_index_t C;
43 
53  };
54 
55  CK_TILE_HOST static constexpr Kargs
56  MakeKargs(const void* p_in,
57  void* p_out,
58  const long_index_t G,
59  const long_index_t N,
60  const long_index_t C,
61  const array<long_index_t, NDimSpatial> input_spatial_lengths,
62  const array<long_index_t, NDimSpatial> filter_spatial_lengths,
63  const array<long_index_t, NDimSpatial> output_spatial_lengths,
64  const array<long_index_t, NDimSpatial + 3> image_g_n_c_wis_strides,
65  const array<long_index_t, 3> gemm_g_m_k_strides,
66  const array<long_index_t, NDimSpatial> conv_filter_strides,
67  const array<long_index_t, NDimSpatial> conv_filter_dilations,
68  const array<long_index_t, NDimSpatial> input_left_pads,
69  const array<long_index_t, NDimSpatial> input_right_pads)
70  {
71  return Kargs{p_in,
72  p_out,
73  G,
74  N,
75  C,
76  input_spatial_lengths,
77  filter_spatial_lengths,
78  output_spatial_lengths,
79  image_g_n_c_wis_strides,
80  gemm_g_m_k_strides,
81  conv_filter_strides,
82  conv_filter_dilations,
83  input_left_pads,
84  input_right_pads};
85  }
86 
87  CK_TILE_HOST static constexpr auto GridSize(index_t GemmM, index_t GemmK, index_t Batch)
88  {
89  return dim3(
91  }
92 
93  CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::kBlockSize; }
94 
95  CK_TILE_DEVICE auto MakeImageMKDesc(const Kargs& kargs) const
96  {
97  static_assert(NDimSpatial == 2, "Not supported.");
98 
99  const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
100  make_tuple(
101  kargs.N, kargs.input_spatial_lengths[I0], kargs.input_spatial_lengths[I1], kargs.C),
105  kargs.image_g_n_c_wis_strides[I2]),
107  I1);
108 
109  const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
110  in_n_hi_wi_c_desc,
113  kargs.input_left_pads[I0],
114  kargs.input_right_pads[I0]),
116  kargs.input_left_pads[I1],
117  kargs.input_right_pads[I1]),
121 
122  const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
123  in_n_hip_wip_c_desc,
124  make_tuple(
135 
137  in_n_y_ho_x_wo_c_desc,
138  make_tuple(
140  kargs.N, kargs.output_spatial_lengths[I0], kargs.output_spatial_lengths[I1])),
142  kargs.filter_spatial_lengths[I0], kargs.filter_spatial_lengths[I1], kargs.C))),
145  }
146 
147  CK_TILE_DEVICE auto CalculateMKDims(const Kargs& kargs) const
148  {
149  static_assert(NDimSpatial == 2, "Not supported.");
150  const index_t M = kargs.N * static_cast<index_t>(kargs.output_spatial_lengths[I0] *
151  kargs.output_spatial_lengths[I1]);
152  const index_t K = kargs.C * static_cast<index_t>(kargs.filter_spatial_lengths[I0] *
153  kargs.filter_spatial_lengths[I1]);
154  return make_tuple(M, K);
155  }
156 
158  {
159  using P = typename Problem::BlockShape;
160  // P: {kMWarpPerBlock * kKWarpPerBlock, kMThreadPerWarp * kKThreadPerWarp}
161  // Y: {kMPerThread, kKPerThread}
164  sequence<1>,
170  sequence<2, 2>>{});
171  }
172 
173  CK_TILE_DEVICE void ConvTensorRearrange(const Kargs& kargs) const
174  {
175  const auto [M, K] = CalculateMKDims(kargs);
176 
177  const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock);
178  const index_t iK = __builtin_amdgcn_readfirstlane(blockIdx.y * kKPerBlock);
179  const index_t iBatch = __builtin_amdgcn_readfirstlane(blockIdx.z);
180 
181  const auto in_offset = iBatch * kargs.image_g_n_c_wis_strides[I0];
182  const auto out_offset = iBatch * kargs.gemm_g_m_k_strides[I0];
183 
184  const auto image_m_k = make_tensor_view<address_space_enum::global>(
185  static_cast<const InDataType*>(kargs.p_in) + in_offset, MakeImageMKDesc(kargs));
186  const auto gemm_m_k = make_naive_tensor_view<address_space_enum::global>(
187  static_cast<OutDataType*>(kargs.p_out) + out_offset,
188  make_tuple(M, K),
191  I1);
192 
193  const auto image_m_k_padded =
194  pad_tensor_view(image_m_k,
197  const auto gemm_m_k_padded =
198  pad_tensor_view(gemm_m_k,
201 
202  constexpr auto dstr = MakeBlockTileDistribution();
203 
204  const auto image_tile =
205  make_tile_window(image_m_k_padded,
207  {iM, iK},
208  dstr);
209 
210  auto gemm_tile = make_tile_window(gemm_m_k_padded,
212  {iM, iK},
213  dstr);
214 
215  // load from Global
216  const auto loaded_tile = load_tile(image_tile);
217  // save to Global
218  store_tile(gemm_tile, loaded_tile);
219  }
220 
221  CK_TILE_DEVICE void operator()(Kargs& kargs) const { ConvTensorRearrange(kargs); }
222 };
223 
224 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto make_naive_tensor_descriptor(const tuple< Lengths... > &lengths, const tuple< Strides... > &strides, number< GuaranteedLastDimensionVectorLength >=number<-1 >{}, number< GuaranteedLastDimensionVectorStride >=number<-1 >{})
Definition: tensor_descriptor.hpp:255
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
constexpr CK_TILE_HOST_DEVICE auto make_merge_transform(const LowLengths &low_lengths)
Definition: coordinate_transform.hpp:1672
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE auto pad_tensor_view(const TensorView &tensor_view, const TileLengths &tile_lengths, DoPads)
Definition: tensor_view.hpp:480
constexpr CK_TILE_HOST_DEVICE auto make_pass_through_transform(const LowLength &low_length)
Definition: coordinate_transform.hpp:1615
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:20
constexpr CK_TILE_HOST_DEVICE auto make_pad_transform(const LowLength &low_length, const LeftPad &left_pad, const RightPad &right_pad, bool_constant< SkipIsValidCheck >=bool_constant< false >{})
Definition: coordinate_transform.hpp:1622
CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution< BottomTensorView_, WindowLengths_, TileDistribution_, NumCoord > &tile_window, number< i_access >={}, bool_constant< oob_conditional_check >={})
Definition: load_tile.hpp:27
int64_t long_index_t
Definition: integer.hpp:10
constexpr CK_TILE_HOST_DEVICE auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldTopIdss, NewUpperDimensionNewTopIdss)
Definition: tensor_descriptor.hpp:184
constexpr CK_TILE_DEVICE auto make_tile_window(null_tensor_view, const WindowLengths &window_lengths, const multi_index< WindowLengths::size()> &, Ts &&...)
Definition: null_tile_window.hpp:72
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
CK_TILE_DEVICE void store_tile(tile_window_with_static_lengths< BottomTensorView_, WindowLengths_ > &tile_window_tmp, const static_distributed_tensor< DataType_, TileDistribution_ > &dstr_tensor)
Definition: store_tile.hpp:23
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:498
constexpr CK_TILE_HOST_DEVICE auto make_embed_transform(const UpLengths &up_lengths, const Coefficients &coefficients)
Definition: coordinate_transform.hpp:1651
Definition: image_to_column_kernel.hpp:36
const array< long_index_t, NDimSpatial > input_right_pads
Definition: image_to_column_kernel.hpp:52
const array< long_index_t, NDimSpatial > conv_filter_strides
Definition: image_to_column_kernel.hpp:49
const array< long_index_t, NDimSpatial > conv_filter_dilations
Definition: image_to_column_kernel.hpp:50
const long_index_t C
Definition: image_to_column_kernel.hpp:42
const long_index_t N
Definition: image_to_column_kernel.hpp:41
const long_index_t G
Definition: image_to_column_kernel.hpp:40
const array< long_index_t, NDimSpatial+3 > image_g_n_c_wis_strides
Definition: image_to_column_kernel.hpp:47
const void * p_in
Definition: image_to_column_kernel.hpp:37
const array< long_index_t, NDimSpatial > filter_spatial_lengths
Definition: image_to_column_kernel.hpp:45
void * p_out
Definition: image_to_column_kernel.hpp:38
const array< long_index_t, NDimSpatial > input_left_pads
Definition: image_to_column_kernel.hpp:51
const array< long_index_t, 3 > gemm_g_m_k_strides
Definition: image_to_column_kernel.hpp:48
const array< long_index_t, NDimSpatial > input_spatial_lengths
Definition: image_to_column_kernel.hpp:44
const array< long_index_t, NDimSpatial > output_spatial_lengths
Definition: image_to_column_kernel.hpp:46
Definition: image_to_column_kernel.hpp:13
CK_TILE_DEVICE auto MakeImageMKDesc(const Kargs &kargs) const
Definition: image_to_column_kernel.hpp:95
static constexpr CK_TILE_HOST Kargs MakeKargs(const void *p_in, void *p_out, const long_index_t G, const long_index_t N, const long_index_t C, const array< long_index_t, NDimSpatial > input_spatial_lengths, const array< long_index_t, NDimSpatial > filter_spatial_lengths, const array< long_index_t, NDimSpatial > output_spatial_lengths, const array< long_index_t, NDimSpatial+3 > image_g_n_c_wis_strides, const array< long_index_t, 3 > gemm_g_m_k_strides, const array< long_index_t, NDimSpatial > conv_filter_strides, const array< long_index_t, NDimSpatial > conv_filter_dilations, const array< long_index_t, NDimSpatial > input_left_pads, const array< long_index_t, NDimSpatial > input_right_pads)
Definition: image_to_column_kernel.hpp:56
static constexpr auto I2
Definition: image_to_column_kernel.hpp:16
CK_TILE_DEVICE void ConvTensorRearrange(const Kargs &kargs) const
Definition: image_to_column_kernel.hpp:173
static constexpr auto I4
Definition: image_to_column_kernel.hpp:18
static constexpr auto I1
Definition: image_to_column_kernel.hpp:15
static constexpr CK_TILE_DEVICE auto MakeBlockTileDistribution()
Definition: image_to_column_kernel.hpp:157
remove_cvref_t< typename Problem::InDataType > InDataType
Definition: image_to_column_kernel.hpp:22
remove_cvref_t< Problem_ > Problem
Definition: image_to_column_kernel.hpp:20
static constexpr index_t AligmentOut
Definition: image_to_column_kernel.hpp:28
static constexpr CK_TILE_HOST auto BlockSize()
Definition: image_to_column_kernel.hpp:93
static constexpr index_t AligmentIn
Definition: image_to_column_kernel.hpp:27
static constexpr CK_TILE_HOST auto GridSize(index_t GemmM, index_t GemmK, index_t Batch)
Definition: image_to_column_kernel.hpp:87
CK_TILE_DEVICE void operator()(Kargs &kargs) const
Definition: image_to_column_kernel.hpp:221
static constexpr auto I3
Definition: image_to_column_kernel.hpp:17
static constexpr index_t kKPerBlock
Definition: image_to_column_kernel.hpp:33
remove_cvref_t< typename Problem::OutDataType > OutDataType
Definition: image_to_column_kernel.hpp:23
CK_TILE_DEVICE auto CalculateMKDims(const Kargs &kargs) const
Definition: image_to_column_kernel.hpp:147
static constexpr index_t kMPerBlock
Definition: image_to_column_kernel.hpp:32
static constexpr index_t NDimSpatial
Definition: image_to_column_kernel.hpp:25
static constexpr auto I0
Definition: image_to_column_kernel.hpp:14
Definition: array.hpp:24
Definition: integral_constant.hpp:13
Definition: sequence.hpp:52
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192