/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp Source File
gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
12 
13 namespace ck {
14 
15 template <typename GridwiseWelfordSecondHalfReduceFirstHalf_,
16  typename XDataType,
17  typename DyDataType,
18  typename AccDataType,
19  typename ScaleDataType,
20  typename DscaleDbiasDataType,
21  typename MeanVarDataType,
22  typename DyElementwiseOp,
23  typename XYGridDesc_M_K,
24  typename MeanVarGridDesc_M,
25  typename MeanVarCountGridDesc_M_K,
26  typename DscaleDbiasGridDesc_M_G>
28  const XYGridDesc_M_K x_grid_desc_m_k,
29  const XYGridDesc_M_K dy_grid_desc_m_k,
30  const MeanVarGridDesc_M mean_var_grid_desc_m,
31  const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k,
32  const DscaleDbiasGridDesc_M_G dscale_dbias_grid_desc_m_g,
33  index_t blkgroup_size,
34  index_t num_xy_k_block_tile_iteration,
35  index_t num_mean_var_count_k_block_tile_iteration,
36  AccDataType epsilon,
37  bool haveSavedMeanInvVar,
38  const MeanVarDataType* const __restrict__ p_savedMean,
39  const MeanVarDataType* const __restrict__ p_savedInvVar,
40  const MeanVarDataType* const __restrict__ p_in_welford_mean,
41  const MeanVarDataType* const __restrict__ p_in_welford_variance,
42  const int32_t* const __restrict__ p_in_welford_count,
43  const DyElementwiseOp dy_elementwise_op,
44  MeanVarDataType* const __restrict__ p_out_welford_mean,
45  MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
46  const XDataType* const __restrict__ p_x,
47  const DyDataType* const __restrict__ p_dy,
48  DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
49  DscaleDbiasDataType* const __restrict__ p_reduce_dbias)
50 {
51  GridwiseWelfordSecondHalfReduceFirstHalf_::Run(x_grid_desc_m_k,
52  dy_grid_desc_m_k,
53  mean_var_grid_desc_m,
54  mean_var_count_grid_desc_m_k,
55  dscale_dbias_grid_desc_m_g,
56  blkgroup_size,
57  num_xy_k_block_tile_iteration,
58  num_mean_var_count_k_block_tile_iteration,
59  epsilon,
60  haveSavedMeanInvVar,
61  p_savedMean,
62  p_savedInvVar,
63  p_in_welford_mean,
64  p_in_welford_variance,
65  p_in_welford_count,
66  dy_elementwise_op,
67  p_out_welford_mean,
68  p_out_welford_inv_variance,
69  p_x,
70  p_dy,
71  p_reduce_dscale,
72  p_reduce_dbias);
73 };
74 
75 template <typename XDataType,
76  typename DyDataType,
77  typename AccDataType,
78  typename ScaleDataType,
79  typename DscaleDbiasDataType,
80  typename MeanVarDataType,
81  typename DyElementwiseOp,
82  typename XYGridDesc_M_K,
83  typename MeanVarGridDesc_M,
84  typename MeanVarCountGridDesc_M_K,
85  typename DscaleDbiasGridDesc_M_G,
86  index_t BlockSize,
87  index_t MThreadClusterSize,
88  index_t KThreadClusterSize,
89  index_t MThreadSliceSize,
90  index_t KThreadSliceSize,
91  index_t XDyVectorDim,
92  index_t XSrcVectorSize,
93  index_t DySrcVectorSize,
94  index_t MeanVarSrcVectorSize>
96 {
97  static_assert((XDyVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 &&
98  MThreadSliceSize % DySrcVectorSize == 0) ||
99  (XDyVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 &&
100  KThreadSliceSize % DySrcVectorSize == 0),
101  "Invalid thread slice sizes and/or vector sizes configuration, please check!");
102 
103  static constexpr bool reorder_thread_cluster = (XDyVectorDim == 0);
104 
106 
109 
112 
113  static constexpr auto thread_cluster_desc =
115 
122 
125 
126  using BlockwiseWelford = BlockwiseWelford<AccDataType,
127  BlockSize,
130 
132  BlockSize,
136  false>;
137 
142  false>;
143 
145 
146  static constexpr auto I0 = Number<0>{};
147  static constexpr auto I1 = Number<1>{};
148 
149  static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
150  static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
151 
152  // clang-format off
153  // Two of the steps of Multiblock BatchNorm Backward
154  // Step 1: Second half of Welford method to calculate mean and variance, as well as getting inv-variance = 1/sqrt(epsilon+variance)
155  // Step 2: First half of Reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
156  // clang-format on
157  __device__ static void Run(const XYGridDesc_M_K& x_grid_desc_m_k,
158  const XYGridDesc_M_K& dy_grid_desc_m_k,
159  const MeanVarGridDesc_M& mean_var_grid_desc_m,
160  const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k,
161  const DscaleDbiasGridDesc_M_G& dscale_dbias_grid_desc_m_g,
162  index_t blkgroup_size,
163  index_t num_xy_k_block_tile_iteration,
164  index_t num_mean_var_count_k_block_tile_iteration,
165  AccDataType epsilon,
166  bool haveSavedMeanInvVar,
167  const MeanVarDataType* const __restrict__ p_savedMean,
168  const MeanVarDataType* const __restrict__ p_savedInvVar,
169  const MeanVarDataType* const __restrict__ p_in_welford_mean,
170  const MeanVarDataType* const __restrict__ p_in_welford_variance,
171  const int32_t* const __restrict__ p_in_welford_count,
172  const DyElementwiseOp dy_elementwise_op,
173  MeanVarDataType* const __restrict__ p_out_welford_mean,
174  MeanVarDataType* const __restrict__ p_out_welford_inv_variance,
175  const XDataType* const __restrict__ p_x,
176  const DyDataType* const __restrict__ p_dy,
177  DscaleDbiasDataType* const __restrict__ p_reduce_dscale,
178  DscaleDbiasDataType* const __restrict__ p_reduce_dbias)
179  {
180  __shared__ AccDataType p_reduce_work_buffer[BlockSize];
181 
182  auto reduce_work_buf =
183  make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
184 
186  in_welford_mean_thread_buf;
188  in_welford_var_thread_buf;
190  in_welford_count_thread_buf;
191 
193  welford_mean_thread_buf;
195  welford_var_thread_buf;
197  welford_count_thread_buf;
198 
200  welford_mean_thread_buf;
202  inv_var_thread_buf = welford_var_thread_buf;
203 
205  x_thread_buf;
207  dy_thread_buf;
208 
209  // buffer of values of dy * (x-mean) * inv-variance, used as input of Blockwise reduction
211  tmp1_thread_buf;
212 
214  reduce_dscale_thread_buf;
216  reduce_dbias_thread_buf;
217 
218  const index_t thread_local_id = get_thread_local_1d_id();
219  const index_t block_global_id = get_block_1d_id();
220  const index_t blkgroup_id = block_global_id / blkgroup_size;
221  const index_t block_local_id = block_global_id % blkgroup_size;
222 
223  const auto thread_cluster_idx =
224  thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
225 
226  const auto thread_m_cluster_id = thread_cluster_idx[I0];
227  const auto thread_k_cluster_id = thread_cluster_idx[I1];
228 
229  using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, KThreadSliceSize>;
230  using ThreadBufferLengths_M = Sequence<MThreadSliceSize>;
231  using ThreadBufferLengths_M_1 = Sequence<MThreadSliceSize, 1>;
232  constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
234  constexpr auto thread_buffer_desc_m =
236  constexpr auto thread_buffer_desc_m_1 = make_naive_tensor_descriptor_packed(
238 
239  // clang-format off
240  // Step 1: load existing mean and inv-variance, or do final welford reduction on mean and variance as well as get inv-variance = 1/sqrt(epsilon+variance)
241  // clang-format on
242 
243  if(haveSavedMeanInvVar)
244  {
245  const auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
246  p_savedMean, mean_var_grid_desc_m.GetElementSpaceSize());
247 
248  const auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
249  p_savedInvVar, mean_var_grid_desc_m.GetElementSpaceSize());
250 
251  auto threadwise_mean_inv_var_load =
252  ThreadwiseTensorSliceTransfer_v2<MeanVarDataType,
253  AccDataType,
254  MeanVarGridDesc_M,
255  decltype(thread_buffer_desc_m),
256  ThreadBufferLengths_M,
257  Sequence<0>,
258  0,
259  MeanVarSrcVectorSize,
260  1,
261  true>(
262  mean_var_grid_desc_m,
263  make_multi_index(blkgroup_id * M_BlockTileSize +
264  thread_m_cluster_id * MThreadSliceSize));
265 
266  threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
267  mean_global_buf,
268  thread_buffer_desc_m,
269  make_tuple(I0),
270  mean_thread_buf);
271 
272  threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
273  inv_var_global_buf,
274  thread_buffer_desc_m,
275  make_tuple(I0),
276  inv_var_thread_buf);
277  }
278  else
279  {
280  const auto welford_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
281  p_in_welford_mean, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
282 
283  const auto welford_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
284  p_in_welford_variance, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
285 
286  const auto welford_count_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
287  p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
288 
289  auto threadwise_mean_var_load_m_k =
291  AccDataType,
292  MeanVarCountGridDesc_M_K,
293  decltype(thread_buffer_desc_m_1),
294  ThreadBufferLengths_M_1,
296  1,
297  1,
298  1,
299  true>(
300  mean_var_count_grid_desc_m_k,
301  make_multi_index(blkgroup_id * M_BlockTileSize +
302  thread_m_cluster_id * MThreadSliceSize,
303  thread_k_cluster_id * 1));
304 
305  auto threadwise_count_load_m_k =
307  int32_t,
308  MeanVarCountGridDesc_M_K,
309  decltype(thread_buffer_desc_m_1),
310  ThreadBufferLengths_M_1,
312  1,
313  1,
314  1,
315  true>(
316  mean_var_count_grid_desc_m_k,
317  make_multi_index(blkgroup_id * M_BlockTileSize +
318  thread_m_cluster_id * MThreadSliceSize,
319  thread_k_cluster_id * 1));
320 
321  constexpr auto mean_var_count_thread_copy_step_m_k =
322  make_multi_index(0, KThreadClusterSize * 1);
323 
324  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
325  welford_mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
326  welford_var_thread_buf(I) = type_convert<AccDataType>(0.0f);
327  welford_count_thread_buf(I) = 0;
328  });
329 
330  for(index_t reducedTiles = 0; reducedTiles < num_mean_var_count_k_block_tile_iteration;
331  ++reducedTiles)
332  {
333  threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
334  welford_mean_global_buf,
335  thread_buffer_desc_m_1,
336  make_tuple(I0, I0),
337  in_welford_mean_thread_buf);
338 
339  threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
340  welford_var_global_buf,
341  thread_buffer_desc_m_1,
342  make_tuple(I0, I0),
343  in_welford_var_thread_buf);
344 
345  threadwise_count_load_m_k.Run(mean_var_count_grid_desc_m_k,
346  welford_count_global_buf,
347  thread_buffer_desc_m_1,
348  make_tuple(I0, I0),
349  in_welford_count_thread_buf);
350 
351  ThreadwiseWelford::Run(in_welford_mean_thread_buf,
352  in_welford_var_thread_buf,
353  in_welford_count_thread_buf,
354  welford_mean_thread_buf,
355  welford_var_thread_buf,
356  welford_count_thread_buf);
357 
358  threadwise_mean_var_load_m_k.MoveSrcSliceWindow(
359  mean_var_count_grid_desc_m_k, mean_var_count_thread_copy_step_m_k);
360  threadwise_count_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
361  mean_var_count_thread_copy_step_m_k);
362  }
363 
364  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
365  if constexpr(I > 0)
366  block_sync_lds();
367 
368  BlockwiseWelford::Run(welford_mean_thread_buf(I),
369  welford_var_thread_buf(I),
370  welford_count_thread_buf(I));
371  });
372 
373  // calculate inv-variance as 1/sqrt(epsilon+variance), stored in place of variance
374  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
375  welford_var_thread_buf(I) =
376  type_convert<AccDataType>(1.0) / sqrt(welford_var_thread_buf[I] + epsilon);
377  });
378 
379  if(block_local_id == 0 && thread_k_cluster_id == 0)
380  {
381 
382  auto threadwise_mean_inv_var_store =
384  MeanVarDataType,
385  decltype(thread_buffer_desc_m),
386  MeanVarGridDesc_M,
388  ThreadBufferLengths_M,
389  Sequence<0>,
390  0,
391  1,
393  1,
394  true>(
395  mean_var_grid_desc_m,
396  make_multi_index(blkgroup_id * M_BlockTileSize +
397  thread_m_cluster_id * MThreadSliceSize),
398  PassThroughOp{});
399 
400  auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
401  p_out_welford_mean, mean_var_grid_desc_m.GetElementSpaceSize());
402 
403  auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
404  p_out_welford_inv_variance, mean_var_grid_desc_m.GetElementSpaceSize());
405 
406  threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
407  make_tuple(I0),
408  mean_thread_buf,
409  mean_var_grid_desc_m,
410  mean_global_buf);
411 
412  threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
413  make_tuple(I0),
414  inv_var_thread_buf,
415  mean_var_grid_desc_m,
416  inv_var_global_buf);
417  };
418  };
419 
420  const index_t workSizePerBlock = K_BlockTileSize * num_xy_k_block_tile_iteration;
421 
422  auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
423  AccDataType,
424  XYGridDesc_M_K,
425  decltype(thread_buffer_desc_m_k),
426  ThreadBufferLengths_M_K,
428  XDyVectorDim,
429  XSrcVectorSize,
430  1,
431  true>(
432  x_grid_desc_m_k,
433  make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
434  workSizePerBlock * block_local_id +
435  thread_k_cluster_id * KThreadSliceSize));
436 
437  auto threadwise_dy_load = ThreadwiseTensorSliceTransfer_v2<DyDataType,
438  AccDataType,
439  XYGridDesc_M_K,
440  decltype(thread_buffer_desc_m_k),
441  ThreadBufferLengths_M_K,
443  XDyVectorDim,
444  DySrcVectorSize,
445  1,
446  true>(
447  dy_grid_desc_m_k,
448  make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
449  workSizePerBlock * block_local_id +
450  thread_k_cluster_id * KThreadSliceSize));
451 
452  const auto x_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
453  p_x, x_grid_desc_m_k.GetElementSpaceSize());
454 
455  const auto dy_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
456  p_dy, dy_grid_desc_m_k.GetElementSpaceSize());
457 
458  constexpr auto xy_thread_copy_step_m_k = make_multi_index(0, K_BlockTileSize);
459 
460  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
461  reduce_dscale_thread_buf(I) = type_convert<AccDataType>(0);
462  reduce_dbias_thread_buf(I) = type_convert<AccDataType>(0);
463  });
464 
465  // clang-format off
466  // Step 2: first-half of reduction: dbias = sum(dy), dscale = sum(dy * (x-mean) * inv-variance)
467  // clang-format on
468 
469  for(index_t reducedTiles = 0; reducedTiles < num_xy_k_block_tile_iteration; ++reducedTiles)
470  {
471  threadwise_x_load.Run(x_grid_desc_m_k,
472  x_global_buf,
473  thread_buffer_desc_m_k,
474  make_tuple(I0, I0),
475  x_thread_buf);
476 
477  threadwise_dy_load.Run(dy_grid_desc_m_k,
478  dy_global_buf,
479  thread_buffer_desc_m_k,
480  make_tuple(I0, I0),
481  dy_thread_buf);
482 
483  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
484  static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
485  constexpr auto offset =
486  thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK));
487 
488  dy_elementwise_op(dy_thread_buf(Number<offset>{}),
489  dy_thread_buf[Number<offset>{}]);
490 
491  AccDataType norm_x = (x_thread_buf[Number<offset>{}] - mean_thread_buf[iM]) *
492  inv_var_thread_buf[iM];
493 
494  tmp1_thread_buf(Number<offset>{}) = norm_x * dy_thread_buf[Number<offset>{}];
495  });
496  });
497 
498  ThreadwiseReduce::Reduce(tmp1_thread_buf, reduce_dscale_thread_buf);
499  ThreadwiseReduce::Reduce(dy_thread_buf, reduce_dbias_thread_buf);
500 
501  threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k);
502  threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, xy_thread_copy_step_m_k);
503  };
504 
505  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
506  if constexpr(I > 0)
507  block_sync_lds();
508 
509  BlockwiseReduce::Reduce(reduce_work_buf, reduce_dscale_thread_buf(I));
510  block_sync_lds();
511  BlockwiseReduce::Reduce(reduce_work_buf, reduce_dbias_thread_buf(I));
512  });
513 
514  auto threadwise_dscale_dbias_store =
516  DscaleDbiasDataType,
517  decltype(thread_buffer_desc_m_1),
518  DscaleDbiasGridDesc_M_G,
520  ThreadBufferLengths_M_1,
522  1,
523  1,
525  1,
526  true>(
527  dscale_dbias_grid_desc_m_g,
528  make_multi_index(blkgroup_id * M_BlockTileSize +
529  thread_m_cluster_id * MThreadSliceSize,
530  block_local_id),
531  PassThroughOp{});
532 
533  auto reduce_dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
534  p_reduce_dscale, dscale_dbias_grid_desc_m_g.GetElementSpaceSize());
535 
536  auto reduce_dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
537  p_reduce_dbias, dscale_dbias_grid_desc_m_g.GetElementSpaceSize());
538 
539  if(thread_k_cluster_id == 0)
540  {
541  threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1,
542  make_tuple(I0, I0),
543  reduce_dscale_thread_buf,
544  dscale_dbias_grid_desc_m_g,
545  reduce_dscale_global_buf);
546 
547  threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1,
548  make_tuple(I0, I0),
549  reduce_dbias_thread_buf,
550  dscale_dbias_grid_desc_m_g,
551  reduce_dbias_global_buf);
552  };
553  };
554 };
555 
556 } // namespace ck
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__global__ void kernel_welford_second_half_reduce_first_half(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k, const MeanVarGridDesc_M mean_var_grid_desc_m, const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, const DscaleDbiasGridDesc_M_G dscale_dbias_grid_desc_m_g, index_t blkgroup_size, index_t num_xy_k_block_tile_iteration, index_t num_mean_var_count_k_block_tile_iteration, AccDataType epsilon, bool haveSavedMeanInvVar, const MeanVarDataType *const __restrict__ p_savedMean, const MeanVarDataType *const __restrict__ p_savedInvVar, const MeanVarDataType *const __restrict__ p_in_welford_mean, const MeanVarDataType *const __restrict__ p_in_welford_variance, const int32_t *const __restrict__ p_in_welford_count, const DyElementwiseOp dy_elementwise_op, MeanVarDataType *const __restrict__ p_out_welford_mean, MeanVarDataType *const __restrict__ p_out_welford_inv_variance, const XDataType *const __restrict__ p_x, const DyDataType *const __restrict__ p_dy, DscaleDbiasDataType *const __restrict__ p_reduce_dscale, DscaleDbiasDataType *const __restrict__ p_reduce_dbias)
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:27
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:16
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition: blockwise_welford.hpp:51
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:96
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:105
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:108
static constexpr index_t K_BlockTileSize
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:150
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< 1 >{}))) ThreadReduceSrcDesc_M_1
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:119
BlockwiseWelford< AccDataType, BlockSize, ThreadClusterLengths_M_K, ThreadClusterArrangeOrder > BlockwiseWelford
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:129
static constexpr auto I0
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:146
static __device__ void Run(const XYGridDesc_M_K &x_grid_desc_m_k, const XYGridDesc_M_K &dy_grid_desc_m_k, const MeanVarGridDesc_M &mean_var_grid_desc_m, const MeanVarCountGridDesc_M_K &mean_var_count_grid_desc_m_k, const DscaleDbiasGridDesc_M_G &dscale_dbias_grid_desc_m_g, index_t blkgroup_size, index_t num_xy_k_block_tile_iteration, index_t num_mean_var_count_k_block_tile_iteration, AccDataType epsilon, bool haveSavedMeanInvVar, const MeanVarDataType *const __restrict__ p_savedMean, const MeanVarDataType *const __restrict__ p_savedInvVar, const MeanVarDataType *const __restrict__ p_in_welford_mean, const MeanVarDataType *const __restrict__ p_in_welford_variance, const int32_t *const __restrict__ p_in_welford_count, const DyElementwiseOp dy_elementwise_op, MeanVarDataType *const __restrict__ p_out_welford_mean, MeanVarDataType *const __restrict__ p_out_welford_inv_variance, const XDataType *const __restrict__ p_x, const DyDataType *const __restrict__ p_dy, DscaleDbiasDataType *const __restrict__ p_reduce_dscale, DscaleDbiasDataType *const __restrict__ p_reduce_dbias)
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:157
static constexpr bool reorder_thread_cluster
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:103
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:117
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:144
static constexpr auto I1
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:147
static constexpr auto thread_cluster_desc
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:113
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:111
static constexpr index_t M_BlockTileSize
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:149
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:121
Definition: reduction_functions_blockwise.hpp:28
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition: reduction_functions_blockwise.hpp:44
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: reduction_functions_threadwise.hpp:23
static __device__ void Reduce(const SrcBufferType &src_buf, DstBufferType &dst_buf)
Definition: reduction_functions_threadwise.hpp:36
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: threadwise_tensor_slice_transfer.hpp:214
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:243
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:355
Definition: threadwise_welford.hpp:83
static __device__ void Run(const SrcMeanBufferType &src_mean_buf, const SrcVarBufferType &src_var_buf, const SrcCountBufferType &src_count_buf, DstMeanBufferType &dst_mean_buf, DstVarBufferType &dst_var_buf, DstCountBufferType &dst_count_buf)
Definition: threadwise_welford.hpp:110
Definition: functional.hpp:100
Definition: integral_constant.hpp:10
Definition: reduction_operator.hpp:37
Definition: functional2.hpp:31
Definition: unary_element_wise_operation.hpp:241