11 template <
typename Problem_, 
typename Policy_ = 
void>
 
   17     static constexpr 
bool kFastFDiv = Problem::kFastFDiv;
 
   18     static constexpr 
bool kWelford  = Problem::kWelford;
 
   26     template <
typename XDistributedTensor_,
 
   27               typename MeanDistributedTensor_,
 
   28               typename VarDistributedTensor_>
 
   30                                    MeanDistributedTensor_& mean_tensor,
 
   31                                    VarDistributedTensor_& var_tensor,
 
   33                                    const int& max_count_)
 
   38         constexpr 
auto spans = XDistributedTensor_::get_distributed_spans();
 
   41             if(cur_count_ < max_count_)
 
   45                     constexpr 
auto in_dstr_idx  = 
make_tuple(dstr_idx_i0, dstr_idx_i1);
 
   46                     constexpr 
auto out_dstr_idx = 
make_tuple(dstr_idx_i0);
 
   48                     auto x = ck_tile::type_convert<ComputeDataType>(x_tensor[in_dstr_idx]);
 
   52                                        var_tensor(out_dstr_idx),
 
   59                         mean_tensor(out_dstr_idx) += x;
 
   60                         var_tensor(out_dstr_idx) += x * x;
 
   67     template <
typename XDistributedTensor_>
 
   70         static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>, 
"wrong!");
 
   76                 XDistributedTensor_::get_tile_distribution()
 
   77                     .get_static_tile_distribution_encoding(),
 
   80         auto tensor = make_static_distributed_tensor<ComputeDataType>(dstr);
 
   85     template <
typename XDistributedTensor_>
 
   87     operator()(
const XDistributedTensor_& x_tensor, 
int& cur_count_, 
const int& max_count_)
 
   89         auto mean_tensor = MakeMeanVarBlockTile<XDistributedTensor_>();
 
   90         auto var_tensor  = MakeMeanVarBlockTile<XDistributedTensor_>();
 
   94         (*this)(x_tensor, mean_tensor, var_tensor, cur_count_, max_count_);
 
  100 template <
typename Problem_, 
typename Policy_ = 
void>
 
  105     static constexpr 
bool kWelford  = Problem::kWelford;
 
  107     template <
typename MeanDistributedTensor_, 
typename VarDistributedTensor_>
 
  109     operator()(MeanDistributedTensor_& mean_tensor, VarDistributedTensor_& var_tensor, 
int& count)
 
  111         using Dstr             = 
typename MeanDistributedTensor_::StaticTileDistribution;
 
  112         using DstrEncode       = 
typename Dstr::DstrEncode;
 
  113         using DstrEncodeDetail = 
typename DstrEncode::detail;
 
  115         static_assert(std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
 
  118         constexpr 
index_t NDimP = Dstr::get_num_of_dimension_p();
 
  119         constexpr 
index_t NDimR = Dstr::get_num_of_dimension_r();
 
  121         constexpr 
index_t idim_p_lane = NDimP - 1;
 
  127         constexpr 
index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
 
  128         static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
 
  130         const int original_count = count;
 
  134             auto v_local_mean  = mean_tensor.get_thread_buffer()[i];
 
  135             auto v_local_var   = var_tensor.get_thread_buffer()[i];
 
  136             auto v_local_count = original_count;
 
  143                 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
 
  145                     constexpr 
index_t r_length = DstrEncode::rs_lengths_[idim_r];
 
  147                     constexpr 
index_t lid_over_rid_derivative =
 
  148                         DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
 
  151                                   "wrong! only support power of 2 reduction");
 
  160                             (
number<lid_over_rid_derivative << istage.
value>{}.value);
 
  163                         const auto v_remote_mean = 
warp_shuffle(v_local_mean, src_lane);
 
  164                         const auto v_remote_var  = 
warp_shuffle(v_local_var, src_lane);
 
  167                             const auto v_remote_count = 
warp_shuffle(v_local_count, src_lane);
 
  170                             welford_merge(v_local_mean,
 
  180                             v_local_mean += v_remote_mean;
 
  181                             v_local_var += v_remote_var;
 
  187             mean_tensor.get_thread_buffer()(i) = v_local_mean;
 
  188             var_tensor.get_thread_buffer()(i)  = v_local_var;
 
  191                 count = v_local_count;
 
  197 template <
typename Problem_, 
typename Policy_ = 
void>
 
  202     static constexpr 
bool kFastFDiv = Problem::kFastFDiv;
 
  203     static constexpr 
bool kWelford  = Problem::kWelford;
 
  204     using smem_dtype                = std::conditional_t<kWelford, fp32x4_t, fp32x2_t>;
 
  206     template <
typename MeanDistributedTensor_>
 
  209         constexpr 
index_t num_reduce_warps = [&]() {
 
  210             using Dstr             = 
typename MeanDistributedTensor_::StaticTileDistribution;
 
  211             using DstrEncode       = 
typename Dstr::DstrEncode;
 
  212             using DstrEncodeDetail = 
typename DstrEncode::detail;
 
  214             constexpr 
index_t NDimR = Dstr::get_num_of_dimension_r();
 
  216             constexpr 
index_t idim_p_warp = 0;
 
  220                 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r])
 
  222                     constexpr 
