DeviceBatchNormBwdImpl< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim, UseMultiblockInK, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize > Struct Template Reference

DeviceBatchNormBwdImpl&lt; XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim, UseMultiblockInK, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize &gt; Struct Template Reference#

Composable Kernel: ck::tensor_operation::device::DeviceBatchNormBwdImpl< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim, UseMultiblockInK, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize > Struct Template Reference
ck::tensor_operation::device::DeviceBatchNormBwdImpl< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim, UseMultiblockInK, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize > Struct Template Reference

#include <device_batchnorm_backward_impl.hpp>

Inheritance diagram for ck::tensor_operation::device::DeviceBatchNormBwdImpl< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim, UseMultiblockInK, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize >:
ck::tensor_operation::device::DeviceBatchNormBwd< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim > ck::tensor_operation::device::BaseOperator

Classes

struct  Argument
 
struct  Invoker
 

Public Types

using XYGridDesc_M_K = decltype(MakeXY2dDescriptor({1}, {1}, 1, 1))
 
using ScaleBiasGridDesc_M = decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1}))
 
using MeanVarGridDesc_M = ScaleBiasGridDesc_M
 

Public Member Functions

size_t GetWorkSpaceSize (const BaseArgument *pArg) const override
 
void SetWorkSpacePointer (BaseArgument *pArg, void *p_workspace, const StreamConfig &=StreamConfig{}) const override
 
bool IsSupportedArgument (const BaseArgument *pArg) override
 
std::unique_ptr< BaseArgumentMakeArgumentPointer (const std::array< index_t, Rank > xyLengths, const std::array< index_t, Rank > xStrides, const std::array< index_t, Rank > dyStrides, const std::array< index_t, Rank > dxStrides, const std::array< int, NumBatchNormReduceDim > reduceDims, const std::array< ck::index_t, NumInvariantDim > bnScaleBiasMeanVarLengths, const std::array< ck::index_t, NumInvariantDim > bnScaleStrides, const std::array< ck::index_t, NumInvariantDim > bnDscaleDbiasStrides, const std::array< ck::index_t, NumInvariantDim > bnMeanVarStrides, const void *p_x, const void *p_dy, const void *p_scale, const void *p_savedMean, const void *p_savedInvVar, double epsilon, const DyElementwiseOp dy_elementwise_op, void *p_dx, void *p_dscale, void *p_dbias) override
 
std::unique_ptr< BaseInvokerMakeInvokerPointer () override
 
std::string GetTypeString () const override
 
- Public Member Functions inherited from ck::tensor_operation::device::BaseOperator
 BaseOperator ()=default
 
 BaseOperator (const BaseOperator &)=default
 
BaseOperatoroperator= (const BaseOperator &)=default
 
virtual std::string GetTypeIdName () const
 
virtual std::optional< std::string > GetObjectName () const
 
virtual std::optional< std::string > GetTemplateInfo () const
 
virtual std::string GetTypeIdHashCode () const
 
virtual ~BaseOperator ()
 

Static Public Member Functions

static auto MakeXY2dDescriptor (const std::array< index_t, Rank > &xyLengths, const std::array< index_t, Rank > &xyStrides, int blkGroupSize, int numBlockTileIteration)
 
static auto MakeMultiblockFirstReduceOutputMG2dDescriptor (int invariantLength, int blkGroupSize)
 
static auto MakeMultiblockFinalReduceInputMK2dDescriptor (int invariantLength, int blkGroupSize)
 
static auto MakeScaleBiasMeanVar1dDescriptor (const std::array< index_t, NumInvariantDim > &lengths, const std::array< index_t, NumInvariantDim > &strides)
 

Static Public Attributes

static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim
 
static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize
 
static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize
 
- Static Public Attributes inherited from ck::tensor_operation::device::DeviceBatchNormBwd< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim >
static constexpr index_t NumInvariantDim = Rank - NumBatchNormReduceDim
 

Member Typedef Documentation

◆ MeanVarGridDesc_M

template<typename XDataType , typename DxDataType , typename DyDataType , typename AccDataType , typename ScaleDataType , typename DscaleDbiasDataType , typename MeanVarDataType , typename DyElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim, bool UseMultiblockInK, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XDyDxVectorDim, index_t XSrcVectorSize, index_t DySrcVectorSize, index_t DxDstVectorSize, index_t ScaleSrcVectorSize, index_t DscaleDbiasDstVectorSize, index_t MeanVarSrcVectorSize>
using ck::tensor_operation::device::DeviceBatchNormBwdImpl< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim, UseMultiblockInK, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize >::MeanVarGridDesc_M = ScaleBiasGridDesc_M

