/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_elementwise_normalization_impl.hpp Source File
device_elementwise_normalization_impl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
9 #include "ck/utility/math.hpp"
10 #include "ck/utility/sequence.hpp"
12 
20 
21 // X = Elementwise(input1, input2, input3, ...)
22 // Y = Normalization(X, beta, gamma)
23 namespace ck {
24 template <typename GridwiseElementwiseReduction,
25  typename InDataTypePointerTuple, // Datatype tuple of inputs
26  typename XDataType, // Datatype of X
27  typename GammaDataType, // Datatype of Gamma
28  typename BetaDataType, // Datatype of Beta
29  typename YDataType, // Datatype of Y
30  typename AccDataType, // AccDatatype
31  typename XElementwiseOperation, // Operation of input
32  typename YElementwiseOperation, // Operation of output of normalization
33  typename InGrid2dDescTuple, // Descriptor tuple of inputs
34  typename GridDesc_M_K> // Descriptor of inputs, Gamma, Beta
36  const InGrid2dDescTuple in_grid_2d_desc_tuple, // Descriptor tuple of inputs
37  const GridDesc_M_K x_grid_desc_m_k, // Descriptor of X
38  const GridDesc_M_K gamma_grid_desc_m_k, // Descriptor of gamma
39  const GridDesc_M_K beta_grid_desc_m_k, // Descriptor of beta
40  const GridDesc_M_K y_grid_desc_m_k, // Descriptor of Y
41  index_t num_k_block_tile_iteration, //
42  AccDataType epsilon, // Datatype of epsilon
43  const InDataTypePointerTuple p_in_global_tuple, // Ptr tuple of input matrixs
44  const GammaDataType* const __restrict__ p_gamma_global, // Ptr of gamma
45  const BetaDataType* const __restrict__ p_beta_global, // Ptr of beta
46  YDataType* const __restrict__ p_y_global, // Ptr of y
47  const XElementwiseOperation x_elementwise_op, // Operation of input
48  const YElementwiseOperation y_elementwise_op) // Operation of output of normalization
49 {
50  extern __shared__ XDataType p_x_lds[];
51  GridwiseElementwiseReduction::Run(in_grid_2d_desc_tuple, // Descriptor tuple of inputs
52  x_grid_desc_m_k, // Descriptor of X
53  gamma_grid_desc_m_k, // Descriptor of Gamma
54  beta_grid_desc_m_k, // Descriptor of Beta
55  y_grid_desc_m_k, // Descriptor of Y
56  num_k_block_tile_iteration, //
57  epsilon, // epsilon
58  p_in_global_tuple, // Ptr tuple of inputs
59  p_x_lds, // Ptr of X
60  p_gamma_global, // Ptr of gamma
61  p_beta_global, // Ptr of beta
62  p_y_global, // Ptr of Y
63  x_elementwise_op, // Operation of input
64  y_elementwise_op); // Operation of output of normalization
65 };
66 } // namespace ck
67 
68 namespace ck {
69 namespace tensor_operation {
70 namespace device {
71 
72 // Y = LayerNorm(A + B, Beta, Gamma)
73 template <typename InDataTypeTuple, // Datatype of inputs
74  typename GammaDataType, // Datatype of gamma
75  typename BetaDataType, // Datatype of beta
76  typename AccDataType, //
77  typename YDataType, //
78  typename XElementwiseOperation, //
79  typename YElementwiseOperation, //
80  index_t Rank, //
81  index_t NumReduceDim, //
82  index_t BlockSize, //
83  index_t MThreadClusterSize, // Num of threads in a block on M direction
84  index_t KThreadClusterSize, // Num of threads in a block on N direction
85  index_t MThreadSliceSize, // Each thread calculate rows
86  index_t KThreadSliceSize, // Each thread calculate columns
87  index_t XYSrcVectorDim, // Dimension to do reduce
88  index_t XSrcVectorSize, // Size to fetch source x
89  index_t GammaSrcVectorDim, // Dimension for gamma to do reduce
90  index_t GammaSrcVectorSize, // Size to fetch source gamma
91  index_t BetaSrcVectorDim, // Dimension for beta to do reduce
92  index_t BetaSrcVectorSize, // Size to fetch source beta
93  index_t YDstVectorSize> // Size to write destination Y
95  : public DeviceElementwiseNormalization<InDataTypeTuple,
96  GammaDataType,
97  BetaDataType,
98  AccDataType,
99  YDataType,
100  XElementwiseOperation,
101  YElementwiseOperation,
102  Rank,
103  NumReduceDim>
104 {
105  static constexpr int NumInput = InDataTypeTuple::Size();
106 
107  using XDataType = YDataType;
108 
109  static_assert(
110  (KThreadSliceSize % GammaSrcVectorSize == 0),
111  "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!");
112 
113  static_assert(
114  (KThreadSliceSize % BetaSrcVectorSize == 0),
115  "Invalid thread slice sizes and/or beta vector sizes configuration, please check!");
116 
117  static constexpr index_t M_BlockTileSize =
118  MThreadClusterSize * MThreadSliceSize; // num of rows calculated in a block
119  static constexpr index_t K_BlockTileSize =
120  KThreadClusterSize * KThreadSliceSize; // num of columns calculated in a block
121 
123  {
124  return generate_tuple(
125  [&](auto I) {
126  using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
127  return static_cast<const DataType*>(nullptr);
128  },
129  Number<NumInput>{});
130  };
131 
133 
134  static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
135  const std::vector<index_t>& inStrides,
136  int blkGroupSize,
137  int numBlockTileIteration)
138  {
139  constexpr index_t NumInvariantDim = Rank - NumReduceDim;
140  static constexpr index_t numSrcDim = Rank;
141  static constexpr bool reduceAllDim = (NumInvariantDim == 0);
142 
143  const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
144  const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
145 
146  const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
147 
148  const auto in_grid_desc_m_k = [&]() {
149  if constexpr(reduceAllDim)
150  {
151  const auto one_dim_inDesc = transform_tensor_descriptor(
152  inDesc,
153  make_tuple(make_merge_transform(tupleSrcLengths)),
156 
157  return transform_tensor_descriptor(one_dim_inDesc,
159  1, one_dim_inDesc.GetLength(Number<0>{})))),
162  }
163  else
164  {
165  using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
167 
168  const auto reduceDimLengths =
169  make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
170  const auto invariantDimLengths =
171  make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
172 
174  inDesc,
175  make_tuple(make_merge_transform(invariantDimLengths),
176  make_merge_transform(reduceDimLengths)),
177  make_tuple(InvariantDims{}, ReduceDims{}),
179  }
180  }();
181 
182  const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
183  const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
184 
185  const int reduceSizePerBlock = K_BlockTileSize * numBlockTileIteration;
186  const auto inPad_M =
187  math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
188  const auto inPad_K = reduceSizePerBlock * blkGroupSize - reduceLength;
189 
190  auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
191  in_grid_desc_m_k,
192  make_tuple(make_right_pad_transform(invariantLength, inPad_M),
193  make_right_pad_transform(reduceLength, inPad_K)),
196 
197  return (in_grid_desc_m_k_padded);
198  };
199 
200  template <index_t TupleSize>
202  {
203  return generate_tuple([&](auto) { return MakeSrc2dDescriptor({1}, {1}, 1, 1); },
205  };
206 
208 
209  using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1));
210 
213  XDataType,
214  GammaDataType,
215  BetaDataType,
216  YDataType,
217  AccDataType,
218  XElementwiseOperation,
219  YElementwiseOperation,
221  GridDesc_M_K,
222  BlockSize,
223  MThreadClusterSize,
224  KThreadClusterSize,
225  MThreadSliceSize,
226  KThreadSliceSize,
227  XYSrcVectorDim,
228  XSrcVectorSize,
229  GammaSrcVectorDim,
230  GammaSrcVectorSize,
231  BetaSrcVectorDim,
232  BetaSrcVectorSize,
233  XYSrcVectorDim,
234  YDstVectorSize,
235  false>;
236 
239  XDataType,
240  GammaDataType,
241  BetaDataType,
242  YDataType,
243  AccDataType,
244  XElementwiseOperation,
245  YElementwiseOperation,
247  GridDesc_M_K,
248  BlockSize,
249  MThreadClusterSize,
250  KThreadClusterSize,
251  MThreadSliceSize,
252  KThreadSliceSize,
253  XYSrcVectorDim,
254  XSrcVectorSize,
255  GammaSrcVectorDim,
256  GammaSrcVectorSize,
257  BetaSrcVectorDim,
258  BetaSrcVectorSize,
259  XYSrcVectorDim,
260  YDstVectorSize,
261  true>;
262 
263  struct Argument : public BaseArgument
264  {
265  Argument(const std::vector<index_t> lengths,
266  const std::array<std::vector<index_t>, NumInput> inStridesArray,
267  const std::vector<index_t> gammaStrides,
268  const std::vector<index_t> betaStrides,
269  const std::vector<index_t> yStrides,
270  const std::vector<index_t> reduceDims,
271  XElementwiseOperation x_elementwise_op,
272  YElementwiseOperation y_elementwise_op,
273  double epsilon,
274  const std::array<const void*, NumInput> in_dev_buffers,
275  const GammaDataType* p_gamma,
276  const BetaDataType* p_beta,
277  YDataType* p_y)
278  : p_gamma_(p_gamma),
279  p_beta_(p_beta),
280  p_y_(p_y),
281  x_elementwise_op_(x_elementwise_op),
282  y_elementwise_op_(y_elementwise_op)
283  {
284  epsilon_ = static_cast<AccDataType>(epsilon);
285 
286  Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
287  for(int i = 0; i < NumInput; i++)
288  {
289  inStridesArray_[i] =
290  shuffle_tensor_dimensions<Rank, NumReduceDim>(inStridesArray[i], reduceDims);
291  }
292 
293  yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims);
294  xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims);
295 
296  gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims);
297  betaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(betaStrides, reduceDims);
298 
300  [&](auto I) {
301  using DataType = remove_cvref_t<decltype(InDataTypeTuple{}[I])>;
302  return static_cast<const DataType*>(in_dev_buffers[I.value]);
303  },
304  Number<NumInput>{});
305 
306  long_index_t invariant_total_length;
307  long_index_t reduce_total_length;
308 
309  std::tie(invariant_total_length, reduce_total_length) =
310  get_2d_lengths<Rank, NumReduceDim>(Lengths_);
311 
312  blkGroupSize_ = 1;
313  numBlockTileIteration_ = (reduce_total_length + K_BlockTileSize - 1) / K_BlockTileSize;
314 
315  gridSize_ = math::integer_least_multiple(invariant_total_length, M_BlockTileSize) /
317 
319  [&](auto I) {
320  return MakeSrc2dDescriptor(
322  },
323  Number<NumInput>{});
324 
327 
330 
333 
336 
337  sweep_once_ =
338  x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
339 
340  if(!sweep_once_) // if not sweep once, compute memory size for matrix X in lds for
341  // store Intermediate results
342  {
343  int block_TileSize = M_BlockTileSize * reduce_total_length;
344  x_lds_size_ = block_TileSize * sizeof(XDataType);
345  }
346  else
347  x_lds_size_ = 0;
348  }
349 
350  AccDataType epsilon_;
351 
353  const GammaDataType* p_gamma_;
354  const BetaDataType* p_beta_;
355  YDataType* p_y_;
356 
357  std::vector<index_t> Lengths_;
358  std::array<std::vector<index_t>, NumInput> inStridesArray_;
359  std::vector<index_t> xStrides_;
360  std::vector<index_t> gammaStrides_;
361  std::vector<index_t> betaStrides_;
362  std::vector<index_t> yStrides_;
363 
364  XElementwiseOperation x_elementwise_op_;
365  YElementwiseOperation y_elementwise_op_;
366 
369  size_t gridSize_;
370 
378  };
379 
380  struct Invoker : public BaseInvoker
381  {
382  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
383  {
384  const auto kernel_main =
387  XDataType,
388  GammaDataType,
389  BetaDataType,
390  YDataType,
391  AccDataType,
392  XElementwiseOperation,
393  YElementwiseOperation,
395  GridDesc_M_K>
398  XDataType,
399  GammaDataType,
400  BetaDataType,
401  YDataType,
402  AccDataType,
403  XElementwiseOperation,
404  YElementwiseOperation,
406  GridDesc_M_K>;
407 
408  float avg_time = 0;
409  avg_time += launch_and_time_kernel(stream_config,
410  kernel_main,
411  dim3(arg.gridSize_),
412  dim3(BlockSize),
413  arg.x_lds_size_,
415  arg.x_grid_desc_m_k_,
418  arg.y_grid_desc_m_k_,
420  arg.epsilon_,
421  arg.in_dev_buffers_,
422  arg.p_gamma_,
423  arg.p_beta_,
424  arg.p_y_,
425  arg.x_elementwise_op_,
426  arg.y_elementwise_op_);
427 
428  return (avg_time);
429  };
430 
431  float Run(const BaseArgument* p_arg,
432  const StreamConfig& stream_config = StreamConfig{}) override
433  {
434  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
435  };
436  };
437 
438  bool IsSupportedArgument(const BaseArgument* p_arg) override
439  {
440  const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
441 
442  constexpr index_t NumInvariantDim = Rank - NumReduceDim;
443 
444  if constexpr(XYSrcVectorDim == 0)
445  {
446  if constexpr(NumInvariantDim == 0)
447  {
448  return false;
449  }
450  else
451  {
452  for(int i = 0; i < NumInput; i++)
453  {
454  if(p_arg_->inStridesArray_[i][NumInvariantDim - 1] != 1)
455  return false;
456  }
457 
458  if(p_arg_->inStridesArray_[0][NumInvariantDim - 1] != 1 &&
459  p_arg_->inStridesArray_[1][NumInvariantDim - 1] != 1)
460  return false;
461 
462  if(p_arg_->invariant_lowest_length % XSrcVectorSize != 0)
463  return false;
464  };
465  }
466  else
467  {
468  for(int i = 0; i < NumInput; i++)
469  {
470  if(p_arg_->inStridesArray_[i][Rank - 1] != 1)
471  return false;
472  }
473 
474  if(p_arg_->Lengths_[Rank - 1] % XSrcVectorSize != 0)
475  return false;
476  };
477 
478  if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0)
479  {
480  return false;
481  }
482 
483  auto IsScalarPerVectorValid = [](bool isLastDimensionCoalesced, int scalarPerVector) {
484  bool ret = true;
485 
486  if(!isLastDimensionCoalesced)
487  ret = scalarPerVector == 1;
488  else
489  ret = KThreadSliceSize % scalarPerVector == 0;
490 
491  return ret;
492  };
493 
494  if(!IsScalarPerVectorValid(p_arg_->gammaStrides_.back() == 1, GammaSrcVectorSize))
495  return false;
496 
497  if(!IsScalarPerVectorValid(p_arg_->betaStrides_.back() == 1, BetaSrcVectorSize))
498  return false;
499 
500  // if fastest dim is not reduced
501  if constexpr(XYSrcVectorDim == 0) //
502  {
503  if(p_arg_->gammaStrides_[NumInvariantDim - 1] != 1)
504  return (false);
505 
506  if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0)
507  return (false);
508  }
509  else // if fastest dim is reduced
510  {
511  if(p_arg_->gammaStrides_[Rank - 1] != 1)
512  return (false);
513 
514  if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0)
515  return (false);
516  }
517 
518  // if fastest dim is not reduced
519  if constexpr(XYSrcVectorDim == 0)
520  {
521  if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1)
522  return (false);
523 
524  if(p_arg_->invariant_lowest_length % BetaSrcVectorSize != 0)
525  return (false);
526  }
527  else // if fastest dim is reduced
528  {
529  if(p_arg_->betaStrides_[Rank - 1] != 1)
530  return (false);
531 
532  if(p_arg_->Lengths_[Rank - 1] % BetaSrcVectorSize != 0)
533  return (false);
534  }
535 
536  if(p_arg_->x_lds_size_ >= 65536)
537  {
538  return (false);
539  }
540 
541  return true;
542  };
543 
544  std::unique_ptr<BaseArgument>
545  MakeArgumentPointer(const std::vector<index_t> lengths,
546  const std::array<std::vector<index_t>, NumInput> inStridesArray,
547  const std::vector<index_t> gammaStrides,
548  const std::vector<index_t> betaStrides,
549  const std::vector<index_t> yStrides,
550  const std::vector<index_t> reduceDims,
551  double epsilon,
552  const std::array<const void*, NumInput> in_dev_buffers,
553  const void* p_gamma,
554  const void* p_beta,
555  void* p_y,
556  XElementwiseOperation x_elementwise_op,
557  YElementwiseOperation y_elementwise_op) override
558  {
559  return std::make_unique<Argument>(lengths,
560  inStridesArray,
561  gammaStrides,
562  betaStrides,
563  yStrides,
564  reduceDims,
565  x_elementwise_op,
566  y_elementwise_op,
567  epsilon,
568  in_dev_buffers,
569  static_cast<const GammaDataType*>(p_gamma),
570  static_cast<const BetaDataType*>(p_beta),
571  static_cast<YDataType*>(p_y));
572  };
573 
574  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
575  {
576  return std::make_unique<Invoker>();
577  };
578 
579  std::string GetTypeString() const override
580  {
581  auto str = std::stringstream();
582 
583  // clang-format off
584  str << "DeviceElementwiseNormalizationImpl<" << BlockSize << ",";
585  str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
586  str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
587  str << "XYSrcVectorDim_" << XYSrcVectorDim << ",";
588  str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">";
589  // clang-format on
590 
591  return str.str();
592  }
593 };
594 
595 } // namespace device
596 } // namespace tensor_operation
597 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
auto make_tuple_from_array(const std::vector< index_t > &lengths, Number< arraySize >)
Definition: device_reduce_common.hpp:65
auto make_tuple_from_array_and_index_seq(const std::vector< index_t > &lengths, Sequence< Ns... >)
Definition: device_reduce_common.hpp:59
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
int64_t long_index_t
Definition: ck.hpp:290
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__global__ void kernel_elementwise_layernorm(const InGrid2dDescTuple in_grid_2d_desc_tuple, const GridDesc_M_K x_grid_desc_m_k, const GridDesc_M_K gamma_grid_desc_m_k, const GridDesc_M_K beta_grid_desc_m_k, const GridDesc_M_K y_grid_desc_m_k, index_t num_k_block_tile_iteration, AccDataType epsilon, const InDataTypePointerTuple p_in_global_tuple, const GammaDataType *const __restrict__ p_gamma_global, const BetaDataType *const __restrict__ p_beta_global, YDataType *const __restrict__ p_y_global, const XElementwiseOperation x_elementwise_op, const YElementwiseOperation y_elementwise_op)
Definition: device_elementwise_normalization_impl.hpp:35
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
Definition: stream_config.hpp:10
Definition: gridwise_elementwise_layernorm_welford_variance.hpp:42
Definition: sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:256
Definition: integral_constant.hpp:10
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_elementwise_normalization.hpp:25
Definition: device_elementwise_normalization_impl.hpp:264
std::array< std::vector< index_t >, NumInput > inStridesArray_
Definition: device_elementwise_normalization_impl.hpp:358
YElementwiseOperation y_elementwise_op_
Definition: device_elementwise_normalization_impl.hpp:365
AccDataType epsilon_
Definition: device_elementwise_normalization_impl.hpp:350
GridDesc_M_K gamma_grid_desc_m_k_
Definition: device_elementwise_normalization_impl.hpp:373
size_t gridSize_
Definition: device_elementwise_normalization_impl.hpp:369
GridDesc_M_K y_grid_desc_m_k_
Definition: device_elementwise_normalization_impl.hpp:375
XElementwiseOperation x_elementwise_op_
Definition: device_elementwise_normalization_impl.hpp:364
std::vector< index_t > betaStrides_
Definition: device_elementwise_normalization_impl.hpp:361
std::vector< index_t > gammaStrides_
Definition: device_elementwise_normalization_impl.hpp:360
bool sweep_once_
Definition: device_elementwise_normalization_impl.hpp:376
int x_lds_size_
Definition: device_elementwise_normalization_impl.hpp:377
int blkGroupSize_
Definition: device_elementwise_normalization_impl.hpp:367
InGrid2dDescTuple in_grid_2d_desc_tuple_
Definition: device_elementwise_normalization_impl.hpp:371
YDataType * p_y_
Definition: device_elementwise_normalization_impl.hpp:355
std::vector< index_t > Lengths_
Definition: device_elementwise_normalization_impl.hpp:357
GridDesc_M_K x_grid_desc_m_k_
Definition: device_elementwise_normalization_impl.hpp:372
GridDesc_M_K beta_grid_desc_m_k_
Definition: device_elementwise_normalization_impl.hpp:374
std::vector< index_t > yStrides_
Definition: device_elementwise_normalization_impl.hpp:362
InDataTypePointerTuple in_dev_buffers_
Definition: device_elementwise_normalization_impl.hpp:352
std::vector< index_t > xStrides_
Definition: device_elementwise_normalization_impl.hpp:359
const GammaDataType * p_gamma_
Definition: device_elementwise_normalization_impl.hpp:353
int numBlockTileIteration_
Definition: device_elementwise_normalization_impl.hpp:368
const BetaDataType * p_beta_
Definition: device_elementwise_normalization_impl.hpp:354
Argument(const std::vector< index_t > lengths, const std::array< std::vector< index_t >, NumInput > inStridesArray, const std::vector< index_t > gammaStrides, const std::vector< index_t > betaStrides, const std::vector< index_t > yStrides, const std::vector< index_t > reduceDims, XElementwiseOperation x_elementwise_op, YElementwiseOperation y_elementwise_op, double epsilon, const std::array< const void *, NumInput > in_dev_buffers, const GammaDataType *p_gamma, const BetaDataType *p_beta, YDataType *p_y)
Definition: device_elementwise_normalization_impl.hpp:265
Definition: device_elementwise_normalization_impl.hpp:381
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_elementwise_normalization_impl.hpp:431
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_elementwise_normalization_impl.hpp:382
Definition: device_elementwise_normalization_impl.hpp:104
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_elementwise_normalization_impl.hpp:574
std::string GetTypeString() const override
Definition: device_elementwise_normalization_impl.hpp:579
static constexpr index_t M_BlockTileSize
Definition: device_elementwise_normalization_impl.hpp:117
static auto GenerateSrcGrid2dDescTuple(Number< TupleSize >)
Definition: device_elementwise_normalization_impl.hpp:201
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::vector< index_t > lengths, const std::array< std::vector< index_t >, NumInput > inStridesArray, const std::vector< index_t > gammaStrides, const std::vector< index_t > betaStrides, const std::vector< index_t > yStrides, const std::vector< index_t > reduceDims, double epsilon, const std::array< const void *, NumInput > in_dev_buffers, const void *p_gamma, const void *p_beta, void *p_y, XElementwiseOperation x_elementwise_op, YElementwiseOperation y_elementwise_op) override
Definition: device_elementwise_normalization_impl.hpp:545
static constexpr index_t K_BlockTileSize
Definition: device_elementwise_normalization_impl.hpp:119
YDataType XDataType
Definition: device_elementwise_normalization_impl.hpp:107
decltype(GenerateInDataTypePointerTuple()) InDataTypePointerTuple
Definition: device_elementwise_normalization_impl.hpp:132
static auto GenerateInDataTypePointerTuple()
Definition: device_elementwise_normalization_impl.hpp:122
GridwiseElementwiseLayernormWelfordVariance_mk_to_mk< InDataTypePointerTuple, XDataType, GammaDataType, BetaDataType, YDataType, AccDataType, XElementwiseOperation, YElementwiseOperation, InGrid2dDescTuple, GridDesc_M_K, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, XYSrcVectorDim, YDstVectorSize, false > GridwiseReduceLayernormGeneric
Definition: device_elementwise_normalization_impl.hpp:235
GridwiseElementwiseLayernormWelfordVariance_mk_to_mk< InDataTypePointerTuple, XDataType, GammaDataType, BetaDataType, YDataType, AccDataType, XElementwiseOperation, YElementwiseOperation, InGrid2dDescTuple, GridDesc_M_K, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XYSrcVectorDim, XSrcVectorSize, GammaSrcVectorDim, GammaSrcVectorSize, BetaSrcVectorDim, BetaSrcVectorSize, XYSrcVectorDim, YDstVectorSize, true > GridwiseReduceLayernormSweepOnce
Definition: device_elementwise_normalization_impl.hpp:261
decltype(MakeSrc2dDescriptor({1}, {1}, 1, 1)) GridDesc_M_K
Definition: device_elementwise_normalization_impl.hpp:209
static constexpr int NumInput
Definition: device_elementwise_normalization_impl.hpp:105
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_elementwise_normalization_impl.hpp:438
static auto MakeSrc2dDescriptor(const std::vector< index_t > &inLengths, const std::vector< index_t > &inStrides, int blkGroupSize, int numBlockTileIteration)
Definition: device_elementwise_normalization_impl.hpp:134
decltype(GenerateSrcGrid2dDescTuple(Number< NumInput >{})) InGrid2dDescTuple
Definition: device_elementwise_normalization_impl.hpp:207