/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_elementwise_scale_impl.hpp Source File
device_elementwise_scale_impl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
9 #include "ck/utility/math.hpp"
10 #include "ck/utility/sequence.hpp"
14 
17 
18 namespace ck {
19 namespace tensor_operation {
20 namespace device {
21 
26 template <typename InDataTypeTuple,
27  typename OutDataTypeTuple,
28  typename ElementwiseOperation,
29  typename UnaryOperation,
30  typename Scale,
31  index_t NumDim,
32  index_t MPerThread,
33  typename InScalarPerVectorSeq,
34  typename OutScalarPerVectorSeq>
35 struct DeviceElementwiseImpl : public DeviceElementwise<InDataTypeTuple,
36  OutDataTypeTuple,
37  ElementwiseOperation,
38  UnaryOperation,
39  Scale,
40  NumDim>
41 {
42  static constexpr int NumInput = InDataTypeTuple::Size();
43  static constexpr int NumOutput = OutDataTypeTuple::Size();
44 
45  static_assert(NumInput == InScalarPerVectorSeq::Size() &&
46  NumOutput == OutScalarPerVectorSeq::Size(),
47  "Tuple size is inconsistent with the number of in/out!");
48 
50  {
51  return generate_tuple(
52  [&](auto I) {
53  using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
54 
55  return static_cast<const DataType*>(nullptr);
56  },
58  };
59 
61  {
62  return generate_tuple(
63  [&](auto I) {
64  using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
65 
66  return static_cast<DataType*>(nullptr);
67  },
69  };
70 
73 
74  template <typename Desc_M>
75  static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
76  {
77  constexpr auto I0 = Number<0>{};
78 
79  const auto m = desc_m.GetLength(I0);
80  const index_t loop_step = gridSize * blockSize * MPerThread;
81  const auto pad = math::integer_least_multiple(m, loop_step) - m;
82  const auto desc_m_pad =
87  return desc_m_pad;
88  }
89 
90  static auto MakeDescriptor_M(const std::array<index_t, NumDim>& lengths,
91  const std::array<index_t, NumDim>& stride,
92  index_t gridSize,
93  index_t blockSize)
94  {
95  auto tupleOfShape = generate_tuple([&](auto I) { return lengths[I]; }, Number<NumDim>{});
96  auto tupleOfStride = generate_tuple([&](auto I) { return stride[I]; }, Number<NumDim>{});
97 
98  // nd desc - [s0, s1, s2, ...]
99  const auto desc = make_naive_tensor_descriptor(tupleOfShape, tupleOfStride);
100 
101  // merge nd to 1d desc - [s0 * s1 * ...]
102  if constexpr(NumDim > 1)
103  {
104  const auto desc_m = transform_tensor_descriptor(
105  desc,
106  make_tuple(make_merge_transform(tupleOfShape)),
107  make_tuple(generate_sequence_v2([&](auto I) { return I; }, Number<NumDim>{})),
109 
110  return PadDescriptor_M_1d(desc_m, gridSize, blockSize);
111  }
112  else
113  return PadDescriptor_M_1d(desc, gridSize, blockSize);
114  }
115 
116  template <index_t TupleSize>
118  {
119  return generate_tuple(
120  [&](auto) {
121  if constexpr(NumDim > 1)
122  {
123  return MakeDescriptor_M({1, 1}, {1, 1}, 1, 1);
124  }
125  else
126  {
127  return MakeDescriptor_M({1}, {1}, 1, 1);
128  };
129  },
131  };
132 
135 
140  ElementwiseOperation,
141  UnaryOperation,
142  Scale,
143  MPerThread,
144  InScalarPerVectorSeq,
145  OutScalarPerVectorSeq>;
146 
147  struct Argument : public BaseArgument
148  {
149  Argument(const std::array<index_t, NumDim> lengths,
150  const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
151  const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
152  const std::array<const void*, NumInput> in_dev_buffers,
153  const std::array<void*, NumOutput> out_dev_buffers,
154  ElementwiseOperation elementwise_op,
155  UnaryOperation unary_op,
156  Scale scale_op)
157 
158  : lengths_(lengths),
159  inStridesArray_(inStridesArray),
160  outStridesArray_(outStridesArray),
161  elementwise_op_(elementwise_op),
162  unary_op_(unary_op),
163  scale_op_(scale_op),
164  blockSize_(256)
165  {
167  [&](auto I) {
168  using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
169  return static_cast<const DataType*>(in_dev_buffers[I.value]);
170  },
171  Number<NumInput>{});
172 
174  [&](auto I) {
175  using DataType = remove_cvref_t<decltype(OutDataTypeTuple{}[I])>;
176  return static_cast<DataType*>(out_dev_buffers[I.value]);
177  },
179  }
180 
183 
184  std::array<index_t, NumDim> lengths_;
185  std::array<std::array<index_t, NumDim>, NumInput> inStridesArray_;
186  std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray_;
187 
188  ElementwiseOperation elementwise_op_;
189  UnaryOperation unary_op_;
190  Scale scale_op_;
192  };
193 
194  struct Invoker : public BaseInvoker
195  {
196  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
197  {
198  index_t gridSize = getAvailableComputeUnitCount(stream_config);
199 
200  auto in_grid_1d_desc_tuple = generate_tuple(
201  [&](auto I) {
202  return MakeDescriptor_M(
203  arg.lengths_, arg.inStridesArray_[I.value], gridSize, arg.blockSize_);
204  },
205  Number<NumInput>{});
206 
207  auto out_grid_1d_desc_tuple = generate_tuple(
208  [&](auto I) {
209  return MakeDescriptor_M(
210  arg.lengths_, arg.outStridesArray_[I.value], gridSize, arg.blockSize_);
211  },
213 
214  const auto kernel = kernel_elementwise_1d<GridwiseElementwise,
219  ElementwiseOperation,
220  UnaryOperation,
221  Scale>;
222 
223  float elapsed_time = launch_and_time_kernel(stream_config,
224  kernel,
225  dim3(gridSize),
226  dim3(arg.blockSize_),
227  0,
228  in_grid_1d_desc_tuple,
229  out_grid_1d_desc_tuple,
230  arg.in_dev_buffers_,
231  arg.out_dev_buffers_,
232  arg.elementwise_op_,
233  arg.unary_op_,
234  arg.scale_op_);
235  return elapsed_time;
236  }
237 
238  // polymorphic
239  float Run(const BaseArgument* p_arg,
240  const StreamConfig& stream_config = StreamConfig{}) override
241  {
242  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
243  }
244  };
245 
246  static bool IsSupportedArgument(const Argument& arg)
247  {
248  if(arg.lengths_.back() % MPerThread != 0)
249  return false;
250 
251  auto IsScalarPerVectorValid = [&](const std::array<index_t, NumDim>& lengths,
252  const std::array<index_t, NumDim>& strides,
253  index_t scalarPerVector) {
254  if(strides.back() == 1 && lengths.back() % scalarPerVector == 0)
255  return true;
256 
257  if(strides.back() != 1 && scalarPerVector == 1)
258  return true;
259 
260  return false;
261  };
262 
263  bool valid = true;
264  static_for<0, NumInput, 1>{}([&](auto I) {
265  if(!IsScalarPerVectorValid(
266  arg.lengths_, arg.inStridesArray_[I.value], InScalarPerVectorSeq::At(I)))
267  valid = false;
268  });
269 
270  static_for<0, NumOutput, 1>{}([&](auto I) {
271  if(!IsScalarPerVectorValid(
272  arg.lengths_, arg.outStridesArray_[I.value], OutScalarPerVectorSeq::At(I)))
273  valid = false;
274  });
275 
276  return valid;
277  };
278 
279  bool IsSupportedArgument(const BaseArgument* p_arg) override
280  {
281  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
282  }
283 
284  static auto
285  MakeArgument(const std::array<index_t, NumDim> lengths,
286  const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
287  const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
288  const std::array<const void*, NumInput> in_dev_buffers,
289  const std::array<void*, NumOutput> out_dev_buffers,
290  ElementwiseOperation elementwise_op,
291  UnaryOperation unary_op,
292  Scale scale_op)
293  {
294  return Argument{lengths,
295  inStridesArray,
296  outStridesArray,
297  in_dev_buffers,
298  out_dev_buffers,
299  elementwise_op,
300  unary_op,
301  scale_op};
302  }
303 
304  std::unique_ptr<BaseArgument>
305  MakeArgumentPointer(const std::array<index_t, NumDim> lengths,
306  const std::array<std::array<index_t, NumDim>, NumInput> inStridesArray,
307  const std::array<std::array<index_t, NumDim>, NumOutput> outStridesArray,
308  const std::array<const void*, NumInput> in_dev_buffers,
309  const std::array<void*, NumOutput> out_dev_buffers,
310  ElementwiseOperation elementwise_op,
311  UnaryOperation unary_op,
312  Scale scale_op) override
313  {
314  return std::make_unique<Argument>(lengths,
315  inStridesArray,
316  outStridesArray,
317  in_dev_buffers,
318  out_dev_buffers,
319  elementwise_op,
320  unary_op,
321  scale_op);
322  }
323 
324  static auto MakeInvoker() { return Invoker{}; }
325  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
326  {
327  return std::make_unique<Invoker>();
328  };
329 
330  std::string GetTypeString() const override
331  {
332  auto str = std::stringstream();
333 
334  // clang-format off
335  str << "DeviceElementwiseNormalizationImpl<";
336  str << NumDim << ", ";
337  str << MPerThread << ">";
338  // clang-format on
339 
340  return str.str();
341  }
342 }; // namespace device
343 
344 } // namespace device
345 } // namespace tensor_operation
346 } // namespace ck
auto pad(ck::index_t mpb, ck::index_t npb, ck::index_t kpb, ck::tensor_operation::device::GemmSpecialization gemm, CDesc_MRaw_NRaw conv)
Definition: helper.hpp:70
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__global__ void kernel_elementwise_1d(const InGrid1dDescTuple in_grid_1d_desc_tuple, const OutGrid1dDescTuple out_grid_1d_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const ElementwiseOperation elementwise_op, const UnaryOperation unary_op, const Scale scale_op)
Definition: gridwise_elementwise_1d_scale.hpp:21
__host__ constexpr __device__ auto generate_sequence_v2(F &&f, Number< N >)
Definition: sequence_helper.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
Definition: stream_config.hpp:10
Definition: gridwise_elementwise_1d_scale.hpp:49
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_elementwise_dynamic_vector_dims_impl.hpp:214
Scale scale_op_
Definition: device_elementwise_scale_impl.hpp:190
InDataTypePointerTuple in_dev_buffers_
Definition: device_elementwise_dynamic_vector_dims_impl.hpp:242
UnaryOperation unary_op_
Definition: device_elementwise_scale_impl.hpp:189
std::array< index_t, NumDim > lengths_
Definition: device_elementwise_dynamic_vector_dims_impl.hpp:245
Argument(const std::array< index_t, NumDim > lengths, const std::array< std::array< index_t, NumDim >, NumInput > inStridesArray, const std::array< std::array< index_t, NumDim >, NumOutput > outStridesArray, const std::array< const void *, NumInput > in_dev_buffers, const std::array< void *, NumOutput > out_dev_buffers, ElementwiseOperation elementwise_op, UnaryOperation unary_op, Scale scale_op)
Definition: device_elementwise_scale_impl.hpp:149
OutDataTypePointerTuple out_dev_buffers_
Definition: device_elementwise_dynamic_vector_dims_impl.hpp:243
index_t blockSize_
Definition: device_elementwise_scale_impl.hpp:191
ElementwiseOperation elementwise_op_
Definition: device_elementwise_dynamic_vector_dims_impl.hpp:249
std::array< std::array< index_t, NumDim >, NumInput > inStridesArray_
Definition: device_elementwise_dynamic_vector_dims_impl.hpp:246
std::array< std::array< index_t, NumDim >, NumOutput > outStridesArray_
Definition: device_elementwise_dynamic_vector_dims_impl.hpp:247
Definition: device_elementwise_dynamic_vector_dims_impl.hpp:253
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_elementwise_scale_impl.hpp:196
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_elementwise_scale_impl.hpp:239
static auto MakeInvoker()
Definition: device_elementwise_scale_impl.hpp:324
decltype(GenerateInOutGrid1dDescTuple(Number< NumInput >{})) InGrid1dDescTuple
Definition: device_elementwise_scale_impl.hpp:133
static auto MakeArgument(const std::array< index_t, NumDim > lengths, const std::array< std::array< index_t, NumDim >, NumInput > inStridesArray, const std::array< std::array< index_t, NumDim >, NumOutput > outStridesArray, const std::array< const void *, NumInput > in_dev_buffers, const std::array< void *, NumOutput > out_dev_buffers, ElementwiseOperation elementwise_op, UnaryOperation unary_op, Scale scale_op)
Definition: device_elementwise_scale_impl.hpp:285
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_elementwise_scale_impl.hpp:325
static auto PadDescriptor_M_1d(Desc_M desc_m, index_t gridSize, index_t blockSize)
Definition: device_elementwise_scale_impl.hpp:75
static constexpr auto I0
Definition: device_elementwise_dynamic_vector_dims_impl.hpp:41
decltype(GenerateInDataTypePointerTuple()) InDataTypePointerTuple
Definition: device_elementwise_dynamic_vector_dims_impl.hpp:70
GridwiseElementwise_1D< InGrid1dDescTuple, OutGrid1dDescTuple, InDataTypePointerTuple, OutDataTypePointerTuple, ElementwiseOperation, UnaryOperation, Scale, MPerThread, InScalarPerVectorSeq, OutScalarPerVectorSeq > GridwiseElementwise
Definition: device_elementwise_scale_impl.hpp:145
static auto MakeDescriptor_M(const std::array< index_t, NumDim > &lengths, const std::array< index_t, NumDim > &stride, index_t gridSize, index_t blockSize)
Definition: device_elementwise_scale_impl.hpp:90
decltype(GenerateInOutGrid1dDescTuple(Number< NumOutput >{})) OutGrid1dDescTuple
Definition: device_elementwise_scale_impl.hpp:134
decltype(GenerateOutDataTypePointerTuple()) OutDataTypePointerTuple
Definition: device_elementwise_dynamic_vector_dims_impl.hpp:71
static auto GenerateInOutGrid1dDescTuple(Number< TupleSize >)
Definition: device_elementwise_scale_impl.hpp:117
static constexpr int NumInput
Definition: device_elementwise_dynamic_vector_dims_impl.hpp:38
static bool IsSupportedArgument(const Argument &arg)
Definition: device_elementwise_scale_impl.hpp:246
std::string GetTypeString() const override
Definition: device_elementwise_scale_impl.hpp:330
static constexpr int NumOutput
Definition: device_elementwise_dynamic_vector_dims_impl.hpp:39
static auto GenerateInDataTypePointerTuple()
Definition: device_elementwise_scale_impl.hpp:49
static auto GenerateOutDataTypePointerTuple()
Definition: device_elementwise_scale_impl.hpp:60
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, NumDim > lengths, const std::array< std::array< index_t, NumDim >, NumInput > inStridesArray, const std::array< std::array< index_t, NumDim >, NumOutput > outStridesArray, const std::array< const void *, NumInput > in_dev_buffers, const std::array< void *, NumOutput > out_dev_buffers, ElementwiseOperation elementwise_op, UnaryOperation unary_op, Scale scale_op) override
Definition: device_elementwise_scale_impl.hpp:305
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_elementwise_scale_impl.hpp:279