/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_layernorm_welford_variance.hpp Source File
gridwise_elementwise_layernorm_welford_variance.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
12 
13 namespace ck {
14 
15 // X = Elementwise(input1, input2, input3, ...)
16 // Y = Normalization(X, beta, gamma)
17 template <typename InDataTypePointerTuple,
18  typename XDataType,
19  typename GammaDataType,
20  typename BetaDataType,
21  typename YDataType,
22  typename AccDataType,
23  typename XElementwiseOperation,
24  typename YElementwiseOperation,
25  typename InGrid2dDescTuple,
26  typename GridDesc_M_K,
27  index_t BlockSize,
28  index_t MThreadClusterSize,
29  index_t KThreadClusterSize,
30  index_t MThreadSliceSize,
31  index_t KThreadSliceSize,
32  index_t XSrcVectorDim,
33  index_t XSrcVectorSize,
34  index_t GammaSrcVectorDim,
35  index_t GammaSrcVectorSize,
36  index_t BetaSrcVectorDim,
37  index_t BetaSrcVectorSize,
38  index_t YDstVectorDim,
39  index_t YDstVectorSize,
40  bool SweepOnce>
42 {
43  static_assert((XSrcVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
44  (XSrcVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
45  "Invalid thread slice sizes and/or vector sizes configuration, please check!");
46 
47  static_assert((YDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
48  (YDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
49  "Invalid thread slice sizes and/or vector sizes configuration, please check!");
50 
51  static constexpr index_t NumInput = InDataTypePointerTuple::Size();
52 
53  static constexpr bool reorder_thread_cluster = (XSrcVectorDim == 0);
54 
56 
59 
62 
63  static constexpr auto thread_cluster_desc =
65 
70 
73 
74  using BlockwiseWelford = BlockwiseWelford<AccDataType,
75  BlockSize,
78 
79  static constexpr auto I0 = Number<0>{};
80  static constexpr auto I1 = Number<1>{};
81  static constexpr auto I2 = Number<2>{};
82 
83  static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
84  static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
85  static constexpr index_t K_BlockTileStepSize = KThreadClusterSize * XSrcVectorSize;
86 
87  static constexpr auto XThreadBufferNumber = Number<KThreadSliceSize / XSrcVectorSize>{};
88  static constexpr auto GammaThreadBufferNumber = Number<KThreadSliceSize / GammaSrcVectorSize>{};
89  static constexpr auto BetaThreadBufferNumber = Number<KThreadSliceSize / BetaSrcVectorSize>{};
90  static constexpr auto YThreadBufferNumber = Number<KThreadSliceSize / YDstVectorSize>{};
91 
92  __device__ static int GetKPerThread(const GridDesc_M_K& x_grid_desc_m_k,
93  int thread_k_cluster_id)
94  {
95  int kPerBlock = x_grid_desc_m_k.GetTransforms()[I2].GetUpperLengths()[I0];
96  int kPerThread =
97  kPerBlock < K_BlockTileSize ? 0 : KThreadSliceSize * (kPerBlock / K_BlockTileSize);
98  int kPerBlockTail = kPerBlock - kPerThread * KThreadClusterSize;
99 
100  if(kPerBlockTail > 0)
101  {
103  int thread_max_len =
104  (thread_k_cluster_id + 1) * XSrcVectorSize + K_BlockTileStepSize * i;
105  int delta = thread_max_len - kPerBlockTail;
106  delta = math::clamp(thread_max_len - kPerBlockTail, 0, XSrcVectorSize);
107  kPerThread += XSrcVectorSize - delta;
108  });
109  }
110 
111  return kPerThread;
112  }
113 
114  __device__ static void Run(const InGrid2dDescTuple in_grid_2d_desc_tuple,
115  const GridDesc_M_K& x_grid_desc_m_k,
116  const GridDesc_M_K& gamma_grid_desc_m_k,
117  const GridDesc_M_K& beta_grid_desc_m_k,
118  const GridDesc_M_K& y_grid_desc_m_k,
119  index_t num_k_block_tile_iteration,
120  AccDataType epsilon,
121  const InDataTypePointerTuple p_in_global_tuple,
122  XDataType* const __restrict__ p_x_lds_,
123  const GammaDataType* const __restrict__ p_gamma_global,
124  const BetaDataType* const __restrict__ p_beta_global,
125  YDataType* const __restrict__ p_y_global,
126  const XElementwiseOperation x_elementwise_op,
127  const YElementwiseOperation y_elementwise_op)
128  {
129  if constexpr(SweepOnce)
130  {
131  num_k_block_tile_iteration = 1;
132  }
133 
134  const index_t thread_local_id = get_thread_local_1d_id();
135  const index_t block_global_id = get_block_1d_id();
136  const index_t grid_size = get_grid_size();
137 
138  auto in_global_buf_tuple = generate_tuple(
139  [&](auto I) {
140  static_assert(in_grid_2d_desc_tuple[I].GetNumOfDimension() ==
141  2); // matrix dimension
142 
143  return make_dynamic_buffer<AddressSpaceEnum::Global>(
144  p_in_global_tuple[I], in_grid_2d_desc_tuple[I].GetElementSpaceSize());
145  },
146  Number<NumInput>{});
147 
148  auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
149  p_y_global, y_grid_desc_m_k.GetElementSpaceSize());
150 
151  auto x_lds_val_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
152  p_x_lds_, x_grid_desc_m_k.GetElementSpaceSize() / grid_size);
153 
154  auto in_thread_buf_tuple = generate_tuple(
155  [&](auto) {
156  return generate_tuple(
157  [&](auto) {
159  AccDataType,
160  MThreadSliceSize * XSrcVectorSize,
161  true>{};
162  },
163  Number<NumInput>{});
164  },
166 
167  auto x_thread_buf = generate_tuple(
168  [&](auto) {
170  AccDataType,
171  MThreadSliceSize * XSrcVectorSize,
172  true>{};
173  },
175 
176  auto gamma_thread_buf = generate_tuple(
177  [&](auto) {
179  AccDataType,
180  MThreadSliceSize * GammaSrcVectorSize,
181  true>{};
182  },
184 
185  auto beta_thread_buf = generate_tuple(
186  [&](auto) {
188  AccDataType,
189  MThreadSliceSize * BetaSrcVectorSize,
190  true>{};
191  },
193 
194  auto y_thread_buf = generate_tuple(
195  [&](auto) {
197  AccDataType,
198  MThreadSliceSize * YDstVectorSize,
199  true>{};
200  },
202 
205 
206  const auto thread_cluster_idx =
207  thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
208 
209  const auto thread_m_cluster_id = thread_cluster_idx[I0];
210  const auto thread_k_cluster_id = thread_cluster_idx[I1];
211 
212  using ThreadBufferLengths_M_K = Sequence<MThreadSliceSize, XSrcVectorSize>;
213 
214  constexpr auto thread_buffer_desc_m_k = make_naive_tensor_descriptor_packed(
216 
217  auto in_global_load_tuple = generate_tuple(
218  [&](auto I) {
219  using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
221 
222  return ThreadwiseTensorSliceTransfer_v2<DataType,
223  AccDataType,
224  decltype(in_grid_2d_desc_tuple[I]),
225  decltype(thread_buffer_desc_m_k),
226  ThreadBufferLengths_M_K,
228  XSrcVectorDim,
229  XSrcVectorSize,
230  1,
231  false>{
232  in_grid_2d_desc_tuple[I],
233  make_multi_index(block_global_id * M_BlockTileSize +
234  thread_m_cluster_id * MThreadSliceSize,
235  thread_k_cluster_id * XSrcVectorSize)};
236  },
237  Number<NumInput>{});
238 
239  auto threadwise_x_load = ThreadwiseTensorSliceTransfer_v2<XDataType,
240  AccDataType,
241  GridDesc_M_K,
242  decltype(thread_buffer_desc_m_k),
243  ThreadBufferLengths_M_K,
245  XSrcVectorDim,
246  XSrcVectorSize,
247  1,
248  true>(
249  x_grid_desc_m_k,
250  make_multi_index(thread_m_cluster_id * MThreadSliceSize,
251  thread_k_cluster_id * XSrcVectorSize));
252 
253  auto threadwise_gamma_load =
254  ThreadwiseTensorSliceTransfer_v2<GammaDataType,
255  AccDataType,
256  GridDesc_M_K,
257  decltype(thread_buffer_desc_m_k),
258  ThreadBufferLengths_M_K,
260  GammaSrcVectorDim,
261  GammaSrcVectorSize,
262  1,
263  true>(
264  gamma_grid_desc_m_k,
265  make_multi_index(block_global_id * M_BlockTileSize +
266  thread_m_cluster_id * MThreadSliceSize,
267  thread_k_cluster_id * GammaSrcVectorSize));
268 
269  auto threadwise_beta_load =
271  AccDataType,
272  GridDesc_M_K,
273  decltype(thread_buffer_desc_m_k),
274  ThreadBufferLengths_M_K,
276  BetaSrcVectorDim,
277  BetaSrcVectorSize,
278  1,
279  true>(
280  beta_grid_desc_m_k,
281  make_multi_index(block_global_id * M_BlockTileSize +
282  thread_m_cluster_id * MThreadSliceSize,
283  thread_k_cluster_id * BetaSrcVectorSize));
284 
286  PassThrough pass_through_op;
287  auto threadwise_x_store =
289  XDataType,
290  decltype(thread_buffer_desc_m_k),
291  GridDesc_M_K,
292  PassThrough,
293  ThreadBufferLengths_M_K,
295  XSrcVectorDim,
296  XSrcVectorSize,
298  1,
299  true>(
300  x_grid_desc_m_k,
301  make_multi_index(thread_m_cluster_id * MThreadSliceSize,
302  thread_k_cluster_id * XSrcVectorSize),
303  pass_through_op);
304 
305  auto threadwise_y_store =
307  YDataType,
308  decltype(thread_buffer_desc_m_k),
309  GridDesc_M_K,
310  YElementwiseOperation,
311  ThreadBufferLengths_M_K,
313  YDstVectorDim,
314  YDstVectorSize,
316  1,
317  true>(
318  y_grid_desc_m_k,
319  make_multi_index(block_global_id * M_BlockTileSize +
320  thread_m_cluster_id * MThreadSliceSize,
321  thread_k_cluster_id * YDstVectorSize),
322  y_elementwise_op);
323 
324  // Copy x from Cache
325  // one pass: fwd, second pass: bwd
326  constexpr auto thread_copy_fwd_step_m_k = make_multi_index(0, K_BlockTileStepSize);
327  constexpr auto thread_copy_bwd_step_m_k =
328  make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
329 
330  const auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
331  p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
332 
333  const auto beta_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
334  p_beta_global, beta_grid_desc_m_k.GetElementSpaceSize());
335 
336  auto threadwise_welford = ThreadwiseWelford();
337  threadwise_welford.max_count_ = GetKPerThread(x_grid_desc_m_k, thread_k_cluster_id);
338 
339  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
340  mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
341  var_thread_buf(I) = type_convert<AccDataType>(0.0f);
342  });
343 
344  for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
345  {
346  static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
347  static_for<0, NumInput, 1>{}([&](auto I) { // input load loop
348  in_global_load_tuple(I).Run(in_grid_2d_desc_tuple[I],
349  in_global_buf_tuple[I],
350  thread_buffer_desc_m_k,
351  make_tuple(I0, I0),
352  in_thread_buf_tuple(iK0)(I));
353 
354  in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_2d_desc_tuple[I],
355  thread_copy_fwd_step_m_k);
356  });
357 
358  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) { // input add loop
359  static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
360  constexpr auto offset_m_k =
361  thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
362 
363  // get reference to in data
364  const auto in_data_refs = generate_tie(
365  // return type should be lvalue
366  [&](auto I) -> const auto& {
367  return in_thread_buf_tuple(iK0)(I)(Number<offset_m_k>{});
368  },
369  Number<NumInput>{});
370 
371  // get reference to dst data
372  auto out_data_refs = generate_tie(
373  // return type should be lvalue
374  [&](auto) -> auto& { return x_thread_buf(iK0)(Number<offset_m_k>{}); },
375  I1);
376 
377  unpack2(x_elementwise_op, out_data_refs, in_data_refs);
378  });
379  });
380  threadwise_welford.Run(x_thread_buf[iK0], mean_thread_buf, var_thread_buf);
381 
382  if constexpr(!SweepOnce)
383  {
384  threadwise_x_store.Run(thread_buffer_desc_m_k,
385  make_tuple(I0, I0),
386  x_thread_buf(iK0),
387  x_grid_desc_m_k,
388  x_lds_val_buf);
389  threadwise_x_store.MoveDstSliceWindow(x_grid_desc_m_k,
390  thread_copy_fwd_step_m_k);
391  }
392  });
393  }
394 
395  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
396  if constexpr(I > 0)
397  block_sync_lds();
398 
399  int count = threadwise_welford.cur_count_;
400  BlockwiseWelford::Run(mean_thread_buf(I), var_thread_buf(I), count);
401  });
402 
403  auto thread_copy_tail_m_k =
404  (num_k_block_tile_iteration - 1) * XThreadBufferNumber * thread_copy_fwd_step_m_k;
405 
406  if constexpr(!SweepOnce)
407  threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_tail_m_k);
408  threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_tail_m_k);
409  threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k, thread_copy_tail_m_k);
410  threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_tail_m_k);
411 
412  for(index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
413  {
414  if constexpr(!SweepOnce)
415  {
417  threadwise_x_load.Run(x_grid_desc_m_k,
418  x_lds_val_buf,
419  thread_buffer_desc_m_k,
420  make_tuple(I0, I0),
421  x_thread_buf(i));
422  threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, thread_copy_fwd_step_m_k);
423  });
424  }
425 
427  threadwise_gamma_load.Run(gamma_grid_desc_m_k,
428  gamma_global_val_buf,
429  thread_buffer_desc_m_k,
430  make_tuple(I0, I0),
431  gamma_thread_buf(i));
432  threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
433  thread_copy_fwd_step_m_k);
434  });
435 
436  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
437  auto divisor = 1 / ck::math::sqrt(var_thread_buf(iM) + epsilon);
438  static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
439  static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
440  constexpr auto offset_m_k =
441  thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
442 
443  // normalize
444  y_thread_buf(iK0)(Number<offset_m_k>{}) =
445  (x_thread_buf(iK0)(Number<offset_m_k>{}) - mean_thread_buf(iM)) *
446  divisor;
447 
448  // gamma
449  y_thread_buf(iK0)(Number<offset_m_k>{}) =
450  y_thread_buf(iK0)(Number<offset_m_k>{}) *
451  gamma_thread_buf(iK0)(Number<offset_m_k>{});
452  });
453  });
454  });
455 
457  threadwise_beta_load.Run(beta_grid_desc_m_k,
458  beta_global_val_buf,
459  thread_buffer_desc_m_k,
460  make_tuple(I0, I0),
461  beta_thread_buf(i));
462  threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
463  thread_copy_fwd_step_m_k);
464  });
465 
466  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
467  static_for<0, XThreadBufferNumber, 1>{}([&](auto iK0) {
468  static_for<0, XSrcVectorSize, 1>{}([&](auto iK1) {
469  constexpr auto offset_m_k =
470  thread_buffer_desc_m_k.CalculateOffset(make_tuple(iM, iK1));
471 
472  // beta
473  y_thread_buf(iK0)(Number<offset_m_k>{}) =
474  y_thread_buf(iK0)(Number<offset_m_k>{}) +
475  beta_thread_buf(iK0)(Number<offset_m_k>{});
476  });
477  });
478  });
479 
481  threadwise_y_store.Run(thread_buffer_desc_m_k,
482  make_tuple(I0, I0),
483  y_thread_buf(i),
484  y_grid_desc_m_k,
485  y_global_val_buf);
486  threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, thread_copy_fwd_step_m_k);
487  });
488 
489  if constexpr(!SweepOnce)
490  threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
491  threadwise_gamma_load.MoveSrcSliceWindow(gamma_grid_desc_m_k,
492  2 * thread_copy_bwd_step_m_k);
493  threadwise_beta_load.MoveSrcSliceWindow(beta_grid_desc_m_k,
494  2 * thread_copy_bwd_step_m_k);
495  threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, 2 * thread_copy_bwd_step_m_k);
496  }
497  }
498 };
499 
500 } // namespace ck
__host__ constexpr __device__ T clamp(const T &x, const T &lowerbound, const T &upperbound)
Definition: math.hpp:148
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto unpack2(F &&f, X &&x, Y &&y)
Definition: functional4.hpp:55
__device__ index_t get_grid_size()
Definition: get_id.hpp:24
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:22
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
int32_t index_t
Definition: ck.hpp:289
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:16
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:298
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition: blockwise_welford.hpp:51
Definition: gridwise_elementwise_layernorm_welford_variance.hpp:42
static constexpr index_t K_BlockTileSize
Definition: gridwise_elementwise_layernorm_welford_variance.hpp:84
static constexpr bool reorder_thread_cluster
Definition: gridwise_elementwise_layernorm_welford_variance.hpp:53
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition: gridwise_elementwise_layernorm_welford_variance.hpp:58
static constexpr auto GammaThreadBufferNumber
Definition: gridwise_elementwise_layernorm_welford_variance.hpp:88
static constexpr auto I0
Definition: gridwise_elementwise_layernorm_welford_variance.hpp:79
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition: gridwise_elementwise_layernorm_welford_variance.hpp:69
static __device__ void Run(const InGrid2dDescTuple in_grid_2d_desc_tuple, const GridDesc_M_K &x_grid_desc_m_k, const GridDesc_M_K &gamma_grid_desc_m_k, const GridDesc_M_K &beta_grid_desc_m_k, const GridDesc_M_K &y_grid_desc_m_k, index_t num_k_block_tile_iteration, AccDataType epsilon, const InDataTypePointerTuple p_in_global_tuple, XDataType *const __restrict__ p_x_lds_, const GammaDataType *const __restrict__ p_gamma_global, const BetaDataType *const __restrict__ p_beta_global, YDataType *const __restrict__ p_y_global, const XElementwiseOperation x_elementwise_op, const YElementwiseOperation y_elementwise_op)
Definition: gridwise_elementwise_layernorm_welford_variance.hpp:114
static constexpr auto thread_cluster_desc
Definition: gridwise_elementwise_layernorm_welford_variance.hpp:63
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< XSrcVectorSize >{}))) ThreadReduceSrcDesc_M_K
Definition: gridwise_elementwise_layernorm_welford_variance.hpp:67
static constexpr auto YThreadBufferNumber
Definition: gridwise_elementwise_layernorm_welford_variance.hpp:90
ThreadwiseWelford< AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M > ThreadwiseWelford
Definition: gridwise_elementwise_layernorm_welford_variance.hpp:72
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition: gridwise_elementwise_layernorm_welford_variance.hpp:61
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition: gridwise_elementwise_layernorm_welford_variance.hpp:55
static constexpr auto BetaThreadBufferNumber
Definition: gridwise_elementwise_layernorm_welford_variance.hpp:89
static constexpr auto I1
Definition: gridwise_elementwise_layernorm_welford_variance.hpp:80
static __device__ int GetKPerThread(const GridDesc_M_K &x_grid_desc_m_k, int thread_k_cluster_id)
Definition: gridwise_elementwise_layernorm_welford_variance.hpp:92
static constexpr index_t K_BlockTileStepSize
Definition: gridwise_elementwise_layernorm_welford_variance.hpp:85
static constexpr auto I2
Definition: gridwise_elementwise_layernorm_welford_variance.hpp:81
static constexpr auto XThreadBufferNumber
Definition: gridwise_elementwise_layernorm_welford_variance.hpp:87
BlockwiseWelford< AccDataType, BlockSize, ThreadClusterLengths_M_K, ThreadClusterArrangeOrder > BlockwiseWelford
Definition: gridwise_elementwise_layernorm_welford_variance.hpp:77
static constexpr index_t NumInput
Definition: gridwise_elementwise_layernorm_welford_variance.hpp:51
static constexpr index_t M_BlockTileSize
Definition: gridwise_elementwise_layernorm_welford_variance.hpp:83
Definition: multi_index_transform.hpp:13
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: threadwise_tensor_slice_transfer.hpp:39
__device__ void MoveDstSliceWindow(const DstDesc &dst_desc, const Index &dst_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:173
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:66
Definition: threadwise_tensor_slice_transfer.hpp:214
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:243
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:355
Definition: functional.hpp:100
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
Definition: unary_element_wise_operation.hpp:241