/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp Source File
device_softmax_impl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
16 
17 namespace ck {
18 namespace tensor_operation {
19 namespace device {
20 
21 template <typename InDataType,
22  typename AccDataType,
23  typename OutDataType,
24  typename InElementwiseOp,
25  typename AccElementwiseOp,
26  index_t Rank,
27  index_t NumReduceDim,
28  index_t BlockSize,
29  index_t MThreadClusterSize,
30  index_t KThreadClusterSize,
31  index_t MThreadSliceSize,
32  index_t KThreadSliceSize,
33  index_t InSrcVectorDim,
34  index_t InSrcVectorSize,
35  index_t OutDstVectorSize>
36 struct DeviceSoftmaxImpl : public DeviceSoftmax<InDataType,
37  AccDataType,
38  OutDataType,
39  InElementwiseOp,
40  AccElementwiseOp,
41  Rank,
42  NumReduceDim>
43 {
44  static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
45 
46  static constexpr index_t NumSrcDim = Rank;
47  static constexpr index_t NumDstDim = (NumInvariantDim == 0) ? 1 : NumInvariantDim;
48  static constexpr bool reduceAllDim = (NumInvariantDim == 0);
49 
50  static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
51  static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
52 
53  static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
54  const std::vector<index_t>& inStrides,
55  int blkGroupSize,
56  int numBlockTileIteration)
57  {
58  const auto tupleSrcLengths =
59  generate_tuple([&](auto I) { return inLengths[I]; }, Number<Rank>{});
60  const auto tupleSrcStrides =
61  generate_tuple([&](auto I) { return inStrides[I]; }, Number<Rank>{});
62 
63  const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
64 
65  const auto in_grid_desc_m_k = [&]() {
66  if constexpr(reduceAllDim)
67  {
68  const auto one_dim_inDesc = transform_tensor_descriptor(
69  inDesc,
70  make_tuple(make_merge_transform(tupleSrcLengths)),
73 
74  return transform_tensor_descriptor(one_dim_inDesc,
76  1, one_dim_inDesc.GetLength(Number<0>{})))),
79  }
80  else
81  {
82  using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
84 
85  const auto reduceDimLengths = generate_tuple(
86  [&](auto I) { return inLengths[NumInvariantDim + I]; }, Number<NumReduceDim>{});
87  const auto invariantDimLengths =
88  generate_tuple([&](auto I) { return inLengths[I]; }, Number<NumInvariantDim>{});
89 
91  inDesc,
92  make_tuple(make_merge_transform(invariantDimLengths),
93  make_merge_transform(reduceDimLengths)),
94  make_tuple(InvariantDims{}, ReduceDims{}),
96  }
97  }();
98 
99  const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
100  const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
101 
102  const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
103  const auto inPad_M =
104  math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
105  const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
106 
107  auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
108  in_grid_desc_m_k,
109  make_tuple(make_right_pad_transform(invariantLength, inPad_M),
110  make_right_pad_transform(reduceLength, inPad_K)),
113 
114  return (in_grid_desc_m_k_padded);
115  };
116 
117  using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
118 
120  OutDataType,
121  AccDataType,
122  GridDesc_M_K,
123  BlockSize,
124  MThreadClusterSize,
125  KThreadClusterSize,
126  MThreadSliceSize,
127  KThreadSliceSize,
128  InSrcVectorDim,
129  InSrcVectorSize,
130  OutDstVectorSize,
131  false>;
132 
134  OutDataType,
135  AccDataType,
136  GridDesc_M_K,
137  BlockSize,
138  MThreadClusterSize,
139  KThreadClusterSize,
140  MThreadSliceSize,
141  KThreadSliceSize,
142  InSrcVectorDim,
143  InSrcVectorSize,
144  OutDstVectorSize,
145  true>;
146 
147  struct Argument : public BaseArgument
148  {
149  Argument(const std::vector<index_t> inLengths,
150  const std::vector<index_t> inStrides,
151  const std::vector<index_t> reduceDims,
152  double alpha,
153  double beta,
154  const InDataType* in_dev,
155  OutDataType* out_dev,
156  InElementwiseOp in_elementwise_op,
157  AccElementwiseOp acc_elementwise_op)
158  : in_dev_{in_dev},
159  out_dev_{out_dev},
160  in_elementwise_op_{in_elementwise_op},
161  acc_elementwise_op_{acc_elementwise_op}
162  {
163  alpha_ = static_cast<AccDataType>(alpha);
164  beta_ = static_cast<AccDataType>(beta);
165 
166  if(Rank != inLengths.size() || Rank != inStrides.size() ||
167  NumReduceDim != reduceDims.size())
168  {
169  throw std::runtime_error(
170  "One of inLengths/inStrides/reduceDims has invalid size!"
171  "\nExpected size inLengths: " +
172  std::to_string(Rank) + ", inStrides: " + std::to_string(Rank) +
173  ", reduceDims: " + std::to_string(NumReduceDim) +
174  "\nBut have inLengths: " + std::to_string(inLengths.size()) +
175  ", inStrides: " + std::to_string(inStrides.size()) +
176  ", reduceDims: " + std::to_string(reduceDims.size()));
177  }
178 
179  for(std::size_t i = 0; i < reduceDims.size(); ++i)
180  {
181  if(reduceDims[i] < 0 || reduceDims[i] >= Rank)
182  {
183  throw std::runtime_error("Provided reduce dimension exceed input tensor Rank!"
184  "\nHave reduceDims[" +
185  std::to_string(i) +
186  "]: " + std::to_string(reduceDims[i]));
187  }
188  }
189 
190  inLengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inLengths, reduceDims);
191  inStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(inStrides, reduceDims);
192 
193  long_index_t invariant_total_length;
194  long_index_t reduce_total_length;
195 
196  std::tie(invariant_total_length, reduce_total_length) =
197  get_2d_lengths<Rank, NumReduceDim>(inLengths_);
198 
199  if constexpr(NumInvariantDim == 0)
201  else
203 
204  blkGroupSize = 1;
205  numBlockTileIteration = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
206 
207  gridSize = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
209  }
210 
211  std::vector<index_t> inLengths_;
212  std::vector<index_t> inStrides_;
213 
214  AccDataType alpha_;
215  AccDataType beta_;
216 
217  const InDataType* in_dev_;
218  OutDataType* out_dev_;
219 
220  InElementwiseOp in_elementwise_op_;
221  AccElementwiseOp acc_elementwise_op_;
222 
224 
227  size_t gridSize;
228  };
229 
230  struct Invoker : public BaseInvoker
231  {
232  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
233  {
234  const auto in_grid_desc_m_k = DeviceSoftmaxImpl::MakeSrc2dDescriptor(
236  const auto out_grid_desc_m_k = DeviceSoftmaxImpl::MakeSrc2dDescriptor(
238 
239  bool sweep_once =
240  in_grid_desc_m_k.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
241 
242  const auto kernel_main = sweep_once ? kernel_softmax<GridwiseSoftmaxSweepOnce,
243  InDataType,
244  OutDataType,
245  AccDataType,
246  GridDesc_M_K>
248  InDataType,
249  OutDataType,
250  AccDataType,
251  GridDesc_M_K>;
252 
253  float avg_time = 0;
254 
255  avg_time += launch_and_time_kernel(stream_config,
256  kernel_main,
257  dim3(arg.gridSize),
258  dim3(BlockSize),
259  0,
260  in_grid_desc_m_k,
261  out_grid_desc_m_k,
262  arg.blkGroupSize,
264  arg.alpha_,
265  arg.in_dev_,
266  arg.beta_,
267  arg.out_dev_);
268 
269  return (avg_time);
270  };
271 
272  float Run(const BaseArgument* p_arg,
273  const StreamConfig& stream_config = StreamConfig{}) override
274  {
275  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
276  };
277  };
278 
279  static bool IsSupportedArgument(const Argument& arg)
280  {
281  if constexpr(InSrcVectorDim == 0)
282  {
283  if constexpr(NumInvariantDim == 0)
284  {
285  return false;
286  }
287  else
288  {
289  if(arg.inStrides_[NumInvariantDim - 1] != 1 && InSrcVectorSize != 1)
290  {
291  return false;
292  }
293  if(arg.invariant_lowest_length_ % InSrcVectorSize != 0)
294  {
295  return false;
296  }
297  }
298  }
299  else
300  {
301  if(arg.inStrides_[Rank - 1] != 1 && InSrcVectorSize != 1)
302  {
303  return false;
304  }
305  if(arg.inLengths_[Rank - 1] % InSrcVectorSize != 0)
306  {
307  return false;
308  }
309  }
310 
311  // To improve
312  if(NumInvariantDim > 0 && arg.invariant_lowest_length_ % OutDstVectorSize != 0)
313  {
314  return false;
315  }
316 
317  if(arg.inLengths_[Rank - 1] % OutDstVectorSize != 0)
318  {
319  return false;
320  }
321 
322  return true;
323  };
324 
325  bool IsSupportedArgument(const BaseArgument* p_arg) override
326  {
327  return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
328  }
329 
330  static auto MakeArgument(const std::vector<index_t> inLengths,
331  const std::vector<index_t> inStrides,
332  const std::vector<int> reduceDims,
333  double alpha,
334  double beta,
335  const InDataType* in_dev,
336  OutDataType* out_dev,
337  InElementwiseOp in_elementwise_op,
338  AccElementwiseOp acc_elementwise_op)
339  {
340  return Argument{inLengths,
341  inStrides,
342  reduceDims,
343  alpha,
344  beta,
345  in_dev,
346  out_dev,
347  in_elementwise_op,
348  acc_elementwise_op};
349  };
350 
351  //
352  // @brief Makes a pointer to Argument class.
353  //
354  // @param[in] inLengths Input tensor extent(s) from high to low dimension
355  // @param[in] inStrides Input tensor stride(s) from high to low dimension
356  // @param[in] reduceDims The dimension(s) the normalization operation is applied
357  // @param[in] alpha Typeless pointer in host memory storing the alpha scaling
358  // value as type AccDataType
359  // @param[in] beta Typeless pointer in host memory storing the beta scaling
360  // value as type AccDataType
361  // @param[in] in_dev Typeless const pointer in device memory storing the input
362  // tensor
363  // @param out_dev Typeless pointer in device memory storing the output tensor
364  // @param[in] in_elementwise_op The input elementwise operation.
365  // @param[in] acc_elementwise_op The accumulation elementwise operation.
366  //
367  // @return Unique pointer to the Argument class.
368  //
369  std::unique_ptr<BaseArgument> MakeArgumentPointer(const std::vector<index_t> inLengths,
370  const std::vector<index_t> inStrides,
371  const std::vector<int> reduceDims,
372  double alpha,
373  double beta,
374  const void* in_dev,
375  void* out_dev,
376  InElementwiseOp in_elementwise_op,
377  AccElementwiseOp acc_elementwise_op) override
378  {
379  return std::make_unique<Argument>(inLengths,
380  inStrides,
381  reduceDims,
382  alpha,
383  beta,
384  static_cast<const InDataType*>(in_dev),
385  static_cast<OutDataType*>(out_dev),
386  in_elementwise_op,
387  acc_elementwise_op);
388  };
389 
390  static auto MakeInvoker() { return Invoker{}; }
391 
392  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
393  {
394  return std::make_unique<Invoker>();
395  };
396 
397  std::string GetTypeString() const override
398  {
399  auto str = std::stringstream();
400 
401  // clang-format off
402  str << "DeviceReduceSoftmax<"
403  << Rank << "," << NumReduceDim << "," << BlockSize << ","
404  << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ","
405  << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ","
406  << "InSrcVectorDim_" << InSrcVectorDim
407  << "_InSrcVectorSize_" << InSrcVectorSize
408  << "_OutDstVectorSize_" << OutDstVectorSize << ">";
409  // clang-format on
410 
411  return str.str();
412  }
413 };
414 
415 } // namespace device
416 } // namespace tensor_operation
417 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
int64_t long_index_t
Definition: ck.hpp:290
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__global__ void kernel_softmax(const GridDesc_M_K in_grid_desc_m_k, const GridDesc_M_K out_grid_desc_m_k, index_t block_group_size, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global)
Definition: gridwise_softmax.hpp:22
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
Definition: stream_config.hpp:10
Definition: gridwise_softmax.hpp:55
Definition: sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:256
Definition: integral_constant.hpp:10
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_softmax.hpp:24
Definition: device_softmax_impl.hpp:148
std::vector< index_t > inLengths_
Definition: device_softmax_impl.hpp:211
AccDataType alpha_
Definition: device_softmax_impl.hpp:214
index_t invariant_lowest_length_
Definition: device_softmax_impl.hpp:223
AccElementwiseOp acc_elementwise_op_
Definition: device_softmax_impl.hpp:221
const InDataType * in_dev_
Definition: device_softmax_impl.hpp:217
AccDataType beta_
Definition: device_softmax_impl.hpp:215
size_t gridSize
Definition: device_softmax_impl.hpp:227
Argument(const std::vector< index_t > inLengths, const std::vector< index_t > inStrides, const std::vector< index_t > reduceDims, double alpha, double beta, const InDataType *in_dev, OutDataType *out_dev, InElementwiseOp in_elementwise_op, AccElementwiseOp acc_elementwise_op)
Definition: device_softmax_impl.hpp:149
int blkGroupSize
Definition: device_softmax_impl.hpp:225
InElementwiseOp in_elementwise_op_
Definition: device_softmax_impl.hpp:220
OutDataType * out_dev_
Definition: device_softmax_impl.hpp:218
int numBlockTileIteration
Definition: device_softmax_impl.hpp:226
std::vector< index_t > inStrides_
Definition: device_softmax_impl.hpp:212
Definition: device_softmax_impl.hpp:231
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_softmax_impl.hpp:232
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_softmax_impl.hpp:272
Definition: device_softmax_impl.hpp:43
static auto MakeSrc2dDescriptor(const std::vector< index_t > &inLengths, const std::vector< index_t > &inStrides, int blkGroupSize, int numBlockTileIteration)
Definition: device_softmax_impl.hpp:53
static constexpr index_t NumInvariantDim
Definition: device_softmax_impl.hpp:44
static auto MakeInvoker()
Definition: device_softmax_impl.hpp:390
static bool IsSupportedArgument(const Argument &arg)
Definition: device_softmax_impl.hpp:279
static constexpr index_t NumSrcDim
Definition: device_softmax_impl.hpp:46
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_softmax_impl.hpp:392
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_softmax_impl.hpp:325
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::vector< index_t > inLengths, const std::vector< index_t > inStrides, const std::vector< int > reduceDims, double alpha, double beta, const void *in_dev, void *out_dev, InElementwiseOp in_elementwise_op, AccElementwiseOp acc_elementwise_op) override
Definition: device_softmax_impl.hpp:369
static constexpr index_t M_BlockTileSize
Definition: device_softmax_impl.hpp:50
decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)) GridDesc_M_K
Definition: device_softmax_impl.hpp:117
GridwiseSoftmax_mk_to_mk< InDataType, OutDataType, AccDataType, GridDesc_M_K, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize, false > GridwiseSoftmaxGeneric
Definition: device_softmax_impl.hpp:131
static auto MakeArgument(const std::vector< index_t > inLengths, const std::vector< index_t > inStrides, const std::vector< int > reduceDims, double alpha, double beta, const InDataType *in_dev, OutDataType *out_dev, InElementwiseOp in_elementwise_op, AccElementwiseOp acc_elementwise_op)
Definition: device_softmax_impl.hpp:330
static constexpr index_t NumDstDim
Definition: device_softmax_impl.hpp:47
static constexpr index_t K_BlockTileSize
Definition: device_softmax_impl.hpp:51
std::string GetTypeString() const override
Definition: device_softmax_impl.hpp:397
GridwiseSoftmax_mk_to_mk< InDataType, OutDataType, AccDataType, GridDesc_M_K, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, InSrcVectorDim, InSrcVectorSize, OutDstVectorSize, true > GridwiseSoftmaxSweepOnce
Definition: device_softmax_impl.hpp:145
static constexpr bool reduceAllDim
Definition: device_softmax_impl.hpp:48