/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_permute_impl.hpp Source File
device_permute_impl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <array>
7 #include <memory>
8 #include <utility>
9 
10 #include "ck/utility/math.hpp"
11 #include "ck/utility/sequence.hpp"
17 
19 
20 namespace ck {
21 namespace tensor_operation {
22 namespace device {
23 
24 // Swap last 2 dimensions
25 // input shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2], d[NumDim-1]]
26 // ^^^^^^^^^^^
27 // output shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-1], d[NumDim-2]]
28 // ^^^^^^^^^^^
29 template <index_t NumDim,
30  typename InDataType,
31  typename OutDataType,
32  typename ElementwiseOperation,
33  index_t BlockSize,
34  index_t NPerBlock,
35  index_t HPerBlock,
36  index_t WPerBlock,
37  index_t InBlockLdsExtraW,
38  typename InBlockTransferThreadClusterLengths,
39  typename InBlockTransferThreadClusterArrangeOrder,
40  index_t SrcVectorDim,
41  index_t DstVectorDim,
42  index_t SrcScalarPerVector,
43  index_t DstScalarPerVector>
44 struct DevicePermuteImpl : DevicePermute<NumDim, InDataType, OutDataType, ElementwiseOperation>
45 {
47  using typename BaseType::Lengths;
48  using typename BaseType::Strides;
49 
50  static_assert(3 <= NumDim, "Only accept at least 3D dimension tensor");
51  static_assert((NumDim - 2) <= SrcVectorDim && SrcVectorDim < NumDim);
52  static_assert((NumDim - 2) <= DstVectorDim && DstVectorDim < NumDim);
53  static_assert(SrcVectorDim != DstVectorDim);
54 
55  template <index_t N = NumDim>
56  static auto ConvertArrayToTuple(const std::array<index_t, NumDim>& array)
57  {
58  static_assert(1 <= N && N <= NumDim);
59 
60  return generate_tuple([&](auto I) { return array[I]; }, Number<N>{});
61  }
62 
63  static auto MakeDescriptor_N_H_W(const Lengths& lengths, const Strides& stride)
64  {
65  // create nd descriptor, shape: [d[0], d[1], d[2], ..., d[NumDim-3], d[NumDim-2],
66  // d[NumDim-1]]
67  const auto desc =
69 
70  // merge nd to 3d descriptor, shape: [(d[0] * d[1] * d[2] * ... * d[NumDim-3]), d[NumDim-2],
71  // d[NumDim-1]]
72  // => [N, H, W]
73  const index_t H = *std::next(rbegin(lengths));
74  const index_t W = *rbegin(lengths);
75  const auto desc_n_h_w = transform_tensor_descriptor(
76  desc,
77  make_tuple(make_merge_transform(ConvertArrayToTuple<NumDim - 2>(lengths)),
80  make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim - 2>{}),
82  Sequence<NumDim - 1>{}),
84 
85  return PadTensorDescriptor(
86  desc_n_h_w, make_tuple(NPerBlock, HPerBlock, WPerBlock), Sequence<true, true, true>{});
87  }
88 
89  using InGridDesc = decltype(MakeDescriptor_N_H_W({1, 1}, {1, 1}));
91 
93  InGridDesc,
95  InDataType,
96  OutDataType,
97  ElementwiseOperation,
98  BlockSize,
99  NPerBlock,
100  HPerBlock,
101  WPerBlock,
102  InBlockLdsExtraW,
103  InBlockTransferThreadClusterLengths,
104  InBlockTransferThreadClusterArrangeOrder,
105  SrcVectorDim - (NumDim - 3), // calculate new SrcVectorDim for the merged descriptor
106  DstVectorDim - (NumDim - 3), // calculate new DstVectorDim for the merged descriptor
107  SrcScalarPerVector,
108  DstScalarPerVector>;
109 
111 
112  struct Argument : public BaseArgument
113  {
114  Argument(const Lengths& in_lengths,
115  const Strides& in_strides,
116  const Lengths& out_lengths,
117  const Strides& out_strides,
118  const void* in_dev_buffer,
119  void* out_dev_buffer,
120  ElementwiseOperation elementwise_op)
121  : in_dev_buffer_(static_cast<const InDataType*>(in_dev_buffer)),
122  out_dev_buffer_(static_cast<OutDataType*>(out_dev_buffer)),
123  in_grid_desc_(MakeDescriptor_N_H_W(in_lengths, in_strides)),
124  out_grid_desc_(MakeDescriptor_N_H_W(out_lengths, out_strides)),
125  in_lengths_(in_lengths),
126  in_strides_(in_strides),
127  out_lengths_(out_lengths),
128  out_strides_(out_strides),
129  elementwise_op_(elementwise_op),
130  block_2_tile_map_(GridwisePermute::MakeDefaultBlock2TileMap(in_grid_desc_))
131  {
132  }
133 
134  const InDataType* in_dev_buffer_;
135  OutDataType* out_dev_buffer_;
138 
143 
144  ElementwiseOperation elementwise_op_;
145 
147  };
148 
150  {
151  static float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
152  {
153  const index_t grid_size = arg.block_2_tile_map_.CalculateGridSize(arg.in_grid_desc_);
154 
155  const auto kernel = kernel_nd_permute<GridwisePermute,
156  InGridDesc,
157  OutGridDesc,
158  InDataType,
159  OutDataType,
160  ElementwiseOperation,
161  Block2TileMap>;
162 
163  float elapsed_time = launch_and_time_kernel(stream_config,
164  kernel,
165  dim3(grid_size),
166  dim3(BlockSize),
167  0,
168  arg.in_grid_desc_,
169  arg.out_grid_desc_,
170  arg.in_dev_buffer_,
171  arg.out_dev_buffer_,
172  arg.elementwise_op_,
173  arg.block_2_tile_map_);
174  return elapsed_time;
175  }
176 
177  float Run(const BaseArgument* arg,
178  const StreamConfig& stream_config = StreamConfig{}) override final
179  {
180  const auto* const argument = dynamic_cast<const Argument*>(arg);
181  if(!argument)
182  {
183  return NAN;
184  }
185 
186  return Run(*argument, stream_config);
187  }
188  };
189 
190  static bool IsSupportedArgument(const Argument& arg)
191  {
192  constexpr auto GetPaddedLength = [](index_t length, index_t tile_length) {
193  return math::integer_divide_ceil(length, tile_length) * tile_length;
194  };
195 
196  constexpr auto IsScalarPerVectorValid =
197  [](index_t length, index_t stride, index_t scalar_per_vector) {
198  if(stride == 1 && length % scalar_per_vector == 0)
199  {
200  return true;
201  }
202  else if(stride != 1 && scalar_per_vector == 1)
203  {
204  return true;
205  }
206 
207  return false;
208  };
209 
210  return IsScalarPerVectorValid(arg.in_lengths_[SrcVectorDim],
211  arg.in_strides_[SrcVectorDim],
212  SrcScalarPerVector) &&
213  IsScalarPerVectorValid(
214  GetPaddedLength(arg.in_lengths_[SrcVectorDim],
215  (SrcVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)),
216  arg.in_strides_[SrcVectorDim],
217  SrcScalarPerVector) &&
218  IsScalarPerVectorValid(arg.out_lengths_[DstVectorDim],
219  arg.out_strides_[DstVectorDim],
220  DstScalarPerVector) &&
221  IsScalarPerVectorValid(
222  GetPaddedLength(arg.out_lengths_[DstVectorDim],
223  (DstVectorDim == NumDim - 2 ? HPerBlock : WPerBlock)),
224  arg.in_strides_[DstVectorDim],
225  DstScalarPerVector) &&
227  };
228 
229  // override methods inherited from 'BaseOperator'
230  bool IsSupportedArgument(const BaseArgument* arg) override final
231  {
232  const auto* const argument = dynamic_cast<const Argument*>(arg);
233  if(!argument)
234  {
235  return false;
236  }
237 
238  return IsSupportedArgument(*argument);
239  }
240 
241  // override methods inherited from 'DevicePermute'
242  std::unique_ptr<BaseArgument>
243  MakeArgumentPointer(const Lengths& in_lengths,
244  const Strides& in_strides,
245  const Lengths& out_lengths,
246  const Strides& out_strides,
247  const void* in_dev_buffer,
248  void* out_dev_buffer,
249  ElementwiseOperation elementwise_op) override final
250  {
251  return std::make_unique<Argument>(in_lengths,
252  in_strides,
253  out_lengths,
254  out_strides,
255  in_dev_buffer,
256  out_dev_buffer,
257  elementwise_op);
258  }
259 
260  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override final
261  {
262  return std::make_unique<Invoker>();
263  };
264 
265  // other constructor methods
266  template <typename... Args>
267  static std::enable_if_t<std::is_constructible_v<Argument, Args...>, Argument>
268  MakeArgument(Args&&... args) noexcept(std::is_nothrow_constructible_v<Argument, Args...>)
269  {
270  return Argument{std::forward<Args>(args)...};
271  }
272 
273  static std::enable_if_t<std::is_default_constructible_v<Invoker>, Invoker>
274  MakeInvoker() noexcept(std::is_nothrow_default_constructible_v<Invoker>)
275  {
276  return Invoker{};
277  }
278 };
279 
280 } // namespace device
281 } // namespace tensor_operation
282 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto PadTensorDescriptor(const TensorDesc &desc, const TileLengths &tile_lengths, DoPads)
Definition: matrix_padder.hpp:19
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__global__ void kernel_nd_permute(const InGridDesc in_grid_desc, const OutGridDesc out_grid_desc, const InDataType *p_in_global, OutDataType *p_out_global, const ElementwiseOperation elementwise_op, const Block2TileMap block_2_tile_map)
Definition: gridwise_permute.hpp:25
__host__ constexpr __device__ auto generate_sequence_v2(F &&f, Number< N >)
Definition: sequence_helper.hpp:25
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
typename std::enable_if< B, T >::type enable_if_t
Definition: enable_if.hpp:13
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
Definition: stream_config.hpp:10
__host__ static constexpr __device__ bool CheckValidity(const InGridDesc &in_grid_desc, const OutGridDesc &out_grid_desc)
Definition: gridwise_permute.hpp:182
Block2TileMap DefaultBlock2TileMap
Definition: gridwise_permute.hpp:134
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_permute.hpp:18
Lengths Strides
Definition: device_permute.hpp:20
std::array< index_t, NumDim > Lengths
Definition: device_permute.hpp:19
Definition: device_permute_impl.hpp:113
Strides in_strides_
Definition: device_permute_impl.hpp:140
ElementwiseOperation elementwise_op_
Definition: device_permute_impl.hpp:144
Argument(const Lengths &in_lengths, const Strides &in_strides, const Lengths &out_lengths, const Strides &out_strides, const void *in_dev_buffer, void *out_dev_buffer, ElementwiseOperation elementwise_op)
Definition: device_permute_impl.hpp:114
OutGridDesc out_grid_desc_
Definition: device_permute_impl.hpp:137
Block2TileMap block_2_tile_map_
Definition: device_permute_impl.hpp:146
InGridDesc in_grid_desc_
Definition: device_permute_impl.hpp:136
const InDataType * in_dev_buffer_
Definition: device_permute_impl.hpp:134
Lengths in_lengths_
Definition: device_permute_impl.hpp:139
Strides out_strides_
Definition: device_permute_impl.hpp:142
OutDataType * out_dev_buffer_
Definition: device_permute_impl.hpp:135
Lengths out_lengths_
Definition: device_permute_impl.hpp:141
Definition: device_permute_impl.hpp:150
static float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_permute_impl.hpp:151
float Run(const BaseArgument *arg, const StreamConfig &stream_config=StreamConfig{}) override final
Definition: device_permute_impl.hpp:177
Definition: device_permute_impl.hpp:45
InGridDesc OutGridDesc
Definition: device_permute_impl.hpp:90
static std::enable_if_t< std::is_constructible_v< Argument, Args... >, Argument > MakeArgument(Args &&... args) noexcept(std::is_nothrow_constructible_v< Argument, Args... >)
Definition: device_permute_impl.hpp:268
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override final
Definition: device_permute_impl.hpp:260
decltype(MakeDescriptor_N_H_W({1, 1}, {1, 1})) InGridDesc
Definition: device_permute_impl.hpp:89
static auto MakeDescriptor_N_H_W(const Lengths &lengths, const Strides &stride)
Definition: device_permute_impl.hpp:63
bool IsSupportedArgument(const BaseArgument *arg) override final
Definition: device_permute_impl.hpp:230
static std::enable_if_t< std::is_default_constructible_v< Invoker >, Invoker > MakeInvoker() noexcept(std::is_nothrow_default_constructible_v< Invoker >)
Definition: device_permute_impl.hpp:274
static auto ConvertArrayToTuple(const std::array< index_t, NumDim > &array)
Definition: device_permute_impl.hpp:56
typename GridwisePermute::DefaultBlock2TileMap Block2TileMap
Definition: device_permute_impl.hpp:110
std::unique_ptr< BaseArgument > MakeArgumentPointer(const Lengths &in_lengths, const Strides &in_strides, const Lengths &out_lengths, const Strides &out_strides, const void *in_dev_buffer, void *out_dev_buffer, ElementwiseOperation elementwise_op) override final
Definition: device_permute_impl.hpp:243
static bool IsSupportedArgument(const Argument &arg)
Definition: device_permute_impl.hpp:190
GridwisePermute< InGridDesc, OutGridDesc, InDataType, OutDataType, ElementwiseOperation, BlockSize, NPerBlock, HPerBlock, WPerBlock, InBlockLdsExtraW, InBlockTransferThreadClusterLengths, InBlockTransferThreadClusterArrangeOrder, SrcVectorDim -(NumDim - 3), DstVectorDim -(NumDim - 3), SrcScalarPerVector, DstScalarPerVector > GridwisePermute
Definition: device_permute_impl.hpp:108