/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp Source File#
device_batchnorm_backward_impl.hpp
Go to the documentation of this file.
13 #include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_first_half.hpp"
14 #include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp"
15 #include "ck/tensor_operation/gpu/grid/batchnorm_multiblock/gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp"
123 static auto MakeMultiblockFirstReduceOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
867 str << "VectorSize_X" << XSrcVectorSize << "_scale_" << ScaleSrcVectorSize << "_bias_" << DscaleDbiasDstVectorSize << "_mean_var_" << MeanVarSrcVectorSize << "_Dx_" << DxDstVectorSize << ">";
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
__host__ constexpr __device__ auto integer_least_multiple(X x, Y y)
Definition: math.hpp:78
Definition: ck.hpp:264
__global__ void kernel_multiblock_welford_first_half(const XGridDesc_M_K x_grid_desc_m_k, const MeanVarCountGridDesc_M_G mean_var_count_grid_desc_m_g, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, index_t num_k_block_tile_iteration, const XDataType *const __restrict__ p_x, MeanVarDataType *const p_welford_mean, MeanVarDataType *const p_welford_variance, int32_t *const p_welford_count)
Definition: gridwise_multiblock_welford_first_half.hpp:21
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__global__ void kernel_welford_second_half_reduce_first_half(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k, const MeanVarGridDesc_M mean_var_grid_desc_m, const MeanVarCountGridDesc_M_K mean_var_count_grid_desc_m_k, const DscaleDbiasGridDesc_M_G dscale_dbias_grid_desc_m_g, index_t blkgroup_size, index_t num_xy_k_block_tile_iteration, index_t num_mean_var_count_k_block_tile_iteration, AccDataType epsilon, bool haveSavedMeanInvVar, const MeanVarDataType *const __restrict__ p_savedMean, const MeanVarDataType *const __restrict__ p_savedInvVar, const MeanVarDataType *const __restrict__ p_in_welford_mean, const MeanVarDataType *const __restrict__ p_in_welford_variance, const int32_t *const __restrict__ p_in_welford_count, const DyElementwiseOp dy_elementwise_op, MeanVarDataType *const __restrict__ p_out_welford_mean, MeanVarDataType *const __restrict__ p_out_welford_inv_variance, const XDataType *const __restrict__ p_x, const DyDataType *const __restrict__ p_dy, DscaleDbiasDataType *const __restrict__ p_reduce_dscale, DscaleDbiasDataType *const __restrict__ p_reduce_dbias)
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:27
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__global__ void kernel_batchnorm_backward_with_blockwise_welford(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k, const XYGridDesc_M_K dx_grid_desc_m_k, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M dscale_dbias_grid_desc_m, const MeanVarGridDesc_M mean_var_grid_desc_m, const GetReduceCountPerThreadFunctor get_reduce_count_per_thread, long_index_t reduce_size, index_t num_k_block_tile_iteration, AccDataType epsilon, const XDataType *const __restrict__ p_x, const DyDataType *const __restrict__ p_dy, const ScaleDataType *const __restrict__ p_scale, bool haveSavedMeanInvVar, const MeanVarDataType *const __restrict__ p_savedMean, const MeanVarDataType *const __restrict__ p_savedInvVar, const DyElementwiseOp dy_elementwise_op, DxDataType *const __restrict__ p_dx, DscaleDbiasDataType *const __restrict__ p_dscale, DscaleDbiasDataType *const __restrict__ p_dbias)
Definition: gridwise_batchnorm_backward_blockwise_welford.hpp:31
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__global__ void kernel_reduce_second_half_batchnorm_backward_final(const XYGridDesc_M_K x_grid_desc_m_k, const XYGridDesc_M_K dy_grid_desc_m_k, const XYGridDesc_M_K dx_grid_desc_m_k, const DscaleDbiasGridDesc_M_K dscale_dbias_grid_desc_m_k, const MeanVarGridDesc_M mean_var_grid_desc_m, const ScaleBiasGridDesc_M scale_grid_desc_m, const ScaleBiasGridDesc_M bias_grid_desc_m, index_t blkgroup_size, long_index_t reduce_size, index_t num_xy_k_block_tile_iteration, index_t num_dscale_dbias_k_block_tile_iteration, const DscaleDbiasDataType *const __restrict__ p_reduce_dscale, const DscaleDbiasDataType *const __restrict__ p_reduce_dbias, const MeanVarDataType *const __restrict__ p_mean, const MeanVarDataType *const __restrict__ p_inv_var, const XDataType *const __restrict__ p_x, const DyDataType *const __restrict__ p_dy, const ScaleDataType *const __restrict__ p_scale, const DyElementwiseOp dy_elementwise_op, DxDataType *const __restrict__ p_dx, DscaleDbiasDataType *const __restrict__ p_dscale, DscaleDbiasDataType *const __restrict__ p_dbias)
Definition: gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp:26
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
Definition: stream_config.hpp:10
Definition: gridwise_batchnorm_backward_blockwise_welford.hpp:100
Definition: gridwise_multiblock_welford_first_half.hpp:55
Definition: gridwise_multiblock_reduce_second_half_batchnorm_backward_final.hpp:99
Definition: gridwise_multiblock_welford_second_half_multiblock_reduce_first_half.hpp:96
Definition: sequence.hpp:43
typename conditional< kHasContent, type0, type1 >::type type
Definition: sequence.hpp:256
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_batchnorm_backward.hpp:27
Definition: device_batchnorm_backward_impl.hpp:197
std::array< index_t, Rank > dyStrides_
Definition: device_batchnorm_backward_impl.hpp:295
XYGridDesc_M_K x_grid_desc_m_k
Definition: device_batchnorm_backward_impl.hpp:320
AccDataType epsilon_
Definition: device_batchnorm_backward_impl.hpp:289
DscaleDbiasDataType * p_dscale_
Definition: device_batchnorm_backward_impl.hpp:310
std::array< index_t, Rank > xStrides_
Definition: device_batchnorm_backward_impl.hpp:294
std::array< index_t, Rank > xyLengths_
Definition: device_batchnorm_backward_impl.hpp:293
std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleStrides_
Definition: device_batchnorm_backward_impl.hpp:299
bool haveSavedMeanInvVar_
Definition: device_batchnorm_backward_impl.hpp:291
const MeanVarDataType * p_savedMean_
Definition: device_batchnorm_backward_impl.hpp:306
int blkGroupSize
Definition: device_batchnorm_backward_impl.hpp:316
std::array< index_t, Rank > dxStrides_
Definition: device_batchnorm_backward_impl.hpp:296
ScaleBiasGridDesc_M dscale_dbias_grid_desc_m
Definition: device_batchnorm_backward_impl.hpp:324
std::array< index_t, Rank - NumBatchNormReduceDim > bnMeanVarStrides_
Definition: device_batchnorm_backward_impl.hpp:301
void * workspace_reduce_dbias
Definition: device_batchnorm_backward_impl.hpp:335
const ScaleDataType * p_scale_
Definition: device_batchnorm_backward_impl.hpp:305
long_index_t reduce_length
Definition: device_batchnorm_backward_impl.hpp:314
const DyDataType * p_dy_
Definition: device_batchnorm_backward_impl.hpp:304
Argument(const std::array< index_t, Rank > xyLengths, const std::array< index_t, Rank > xStrides, const std::array< index_t, Rank > dyStrides, const std::array< index_t, Rank > dxStrides, const std::array< int, NumBatchNormReduceDim > reduceDims, const std::array< ck::index_t, NumInvariantDim > bnScaleBiasMeanVarLengths, const std::array< ck::index_t, NumInvariantDim > bnScaleStrides, const std::array< ck::index_t, NumInvariantDim > bnDscaleDbiasStrides, const std::array< ck::index_t, NumInvariantDim > bnMeanVarStrides, const XDataType *p_x, const DyDataType *p_dy, const ScaleDataType *p_scale, const MeanVarDataType *p_savedMean, const MeanVarDataType *p_savedInvVar, const DyElementwiseOp dy_elementwise_op, double epsilon, DxDataType *p_dx, DscaleDbiasDataType *p_dscale, DscaleDbiasDataType *p_dbias)
Definition: device_batchnorm_backward_impl.hpp:198
ScaleBiasGridDesc_M scale_grid_desc_m
Definition: device_batchnorm_backward_impl.hpp:323
size_t gridSize
Definition: device_batchnorm_backward_impl.hpp:318
DxDataType * p_dx_
Definition: device_batchnorm_backward_impl.hpp:309
void * workspace_variance
Definition: device_batchnorm_backward_impl.hpp:328
MeanVarGridDesc_M mean_var_grid_desc_m
Definition: device_batchnorm_backward_impl.hpp:325
const XDataType * p_x_
Definition: device_batchnorm_backward_impl.hpp:303
void * workspace_savedMean
Definition: device_batchnorm_backward_impl.hpp:331
int numBlockTileIteration
Definition: device_batchnorm_backward_impl.hpp:317
void * workspace_mean
Definition: device_batchnorm_backward_impl.hpp:327
void * workspace_savedInvVar
Definition: device_batchnorm_backward_impl.hpp:332
long_index_t invariant_length
Definition: device_batchnorm_backward_impl.hpp:313
DscaleDbiasDataType * p_dbias_
Definition: device_batchnorm_backward_impl.hpp:311
std::array< index_t, Rank - NumBatchNormReduceDim > bnScaleBiasMeanVarLengths_
Definition: device_batchnorm_backward_impl.hpp:298
std::array< index_t, Rank - NumBatchNormReduceDim > bnDscaleDbiasStrides_
Definition: device_batchnorm_backward_impl.hpp:300
void * workspace_count
Definition: device_batchnorm_backward_impl.hpp:329
XYGridDesc_M_K dy_grid_desc_m_k
Definition: device_batchnorm_backward_impl.hpp:321
const MeanVarDataType * p_savedInvVar_
Definition: device_batchnorm_backward_impl.hpp:307
const DyElementwiseOp dy_elementwise_op_
Definition: device_batchnorm_backward_impl.hpp:308
void * workspace_reduce_dscale
Definition: device_batchnorm_backward_impl.hpp:334
XYGridDesc_M_K dx_grid_desc_m_k
Definition: device_batchnorm_backward_impl.hpp:322
Definition: device_batchnorm_backward_impl.hpp:436
float Run(const Argument &arg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_batchnorm_backward_impl.hpp:437
float Run(const BaseArgument *pArg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_batchnorm_backward_impl.hpp:742
Definition: device_batchnorm_backward_impl.hpp:58
std::string GetTypeString() const override
Definition: device_batchnorm_backward_impl.hpp:858
static constexpr index_t NumInvariantDim
Definition: device_batchnorm_backward_impl.hpp:71
static constexpr index_t M_BlockTileSize
Definition: device_batchnorm_backward_impl.hpp:73
std::unique_ptr< BaseArgument > MakeArgumentPointer(const std::array< index_t, Rank > xyLengths, const std::array< index_t, Rank > xStrides, const std::array< index_t, Rank > dyStrides, const std::array< index_t, Rank > dxStrides, const std::array< int, NumBatchNormReduceDim > reduceDims, const std::array< ck::index_t, NumInvariantDim > bnScaleBiasMeanVarLengths, const std::array< ck::index_t, NumInvariantDim > bnScaleStrides, const std::array< ck::index_t, NumInvariantDim > bnDscaleDbiasStrides, const std::array< ck::index_t, NumInvariantDim > bnMeanVarStrides, const void *p_x, const void *p_dy, const void *p_scale, const void *p_savedMean, const void *p_savedInvVar, double epsilon, const DyElementwiseOp dy_elementwise_op, void *p_dx, void *p_dscale, void *p_dbias) override
Definition: device_batchnorm_backward_impl.hpp:812
bool IsSupportedArgument(const BaseArgument *pArg) override
Definition: device_batchnorm_backward_impl.hpp:749
static auto MakeMultiblockFirstReduceOutputMG2dDescriptor(int invariantLength, int blkGroupSize)
Definition: device_batchnorm_backward_impl.hpp:123
ScaleBiasGridDesc_M MeanVarGridDesc_M
Definition: device_batchnorm_backward_impl.hpp:194
static constexpr index_t K_BlockTileSize
Definition: device_batchnorm_backward_impl.hpp:74
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_batchnorm_backward_impl.hpp:853
static auto MakeXY2dDescriptor(const std::array< index_t, Rank > &xyLengths, const std::array< index_t, Rank > &xyStrides, int blkGroupSize, int numBlockTileIteration)
Definition: device_batchnorm_backward_impl.hpp:76
decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1})) ScaleBiasGridDesc_M
Definition: device_batchnorm_backward_impl.hpp:193
static auto MakeScaleBiasMeanVar1dDescriptor(const std::array< index_t, NumInvariantDim > &lengths, const std::array< index_t, NumInvariantDim > &strides)
Definition: device_batchnorm_backward_impl.hpp:163
size_t GetWorkSpaceSize(const BaseArgument *pArg) const override
Definition: device_batchnorm_backward_impl.hpp:338
static auto MakeMultiblockFinalReduceInputMK2dDescriptor(int invariantLength, int blkGroupSize)
Definition: device_batchnorm_backward_impl.hpp:141
void SetWorkSpacePointer(BaseArgument *pArg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
Definition: device_batchnorm_backward_impl.hpp:379
decltype(MakeXY2dDescriptor({1}, {1}, 1, 1)) XYGridDesc_M_K
Definition: device_batchnorm_backward_impl.hpp:192
Definition: welford_helper.hpp:44