18 template <
typename GridwiseElementwiseFunctor,
19 typename InGridDescTuple,
20 typename OutGridDescTuple,
21 typename InDataTypePointerTuple,
22 typename OutDataTypePointerTuple,
23 typename Block2TileMap,
24 typename ElementwiseOperation>
26 #if CK_USE_LAUNCH_BOUNDS
30 const OutGridDescTuple out_grid_desc_tuple,
31 const InDataTypePointerTuple p_in_global_tuple,
32 const OutDataTypePointerTuple p_out_global_tuple,
33 const Block2TileMap block_2_tile_map,
34 const ElementwiseOperation elementwise_op)
36 GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,
44 template <
typename GridwiseElementwiseFunctorA,
45 typename GridwiseElementwiseFunctorB,
46 typename InAGridDescTuple,
47 typename InBGridDescTuple,
48 typename OutAGridDescTuple,
49 typename OutBGridDescTuple,
50 typename InADataTypePointerTuple,
51 typename InBDataTypePointerTuple,
52 typename OutADataTypePointerTuple,
53 typename OutBDataTypePointerTuple,
54 typename Block2TileMapA,
55 typename Block2TileMapB,
56 typename ElementwiseOperation>
58 #if CK_USE_LAUNCH_BOUNDS
62 const InBGridDescTuple in_grid_desc_tuple_b,
63 const OutAGridDescTuple out_grid_desc_tuple_a,
64 const OutBGridDescTuple out_grid_desc_tuple_b,
65 const InADataTypePointerTuple p_in_global_tuple_a,
66 const InBDataTypePointerTuple p_in_global_tuple_b,
67 const OutADataTypePointerTuple p_out_global_tuple_a,
68 const OutBDataTypePointerTuple p_out_global_tuple_b,
69 const Block2TileMapA block_2_tile_map_a,
70 const Block2TileMapB block_2_tile_map_b,
71 const ElementwiseOperation elementwise_op,
76 GridwiseElementwiseFunctorA::Run(in_grid_desc_tuple_a,
77 out_grid_desc_tuple_a,
86 GridwiseElementwiseFunctorB::Run(in_grid_desc_tuple_b,
87 out_grid_desc_tuple_b,
96 template <
typename GridwiseElementwiseFunctorA,
97 typename GridwiseElementwiseFunctorB,
98 typename InAGridDescTuple,
99 typename InBGridDescTuple,
100 typename OutAGridDescTuple,
101 typename OutBGridDescTuple,
102 typename InADataTypePointerTuple,
103 typename InBDataTypePointerTuple,
104 typename OutADataTypePointerTuple,
105 typename OutBDataTypePointerTuple,
106 typename Block2TileMapA,
107 typename Block2TileMapB,
108 typename ElementwiseOperation,
114 #if CK_USE_LAUNCH_BOUNDS
118 const InAGridDescTuple in_grid_desc_tuple_a,
119 const InBGridDescTuple in_grid_desc_tuple_b,
120 const OutAGridDescTuple out_grid_desc_tuple_a,
121 const OutBGridDescTuple out_grid_desc_tuple_b,
122 const InADataTypePointerTuple p_in_global_tuple_a,
123 const InBDataTypePointerTuple p_in_global_tuple_b,
124 const OutADataTypePointerTuple p_out_global_tuple_a,
125 const OutBDataTypePointerTuple p_out_global_tuple_b,
126 const Block2TileMapA block_2_tile_map_a,
127 const Block2TileMapB block_2_tile_map_b,
128 const ElementwiseOperation elementwise_op,
132 const std::array<index_t, NumInputsA> input_batch_strides_a,
133 const std::array<index_t, NumInputsB> input_batch_strides_b,
134 const std::array<index_t, NumOutputsA> output_batch_strides_a,
135 const std::array<index_t, NumOutputsB> output_batch_strides_b)
137 static_assert(InAGridDescTuple::Size() == NumInputsA &&
138 InADataTypePointerTuple::Size() == NumInputsA);
139 static_assert(OutAGridDescTuple::Size() == NumOutputsA &&
140 OutADataTypePointerTuple::Size() == NumOutputsA);
141 static_assert(InBGridDescTuple::Size() == NumInputsB &&
142 InBDataTypePointerTuple::Size() == NumInputsB);
143 static_assert(OutBGridDescTuple::Size() == NumOutputsB &&
144 OutBDataTypePointerTuple::Size() == NumOutputsB);
148 if(block_id < a_grid_size)
150 const index_t num_blocks_per_batch =
151 __builtin_amdgcn_readfirstlane(a_grid_size / batch_count_a);
152 const index_t g_idx = __builtin_amdgcn_readfirstlane(block_id / num_blocks_per_batch);
154 InADataTypePointerTuple p_in_global_with_offset_tuple;
155 OutADataTypePointerTuple p_out_global_with_offset_tuple;
157 static_for<0, InADataTypePointerTuple::Size(), 1>{}([&](
auto i) {
158 p_in_global_with_offset_tuple(i) =
159 p_in_global_tuple_a.At(i) +
160 type_convert<long_index_t>(input_batch_strides_a[i]) * g_idx;
163 static_for<0, OutADataTypePointerTuple::Size(), 1>{}([&](
auto i) {
164 p_out_global_with_offset_tuple(i) =
165 p_out_global_tuple_a.At(i) +
166 type_convert<long_index_t>(output_batch_strides_a[i]) * g_idx;
169 GridwiseElementwiseFunctorA::Run(in_grid_desc_tuple_a,
170 out_grid_desc_tuple_a,
171 p_in_global_with_offset_tuple,
172 p_out_global_with_offset_tuple,
179 const index_t num_blocks_per_batch =
180 __builtin_amdgcn_readfirstlane((
get_grid_size() - a_grid_size) / batch_count_b);
182 __builtin_amdgcn_readfirstlane((block_id - a_grid_size) / num_blocks_per_batch);
184 InBDataTypePointerTuple p_in_global_with_offset_tuple;
185 OutBDataTypePointerTuple p_out_global_with_offset_tuple;
187 static_for<0, InBDataTypePointerTuple::Size(), 1>{}([&](
auto i) {
188 p_in_global_with_offset_tuple(i) =
189 p_in_global_tuple_b.At(i) +
190 type_convert<long_index_t>(input_batch_strides_b[i]) * g_idx;
193 static_for<0, OutBDataTypePointerTuple::Size(), 1>{}([&](
auto i) {
194 p_out_global_with_offset_tuple(i) =
195 p_out_global_tuple_b.At(i) +
196 type_convert<long_index_t>(output_batch_strides_b[i]) * g_idx;
199 GridwiseElementwiseFunctorB::Run(in_grid_desc_tuple_b,
200 out_grid_desc_tuple_b,
201 p_in_global_with_offset_tuple,
202 p_out_global_with_offset_tuple,
205 block_id - a_grid_size);
209 template <
typename GridwiseElementwiseFunctor,
210 typename InGridDescTuple,
211 typename OutGridDescTuple,
212 typename InDataTypePointerTuple,
213 typename OutDataTypePointerTuple,
214 typename Block2TileMap,
215 typename ElementwiseOperation,
219 #if CK_USE_LAUNCH_BOUNDS
223 const OutGridDescTuple out_grid_desc_tuple,
224 const InDataTypePointerTuple p_in_global_tuple,
225 const OutDataTypePointerTuple p_out_global_tuple,
226 const Block2TileMap block_2_tile_map,
227 const ElementwiseOperation elementwise_op,
229 const std::array<index_t, NumInputs> input_batch_strides,
230 const std::array<index_t, NumOutputs> output_batch_strides)
232 static_assert(InGridDescTuple::Size() == NumInputs &&
233 InDataTypePointerTuple::Size() == NumInputs);
234 static_assert(OutGridDescTuple::Size() == NumOutputs &&
235 OutDataTypePointerTuple::Size() == NumOutputs);
237 const index_t num_blocks_per_batch =
238 __builtin_amdgcn_readfirstlane(
get_grid_size() / batch_count);
241 InDataTypePointerTuple p_in_global_with_offset_tuple;
242 OutDataTypePointerTuple p_out_global_with_offset_tuple;
244 static_for<0, InDataTypePointerTuple::Size(), 1>{}([&](
auto i) {
245 p_in_global_with_offset_tuple(i) =
246 p_in_global_tuple.At(i) + type_convert<long_index_t>(input_batch_strides[i]) * g_idx;
249 static_for<0, OutDataTypePointerTuple::Size(), 1>{}([&](
auto i) {
250 p_out_global_with_offset_tuple(i) =
251 p_out_global_tuple.At(i) + type_convert<long_index_t>(output_batch_strides[i]) * g_idx;
254 GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,
256 p_in_global_with_offset_tuple,
257 p_out_global_with_offset_tuple,
262 template <
typename InGridDescTuple,
263 typename OutGridDescTuple,
264 typename InDataTypePointerTuple,
265 typename OutDataTypePointerTuple,
266 typename Block2TileMap,
267 typename ElementwiseOperation,
273 typename ThreadClusterArrangeOrder,
274 typename InScalarPerVectorSeq,
275 typename OutScalarPerVectorSeq,
283 static_assert(
NumInput == InScalarPerVectorSeq::Size() &&
284 NumOutput == OutScalarPerVectorSeq::Size() &&
286 "Tuple size is inconsistent with the number of in/out!");
291 static_assert((SrcVectorDim ==
I0 || SrcVectorDim ==
I1) &&
292 (DstVectorDim ==
I0 || DstVectorDim ==
I1),
293 "Vector dim must be equal to 0 or 1.");
297 __device__
static void Run(
const InGridDescTuple& in_grid_desc_tuple,
298 const OutGridDescTuple& out_grid_desc_tuple,
299 const InDataTypePointerTuple& p_in_global_tuple,
300 const OutDataTypePointerTuple& p_out_global_tuple,
301 const Block2TileMap& block_2_tile_map,
302 const ElementwiseOperation& elementwise_op,
308 using DataTypePointer =
remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
317 using DataTypePointer =
remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
326 return make_dynamic_buffer<AddressSpaceEnum::Global>(
327 p_in_global_tuple[I], in_grid_desc_tuple[I].GetElementSpaceSize());
333 return make_dynamic_buffer<AddressSpaceEnum::Global>(
334 p_out_global_tuple[I], out_grid_desc_tuple[I].GetElementSpaceSize());
338 const auto block_work_idx =
341 const index_t m0_block_data_idx_on_grid =
342 __builtin_amdgcn_readfirstlane(block_work_idx[
I0] * M0PerBlock);
343 const index_t m1_block_data_idx_on_grid =
344 __builtin_amdgcn_readfirstlane(block_work_idx[
I1] * M1PerBlock);
347 return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
352 return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
362 using SrcDimAccessOrder =
363 std::conditional_t<SrcVectorDim == I1, Sequence<0, 1>,
Sequence<1, 0>>;
364 using DstDimAccessOrder =
365 std::conditional_t<DstVectorDim == I1, Sequence<0, 1>,
Sequence<1, 0>>;
367 using ThreadClusterLengths =
372 ElementwiseOperation,
375 ThreadClusterLengths,
376 ThreadClusterArrangeOrder,
385 InScalarPerVectorSeq,
386 OutScalarPerVectorSeq,
391 input_thread_grid_offset,
393 output_thread_grid_offset,
395 global_to_global_transfer.Run(
396 in_grid_desc_tuple, in_global_buf_tuple, out_grid_desc_tuple, out_global_buf_tuple,
I0);
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:30
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:29
__global__ void kernel_batched_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op, const index_t batch_count, const std::array< index_t, NumInputs > input_batch_strides, const std::array< index_t, NumOutputs > output_batch_strides)
Definition: gridwise_elementwise_2d.hpp:222
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_grid_size()
Definition: get_id.hpp:27
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
typename remove_pointer< T >::type remove_pointer_t
Definition: type.hpp:300
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
int32_t index_t
Definition: ck.hpp:300
__global__ void kernel_elementwise_dual(const InAGridDescTuple in_grid_desc_tuple_a, const InBGridDescTuple in_grid_desc_tuple_b, const OutAGridDescTuple out_grid_desc_tuple_a, const OutBGridDescTuple out_grid_desc_tuple_b, const InADataTypePointerTuple p_in_global_tuple_a, const InBDataTypePointerTuple p_in_global_tuple_b, const OutADataTypePointerTuple p_out_global_tuple_a, const OutBDataTypePointerTuple p_out_global_tuple_b, const Block2TileMapA block_2_tile_map_a, const Block2TileMapB block_2_tile_map_b, const ElementwiseOperation elementwise_op, const index_t a_grid_size)
Definition: gridwise_elementwise_2d.hpp:61
__global__ void kernel_elementwise_batched_dual(const InAGridDescTuple in_grid_desc_tuple_a, const InBGridDescTuple in_grid_desc_tuple_b, const OutAGridDescTuple out_grid_desc_tuple_a, const OutBGridDescTuple out_grid_desc_tuple_b, const InADataTypePointerTuple p_in_global_tuple_a, const InBDataTypePointerTuple p_in_global_tuple_b, const OutADataTypePointerTuple p_out_global_tuple_a, const OutBDataTypePointerTuple p_out_global_tuple_b, const Block2TileMapA block_2_tile_map_a, const Block2TileMapB block_2_tile_map_b, const ElementwiseOperation elementwise_op, const index_t a_grid_size, const index_t batch_count_a, const index_t batch_count_b, const std::array< index_t, NumInputsA > input_batch_strides_a, const std::array< index_t, NumInputsB > input_batch_strides_b, const std::array< index_t, NumOutputsA > output_batch_strides_a, const std::array< index_t, NumOutputsB > output_batch_strides_b)
Definition: gridwise_elementwise_2d.hpp:117
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:295
__global__ void kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op)
Definition: gridwise_elementwise_2d.hpp:29
Definition: gridwise_elementwise_2d.hpp:279
static constexpr index_t NumInput
Definition: gridwise_elementwise_2d.hpp:280
static constexpr auto I1
Definition: gridwise_elementwise_2d.hpp:289
static __device__ void Run(const InGridDescTuple &in_grid_desc_tuple, const OutGridDescTuple &out_grid_desc_tuple, const InDataTypePointerTuple &p_in_global_tuple, const OutDataTypePointerTuple &p_out_global_tuple, const Block2TileMap &block_2_tile_map, const ElementwiseOperation &elementwise_op, const index_t block_id=get_block_1d_id())
Definition: gridwise_elementwise_2d.hpp:297
static constexpr auto I0
Definition: gridwise_elementwise_2d.hpp:288
static constexpr index_t NumOutput
Definition: gridwise_elementwise_2d.hpp:281
Definition: sequence.hpp:43
Definition: thread_group.hpp:12
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r2.hpp:45
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: unary_element_wise_operation.hpp:308