/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_normalization_fwd_impl.hpp Source File
device_normalization_fwd_impl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
16 
17 namespace ck {
18 namespace tensor_operation {
19 namespace device {
20 
21 // Y = Normalization(X, Beta, Gamma)
22 // M: Invariant length
23 // K: Reduce length (Calculate mean and variance along K dimension)
24 // eg. Length = [N, C, H, W], reduce dim = [C, H, W]
25 // Then, M = N, K = C * H * W
26 template <typename XDataType,
27  typename GammaDataType,
28  typename BetaDataType,
29  typename ComputeDataType,
30  typename YDataType,
31  typename SaveMeanInvStdDataType,
32  typename YElementwiseOperation,
33  index_t Rank,
34  index_t NumReduceDim,
35  index_t BlockSize,
36  index_t MThreadClusterSize,
37  index_t KThreadClusterSize,
38  index_t MThreadSliceSize,
39  index_t KThreadSliceSize,
40  index_t XYSrcVectorDim,
41  index_t XSrcVectorSize,
42  index_t GammaSrcVectorDim,
43  index_t GammaSrcVectorSize,
44  index_t BetaSrcVectorDim,
45  index_t BetaSrcVectorSize,
46  index_t YDstVectorSize,
47  index_t SaveMeanInvStdDstVectorSize,
48  bool UseWelford = true>
50  GammaDataType,
51  BetaDataType,
52  YDataType,
53  SaveMeanInvStdDataType,
54  YElementwiseOperation,
55  Rank,
56  NumReduceDim>
57 {
58  static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize);
59  static_assert(
60  ((GammaSrcVectorDim == 0 && MThreadSliceSize % GammaSrcVectorSize == 0) ||
61  (GammaSrcVectorDim == 1 && KThreadSliceSize % GammaSrcVectorSize == 0)),
62  "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!");
63 
64  static_assert(
65  ((BetaSrcVectorDim == 0 && MThreadSliceSize % BetaSrcVectorSize == 0) ||
66  (BetaSrcVectorDim == 1 && KThreadSliceSize % BetaSrcVectorSize == 0)),
67  "Invalid thread slice sizes and/or beta vector sizes configuration, please check!");
68 
69  static_assert(MThreadSliceSize % SaveMeanInvStdDstVectorSize == 0,
70  "Invalid thread slice sizes and/or save mean and inverse std vector sizes "
71  "configuration, please check!");
72 
74 
75  static constexpr index_t NumInvariantDim = Rank - NumReduceDim;
76  static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
77  static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
78 
79  static constexpr bool reduceAllDim = (NumInvariantDim == 0);
80  static_assert(!reduceAllDim); // TODO
81 
82  static auto MakeSrc2dDescriptor(const std::vector<index_t>& inLengths,
83  const std::vector<index_t>& inStrides,
84  int numBlockTileIteration)
85  {
86  static constexpr index_t numSrcDim = Rank;
87 
88  const auto tupleSrcLengths = make_tuple_from_array(inLengths, Number<numSrcDim>{});
89  const auto tupleSrcStrides = make_tuple_from_array(inStrides, Number<numSrcDim>{});
90 
91  const auto inDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
92 
93  const auto in_grid_desc_m_k = [&]() {
94  if constexpr(reduceAllDim)
95  {
96  const auto one_dim_inDesc = transform_tensor_descriptor(
97  inDesc,
98  make_tuple(make_merge_transform(tupleSrcLengths)),
101 
102  return transform_tensor_descriptor(one_dim_inDesc,
104  1, one_dim_inDesc.GetLength(Number<0>{})))),
107  }
108  else
109  {
110  using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
112 
113  const auto reduceDimLengths =
114  make_tuple_from_array_and_index_seq(inLengths, ReduceDims{});
115  const auto invariantDimLengths =
116  make_tuple_from_array_and_index_seq(inLengths, InvariantDims{});
117 
119  inDesc,
120  make_tuple(make_merge_transform(invariantDimLengths),
121  make_merge_transform(reduceDimLengths)),
122  make_tuple(InvariantDims{}, ReduceDims{}),
124  }
125  }();
126 
127  const auto invariantLength = in_grid_desc_m_k.GetLength(Number<0>{});
128  const auto reduceLength = in_grid_desc_m_k.GetLength(Number<1>{});
129 
130  const auto inPad_M =
131  math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
132  const auto inPad_K = K_BlockTileSize * numBlockTileIteration - reduceLength;
133 
134  auto in_grid_desc_m_k_padded = transform_tensor_descriptor(
135  in_grid_desc_m_k,
136  make_tuple(make_right_pad_transform(invariantLength, inPad_M),
137  make_right_pad_transform(reduceLength, inPad_K)),
140 
141  return (in_grid_desc_m_k_padded);
142  };
143 
144  static auto MakeSaveMeanInvStdDescriptor_M(const std::vector<index_t>& lengths,
145  const std::vector<index_t>& strides)
146  {
147  using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
148 
149  const auto tupleSrcLengths = make_tuple_from_array_and_index_seq(lengths, InvariantDims{});
150  const auto tupleSrcStrides = make_tuple_from_array_and_index_seq(strides, InvariantDims{});
151 
152  const auto desc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides);
153 
154  const auto grid_desc_m =
156  make_tuple(make_merge_transform(tupleSrcLengths)),
157  make_tuple(InvariantDims{}),
159 
160  const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
161  const auto pad_M =
162  math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
163 
164  auto grid_desc_m_padded = transform_tensor_descriptor(
165  grid_desc_m,
166  make_tuple(make_right_pad_transform(invariantLength, pad_M)),
169 
170  return grid_desc_m_padded;
171  }
172 
173  using GridDesc_M_K = decltype(MakeSrc2dDescriptor({1}, {1}, 1));
174  using GridDesc_M = decltype(MakeSaveMeanInvStdDescriptor_M({1}, {1}));
175 
176  struct Argument : public BaseArgument
177  {
178  Argument(const std::vector<index_t> lengths,
179  const std::vector<index_t> xStrides,
180  const std::vector<index_t> gammaStrides,
181  const std::vector<index_t> betaStrides,
182  const std::vector<index_t> yStrides,
183  const std::vector<index_t> saveMeanStrides,
184  const std::vector<index_t> saveInvStdStrides,
185  const std::vector<index_t> reduceDims,
186  YElementwiseOperation y_elementwise_op,
187  double epsilon,
188  const XDataType* p_x,
189  const GammaDataType* p_gamma,
190  const BetaDataType* p_beta,
191  YDataType* p_y,
192  SaveMeanInvStdDataType* p_saveMean,
193  SaveMeanInvStdDataType* p_saveInvStd)
194  : p_x_(p_x),
195  p_gamma_(p_gamma),
196  p_beta_(p_beta),
197  p_y_(p_y),
198  p_saveMean_(p_saveMean),
199  p_saveInvStd_(p_saveInvStd),
200  y_elementwise_op_(y_elementwise_op)
201  {
202  epsilon_ = static_cast<ComputeDataType>(epsilon);
203 
204  Lengths_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(lengths, reduceDims);
205  xStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(xStrides, reduceDims);
206  yStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(yStrides, reduceDims);
207  gammaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(gammaStrides, reduceDims);
208  betaStrides_ = shuffle_tensor_dimensions<Rank, NumReduceDim>(betaStrides, reduceDims);
209  saveMeanStrides_ = saveMeanStrides;
210  saveInvStdStrides_ = saveInvStdStrides;
211 
212  std::tie(MRaw_, KRaw_) = get_2d_lengths<Rank, NumReduceDim>(Lengths_);
213 
215 
217 
226 
227  isSweeponce_ =
228  x_grid_desc_m_k_.GetLength(Number<1>{}) <= KThreadClusterSize * KThreadSliceSize;
229 
230  if constexpr(NumInvariantDim == 0)
232  else
234  }
235 
236  ComputeDataType epsilon_;
237 
238  const XDataType* p_x_;
239  const GammaDataType* p_gamma_;
240  const BetaDataType* p_beta_;
241  YDataType* p_y_;
242  SaveMeanInvStdDataType* p_saveMean_;
243  SaveMeanInvStdDataType* p_saveInvStd_;
244 
245  std::vector<index_t> Lengths_;
246  std::vector<index_t> xStrides_;
247  std::vector<index_t> gammaStrides_;
248  std::vector<index_t> betaStrides_;
249  std::vector<index_t> yStrides_;
250  std::vector<index_t> saveMeanStrides_;
251  std::vector<index_t> saveInvStdStrides_;
252 
253  YElementwiseOperation y_elementwise_op_;
254 
256  size_t gridSize_;
257 
265 
266  index_t MRaw_; // Invariant length
267  index_t KRaw_; // reduce length
268 
270  };
271 
272  struct Invoker : public BaseInvoker
273  {
274  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
275  {
276  auto kernel_main = NormalizationKernelSelector<XDataType,
277  GammaDataType,
278  BetaDataType,
279  YDataType,
280  SaveMeanInvStdDataType,
281  ComputeDataType,
282  YElementwiseOperation,
283  GridDesc_M_K,
284  GridDesc_M,
285  BlockSize,
286  MThreadClusterSize,
287  KThreadClusterSize,
288  MThreadSliceSize,
289  KThreadSliceSize,
290  XYSrcVectorDim,
291  XSrcVectorSize,
292  GammaSrcVectorDim,
293  GammaSrcVectorSize,
294  BetaSrcVectorDim,
295  BetaSrcVectorSize,
296  XYSrcVectorDim,
297  YDstVectorSize,
298  SaveMeanInvStdDstVectorSize,
299  UseWelford>(arg.isSweeponce_);
300 
301  float avg_time = 0;
302  avg_time += launch_and_time_kernel(stream_config,
303  kernel_main,
304  dim3(arg.gridSize_),
305  dim3(BlockSize),
306  0,
307  arg.x_grid_desc_m_k_,
310  arg.y_grid_desc_m_k_,
314  arg.epsilon_,
315  arg.p_x_,
316  arg.p_gamma_,
317  arg.p_beta_,
318  arg.p_y_,
319  arg.p_saveMean_,
320  arg.p_saveInvStd_,
321  arg.y_elementwise_op_);
322 
323  return (avg_time);
324  };
325 
326  float Run(const BaseArgument* p_arg,
327  const StreamConfig& stream_config = StreamConfig{}) override
328  {
329  return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
330  };
331  };
332 
333  bool IsSupportedArgument(const BaseArgument* p_arg) override
334  {
335  const Argument* p_arg_ = dynamic_cast<const Argument*>(p_arg);
336 
337  if constexpr(XYSrcVectorDim == 0)
338  {
339  if constexpr(NumInvariantDim == 0)
340  {
341  return false;
342  }
343  else
344  {
345  if(p_arg_->xStrides_[NumInvariantDim - 1] != 1)
346  return false;
347 
348  if(p_arg_->invariant_lowest_length_ % XSrcVectorSize != 0)
349  return false;
350 
351  if(p_arg_->invariant_lowest_length_ % YDstVectorSize != 0)
352  return false;
353  };
354  }
355  else
356  {
357  if(p_arg_->xStrides_[Rank - 1] != 1)
358  return false;
359 
360  if(p_arg_->Lengths_[Rank - 1] % XSrcVectorSize != 0)
361  return false;
362 
363  if(p_arg_->Lengths_[Rank - 1] % YDstVectorSize != 0)
364  {
365  return false;
366  }
367  };
368 
369  // if fastest dim is not reduced
370  if constexpr(GammaSrcVectorDim == 0)
371  {
372  if(p_arg_->gammaStrides_[NumInvariantDim - 1] != 1)
373  return (false);
374 
375  if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0)
376  return (false);
377  }
378  else // if fastest dim is reduced
379  {
380  if(p_arg_->gammaStrides_[Rank - 1] != 1)
381  return (false);
382 
383  if(p_arg_->Lengths_[Rank - 1] % GammaSrcVectorSize != 0)
384  return (false);
385  }
386 
387  // if fastest dim is not reduced
388  if constexpr(BetaSrcVectorDim == 0)
389  {
390  if(p_arg_->betaStrides_[NumInvariantDim - 1] != 1)
391  return (false);
392 
393  if(p_arg_->invariant_lowest_length_ % BetaSrcVectorSize != 0)
394  return (false);
395  }
396  else // if fastest dim is reduced
397  {
398  if(p_arg_->betaStrides_[Rank - 1] != 1)
399  return (false);
400 
401  if(p_arg_->Lengths_[Rank - 1] % BetaSrcVectorSize != 0)
402  return (false);
403  }
404 
405  if(p_arg_->invariant_lowest_length_ % SaveMeanInvStdDstVectorSize != 0)
406  return false;
407 
408  return true;
409  };
410 
411  std::unique_ptr<BaseArgument>
412  MakeArgumentPointer(const std::vector<index_t> lengths,
413  const std::vector<index_t> xStrides,
414  const std::vector<index_t> gammaStrides,
415  const std::vector<index_t> betaStrides,
416  const std::vector<index_t> yStrides,
417  const std::vector<index_t> saveMeanStrides,
418  const std::vector<index_t> saveInvStdStrides,
419  const std::vector<index_t> reduceDims,
420  double epsilon,
421  const void* p_x,
422  const void* p_gamma,
423  const void* p_beta,
424  void* p_y,
425  void* p_saveMean,
426  void* p_saveInvStd,
427  YElementwiseOperation y_elementwise_op) override
428  {
429  if(lengths.size() != Rank || xStrides.size() != Rank || gammaStrides.size() != Rank ||
430  betaStrides.size() != Rank || yStrides.size() != Rank ||
431  saveMeanStrides.size() != NumInvariantDim || saveInvStdStrides.size() != NumInvariantDim)
432  throw std::runtime_error("dimension is incorrect");
433 
434  return std::make_unique<Argument>(lengths,
435  xStrides,
436  gammaStrides,
437  betaStrides,
438  yStrides,
439  saveMeanStrides,
440  saveInvStdStrides,
441  reduceDims,
442  y_elementwise_op,
443  epsilon,
444  static_cast<const XDataType*>(p_x),
445  static_cast<const GammaDataType*>(p_gamma),
446  static_cast<const BetaDataType*>(p_beta),
447  static_cast<YDataType*>(p_y),
448  static_cast<SaveMeanInvStdDataType*>(p_saveMean),
449  static_cast<SaveMeanInvStdDataType*>(p_saveInvStd));
450  };
451 
452  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
453  {
454  return std::make_unique<Invoker>();
455  };
456 
457  std::string GetTypeString() const override
458  {
459  auto str = std::stringstream();
460 
461  // clang-format off
462  str << "DeviceNormalizationFwdImpl<" << BlockSize << ",";
463  str << "Cluster_MK_" << MThreadClusterSize << "_" << KThreadClusterSize << ",";
464  str << "Slice_MK_" << MThreadSliceSize << "_" << KThreadSliceSize << ",";
465  str << "XYSrcVectorDim_" << XYSrcVectorDim << ",";
466  str << "VectorSize_X" << XSrcVectorSize << "_Gamma" << GammaSrcVectorSize << "_Beta" << BetaSrcVectorSize << "_Y" << YDstVectorSize << ">";
467  // clang-format on
468 
469  return str.str();
470  }
471 };
472 
473 } // namespace device
474 } // namespace tensor_operation
475 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
auto make_tuple_from_array(const std::vector< index_t > &lengths, Number< arraySize >)
Definition: device_reduce_common.hpp:65
auto make_tuple_from_array_and_index_seq(const std::vector< index_t > &lengths, Sequence< Ns... >)
Definition: device_reduce_common.hpp:59
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
auto NormalizationKernelSelector(bool isSweepOnce)
Definition: gridwise_normalization_selector.hpp:78
Definition: stream_config.hpp:10
Definition: sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:256
Definition: integral_constant.hpp:10
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_normalization_fwd.hpp:23
Definition: device_normalization_fwd_impl.hpp:177
const GammaDataType * p_gamma_
Definition: device_normalization_fwd_impl.hpp:239
const XDataType * p_x_
Definition: device_normalization_fwd_impl.hpp:238
YDataType * p_y_
Definition: device_normalization_fwd_impl.hpp:241
GridDesc_M save_mean_grid_desc_m_
Definition: device_normalization_fwd_impl.hpp:262
Argument(const std::vector< index_t > lengths, const std::vector< index_t > xStrides, const std::vector< index_t > gammaStrides, const std::vector< index_t > betaStrides, const std::vector< index_t > yStrides, const std::vector< index_t > saveMeanStrides, const std::vector< index_t > saveInvStdStrides, const std::vector< index_t > reduceDims, YElementwiseOperation y_elementwise_op, double epsilon, const XDataType *p_x, const GammaDataType *p_gamma, const BetaDataType *p_beta, YDataType *p_y, SaveMeanInvStdDataType *p_saveMean, SaveMeanInvStdDataType *p_saveInvStd)
Definition: device_normalization_fwd_impl.hpp:178
index_t MRaw_
Definition: device_normalization_fwd_impl.hpp:266
ComputeDataType epsilon_
Definition: device_normalization_fwd_impl.hpp:236
std::vector< index_t > saveInvStdStrides_
Definition: device_normalization_fwd_impl.hpp:251
GridDesc_M_K x_grid_desc_m_k_
Definition: device_normalization_fwd_impl.hpp:258
SaveMeanInvStdDataType * p_saveInvStd_
Definition: device_normalization_fwd_impl.hpp:243
std::vector< index_t > Lengths_
Definition: device_normalization_fwd_impl.hpp:245
SaveMeanInvStdDataType * p_saveMean_
Definition: device_normalization_fwd_impl.hpp:242
index_t invariant_lowest_length_
Definition: device_normalization_fwd_impl.hpp:269
std::vector< index_t > betaStrides_
Definition: device_normalization_fwd_impl.hpp:248
std::vector< index_t > saveMeanStrides_
Definition: device_normalization_fwd_impl.hpp:250
bool isSweeponce_
Definition: device_normalization_fwd_impl.hpp:264
std::vector< index_t > xStrides_
Definition: device_normalization_fwd_impl.hpp:246
index_t KRaw_
Definition: device_normalization_fwd_impl.hpp:267
std::vector< index_t > gammaStrides_
Definition: device_normalization_fwd_impl.hpp:247
GridDesc_M_K y_grid_desc_m_k_
Definition: device_normalization_fwd_impl.hpp:261
GridDesc_M_K gamma_grid_desc_m_k_
Definition: device_normalization_fwd_impl.hpp:259
GridDesc_M save_inv_std_grid_desc_m_
Definition: device_normalization_fwd_impl.hpp:263
YElementwiseOperation y_elementwise_op_
Definition: device_normalization_fwd_impl.hpp:253
GridDesc_M_K beta_grid_desc_m_k_
Definition: device_normalization_fwd_impl.hpp:260
int numBlockTileIteration_
Definition: device_normalization_fwd_impl.hpp:255
std::vector< index_t > yStrides_
Definition: device_normalization_fwd_impl.hpp:249
const BetaDataType * p_beta_
Definition: device_normalization_fwd_impl.hpp:240
size_t gridSize_
Definition: device_normalization_fwd_impl.hpp:256
Definition: device_normalization_fwd_impl.hpp:273
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_normalization_fwd_impl.hpp:326
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_normalization_fwd_impl.hpp:274
Definition: device_normalization_fwd_impl.hpp:57
static constexpr index_t M_BlockTileSize
Definition: device_normalization_fwd_impl.hpp:76
decltype(MakeSaveMeanInvStdDescriptor_M({1}, {1})) GridDesc_M
Definition: device_normalization_fwd_impl.hpp:174
static auto MakeSrc2dDescriptor(const std::vector< index_t > &inLengths, const std::vector< index_t > &inStrides, int numBlockTileIteration)
Definition: device_normalization_fwd_impl.hpp:82
static constexpr index_t NumInvariantDim
Definition: device_normalization_fwd_impl.hpp:75
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_normalization_fwd_impl.hpp:333
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_normalization_fwd_impl.hpp:452
static auto MakeSaveMeanInvStdDescriptor_M(const std::vector< index_t > &lengths, const std::vector< index_t > &strides)
Definition: device_normalization_fwd_impl.hpp:144
static constexpr index_t K_BlockTileSize
Definition: device_normalization_fwd_impl.hpp:77
decltype(MakeSrc2dDescriptor({1}, {1}, 1)) GridDesc_M_K
Definition: device_normalization_fwd_impl.hpp:173
std::string GetTypeString() const override
Definition: device_normalization_fwd_impl.hpp:457
static constexpr bool reduceAllDim
Definition: device_normalization_fwd_impl.hpp:79
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::vector< index_t > lengths, const std::vector< index_t > xStrides, const std::vector< index_t > gammaStrides, const std::vector< index_t > betaStrides, const std::vector< index_t > yStrides, const std::vector< index_t > saveMeanStrides, const std::vector< index_t > saveInvStdStrides, const std::vector< index_t > reduceDims, double epsilon, const void *p_x, const void *p_gamma, const void *p_beta, void *p_y, void *p_saveMean, void *p_saveInvStd, YElementwiseOperation y_elementwise_op) override
Definition: device_normalization_fwd_impl.hpp:412
Definition: unary_element_wise_operation.hpp:241