/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp Source File
device_batchnorm_forward_impl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
21 
22 namespace ck {
23 namespace tensor_operation {
24 namespace device {
25 
26 template <typename XDataType,
27  typename YDataType,
28  typename AccDataType,
29  typename ScaleDataType,
30  typename BiasDataType,
31  typename MeanVarDataType,
32  typename YElementwiseOp,
33  index_t Rank,
34  index_t NumBatchNormReduceDim,
35  bool UseMultiblockInK,
36  index_t BlockSize,
37  index_t MThreadClusterSize,
38  index_t KThreadClusterSize,
39  index_t MThreadSliceSize,
40  index_t KThreadSliceSize,
41  index_t XSrcYDstVectorDim,
42  index_t XSrcVectorSize,
43  index_t YDstVectorSize,
44  index_t ScaleSrcVectorSize,
45  index_t BiasSrcVectorSize,
46  index_t MeanVarSrcDstVectorSize>
47 struct DeviceBatchNormFwdImpl : public DeviceBatchNormFwd<XDataType,
48  YDataType,
49  AccDataType,
50  ScaleDataType,
51  BiasDataType,
52  MeanVarDataType,
53  YElementwiseOp,
54  Rank,
55  NumBatchNormReduceDim>
56 {
57  static_assert(Rank <= 6, "Bigger Rank size is not supported!");
58  static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
59  "Invalid thread cluster size assignments!");
60 
61  static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
62  (XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
63  "Invalid thread slice sizes and/or vector sizes configuration, please check!");
64 
65  static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
66 
67  static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
68  static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
69 
70  static auto MakeXY2dDescriptor(const std::array<index_t, Rank>& xyLengths,
71  const std::array<index_t, Rank>& xyStrides,
72  int blkGroupSize,
73  int numBlockTileIteration)
74  {
75  const auto tupleXYLengths =
76  generate_tuple([&](auto I) { return xyLengths[I]; }, Number<Rank>{});
77  const auto tupleXYStrides =
78  generate_tuple([&](auto I) { return xyStrides[I]; }, Number<Rank>{});
79 
80  const auto raw_grid_desc = make_naive_tensor_descriptor(tupleXYLengths, tupleXYStrides);
81 
82  const auto grid_desc_m_k = [&]() {
83  using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
85 
86  const auto reduceDimLengths =
87  generate_tuple([&](auto I) { return xyLengths[NumInvariantDim + I]; },
89  const auto invariantDimLengths =
90  generate_tuple([&](auto I) { return xyLengths[I]; }, Number<NumInvariantDim>{});
91 
92  return transform_tensor_descriptor(raw_grid_desc,
93  make_tuple(make_merge_transform(invariantDimLengths),
94  make_merge_transform(reduceDimLengths)),
95  make_tuple(InvariantDims{}, ReduceDims{}),
97  }();
98 
99  const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{});
100  const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{});
101 
102  const int workSizePerBlock = K_BlockTileSize * numBlockTileIteration;
103  const auto mPad =
104  math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
105  const auto kPad = workSizePerBlock * blkGroupSize - reduceLength;
106 
107  auto grid_desc_m_k_padded =
108  transform_tensor_descriptor(grid_desc_m_k,
109  make_tuple(make_right_pad_transform(invariantLength, mPad),
110  make_right_pad_transform(reduceLength, kPad)),
113 
114  return (grid_desc_m_k_padded);
115  };
116 
117  static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
118  {
119  const auto grid_desc_m_g = make_naive_tensor_descriptor(
120  make_tuple(invariantLength, blkGroupSize), make_tuple(1, invariantLength));
121 
122  const auto mPad =
123  math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
124 
125  auto grid_desc_m_g_padded =
126  transform_tensor_descriptor(grid_desc_m_g,
127  make_tuple(make_right_pad_transform(invariantLength, mPad),
128  make_pass_through_transform(blkGroupSize)),
131 
132  return (grid_desc_m_g_padded);
133  };
134 
135  static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize)
136  {
137  const auto reduceLength = blkGroupSize;
138  const auto grid_desc_m_k = make_naive_tensor_descriptor(
139  make_tuple(invariantLength, reduceLength), make_tuple(1, invariantLength));
140 
141  const auto mPad =
142  math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
143  const auto kPad =
144  math::integer_least_multiple(reduceLength, KThreadClusterSize) - reduceLength;
145 
146  auto grid_desc_m_k_padded =
147  transform_tensor_descriptor(grid_desc_m_k,
148  make_tuple(make_right_pad_transform(invariantLength, mPad),
149  make_right_pad_transform(reduceLength, kPad)),
152 
153  return (grid_desc_m_k_padded);
154  };
155 
156  static auto
157  MakeScaleBiasMeanVar1dDescriptor(const std::array<index_t, NumInvariantDim>& lengths,
158  const std::array<index_t, NumInvariantDim>& strides)
159  {
160  const auto tupleLengths =
161  generate_tuple([&](auto I) { return lengths[I]; }, Number<NumInvariantDim>{});
162  const auto tupleStrides =
163  generate_tuple([&](auto I) { return strides[I]; }, Number<NumInvariantDim>{});
164 
165  auto raw_grid_desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides);
166 
167  auto grid_desc_m = transform_tensor_descriptor(
168  raw_grid_desc,
169  make_tuple(make_merge_transform(tupleLengths)),
172 
173  const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
174 
175  const auto mPad =
176  math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
177 
178  auto grid_desc_m_padded =
179  transform_tensor_descriptor(grid_desc_m,
180  make_tuple(make_right_pad_transform(invariantLength, mPad)),
183  return (grid_desc_m_padded);
184  };
185 
186  using XYGridDesc_M_K = decltype(MakeXY2dDescriptor({1}, {1}, 1, 1));
188 
189  struct Argument : public BaseArgument
190  {
191  Argument(const std::array<index_t, Rank> xyLengths,
192  const std::array<index_t, Rank> xStrides,
193  const std::array<index_t, Rank> yStrides,
194  const std::array<int, NumBatchNormReduceDim> reduceDims,
195  const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
196  const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
197  const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
198  const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
199  const XDataType* p_x,
200  const ScaleDataType* p_scale,
201  const BiasDataType* p_bias,
202  const YElementwiseOp y_elementwise_op,
203  double epsilon,
204  YDataType* p_y,
205  MeanVarDataType* resultSaveMean,
206  MeanVarDataType* resultSaveInvVariance,
207  double averageFactor,
208  MeanVarDataType* resultRunningMean,
209  MeanVarDataType* resultRunningVariance)
210  : bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
211  bnScaleStrides_(bnScaleStrides),
212  bnBiasStrides_(bnBiasStrides),
213  bnMeanVarStrides_(bnMeanVarStrides),
214  p_x_(p_x),
215  p_scale_(p_scale),
216  p_bias_(p_bias),
217  y_elementwise_op_(y_elementwise_op),
218  p_y_(p_y),
219  resultSaveMean_(resultSaveMean),
220  resultSaveInvVariance_(resultSaveInvVariance),
221  resultRunningMean_(resultRunningMean),
222  resultRunningVariance_(resultRunningVariance)
223  {
224  xyLengths_ =
225  shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xyLengths, reduceDims);
226  xStrides_ =
227  shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xStrides, reduceDims);
228  yStrides_ =
229  shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(yStrides, reduceDims);
230 
232  get_2d_lengths<Rank, NumBatchNormReduceDim>(xyLengths_);
233 
234  epsilon_ = type_convert<AccDataType>(epsilon);
235  averageFactor_ = type_convert<AccDataType>(averageFactor);
236 
238  (resultRunningMean != nullptr && resultRunningVariance != nullptr);
239  saveMeanInvVariance_ = (resultSaveMean != nullptr && resultSaveInvVariance_ != nullptr);
240 
241  if(UseMultiblockInK)
242  {
243  int iterations = 1;
244  while(true)
245  {
246  int testBlkGroupSize = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
247  (K_BlockTileSize * iterations);
248 
249  // we want the blkGroupSize be not more than 16
250  if(testBlkGroupSize <= 16)
251  break;
252 
253  iterations++;
254  };
255 
256  blkGroupSize_ = (reduce_length_ + (K_BlockTileSize * iterations) - 1) /
257  (K_BlockTileSize * iterations);
258 
259  numBlockTileIteration_ = iterations;
260  }
261  else
262  {
263  blkGroupSize_ = 1;
265  };
266 
268 
274  MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides_);
276  MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnBiasStrides_);
278  MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides_);
279  }
280 
281  AccDataType epsilon_;
282  AccDataType averageFactor_;
283 
286 
287  std::array<index_t, Rank> xyLengths_;
288  std::array<index_t, Rank> xStrides_;
289  std::array<index_t, Rank> yStrides_;
290 
291  std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths_;
292  std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides_;
293  std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides_;
294  std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides_;
295 
296  const XDataType* p_x_;
297  const ScaleDataType* p_scale_;
298  const BiasDataType* p_bias_;
299  const YElementwiseOp y_elementwise_op_;
300  YDataType* p_y_;
301 
302  MeanVarDataType* resultSaveMean_;
303  MeanVarDataType* resultSaveInvVariance_;
304 
305  MeanVarDataType* resultRunningMean_;
306  MeanVarDataType* resultRunningVariance_;
307 
310 
313  size_t gridSize_;
314 
320 
324 
325  void* control_;
326  };
327 
328  size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
329  {
330  const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
331 
332  size_t workspace_size = 0;
333 
334  if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
335  {
336  // workspace for welford intermediate mean
337  workspace_size +=
338  pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64;
339 
340  // workspace for welford intermediate variance
341  workspace_size +=
342  pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType) + 64;
343 
344  // workspace for welford intermediate count
345  workspace_size +=
346  pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t) + 64;
347 
348  // workspace for barrier objects, each barrier object consists of two integers
349  // TODO: allocate barrier object memory globally to reuse it by other operators
350  workspace_size += (pArg_->invariant_length_ + M_BlockTileSize - 1) / M_BlockTileSize *
351  sizeof(int) * 2;
352  }
353 
354  return (workspace_size);
355  };
356 
358  void* p_workspace,
359  const StreamConfig& = StreamConfig{}) const override
360  {
361  Argument* pArg_ = dynamic_cast<Argument*>(pArg);
362 
363  pArg_->p_workspace_ = p_workspace;
364 
365  if(UseMultiblockInK && pArg_->blkGroupSize_ > 1)
366  {
367  // setup buffer used for intermediate welford mean
368  pArg_->workspace_mean_ = static_cast<char*>(pArg_->p_workspace_);
369 
370  index_t mean_space_sz =
371  pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType);
372 
373  mean_space_sz = math::integer_least_multiple(mean_space_sz, 64);
374 
375  // setup buffer used for intermediate welford varirance
376  pArg_->workspace_variance_ =
377  reinterpret_cast<char*>(pArg_->workspace_mean_) + mean_space_sz;
378 
379  index_t variance_space_sz =
380  pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(MeanVarDataType);
381 
382  variance_space_sz = math::integer_least_multiple(variance_space_sz, 64);
383 
384  // setup buffer used for intermediate welfor count
385  pArg_->workspace_count_ =
386  reinterpret_cast<char*>(pArg_->workspace_variance_) + variance_space_sz;
387 
388  index_t count_space_sz =
389  pArg_->invariant_length_ * pArg_->blkGroupSize_ * sizeof(int32_t);
390 
391  count_space_sz = math::integer_least_multiple(count_space_sz, 64);
392 
393  pArg_->control_ = reinterpret_cast<char*>(pArg_->workspace_count_) + count_space_sz;
394 
395  index_t control_space_sz = (pArg_->invariant_length_ + M_BlockTileSize - 1) /
396  M_BlockTileSize * sizeof(int) * 2;
397 
398  hip_check_error(hipMemset(pArg_->control_, 0, control_space_sz));
399  };
400  };
401 
402  struct Invoker : public BaseInvoker
403  {
404  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
405  {
406  float avg_time = 0;
407 
408  if(UseMultiblockInK && arg.blkGroupSize_ > 1)
409  {
410  using GetReduceCountPerThreadFunctor =
412 
413  GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
415 
416  const auto mean_var_count_grid_desc_m_g =
419 
420  const auto mean_var_count_grid_desc_m_k =
423 
424  using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g);
425  using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k);
426 
427  using GridwiseMultiblockBatchNormForward_ =
429  YDataType,
430  AccDataType,
431  ScaleDataType,
432  BiasDataType,
433  MeanVarDataType,
434  YElementwiseOp,
436  MeanVarCountGridDesc_M_G,
437  MeanVarCountGridDesc_M_K,
440  GetReduceCountPerThreadFunctor,
441  BlockSize,
442  MThreadClusterSize,
443  KThreadClusterSize,
444  MThreadSliceSize,
445  KThreadSliceSize,
446  XSrcYDstVectorDim,
447  XSrcVectorSize,
448  YDstVectorSize,
449  ScaleSrcVectorSize,
450  BiasSrcVectorSize,
451  MeanVarSrcDstVectorSize>;
452 
453  using GridwiseMultiblockWelfordFirstHalf_ =
455  AccDataType,
456  MeanVarDataType,
458  MeanVarCountGridDesc_M_G,
459  GetReduceCountPerThreadFunctor,
460  BlockSize,
461  MThreadClusterSize,
462  KThreadClusterSize,
463  MThreadSliceSize,
464  KThreadSliceSize,
465  XSrcYDstVectorDim,
466  XSrcVectorSize>;
467 
468  using GridwiseWelfordSecondHalfBatchNormForwardFinal_ =
470  YDataType,
471  AccDataType,
472  ScaleDataType,
473  BiasDataType,
474  MeanVarDataType,
475  YElementwiseOp,
477  MeanVarCountGridDesc_M_K,
480  BlockSize,
481  MThreadClusterSize,
482  KThreadClusterSize,
483  MThreadSliceSize,
484  KThreadSliceSize,
485  XSrcYDstVectorDim,
486  XSrcVectorSize,
487  YDstVectorSize,
488  ScaleSrcVectorSize,
489  BiasSrcVectorSize,
490  MeanVarSrcDstVectorSize>;
491 
492  // It is found that:
493  // 1) gfx1030 does not support the GLC enabled vector load/store, so using the
494  // two-kernel method for gfx1030
495  // 2) Profiler on gfx908 could hang even though it works when running examples
496  // 3) Single-kernel method works on gfx1100, but the performance it not better
497  // than two-kernel method (due to more warps participating the barrier)
498  if(ck::get_device_name() == "gfx90a")
499  {
500  const auto kern_multiblock_batchnorm_fwd_ =
501  kernel_multiblock_batchnorm_forward<GridwiseMultiblockBatchNormForward_,
502  XDataType,
503  YDataType,
504  AccDataType,
505  ScaleDataType,
506  BiasDataType,
507  MeanVarDataType,
508  YElementwiseOp,
510  MeanVarCountGridDesc_M_G,
511  MeanVarCountGridDesc_M_K,
514  GetReduceCountPerThreadFunctor>;
515 
516  avg_time += launch_and_time_kernel(
517  stream_config,
518  kern_multiblock_batchnorm_fwd_,
519  dim3(arg.gridSize_),
520  dim3(BlockSize),
521  0,
522  arg.x_grid_desc_m_k_,
523  arg.y_grid_desc_m_k_,
524  mean_var_count_grid_desc_m_g, // for writing to mean/variance/count
525  // workspace by multiple workgroups
526  mean_var_count_grid_desc_m_k, // for reading from mean/variance/count
527  // workspace by each workgroup
528  arg.scale_grid_desc_m_,
529  arg.bias_grid_desc_m_,
531  get_reduce_count_per_thread,
533  arg.epsilon_,
534  arg.p_x_,
535  static_cast<MeanVarDataType*>(arg.workspace_mean_),
536  static_cast<MeanVarDataType*>(arg.workspace_variance_),
537  static_cast<int32_t*>(arg.workspace_count_),
538  static_cast<int*>(arg.control_),
539  arg.p_scale_,
540  arg.p_bias_,
541  arg.y_elementwise_op_,
542  arg.p_y_,
543  arg.updateMovingAverage_, // true or false
544  arg.averageFactor_,
545  arg.resultRunningMean_,
547  arg.saveMeanInvVariance_, // true or false
548  arg.resultSaveMean_,
550  }
551  else
552  {
553  const auto kern_multiblock_welford_first_half =
554  kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_,
555  XDataType,
556  MeanVarDataType,
558  MeanVarCountGridDesc_M_G,
559  GetReduceCountPerThreadFunctor>;
560 
561  const auto kern_welford_second_half_batchnorm_forward_final =
563  GridwiseWelfordSecondHalfBatchNormForwardFinal_,
564  XDataType,
565  YDataType,
566  AccDataType,
567  ScaleDataType,
568  BiasDataType,
569  MeanVarDataType,
570  YElementwiseOp,
572  MeanVarCountGridDesc_M_K,
575 
576  avg_time += launch_and_time_kernel(
577  stream_config,
578  kern_multiblock_welford_first_half,
579  dim3(arg.gridSize_),
580  dim3(BlockSize),
581  0,
582  arg.x_grid_desc_m_k_,
583  mean_var_count_grid_desc_m_g,
584  get_reduce_count_per_thread,
586  arg.p_x_,
587  static_cast<MeanVarDataType*>(arg.workspace_mean_),
588  static_cast<MeanVarDataType*>(arg.workspace_variance_),
589  static_cast<int32_t*>(arg.workspace_count_));
590 
591  avg_time += launch_and_time_kernel(
592  stream_config,
593  kern_welford_second_half_batchnorm_forward_final,
594  dim3(arg.gridSize_),
595  dim3(BlockSize),
596  0,
597  arg.x_grid_desc_m_k_,
598  arg.y_grid_desc_m_k_,
599  mean_var_count_grid_desc_m_k,
600  arg.scale_grid_desc_m_,
601  arg.bias_grid_desc_m_,
603  arg.blkGroupSize_,
605  arg.epsilon_,
606  static_cast<MeanVarDataType*>(arg.workspace_mean_),
607  static_cast<MeanVarDataType*>(arg.workspace_variance_),
608  static_cast<int32_t*>(arg.workspace_count_),
609  arg.p_x_,
610  arg.p_scale_,
611  arg.p_bias_,
612  arg.y_elementwise_op_,
613  arg.p_y_,
615  arg.averageFactor_,
616  arg.resultRunningMean_,
619  arg.resultSaveMean_,
621  };
622  }
623  else
624  {
625  using GetReduceCountPerThreadFunctor =
626  GetReduceCountPerThreadForBlockwiseWelford<K_BlockTileSize, KThreadSliceSize>;
627 
628  GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
630 
631  using GridwiseBatchNormForwardWithBlockwiseWelford_ =
633  YDataType,
634  AccDataType,
635  ScaleDataType,
636  BiasDataType,
637  MeanVarDataType,
638  YElementwiseOp,
642  GetReduceCountPerThreadFunctor,
643  BlockSize,
644  MThreadClusterSize,
645  KThreadClusterSize,
646  MThreadSliceSize,
647  KThreadSliceSize,
648  XSrcYDstVectorDim,
649  XSrcVectorSize,
650  YDstVectorSize,
651  ScaleSrcVectorSize,
652  BiasSrcVectorSize,
653  MeanVarSrcDstVectorSize>;
654 
655  const auto kern_batchnorm_fwd = kernel_batchnorm_forward_with_blockwise_welford<
656  GridwiseBatchNormForwardWithBlockwiseWelford_,
657  XDataType,
658  YDataType,
659  AccDataType,
660  ScaleDataType,
661  BiasDataType,
662  MeanVarDataType,
663  YElementwiseOp,
667  GetReduceCountPerThreadFunctor>;
668 
669  avg_time += launch_and_time_kernel(stream_config,
670  kern_batchnorm_fwd,
671  dim3(arg.gridSize_),
672  dim3(BlockSize),
673  0,
674  arg.x_grid_desc_m_k_,
675  arg.y_grid_desc_m_k_,
676  arg.scale_grid_desc_m_,
677  arg.bias_grid_desc_m_,
679  get_reduce_count_per_thread,
681  arg.epsilon_,
682  arg.p_x_,
683  arg.p_scale_,
684  arg.p_bias_,
685  arg.y_elementwise_op_,
686  arg.p_y_,
687  arg.updateMovingAverage_, // true or false
688  arg.averageFactor_,
689  arg.resultRunningMean_,
691  arg.saveMeanInvVariance_, // true or false
692  arg.resultSaveMean_,
694  };
695 
696  return (avg_time);
697  };
698 
699  float Run(const BaseArgument* pArg,
700  const StreamConfig& stream_config = StreamConfig{}) override
701  {
702  return Run(*dynamic_cast<const Argument*>(pArg), stream_config);
703  };
704  };
705 
706  bool IsSupportedArgument(const BaseArgument* pArg) override
707  {
708  const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
709 
710  if constexpr(XSrcYDstVectorDim == 0)
711  {
712  if(pArg_->xStrides_[NumInvariantDim - 1] != 1 ||
713  pArg_->yStrides_[NumInvariantDim - 1] != 1)
714  return false;
715 
716  if(pArg_->xyLengths_[NumInvariantDim - 1] % XSrcVectorSize != 0 ||
717  pArg_->xyLengths_[NumInvariantDim - 1] % YDstVectorSize != 0)
718  return false;
719  }
720  else
721  {
722  if(pArg_->xStrides_[Rank - 1] != 1 || pArg_->yStrides_[Rank - 1] != 1)
723  return false;
724 
725  if(pArg_->xyLengths_[Rank - 1] % XSrcVectorSize != 0 ||
726  pArg_->xyLengths_[Rank - 1] % YDstVectorSize != 0)
727  return false;
728  };
729 
730  if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1)
731  return false;
732  if(pArg_->bnBiasStrides_[NumInvariantDim - 1] != 1 && BiasSrcVectorSize != 1)
733  return false;
734 
735  if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0)
736  return false;
737  if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % BiasSrcVectorSize != 0)
738  return false;
739 
740  if(pArg_->bnMeanVarStrides_[NumInvariantDim - 1] != 1 && MeanVarSrcDstVectorSize != 1)
741  return false;
742 
743  if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % MeanVarSrcDstVectorSize != 0)
744  return false;
745 
746  bool is_valid = true;
747 
748  static_for<0, NumInvariantDim, 1>{}([&](auto I) {
749  if(pArg_->xyLengths_[I] != pArg_->bnScaleBiasMeanVarLengths_[I])
750  is_valid = false;
751  });
752 
753  if(!is_valid)
754  return false;
755 
756  return true;
757  };
758 
759  std::unique_ptr<BaseArgument> MakeArgumentPointer(
760  const std::array<index_t, Rank> xyLengths,
761  const std::array<index_t, Rank> xStrides,
762  const std::array<index_t, Rank> yStrides,
763  const std::array<int, NumBatchNormReduceDim> reduceDims,
764  const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths,
765  const std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides,
766  const std::array<index_t, Rank - NumBatchNormReduceDim> bnBiasStrides,
767  const std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides,
768  const void* p_x,
769  const void* p_scale,
770  const void* p_bias,
771  double epsilon,
772  const YElementwiseOp y_elementwise_op,
773  void* p_y,
774  void* resultSaveMean,
775  void* resultSaveInvVariance,
776  double averageFactor,
777  void* resultRunningMean,
778  void* resultRunningVariance) override
779  {
780  return std::make_unique<Argument>(xyLengths,
781  xStrides,
782  yStrides,
783  reduceDims,
784  bnScaleBiasMeanVarLengths,
785  bnScaleStrides,
786  bnBiasStrides,
787  bnMeanVarStrides,
788  static_cast<const XDataType*>(p_x),
789  static_cast<const ScaleDataType*>(p_scale),
790  static_cast<const BiasDataType*>(p_bias),
791  y_elementwise_op,
792  epsilon,
793  static_cast<YDataType*>(p_y),
794  static_cast<MeanVarDataType*>(resultSaveMean),
795  static_cast<MeanVarDataType*>(resultSaveInvVariance),
796  averageFactor,
797  static_cast<MeanVarDataType*>(resultRunningMean),
798  static_cast<MeanVarDataType*>(resultRunningVariance));
799  };
800 
801  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
802  {
803  return std::make_unique<Invoker>();
804  };
805 
806  std::string GetTypeString() const override
807  {
808  auto str = std::stringstream();
809 
810  // clang-format off
811  str << "DeviceBatchNormFwdImpl<" << BlockSize << ",";
812  str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
813  str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
814  str << "XSrcYDstVectorDim_" << XSrcYDstVectorDim << ",";
815  str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << BiasSrcVectorSize << "_mean_var_" << MeanVarSrcDstVectorSize << "_Y" << YDstVectorSize << ">";
816  // clang-format on
817 
818  return str.str();
819  }
820 };
821 
822 } // namespace device
823 } // namespace tensor_operation
824 } // namespace ck
void hip_check_error(hipError_t x)
Definition: hip_check_error.hpp:10
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
Definition: ck.hpp:264
__global__ void kernel_multiblock_welford_first_half(const XGridDesc_M_K x_grid_desc_m_k, const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, index_t num_k_block_tile_iteration, const XDataType *const __restrict__ p_x, MeanVarDataType *const p_welford_mean, MeanVarDataType *const p_welford_variance, int32_t *const p_welford_count)
Definition: gridwise_multiblock_welford_first_half.hpp:21
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_welford_second_half_batchnorm_forward_final(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K y_grid_desc_m_k, const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M bias_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m, index_t blkgroup_size, index_t num_xy_k_block_tile_iteration, AccDataType epsilon, const MeanVarDataType *const __restrict__ p_in_welford_mean, const MeanVarDataType *const __restrict__ p_in_welford_variance, const int32_t *const __restrict__ p_in_welford_count, const XDataType *const __restrict__ p_x, const ScaleDataType *const __restrict__ p_scale, const BiasDataType *const __restrict__ p_bias, const YElementwiseOp y_elementwise_op, YDataType *const __restrict__ p_y, bool updateMovingAverage, AccDataType averageFactor, MeanVarDataType *const __restrict__ resultRunningMean, MeanVarDataType *const __restrict__ resultRunningVariance, bool saveMeanInvVariance, MeanVarDataType *const __restrict__ resultSaveMean, MeanVarDataType *const __restrict__ resultSaveInvVariance)
Definition: gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp:27
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
int64_t long_index_t
Definition: ck.hpp:290
std::string get_device_name()
Definition: device_prop.hpp:12
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
__global__ void kernel_multiblock_batchnorm_forward(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K y_grid_desc_m_k, const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g, const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M bias_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, index_t num_k_block_tile_iteration, AccDataType epsilon, const XDataType *const __restrict__ p_x, MeanVarDataType *const __restrict__ p_welford_mean, MeanVarDataType *const __restrict__ p_welford_variance, int32_t *const __restrict__ p_welford_count, int32_t *const __restrict__ p_control, const ScaleDataType *const __restrict__ p_scale, const BiasDataType *const __restrict__ p_bias, const YElementwiseOp y_elementwise_op, YDataType *const __restrict__ p_y, bool updateMovingAverage, AccDataType averageFactor, MeanVarDataType *const __restrict__ resultRunningMean, MeanVarDataType *const __restrict__ resultRunningVariance, bool saveMeanInvVariance, MeanVarDataType *const __restrict__ resultSaveMean, MeanVarDataType *const __restrict__ resultSaveInvVariance)
Definition: gridwise_multiblock_batchnorm_forward.hpp:31
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__global__ void kernel_batchnorm_forward_with_blockwise_welford(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K y_grid_desc_m_k, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M bias_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, index_t num_k_block_tile_iteration, AccDataType epsilon, const XDataType *const __restrict__ p_x, const ScaleDataType *const __restrict__ p_scale, const BiasDataType *const __restrict__ p_bias, const YElementwiseOp y_elementwise_op, YDataType *const __restrict__ p_y, bool updateMovingAverage, AccDataType averageFactor, MeanVarDataType *const __restrict__ resultRunningMean, MeanVarDataType *const __restrict__ resultRunningVariance, bool saveMeanInvVariance, MeanVarDataType *const __restrict__ resultSaveMean, MeanVarDataType *const __restrict__ resultSaveInvVariance)
Definition: gridwise_batchnorm_forward_blockwise_welford.hpp:27
Definition: stream_config.hpp:10
Definition: gridwise_batchnorm_forward_blockwise_welford.hpp:94
Definition: gridwise_multiblock_batchnorm_forward.hpp:112
Definition: gridwise_multiblock_welford_first_half.hpp:55
Definition: gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp:102
Definition: sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:256
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_batchnorm_forward.hpp:26
Definition: device_batchnorm_forward_impl.hpp:190
MeanVarDataType * resultRunningMean_
Definition: device_batchnorm_forward_impl.hpp:305
long_index_t reduce_length_
Definition: device_batchnorm_forward_impl.hpp:309
const ScaleDataType * p_scale_
Definition: device_batchnorm_forward_impl.hpp:297
bool updateMovingAverage_
Definition: device_batchnorm_forward_impl.hpp:284
ScaleBiasMeanVarGridDesc_M scale_grid_desc_m_
Definition: device_batchnorm_forward_impl.hpp:317
std::array< index_t, Rank > xStrides_
Definition: device_batchnorm_forward_impl.hpp:288
const XDataType * p_x_
Definition: device_batchnorm_forward_impl.hpp:296
std::array< index_t, Rank > xyLengths_
Definition: device_batchnorm_forward_impl.hpp:287
int blkGroupSize_
Definition: device_batchnorm_forward_impl.hpp:311
XYGridDesc_M_K x_grid_desc_m_k_
Definition: device_batchnorm_forward_impl.hpp:315
XYGridDesc_M_K y_grid_desc_m_k_
Definition: device_batchnorm_forward_impl.hpp:316
ScaleBiasMeanVarGridDesc_M bias_grid_desc_m_
Definition: device_batchnorm_forward_impl.hpp:318
bool saveMeanInvVariance_
Definition: device_batchnorm_forward_impl.hpp:285
long_index_t invariant_length_
Definition: device_batchnorm_forward_impl.hpp:308
MeanVarDataType * resultRunningVariance_
Definition: device_batchnorm_forward_impl.hpp:306
Argument(const std::array< index_t, Rank > xyLengths, const std::array< index_t, Rank > xStrides, const std::array< index_t, Rank > yStrides, const std::array< int, NumBatchNormReduceDim > reduceDims, const std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleBiasMeanVarLengths, const std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleStrides, const std::array< index_t, Rank - NumBatchNormReduceDim > bnBiasStrides, const std::array< index_t, Rank - NumBatchNormReduceDim > bnMeanVarStrides, const XDataType *p_x, const ScaleDataType *p_scale, const BiasDataType *p_bias, const YElementwiseOp y_elementwise_op, double epsilon, YDataType *p_y, MeanVarDataType *resultSaveMean, MeanVarDataType *resultSaveInvVariance, double averageFactor, MeanVarDataType *resultRunningMean, MeanVarDataType *resultRunningVariance)
Definition: device_batchnorm_forward_impl.hpp:191
AccDataType averageFactor_
Definition: device_batchnorm_forward_impl.hpp:282
const BiasDataType * p_bias_
Definition: device_batchnorm_forward_impl.hpp:298
AccDataType epsilon_
Definition: device_batchnorm_forward_impl.hpp:281
ScaleBiasMeanVarGridDesc_M mean_var_grid_desc_m_
Definition: device_batchnorm_forward_impl.hpp:319
std::array< index_t, Rank - NumBatchNormReduceDim > bnBiasStrides_
Definition: device_batchnorm_forward_impl.hpp:293
int numBlockTileIteration_
Definition: device_batchnorm_forward_impl.hpp:312
void * workspace_count_
Definition: device_batchnorm_forward_impl.hpp:323
const YElementwiseOp y_elementwise_op_
Definition: device_batchnorm_forward_impl.hpp:299
void * workspace_mean_
Definition: device_batchnorm_forward_impl.hpp:321
std::array< index_t, Rank - NumBatchNormReduceDim > bnMeanVarStrides_
Definition: device_batchnorm_forward_impl.hpp:294
void * control_
Definition: device_batchnorm_forward_impl.hpp:325
MeanVarDataType * resultSaveMean_
Definition: device_batchnorm_forward_impl.hpp:302
YDataType * p_y_
Definition: device_batchnorm_forward_impl.hpp:300
size_t gridSize_
Definition: device_batchnorm_forward_impl.hpp:313
MeanVarDataType * resultSaveInvVariance_
Definition: device_batchnorm_forward_impl.hpp:303
std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleBiasMeanVarLengths_
Definition: device_batchnorm_forward_impl.hpp:291
void * workspace_variance_
Definition: device_batchnorm_forward_impl.hpp:322
std::array< index_t, Rank > yStrides_
Definition: device_batchnorm_forward_impl.hpp:289
std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleStrides_
Definition: device_batchnorm_forward_impl.hpp:292
Definition: device_batchnorm_forward_impl.hpp:403
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_batchnorm_forward_impl.hpp:404
float Run(const BaseArgument *pArg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_batchnorm_forward_impl.hpp:699
Definition: device_batchnorm_forward_impl.hpp:56
static auto MakeXY2dDescriptor(const std::array< index_t, Rank > &xyLengths, const std::array< index_t, Rank > &xyStrides, int blkGroupSize, int numBlockTileIteration)
Definition: device_batchnorm_forward_impl.hpp:70
bool IsSupportedArgument(const BaseArgument *pArg) override
Definition: device_batchnorm_forward_impl.hpp:706
static constexpr index_t K_BlockTileSize
Definition: device_batchnorm_forward_impl.hpp:68
std::string GetTypeString() const override
Definition: device_batchnorm_forward_impl.hpp:806
void SetWorkSpacePointer(BaseArgument *pArg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition: device_batchnorm_forward_impl.hpp:357
decltype(MakeXY2dDescriptor({1}, {1}, 1, 1)) XYGridDesc_M_K
Definition: device_batchnorm_forward_impl.hpp:186
static constexpr index_t M_BlockTileSize
Definition: device_batchnorm_forward_impl.hpp:67
decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1})) ScaleBiasMeanVarGridDesc_M
Definition: device_batchnorm_forward_impl.hpp:187
static auto MakeScaleBiasMeanVar1dDescriptor(const std::array< index_t, NumInvariantDim > &lengths, const std::array< index_t, NumInvariantDim > &strides)
Definition: device_batchnorm_forward_impl.hpp:157
static constexpr index_t NumInvariantDim
Definition: device_batchnorm_forward_impl.hpp:65
size_t GetWorkSpaceSize(const BaseArgument *pArg) const override
Definition: device_batchnorm_forward_impl.hpp:328
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, Rank > xyLengths, const std::array< index_t, Rank > xStrides, const std::array< index_t, Rank > yStrides, const std::array< int, NumBatchNormReduceDim > reduceDims, const std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleBiasMeanVarLengths, const std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleStrides, const std::array< index_t, Rank - NumBatchNormReduceDim > bnBiasStrides, const std::array< index_t, Rank - NumBatchNormReduceDim > bnMeanVarStrides, const void *p_x, const void *p_scale, const void *p_bias, double epsilon, const YElementwiseOp y_elementwise_op, void *p_y, void *resultSaveMean, void *resultSaveInvVariance, double averageFactor, void *resultRunningMean, void *resultRunningVariance) override
Definition: device_batchnorm_forward_impl.hpp:759
static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize)
Definition: device_batchnorm_forward_impl.hpp:135
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_batchnorm_forward_impl.hpp:801
static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
Definition: device_batchnorm_forward_impl.hpp:117