◆ ScaleBiasGridDesc_M

template<typename XDataType , typename DxDataType , typename DyDataType , typename AccDataType , typename ScaleDataType , typename DscaleDbiasDataType , typename MeanVarDataType , typename DyElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim, bool UseMultiblockInK, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XDyDxVectorDim, index_t XSrcVectorSize, index_t DySrcVectorSize, index_t DxDstVectorSize, index_t ScaleSrcVectorSize, index_t DscaleDbiasDstVectorSize, index_t MeanVarSrcVectorSize>
using ck::tensor_operation::device::DeviceBatchNormBwdImpl< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim, UseMultiblockInK, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize >::ScaleBiasGridDesc_M = decltype(MakeScaleBiasMeanVar1dDescriptor({1}, {1}))

◆ XYGridDesc_M_K

template<typename XDataType , typename DxDataType , typename DyDataType , typename AccDataType , typename ScaleDataType , typename DscaleDbiasDataType , typename MeanVarDataType , typename DyElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim, bool UseMultiblockInK, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XDyDxVectorDim, index_t XSrcVectorSize, index_t DySrcVectorSize, index_t DxDstVectorSize, index_t ScaleSrcVectorSize, index_t DscaleDbiasDstVectorSize, index_t MeanVarSrcVectorSize>
using ck::tensor_operation::device::DeviceBatchNormBwdImpl< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim, UseMultiblockInK, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize >::XYGridDesc_M_K = decltype(MakeXY2dDescriptor({1}, {1}, 1, 1))

Member Function Documentation

◆ GetTypeString()

template<typename XDataType , typename DxDataType , typename DyDataType , typename AccDataType , typename ScaleDataType , typename DscaleDbiasDataType , typename MeanVarDataType , typename DyElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim, bool UseMultiblockInK, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XDyDxVectorDim, index_t XSrcVectorSize, index_t DySrcVectorSize, index_t DxDstVectorSize, index_t ScaleSrcVectorSize, index_t DscaleDbiasDstVectorSize, index_t MeanVarSrcVectorSize>
std::string ck::tensor_operation::device::DeviceBatchNormBwdImpl< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim, UseMultiblockInK, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize >::GetTypeString ( ) const
inlineoverridevirtual

◆ GetWorkSpaceSize()

template<typename XDataType , typename DxDataType , typename DyDataType , typename AccDataType , typename ScaleDataType , typename DscaleDbiasDataType , typename MeanVarDataType , typename DyElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim, bool UseMultiblockInK, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XDyDxVectorDim, index_t XSrcVectorSize, index_t DySrcVectorSize, index_t DxDstVectorSize, index_t ScaleSrcVectorSize, index_t DscaleDbiasDstVectorSize, index_t MeanVarSrcVectorSize>
size_t ck::tensor_operation::device::DeviceBatchNormBwdImpl< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim, UseMultiblockInK, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize >::GetWorkSpaceSize ( const BaseArgument pArg) const
inlineoverridevirtual

◆ IsSupportedArgument()

template<typename XDataType , typename DxDataType , typename DyDataType , typename AccDataType , typename ScaleDataType , typename DscaleDbiasDataType , typename MeanVarDataType , typename DyElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim, bool UseMultiblockInK, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XDyDxVectorDim, index_t XSrcVectorSize, index_t DySrcVectorSize, index_t DxDstVectorSize, index_t ScaleSrcVectorSize, index_t DscaleDbiasDstVectorSize, index_t MeanVarSrcVectorSize>
bool ck::tensor_operation::device::DeviceBatchNormBwdImpl< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim, UseMultiblockInK, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize >::IsSupportedArgument ( const BaseArgument pArg)
inlineoverridevirtual

◆ MakeArgumentPointer()

