25 template <
typename DYDataType,
 
   27           typename GammaDataType,
 
   28           typename MeanInvStdDataType,
 
   29           typename ComputeDataType,
 
   31           typename GridDesc_M_K,
 
   44           index_t MeanInvStdSrcVectorSize,
 
   51     static_assert(((DYSrcVectorDim == 0 && MThreadSliceSize == DYSrcVectorSize) ||
 
   52                    (DYSrcVectorDim == 1 && KThreadSliceSize == DYSrcVectorSize)),
 
   53                   "Invalid thread slice sizes and/or dy vector sizes configuration, please check!");
 
   55     static_assert(((XSrcVectorDim == 0 && MThreadSliceSize == XSrcVectorSize) ||
 
   56                    (XSrcVectorDim == 1 && KThreadSliceSize == XSrcVectorSize)),
 
   57                   "Invalid thread slice sizes and/or x vector sizes configuration, please check!");
 
   60         ((GammaSrcVectorDim == 0 && MThreadSliceSize == GammaSrcVectorSize) ||
 
   61          (GammaSrcVectorDim == 1 && KThreadSliceSize == GammaSrcVectorSize)),
 
   62         "Invalid thread slice sizes and/or gamma vector sizes configuration, please check!");
 
   65         ((MeanInvStdSrcVectorDim == 0 && MThreadSliceSize == MeanInvStdSrcVectorSize) ||
 
   66          (MeanInvStdSrcVectorDim == 1 && KThreadSliceSize == MeanInvStdSrcVectorSize)),
 
   67         "Invalid thread slice sizes and/or mean/inv_std vector sizes configuration, please check!");
 
   69     static_assert(((DXDstVectorDim == 0 && MThreadSliceSize == DXDstVectorSize) ||
 
   70                    (DXDstVectorDim == 1 && KThreadSliceSize == DXDstVectorSize)),
 
   71                   "Invalid thread slice sizes and/or dx vector sizes configuration, please check!");
 
  114     __device__ 
