/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp Source File
gridwise_elementwise_2d.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
15 
16 namespace ck {
17 
18 template <typename GridwiseElementwiseFunctor,
19  typename InGridDescTuple,
20  typename OutGridDescTuple,
21  typename InDataTypePointerTuple,
22  typename OutDataTypePointerTuple,
23  typename Block2TileMap,
24  typename ElementwiseOperation>
25 __global__ void
26 #if CK_USE_LAUNCH_BOUNDS
28 #endif
29  kernel_elementwise(const InGridDescTuple in_grid_desc_tuple,
30  const OutGridDescTuple out_grid_desc_tuple,
31  const InDataTypePointerTuple p_in_global_tuple,
32  const OutDataTypePointerTuple p_out_global_tuple,
33  const Block2TileMap block_2_tile_map,
34  const ElementwiseOperation elementwise_op)
35 {
36  GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,
37  out_grid_desc_tuple,
38  p_in_global_tuple,
39  p_out_global_tuple,
40  block_2_tile_map,
41  elementwise_op);
42 }
43 
44 template <typename GridwiseElementwiseFunctorA,
45  typename GridwiseElementwiseFunctorB,
46  typename InAGridDescTuple,
47  typename InBGridDescTuple,
48  typename OutAGridDescTuple,
49  typename OutBGridDescTuple,
50  typename InADataTypePointerTuple,
51  typename InBDataTypePointerTuple,
52  typename OutADataTypePointerTuple,
53  typename OutBDataTypePointerTuple,
54  typename Block2TileMapA,
55  typename Block2TileMapB,
56  typename ElementwiseOperation>
57 __global__ void
58 #if CK_USE_LAUNCH_BOUNDS
60 #endif
61  kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a,
62  const InBGridDescTuple in_grid_desc_tuple_b,
63  const OutAGridDescTuple out_grid_desc_tuple_a,
64  const OutBGridDescTuple out_grid_desc_tuple_b,
65  const InADataTypePointerTuple p_in_global_tuple_a,
66  const InBDataTypePointerTuple p_in_global_tuple_b,
67  const OutADataTypePointerTuple p_out_global_tuple_a,
68  const OutBDataTypePointerTuple p_out_global_tuple_b,
69  const Block2TileMapA block_2_tile_map_a,
70  const Block2TileMapB block_2_tile_map_b,
71  const ElementwiseOperation elementwise_op,
72  const index_t a_grid_size)
73 {
74  if(get_block_1d_id() < a_grid_size)
75  {
76  GridwiseElementwiseFunctorA::Run(in_grid_desc_tuple_a,
77  out_grid_desc_tuple_a,
78  p_in_global_tuple_a,
79  p_out_global_tuple_a,
80  block_2_tile_map_a,
81  elementwise_op,
82  get_block_1d_id());
83  }
84  else
85  {
86  GridwiseElementwiseFunctorB::Run(in_grid_desc_tuple_b,
87  out_grid_desc_tuple_b,
88  p_in_global_tuple_b,
89  p_out_global_tuple_b,
90  block_2_tile_map_b,
91  elementwise_op,
92  get_block_1d_id() - a_grid_size);
93  }
94 }
95 
96 template <typename GridwiseElementwiseFunctorA,
97  typename GridwiseElementwiseFunctorB,
98  typename InAGridDescTuple,
99  typename InBGridDescTuple,
100  typename OutAGridDescTuple,
101  typename OutBGridDescTuple,
102  typename InADataTypePointerTuple,
103  typename InBDataTypePointerTuple,
104  typename OutADataTypePointerTuple,
105  typename OutBDataTypePointerTuple,
106  typename Block2TileMapA,
107  typename Block2TileMapB,
108  typename ElementwiseOperation,
109  index_t NumInputsA,
110  index_t NumInputsB,
111  index_t NumOutputsA,
112  index_t NumOutputsB>
113 __global__ void
114 #if CK_USE_LAUNCH_BOUNDS
115  __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
116 #endif
118  const InAGridDescTuple in_grid_desc_tuple_a,
119  const InBGridDescTuple in_grid_desc_tuple_b,
120  const OutAGridDescTuple out_grid_desc_tuple_a,
121  const OutBGridDescTuple out_grid_desc_tuple_b,
122  const InADataTypePointerTuple p_in_global_tuple_a,
123  const InBDataTypePointerTuple p_in_global_tuple_b,
124  const OutADataTypePointerTuple p_out_global_tuple_a,
125  const OutBDataTypePointerTuple p_out_global_tuple_b,
126  const Block2TileMapA block_2_tile_map_a,
127  const Block2TileMapB block_2_tile_map_b,
128  const ElementwiseOperation elementwise_op,
129  const index_t a_grid_size,
130  const index_t batch_count_a,
131  const index_t batch_count_b,
132  const std::array<index_t, NumInputsA> input_batch_strides_a,
133  const std::array<index_t, NumInputsB> input_batch_strides_b,
134  const std::array<index_t, NumOutputsA> output_batch_strides_a,
135  const std::array<index_t, NumOutputsB> output_batch_strides_b)
136 {
137  static_assert(InAGridDescTuple::Size() == NumInputsA &&
138  InADataTypePointerTuple::Size() == NumInputsA);
139  static_assert(OutAGridDescTuple::Size() == NumOutputsA &&
140  OutADataTypePointerTuple::Size() == NumOutputsA);
141  static_assert(InBGridDescTuple::Size() == NumInputsB &&
142  InBDataTypePointerTuple::Size() == NumInputsB);
143  static_assert(OutBGridDescTuple::Size() == NumOutputsB &&
144  OutBDataTypePointerTuple::Size() == NumOutputsB);
145 
146  const index_t block_id = __builtin_amdgcn_readfirstlane(get_block_1d_id());
147 
148  if(block_id < a_grid_size)
149  {
150  const index_t num_blocks_per_batch =
151  __builtin_amdgcn_readfirstlane(a_grid_size / batch_count_a);
152  const index_t g_idx = __builtin_amdgcn_readfirstlane(block_id / num_blocks_per_batch);
153 
154  InADataTypePointerTuple p_in_global_with_offset_tuple;
155  OutADataTypePointerTuple p_out_global_with_offset_tuple;
156 
157  static_for<0, InADataTypePointerTuple::Size(), 1>{}([&](auto i) {
158  p_in_global_with_offset_tuple(i) =
159  p_in_global_tuple_a.At(i) +
160  type_convert<long_index_t>(input_batch_strides_a[i]) * g_idx;
161  });
162 
163  static_for<0, OutADataTypePointerTuple::Size(), 1>{}([&](auto i) {
164  p_out_global_with_offset_tuple(i) =
165  p_out_global_tuple_a.At(i) +
166  type_convert<long_index_t>(output_batch_strides_a[i]) * g_idx;
167  });
168 
169  GridwiseElementwiseFunctorA::Run(in_grid_desc_tuple_a,
170  out_grid_desc_tuple_a,
171  p_in_global_with_offset_tuple,
172  p_out_global_with_offset_tuple,
173  block_2_tile_map_a,
174  elementwise_op,
175  block_id);
176  }
177  else
178  {
179  const index_t num_blocks_per_batch =
180  __builtin_amdgcn_readfirstlane((get_grid_size() - a_grid_size) / batch_count_b);
181  const index_t g_idx =
182  __builtin_amdgcn_readfirstlane((block_id - a_grid_size) / num_blocks_per_batch);
183 
184  InBDataTypePointerTuple p_in_global_with_offset_tuple;
185  OutBDataTypePointerTuple p_out_global_with_offset_tuple;
186 
187  static_for<0, InBDataTypePointerTuple::Size(), 1>{}([&](auto i) {
188  p_in_global_with_offset_tuple(i) =
189  p_in_global_tuple_b.At(i) +
190  type_convert<long_index_t>(input_batch_strides_b[i]) * g_idx;
191  });
192 
193  static_for<0, OutBDataTypePointerTuple::Size(), 1>{}([&](auto i) {
194  p_out_global_with_offset_tuple(i) =
195  p_out_global_tuple_b.At(i) +
196  type_convert<long_index_t>(output_batch_strides_b[i]) * g_idx;
197  });
198 
199  GridwiseElementwiseFunctorB::Run(in_grid_desc_tuple_b,
200  out_grid_desc_tuple_b,
201  p_in_global_with_offset_tuple,
202  p_out_global_with_offset_tuple,
203  block_2_tile_map_b,
204  elementwise_op,
205  block_id - a_grid_size);
206  }
207 }
208 
209 template <typename GridwiseElementwiseFunctor,
210  typename InGridDescTuple,
211  typename OutGridDescTuple,
212  typename InDataTypePointerTuple,
213  typename OutDataTypePointerTuple,
214  typename Block2TileMap,
215  typename ElementwiseOperation,
216  index_t NumInputs,
217  index_t NumOutputs>
218 __global__ void
219 #if CK_USE_LAUNCH_BOUNDS
220  __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
221 #endif
222  kernel_batched_elementwise(const InGridDescTuple in_grid_desc_tuple,
223  const OutGridDescTuple out_grid_desc_tuple,
224  const InDataTypePointerTuple p_in_global_tuple,
225  const OutDataTypePointerTuple p_out_global_tuple,
226  const Block2TileMap block_2_tile_map,
227  const ElementwiseOperation elementwise_op,
228  const index_t batch_count,
229  const std::array<index_t, NumInputs> input_batch_strides,
230  const std::array<index_t, NumOutputs> output_batch_strides)
231 {
232  static_assert(InGridDescTuple::Size() == NumInputs &&
233  InDataTypePointerTuple::Size() == NumInputs);
234  static_assert(OutGridDescTuple::Size() == NumOutputs &&
235  OutDataTypePointerTuple::Size() == NumOutputs);
236 
237  const index_t num_blocks_per_batch =
238  __builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
239  const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
240 
241  InDataTypePointerTuple p_in_global_with_offset_tuple;
242  OutDataTypePointerTuple p_out_global_with_offset_tuple;
243 
244  static_for<0, InDataTypePointerTuple::Size(), 1>{}([&](auto i) {
245  p_in_global_with_offset_tuple(i) =
246  p_in_global_tuple.At(i) + type_convert<long_index_t>(input_batch_strides[i]) * g_idx;
247  });
248 
249  static_for<0, OutDataTypePointerTuple::Size(), 1>{}([&](auto i) {
250  p_out_global_with_offset_tuple(i) =
251  p_out_global_tuple.At(i) + type_convert<long_index_t>(output_batch_strides[i]) * g_idx;
252  });
253 
254  GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,
255  out_grid_desc_tuple,
256  p_in_global_with_offset_tuple,
257  p_out_global_with_offset_tuple,
258  block_2_tile_map,
259  elementwise_op);
260 }
261 
262 template <typename InGridDescTuple,
263  typename OutGridDescTuple,
264  typename InDataTypePointerTuple,
265  typename OutDataTypePointerTuple,
266  typename Block2TileMap,
267  typename ElementwiseOperation,
268  index_t BlockSize,
269  index_t M0PerBlock,
270  index_t M1PerBlock,
271  index_t M0PerThread,
272  index_t M1PerThread,
273  typename ThreadClusterArrangeOrder,
274  typename InScalarPerVectorSeq,
275  typename OutScalarPerVectorSeq,
276  index_t SrcVectorDim,
277  index_t DstVectorDim>
279 {
280  static constexpr index_t NumInput = InDataTypePointerTuple::Size();
281  static constexpr index_t NumOutput = OutDataTypePointerTuple::Size();
282 
283  static_assert(NumInput == InScalarPerVectorSeq::Size() &&
284  NumOutput == OutScalarPerVectorSeq::Size() &&
285  NumInput == InGridDescTuple::Size() && NumOutput == OutGridDescTuple::Size(),
286  "Tuple size is inconsistent with the number of in/out!");
287 
288  static constexpr auto I0 = Number<0>{};
289  static constexpr auto I1 = Number<1>{};
290 
291  static_assert((SrcVectorDim == I0 || SrcVectorDim == I1) &&
292  (DstVectorDim == I0 || DstVectorDim == I1),
293  "Vector dim must be equal to 0 or 1.");
294 
296 
297  __device__ static void Run(const InGridDescTuple& in_grid_desc_tuple,
298  const OutGridDescTuple& out_grid_desc_tuple,
299  const InDataTypePointerTuple& p_in_global_tuple,
300  const OutDataTypePointerTuple& p_out_global_tuple,
301  const Block2TileMap& block_2_tile_map,
302  const ElementwiseOperation& elementwise_op,
303  const index_t block_id = get_block_1d_id())
304  {
305 
306  constexpr auto src_datas = generate_tuple(
307  [&](auto I) {
308  using DataTypePointer = remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
310 
311  return DataType{};
312  },
313  Number<NumInput>{});
314 
315  constexpr auto dst_datas = generate_tuple(
316  [&](auto I) {
317  using DataTypePointer = remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
318  using DataType = remove_pointer_t<DataTypePointer>;
319 
320  return DataType{};
321  },
323 
324  const auto in_global_buf_tuple = generate_tuple(
325  [&](auto I) {
326  return make_dynamic_buffer<AddressSpaceEnum::Global>(
327  p_in_global_tuple[I], in_grid_desc_tuple[I].GetElementSpaceSize());
328  },
329  Number<NumInput>{});
330 
331  auto out_global_buf_tuple = generate_tuple(
332  [&](auto I) {
333  return make_dynamic_buffer<AddressSpaceEnum::Global>(
334  p_out_global_tuple[I], out_grid_desc_tuple[I].GetElementSpaceSize());
335  },
337 
338  const auto block_work_idx =
339  block_2_tile_map.CalculateBottomIndex(make_multi_index(block_id));
340 
341  const index_t m0_block_data_idx_on_grid =
342  __builtin_amdgcn_readfirstlane(block_work_idx[I0] * M0PerBlock);
343  const index_t m1_block_data_idx_on_grid =
344  __builtin_amdgcn_readfirstlane(block_work_idx[I1] * M1PerBlock);
345  const auto input_thread_grid_offset = generate_tuple(
346  [&](auto) {
347  return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
348  },
349  Number<NumInput>{});
350  const auto output_thread_grid_offset = generate_tuple(
351  [&](auto) {
352  return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
353  },
355 
357  // If src and dst have same vector dim, then:
358  // M0 dim - for src and dst vector load/store
359  // else:
360  // M0 dim - for dst vector load
361  // M1 dim - for src vector store
362  using SrcDimAccessOrder =
363  std::conditional_t<SrcVectorDim == I1, Sequence<0, 1>, Sequence<1, 0>>;
364  using DstDimAccessOrder =
365  std::conditional_t<DstVectorDim == I1, Sequence<0, 1>, Sequence<1, 0>>;
366 
367  using ThreadClusterLengths =
368  Sequence<Number<M0PerBlock / M0PerThread>{}, Number<M1PerBlock / M1PerThread>{}>;
369 
370  auto global_to_global_transfer = ThreadGroupTensorSliceTransfer_v4r2<
372  ElementwiseOperation,
375  ThreadClusterLengths,
376  ThreadClusterArrangeOrder,
377  decltype(src_datas),
378  decltype(dst_datas),
379  InGridDescTuple,
380  OutGridDescTuple,
381  SrcDimAccessOrder,
382  DstDimAccessOrder,
383  SrcVectorDim,
384  DstVectorDim,
385  InScalarPerVectorSeq,
386  OutScalarPerVectorSeq,
390  uniform_sequence_gen_t<NumOutput, false>>{in_grid_desc_tuple,
391  input_thread_grid_offset,
392  out_grid_desc_tuple,
393  output_thread_grid_offset,
394  elementwise_op};
395  global_to_global_transfer.Run(
396  in_grid_desc_tuple, in_global_buf_tuple, out_grid_desc_tuple, out_global_buf_tuple, I0);
397  }
398 };
399 
400 } // namespace ck
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
Definition: ck.hpp:269
__global__ void kernel_batched_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op, const index_t batch_count, const std::array< index_t, NumInputs > input_batch_strides, const std::array< index_t, NumOutputs > output_batch_strides)
Definition: gridwise_elementwise_2d.hpp:222
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_grid_size()
Definition: get_id.hpp:27
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
typename remove_pointer< T >::type remove_pointer_t
Definition: type.hpp:300
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:300
__global__ void kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a, const InBGridDescTuple in_grid_desc_tuple_b, const OutAGridDescTuple out_grid_desc_tuple_a, const OutBGridDescTuple out_grid_desc_tuple_b, const InADataTypePointerTuple p_in_global_tuple_a, const InBDataTypePointerTuple p_in_global_tuple_b, const OutADataTypePointerTuple p_out_global_tuple_a, const OutBDataTypePointerTuple p_out_global_tuple_b, const Block2TileMapA block_2_tile_map_a, const Block2TileMapB block_2_tile_map_b, const ElementwiseOperation elementwise_op, const index_t a_grid_size)
Definition: gridwise_elementwise_2d.hpp:61
__global__ void kernel_elementwise_batched_dual(const InAGridDescTuple in_grid_desc_tuple_a, const InBGridDescTuple in_grid_desc_tuple_b, const OutAGridDescTuple out_grid_desc_tuple_a, const OutBGridDescTuple out_grid_desc_tuple_b, const InADataTypePointerTuple p_in_global_tuple_a, const InBDataTypePointerTuple p_in_global_tuple_b, const OutADataTypePointerTuple p_out_global_tuple_a, const OutBDataTypePointerTuple p_out_global_tuple_b, const Block2TileMapA block_2_tile_map_a, const Block2TileMapB block_2_tile_map_b, const ElementwiseOperation elementwise_op, const index_t a_grid_size, const index_t batch_count_a, const index_t batch_count_b, const std::array< index_t, NumInputsA > input_batch_strides_a, const std::array< index_t, NumInputsB > input_batch_strides_b, const std::array< index_t, NumOutputsA > output_batch_strides_a, const std::array< index_t, NumOutputsB > output_batch_strides_b)
Definition: gridwise_elementwise_2d.hpp:117
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:295
__global__ void kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op)
Definition: gridwise_elementwise_2d.hpp:29
Definition: gridwise_elementwise_2d.hpp:279
static constexpr index_t NumInput
Definition: gridwise_elementwise_2d.hpp:280
static constexpr auto I1
Definition: gridwise_elementwise_2d.hpp:289
static __device__ void Run(const InGridDescTuple &in_grid_desc_tuple, const OutGridDescTuple &out_grid_desc_tuple, const InDataTypePointerTuple &p_in_global_tuple, const OutDataTypePointerTuple &p_out_global_tuple, const Block2TileMap &block_2_tile_map, const ElementwiseOperation &elementwise_op, const index_t block_id=get_block_1d_id())
Definition: gridwise_elementwise_2d.hpp:297
static constexpr auto I0
Definition: gridwise_elementwise_2d.hpp:288
static constexpr index_t NumOutput
Definition: gridwise_elementwise_2d.hpp:281
Definition: sequence.hpp:43
Definition: thread_group.hpp:12
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r2.hpp:45
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:308