17 template <
typename GridwiseMultiblockBatchNormForward_,
21 typename ScaleDataType,
22 typename BiasDataType,
23 typename MeanVarDataType,
24 typename YElementwiseOp,
25 typename XYGridDesc_M_K,
26 typename MeanVarCountGridDesc_M_G,
27 typename MeanVarCountGridDesc_M_K,
28 typename ScaleBiasGridDesc_M,
29 typename MeanVarGridDesc_M,
30 typename GetReduceCountPerThreadFunctor>
32 const XYGridDesc_M_K x_grid_desc_m_k,
33 const XYGridDesc_M_K y_grid_desc_m_k,
34 const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g,
35 const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k,
36 const ScaleBiasGridDesc_M scale_grid_desc_m,
37 const ScaleBiasGridDesc_M bias_grid_desc_m,
38 const MeanVarGridDesc_M mean_var_grid_desc_m,
39 const GetReduceCountPerThreadFunctor get_reduce_count_per_thread,
40 index_t num_k_block_tile_iteration,
42 const XDataType*
const __restrict__ p_x,
43 MeanVarDataType*
const __restrict__ p_welford_mean,
44 MeanVarDataType*
const __restrict__ p_welford_variance,
45 int32_t*
const __restrict__ p_welford_count,
46 int32_t*
const __restrict__ p_control,
47 const ScaleDataType*
const __restrict__ p_scale,
48 const BiasDataType*
const __restrict__ p_bias,
49 const YElementwiseOp y_elementwise_op,
50 YDataType*
const __restrict__ p_y,
51 bool updateMovingAverage,
52 AccDataType averageFactor,
53 MeanVarDataType*
const __restrict__ resultRunningMean,
54 MeanVarDataType*
const __restrict__ resultRunningVariance,
55 bool saveMeanInvVariance,
56 MeanVarDataType*
const __restrict__ resultSaveMean,
57 MeanVarDataType*
const __restrict__ resultSaveInvVariance)
59 GridwiseMultiblockBatchNormForward_::Run(x_grid_desc_m_k,
61 mean_var_count_grid_desc_m_g,
62 mean_var_count_grid_desc_m_k,
66 get_reduce_count_per_thread,
67 num_k_block_tile_iteration,
81 resultRunningVariance,
84 resultSaveInvVariance);
87 template <
typename XDataType,
90 typename ScaleDataType,
91 typename BiasDataType,
92 typename MeanVarDataType,
93 typename YElementwiseOp,
94 typename XYGridDesc_M_K,
95 typename MeanVarCountGridDesc_M_G,
96 typename MeanVarCountGridDesc_M_K,
97 typename ScaleBiasGridDesc_M,
98 typename MeanVarGridDesc_M,
99 typename GetReduceCountPerThreadFunctor,
110 index_t MeanVarSrcDstVectorSize>
113 static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % XSrcVectorSize == 0) ||
114 (XSrcYDstVectorDim == 1 && KThreadSliceSize % XSrcVectorSize == 0),
115 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
117 static_assert((XSrcYDstVectorDim == 0 && MThreadSliceSize % YDstVectorSize == 0) ||
118 (XSrcYDstVectorDim == 1 && KThreadSliceSize % YDstVectorSize == 0),
119 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
168 __device__
static void Run(
const XYGridDesc_M_K& x_grid_desc_m_k,
169 const XYGridDesc_M_K& y_grid_desc_m_k,
170 const MeanVarCountGridDesc_M_G& mean_var_count_grid_desc_m_g,
171 const MeanVarCountGridDesc_M_K& mean_var_count_grid_desc_m_k,
172 const ScaleBiasGridDesc_M& scale_grid_desc_m,
173 const ScaleBiasGridDesc_M& bias_grid_desc_m,
174 const MeanVarGridDesc_M& mean_var_grid_desc_m,
175 const GetReduceCountPerThreadFunctor& get_reduce_count_per_thread,
176 index_t num_k_block_tile_iteration,
178 const XDataType*
const __restrict__ p_x,
179 MeanVarDataType*
const __restrict__ p_welford_mean,
180 MeanVarDataType*
const __restrict__ p_welford_variance,
181 int32_t*
const __restrict__ p_welford_count,
182 int32_t*
const __restrict__ p_control,
183 const ScaleDataType*
const __restrict__ p_scale,
184 const BiasDataType*
const __restrict__ p_bias,
185 const YElementwiseOp y_elementwise_op,
186 YDataType*
const __restrict__ p_y,
187 bool updateMovingAverage,
188 AccDataType averageFactor,
189 MeanVarDataType*
const __restrict__ resultRunningMean,
190 MeanVarDataType*
const __restrict__ resultRunningVariance,
191 bool saveMeanInvVariance,
192 MeanVarDataType*
const __restrict__ resultSaveMean,
193 MeanVarDataType*
const __restrict__ resultSaveInvVariance)
195 using ck::math::sqrt;
197 const index_t blkgroup_size = mean_var_count_grid_desc_m_g.GetLength(
I1);
201 const index_t blkgroup_id = block_global_id / blkgroup_size;
202 const index_t block_local_id = block_global_id % blkgroup_size;
204 if(block_local_id == 0)
205 gms_init(BlockSize / warpSize * blkgroup_size, &p_control[blkgroup_id * 2]);
207 const auto thread_cluster_idx =
210 const auto thread_m_cluster_id = thread_cluster_idx[
I0];
211 const auto thread_k_cluster_id = thread_cluster_idx[
I1];
219 constexpr
auto thread_buffer_desc_m =
242 decltype(thread_buffer_desc_m_k),
243 ThreadBufferLengths_M_K,
251 block_local_id * reduceSizePerBlock +
252 thread_k_cluster_id * KThreadSliceSize));
256 const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
257 p_x, x_grid_desc_m_k.GetElementSpaceSize());
262 threadwise_welford_1.max_count_ =
263 get_reduce_count_per_thread(block_local_id, thread_k_cluster_id);
266 mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
267 var_thread_buf(I) = type_convert<AccDataType>(0.0f);
270 for(
index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
272 threadwise_x_load.Run(x_grid_desc_m_k,
274 thread_buffer_desc_m_k,
278 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_copy_fwd_step_m_k);
279 threadwise_welford_1.Run(x_thread_buf, mean_thread_buf, var_thread_buf);
286 count_thread_buf(I) = threadwise_welford_1.cur_count_;
292 auto mean_global_val_buf =
293 make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
294 p_welford_mean, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
296 auto var_global_val_buf =
297 make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
298 p_welford_variance, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
300 auto count_global_val_buf =
301 make_dynamic_buffer<AddressSpaceEnum::Global, AmdBufferCoherenceEnum::GLC>(
302 p_welford_count, mean_var_count_grid_desc_m_g.GetElementSpaceSize());
304 auto threadwise_mean_var_store_m_g =
307 decltype(thread_buffer_desc_m_1),
308 MeanVarCountGridDesc_M_G,
310 ThreadBufferLengths_M_1,
317 mean_var_count_grid_desc_m_g,
319 thread_m_cluster_id * MThreadSliceSize,
323 auto threadwise_count_store_m_g =
326 decltype(thread_buffer_desc_m_1),
327 MeanVarCountGridDesc_M_G,
329 ThreadBufferLengths_M_1,
336 mean_var_count_grid_desc_m_g,
338 thread_m_cluster_id * MThreadSliceSize,
342 if(thread_k_cluster_id == 0)
344 threadwise_mean_var_store_m_g.Run(thread_buffer_desc_m_1,
347 mean_var_count_grid_desc_m_g,
348 mean_global_val_buf);
350 threadwise_mean_var_store_m_g.Run(thread_buffer_desc_m_1,
353 mean_var_count_grid_desc_m_g,
356 threadwise_count_store_m_g.
Run(thread_buffer_desc_m_1,
359 mean_var_count_grid_desc_m_g,
360 count_global_val_buf);
363 gms_barrier(&p_control[blkgroup_id * 2]);
365 if(block_local_id == 0)
366 gms_reset(&p_control[blkgroup_id * 2]);
371 auto threadwise_mean_var_load_m_k =
374 MeanVarCountGridDesc_M_K,
375 decltype(thread_buffer_desc_m_1),
376 ThreadBufferLengths_M_1,
382 mean_var_count_grid_desc_m_k,
384 thread_m_cluster_id * MThreadSliceSize,
385 thread_k_cluster_id * 1));
387 auto threadwise_count_load_m_k =
390 MeanVarCountGridDesc_M_K,
391 decltype(thread_buffer_desc_m_1),
392 ThreadBufferLengths_M_1,
398 mean_var_count_grid_desc_m_k,
400 thread_m_cluster_id * MThreadSliceSize,
401 thread_k_cluster_id * 1));
404 mean_thread_buf(I) = type_convert<AccDataType>(0.0f);
405 var_thread_buf(I) = type_convert<AccDataType>(0.0f);
406 count_thread_buf(I) = 0;
409 constexpr
auto mean_var_count_read_fwd_step_m_k =
make_multi_index(0, KThreadClusterSize);
411 int32_t reducedSize = 0;
412 while(reducedSize < blkgroup_size)
414 threadwise_mean_var_load_m_k.
Run(mean_var_count_grid_desc_m_k,
416 thread_buffer_desc_m_1,
418 tmp_mean_thread_buf);
420 threadwise_mean_var_load_m_k.Run(mean_var_count_grid_desc_m_k,
422 thread_buffer_desc_m_1,
426 threadwise_count_load_m_k.
Run(mean_var_count_grid_desc_m_k,
427 count_global_val_buf,
428 thread_buffer_desc_m_1,
430 tmp_count_thread_buf);
434 tmp_count_thread_buf,
439 reducedSize += KThreadClusterSize;
441 threadwise_mean_var_load_m_k.MoveSrcSliceWindow(mean_var_count_grid_desc_m_k,
442 mean_var_count_read_fwd_step_m_k);
444 mean_var_count_read_fwd_step_m_k);
463 auto threadwise_y_store =
466 decltype(thread_buffer_desc_m_k),
469 ThreadBufferLengths_M_K,
478 blkgroup_id *
M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
479 block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize),
482 auto threadwise_scale_load =
486 decltype(thread_buffer_desc_m),
487 ThreadBufferLengths_M,
495 thread_m_cluster_id * MThreadSliceSize));
500 decltype(thread_buffer_desc_m),
501 ThreadBufferLengths_M,
509 thread_m_cluster_id * MThreadSliceSize));
511 const auto scale_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
512 p_scale, scale_grid_desc_m.GetElementSpaceSize());
514 const auto bias_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
515 p_bias, bias_grid_desc_m.GetElementSpaceSize());
517 auto y_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
518 p_y, y_grid_desc_m_k.GetElementSpaceSize());
520 threadwise_scale_load.
Run(scale_grid_desc_m,
521 scale_global_val_buf,
522 thread_buffer_desc_m,
526 threadwise_bias_load.
Run(bias_grid_desc_m,
528 thread_buffer_desc_m,
532 threadwise_x_load.SetSrcSliceOrigin(
535 block_local_id * reduceSizePerBlock +
536 thread_k_cluster_id * KThreadSliceSize));
538 for(
index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
540 threadwise_x_load.Run(x_grid_desc_m_k,
542 thread_buffer_desc_m_k,
547 AccDataType multiplier =
548 scale_thread_buf[
Number<iM>{}] / sqrt(var_thread_buf[iM] + epsilon);
550 AccDataType fused_mean_bias =
551 bias_thread_buf[
Number<iM>{}] - mean_thread_buf[iM] * multiplier;
554 constexpr
auto offset =
555 thread_buffer_desc_m_k.CalculateOffset(
make_tuple(iM, iK));
563 threadwise_y_store.Run(thread_buffer_desc_m_k,
569 threadwise_x_load.MoveSrcSliceWindow(x_grid_desc_m_k, xy_copy_fwd_step_m_k);
570 threadwise_y_store.MoveDstSliceWindow(y_grid_desc_m_k, xy_copy_fwd_step_m_k);
575 if(updateMovingAverage && block_local_id == 0 && thread_k_cluster_id == 0)
578 running_mean_thread_buf;
580 running_var_thread_buf;
582 auto running_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
583 resultRunningMean, mean_var_grid_desc_m.GetElementSpaceSize());
585 auto running_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
586 resultRunningVariance, mean_var_grid_desc_m.GetElementSpaceSize());
588 auto threadwise_mean_var_load =
592 decltype(thread_buffer_desc_m),
593 ThreadBufferLengths_M,
596 MeanVarSrcDstVectorSize,
599 mean_var_grid_desc_m,
601 thread_m_cluster_id * MThreadSliceSize));
603 threadwise_mean_var_load.Run(mean_var_grid_desc_m,
604 running_mean_global_buf,
605 thread_buffer_desc_m,
607 running_mean_thread_buf);
609 threadwise_mean_var_load.Run(mean_var_grid_desc_m,
610 running_var_global_buf,
611 thread_buffer_desc_m,
613 running_var_thread_buf);
615 AccDataType oneMinusAverageFactor = type_convert<AccDataType>(1.0) - averageFactor;
618 running_mean_thread_buf(I) = running_mean_thread_buf[I] * oneMinusAverageFactor +
619 mean_thread_buf[I] * averageFactor;
620 running_var_thread_buf(I) = running_var_thread_buf[I] * oneMinusAverageFactor +
621 var_thread_buf[I] * averageFactor;
624 auto threadwise_mean_var_store =
627 decltype(thread_buffer_desc_m),
630 ThreadBufferLengths_M,
633 MeanVarSrcDstVectorSize,
637 mean_var_grid_desc_m,
639 thread_m_cluster_id * MThreadSliceSize),
642 threadwise_mean_var_store.Run(thread_buffer_desc_m,
644 running_mean_thread_buf,
645 mean_var_grid_desc_m,
646 running_mean_global_buf);
648 threadwise_mean_var_store.Run(thread_buffer_desc_m,
650 running_var_thread_buf,
651 mean_var_grid_desc_m,
652 running_var_global_buf);
657 if(saveMeanInvVariance && block_local_id == 0 && thread_k_cluster_id == 0)
659 auto result_mean_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
660 resultSaveMean, mean_var_grid_desc_m.GetElementSpaceSize());
662 auto result_inv_var_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
663 resultSaveInvVariance, mean_var_grid_desc_m.GetElementSpaceSize());
668 type_convert<AccDataType>(1.0f) / sqrt(epsilon + var_thread_buf[I]);
671 auto threadwise_mean_inv_var_store =
674 decltype(thread_buffer_desc_m),
677 ThreadBufferLengths_M,
680 MeanVarSrcDstVectorSize,
684 mean_var_grid_desc_m,
686 thread_m_cluster_id * MThreadSliceSize),
689 threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
692 mean_var_grid_desc_m,
693 result_mean_global_buf);
695 threadwise_mean_inv_var_store.Run(thread_buffer_desc_m,
698 mean_var_grid_desc_m,
699 result_inv_var_global_buf);
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:16
__global__ void kernel_multiblock_batchnorm_forward(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K y_grid_desc_m_k, const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g, const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M bias_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, index_t num_k_block_tile_iteration, AccDataType epsilon, const XDataType *const __restrict__ p_x, MeanVarDataType *const __restrict__ p_welford_mean, MeanVarDataType *const __restrict__ p_welford_variance, int32_t *const __restrict__ p_welford_count, int32_t *const __restrict__ p_control, const ScaleDataType *const __restrict__ p_scale, const BiasDataType *const __restrict__ p_bias, const YElementwiseOp y_elementwise_op, YDataType *const __restrict__ p_y, bool updateMovingAverage, AccDataType averageFactor, MeanVarDataType *const __restrict__ resultRunningMean, MeanVarDataType *const __restrict__ resultRunningVariance, bool saveMeanInvVariance, MeanVarDataType *const __restrict__ resultSaveMean, MeanVarDataType *const __restrict__ resultSaveInvVariance)
Definition: gridwise_multiblock_batchnorm_forward.hpp:31
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
Definition: blockwise_welford.hpp:25
static __device__ void Run(T &mean_value, T &var_value, CountDataType &count)
Definition: blockwise_welford.hpp:51
Definition: gridwise_multiblock_batchnorm_forward.hpp:112
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< 1 >{}))) ThreadReduceSrcDesc_M_1
Definition: gridwise_multiblock_batchnorm_forward.hpp:140
ThreadwiseWelford< AccDataType, ThreadReduceSrcDesc_M_K, ThreadReduceDstDesc_M > ThreadwiseWelford1
Definition: gridwise_multiblock_batchnorm_forward.hpp:143
static constexpr auto I0
Definition: gridwise_multiblock_batchnorm_forward.hpp:162
static constexpr bool reorder_thread_cluster
Definition: gridwise_multiblock_batchnorm_forward.hpp:121
static constexpr index_t K_BlockTileSize
Definition: gridwise_multiblock_batchnorm_forward.hpp:166
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition: gridwise_multiblock_batchnorm_forward.hpp:123
static constexpr auto I1
Definition: gridwise_multiblock_batchnorm_forward.hpp:163
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_multiblock_batchnorm_forward.hpp:160
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition: gridwise_multiblock_batchnorm_forward.hpp:126
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition: gridwise_multiblock_batchnorm_forward.hpp:129
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition: gridwise_multiblock_batchnorm_forward.hpp:137
static constexpr auto thread_cluster_desc
Definition: gridwise_multiblock_batchnorm_forward.hpp:131
static constexpr index_t M_BlockTileSize
Definition: gridwise_multiblock_batchnorm_forward.hpp:165
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition: gridwise_multiblock_batchnorm_forward.hpp:135
static __device__ void Run(const XYGridDesc_M_K &x_grid_desc_m_k, const XYGridDesc_M_K &y_grid_desc_m_k, const MeanVarCountGridDesc_M_G &mean_var_count_grid_desc_m_g, const MeanVarCountGridDesc_M_K &mean_var_count_grid_desc_m_k, const ScaleBiasGridDesc_M &scale_grid_desc_m, const ScaleBiasGridDesc_M &bias_grid_desc_m, const MeanVarGridDesc_M &mean_var_grid_desc_m, const GetReduceCountPerThreadFunctor &get_reduce_count_per_thread, index_t num_k_block_tile_iteration, AccDataType epsilon, const XDataType *const __restrict__ p_x, MeanVarDataType *const __restrict__ p_welford_mean, MeanVarDataType *const __restrict__ p_welford_variance, int32_t *const __restrict__ p_welford_count, int32_t *const __restrict__ p_control, const ScaleDataType *const __restrict__ p_scale, const BiasDataType *const __restrict__ p_bias, const YElementwiseOp y_elementwise_op, YDataType *const __restrict__ p_y, bool updateMovingAverage, AccDataType averageFactor, MeanVarDataType *const __restrict__ resultRunningMean, MeanVarDataType *const __restrict__ resultRunningVariance, bool saveMeanInvVariance, MeanVarDataType *const __restrict__ resultSaveMean, MeanVarDataType *const __restrict__ resultSaveInvVariance)
Definition: gridwise_multiblock_batchnorm_forward.hpp:168
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: threadwise_tensor_slice_transfer.hpp:39
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:66
Definition: threadwise_tensor_slice_transfer.hpp:214
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:243
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:355
Definition: threadwise_welford.hpp:18
Definition: threadwise_welford.hpp:83
static __device__ void Run(const SrcMeanBufferType &src_mean_buf, const SrcVarBufferType &src_var_buf, const SrcCountBufferType &src_count_buf, DstMeanBufferType &dst_mean_buf, DstVarBufferType &dst_var_buf, DstCountBufferType &dst_count_buf)
Definition: threadwise_welford.hpp:110
Definition: functional.hpp:100
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
Definition: unary_element_wise_operation.hpp:241