/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp Source File
gridwise_multiblock_batchnorm_forward.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/math_v2.hpp"
12 
14 
15 namespace ck {
16 
17 template <typename GridwiseMultiblockBatchNormForward_,
18  typename XDataType,
19  typename YDataType,
20  typename AccDataType,
21  typename ScaleDataType,
22  typename BiasDataType,
23  typename MeanVarDataType,
24  typename YElementwiseOp,
25  typename XYGridDesc_M_K,
26  typename MeanVarCountGridDesc_M_G,
27  typename MeanVarCountGridDesc_M_K,
28  typename ScaleBiasGridDesc_M,
29  typename MeanVarGridDesc_M,
30  typename GetReduceCountPerThreadFunctor>
32  const XYGridDesc_M_K x_grid_desc_m_k,
33  const XYGridDesc_M_K y_grid_desc_m_k,
34  const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g,
35  const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k,
36  const ScaleBiasGridDesc_M scale_grid_desc_m,
37  const ScaleBiasGridDesc_M bias_grid_desc_m,
38  const MeanVarGridDesc_M mean_var_grid_desc_m,
39  const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
40  index_t num_k_block_tile_iteration,
41  AccDataType epsilon,
42  const XDataType* const __restrict__ p_x,
43  MeanVarDataType* const __restrict__ p_welford_mean,
44  MeanVarDataType* const __restrict__ p_welford_variance,
45  int32_t* const __restrict__ p_welford_count,
46  int32_t* const __restrict__ p_control,
47  const ScaleDataType* const __restrict__ p_scale,
48  const BiasDataType* const __restrict__ p_bias,
49  const YElementwiseOp y_elementwise_op,
50  YDataType* const __restrict__ p_y,
51  bool updateMovingAverage,
52  AccDataType averageFactor,
53  MeanVarDataType* const __restrict__ resultRunningMean,
54  MeanVarDataType* const __restrict__ resultRunningVariance,
55  bool saveMeanInvVariance,
56  MeanVarDataType* const __restrict__ resultSaveMean,
57  MeanVarDataType* const __restrict__ resultSaveInvVariance)
58 {
59  GridwiseMultiblockBatchNormForward_::Run(x_grid_desc_m_k,
60  y_grid_desc_m_k,
61  mean_var_count_grid_desc_m_g,
62  mean_var_count_grid_desc_m_k,
63  scale_grid_desc_m,
64  bias_grid_desc_m,
65  mean_var_grid_desc_m,
66  get_reduce_count_per_thread,
67  num_k_block_tile_iteration,
68  epsilon,
69  p_x,
70  p_welford_mean,
71  p_welford_variance,
72  p_welford_count,
73  p_control,
74  p_scale,
75  p_bias,
76  y_elementwise_op,
77  p_y,
78  updateMovingAverage,
79  averageFactor,
80  resultRunningMean,
81  resultRunningVariance,
82  saveMeanInvVariance,
83  resultSaveMean,
84  resultSaveInvVariance);
85 };
86 
87 template <typename XDataType,
88  typename YDataType,
89  typename AccDataType,
90  typename ScaleDataType,
91  typename BiasDataType,
92  typename MeanVarDataType,
93  typename YElementwiseOp,
94  typename XYGridDesc_M_K,
95  typename MeanVarCountGridDesc_M_G,
96  typename MeanVarCountGridDesc_M_K,
97  typename ScaleBiasGridDesc_M,
98  typename MeanVarGridDesc_M,
99  typename GetReduceCountPerThreadFunctor,
100  index_t BlockSize,
101  index_t MThreadClusterSize,
102  index_t KThreadClusterSize,
103  index_t MThreadSliceSize,
104  index_t KThreadSliceSize,
105  index_t XSrcYDstVectorDim,
106  index_t XSrcVectorSize,
107  index_t YDstVectorSize,
108  index_t ScaleSrcVectorSize,
109  index_t BiasSrcVectorSize,
110  index_t MeanVarSrcDstVectorSize>
112 {
113  static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
114  (XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
115  "Invalid thread slice sizes and/or vector sizes configuration, please check!");
116 
117  static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
118  (XSrcYDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
119  "Invalid thread slice sizes and/or vector sizes configuration, please check!");
120 
121  static constexpr bool reorder_thread_cluster = (XSrcYDstVectorDim == 0);
122 
124 
127 
130 
131  static constexpr auto thread_cluster_desc =
133 
138 
141 
144 
147 
148  using BlockwiseWelford1 = BlockwiseWelford<AccDataType,
149  BlockSize,
152  false>;
153 
154  using BlockwiseWelford2 = BlockwiseWelford<AccDataType,
155  BlockSize,
158  true>;
159 
161 
162  static constexpr auto I0 = Number<0>{};
163  static constexpr auto I1 = Number<1>{};
164 
165  static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
166  static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
167 
168  __device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
169  const XYGridDesc_M_K& y_grid_desc_m_k,
170  const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
171  const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k,
172  const ScaleBiasGridDesc_M& scale_grid_desc_m,
173  const ScaleBiasGridDesc_M& bias_grid_desc_m,
174  const MeanVarGridDesc_M& mean_var_grid_desc_m,
175  const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
176  index_t num_k_block_tile_iteration,
177  AccDataType epsilon,
178  const XDataType* const __restrict__ p_x,
179  MeanVarDataType* const __restrict__ p_welford_mean,
180  MeanVarDataType* const __restrict__ p_welford_variance,
181  int32_t* const __restrict__ p_welford_count,
182  int32_t* const __restrict__ p_control,
183  const ScaleDataType* const __restrict__ p_scale,
184  const BiasDataType* const __restrict__ p_bias,
185  const YElementwiseOp y_elementwise_op,
186  YDataType* const __restrict__ p_y,
187  bool updateMovingAverage,
188  AccDataType averageFactor,
189  MeanVarDataType* const __restrict__ resultRunningMean,
190  MeanVarDataType* const __restrict__ resultRunningVariance,
191  bool saveMeanInvVariance,
192  MeanVarDataType* const __restrict__ resultSaveMean,
193  MeanVarDataType* const __restrict__ resultSaveInvVariance)
194  {
195  using ck::math::sqrt;
196 
197  const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(I1);
198 
199  const index_t thread_local_id = get_thread_local_1d_id();
200  const index_t block_global_id = get_block_1d_id();
201  const index_t blkgroup_id = block_global_id / blkgroup_size;
202  const index_t block_local_id = block_global_id % blkgroup_size;
203 
204  if(block_local_id == 0)
205  gms_init(BlockSize / warpSize * blkgroup_size, &p_control[blkgroup_id * 2]);
206 
207  const auto thread_cluster_idx =
208  thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
209 
210  const auto thread_m_cluster_id = thread_cluster_idx[I0];
211  const auto thread_k_cluster_id = thread_cluster_idx[I1];
212 
213  using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
214  using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
215  using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
216 
217  constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
219  constexpr auto thread_buffer_desc_m =
221  constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
223 
225  x_thread_buf;
226 
230 
232  tmp_mean_thread_buf;
234  tmp_var_thread_buf;
236 
237  const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
238 
239  auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
240  AccDataType,
241  XYGridDesc_M_K,
242  decltype(thread_buffer_desc_m_k),
243  ThreadBufferLengths_M_K,
245  XSrcYDstVectorDim,
246  XSrcVectorSize,
247  1,
248  true>(
249  x_grid_desc_m_k,
250  make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
251  block_local_id * reduceSizePerBlock +
252  thread_k_cluster_id * KThreadSliceSize));
253 
254  constexpr auto xy_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileSize);
255 
256  const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
257  p_x, x_grid_desc_m_k.GetElementSpaceSize());
258 
259  // Step 1: each workgroup does local welford reduction
260 
261  auto threadwise_welford_1 = ThreadwiseWelford1();
262  threadwise_welford_1.max_count_ =
263  get_reduce_count_per_thread(block_local_id, thread_k_cluster_id);
264 
265  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
266  mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
267  var_thread_buf(I) = type_convert<AccDataType>(0.0f);
268  });
269 
270  for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
271  {
272  threadwise_x_load.Run(x_grid_desc_m_k,
273  x_global_val_buf,
274  thread_buffer_desc_m_k,
275  make_tuple(I0, I0),
276  x_thread_buf);
277 
278  threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_copy_fwd_step_m_k);
279  threadwise_welford_1.Run(x_thread_buf, mean_thread_buf, var_thread_buf);
280  }
281 
282  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
283  if constexpr(I > 0)
284  block_sync_lds();
285 
286  count_thread_buf(I) = threadwise_welford_1.cur_count_;
287  BlockwiseWelford1::Run(mean_thread_buf(I), var_thread_buf(I), count_thread_buf(I));
288  });
289 
290  // Step 2: each workgroup writes its local welford result to workspace memory
291 
292  auto mean_global_val_buf =
293  make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
294  p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
295 
296  auto var_global_val_buf =
297  make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
298  p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
299 
300  auto count_global_val_buf =
301  make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
302  p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
303 
304  auto threadwise_mean_var_store_m_g =
306  MeanVarDataType,
307  decltype(thread_buffer_desc_m_1),
308  MeanVarCountGridDesc_M_G,
310  ThreadBufferLengths_M_1,
312  0,
313  1,
315  1,
316  true>(
317  mean_var_count_grid_desc_m_g,
318  make_multi_index(blkgroup_id * M_BlockTileSize +
319  thread_m_cluster_id * MThreadSliceSize,
320  block_local_id),
321  PassThroughOp{});
322 
323  auto threadwise_count_store_m_g =
325  int32_t,
326  decltype(thread_buffer_desc_m_1),
327  MeanVarCountGridDesc_M_G,
329  ThreadBufferLengths_M_1,
331  0,
332  1,
334  1,
335  true>(
336  mean_var_count_grid_desc_m_g,
337  make_multi_index(blkgroup_id * M_BlockTileSize +
338  thread_m_cluster_id * MThreadSliceSize,
339  block_local_id),
340  PassThroughOp{});
341 
342  if(thread_k_cluster_id == 0)
343  {
344  threadwise_mean_var_store_m_g.Run(thread_buffer_desc_m_1,
345  make_tuple(I0, I0),
346  mean_thread_buf,
347  mean_var_count_grid_desc_m_g,
348  mean_global_val_buf);
349 
350  threadwise_mean_var_store_m_g.Run(thread_buffer_desc_m_1,
351  make_tuple(I0, I0),
352  var_thread_buf,
353  mean_var_count_grid_desc_m_g,
354  var_global_val_buf);
355 
356  threadwise_count_store_m_g.Run(thread_buffer_desc_m_1,
357  make_tuple(I0, I0),
358  count_thread_buf,
359  mean_var_count_grid_desc_m_g,
360  count_global_val_buf);
361  };
362 
363  gms_barrier(&p_control[blkgroup_id * 2]);
364 
365  if(block_local_id == 0)
366  gms_reset(&p_control[blkgroup_id * 2]);
367 
368  // Step 3: each workgroup reads welford results from workspace memory and does final welford
369  // reduction
370 
371  auto threadwise_mean_var_load_m_k =
372  ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
373  AccDataType,
374  MeanVarCountGridDesc_M_K,
375  decltype(thread_buffer_desc_m_1),
376  ThreadBufferLengths_M_1,
378  0,
379  1,
380  1,
381  true>(
382  mean_var_count_grid_desc_m_k,
383  make_multi_index(blkgroup_id * M_BlockTileSize +
384  thread_m_cluster_id * MThreadSliceSize,
385  thread_k_cluster_id * 1));
386 
387  auto threadwise_count_load_m_k =
389  int32_t,
390  MeanVarCountGridDesc_M_K,
391  decltype(thread_buffer_desc_m_1),
392  ThreadBufferLengths_M_1,
394  0,
395  1,
396  1,
397  true>(
398  mean_var_count_grid_desc_m_k,
399  make_multi_index(blkgroup_id * M_BlockTileSize +
400  thread_m_cluster_id * MThreadSliceSize,
401  thread_k_cluster_id * 1));
402 
403  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
404  mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
405  var_thread_buf(I) = type_convert<AccDataType>(0.0f);
406  count_thread_buf(I) = 0;
407  });
408 
409  constexpr auto mean_var_count_read_fwd_step_m_k = make_multi_index(0, KThreadClusterSize);
410 
411  int32_t reducedSize = 0;
412  while(reducedSize < blkgroup_size)
413  {
414  threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
415  mean_global_val_buf,
416  thread_buffer_desc_m_1,
417  make_tuple(I0, I0),
418  tmp_mean_thread_buf);
419 
420  threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
421  var_global_val_buf,
422  thread_buffer_desc_m_1,
423  make_tuple(I0, I0),
424  tmp_var_thread_buf);
425 
426  threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k,
427  count_global_val_buf,
428  thread_buffer_desc_m_1,
429  make_tuple(I0, I0),
430  tmp_count_thread_buf);
431 
432  ThreadwiseWelford2::Run(tmp_mean_thread_buf,
433  tmp_var_thread_buf,
434  tmp_count_thread_buf,
435  mean_thread_buf,
436  var_thread_buf,
437  count_thread_buf);
438 
439  reducedSize += KThreadClusterSize;
440 
441  threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
442  mean_var_count_read_fwd_step_m_k);
443  threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
444  mean_var_count_read_fwd_step_m_k);
445  };
446 
447  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
448  if constexpr(I > 0)
449  block_sync_lds();
450 
451  BlockwiseWelford2::Run(mean_thread_buf(I), var_thread_buf(I), count_thread_buf(I));
452  });
453 
454  // Step 4: do normalization using the mean/variance
455 
457 
459 
461  y_thread_buf;
462 
463  auto threadwise_y_store =
465  YDataType,
466  decltype(thread_buffer_desc_m_k),
467  XYGridDesc_M_K,
468  YElementwiseOp,
469  ThreadBufferLengths_M_K,
471  XSrcYDstVectorDim,
472  YDstVectorSize,
474  1,
475  true>(
476  y_grid_desc_m_k,
478  blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
479  block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize),
480  y_elementwise_op);
481 
482  auto threadwise_scale_load =
483  ThreadwiseTensorSliceTransfer_v2<ScaleDataType,
484  AccDataType,
485  ScaleBiasGridDesc_M,
486  decltype(thread_buffer_desc_m),
487  ThreadBufferLengths_M,
488  Sequence<0>,
489  0,
490  ScaleSrcVectorSize,
491  1,
492  true>(
493  scale_grid_desc_m,
494  make_multi_index(blkgroup_id * M_BlockTileSize +
495  thread_m_cluster_id * MThreadSliceSize));
496 
497  auto threadwise_bias_load = ThreadwiseTensorSliceTransfer_v2<BiasDataType,
498  AccDataType,
499  ScaleBiasGridDesc_M,
500  decltype(thread_buffer_desc_m),
501  ThreadBufferLengths_M,
502  Sequence<0>,
503  0,
504  BiasSrcVectorSize,
505  1,
506  true>(
507  bias_grid_desc_m,
508  make_multi_index(blkgroup_id * M_BlockTileSize +
509  thread_m_cluster_id * MThreadSliceSize));
510 
511  const auto scale_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
512  p_scale, scale_grid_desc_m.GetElementSpaceSize());
513 
514  const auto bias_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
515  p_bias, bias_grid_desc_m.GetElementSpaceSize());
516 
517  auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
518  p_y, y_grid_desc_m_k.GetElementSpaceSize());
519 
520  threadwise_scale_load.Run(scale_grid_desc_m,
521  scale_global_val_buf,
522  thread_buffer_desc_m,
523  make_tuple(I0),
524  scale_thread_buf);
525 
526  threadwise_bias_load.Run(bias_grid_desc_m,
527  bias_global_val_buf,
528  thread_buffer_desc_m,
529  make_tuple(I0),
530  bias_thread_buf);
531 
532  threadwise_x_load.SetSrcSliceOrigin(
533  x_grid_desc_m_k,
534  make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
535  block_local_id * reduceSizePerBlock +
536  thread_k_cluster_id * KThreadSliceSize));
537 
538  for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
539  {
540  threadwise_x_load.Run(x_grid_desc_m_k,
541  x_global_val_buf,
542  thread_buffer_desc_m_k,
543  make_tuple(I0, I0),
544  x_thread_buf);
545 
546  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
547  AccDataType multiplier =
548  scale_thread_buf[Number<iM>{}] / sqrt(var_thread_buf[iM] + epsilon);
549 
550  AccDataType fused_mean_bias =
551  bias_thread_buf[Number<iM>{}] - mean_thread_buf[iM] * multiplier;
552 
553  static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
554  constexpr auto offset =
555  thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
556 
557  // normalize
558  y_thread_buf(Number<offset>{}) =
559  x_thread_buf[Number<offset>{}] * multiplier + fused_mean_bias;
560  });
561  });
562 
563  threadwise_y_store.Run(thread_buffer_desc_m_k,
564  make_tuple(I0, I0),
565  y_thread_buf,
566  y_grid_desc_m_k,
567  y_global_val_buf);
568 
569  threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_copy_fwd_step_m_k);
570  threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, xy_copy_fwd_step_m_k);
571  }
572 
573  // Step 5: update the moving average of mean and variance (optional)
574 
575  if(updateMovingAverage && block_local_id == 0 && thread_k_cluster_id == 0)
576  {
578  running_mean_thread_buf;
580  running_var_thread_buf;
581 
582  auto running_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
583  resultRunningMean, mean_var_grid_desc_m.GetElementSpaceSize());
584 
585  auto running_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
586  resultRunningVariance, mean_var_grid_desc_m.GetElementSpaceSize());
587 
588  auto threadwise_mean_var_load =
589  ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
590  AccDataType,
591  MeanVarGridDesc_M,
592  decltype(thread_buffer_desc_m),
593  ThreadBufferLengths_M,
594  Sequence<0>,
595  0,
596  MeanVarSrcDstVectorSize,
597  1,
598  true>(
599  mean_var_grid_desc_m,
600  make_multi_index(blkgroup_id * M_BlockTileSize +
601  thread_m_cluster_id * MThreadSliceSize));
602 
603  threadwise_mean_var_load.Run(mean_var_grid_desc_m,
604  running_mean_global_buf,
605  thread_buffer_desc_m,
606  make_tuple(I0),
607  running_mean_thread_buf);
608 
609  threadwise_mean_var_load.Run(mean_var_grid_desc_m,
610  running_var_global_buf,
611  thread_buffer_desc_m,
612  make_tuple(I0),
613  running_var_thread_buf);
614 
615  AccDataType oneMinusAverageFactor = type_convert<AccDataType>(1.0) - averageFactor;
616 
617  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
618  running_mean_thread_buf(I) = running_mean_thread_buf[I] * oneMinusAverageFactor +
619  mean_thread_buf[I] * averageFactor;
620  running_var_thread_buf(I) = running_var_thread_buf[I] * oneMinusAverageFactor +
621  var_thread_buf[I] * averageFactor;
622  });
623 
624  auto threadwise_mean_var_store =
626  MeanVarDataType,
627  decltype(thread_buffer_desc_m),
628  MeanVarGridDesc_M,
630  ThreadBufferLengths_M,
631  Sequence<0>,
632  0,
633  MeanVarSrcDstVectorSize,
635  1,
636  true>(
637  mean_var_grid_desc_m,
638  make_multi_index(blkgroup_id * M_BlockTileSize +
639  thread_m_cluster_id * MThreadSliceSize),
640  PassThroughOp{});
641 
642  threadwise_mean_var_store.Run(thread_buffer_desc_m,
643  make_tuple(I0),
644  running_mean_thread_buf,
645  mean_var_grid_desc_m,
646  running_mean_global_buf);
647 
648  threadwise_mean_var_store.Run(thread_buffer_desc_m,
649  make_tuple(I0),
650  running_var_thread_buf,
651  mean_var_grid_desc_m,
652  running_var_global_buf);
653  };
654 
655  // Step 6: save mean and inv-variance (optional)
656 
657  if(saveMeanInvVariance && block_local_id == 0 && thread_k_cluster_id == 0)
658  {
659  auto result_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
660  resultSaveMean, mean_var_grid_desc_m.GetElementSpaceSize());
661 
662  auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
663  resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
664 
665  // calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance
666  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
667  var_thread_buf(I) =
668  type_convert<AccDataType>(1.0f) / sqrt(epsilon + var_thread_buf[I]);
669  });
670 
671  auto threadwise_mean_inv_var_store =
673  MeanVarDataType,
674  decltype(thread_buffer_desc_m),
675  MeanVarGridDesc_M,
677  ThreadBufferLengths_M,
678  Sequence<0>,
679  0,
680  MeanVarSrcDstVectorSize,
682  1,
683  true>(
684  mean_var_grid_desc_m,
685  make_multi_index(blkgroup_id * M_BlockTileSize +
686  thread_m_cluster_id * MThreadSliceSize),
687  PassThroughOp{});
688 
689  threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
690  make_tuple(I0),
691  mean_thread_buf,
692  mean_var_grid_desc_m,
693  result_mean_global_buf);
694 
695  threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
696  make_tuple(I0),
697  var_thread_buf,
698  mean_var_grid_desc_m,
699  result_inv_var_global_buf);
700  };
701  }
702 }; // namespace ck
703 
704 } // namespace ck
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:16
__global__ void kernel_multiblock_batchnorm_forward(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K y_grid_desc_m_k, const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g, const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M bias_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, index_t num_k_block_tile_iteration, AccDataType epsilon, const XDataType *const __restrict__ p_x, MeanVarDataType *const __restrict__ p_welford_mean, MeanVarDataType *const __restrict__ p_welford_variance, int32_t *const __restrict__ p_welford_count, int32_t *const __restrict__ p_control, const ScaleDataType *const __restrict__ p_scale, const BiasDataType *const __restrict__ p_bias, const YElementwiseOp y_elementwise_op, YDataType *const __restrict__ p_y, bool updateMovingAverage, AccDataType averageFactor, MeanVarDataType *const __restrict__ resultRunningMean, MeanVarDataType *const __restrict__ resultRunningVariance, bool saveMeanInvVariance, MeanVarDataType *const __restrict__ resultSaveMean, MeanVarDataType *const __restrict__ resultSaveInvVariance)
Definition: gridwise_multiblock_batchnorm_forward.hpp:31
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
Definition: blockwise_welford.hpp:25
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition: blockwise_welford.hpp:51
Definition: gridwise_multiblock_batchnorm_forward.hpp:112
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< 1 >{}))) ThreadReduceSrcDesc_M_1
Definition: gridwise_multiblock_batchnorm_forward.hpp:140
ThreadwiseWelford< AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M > ThreadwiseWelford1
Definition: gridwise_multiblock_batchnorm_forward.hpp:143
static constexpr auto I0
Definition: gridwise_multiblock_batchnorm_forward.hpp:162
static constexpr bool reorder_thread_cluster
Definition: gridwise_multiblock_batchnorm_forward.hpp:121
static constexpr index_t K_BlockTileSize
Definition: gridwise_multiblock_batchnorm_forward.hpp:166
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition: gridwise_multiblock_batchnorm_forward.hpp:123
static constexpr auto I1
Definition: gridwise_multiblock_batchnorm_forward.hpp:163
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_multiblock_batchnorm_forward.hpp:160
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition: gridwise_multiblock_batchnorm_forward.hpp:126
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition: gridwise_multiblock_batchnorm_forward.hpp:129
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition: gridwise_multiblock_batchnorm_forward.hpp:137
static constexpr auto thread_cluster_desc
Definition: gridwise_multiblock_batchnorm_forward.hpp:131
static constexpr index_t M_BlockTileSize
Definition: gridwise_multiblock_batchnorm_forward.hpp:165
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition: gridwise_multiblock_batchnorm_forward.hpp:135
static __device__ void Run(const XYGridDesc_M_K &x_grid_desc_m_k, const XYGridDesc_M_K &y_grid_desc_m_k, const MeanVarCountGridDesc_M_G &mean_var_count_grid_desc_m_g, const MeanVarCountGridDesc_M_K &mean_var_count_grid_desc_m_k, const ScaleBiasGridDesc_M &scale_grid_desc_m, const ScaleBiasGridDesc_M &bias_grid_desc_m, const MeanVarGridDesc_M &mean_var_grid_desc_m, const GetReduceCountPerThreadFunctor &get_reduce_count_per_thread, index_t num_k_block_tile_iteration, AccDataType epsilon, const XDataType *const __restrict__ p_x, MeanVarDataType *const __restrict__ p_welford_mean, MeanVarDataType *const __restrict__ p_welford_variance, int32_t *const __restrict__ p_welford_count, int32_t *const __restrict__ p_control, const ScaleDataType *const __restrict__ p_scale, const BiasDataType *const __restrict__ p_bias, const YElementwiseOp y_elementwise_op, YDataType *const __restrict__ p_y, bool updateMovingAverage, AccDataType averageFactor, MeanVarDataType *const __restrict__ resultRunningMean, MeanVarDataType *const __restrict__ resultRunningVariance, bool saveMeanInvVariance, MeanVarDataType *const __restrict__ resultSaveMean, MeanVarDataType *const __restrict__ resultSaveInvVariance)
Definition: gridwise_multiblock_batchnorm_forward.hpp:168
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: threadwise_tensor_slice_transfer.hpp:39
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:66
Definition: threadwise_tensor_slice_transfer.hpp:214
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:243
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:355
Definition: threadwise_welford.hpp:18
Definition: threadwise_welford.hpp:83
static __device__ void Run(const SrcMeanBufferType &src_mean_buf, const SrcVarBufferType &src_var_buf, const SrcCountBufferType &src_count_buf, DstMeanBufferType &dst_mean_buf, DstVarBufferType &dst_var_buf, DstCountBufferType &dst_count_buf)
Definition: threadwise_welford.hpp:110
Definition: functional.hpp:100
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
Definition: unary_element_wise_operation.hpp:241