template<typename XDataType , typename DxDataType , typename DyDataType , typename AccDataType , typename ScaleDataType , typename DscaleDbiasDataType , typename MeanVarDataType , typename DyElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim, bool UseMultiblockInK, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XDyDxVectorDim, index_t XSrcVectorSize, index_t DySrcVectorSize, index_t DxDstVectorSize, index_t ScaleSrcVectorSize, index_t DscaleDbiasDstVectorSize, index_t MeanVarSrcVectorSize>
std::unique_ptr<BaseArgument> ck::tensor_operation::device::DeviceBatchNormBwdImpl< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim, UseMultiblockInK, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize >::MakeArgumentPointer ( const std::array< index_t, Rank >  xyLengths,
const std::array< index_t, Rank >  xStrides,
const std::array< index_t, Rank >  dyStrides,
const std::array< index_t, Rank >  dxStrides,
const std::array< int, NumBatchNormReduceDim >  reduceDims,
const std::array< ck::index_t, NumInvariantDim bnScaleBiasMeanVarLengths,
const std::array< ck::index_t, NumInvariantDim bnScaleStrides,
const std::array< ck::index_t, NumInvariantDim bnDscaleDbiasStrides,
const std::array< ck::index_t, NumInvariantDim bnMeanVarStrides,
const void *  p_x,
const void *  p_dy,
const void *  p_scale,
const void *  p_savedMean,
const void *  p_savedInvVar,
double  epsilon,
const DyElementwiseOp  dy_elementwise_op,
void *  p_dx,
void *  p_dscale,
void *  p_dbias 
)
inlineoverridevirtual

◆ MakeInvokerPointer()

template<typename XDataType , typename DxDataType , typename DyDataType , typename AccDataType , typename ScaleDataType , typename DscaleDbiasDataType , typename MeanVarDataType , typename DyElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim, bool UseMultiblockInK, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XDyDxVectorDim, index_t XSrcVectorSize, index_t DySrcVectorSize, index_t DxDstVectorSize, index_t ScaleSrcVectorSize, index_t DscaleDbiasDstVectorSize, index_t MeanVarSrcVectorSize>
std::unique_ptr<BaseInvoker> ck::tensor_operation::device::DeviceBatchNormBwdImpl< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim, UseMultiblockInK, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize >::MakeInvokerPointer ( )
inlineoverridevirtual

◆ MakeMultiblockFinalReduceInputMK2dDescriptor()

template<typename XDataType , typename DxDataType , typename DyDataType , typename AccDataType , typename ScaleDataType , typename DscaleDbiasDataType , typename MeanVarDataType , typename DyElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim, bool UseMultiblockInK, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XDyDxVectorDim, index_t XSrcVectorSize, index_t DySrcVectorSize, index_t DxDstVectorSize, index_t ScaleSrcVectorSize, index_t DscaleDbiasDstVectorSize, index_t MeanVarSrcVectorSize>
static auto ck::tensor_operation::device::DeviceBatchNormBwdImpl< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim, UseMultiblockInK, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize >::MakeMultiblockFinalReduceInputMK2dDescriptor ( int  invariantLength,
int  blkGroupSize 
)
inlinestatic

◆ MakeMultiblockFirstReduceOutputMG2dDescriptor()

template<typename XDataType , typename DxDataType , typename DyDataType , typename AccDataType , typename ScaleDataType , typename DscaleDbiasDataType , typename MeanVarDataType , typename DyElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim, bool UseMultiblockInK, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XDyDxVectorDim, index_t XSrcVectorSize, index_t DySrcVectorSize, index_t DxDstVectorSize, index_t ScaleSrcVectorSize, index_t DscaleDbiasDstVectorSize, index_t MeanVarSrcVectorSize>
static auto ck::tensor_operation::device::DeviceBatchNormBwdImpl< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim, UseMultiblockInK, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize >::MakeMultiblockFirstReduceOutputMG2dDescriptor ( int  invariantLength,
int  blkGroupSize 
)
inlinestatic

◆ MakeScaleBiasMeanVar1dDescriptor()

template<typename XDataType , typename DxDataType , typename DyDataType , typename AccDataType , typename ScaleDataType , typename DscaleDbiasDataType , typename MeanVarDataType , typename DyElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim, bool UseMultiblockInK, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XDyDxVectorDim, index_t XSrcVectorSize, index_t DySrcVectorSize, index_t DxDstVectorSize, index_t ScaleSrcVectorSize, index_t DscaleDbiasDstVectorSize, index_t MeanVarSrcVectorSize>
static auto ck::tensor_operation::device::DeviceBatchNormBwdImpl< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim, UseMultiblockInK, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize >::MakeScaleBiasMeanVar1dDescriptor ( const std::array< index_t, NumInvariantDim > &  lengths,
const std::array< index_t, NumInvariantDim > &  strides 
)
inlinestatic

◆ MakeXY2dDescriptor()