static void Run(
const GridDesc_M_K& dy_grid_desc_m_k,
 
  115                                const GridDesc_M_K& x_grid_desc_m_k,
 
  116                                const GridDesc_M_K& gamma_grid_desc_m_k,
 
  117                                const GridDesc_M_K& mean_grid_desc_m_k,
 
  118                                const GridDesc_M_K& inv_std_grid_desc_m_k,
 
  119                                const GridDesc_M_K& dx_grid_desc_m_k,
 
  120                                index_t num_k_block_tile_iteration,
 
  121                                const DYDataType* 
const __restrict__ p_dy_global,
 
  122                                const XDataType* 
const __restrict__ p_x_global,
 
  123                                const GammaDataType* 
const __restrict__ p_gamma_global,
 
  124                                const MeanInvStdDataType* 
const __restrict__ p_mean_global,
 
  125                                const MeanInvStdDataType* 
const __restrict__ p_inv_std_global,
 
  126                                DXDataType* 
const __restrict__ p_dx_global)
 
  129         __shared__ ComputeDataType p_reduce_work_buffer[BlockSize];
 
  131         auto reduce_work_buf =
 
  132             make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
 
  135         const auto dy_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  136             p_dy_global, dy_grid_desc_m_k.GetElementSpaceSize());
 
  138         const auto x_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  139             p_x_global, x_grid_desc_m_k.GetElementSpaceSize());
 
  141         auto gamma_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  142             p_gamma_global, gamma_grid_desc_m_k.GetElementSpaceSize());
 
  144         const auto mean_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  145             p_mean_global, mean_grid_desc_m_k.GetElementSpaceSize());
 
  147         const auto inv_std_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  148             p_inv_std_global, inv_std_grid_desc_m_k.GetElementSpaceSize());
 
  150         auto dx_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  151             p_dx_global, dx_grid_desc_m_k.GetElementSpaceSize());
 
  156                                           MThreadSliceSize * KThreadSliceSize,
 
  161                                          MThreadSliceSize * KThreadSliceSize,
 
  166                                              MThreadSliceSize * KThreadSliceSize,
 
  171                                             MThreadSliceSize * KThreadSliceSize,
 
  176                                                MThreadSliceSize * KThreadSliceSize,
 
  181                                           MThreadSliceSize * KThreadSliceSize,
 
  194         const auto thread_cluster_idx =
 
  197         const auto thread_m_cluster_id = thread_cluster_idx[
I0];
 
  198         const auto thread_k_cluster_id = thread_cluster_idx[
I1];
 
  213                                  thread_m_cluster_id * MThreadSliceSize,
 
  214                              thread_k_cluster_id * KThreadSliceSize));
 
  228                                  thread_m_cluster_id * MThreadSliceSize,
 
  229                              thread_k_cluster_id * KThreadSliceSize));
 
  231         auto threadwise_gamma_load =
 
  244                                      thread_m_cluster_id * MThreadSliceSize,
 
  245                                  thread_k_cluster_id * KThreadSliceSize));
 
  247         auto threadwise_mean_load =
 
  254                                              MeanInvStdSrcVectorDim,
 
  255                                              MeanInvStdSrcVectorSize,
 
  260                                      thread_m_cluster_id * MThreadSliceSize,
 
  261                                  thread_k_cluster_id * KThreadSliceSize));
 
  263         auto threadwise_inv_std_load =
 
  270                                              MeanInvStdSrcVectorDim,
 
  271                                              MeanInvStdSrcVectorSize,
 
  274                 inv_std_grid_desc_m_k,
 
  276                                      thread_m_cluster_id * MThreadSliceSize,
 
  277                                  thread_k_cluster_id * KThreadSliceSize));
 
  279         auto threadwise_dx_store =
 
  294                                      thread_m_cluster_id * MThreadSliceSize,
 
  295                                  thread_k_cluster_id * KThreadSliceSize),
 
  298         ComputeDataType reduce_size = type_convert<ComputeDataType>(
 
  299             dy_grid_desc_m_k.GetTransforms()[
I2].GetUpperLengths()[
I0]);
 
  302             ds_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
 
  303             db_thread_buf(I) = type_convert<ComputeDataType>(0.0f);
 
  309         if constexpr(SweepOnce)
 
  311             threadwise_dy_load.Run(dy_grid_desc_m_k,
 
  317             threadwise_x_load.
Run(x_grid_desc_m_k,
 
  323             threadwise_gamma_load.
Run(gamma_grid_desc_m_k,
 
  324                                       gamma_global_val_buf,
 
  329             threadwise_mean_load.
Run(mean_grid_desc_m_k,
 
  335             threadwise_inv_std_load.
Run(inv_std_grid_desc_m_k,
 
  336                                         inv_std_global_val_buf,
 
  342                 constexpr 
auto offset_m =
 
  346                     constexpr 
auto offset_m_k =
 
  349                     ds_thread_buf(offset_m) += dy_thread_buf[offset_m_k] *
 
  350                                                gamma_thread_buf[offset_m_k] *
 
  351                                                x_thread_buf[offset_m_k];
 
  353                     db_thread_buf(offset_m) +=
 
  354                         dy_thread_buf[offset_m_k] * gamma_thread_buf[offset_m_k];
 
  368                 constexpr 
auto offset_m =
 
  372                     constexpr 
auto offset_m_k =
 
  379                     ComputeDataType b = db_thread_buf[offset_m] * mean_thread_buf[offset_m_k] -
 
  380                                         ds_thread_buf[offset_m];
 
  382                     b *= inv_std_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] *
 
  383                          inv_std_thread_buf[offset_m_k] / reduce_size;
 
  385                     ComputeDataType c = -b * mean_thread_buf(offset_m_k);
 
  387                     c -= db_thread_buf[offset_m] * inv_std_thread_buf[offset_m_k] / reduce_size;
 
  389                     dx_thread_buf(offset_m_k) = dy_thread_buf[offset_m_k] *
 
  390                                                     gamma_thread_buf[offset_m_k] *
 
  391                                                     inv_std_thread_buf[offset_m_k] +
 
  392                                                 b * x_thread_buf[offset_m_k] + c;
 
  407             for(
index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
 
  409                 threadwise_dy_load.Run(dy_grid_desc_m_k,
 
  415                 threadwise_x_load.
Run(x_grid_desc_m_k,
 
  421                 threadwise_gamma_load.
Run(gamma_grid_desc_m_k,
 
  422                                           gamma_global_val_buf,
 
  427                 threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_fwd_step_m_k);
 
  430                                                          thread_copy_fwd_step_m_k);
 
  433                     constexpr 
auto offset_m =
 
  437                         constexpr 
auto offset_m_k =
 
  440                         ds_thread_buf(offset_m) += dy_thread_buf[offset_m_k] *
 
  441                                                    gamma_thread_buf[offset_m_k] *
 
  442                                                    x_thread_buf[offset_m_k];
 
  444                         db_thread_buf(offset_m) +=
 
  445                             dy_thread_buf[offset_m_k] * gamma_thread_buf[offset_m_k];
 
  461             auto thread_copy_tail_m_k = (num_k_block_tile_iteration - 1) * thread_copy_fwd_step_m_k;
 
  464             threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k);
 
  466             threadwise_gamma_load.
MoveSrcSliceWindow(gamma_grid_desc_m_k, thread_copy_bwd_step_m_k);
 
  470             threadwise_inv_std_load.
MoveSrcSliceWindow(inv_std_grid_desc_m_k, thread_copy_tail_m_k);
 
  471             threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_tail_m_k);
 
  473             for(
index_t reducedTiles = 0; reducedTiles < num_k_block_tile_iteration; ++reducedTiles)
 
  475                 threadwise_dy_load.
Run(dy_grid_desc_m_k,
 
  481                 threadwise_x_load.
Run(x_grid_desc_m_k,
 
  487                 threadwise_gamma_load.
Run(gamma_grid_desc_m_k,
 
  488                                           gamma_global_val_buf,
 
  493                 threadwise_mean_load.
Run(mean_grid_desc_m_k,
 
  499                 threadwise_inv_std_load.
Run(inv_std_grid_desc_m_k,
 
  500                                             inv_std_global_val_buf,
 
  506                     constexpr 
auto offset_m =
 
  510                         constexpr 
auto offset_m_k =
 
  517                         ComputeDataType b = db_thread_buf[offset_m] * mean_thread_buf[offset_m_k] -
 
  518                                             ds_thread_buf[offset_m];
 
  520                         b *= inv_std_thread_buf[offset_m_k] * inv_std_thread_buf[offset_m_k] *
 
  521                              inv_std_thread_buf[offset_m_k] / reduce_size;
 
  523                         ComputeDataType c = -b * mean_thread_buf(offset_m_k);
 
  525                         c -= db_thread_buf[offset_m] * inv_std_thread_buf[offset_m_k] / reduce_size;
 
  527                         dx_thread_buf(offset_m_k) = dy_thread_buf[offset_m_k] *
 
  528                                                         gamma_thread_buf[offset_m_k] *
 
  529                                                         inv_std_thread_buf[offset_m_k] +
 
  530                                                     b * x_thread_buf[offset_m_k] + c;
 
  540                 threadwise_dy_load.MoveSrcSliceWindow(dy_grid_desc_m_k, thread_copy_bwd_step_m_k);
 
  543                                                          thread_copy_bwd_step_m_k);
 
  545                                                         thread_copy_bwd_step_m_k);
 
  547                                                            thread_copy_bwd_step_m_k);
 
  548                 threadwise_dx_store.MoveDstSliceWindow(dx_grid_desc_m_k, thread_copy_bwd_step_m_k);
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:300
 
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
 
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
 
Definition: gridwise_normalization_bwd_data.hpp:49
 
static constexpr auto thread_buffer_desc_m_k
Definition: gridwise_normalization_bwd_data.hpp:92
 
static constexpr index_t M_BlockTileSize
Definition: gridwise_normalization_bwd_data.hpp:111
 
static constexpr auto I1
Definition: gridwise_normalization_bwd_data.hpp:108
 
static constexpr auto I0
Definition: gridwise_normalization_bwd_data.hpp:107
 
Sequence< MThreadSliceSize, KThreadSliceSize > ThreadBufferLengths_M_K
Definition: gridwise_normalization_bwd_data.hpp:90
 
DYThreadBufferDimAccessOrder ThreadClusterArrangeOrder
Definition: gridwise_normalization_bwd_data.hpp:86
 
typename conditional< DXDstVectorDim==0, Sequence< 1, 0 >, Sequence< 0, 1 > >::type DXThreadBufferDimAccessOrder
Definition: gridwise_normalization_bwd_data.hpp:84
 
static constexpr auto thread_buffer_desc_m
Definition: gridwise_normalization_bwd_data.hpp:95
 
static constexpr auto thread_cluster_desc
Definition: gridwise_normalization_bwd_data.hpp:87
 
static constexpr index_t K_BlockTileSize
Definition: gridwise_normalization_bwd_data.hpp:112
 
typename conditional< DYSrcVectorDim==0, Sequence< 1, 0 >, Sequence< 0, 1 > >::type DYThreadBufferDimAccessOrder
Definition: gridwise_normalization_bwd_data.hpp:76
 
static __device__ void Run(const GridDesc_M_K &dy_grid_desc_m_k, const GridDesc_M_K &x_grid_desc_m_k, const GridDesc_M_K &gamma_grid_desc_m_k, const GridDesc_M_K &mean_grid_desc_m_k, const GridDesc_M_K &inv_std_grid_desc_m_k, const GridDesc_M_K &dx_grid_desc_m_k, index_t num_k_block_tile_iteration, const DYDataType *const __restrict__ p_dy_global, const XDataType *const __restrict__ p_x_global, const GammaDataType *const __restrict__ p_gamma_global, const MeanInvStdDataType *const __restrict__ p_mean_global, const MeanInvStdDataType *const __restrict__ p_inv_std_global, DXDataType *const __restrict__ p_dx_global)
Definition: gridwise_normalization_bwd_data.hpp:114
 
typename conditional< GammaSrcVectorDim==0, Sequence< 1, 0 >, Sequence< 0, 1 > >::type GammaThreadBufferDimAccessOrder
Definition: gridwise_normalization_bwd_data.hpp:80
 
typename conditional< XSrcVectorDim==0, Sequence< 1, 0 >, Sequence< 0, 1 > >::type XThreadBufferDimAccessOrder
Definition: gridwise_normalization_bwd_data.hpp:78
 
typename conditional< MeanInvStdSrcVectorDim==0, Sequence< 1, 0 >, Sequence< 0, 1 > >::type MeanInvStdThreadBufferDimAccessOrder
Definition: gridwise_normalization_bwd_data.hpp:82
 
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_normalization_bwd_data.hpp:98
 
static constexpr auto I2
Definition: gridwise_normalization_bwd_data.hpp:109
 
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition: gridwise_normalization_bwd_data.hpp:73
 
Definition: reduction_functions_blockwise.hpp:28
 
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition: reduction_functions_blockwise.hpp:44
 
Definition: sequence.hpp:43
 
Definition: static_buffer.hpp:16
 
Definition: threadwise_tensor_slice_transfer.hpp:39
 
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
 
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:276
 
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:389
 
Definition: functional.hpp:100
 
Definition: integral_constant.hpp:20
 
Definition: reduction_operator.hpp:37
 
Definition: functional2.hpp:33
 
Definition: unary_element_wise_operation.hpp:308