13 template <
typename GridwiseElementwise1dFunctor,
 
   14           typename InGrid1dDescTuple,
 
   15           typename OutGrid1dDescTuple,
 
   16           typename InDataTypePointerTuple,
 
   17           typename OutDataTypePointerTuple,
 
   18           typename ElementwiseOperation,
 
   19           typename UnaryOperation,
 
   22                                       const OutGrid1dDescTuple out_grid_1d_desc_tuple,
 
   23                                       const InDataTypePointerTuple p_in_global_tuple,
 
   24                                       const OutDataTypePointerTuple p_out_global_tuple,
 
   25                                       const ElementwiseOperation elementwise_op,
 
   26                                       const UnaryOperation unary_op,
 
   29     GridwiseElementwise1dFunctor::Run(in_grid_1d_desc_tuple,
 
   30                                       out_grid_1d_desc_tuple,
 
   38 template <
typename InGrid1dDescTuple,
 
   39           typename OutGrid1dDescTuple,
 
   40           typename InDataTypePointerTuple,
 
   41           typename OutDataTypePointerTuple,
 
   42           typename ElementwiseOperation,
 
   43           typename UnaryOperation,
 
   46           typename InScalarPerVectorSeq,
 
   47           typename OutScalarPerVectorSeq>
 
   53     static_assert(
NumInput == InScalarPerVectorSeq::Size() &&
 
   54                       NumOutput == OutScalarPerVectorSeq::Size() &&
 
   55                       NumInput == InGrid1dDescTuple::Size() &&
 
   57                   "Tuple size is inconsistent with the number of in/out!");
 
   66     __device__ 
static void Run(
const InGrid1dDescTuple in_grid_1d_desc_tuple,
 
   67                                const OutGrid1dDescTuple out_grid_1d_desc_tuple,
 
   68                                const InDataTypePointerTuple p_in_global_tuple,
 
   69                                const OutDataTypePointerTuple p_out_global_tuple,
 
   70                                const ElementwiseOperation elementwise_op,
 
   71                                const UnaryOperation unary_op,
 
   78                 using DataTypePointer = 
remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
 
   87                 using DataTypePointer = 
remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
 
   96                 static_assert(in_grid_1d_desc_tuple[I].GetNumOfDimension() == 1);
 
   98                 return make_dynamic_buffer<AddressSpaceEnum::Global>(
 
   99                     p_in_global_tuple[I], in_grid_1d_desc_tuple[I].GetElementSpaceSize());
 
  105                 static_assert(out_grid_1d_desc_tuple[I].GetNumOfDimension() == 1);
 
  107                 return make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  108                     p_out_global_tuple[I], out_grid_1d_desc_tuple[I].GetElementSpaceSize());
 
  112         const auto thread_global_offset = 
make_multi_index(thread_global_id * MPerThread);
 
  116         const auto M               = in_grid_1d_desc_tuple[
I0].GetLength(
I0);
 
  117         const index_t loop_step    = blockPerGrid * blockSize * MPerThread;
 
  122                 using DataTypePointer = 
remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
 
  127                                                         decltype(in_grid_1d_desc_tuple[I]),
 
  132                                                         InScalarPerVectorSeq::At(
 
  135                                                         false>{in_grid_1d_desc_tuple[I],
 
  136                                                                thread_global_offset};
 
  142                 using DataTypePointer = 
remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
 
  148                                                           decltype(out_grid_1d_desc_tuple[I]),
 
  153                                                           OutScalarPerVectorSeq::At(I),
 
  157                     out_grid_1d_desc_tuple[I], thread_global_offset, 
PassThroughOp{});
 
  161         index_t num_iter = M / (loop_step);
 
  165                 in_global_load_tuple(I).Run(in_grid_1d_desc_tuple[I],
 
  166                                             in_global_buf_tuple[I],
 
  169                                             in_thread_buf_tuple(I));
 
  171                 in_global_load_tuple(I).MoveSrcSliceWindow(in_grid_1d_desc_tuple[I],
 
  179                     [&](
auto I) -> 
auto& { 
return in_thread_buf_tuple(I)(iM); },
 
  185                     [&](
auto I) -> 
auto& { 
return out_thread_buf_tuple(I)(iM); },
 
  188                 unpack2(unary_op, uop_data_refs, uop_data_refs);
 
  192                     [&](
auto I) -> 
auto& { 
return in_thread_buf_tuple(I)(iM); },
 
  197                     [&](
auto I) -> 
auto& { 
return in_thread_buf_tuple(I)(iM); },
 
  200                 unpack2(scale_op, sop_out_data_refs, sop_in_data_refs);
 
  204                     [&](
auto I) -> 
const auto& { 
return in_thread_buf_tuple(I)(iM); },
 
  207                 unpack2(elementwise_op, out_data_refs, in_data_refs);
 
  213                                               out_thread_buf_tuple[I],
 
  214                                               out_grid_1d_desc_tuple[I],
 
  215                                               out_global_buf_tuple(I));
 
  217                 out_global_store_tuple(I).MoveDstSliceWindow(out_grid_1d_desc_tuple[I],
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
__host__ constexpr __device__ auto unpack2(F &&f, X &&x, Y &&y)
Definition: functional4.hpp:55
 
__device__ index_t get_grid_size()
Definition: get_id.hpp:27
 
__host__ constexpr __device__ auto generate_tie(F &&f, Number< N >)
Definition: tuple_helper.hpp:34
 
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
 
typename remove_pointer< T >::type remove_pointer_t
Definition: type.hpp:300
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
__device__ index_t get_block_size()
Definition: get_id.hpp:29
 
__global__ void kernel_elementwise_1d(const InGrid1dDescTuple in_grid_1d_desc_tuple, const OutGrid1dDescTuple out_grid_1d_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const ElementwiseOperation elementwise_op, const UnaryOperation unary_op, const Scale scale_op)
Definition: gridwise_elementwise_1d_scale.hpp:21
 
__device__ index_t get_thread_global_1d_id()
Definition: get_id.hpp:21
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
int32_t index_t
Definition: ck.hpp:297
 
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:295
 
Definition: gridwise_elementwise_1d_scale.hpp:49
 
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_elementwise_1d_scale.hpp:64
 
static constexpr index_t NumOutput
Definition: gridwise_elementwise_1d_scale.hpp:51
 
static __device__ void Run(const InGrid1dDescTuple in_grid_1d_desc_tuple, const OutGrid1dDescTuple out_grid_1d_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const ElementwiseOperation elementwise_op, const UnaryOperation unary_op, const Scale scale_op)
Definition: gridwise_elementwise_1d_scale.hpp:66
 
static constexpr auto thread_buffer_desc_m
Definition: gridwise_elementwise_1d_scale.hpp:61
 
static constexpr auto I0
Definition: gridwise_elementwise_1d_scale.hpp:59
 
static constexpr index_t NumInput
Definition: gridwise_elementwise_1d_scale.hpp:50
 
Definition: sequence.hpp:43
 
Definition: static_buffer.hpp:16
 
Definition: threadwise_tensor_slice_transfer.hpp:39
 
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
 
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33
 
Definition: unary_element_wise_operation.hpp:308