template<typename XDataType , typename DxDataType , typename DyDataType , typename AccDataType , typename ScaleDataType , typename DscaleDbiasDataType , typename MeanVarDataType , typename DyElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim, bool UseMultiblockInK, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XDyDxVectorDim, index_t XSrcVectorSize, index_t DySrcVectorSize, index_t DxDstVectorSize, index_t ScaleSrcVectorSize, index_t DscaleDbiasDstVectorSize, index_t MeanVarSrcVectorSize>
static auto ck::tensor_operation::device::DeviceBatchNormBwdImpl< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim, UseMultiblockInK, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize >::MakeXY2dDescriptor ( const std::array< index_t, Rank > &  xyLengths,
const std::array< index_t, Rank > &  xyStrides,
int  blkGroupSize,
int  numBlockTileIteration 
)
inlinestatic

◆ SetWorkSpacePointer()

template<typename XDataType , typename DxDataType , typename DyDataType , typename AccDataType , typename ScaleDataType , typename DscaleDbiasDataType , typename MeanVarDataType , typename DyElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim, bool UseMultiblockInK, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XDyDxVectorDim, index_t XSrcVectorSize, index_t DySrcVectorSize, index_t DxDstVectorSize, index_t ScaleSrcVectorSize, index_t DscaleDbiasDstVectorSize, index_t MeanVarSrcVectorSize>
void ck::tensor_operation::device::DeviceBatchNormBwdImpl< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim, UseMultiblockInK, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize >::SetWorkSpacePointer ( BaseArgument pArg,
void *  p_workspace,
const StreamConfig = StreamConfig{} 
) const
inlineoverridevirtual

Member Data Documentation

◆ K_BlockTileSize

template<typename XDataType , typename DxDataType , typename DyDataType , typename AccDataType , typename ScaleDataType , typename DscaleDbiasDataType , typename MeanVarDataType , typename DyElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim, bool UseMultiblockInK, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XDyDxVectorDim, index_t XSrcVectorSize, index_t DySrcVectorSize, index_t DxDstVectorSize, index_t ScaleSrcVectorSize, index_t DscaleDbiasDstVectorSize, index_t MeanVarSrcVectorSize>
constexpr index_t ck::tensor_operation::device::DeviceBatchNormBwdImpl< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim, UseMultiblockInK, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize >::K_BlockTileSize = KThreadClusterSize * KThreadSliceSize
staticconstexpr

◆ M_BlockTileSize

template<typename XDataType , typename DxDataType , typename DyDataType , typename AccDataType , typename ScaleDataType , typename DscaleDbiasDataType , typename MeanVarDataType , typename DyElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim, bool UseMultiblockInK, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XDyDxVectorDim, index_t XSrcVectorSize, index_t DySrcVectorSize, index_t DxDstVectorSize, index_t ScaleSrcVectorSize, index_t DscaleDbiasDstVectorSize, index_t MeanVarSrcVectorSize>
constexpr index_t ck::tensor_operation::device::DeviceBatchNormBwdImpl< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim, UseMultiblockInK, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize >::M_BlockTileSize = MThreadClusterSize * MThreadSliceSize
staticconstexpr

◆ NumInvariantDim

template<typename XDataType , typename DxDataType , typename DyDataType , typename AccDataType , typename ScaleDataType , typename DscaleDbiasDataType , typename MeanVarDataType , typename DyElementwiseOp , index_t Rank, index_t NumBatchNormReduceDim, bool UseMultiblockInK, index_t BlockSize, index_t MThreadClusterSize, index_t KThreadClusterSize, index_t MThreadSliceSize, index_t KThreadSliceSize, index_t XDyDxVectorDim, index_t XSrcVectorSize, index_t DySrcVectorSize, index_t DxDstVectorSize, index_t ScaleSrcVectorSize, index_t DscaleDbiasDstVectorSize, index_t MeanVarSrcVectorSize>
constexpr index_t ck::tensor_operation::device::DeviceBatchNormBwdImpl< XDataType, DxDataType, DyDataType, AccDataType, ScaleDataType, DscaleDbiasDataType, MeanVarDataType, DyElementwiseOp, Rank, NumBatchNormReduceDim, UseMultiblockInK, BlockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XDyDxVectorDim, XSrcVectorSize, DySrcVectorSize, DxDstVectorSize, ScaleSrcVectorSize, DscaleDbiasDstVectorSize, MeanVarSrcVectorSize >::NumInvariantDim = Rank - NumBatchNormReduceDim
staticconstexpr

The documentation for this struct was generated from the following file:
  • /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.1/include/ck/tensor_operation/gpu/device/impl/device_batchnorm_backward_impl.hpp