/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp Source File#
device_batchnorm_forward_impl.hpp
Go to the documentation of this file.
13 #include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_batchnorm_forward.hpp"
14 #include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
15 #include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp"
815 str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << BiasSrcVectorSize << "_mean_var_" << MeanVarSrcDstVectorSize << "_Y" << YDstVectorSize << ">";
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
Definition: ck.hpp:264
__global__ void kernel_multiblock_welford_first_half(const XGridDesc_M_K x_grid_desc_m_k, const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, index_t num_k_block_tile_iteration, const XDataType *const __restrict__ p_x, MeanVarDataType *const p_welford_mean, MeanVarDataType *const p_welford_variance, int32_t *const p_welford_count)
Definition: gridwise_multiblock_welford_first_half.hpp:21
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__global__ void kernel_welford_second_half_batchnorm_forward_final(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K y_grid_desc_m_k, const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M bias_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m, index_t blkgroup_size, index_t num_xy_k_block_tile_iteration, AccDataType epsilon, const MeanVarDataType *const __restrict__ p_in_welford_mean, const MeanVarDataType *const __restrict__ p_in_welford_variance, const int32_t *const __restrict__ p_in_welford_count, const XDataType *const __restrict__ p_x, const ScaleDataType *const __restrict__ p_scale, const BiasDataType *const __restrict__ p_bias, const YElementwiseOp y_elementwise_op, YDataType *const __restrict__ p_y, bool updateMovingAverage, AccDataType averageFactor, MeanVarDataType *const __restrict__ resultRunningMean, MeanVarDataType *const __restrict__ resultRunningVariance, bool saveMeanInvVariance, MeanVarDataType *const __restrict__ resultSaveMean, MeanVarDataType *const __restrict__ resultSaveInvVariance)
Definition: gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp:27
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__global__ void kernel_multiblock_batchnorm_forward(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K y_grid_desc_m_k, const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g, const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M bias_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, index_t num_k_block_tile_iteration, AccDataType epsilon, const XDataType *const __restrict__ p_x, MeanVarDataType *const __restrict__ p_welford_mean, MeanVarDataType *const __restrict__ p_welford_variance, int32_t *const __restrict__ p_welford_count, int32_t *const __restrict__ p_control, const ScaleDataType *const __restrict__ p_scale, const BiasDataType *const __restrict__ p_bias, const YElementwiseOp y_elementwise_op, YDataType *const __restrict__ p_y, bool updateMovingAverage, AccDataType averageFactor, MeanVarDataType *const __restrict__ resultRunningMean, MeanVarDataType *const __restrict__ resultRunningVariance, bool saveMeanInvVariance, MeanVarDataType *const __restrict__ resultSaveMean, MeanVarDataType *const __restrict__ resultSaveInvVariance)
Definition: gridwise_multiblock_batchnorm_forward.hpp:31
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
__global__ void kernel_batchnorm_forward_with_blockwise_welford(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K y_grid_desc_m_k, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M bias_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, index_t num_k_block_tile_iteration, AccDataType epsilon, const XDataType *const __restrict__ p_x, const ScaleDataType *const __restrict__ p_scale, const BiasDataType *const __restrict__ p_bias, const YElementwiseOp y_elementwise_op, YDataType *const __restrict__ p_y, bool updateMovingAverage, AccDataType averageFactor, MeanVarDataType *const __restrict__ resultRunningMean, MeanVarDataType *const __restrict__ resultRunningVariance, bool saveMeanInvVariance, MeanVarDataType *const __restrict__ resultSaveMean, MeanVarDataType *const __restrict__ resultSaveInvVariance)
Definition: gridwise_batchnorm_forward_blockwise_welford.hpp:27
Definition: stream_config.hpp:10
Definition: gridwise_batchnorm_forward_blockwise_welford.hpp:94
Definition: gridwise_multiblock_batchnorm_forward.hpp:112
Definition: gridwise_multiblock_welford_first_half.hpp:55
Definition: gridwise_multiblock_welford_second_half_batchnorm_forward_final_obsolete.hpp:102
Definition: sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:256
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_batchnorm_forward.hpp:26
Definition: device_batchnorm_forward_impl.hpp:190
MeanVarDataType * resultRunningMean_
Definition: device_batchnorm_forward_impl.hpp:305
long_index_t reduce_length_
Definition: device_batchnorm_forward_impl.hpp:309
const ScaleDataType * p_scale_
Definition: device_batchnorm_forward_impl.hpp:297
bool updateMovingAverage_
Definition: device_batchnorm_forward_impl.hpp:284
ScaleBiasMeanVarGridDesc_M scale_grid_desc_m_
Definition: device_batchnorm_forward_impl.hpp:317
std::array< index_t, Rank > xStrides_
Definition: device_batchnorm_forward_impl.hpp:288
const XDataType * p_x_
Definition: device_batchnorm_forward_impl.hpp:296
std::array< index_t, Rank > xyLengths_
Definition: device_batchnorm_forward_impl.hpp:287
int blkGroupSize_
Definition: device_batchnorm_forward_impl.hpp:311
XYGridDesc_M_K x_grid_desc_m_k_
Definition: device_batchnorm_forward_impl.hpp:315
XYGridDesc_M_K y_grid_desc_m_k_
Definition: device_batchnorm_forward_impl.hpp:316
ScaleBiasMeanVarGridDesc_M bias_grid_desc_m_
Definition: device_batchnorm_forward_impl.hpp:318
bool saveMeanInvVariance_
Definition: device_batchnorm_forward_impl.hpp:285
long_index_t invariant_length_
Definition: device_batchnorm_forward_impl.hpp:308
MeanVarDataType * resultRunningVariance_
Definition: device_batchnorm_forward_impl.hpp:306
Argument(const std::array< index_t, Rank > xyLengths, const std::array< index_t, Rank > xStrides, const std::array< index_t, Rank > yStrides, const std::array< int, NumBatchNormReduceDim > reduceDims, const std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleBiasMeanVarLengths, const std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleStrides, const std::array< index_t, Rank - NumBatchNormReduceDim > bnBiasStrides, const std::array< index_t, Rank - NumBatchNormReduceDim > bnMeanVarStrides, const XDataType *p_x, const ScaleDataType *p_scale, const BiasDataType *p_bias, const YElementwiseOp y_elementwise_op, double epsilon, YDataType *p_y, MeanVarDataType *resultSaveMean, MeanVarDataType *resultSaveInvVariance, double averageFactor, MeanVarDataType *resultRunningMean, MeanVarDataType *resultRunningVariance)
Definition: device_batchnorm_forward_impl.hpp:191
AccDataType averageFactor_
Definition: device_batchnorm_forward_impl.hpp:282
const BiasDataType * p_bias_
Definition: device_batchnorm_forward_impl.hpp:298
AccDataType epsilon_
Definition: device_batchnorm_forward_impl.hpp:281
ScaleBiasMeanVarGridDesc_M mean_var_grid_desc_m_
Definition: device_batchnorm_forward_impl.hpp:319
std::array< index_t, Rank - NumBatchNormReduceDim > bnBiasStrides_
Definition: device_batchnorm_forward_impl.hpp:293
int numBlockTileIteration_
Definition: device_batchnorm_forward_impl.hpp:312
void * workspace_count_
Definition: device_batchnorm_forward_impl.hpp:323
const YElementwiseOp y_elementwise_op_
Definition: device_batchnorm_forward_impl.hpp:299
void * workspace_mean_
Definition: device_batchnorm_forward_impl.hpp:321
std::array< index_t, Rank - NumBatchNormReduceDim > bnMeanVarStrides_
Definition: device_batchnorm_forward_impl.hpp:294
void * control_
Definition: device_batchnorm_forward_impl.hpp:325
MeanVarDataType * resultSaveMean_
Definition: device_batchnorm_forward_impl.hpp:302
YDataType * p_y_
Definition: device_batchnorm_forward_impl.hpp:300
size_t gridSize_
Definition: device_batchnorm_forward_impl.hpp:313
MeanVarDataType * resultSaveInvVariance_
Definition: device_batchnorm_forward_impl.hpp:303
std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleBiasMeanVarLengths_
Definition: device_batchnorm_forward_impl.hpp:291
void * workspace_variance_
Definition: device_batchnorm_forward_impl.hpp:322
std::array< index_t, Rank > yStrides_
Definition: device_batchnorm_forward_impl.hpp:289
std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleStrides_
Definition: device_batchnorm_forward_impl.hpp:292
Definition: device_batchnorm_forward_impl.hpp:403
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_batchnorm_forward_impl.hpp:404
float Run(const BaseArgument *pArg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_batchnorm_forward_impl.hpp:699
Definition: device_batchnorm_forward_impl.hpp:56
static auto MakeXY2dDescriptor(const std::array< index_t, Rank > &xyLengths, const std::array< index_t, Rank > &xyStrides, int blkGroupSize, int numBlockTileIteration)
Definition: device_batchnorm_forward_impl.hpp:70
bool IsSupportedArgument(const BaseArgument *pArg) override
Definition: device_batchnorm_forward_impl.hpp:706
static constexpr index_t K_BlockTileSize
Definition: device_batchnorm_forward_impl.hpp:68
std::string GetTypeString() const override
Definition: device_batchnorm_forward_impl.hpp:806
void SetWorkSpacePointer(BaseArgument *pArg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition: device_batchnorm_forward_impl.hpp:357
decltype(MakeXY2dDescriptor({1}, {1}, 1, 1)) XYGridDesc_M_K
Definition: device_batchnorm_forward_impl.hpp:186
static constexpr index_t M_BlockTileSize
Definition: device_batchnorm_forward_impl.hpp:67
decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1})) ScaleBiasMeanVarGridDesc_M
Definition: device_batchnorm_forward_impl.hpp:187
static auto MakeScaleBiasMeanVar1dDescriptor(const std::array< index_t, NumInvariantDim > &lengths, const std::array< index_t, NumInvariantDim > &strides)
Definition: device_batchnorm_forward_impl.hpp:157
static constexpr index_t NumInvariantDim
Definition: device_batchnorm_forward_impl.hpp:65
size_t GetWorkSpaceSize(const BaseArgument *pArg) const override
Definition: device_batchnorm_forward_impl.hpp:328
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, Rank > xyLengths, const std::array< index_t, Rank > xStrides, const std::array< index_t, Rank > yStrides, const std::array< int, NumBatchNormReduceDim > reduceDims, const std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleBiasMeanVarLengths, const std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleStrides, const std::array< index_t, Rank - NumBatchNormReduceDim > bnBiasStrides, const std::array< index_t, Rank - NumBatchNormReduceDim > bnMeanVarStrides, const void *p_x, const void *p_scale, const void *p_bias, double epsilon, const YElementwiseOp y_elementwise_op, void *p_y, void *resultSaveMean, void *resultSaveInvVariance, double averageFactor, void *resultRunningMean, void *resultRunningVariance) override
Definition: device_batchnorm_forward_impl.hpp:759
static auto MakeMeanVarCountInputMK2dDescriptor(int invariantLength, int blkGroupSize)
Definition: device_batchnorm_forward_impl.hpp:135
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_batchnorm_forward_impl.hpp:801
static auto MakeMeanVarCountOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
Definition: device_batchnorm_forward_impl.hpp:117
Definition: welford_helper.hpp:44