index_t r_length = DstrEncode::rs_lengths_[idim_r];
 
  228         return num_reduce_warps;
 
  232     template <
typename MeanDistributedTensor_>
 
  238         constexpr 
index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
 
  254         return num_warps * 4 * thread_buf_size * 
sizeof(float);
 
  257     template <
typename MeanDistributedTensor_, 
typename VarDistributedTensor_>
 
  259                                    VarDistributedTensor_& var_tensor,
 
  263         using DataType = 
typename MeanDistributedTensor_::DataType;
 
  264         using Dstr     = 
typename MeanDistributedTensor_::StaticTileDistribution;
 
  268         static_assert(std::is_same_v<Dstr, typename VarDistributedTensor_::StaticTileDistribution>,
 
  271         constexpr 
index_t thread_buf_size = MeanDistributedTensor_::get_thread_buffer_size();
 
  272         static_assert(thread_buf_size == VarDistributedTensor_::get_thread_buffer_size());
 
  278         constexpr 
auto num_reduce_warps = GetReduceWarps<MeanDistributedTensor_>();
 
  280         const index_t smem_offset       = warp_id;
 
  283         if constexpr(num_reduce_warps == 1)
 
  291                 local_scratch_[0] = bit_cast<float>(mean_tensor.get_thread_buffer()[i]);
 
  292                 local_scratch_[1] = bit_cast<float>(var_tensor.get_thread_buffer()[i]);
 
  295                     local_scratch_[2] = bit_cast<float>(count);
 
  297                 smem_ptr[smem_offset + i * num_warps] = local_scratch_;
 
  303         index_t local_warp_id = warp_id / num_reduce_warps;
 
  304         index_t local_smem_os = local_warp_id * num_reduce_warps;
 
  305         smem_dtype all_scratch[thread_buf_size * num_reduce_warps];
 
  308                 all_scratch[i_0 * num_reduce_warps + i_1] =
 
  309                     smem_ptr[i_0 * num_warps + local_smem_os + i_1];
 
  318             auto v_local      = all_scratch[i_0 * num_reduce_warps];
 
  319             auto v_local_mean = bit_cast<DataType>(v_local[0]);
 
  320             auto v_local_var  = bit_cast<DataType>(v_local[1]);
 
  321             int v_local_count = kWelford ? bit_cast<int>(v_local[2]) : 0;
 
  324             static_for<0, num_reduce_warps - 1, 1>{}([&](
auto i_1_n1) {
 
  326                 const smem_dtype v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
 
  327                 const auto v_remote_mean  = bit_cast<DataType>(v_remote[0]);
 
  328                 const auto v_remote_var   = bit_cast<DataType>(v_remote[1]);
 
  331                     const auto v_remote_count = bit_cast<int>(v_remote[2]);
 
  333                     welford_merge(v_local_mean,
 
  343                     v_local_mean += v_remote_mean;
 
  344                     v_local_var += v_remote_var;
 
  348             mean_tensor.get_thread_buffer()(i_0) = v_local_mean;
 
  349             var_tensor.get_thread_buffer()(i_0)  = v_local_var;
 
  351                 count = v_local_count;
 
  360 template <
typename BlockShape>
 
  364     using S                   = BlockShape;
 
  365     index_t LastloopN         = row_size % S::Block_N == 0 ? S::Block_N : row_size % S::Block_N;
 
  366     constexpr 
index_t NThread = S::WarpPerBlock_N * S::ThreadPerWarp_N;
 
  368     index_t iN0               = LastloopN / (S::Vector_N * S::ThreadPerWarp_N);
 
  369     index_t iN1               = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) / S::Vector_N;
 
  370     index_t N2                = (LastloopN % (S::Vector_N * S::ThreadPerWarp_N)) % S::Vector_N;
 
  371     index_t iN3               = iNLane < iN1 ? S::Vector_N : iNLane == iN1 ? N2 : 0;
 
  372     return iN0 * S::Vector_N + iN3;
 
  374     using S_                            = BlockShape;
 
  375     constexpr 
index_t ThreadsPerBlock_N = S_::WarpPerBlock_N * S_::ThreadPerWarp_N;
 
  378     const index_t element_per_row = row_size / S_::Vector_N;
 
  384         index_t _a = lane_id_n < element_per_row ? 1 : 0;
 
  386         lane_id_n += ThreadsPerBlock_N;
 
  388     return cnt * S_::Vector_N;
 
  392 template <
typename VarDistributedTensor_, 
bool FastFdiv_ = false>
 
  397     using DataType = 
typename VarDistributedTensor_::DataType;
 
  400             if(FastFdiv_ && std::is_same_v<DataType, float>)
 
  402                 x = x * __builtin_amdgcn_rcpf(type_convert<DataType>(count));
 
  406                 x = x / type_convert<DataType>(count);
 
#define CK_TILE_DEVICE
Definition: config.hpp:40
 
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
 
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:844
 
Definition: cluster_descriptor.hpp:13
 
CK_TILE_DEVICE index_t get_lane_id()
Definition: arch.hpp:72
 
constexpr CK_TILE_HOST_DEVICE bool is_power_of_two_integer(int32_t x)
Definition: math.hpp:462
 
CK_TILE_DEVICE T warp_shuffle(const T &v_local, uint32_t src_lane)
Definition: utility.hpp:63
 
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition: tile_elementwise.hpp:23
 
int32_t index_t
Definition: integer.hpp:9
 
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
 
constexpr CK_TILE_DEVICE void block_tile_welford_post_scale_var(VarDistributedTensor_ &var_tensor, int count, bool_constant< FastFdiv_ >={})
Definition: block_norm_reduce.hpp:393
 
constexpr CK_TILE_HOST_DEVICE int32_t integer_log2_floor(int32_t x)
Definition: math.hpp:455
 
CK_TILE_DEVICE index_t get_warp_id()
Definition: arch.hpp:74
 
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
 
CK_TILE_DEVICE index_t get_thread_id()
Definition: arch.hpp:79
 
CK_TILE_DEVICE void welford_update(T &mean, T &var, T x, int count, bool_constant< kFastFDiv >={})
Definition: thread_welford.hpp:11
 
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
 
CK_TILE_DEVICE void clear_tile(DstrTensors &dstr_tensor)
Definition: tile_elementwise.hpp:177
 
constexpr CK_TILE_DEVICE index_t block_tile_welford_calculate_max_count(int row_size)
Definition: block_norm_reduce.hpp:361
 
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:498
 
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
 
Definition: block_norm_reduce.hpp:199
 
typename Problem::BlockShape BlockShape
Definition: block_norm_reduce.hpp:201
 
remove_cvref_t< Problem_ > Problem
Definition: block_norm_reduce.hpp:200
 
static constexpr CK_TILE_HOST_DEVICE index_t GetSmemSize()
Definition: block_norm_reduce.hpp:233
 
std::conditional_t< kWelford, fp32x4_t, fp32x2_t > smem_dtype
Definition: block_norm_reduce.hpp:204
 
CK_TILE_DEVICE void operator()(MeanDistributedTensor_ &mean_tensor, VarDistributedTensor_ &var_tensor, int &count, void *smem)
Definition: block_norm_reduce.hpp:258
 
static constexpr CK_TILE_DEVICE index_t GetReduceWarps()
Definition: block_norm_reduce.hpp:207
 
Definition: block_norm_reduce.hpp:13
 
constexpr CK_TILE_DEVICE BlockNormReduce()
Definition: block_norm_reduce.hpp:20
 
typename Problem::ComputeDataType ComputeDataType
Definition: block_norm_reduce.hpp:16
 
static CK_TILE_DEVICE auto MakeMeanVarBlockTile()
Definition: block_norm_reduce.hpp:68
 
remove_cvref_t< Problem_ > Problem
Definition: block_norm_reduce.hpp:14
 
static constexpr bool kWelford
Definition: block_norm_reduce.hpp:18
 
CK_TILE_DEVICE void operator()(const XDistributedTensor_ &x_tensor, MeanDistributedTensor_ &mean_tensor, VarDistributedTensor_ &var_tensor, int &cur_count_, const int &max_count_)
Definition: block_norm_reduce.hpp:29
 
CK_TILE_DEVICE auto operator()(const XDistributedTensor_ &x_tensor, int &cur_count_, const int &max_count_)
Definition: block_norm_reduce.hpp:87
 
typename Problem::XDataType XDataType
Definition: block_norm_reduce.hpp:15
 
static constexpr bool kFastFDiv
Definition: block_norm_reduce.hpp:17
 
Definition: block_norm_reduce.hpp:102
 
CK_TILE_DEVICE void operator()(MeanDistributedTensor_ &mean_tensor, VarDistributedTensor_ &var_tensor, int &count)
Definition: block_norm_reduce.hpp:109
 
static constexpr bool kWelford
Definition: block_norm_reduce.hpp:105
 
remove_cvref_t< Problem_ > Problem
Definition: block_norm_reduce.hpp:103
 
static constexpr bool kFastFDiv
Definition: block_norm_reduce.hpp:104
 
Definition: integral_constant.hpp:13
 
static constexpr value_type value
Definition: integral_constant.hpp:16
 
Definition: sequence.hpp:52
 
Definition: functional.hpp:43