15 template <
typename GridwiseWelfordSecondHalfReduceFirstHalf_,
19 typename ScaleDataType,
20 typename DscaleDbiasDataType,
21 typename MeanVarDataType,
22 typename DyElementwiseOp,
23 typename XYGridDesc_M_K,
24 typename MeanVarGridDesc_M,
25 typename MeanVarCountGridDesc_M_K,
26 typename DscaleDbiasGridDesc_M_G>
28 const XYGridDesc_M_K x_grid_desc_m_k,
29 const XYGridDesc_M_K dy_grid_desc_m_k,
30 const MeanVarGridDesc_M mean_var_grid_desc_m,
31 const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k,
32 const DscaleDbiasGridDesc_M_G dscale_dbias_grid_desc_m_g,
34 index_t num_xy_k_block_tile_iteration,
35 index_t num_mean_var_count_k_block_tile_iteration,
37 bool haveSavedMeanInvVar,
38 const MeanVarDataType*
const __restrict__ p_savedMean,
39 const MeanVarDataType*
const __restrict__ p_savedInvVar,
40 const MeanVarDataType*
const __restrict__ p_in_welford_mean,
41 const MeanVarDataType*
const __restrict__ p_in_welford_variance,
42 const int32_t*
const __restrict__ p_in_welford_count,
43 const DyElementwiseOp dy_elementwise_op,
44 MeanVarDataType*
const __restrict__ p_out_welford_mean,
45 MeanVarDataType*
const __restrict__ p_out_welford_inv_variance,
46 const XDataType*
const __restrict__ p_x,
47 const DyDataType*
const __restrict__ p_dy,
48 DscaleDbiasDataType*
const __restrict__ p_reduce_dscale,
49 DscaleDbiasDataType*
const __restrict__ p_reduce_dbias)
51 GridwiseWelfordSecondHalfReduceFirstHalf_::Run(x_grid_desc_m_k,
54 mean_var_count_grid_desc_m_k,
55 dscale_dbias_grid_desc_m_g,
57 num_xy_k_block_tile_iteration,
58 num_mean_var_count_k_block_tile_iteration,
64 p_in_welford_variance,
68 p_out_welford_inv_variance,
75 template <
typename XDataType,
78 typename ScaleDataType,
79 typename DscaleDbiasDataType,
80 typename MeanVarDataType,
81 typename DyElementwiseOp,
82 typename XYGridDesc_M_K,
83 typename MeanVarGridDesc_M,
84 typename MeanVarCountGridDesc_M_K,
85 typename DscaleDbiasGridDesc_M_G,
97 static_assert((XDyVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0 &&
98 MThreadSliceSize % DySrcVectorSize == 0) ||
99 (XDyVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0 &&
100 KThreadSliceSize % DySrcVectorSize == 0),
101 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
157 __device__
static void Run(
const XYGridDesc_M_K& x_grid_desc_m_k,
158 const XYGridDesc_M_K& dy_grid_desc_m_k,
159 const MeanVarGridDesc_M& mean_var_grid_desc_m,
160 const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k,
161 const DscaleDbiasGridDesc_M_G& dscale_dbias_grid_desc_m_g,
163 index_t num_xy_k_block_tile_iteration,
164 index_t num_mean_var_count_k_block_tile_iteration,
166 bool haveSavedMeanInvVar,
167 const MeanVarDataType*
const __restrict__ p_savedMean,
168 const MeanVarDataType*
const __restrict__ p_savedInvVar,
169 const MeanVarDataType*
const __restrict__ p_in_welford_mean,
170 const MeanVarDataType*
const __restrict__ p_in_welford_variance,
171 const int32_t*
const __restrict__ p_in_welford_count,
172 const DyElementwiseOp dy_elementwise_op,
173 MeanVarDataType*
const __restrict__ p_out_welford_mean,
174 MeanVarDataType*
const __restrict__ p_out_welford_inv_variance,
175 const XDataType*
const __restrict__ p_x,
176 const DyDataType*
const __restrict__ p_dy,
177 DscaleDbiasDataType*
const __restrict__ p_reduce_dscale,
178 DscaleDbiasDataType*
const __restrict__ p_reduce_dbias)
180 __shared__ AccDataType p_reduce_work_buffer[BlockSize];
182 auto reduce_work_buf =
183 make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
186 in_welford_mean_thread_buf;
188 in_welford_var_thread_buf;
190 in_welford_count_thread_buf;
193 welford_mean_thread_buf;
195 welford_var_thread_buf;
197 welford_count_thread_buf;
200 welford_mean_thread_buf;
202 inv_var_thread_buf = welford_var_thread_buf;
214 reduce_dscale_thread_buf;
216 reduce_dbias_thread_buf;
220 const index_t blkgroup_id = block_global_id / blkgroup_size;
221 const index_t block_local_id = block_global_id % blkgroup_size;
223 const auto thread_cluster_idx =
226 const auto thread_m_cluster_id = thread_cluster_idx[
I0];
227 const auto thread_k_cluster_id = thread_cluster_idx[
I1];
234 constexpr
auto thread_buffer_desc_m =
243 if(haveSavedMeanInvVar)
245 const auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
246 p_savedMean, mean_var_grid_desc_m.GetElementSpaceSize());
248 const auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
249 p_savedInvVar, mean_var_grid_desc_m.GetElementSpaceSize());
251 auto threadwise_mean_inv_var_load =
255 decltype(thread_buffer_desc_m),
256 ThreadBufferLengths_M,
259 MeanVarSrcVectorSize,
262 mean_var_grid_desc_m,
264 thread_m_cluster_id * MThreadSliceSize));
266 threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
268 thread_buffer_desc_m,
272 threadwise_mean_inv_var_load.Run(mean_var_grid_desc_m,
274 thread_buffer_desc_m,
280 const auto welford_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
281 p_in_welford_mean, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
283 const auto welford_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
284 p_in_welford_variance, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
286 const auto welford_count_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
287 p_in_welford_count, mean_var_count_grid_desc_m_k.GetElementSpaceSize());
289 auto threadwise_mean_var_load_m_k =
292 MeanVarCountGridDesc_M_K,
293 decltype(thread_buffer_desc_m_1),
294 ThreadBufferLengths_M_1,
300 mean_var_count_grid_desc_m_k,
302 thread_m_cluster_id * MThreadSliceSize,
303 thread_k_cluster_id * 1));
305 auto threadwise_count_load_m_k =
308 MeanVarCountGridDesc_M_K,
309 decltype(thread_buffer_desc_m_1),
310 ThreadBufferLengths_M_1,
316 mean_var_count_grid_desc_m_k,
318 thread_m_cluster_id * MThreadSliceSize,
319 thread_k_cluster_id * 1));
321 constexpr
auto mean_var_count_thread_copy_step_m_k =
325 welford_mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
326 welford_var_thread_buf(I) = type_convert<AccDataType>(0.0f);
327 welford_count_thread_buf(I) = 0;
330 for(
index_t reducedTiles = 0; reducedTiles < num_mean_var_count_k_block_tile_iteration;
333 threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
334 welford_mean_global_buf,
335 thread_buffer_desc_m_1,
337 in_welford_mean_thread_buf);
339 threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
340 welford_var_global_buf,
341 thread_buffer_desc_m_1,
343 in_welford_var_thread_buf);
345 threadwise_count_load_m_k.
Run(mean_var_count_grid_desc_m_k,
346 welford_count_global_buf,
347 thread_buffer_desc_m_1,
349 in_welford_count_thread_buf);
352 in_welford_var_thread_buf,
353 in_welford_count_thread_buf,
354 welford_mean_thread_buf,
355 welford_var_thread_buf,
356 welford_count_thread_buf);
358 threadwise_mean_var_load_m_k.MoveSrcSliceWindow(
359 mean_var_count_grid_desc_m_k, mean_var_count_thread_copy_step_m_k);
361 mean_var_count_thread_copy_step_m_k);
369 welford_var_thread_buf(I),
370 welford_count_thread_buf(I));
375 welford_var_thread_buf(I) =
376 type_convert<AccDataType>(1.0) / sqrt(welford_var_thread_buf[I] + epsilon);
379 if(block_local_id == 0 && thread_k_cluster_id == 0)
382 auto threadwise_mean_inv_var_store =
385 decltype(thread_buffer_desc_m),
388 ThreadBufferLengths_M,
395 mean_var_grid_desc_m,
397 thread_m_cluster_id * MThreadSliceSize),
400 auto mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
401 p_out_welford_mean, mean_var_grid_desc_m.GetElementSpaceSize());
403 auto inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
404 p_out_welford_inv_variance, mean_var_grid_desc_m.GetElementSpaceSize());
406 threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
409 mean_var_grid_desc_m,
412 threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
415 mean_var_grid_desc_m,
425 decltype(thread_buffer_desc_m_k),
426 ThreadBufferLengths_M_K,
434 workSizePerBlock * block_local_id +
435 thread_k_cluster_id * KThreadSliceSize));
440 decltype(thread_buffer_desc_m_k),
441 ThreadBufferLengths_M_K,
449 workSizePerBlock * block_local_id +
450 thread_k_cluster_id * KThreadSliceSize));
452 const auto x_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
453 p_x, x_grid_desc_m_k.GetElementSpaceSize());
455 const auto dy_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
456 p_dy, dy_grid_desc_m_k.GetElementSpaceSize());
461 reduce_dscale_thread_buf(I) = type_convert<AccDataType>(0);
462 reduce_dbias_thread_buf(I) = type_convert<AccDataType>(0);
469 for(
index_t reducedTiles = 0; reducedTiles < num_xy_k_block_tile_iteration; ++reducedTiles)
471 threadwise_x_load.Run(x_grid_desc_m_k,
473 thread_buffer_desc_m_k,
477 threadwise_dy_load.
Run(dy_grid_desc_m_k,
479 thread_buffer_desc_m_k,
485 constexpr
auto offset =
486 thread_buffer_desc_m_k.CalculateOffset(
make_tuple(iM, iK));
491 AccDataType norm_x = (x_thread_buf[
Number<offset>{}] - mean_thread_buf[iM]) *
492 inv_var_thread_buf[iM];
501 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_thread_copy_step_m_k);
514 auto threadwise_dscale_dbias_store =
517 decltype(thread_buffer_desc_m_1),
518 DscaleDbiasGridDesc_M_G,
520 ThreadBufferLengths_M_1,
527 dscale_dbias_grid_desc_m_g,
529 thread_m_cluster_id * MThreadSliceSize,
533 auto reduce_dscale_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
534 p_reduce_dscale, dscale_dbias_grid_desc_m_g.GetElementSpaceSize());
536 auto reduce_dbias_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
537 p_reduce_dbias, dscale_dbias_grid_desc_m_g.GetElementSpaceSize());
539 if(thread_k_cluster_id == 0)
541 threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1,
543 reduce_dscale_thread_buf,
544 dscale_dbias_grid_desc_m_g,
545 reduce_dscale_global_buf);
547 threadwise_dscale_dbias_store.Run(thread_buffer_desc_m_1,
549 reduce_dbias_thread_buf,
550 dscale_dbias_grid_desc_m_g,
551 reduce_dbias_global_buf);
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__global__ void kernel_welford_second_half_reduce_first_half(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k, const MeanVarGridDesc_M mean_var_grid_desc_m, const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, const DscaleDbiasGridDesc_M_G dscale_dbias_grid_desc_m_g, index_t blkgroup_size, index_t num_xy_k_block_tile_iteration, index_t num_mean_var_count_k_block_tile_iteration, AccDataType epsilon, bool haveSavedMeanInvVar, const MeanVarDataType *const __restrict__ p_savedMean, const MeanVarDataType *const __restrict__ p_savedInvVar, const MeanVarDataType *const __restrict__ p_in_welford_mean, const MeanVarDataType *const __restrict__ p_in_welford_variance, const int32_t *const __restrict__ p_in_welford_count, const DyElementwiseOp dy_elementwise_op, MeanVarDataType *const __restrict__ p_out_welford_mean, MeanVarDataType *const __restrict__ p_out_welford_inv_variance, const XDataType *const __restrict__ p_x, const DyDataType *const __restrict__ p_dy, DscaleDbiasDataType *const __restrict__ p_reduce_dscale, DscaleDbiasDataType *const __restrict__ p_reduce_dbias)
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:27
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:16
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition: blockwise_welford.hpp:51
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:96
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:105
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:108
static constexpr index_t K_BlockTileSize
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:150
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< 1 >{}))) ThreadReduceSrcDesc_M_1
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:119
BlockwiseWelford< AccDataType, BlockSize, ThreadClusterLengths_M_K, ThreadClusterArrangeOrder > BlockwiseWelford
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:129
static constexpr auto I0
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:146
static __device__ void Run(const XYGridDesc_M_K &x_grid_desc_m_k, const XYGridDesc_M_K &dy_grid_desc_m_k, const MeanVarGridDesc_M &mean_var_grid_desc_m, const MeanVarCountGridDesc_M_K &mean_var_count_grid_desc_m_k, const DscaleDbiasGridDesc_M_G &dscale_dbias_grid_desc_m_g, index_t blkgroup_size, index_t num_xy_k_block_tile_iteration, index_t num_mean_var_count_k_block_tile_iteration, AccDataType epsilon, bool haveSavedMeanInvVar, const MeanVarDataType *const __restrict__ p_savedMean, const MeanVarDataType *const __restrict__ p_savedInvVar, const MeanVarDataType *const __restrict__ p_in_welford_mean, const MeanVarDataType *const __restrict__ p_in_welford_variance, const int32_t *const __restrict__ p_in_welford_count, const DyElementwiseOp dy_elementwise_op, MeanVarDataType *const __restrict__ p_out_welford_mean, MeanVarDataType *const __restrict__ p_out_welford_inv_variance, const XDataType *const __restrict__ p_x, const DyDataType *const __restrict__ p_dy, DscaleDbiasDataType *const __restrict__ p_reduce_dscale, DscaleDbiasDataType *const __restrict__ p_reduce_dbias)
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:157
static constexpr bool reorder_thread_cluster
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:103
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:117
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:144
static constexpr auto I1
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:147
static constexpr auto thread_cluster_desc
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:113
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:111
static constexpr index_t M_BlockTileSize
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:149
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:121
Definition: reduction_functions_blockwise.hpp:28
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition: reduction_functions_blockwise.hpp:44
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: reduction_functions_threadwise.hpp:23
static __device__ void Reduce(const SrcBufferType &src_buf, DstBufferType &dst_buf)
Definition: reduction_functions_threadwise.hpp:36
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: threadwise_tensor_slice_transfer.hpp:214
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:243
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:355
Definition: threadwise_welford.hpp:83
static __device__ void Run(const SrcMeanBufferType &src_mean_buf, const SrcVarBufferType &src_var_buf, const SrcCountBufferType &src_count_buf, DstMeanBufferType &dst_mean_buf, DstVarBufferType &dst_var_buf, DstCountBufferType &dst_count_buf)
Definition: threadwise_welford.hpp:110
Definition: functional.hpp:100
Definition: integral_constant.hpp:10
Definition: reduction_operator.hpp:37
Definition: functional2.hpp:31
Definition: unary_element_wise_operation.hpp:241