/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_put_element_1d.hpp Source File
gridwise_put_element_1d.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
8 
9 namespace ck {
10 
11 template <typename GridwisePutElementwise1dFunctor,
12  typename InGrid1dDesc,
13  typename InDataType,
14  typename IndexDataType,
15  typename OutDataType,
16  typename ElementwiseOperation>
17 __global__ void kernel_put_element_1d(const InGrid1dDesc in_grid_1d_desc,
18  const InDataType* __restrict__ p_in_global,
19  const IndexDataType* __restrict__ p_indices_global,
20  OutDataType* __restrict__ p_out_global,
21  const ElementwiseOperation elementwise_op)
22 {
23  GridwisePutElementwise1dFunctor::Run(
24  in_grid_1d_desc, p_in_global, p_indices_global, p_out_global, elementwise_op);
25 }
26 
27 // output[indices] = input
28 template <typename InGrid1dDesc,
29  typename InDataType,
30  typename IndexDataType,
31  typename OutDataType,
32  typename ElementwiseOperation,
34  index_t InVectorSize>
36 {
37  static constexpr auto I0 = Number<0>{};
38 
39  static constexpr auto thread_buffer_desc_m =
41 
42  __device__ static void Run(const InGrid1dDesc& in_grid_1d_desc,
43  const InDataType* __restrict__ p_in_global,
44  const IndexDataType* __restrict__ p_indices_global,
45  OutDataType* __restrict__ p_out_global,
46  const ElementwiseOperation& elementwise_op)
47  {
48  // Global Memory
49  const auto in_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
50  p_in_global, in_grid_1d_desc.GetElementSpaceSize());
51 
52  const auto indices_global_buf =
53  make_dynamic_buffer<AddressSpaceEnum::Global>(p_indices_global,
54  in_grid_1d_desc.GetElementSpaceSize(),
56 
57  // VGPR
60 
61  // Thread id, Block id and index
62  const index_t thread_global_id = get_thread_global_1d_id();
63  const auto thread_global_offset = make_multi_index(thread_global_id * InVectorSize);
64  const index_t blockSize = get_block_size();
65  const index_t blockPerGrid = get_grid_size();
66  const auto M = in_grid_1d_desc.GetLength(I0);
67  const index_t loop_step = blockPerGrid * blockSize * InVectorSize;
68  const auto loop_step_index = make_multi_index(loop_step);
69 
70  auto in_global_load =
72  InDataType,
73  decltype(in_grid_1d_desc),
74  decltype(thread_buffer_desc_m),
75  Sequence<InVectorSize>, // SliceLengths
76  Sequence<0>, // DimAccessOrder
77  0, // SrcVectorDim
78  InVectorSize, // ScalarPerVector
79  1, // SrcScalarStrideInVector
80  false>{in_grid_1d_desc, thread_global_offset};
81 
82  auto indices_global_load =
84  IndexDataType,
85  decltype(in_grid_1d_desc),
86  decltype(thread_buffer_desc_m),
87  Sequence<InVectorSize>, // SliceLengths
88  Sequence<0>, // DimAccessOrder
89  0, // SrcVectorDim
90  InVectorSize, // ScalarPerVector
91  1, // SrcScalarStrideInVector
92  false>{in_grid_1d_desc, thread_global_offset};
93 
94  index_t num_iter = M / loop_step;
95  do
96  {
97  in_global_load.Run(in_grid_1d_desc,
98  in_global_buf,
100  make_tuple(I0),
101  in_thread_buf);
102 
103  in_global_load.MoveSrcSliceWindow(in_grid_1d_desc, loop_step_index);
104 
106  [&](auto iM) { elementwise_op(in_thread_buf(iM), in_thread_buf[iM]); });
107 
108  indices_global_load.Run(in_grid_1d_desc,
109  indices_global_buf,
111  make_tuple(I0),
112  indices_thread_buf);
113 
114  indices_global_load.MoveSrcSliceWindow(in_grid_1d_desc, loop_step_index);
115 
116  static_for<0, InVectorSize, 1>{}([&](auto iM) {
117  if(indices_thread_buf[iM] >= 0)
118  {
119  if constexpr(MemOp == InMemoryDataOperationEnum::Set)
120  {
121  // User should guarantee each index in p_indices_global is different
122  *(p_out_global + indices_thread_buf[iM]) =
123  ck::type_convert<OutDataType>(in_thread_buf[iM]);
124  }
125  else if constexpr(MemOp == InMemoryDataOperationEnum::AtomicAdd)
126  {
127  atomic_add<OutDataType>(p_out_global + indices_thread_buf[iM],
128  ck::type_convert<OutDataType>(in_thread_buf[iM]));
129  }
130  else if constexpr(MemOp == InMemoryDataOperationEnum::AtomicMax)
131  {
132  atomic_max<OutDataType>(p_out_global + indices_thread_buf[iM],
133  ck::type_convert<OutDataType>(in_thread_buf[iM]));
134  }
135  else if constexpr(MemOp == InMemoryDataOperationEnum::Add)
136  {
137  // User should guarantee each index in p_indices_global is different
138  *(p_out_global + indices_thread_buf[iM]) +=
139  ck::type_convert<OutDataType>(in_thread_buf[iM]);
140  }
141  else
142  {
143  static_assert(MemOp == InMemoryDataOperationEnum::Set ||
147  }
148  }
149  });
150 
151  } while(--num_iter);
152  }
153 };
154 
155 } // namespace ck
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_grid_size()
Definition: get_id.hpp:24
InMemoryDataOperationEnum
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__device__ index_t get_block_size()
Definition: get_id.hpp:26
__device__ index_t get_thread_global_1d_id()
Definition: get_id.hpp:18
__global__ void kernel_put_element_1d(const InGrid1dDesc in_grid_1d_desc, const InDataType *__restrict__ p_in_global, const IndexDataType *__restrict__ p_indices_global, OutDataType *__restrict__ p_out_global, const ElementwiseOperation elementwise_op)
Definition: gridwise_put_element_1d.hpp:17
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
Definition: gridwise_put_element_1d.hpp:36
static __device__ void Run(const InGrid1dDesc &in_grid_1d_desc, const InDataType *__restrict__ p_in_global, const IndexDataType *__restrict__ p_indices_global, OutDataType *__restrict__ p_out_global, const ElementwiseOperation &elementwise_op)
Definition: gridwise_put_element_1d.hpp:42
static constexpr auto thread_buffer_desc_m
Definition: gridwise_put_element_1d.hpp:39
static constexpr auto I0
Definition: gridwise_put_element_1d.hpp:37
Definition: data_type.hpp:2831
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: threadwise_tensor_slice_transfer.hpp:214
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31