17 template <
typename AccDistributedTensor_, 
typename ReduceFunc, 
bool WithBroadcast = true>
 
   19                                            const ReduceFunc& reduce_func,
 
   22     using Dstr             = 
typename AccDistributedTensor_::StaticTileDistribution;
 
   23     using DstrEncode       = 
typename Dstr::DstrEncode;
 
   24     using DstrEncodeDetail = 
typename DstrEncode::detail;
 
   26     constexpr 
index_t NDimP = Dstr::get_num_of_dimension_p();
 
   27     constexpr 
index_t NDimR = Dstr::get_num_of_dimension_r();
 
   29     constexpr 
index_t idim_p_lane = NDimP - 1;
 
   32     const auto rs_idx = acc_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx);
 
   34     constexpr 
index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size();
 
   37     static_for<0, thread_buf_size, 1>{}([&](
auto i) {
 
   38         auto v_local = acc_tensor.get_thread_buffer()[i];
 
   43         static_for<0, NDimR, 1>{}([&](
auto idim_r) {
 
   45             if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
 
   47                 constexpr 
index_t r_length = DstrEncode::rs_lengths_[idim_r];
 
   49                 constexpr 
index_t lid_over_rid_derivative =
 
   50                     DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
 
   53                               "wrong! only support power of 2 reduction");
 
   58                 static_for<0, nstage, 1>{}([&](
auto istage) {
 
   60                         lid_over_rid_derivative * (1 << (nstage - istage - 1));
 
   66                     v_local = reduce_func(v_local, v_remote);
 
   71         if constexpr(WithBroadcast)
 
   76             static_for<0, NDimR, 1>{}([&](
auto idim_r) {
 
   78                 if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
 
   80                     const index_t r_id = rs_idx[idim_r];
 
   82                     constexpr 
index_t r_length = DstrEncode::rs_lengths_[idim_r];
 
   84                     constexpr 
index_t lid_over_rid_derivative =
 
   85                         DstrEncodeDetail::ps_over_rs_derivative_[NDimP - 1][idim_r];
 
   88                                   "wrong! only support power of 2 reduction");
 
   93                     static_for<0, nstage, 1>{}([&](
auto istage) {
 
   95                         const bool do_i_hold_reduced_data = r_id < (1 << istage);
 
   97                         constexpr 
index_t lid_delta = lid_over_rid_derivative * (1 << istage);
 
  103                         v_local = do_i_hold_reduced_data ? v_local : v_remote;
 
  109         acc_tensor.get_thread_buffer()(i) = v_local;
 
  117 template <
typename AccDistributedTensor_, 
typename ReduceFunc>
 
  119                                                const ReduceFunc& reduce_func)
 
  121     using Dstr             = 
typename AccDistributedTensor_::StaticTileDistribution;
 
  122     using DstrEncode       = 
typename Dstr::DstrEncode;
 
  123     using DstrEncodeDetail = 
typename DstrEncode::detail;
 
  125     constexpr 
index_t NDimP = Dstr::get_num_of_dimension_p();
 
  126     constexpr 
index_t NDimR = Dstr::get_num_of_dimension_r();
 
  128     constexpr 
index_t idim_p_lane = NDimP - 1;
 
  130     constexpr 
index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size();
 
  134         auto v_local = acc_tensor.get_thread_buffer()[i];
 
  141             if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
 
  143                 constexpr 
index_t r_length = DstrEncode::rs_lengths_[idim_r];
 
  145                 constexpr 
index_t lid_over_rid_derivative =
 
  146                     DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
 
  149                               "wrong! only support power of 2 reduction");
 
  157                         __lane_id() ^ (
number<lid_over_rid_derivative << istage.
value>{}.value);
 
  163                     v_local = reduce_func(v_local, v_remote);
 
  168         acc_tensor.get_thread_buffer()(i) = v_local;
 
  173 template <
typename AccDistributedTensor_,
 
  174           typename InDistributedTensor_,
 
  178                                       const InDistributedTensor_& in_tensor,
 
  180                                       const ReduceFunc& reduce_func)
 
  186     constexpr 
auto in_reduce_dims = 
sequence<InReduceDims...>{};
 
  188     constexpr 
index_t ndim_in        = InDistributedTensor_::get_num_of_dimension();
 
  189     constexpr 
index_t ndim_in_reduce = in_reduce_dims.size();
 
  190     constexpr 
index_t ndim_in_free   = ndim_in - ndim_in_reduce;
 
  192     constexpr 
auto in_free_dims_arr = [&] {
 
  195         for(
index_t i = 0; i < ndim_reduce; i++)
 
  197             is_free_dims(in_reduce_dims[i]) = 
false;
 
  204         for(
index_t i = 0; i < ndim_in; i++)
 
  208                 in_free_dims(cnt) = i;
 
  217     constexpr 
auto in_free_dims = 
TO_SEQUENCE(is_free_dims_arr, ndim_in_free);
 
  220     constexpr 
auto spans = InDistributedTensor_::get_distributed_spans();
 
  225         constexpr 
auto acc_dstr_idx = 
make_tuple(dstr_idx_i0);
 
  227         auto acc = acc_tensor[acc_dstr_idx];
 
  231             constexpr 
auto in_dstr_idx = 
make_tuple(dstr_idx_i0, dstr_idx_i1);
 
  233             const auto in = in_tensor[in_dstr_idx];
 
  235             acc = reduce_func(acc, in);
 
  238         acc_tensor(acc_dstr_idx) = acc;
 
  247 template <
typename AccDataType_,
 
  248           typename InDistributedTensor_,
 
  251           typename InDataType_>
 
  254                                       const ReduceFunc& reduce_func,
 
  255                                       const InDataType_& reduce_init)
 
  257     using InDataType  = 
typename InDistributedTensor_::DataType;
 
  263     constexpr 
auto acc_dstr =
 
  265             InDistributedTensor_::get_tile_distribution().get_static_tile_distribution_encoding(),
 
  268     auto acc_tensor = make_static_distributed_tensor<AccDataType>(acc_dstr);
 
  284 template <
typename InDistributedTensor_>
 
  298         constexpr 
auto acc_dstr =
 
  300                 InDistributedTensor::get_tile_distribution()
 
  301                     .get_static_tile_distribution_encoding(),
 
  304         auto dst_ = make_static_distributed_tensor<InDataType>(acc_dstr);
 
  313         constexpr 
auto spans = InDistributedTensor::get_distributed_spans();
 
  322     template <
typename ReduceFunc,
 
  323               typename ReduceSyncFunc,
 
  326                                         const ReduceSyncFunc& reduce_sync_func,
 
  327                                         ReducePacksPerXDim = {}) 
const 
  329         constexpr 
auto spans = InDistributedTensor::get_distributed_spans();
 
  331         constexpr 
auto row_y_unpacks = [&]() {
 
  332             constexpr 
auto row_y_lengths = 
typename decltype(spans[
number<1>{}])::Impl{};
 
  333             constexpr 
auto row_y_size =
 
  335             constexpr 
auto row_y_packs = ReducePacksPerXDim{}.at(number<1>{});
 
  337             static_assert(row_y_size % row_y_packs == 0);
 
  339             constexpr 
auto row_y_slice_size = row_y_size / row_y_packs;
 
  341             constexpr 
auto slice_info = 
slice_sequence(row_y_lengths, number<row_y_slice_size>{});
 
  342             constexpr 
auto unpacks    = slice_info[number<1>{}];
 
  351             constexpr 
auto acc_dstr_idx = 
make_tuple(dstr_idx_i0);
 
  353             auto acc = acc_tensor[acc_dstr_idx];
 
  357                 [&](
auto... dstr_idx_i1) {
 
  358                     acc = reduce_func(acc, 
t[
make_tuple(dstr_idx_i0, dstr_idx_i1)]...);
 
  362             acc_tensor(acc_dstr_idx) = acc;
 
  371     template <
typename ReduceFunc>
 
  382 template <
typename T>
 
#define CK_TILE_DEVICE
Definition: config.hpp:40
 
#define CK_TILE_HOST_DEVICE_EXTERN
Definition: config.hpp:43
 
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
 
CK_TILE_HOST_DEVICE auto get_partition_index(Distribution)
Definition: tile_distribution.hpp:22
 
constexpr CK_TILE_HOST_DEVICE auto make_reduce_tile_distribution_encoding(InDstr, sequence< InReduceDimXs... > reduce_dim_xs_in)
Definition: tile_distribution_encoding.hpp:864
 
Definition: cluster_descriptor.hpp:13
 
CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func)
Definition: block_reduce.hpp:118
 
CK_TILE_DEVICE T warp_shuffle_up(const T &v_local, uint32_t lane_delta)
Definition: utility.hpp:31
 
__host__ __device__ multiplies() -> multiplies< void, void >
FIXME: create macro to replace 'host device' and nothing more.
 
CK_TILE_DEVICE void sweep_tile_uspan(TileDistributedSpan_, const F &f, Unpacks={})
Definition: sweep_tile.hpp:37
 
constexpr CK_TILE_HOST_DEVICE bool is_power_of_two_integer(int32_t x)
Definition: math.hpp:462
 
CK_TILE_DEVICE T warp_shuffle(const T &v_local, uint32_t src_lane)
Definition: utility.hpp:63
 
CK_TILE_DEVICE void tile_elementwise_inout(const InOutElementFunc &inout_element_func, InOutDstrTensors &... inout_dstr_tensors)
Definition: tile_elementwise.hpp:23
 
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_ &acc_tensor, const ReduceFunc &reduce_func, bool_constant< WithBroadcast >={})
Definition: block_reduce.hpp:18
 
int32_t index_t
Definition: integer.hpp:9
 
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
 
CK_TILE_DEVICE T warp_shuffle_down(const T &v_local, uint32_t lane_delta)
Definition: utility.hpp:48
 
constexpr CK_TILE_HOST_DEVICE index_t reduce_on_sequence(Seq, Reduce f, number< Init >)
Definition: sequence.hpp:973
 
constexpr CK_TILE_HOST_DEVICE int32_t integer_log2_floor(int32_t x)
Definition: math.hpp:455
 
CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_ &acc_tensor, const InDistributedTensor_ &in_tensor, sequence< InReduceDims... >, const ReduceFunc &reduce_func)
Definition: block_reduce.hpp:177
 
CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F &f)
Definition: sweep_tile.hpp:20
 
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:343
 
constexpr auto slice_sequence(Seq, number< SliceSize >, Mask=typename uniform_sequence_gen< Seq::size(), 1 >::type{})
Definition: sequence.hpp:1240
 
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:498
 
CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T &, const typename T::DataType &) -> BlockReduce2D< T >
 
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:1017
 
constexpr bool is_same_v
Definition: type.hpp:283
 
Definition: block_reduce.hpp:286
 
remove_cvref_t< InDistributedTensor_ > InDistributedTensor
Definition: block_reduce.hpp:287
 
CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc &reduce_func, const ReduceSyncFunc &reduce_sync_func, ReducePacksPerXDim={}) const
Definition: block_reduce.hpp:325
 
InDataType reduce_init
Definition: block_reduce.hpp:378
 
constexpr CK_TILE_HOST_DEVICE auto MakeDstBlockTile() const
Definition: block_reduce.hpp:295
 
InDistributedTensor t
Definition: block_reduce.hpp:377
 
typename InDistributedTensor::DataType InDataType
Definition: block_reduce.hpp:288
 
CK_TILE_HOST_DEVICE BlockReduce2D(const InDistributedTensor &t_, const InDataType &reduce_init_)
Definition: block_reduce.hpp:290
 
CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc &reduce_func) const
Definition: block_reduce.hpp:372
 
constexpr CK_TILE_HOST_DEVICE auto get_reduce_length_y() const
Definition: block_reduce.hpp:311
 
A fixed-size array container similar to std::array with additional utilities.
Definition: array.hpp:43
 
Definition: integral_constant.hpp:13
 
static constexpr value_type value
Definition: integral_constant.hpp:16
 
Definition: sequence.hpp:52
 
Definition: functional.hpp:43
 
#define TO_SEQUENCE(a, n)
Definition: to_sequence.hpp:10