11 template <
typename GridwisePutElementwise1dFunctor,
12 typename InGrid1dDesc,
14 typename IndexDataType,
16 typename ElementwiseOperation>
18 const InDataType* __restrict__ p_in_global,
19 const IndexDataType* __restrict__ p_indices_global,
20 OutDataType* __restrict__ p_out_global,
21 const ElementwiseOperation elementwise_op)
23 GridwisePutElementwise1dFunctor::Run(
24 in_grid_1d_desc, p_in_global, p_indices_global, p_out_global, elementwise_op);
28 template <
typename InGrid1dDesc,
30 typename IndexDataType,
32 typename ElementwiseOperation,
42 __device__
static void Run(
const InGrid1dDesc& in_grid_1d_desc,
43 const InDataType* __restrict__ p_in_global,
44 const IndexDataType* __restrict__ p_indices_global,
45 OutDataType* __restrict__ p_out_global,
46 const ElementwiseOperation& elementwise_op)
49 const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
50 p_in_global, in_grid_1d_desc.GetElementSpaceSize());
52 const auto indices_global_buf =
53 make_dynamic_buffer<AddressSpaceEnum::Global>(p_indices_global,
54 in_grid_1d_desc.GetElementSpaceSize(),
63 const auto thread_global_offset =
make_multi_index(thread_global_id * InVectorSize);
66 const auto M = in_grid_1d_desc.GetLength(
I0);
67 const index_t loop_step = blockPerGrid * blockSize * InVectorSize;
73 decltype(in_grid_1d_desc),
80 false>{in_grid_1d_desc, thread_global_offset};
82 auto indices_global_load =
85 decltype(in_grid_1d_desc),
92 false>{in_grid_1d_desc, thread_global_offset};
94 index_t num_iter = M / loop_step;
97 in_global_load.Run(in_grid_1d_desc,
103 in_global_load.MoveSrcSliceWindow(in_grid_1d_desc, loop_step_index);
106 [&](
auto iM) { elementwise_op(in_thread_buf(iM), in_thread_buf[iM]); });
108 indices_global_load.Run(in_grid_1d_desc,
114 indices_global_load.MoveSrcSliceWindow(in_grid_1d_desc, loop_step_index);
117 if(indices_thread_buf[iM] >= 0)
122 *(p_out_global + indices_thread_buf[iM]) =
123 ck::type_convert<OutDataType>(in_thread_buf[iM]);
127 atomic_add<OutDataType>(p_out_global + indices_thread_buf[iM],
128 ck::type_convert<OutDataType>(in_thread_buf[iM]));
132 atomic_max<OutDataType>(p_out_global + indices_thread_buf[iM],
133 ck::type_convert<OutDataType>(in_thread_buf[iM]));
138 *(p_out_global + indices_thread_buf[iM]) +=
139 ck::type_convert<OutDataType>(in_thread_buf[iM]);
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_grid_size()
Definition: get_id.hpp:24
InMemoryDataOperationEnum
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__device__ index_t get_block_size()
Definition: get_id.hpp:26
__device__ index_t get_thread_global_1d_id()
Definition: get_id.hpp:18
__global__ void kernel_put_element_1d(const InGrid1dDesc in_grid_1d_desc, const InDataType *__restrict__ p_in_global, const IndexDataType *__restrict__ p_indices_global, OutDataType *__restrict__ p_out_global, const ElementwiseOperation elementwise_op)
Definition: gridwise_put_element_1d.hpp:17
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
Definition: gridwise_put_element_1d.hpp:36
static __device__ void Run(const InGrid1dDesc &in_grid_1d_desc, const InDataType *__restrict__ p_in_global, const IndexDataType *__restrict__ p_indices_global, OutDataType *__restrict__ p_out_global, const ElementwiseOperation &elementwise_op)
Definition: gridwise_put_element_1d.hpp:42
static constexpr auto thread_buffer_desc_m
Definition: gridwise_put_element_1d.hpp:39
static constexpr auto I0
Definition: gridwise_put_element_1d.hpp:37
Definition: data_type.hpp:2831
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: threadwise_tensor_slice_transfer.hpp:214
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31