/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp Source File
gridwise_elementwise_2d.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
15 
16 namespace ck {
17 
18 template <typename GridwiseElementwiseFunctor,
19  typename InGridDescTuple,
20  typename OutGridDescTuple,
21  typename InDataTypePointerTuple,
22  typename OutDataTypePointerTuple,
23  typename Block2TileMap,
24  typename ElementwiseOperation>
25 __global__ void
26 #if CK_USE_LAUNCH_BOUNDS
28 #endif
29  kernel_elementwise(const InGridDescTuple in_grid_desc_tuple,
30  const OutGridDescTuple out_grid_desc_tuple,
31  const InDataTypePointerTuple p_in_global_tuple,
32  const OutDataTypePointerTuple p_out_global_tuple,
33  const Block2TileMap block_2_tile_map,
34  const ElementwiseOperation elementwise_op)
35 {
36  GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,
37  out_grid_desc_tuple,
38  p_in_global_tuple,
39  p_out_global_tuple,
40  block_2_tile_map,
41  elementwise_op);
42 }
43 
44 template <typename GridwiseElementwiseFunctor,
45  typename InAGridDescTuple,
46  typename InBGridDescTuple,
47  typename OutAGridDescTuple,
48  typename OutBGridDescTuple,
49  typename InDataTypePointerTuple,
50  typename OutDataTypePointerTuple,
51  typename Block2TileMapA,
52  typename Block2TileMapB,
53  typename ElementwiseOperation>
54 __global__ void
55 #if CK_USE_LAUNCH_BOUNDS
57 #endif
58  kernel_elementwise_dual(const InBGridDescTuple in_grid_desc_tuple_a,
59  const InBGridDescTuple in_grid_desc_tuple_b,
60  const OutAGridDescTuple out_grid_desc_tuple_a,
61  const OutBGridDescTuple out_grid_desc_tuple_b,
62  const InDataTypePointerTuple p_in_global_tuple_a,
63  const InDataTypePointerTuple p_in_global_tuple_b,
64  const OutDataTypePointerTuple p_out_global_tuple_a,
65  const OutDataTypePointerTuple p_out_global_tuple_b,
66  const Block2TileMapA block_2_tile_map_a,
67  const Block2TileMapB block_2_tile_map_b,
68  const ElementwiseOperation elementwise_op,
69  const index_t a_grid_size)
70 {
71  if(get_block_1d_id() < a_grid_size)
72  {
73  GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_a,
74  out_grid_desc_tuple_a,
75  p_in_global_tuple_a,
76  p_out_global_tuple_a,
77  block_2_tile_map_a,
78  elementwise_op,
79  get_block_1d_id());
80  }
81  else
82  {
83  GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_b,
84  out_grid_desc_tuple_b,
85  p_in_global_tuple_b,
86  p_out_global_tuple_b,
87  block_2_tile_map_b,
88  elementwise_op,
89  get_block_1d_id() - a_grid_size);
90  }
91 }
92 
93 template <typename GridwiseElementwiseFunctor,
94  typename InGridDescTuple,
95  typename OutGridDescTuple,
96  typename InDataTypePointerTuple,
97  typename OutDataTypePointerTuple,
98  typename Block2TileMap,
99  typename ElementwiseOperation,
100  index_t NumInputs,
101  index_t NumOutputs>
102 __global__ void
103 #if CK_USE_LAUNCH_BOUNDS
104  __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
105 #endif
106  kernel_batched_elementwise(const InGridDescTuple in_grid_desc_tuple,
107  const OutGridDescTuple out_grid_desc_tuple,
108  const InDataTypePointerTuple p_in_global_tuple,
109  const OutDataTypePointerTuple p_out_global_tuple,
110  const Block2TileMap block_2_tile_map,
111  const ElementwiseOperation elementwise_op,
112  const index_t batch_count,
113  const std::array<index_t, NumInputs> input_batch_strides,
114  const std::array<index_t, NumOutputs> output_batch_strides)
115 {
116  static_assert(InGridDescTuple::Size() == NumInputs &&
117  InDataTypePointerTuple::Size() == NumInputs);
118  static_assert(OutGridDescTuple::Size() == NumOutputs &&
119  OutDataTypePointerTuple::Size() == NumOutputs);
120 
121  const index_t num_blocks_per_batch =
122  __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
123  const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
124 
125  InDataTypePointerTuple p_in_global_with_offset_tuple;
126  OutDataTypePointerTuple p_out_global_with_offset_tuple;
127 
128  static_for<0, InDataTypePointerTuple::Size(), 1>{}([&](auto i) {
129  p_in_global_with_offset_tuple(i) = p_in_global_tuple.At(i) + input_batch_strides[i] * g_idx;
130  });
131 
132  static_for<0, OutDataTypePointerTuple::Size(), 1>{}([&](auto i) {
133  p_out_global_with_offset_tuple(i) =
134  p_out_global_tuple.At(i) + output_batch_strides[i] * g_idx;
135  });
136 
137  GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,
138  out_grid_desc_tuple,
139  p_in_global_with_offset_tuple,
140  p_out_global_with_offset_tuple,
141  block_2_tile_map,
142  elementwise_op);
143 }
144 
145 template <typename InGridDescTuple,
146  typename OutGridDescTuple,
147  typename InDataTypePointerTuple,
148  typename OutDataTypePointerTuple,
149  typename Block2TileMap,
150  typename ElementwiseOperation,
151  index_t BlockSize,
152  index_t M0PerBlock,
153  index_t M1PerBlock,
154  index_t M0PerThread,
155  index_t M1PerThread,
156  typename ThreadClusterArrangeOrder,
157  typename InScalarPerVectorSeq,
158  typename OutScalarPerVectorSeq,
159  index_t SrcVectorDim,
160  index_t DstVectorDim>
162 {
163  static constexpr index_t NumInput = InDataTypePointerTuple::Size();
164  static constexpr index_t NumOutput = OutDataTypePointerTuple::Size();
165 
166  static_assert(NumInput == InScalarPerVectorSeq::Size() &&
167  NumOutput == OutScalarPerVectorSeq::Size() &&
168  NumInput == InGridDescTuple::Size() && NumOutput == OutGridDescTuple::Size(),
169  "Tuple size is inconsistent with the number of in/out!");
170 
171  static constexpr auto I0 = Number<0>{};
172  static constexpr auto I1 = Number<1>{};
173 
174  static_assert((SrcVectorDim == I0 || SrcVectorDim == I1) &&
175  (DstVectorDim == I0 || DstVectorDim == I1),
176  "Vector dim must be equal to 0 or 1.");
177 
179 
180  __device__ static void Run(const InGridDescTuple& in_grid_desc_tuple,
181  const OutGridDescTuple& out_grid_desc_tuple,
182  const InDataTypePointerTuple& p_in_global_tuple,
183  const OutDataTypePointerTuple& p_out_global_tuple,
184  const Block2TileMap& block_2_tile_map,
185  const ElementwiseOperation& elementwise_op,
186  const index_t block_id = get_block_1d_id())
187  {
188 
189  constexpr auto src_datas = generate_tuple(
190  [&](auto I) {
191  using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
193 
194  return DataType{};
195  },
196  Number<NumInput>{});
197 
198  constexpr auto dst_datas = generate_tuple(
199  [&](auto I) {
200  using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
201  using DataType = remove_pointer_t<DataTypePointer>;
202 
203  return DataType{};
204  },
206 
207  const auto in_global_buf_tuple = generate_tuple(
208  [&](auto I) {
209  return make_dynamic_buffer<AddressSpaceEnum::Global>(
210  p_in_global_tuple[I], in_grid_desc_tuple[I].GetElementSpaceSize());
211  },
212  Number<NumInput>{});
213 
214  auto out_global_buf_tuple = generate_tuple(
215  [&](auto I) {
216  return make_dynamic_buffer<AddressSpaceEnum::Global>(
217  p_out_global_tuple[I], out_grid_desc_tuple[I].GetElementSpaceSize());
218  },
220 
221  const auto block_work_idx =
222  block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id));
223 
224  const index_t m0_block_data_idx_on_grid =
225  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock);
226  const index_t m1_block_data_idx_on_grid =
227  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * M1PerBlock);
228  const auto input_thread_grid_offset = generate_tuple(
229  [&](auto) {
230  return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
231  },
232  Number<NumInput>{});
233  const auto output_thread_grid_offset = generate_tuple(
234  [&](auto) {
235  return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
236  },
238 
240  // If src and dst have same vector dim, then:
241  // M0 dim - for src and dst vector load/store
242  // else:
243  // M0 dim - for dst vector load
244  // M1 dim - for src vector store
245  using SrcDimAccessOrder =
246  std::conditional_t<SrcVectorDim == I1, Sequence<0, 1>, Sequence<1, 0>>;
247  using DstDimAccessOrder =
248  std::conditional_t<DstVectorDim == I1, Sequence<0, 1>, Sequence<1, 0>>;
249 
250  using ThreadClusterLengths =
251  Sequence<Number<M0PerBlock / M0PerThread>{}, Number<M1PerBlock / M1PerThread>{}>;
252 
253  auto global_to_global_transfer = ThreadGroupTensorSliceTransfer_v4r2<
255  ElementwiseOperation,
258  ThreadClusterLengths,
259  ThreadClusterArrangeOrder,
260  decltype(src_datas),
261  decltype(dst_datas),
262  InGridDescTuple,
263  OutGridDescTuple,
264  SrcDimAccessOrder,
265  DstDimAccessOrder,
266  SrcVectorDim,
267  DstVectorDim,
268  InScalarPerVectorSeq,
269  OutScalarPerVectorSeq,
273  uniform_sequence_gen_t<NumOutput, false>>{in_grid_desc_tuple,
274  input_thread_grid_offset,
275  out_grid_desc_tuple,
276  output_thread_grid_offset,
277  elementwise_op};
278  global_to_global_transfer.Run(
279  in_grid_desc_tuple, in_global_buf_tuple, out_grid_desc_tuple, out_global_buf_tuple, I0);
280  }
281 };
282 
283 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:34
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
Definition: ck.hpp:264
__global__ void kernel_batched_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op, const index_t batch_count, const std::array< index_t, NumInputs > input_batch_strides, const std::array< index_t, NumOutputs > output_batch_strides)
Definition: gridwise_elementwise_2d.hpp:106
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_grid_size()
Definition: get_id.hpp:24
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:901
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
typename remove_pointer< T >::type remove_pointer_t
Definition: type.hpp:303
__global__ void kernel_elementwise_dual(const InBGridDescTuple in_grid_desc_tuple_a, const InBGridDescTuple in_grid_desc_tuple_b, const OutAGridDescTuple out_grid_desc_tuple_a, const OutBGridDescTuple out_grid_desc_tuple_b, const InDataTypePointerTuple p_in_global_tuple_a, const InDataTypePointerTuple p_in_global_tuple_b, const OutDataTypePointerTuple p_out_global_tuple_a, const OutDataTypePointerTuple p_out_global_tuple_b, const Block2TileMapA block_2_tile_map_a, const Block2TileMapB block_2_tile_map_b, const ElementwiseOperation elementwise_op, const index_t a_grid_size)
Definition: gridwise_elementwise_2d.hpp:58
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
int32_t index_t
Definition: ck.hpp:289
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:298
__global__ void kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op)
Definition: gridwise_elementwise_2d.hpp:29
Definition: gridwise_elementwise_2d.hpp:162
static constexpr index_t NumInput
Definition: gridwise_elementwise_2d.hpp:163
static constexpr auto I1
Definition: gridwise_elementwise_2d.hpp:172
static __device__ void Run(const InGridDescTuple &in_grid_desc_tuple, const OutGridDescTuple &out_grid_desc_tuple, const InDataTypePointerTuple &p_in_global_tuple, const OutDataTypePointerTuple &p_out_global_tuple, const Block2TileMap &block_2_tile_map, const ElementwiseOperation &elementwise_op, const index_t block_id=get_block_1d_id())
Definition: gridwise_elementwise_2d.hpp:180
static constexpr auto I0
Definition: gridwise_elementwise_2d.hpp:171
static constexpr index_t NumOutput
Definition: gridwise_elementwise_2d.hpp:164
Definition: sequence.hpp:43
Definition: thread_group.hpp:12
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r2.hpp:45
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
Definition: unary_element_wise_operation.hpp:241