/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp Source File
device_batchnorm_backward_impl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
6 #include <iostream>
7 #include <sstream>
8 
20 
21 namespace ck {
22 namespace tensor_operation {
23 namespace device {
24 
25 template <typename XDataType,
26  typename DxDataType,
27  typename DyDataType,
28  typename AccDataType,
29  typename ScaleDataType,
30  typename DscaleDbiasDataType,
31  typename MeanVarDataType,
32  typename DyElementwiseOp,
33  index_t Rank,
34  index_t NumBatchNormReduceDim,
35  bool UseMultiblockInK,
36  index_t BlockSize,
37  index_t MThreadClusterSize,
38  index_t KThreadClusterSize,
39  index_t MThreadSliceSize,
40  index_t KThreadSliceSize,
41  index_t XDyDxVectorDim,
42  index_t XSrcVectorSize,
43  index_t DySrcVectorSize,
44  index_t DxDstVectorSize,
45  index_t ScaleSrcVectorSize,
46  index_t DscaleDbiasDstVectorSize,
47  index_t MeanVarSrcVectorSize>
48 struct DeviceBatchNormBwdImpl : public DeviceBatchNormBwd<XDataType,
49  DxDataType,
50  DyDataType,
51  AccDataType,
52  ScaleDataType,
53  DscaleDbiasDataType,
54  MeanVarDataType,
55  DyElementwiseOp,
56  Rank,
57  NumBatchNormReduceDim>
58 {
59  static_assert(Rank <= 6, "Bigger Rank size is not supported!");
60  static_assert(BlockSize == MThreadClusterSize * KThreadClusterSize,
61  "Invalid thread cluster size assignments!");
62 
63  static_assert((XDyDxVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 &&
64  MThreadSliceSize % DySrcVectorSize == 0 &&
65  MThreadSliceSize % DxDstVectorSize == 0) ||
66  (XDyDxVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 &&
67  KThreadSliceSize % DySrcVectorSize == 0 &&
68  KThreadSliceSize % DxDstVectorSize == 0),
69  "Invalid thread slice sizes and/or vector sizes configuration, please check!");
70 
71  static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim;
72 
73  static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
74  static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
75 
76  static auto MakeXY2dDescriptor(const std::array<index_t, Rank>& xyLengths,
77  const std::array<index_t, Rank>& xyStrides,
78  int blkGroupSize,
79  int numBlockTileIteration)
80  {
81  const auto tupleXYLengths =
82  generate_tuple([&](auto I) { return xyLengths[I]; }, Number<Rank>{});
83  const auto tupleXYStrides =
84  generate_tuple([&](auto I) { return xyStrides[I]; }, Number<Rank>{});
85 
86  const auto raw_grid_desc = make_naive_tensor_descriptor(tupleXYLengths, tupleXYStrides);
87 
88  const auto grid_desc_m_k = [&]() {
89  using InvariantDims = typename arithmetic_sequence_gen<0, NumInvariantDim, 1>::type;
91 
92  const auto reduceDimLengths =
93  generate_tuple([&](auto I) { return xyLengths[NumInvariantDim + I]; },
95  const auto invariantDimLengths =
96  generate_tuple([&](auto I) { return xyLengths[I]; }, Number<NumInvariantDim>{});
97 
98  return transform_tensor_descriptor(raw_grid_desc,
99  make_tuple(make_merge_transform(invariantDimLengths),
100  make_merge_transform(reduceDimLengths)),
101  make_tuple(InvariantDims{}, ReduceDims{}),
103  }();
104 
105  const auto invariantLength = grid_desc_m_k.GetLength(Number<0>{});
106  const auto reduceLength = grid_desc_m_k.GetLength(Number<1>{});
107 
108  const int workSizePerBlock = K_BlockTileSize * numBlockTileIteration;
109  const auto mPad =
110  math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
111  const auto kPad = workSizePerBlock * blkGroupSize - reduceLength;
112 
113  auto grid_desc_m_k_padded =
114  transform_tensor_descriptor(grid_desc_m_k,
115  make_tuple(make_right_pad_transform(invariantLength, mPad),
116  make_right_pad_transform(reduceLength, kPad)),
119 
120  return (grid_desc_m_k_padded);
121  };
122 
123  static auto MakeMultiblockFirstReduceOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
124  {
125  const auto grid_desc_m_g =
126  make_naive_tensor_descriptor_packed(make_tuple(invariantLength, blkGroupSize));
127 
128  const auto mPad =
129  math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
130 
131  auto grid_desc_m_g_padded =
132  transform_tensor_descriptor(grid_desc_m_g,
133  make_tuple(make_right_pad_transform(invariantLength, mPad),
134  make_pass_through_transform(blkGroupSize)),
137 
138  return (grid_desc_m_g_padded);
139  };
140 
141  static auto MakeMultiblockFinalReduceInputMK2dDescriptor(int invariantLength, int blkGroupSize)
142  {
143  const auto reduceLength = blkGroupSize;
144  const auto grid_desc_m_k =
145  make_naive_tensor_descriptor_packed(make_tuple(invariantLength, reduceLength));
146 
147  const auto mPad =
148  math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
149  const auto kPad =
150  math::integer_least_multiple(reduceLength, KThreadClusterSize) - reduceLength;
151 
152  auto grid_desc_m_k_padded =
153  transform_tensor_descriptor(grid_desc_m_k,
154  make_tuple(make_right_pad_transform(invariantLength, mPad),
155  make_right_pad_transform(reduceLength, kPad)),
158 
159  return (grid_desc_m_k_padded);
160  };
161 
162  static auto
163  MakeScaleBiasMeanVar1dDescriptor(const std::array<index_t, NumInvariantDim>& lengths,
164  const std::array<index_t, NumInvariantDim>& strides)
165  {
166  const auto tupleLengths =
167  generate_tuple([&](auto I) { return lengths[I]; }, Number<NumInvariantDim>{});
168  const auto tupleStrides =
169  generate_tuple([&](auto I) { return strides[I]; }, Number<NumInvariantDim>{});
170 
171  auto raw_grid_desc = make_naive_tensor_descriptor(tupleLengths, tupleStrides);
172 
173  auto grid_desc_m = transform_tensor_descriptor(
174  raw_grid_desc,
175  make_tuple(make_merge_transform(tupleLengths)),
178 
179  const auto invariantLength = grid_desc_m.GetLength(Number<0>{});
180 
181  const auto mPad =
182  math::integer_least_multiple(invariantLength, M_BlockTileSize) - invariantLength;
183 
184  auto grid_desc_m_padded =
185  transform_tensor_descriptor(grid_desc_m,
186  make_tuple(make_right_pad_transform(invariantLength, mPad)),
189  return (grid_desc_m_padded);
190  };
191 
192  using XYGridDesc_M_K = decltype(MakeXY2dDescriptor({1}, {1}, 1, 1));
195 
196  struct Argument : public BaseArgument
197  {
198  Argument(const std::array<index_t, Rank> xyLengths,
199  const std::array<index_t, Rank> xStrides,
200  const std::array<index_t, Rank> dyStrides,
201  const std::array<index_t, Rank> dxStrides,
202  const std::array<int, NumBatchNormReduceDim> reduceDims,
203  const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
204  const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
205  const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
206  const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
207  const XDataType* p_x,
208  const DyDataType* p_dy,
209  const ScaleDataType* p_scale,
210  const MeanVarDataType* p_savedMean,
211  const MeanVarDataType* p_savedInvVar,
212  const DyElementwiseOp dy_elementwise_op,
213  double epsilon,
214  DxDataType* p_dx,
215  DscaleDbiasDataType* p_dscale,
216  DscaleDbiasDataType* p_dbias)
217  : bnScaleBiasMeanVarLengths_(bnScaleBiasMeanVarLengths),
218  bnScaleStrides_(bnScaleStrides),
219  bnDscaleDbiasStrides_(bnDscaleDbiasStrides),
220  bnMeanVarStrides_(bnMeanVarStrides),
221  p_x_(p_x),
222  p_dy_(p_dy),
223  p_scale_(p_scale),
224  p_savedMean_(p_savedMean),
225  p_savedInvVar_(p_savedInvVar),
226  dy_elementwise_op_(dy_elementwise_op),
227  p_dx_(p_dx),
228  p_dscale_(p_dscale),
229  p_dbias_(p_dbias)
230  {
231  xyLengths_ =
232  shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xyLengths, reduceDims);
233  xStrides_ =
234  shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(xStrides, reduceDims);
235  dyStrides_ =
236  shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(dyStrides, reduceDims);
237  dxStrides_ =
238  shuffle_tensor_dimensions<Rank, NumBatchNormReduceDim>(dxStrides, reduceDims);
239 
241  get_2d_lengths<Rank, NumBatchNormReduceDim>(xyLengths_);
242 
243  epsilon_ = type_convert<AccDataType>(epsilon);
244 
245  haveSavedMeanInvVar_ = (p_savedMean_ != nullptr && p_savedInvVar_ != nullptr);
246 
247  if(UseMultiblockInK)
248  {
249  int iterations = 1;
250  while(true)
251  {
252  int testBlkGroupSize = (reduce_length + (K_BlockTileSize * iterations) - 1) /
253  (K_BlockTileSize * iterations);
254 
255  // we want the blkGroupSize be not more than 128
256  if(testBlkGroupSize <= 128)
257  break;
258 
259  iterations++;
260  };
261 
262  blkGroupSize = (reduce_length + (K_BlockTileSize * iterations) - 1) /
263  (K_BlockTileSize * iterations);
264 
265  numBlockTileIteration = iterations;
266  }
267  else
268  {
269  blkGroupSize = 1;
271  };
272 
274 
282  MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnScaleStrides);
284  MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnDscaleDbiasStrides);
286  MakeScaleBiasMeanVar1dDescriptor(bnScaleBiasMeanVarLengths, bnMeanVarStrides);
287  }
288 
289  AccDataType epsilon_;
290 
292 
293  std::array<index_t, Rank> xyLengths_;
294  std::array<index_t, Rank> xStrides_;
295  std::array<index_t, Rank> dyStrides_;
296  std::array<index_t, Rank> dxStrides_;
297 
298  std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleBiasMeanVarLengths_;
299  std::array<index_t, Rank - NumBatchNormReduceDim> bnScaleStrides_;
300  std::array<index_t, Rank - NumBatchNormReduceDim> bnDscaleDbiasStrides_;
301  std::array<index_t, Rank - NumBatchNormReduceDim> bnMeanVarStrides_;
302 
303  const XDataType* p_x_;
304  const DyDataType* p_dy_;
305  const ScaleDataType* p_scale_;
306  const MeanVarDataType* p_savedMean_;
307  const MeanVarDataType* p_savedInvVar_;
308  const DyElementwiseOp dy_elementwise_op_;
309  DxDataType* p_dx_;
310  DscaleDbiasDataType* p_dscale_;
311  DscaleDbiasDataType* p_dbias_;
312 
315 
318  size_t gridSize;
319 
326 
330 
333 
336  };
337 
338  size_t GetWorkSpaceSize(const BaseArgument* pArg) const override
339  {
340  const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
341 
342  size_t workspace_size = 0;
343 
344  if(UseMultiblockInK && pArg_->blkGroupSize > 1)
345  {
346  // workspace for the partial reduced result for dscale
347  workspace_size +=
348  pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64;
349 
350  // workspace for the partial reduced result for dbias
351  workspace_size +=
352  pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType) + 64;
353 
354  if(!pArg_->haveSavedMeanInvVar_)
355  {
356  // workspace for welford intermediate mean
357  workspace_size +=
358  pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType) + 64;
359 
360  // workspace for welford intermediate variance
361  workspace_size +=
362  pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType) + 64;
363 
364  // workspace for welford intermediate count
365  workspace_size +=
366  pArg_->invariant_length * pArg_->blkGroupSize * sizeof(int32_t) + 64;
367 
368  // workspace for welford result mean
369  workspace_size += pArg_->invariant_length * sizeof(MeanVarDataType) + 64;
370 
371  // workspace for welford result inv_variance
372  workspace_size += pArg_->invariant_length * sizeof(MeanVarDataType) + 64;
373  };
374  }
375 
376  return (workspace_size);
377  };
378 
380  void* p_workspace,
381  const StreamConfig& = StreamConfig{}) const override
382  {
383  Argument* pArg_ = dynamic_cast<Argument*>(pArg);
384 
385  pArg_->p_workspace_ = p_workspace;
386 
387  index_t space_sz;
388 
389  // setup buffer for the partial reduced result for dscale
390  pArg_->workspace_reduce_dscale = pArg_->p_workspace_;
391 
392  space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType);
393  space_sz = math::integer_least_multiple(space_sz, 64);
394 
395  // setup buffer for the partial reduced result for dbias
396  pArg_->workspace_reduce_dbias =
397  reinterpret_cast<char*>(pArg_->workspace_reduce_dscale) + space_sz;
398 
399  if(UseMultiblockInK && pArg_->blkGroupSize > 1)
400  {
401  space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(DscaleDbiasDataType);
402  space_sz = math::integer_least_multiple(space_sz, 64);
403 
404  // setup buffer for welford intermediate mean
405  pArg_->workspace_mean =
406  reinterpret_cast<char*>(pArg_->workspace_reduce_dbias) + space_sz;
407 
408  space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType);
409  space_sz = math::integer_least_multiple(space_sz, 64);
410 
411  // setup buffer for welford intermediate varirance
412  pArg_->workspace_variance = reinterpret_cast<char*>(pArg_->workspace_mean) + space_sz;
413 
414  space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(MeanVarDataType);
415  space_sz = math::integer_least_multiple(space_sz, 64);
416 
417  // setup buffer for welford intermediate count
418  pArg_->workspace_count = reinterpret_cast<char*>(pArg_->workspace_variance) + space_sz;
419 
420  space_sz = pArg_->invariant_length * pArg_->blkGroupSize * sizeof(int32_t);
421  space_sz = math::integer_least_multiple(space_sz, 64);
422 
423  // setup buffer for welford result mean
424  pArg_->workspace_savedMean = reinterpret_cast<char*>(pArg_->workspace_count) + space_sz;
425 
426  space_sz = pArg_->invariant_length * sizeof(MeanVarDataType);
427  space_sz = math::integer_least_multiple(space_sz, 64);
428 
429  // setup buffer for welford result inv_variance
430  pArg_->workspace_savedInvVar =
431  reinterpret_cast<char*>(pArg_->workspace_savedMean) + space_sz;
432  };
433  };
434 
435  struct Invoker : public BaseInvoker
436  {
437  float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
438  {
439  float avg_time = 0;
440 
441  const auto mean_var_count_grid_desc_m_g =
443  arg.invariant_length, arg.blkGroupSize);
444 
445  const auto dscale_dbias_grid_desc_m_g =
447  arg.invariant_length, arg.blkGroupSize);
448 
449  const auto mean_var_count_grid_desc_m_k =
451  arg.invariant_length, arg.blkGroupSize);
452 
453  const auto dscale_dbias_grid_desc_m_k =
455  arg.invariant_length, arg.blkGroupSize);
456 
457  using MeanVarCountGridDesc_M_G = decltype(mean_var_count_grid_desc_m_g);
458  using MeanVarCountGridDesc_M_K = decltype(mean_var_count_grid_desc_m_k);
459  using DscaleDbiasGridDesc_M_G = decltype(dscale_dbias_grid_desc_m_g);
460  using DscaleDbiasGridDesc_M_K = decltype(dscale_dbias_grid_desc_m_k);
461 
462  using GridwiseWelfordSecondHalfReduceFirstHalf_ =
464  DyDataType,
465  AccDataType,
466  ScaleDataType,
467  DscaleDbiasDataType,
468  MeanVarDataType,
469  DyElementwiseOp,
472  MeanVarCountGridDesc_M_K,
473  DscaleDbiasGridDesc_M_G,
474  BlockSize,
475  MThreadClusterSize,
476  KThreadClusterSize,
477  MThreadSliceSize,
478  KThreadSliceSize,
479  XDyDxVectorDim,
480  XSrcVectorSize,
481  DySrcVectorSize,
482  MeanVarSrcVectorSize>;
483 
484  using GridwiseReduceSecondHalfBatchNormBwdFinal_ =
486  DyDataType,
487  DxDataType,
488  AccDataType,
489  ScaleDataType,
490  DscaleDbiasDataType,
491  MeanVarDataType,
492  DyElementwiseOp,
494  DscaleDbiasGridDesc_M_K,
497  BlockSize,
498  MThreadClusterSize,
499  KThreadClusterSize,
500  MThreadSliceSize,
501  KThreadSliceSize,
502  XDyDxVectorDim,
503  XSrcVectorSize,
504  DySrcVectorSize,
505  DxDstVectorSize,
506  ScaleSrcVectorSize,
507  DscaleDbiasDstVectorSize,
508  MeanVarSrcVectorSize>;
509 
510  if(UseMultiblockInK && arg.blkGroupSize > 1)
511  {
512  using GetReduceCountPerThreadFunctor =
514 
515  GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
517 
518  if(!arg.haveSavedMeanInvVar_)
519  {
520  using GridwiseMultiblockWelfordFirstHalf_ =
522  AccDataType,
523  MeanVarDataType,
525  MeanVarCountGridDesc_M_G,
526  GetReduceCountPerThreadFunctor,
527  BlockSize,
528  MThreadClusterSize,
529  KThreadClusterSize,
530  MThreadSliceSize,
531  KThreadSliceSize,
532  XDyDxVectorDim,
533  XSrcVectorSize>;
534 
535  const auto kern_multiblock_welford_first_half =
536  kernel_multiblock_welford_first_half<GridwiseMultiblockWelfordFirstHalf_,
537  XDataType,
538  MeanVarDataType,
540  MeanVarCountGridDesc_M_G,
541  GetReduceCountPerThreadFunctor>;
542 
543  avg_time += launch_and_time_kernel(
544  stream_config,
545  kern_multiblock_welford_first_half,
546  dim3(arg.gridSize),
547  dim3(BlockSize),
548  0,
549  arg.x_grid_desc_m_k,
550  mean_var_count_grid_desc_m_g,
551  get_reduce_count_per_thread,
553  arg.p_x_,
554  static_cast<MeanVarDataType*>(arg.workspace_mean),
555  static_cast<MeanVarDataType*>(arg.workspace_variance),
556  static_cast<int32_t*>(arg.workspace_count));
557  };
558 
559  const auto kern_welford_second_half_reduce_first_half =
561  GridwiseWelfordSecondHalfReduceFirstHalf_,
562  XDataType,
563  DyDataType,
564  AccDataType,
565  ScaleDataType,
566  DscaleDbiasDataType,
567  MeanVarDataType,
568  DyElementwiseOp,
571  MeanVarCountGridDesc_M_K,
572  DscaleDbiasGridDesc_M_G>;
573 
574  const auto kern_reduce_second_half_batchnorm_backward_final =
576  GridwiseReduceSecondHalfBatchNormBwdFinal_,
577  XDataType,
578  DyDataType,
579  DxDataType,
580  ScaleDataType,
581  DscaleDbiasDataType,
582  MeanVarDataType,
583  DyElementwiseOp,
585  DscaleDbiasGridDesc_M_K,
588 
589  index_t numDscaleDbiasBlockTileIteration =
590  (arg.blkGroupSize + KThreadClusterSize - 1) / KThreadClusterSize;
591 
592  avg_time += launch_and_time_kernel(
593  stream_config,
594  kern_welford_second_half_reduce_first_half,
595  dim3(arg.gridSize),
596  dim3(BlockSize),
597  0,
598  arg.x_grid_desc_m_k,
599  arg.dy_grid_desc_m_k,
601  mean_var_count_grid_desc_m_k,
602  dscale_dbias_grid_desc_m_g,
603  arg.blkGroupSize,
605  numDscaleDbiasBlockTileIteration,
606  arg.epsilon_,
608  arg.haveSavedMeanInvVar_ ? arg.p_savedMean_ : nullptr,
609  arg.haveSavedMeanInvVar_ ? arg.p_savedInvVar_ : nullptr,
611  ? nullptr
612  : static_cast<const MeanVarDataType*>(arg.workspace_mean),
614  ? nullptr
615  : static_cast<const MeanVarDataType*>(arg.workspace_variance),
616  arg.haveSavedMeanInvVar_ ? nullptr
617  : static_cast<const int32_t*>(arg.workspace_count),
618  arg.dy_elementwise_op_,
620  ? nullptr
621  : static_cast<MeanVarDataType*>(arg.workspace_savedMean),
623  ? nullptr
624  : static_cast<MeanVarDataType*>(arg.workspace_savedInvVar),
625  arg.p_x_,
626  arg.p_dy_,
627  static_cast<DscaleDbiasDataType*>(arg.workspace_reduce_dscale),
628  static_cast<DscaleDbiasDataType*>(arg.workspace_reduce_dbias));
629 
630  avg_time += launch_and_time_kernel(
631  stream_config,
632  kern_reduce_second_half_batchnorm_backward_final,
633  dim3(arg.gridSize),
634  dim3(BlockSize),
635  0,
636  arg.x_grid_desc_m_k,
637  arg.dy_grid_desc_m_k,
638  arg.dx_grid_desc_m_k,
639  dscale_dbias_grid_desc_m_k,
641  arg.scale_grid_desc_m,
643  arg.blkGroupSize,
644  arg.reduce_length,
646  numDscaleDbiasBlockTileIteration,
647  static_cast<const DscaleDbiasDataType*>(arg.workspace_reduce_dscale),
648  static_cast<const DscaleDbiasDataType*>(arg.workspace_reduce_dbias),
650  ? arg.p_savedMean_
651  : static_cast<const MeanVarDataType*>(arg.workspace_savedMean),
653  ? arg.p_savedInvVar_
654  : static_cast<const MeanVarDataType*>(arg.workspace_savedInvVar),
655  arg.p_x_,
656  arg.p_dy_,
657  arg.p_scale_,
658  arg.dy_elementwise_op_,
659  arg.p_dx_,
660  arg.p_dscale_,
661  arg.p_dbias_);
662  }
663  else
664  {
665  using GetReduceCountPerThreadFunctor =
666  GetReduceCountPerThreadForBlockwiseWelford<K_BlockTileSize, KThreadSliceSize>;
667 
668  GetReduceCountPerThreadFunctor get_reduce_count_per_thread(
670 
671  using GridwiseBatchNormBackwardWithBlockwiseWelford_ =
673  DyDataType,
674  DxDataType,
675  AccDataType,
676  ScaleDataType,
677  DscaleDbiasDataType,
678  MeanVarDataType,
679  DyElementwiseOp,
683  GetReduceCountPerThreadFunctor,
684  BlockSize,
685  MThreadClusterSize,
686  KThreadClusterSize,
687  MThreadSliceSize,
688  KThreadSliceSize,
689  XDyDxVectorDim,
690  XSrcVectorSize,
691  DySrcVectorSize,
692  DxDstVectorSize,
693  ScaleSrcVectorSize,
694  DscaleDbiasDstVectorSize,
695  MeanVarSrcVectorSize>;
696 
697  const auto kern_batchnorm_bwd = kernel_batchnorm_backward_with_blockwise_welford<
698  GridwiseBatchNormBackwardWithBlockwiseWelford_,
699  XDataType,
700  DyDataType,
701  DxDataType,
702  AccDataType,
703  ScaleDataType,
704  DscaleDbiasDataType,
705  MeanVarDataType,
706  DyElementwiseOp,
710  GetReduceCountPerThreadFunctor>;
711 
712  avg_time += launch_and_time_kernel(stream_config,
713  kern_batchnorm_bwd,
714  dim3(arg.gridSize),
715  dim3(BlockSize),
716  0,
717  arg.x_grid_desc_m_k,
718  arg.dy_grid_desc_m_k,
719  arg.dx_grid_desc_m_k,
720  arg.scale_grid_desc_m,
723  get_reduce_count_per_thread,
724  arg.reduce_length,
726  arg.epsilon_,
727  arg.p_x_,
728  arg.p_dy_,
729  arg.p_scale_,
731  arg.p_savedMean_,
732  arg.p_savedInvVar_,
733  arg.dy_elementwise_op_,
734  arg.p_dx_,
735  arg.p_dscale_,
736  arg.p_dbias_);
737  };
738 
739  return (avg_time);
740  };
741 
742  float Run(const BaseArgument* pArg,
743  const StreamConfig& stream_config = StreamConfig{}) override
744  {
745  return Run(*dynamic_cast<const Argument*>(pArg), stream_config);
746  };
747  };
748 
749  bool IsSupportedArgument(const BaseArgument* pArg) override
750  {
751  const Argument* pArg_ = dynamic_cast<const Argument*>(pArg);
752 
753  if constexpr(XDyDxVectorDim == 0)
754  {
755  if(pArg_->xStrides_[NumInvariantDim - 1] != 1 ||
756  pArg_->dyStrides_[NumInvariantDim - 1] != 1 ||
757  pArg_->dxStrides_[NumInvariantDim - 1] != 1)
758  return false;
759 
760  if(pArg_->xyLengths_[NumInvariantDim - 1] % XSrcVectorSize != 0 ||
761  pArg_->xyLengths_[NumInvariantDim - 1] % DySrcVectorSize != 0 ||
762  pArg_->xyLengths_[NumInvariantDim - 1] % DxDstVectorSize != 0)
763  return false;
764  }
765  else
766  {
767  if(pArg_->xStrides_[Rank - 1] != 1 || pArg_->dyStrides_[Rank - 1] != 1 ||
768  pArg_->dxStrides_[Rank - 1] != 1)
769  return false;
770 
771  if(pArg_->xyLengths_[Rank - 1] % XSrcVectorSize != 0 ||
772  pArg_->xyLengths_[Rank - 1] % DySrcVectorSize != 0 ||
773  pArg_->xyLengths_[Rank - 1] % DxDstVectorSize != 0)
774  return false;
775  };
776 
777  if(pArg_->bnScaleStrides_[NumInvariantDim - 1] != 1 && ScaleSrcVectorSize != 1)
778  return false;
779 
780  if(pArg_->bnDscaleDbiasStrides_[NumInvariantDim - 1] != 1 && DscaleDbiasDstVectorSize != 1)
781  return false;
782 
783  if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % ScaleSrcVectorSize != 0)
784  return false;
785 
786  if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % DscaleDbiasDstVectorSize != 0)
787  return false;
788 
789  if(pArg_->haveSavedMeanInvVar_)
790  {
791  if(pArg_->bnMeanVarStrides_[NumInvariantDim - 1] != 1 && MeanVarSrcVectorSize != 1)
792  return false;
793 
794  if(pArg_->bnScaleBiasMeanVarLengths_[NumInvariantDim - 1] % MeanVarSrcVectorSize != 0)
795  return false;
796  };
797 
798  bool is_valid = true;
799 
800  static_for<0, NumInvariantDim, 1>{}([&](auto I) {
801  if(pArg_->xyLengths_[I] != pArg_->bnScaleBiasMeanVarLengths_[I])
802  is_valid = false;
803  });
804 
805  if(!is_valid)
806  return false;
807 
808  return true;
809  };
810 
811  std::unique_ptr<BaseArgument>
812  MakeArgumentPointer(const std::array<index_t, Rank> xyLengths,
813  const std::array<index_t, Rank> xStrides,
814  const std::array<index_t, Rank> dyStrides,
815  const std::array<index_t, Rank> dxStrides,
816  const std::array<int, NumBatchNormReduceDim> reduceDims,
817  const std::array<ck::index_t, NumInvariantDim> bnScaleBiasMeanVarLengths,
818  const std::array<ck::index_t, NumInvariantDim> bnScaleStrides,
819  const std::array<ck::index_t, NumInvariantDim> bnDscaleDbiasStrides,
820  const std::array<ck::index_t, NumInvariantDim> bnMeanVarStrides,
821  const void* p_x,
822  const void* p_dy,
823  const void* p_scale,
824  const void* p_savedMean,
825  const void* p_savedInvVar,
826  double epsilon,
827  const DyElementwiseOp dy_elementwise_op,
828  void* p_dx,
829  void* p_dscale,
830  void* p_dbias) override
831  {
832  return std::make_unique<Argument>(xyLengths,
833  xStrides,
834  dyStrides,
835  dxStrides,
836  reduceDims,
837  bnScaleBiasMeanVarLengths,
838  bnScaleStrides,
839  bnDscaleDbiasStrides,
840  bnMeanVarStrides,
841  static_cast<const XDataType*>(p_x),
842  static_cast<const DyDataType*>(p_dy),
843  static_cast<const ScaleDataType*>(p_scale),
844  static_cast<const MeanVarDataType*>(p_savedMean),
845  static_cast<const MeanVarDataType*>(p_savedInvVar),
846  dy_elementwise_op,
847  epsilon,
848  static_cast<DxDataType*>(p_dx),
849  static_cast<DscaleDbiasDataType*>(p_dscale),
850  static_cast<DscaleDbiasDataType*>(p_dbias));
851  };
852 
853  std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
854  {
855  return std::make_unique<Invoker>();
856  };
857 
858  std::string GetTypeString() const override
859  {
860  auto str = std::stringstream();
861 
862  // clang-format off
863  str << "DeviceBatchNormBwdImpl<" << BlockSize << ",";
864  str << "M_C" << MThreadClusterSize << "_S" << MThreadSliceSize << ",";
865  str << "K_C" << KThreadClusterSize << "_S" << KThreadSliceSize << ",";
866  str << "XDyDxVectorDim_" << XDyDxVectorDim << ",";
867  str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << DscaleDbiasDstVectorSize << "_mean_var_" << MeanVarSrcVectorSize << "_Dx_" << DxDstVectorSize << ">";
868  // clang-format on
869 
870  return str.str();
871  }
872 }; // namespace device
873 
874 } // namespace device
875 } // namespace tensor_operation
876 } // namespace ck
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
Definition: ck.hpp:264
__global__ void kernel_multiblock_welford_first_half(const XGridDesc_M_K x_grid_desc_m_k, const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, index_t num_k_block_tile_iteration, const XDataType *const __restrict__ p_x, MeanVarDataType *const p_welford_mean, MeanVarDataType *const p_welford_variance, int32_t *const p_welford_count)
Definition: gridwise_multiblock_welford_first_half.hpp:21
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__global__ void kernel_welford_second_half_reduce_first_half(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k, const MeanVarGridDesc_M mean_var_grid_desc_m, const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, const DscaleDbiasGridDesc_M_G dscale_dbias_grid_desc_m_g, index_t blkgroup_size, index_t num_xy_k_block_tile_iteration, index_t num_mean_var_count_k_block_tile_iteration, AccDataType epsilon, bool haveSavedMeanInvVar, const MeanVarDataType *const __restrict__ p_savedMean, const MeanVarDataType *const __restrict__ p_savedInvVar, const MeanVarDataType *const __restrict__ p_in_welford_mean, const MeanVarDataType *const __restrict__ p_in_welford_variance, const int32_t *const __restrict__ p_in_welford_count, const DyElementwiseOp dy_elementwise_op, MeanVarDataType *const __restrict__ p_out_welford_mean, MeanVarDataType *const __restrict__ p_out_welford_inv_variance, const XDataType *const __restrict__ p_x, const DyDataType *const __restrict__ p_dy, DscaleDbiasDataType *const __restrict__ p_reduce_dscale, DscaleDbiasDataType *const __restrict__ p_reduce_dbias)
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:27
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
int64_t long_index_t
Definition: ck.hpp:290
__global__ void kernel_batchnorm_backward_with_blockwise_welford(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k, const XYGridDesc_M_K dx_grid_desc_m_k, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, long_index_t reduce_size, index_t num_k_block_tile_iteration, AccDataType epsilon, const XDataType *const __restrict__ p_x, const DyDataType *const __restrict__ p_dy, const ScaleDataType *const __restrict__ p_scale, bool haveSavedMeanInvVar, const MeanVarDataType *const __restrict__ p_savedMean, const MeanVarDataType *const __restrict__ p_savedInvVar, const DyElementwiseOp dy_elementwise_op, DxDataType *const __restrict__ p_dx, DscaleDbiasDataType *const __restrict__ p_dscale, DscaleDbiasDataType *const __restrict__ p_dbias)
Definition: gridwise_batchnorm_backward_blockwise_welford.hpp:31
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__global__ void kernel_reduce_second_half_batchnorm_backward_final(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k, const XYGridDesc_M_K dx_grid_desc_m_k, const DscaleDbiasGridDesc_M_K dscale_dbias_grid_desc_m_k, const MeanVarGridDesc_M mean_var_grid_desc_m, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M bias_grid_desc_m, index_t blkgroup_size, long_index_t reduce_size, index_t num_xy_k_block_tile_iteration, index_t num_dscale_dbias_k_block_tile_iteration, const DscaleDbiasDataType *const __restrict__ p_reduce_dscale, const DscaleDbiasDataType *const __restrict__ p_reduce_dbias, const MeanVarDataType *const __restrict__ p_mean, const MeanVarDataType *const __restrict__ p_inv_var, const XDataType *const __restrict__ p_x, const DyDataType *const __restrict__ p_dy, const ScaleDataType *const __restrict__ p_scale, const DyElementwiseOp dy_elementwise_op, DxDataType *const __restrict__ p_dx, DscaleDbiasDataType *const __restrict__ p_dscale, DscaleDbiasDataType *const __restrict__ p_dbias)
Definition: gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp:26
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
Definition: stream_config.hpp:10
Definition: gridwise_batchnorm_backward_blockwise_welford.hpp:100
Definition: gridwise_multiblock_welford_first_half.hpp:55
Definition: gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp:99
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:96
Definition: sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:256
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_batchnorm_backward.hpp:27
Definition: device_batchnorm_backward_impl.hpp:197
std::array< index_t, Rank > dyStrides_
Definition: device_batchnorm_backward_impl.hpp:295
XYGridDesc_M_K x_grid_desc_m_k
Definition: device_batchnorm_backward_impl.hpp:320
AccDataType epsilon_
Definition: device_batchnorm_backward_impl.hpp:289
DscaleDbiasDataType * p_dscale_
Definition: device_batchnorm_backward_impl.hpp:310
std::array< index_t, Rank > xStrides_
Definition: device_batchnorm_backward_impl.hpp:294
std::array< index_t, Rank > xyLengths_
Definition: device_batchnorm_backward_impl.hpp:293
std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleStrides_
Definition: device_batchnorm_backward_impl.hpp:299
bool haveSavedMeanInvVar_
Definition: device_batchnorm_backward_impl.hpp:291
const MeanVarDataType * p_savedMean_
Definition: device_batchnorm_backward_impl.hpp:306
int blkGroupSize
Definition: device_batchnorm_backward_impl.hpp:316
std::array< index_t, Rank > dxStrides_
Definition: device_batchnorm_backward_impl.hpp:296
ScaleBiasGridDesc_M dscale_dbias_grid_desc_m
Definition: device_batchnorm_backward_impl.hpp:324
std::array< index_t, Rank - NumBatchNormReduceDim > bnMeanVarStrides_
Definition: device_batchnorm_backward_impl.hpp:301
void * workspace_reduce_dbias
Definition: device_batchnorm_backward_impl.hpp:335
const ScaleDataType * p_scale_
Definition: device_batchnorm_backward_impl.hpp:305
long_index_t reduce_length
Definition: device_batchnorm_backward_impl.hpp:314
const DyDataType * p_dy_
Definition: device_batchnorm_backward_impl.hpp:304
Argument(const std::array< index_t, Rank > xyLengths, const std::array< index_t, Rank > xStrides, const std::array< index_t, Rank > dyStrides, const std::array< index_t, Rank > dxStrides, const std::array< int, NumBatchNormReduceDim > reduceDims, const std::array< ck::index_t, NumInvariantDim > bnScaleBiasMeanVarLengths, const std::array< ck::index_t, NumInvariantDim > bnScaleStrides, const std::array< ck::index_t, NumInvariantDim > bnDscaleDbiasStrides, const std::array< ck::index_t, NumInvariantDim > bnMeanVarStrides, const XDataType *p_x, const DyDataType *p_dy, const ScaleDataType *p_scale, const MeanVarDataType *p_savedMean, const MeanVarDataType *p_savedInvVar, const DyElementwiseOp dy_elementwise_op, double epsilon, DxDataType *p_dx, DscaleDbiasDataType *p_dscale, DscaleDbiasDataType *p_dbias)
Definition: device_batchnorm_backward_impl.hpp:198
ScaleBiasGridDesc_M scale_grid_desc_m
Definition: device_batchnorm_backward_impl.hpp:323
size_t gridSize
Definition: device_batchnorm_backward_impl.hpp:318
DxDataType * p_dx_
Definition: device_batchnorm_backward_impl.hpp:309
void * workspace_variance
Definition: device_batchnorm_backward_impl.hpp:328
MeanVarGridDesc_M mean_var_grid_desc_m
Definition: device_batchnorm_backward_impl.hpp:325
const XDataType * p_x_
Definition: device_batchnorm_backward_impl.hpp:303
void * workspace_savedMean
Definition: device_batchnorm_backward_impl.hpp:331
int numBlockTileIteration
Definition: device_batchnorm_backward_impl.hpp:317
void * workspace_mean
Definition: device_batchnorm_backward_impl.hpp:327
void * workspace_savedInvVar
Definition: device_batchnorm_backward_impl.hpp:332
long_index_t invariant_length
Definition: device_batchnorm_backward_impl.hpp:313
DscaleDbiasDataType * p_dbias_
Definition: device_batchnorm_backward_impl.hpp:311
std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleBiasMeanVarLengths_
Definition: device_batchnorm_backward_impl.hpp:298
std::array< index_t, Rank - NumBatchNormReduceDim > bnDscaleDbiasStrides_
Definition: device_batchnorm_backward_impl.hpp:300
void * workspace_count
Definition: device_batchnorm_backward_impl.hpp:329
XYGridDesc_M_K dy_grid_desc_m_k
Definition: device_batchnorm_backward_impl.hpp:321
const MeanVarDataType * p_savedInvVar_
Definition: device_batchnorm_backward_impl.hpp:307
const DyElementwiseOp dy_elementwise_op_
Definition: device_batchnorm_backward_impl.hpp:308
void * workspace_reduce_dscale
Definition: device_batchnorm_backward_impl.hpp:334
XYGridDesc_M_K dx_grid_desc_m_k
Definition: device_batchnorm_backward_impl.hpp:322
Definition: device_batchnorm_backward_impl.hpp:436
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_batchnorm_backward_impl.hpp:437
float Run(const BaseArgument *pArg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_batchnorm_backward_impl.hpp:742
Definition: device_batchnorm_backward_impl.hpp:58
std::string GetTypeString() const override
Definition: device_batchnorm_backward_impl.hpp:858
static constexpr index_t NumInvariantDim
Definition: device_batchnorm_backward_impl.hpp:71
static constexpr index_t M_BlockTileSize
Definition: device_batchnorm_backward_impl.hpp:73
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, Rank > xyLengths, const std::array< index_t, Rank > xStrides, const std::array< index_t, Rank > dyStrides, const std::array< index_t, Rank > dxStrides, const std::array< int, NumBatchNormReduceDim > reduceDims, const std::array< ck::index_t, NumInvariantDim > bnScaleBiasMeanVarLengths, const std::array< ck::index_t, NumInvariantDim > bnScaleStrides, const std::array< ck::index_t, NumInvariantDim > bnDscaleDbiasStrides, const std::array< ck::index_t, NumInvariantDim > bnMeanVarStrides, const void *p_x, const void *p_dy, const void *p_scale, const void *p_savedMean, const void *p_savedInvVar, double epsilon, const DyElementwiseOp dy_elementwise_op, void *p_dx, void *p_dscale, void *p_dbias) override
Definition: device_batchnorm_backward_impl.hpp:812
bool IsSupportedArgument(const BaseArgument *pArg) override
Definition: device_batchnorm_backward_impl.hpp:749
static auto MakeMultiblockFirstReduceOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
Definition: device_batchnorm_backward_impl.hpp:123
ScaleBiasGridDesc_M MeanVarGridDesc_M
Definition: device_batchnorm_backward_impl.hpp:194
static constexpr index_t K_BlockTileSize
Definition: device_batchnorm_backward_impl.hpp:74
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_batchnorm_backward_impl.hpp:853
static auto MakeXY2dDescriptor(const std::array< index_t, Rank > &xyLengths, const std::array< index_t, Rank > &xyStrides, int blkGroupSize, int numBlockTileIteration)
Definition: device_batchnorm_backward_impl.hpp:76
decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1})) ScaleBiasGridDesc_M
Definition: device_batchnorm_backward_impl.hpp:193
static auto MakeScaleBiasMeanVar1dDescriptor(const std::array< index_t, NumInvariantDim > &lengths, const std::array< index_t, NumInvariantDim > &strides)
Definition: device_batchnorm_backward_impl.hpp:163
size_t GetWorkSpaceSize(const BaseArgument *pArg) const override
Definition: device_batchnorm_backward_impl.hpp:338
static auto MakeMultiblockFinalReduceInputMK2dDescriptor(int invariantLength, int blkGroupSize)
Definition: device_batchnorm_backward_impl.hpp:141
void SetWorkSpacePointer(BaseArgument *pArg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition: device_batchnorm_backward_impl.hpp:379
decltype(MakeXY2dDescriptor({1}, {1}, 1, 1)) XYGridDesc_M_K
Definition: device_batchnorm_backward_impl.hpp:192