18 template <
typename GridwiseElementwiseFunctor,
19 typename InGridDescTuple,
20 typename OutGridDescTuple,
21 typename InDataTypePointerTuple,
22 typename OutDataTypePointerTuple,
23 typename Block2TileMap,
24 typename ElementwiseOperation>
26 #if CK_USE_LAUNCH_BOUNDS
30 const OutGridDescTuple out_grid_desc_tuple,
31 const InDataTypePointerTuple p_in_global_tuple,
32 const OutDataTypePointerTuple p_out_global_tuple,
33 const Block2TileMap block_2_tile_map,
34 const ElementwiseOperation elementwise_op)
36 GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,
44 template <
typename GridwiseElementwiseFunctor,
45 typename InAGridDescTuple,
46 typename InBGridDescTuple,
47 typename OutAGridDescTuple,
48 typename OutBGridDescTuple,
49 typename InDataTypePointerTuple,
50 typename OutDataTypePointerTuple,
51 typename Block2TileMapA,
52 typename Block2TileMapB,
53 typename ElementwiseOperation>
55 #if CK_USE_LAUNCH_BOUNDS
59 const InBGridDescTuple in_grid_desc_tuple_b,
60 const OutAGridDescTuple out_grid_desc_tuple_a,
61 const OutBGridDescTuple out_grid_desc_tuple_b,
62 const InDataTypePointerTuple p_in_global_tuple_a,
63 const InDataTypePointerTuple p_in_global_tuple_b,
64 const OutDataTypePointerTuple p_out_global_tuple_a,
65 const OutDataTypePointerTuple p_out_global_tuple_b,
66 const Block2TileMapA block_2_tile_map_a,
67 const Block2TileMapB block_2_tile_map_b,
68 const ElementwiseOperation elementwise_op,
73 GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_a,
74 out_grid_desc_tuple_a,
83 GridwiseElementwiseFunctor::Run(in_grid_desc_tuple_b,
84 out_grid_desc_tuple_b,
93 template <
typename GridwiseElementwiseFunctor,
94 typename InGridDescTuple,
95 typename OutGridDescTuple,
96 typename InDataTypePointerTuple,
97 typename OutDataTypePointerTuple,
98 typename Block2TileMap,
99 typename ElementwiseOperation,
103 #if CK_USE_LAUNCH_BOUNDS
107 const OutGridDescTuple out_grid_desc_tuple,
108 const InDataTypePointerTuple p_in_global_tuple,
109 const OutDataTypePointerTuple p_out_global_tuple,
110 const Block2TileMap block_2_tile_map,
111 const ElementwiseOperation elementwise_op,
113 const std::array<index_t, NumInputs> input_batch_strides,
114 const std::array<index_t, NumOutputs> output_batch_strides)
116 static_assert(InGridDescTuple::Size() == NumInputs &&
117 InDataTypePointerTuple::Size() == NumInputs);
118 static_assert(OutGridDescTuple::Size() == NumOutputs &&
119 OutDataTypePointerTuple::Size() == NumOutputs);
121 const index_t num_blocks_per_batch =
122 __builtin_amdgcn_readfirstlane(
get_grid_size() / batch_count);
125 InDataTypePointerTuple p_in_global_with_offset_tuple;
126 OutDataTypePointerTuple p_out_global_with_offset_tuple;
128 static_for<0, InDataTypePointerTuple::Size(), 1>{}([&](
auto i) {
129 p_in_global_with_offset_tuple(i) = p_in_global_tuple.At(i) + input_batch_strides[i] * g_idx;
132 static_for<0, OutDataTypePointerTuple::Size(), 1>{}([&](
auto i) {
133 p_out_global_with_offset_tuple(i) =
134 p_out_global_tuple.At(i) + output_batch_strides[i] * g_idx;
137 GridwiseElementwiseFunctor::Run(in_grid_desc_tuple,
139 p_in_global_with_offset_tuple,
140 p_out_global_with_offset_tuple,
145 template <
typename InGridDescTuple,
146 typename OutGridDescTuple,
147 typename InDataTypePointerTuple,
148 typename OutDataTypePointerTuple,
149 typename Block2TileMap,
150 typename ElementwiseOperation,
156 typename ThreadClusterArrangeOrder,
157 typename InScalarPerVectorSeq,
158 typename OutScalarPerVectorSeq,
166 static_assert(
NumInput == InScalarPerVectorSeq::Size() &&
167 NumOutput == OutScalarPerVectorSeq::Size() &&
169 "Tuple size is inconsistent with the number of in/out!");
174 static_assert((SrcVectorDim ==
I0 || SrcVectorDim ==
I1) &&
175 (DstVectorDim ==
I0 || DstVectorDim ==
I1),
176 "Vector dim must be equal to 0 or 1.");
180 __device__
static void Run(
const InGridDescTuple& in_grid_desc_tuple,
181 const OutGridDescTuple& out_grid_desc_tuple,
182 const InDataTypePointerTuple& p_in_global_tuple,
183 const OutDataTypePointerTuple& p_out_global_tuple,
184 const Block2TileMap& block_2_tile_map,
185 const ElementwiseOperation& elementwise_op,
191 using DataTypePointer =
remove_cvref_t<decltype(InDataTypePointerTuple{}[I])>;
200 using DataTypePointer =
remove_cvref_t<decltype(OutDataTypePointerTuple{}[I])>;
209 return make_dynamic_buffer<AddressSpaceEnum::Global>(
210 p_in_global_tuple[I], in_grid_desc_tuple[I].GetElementSpaceSize());
216 return make_dynamic_buffer<AddressSpaceEnum::Global>(
217 p_out_global_tuple[I], out_grid_desc_tuple[I].GetElementSpaceSize());
221 const auto block_work_idx =
224 const index_t m0_block_data_idx_on_grid =
225 __builtin_amdgcn_readfirstlane(block_work_idx[
I0] * M0PerBlock);
226 const index_t m1_block_data_idx_on_grid =
227 __builtin_amdgcn_readfirstlane(block_work_idx[
I1] * M1PerBlock);
230 return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
235 return make_multi_index(m0_block_data_idx_on_grid, m1_block_data_idx_on_grid);
245 using SrcDimAccessOrder =
246 std::conditional_t<SrcVectorDim == I1, Sequence<0, 1>,
Sequence<1, 0>>;
247 using DstDimAccessOrder =
248 std::conditional_t<DstVectorDim == I1, Sequence<0, 1>,
Sequence<1, 0>>;
250 using ThreadClusterLengths =
255 ElementwiseOperation,
258 ThreadClusterLengths,
259 ThreadClusterArrangeOrder,
268 InScalarPerVectorSeq,
269 OutScalarPerVectorSeq,
274 input_thread_grid_offset,
276 output_thread_grid_offset,
278 global_to_global_transfer.Run(
279 in_grid_desc_tuple, in_global_buf_tuple, out_grid_desc_tuple, out_global_buf_tuple,
I0);
#define CK_MIN_BLOCK_PER_CU
Definition: ck.hpp:34
#define CK_MAX_THREAD_PER_BLOCK
Definition: ck.hpp:33
__global__ void kernel_batched_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op, const index_t batch_count, const std::array< index_t, NumInputs > input_batch_strides, const std::array< index_t, NumOutputs > output_batch_strides)
Definition: gridwise_elementwise_2d.hpp:106
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__device__ index_t get_grid_size()
Definition: get_id.hpp:24
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:901
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
typename remove_pointer< T >::type remove_pointer_t
Definition: type.hpp:303
__global__ void kernel_elementwise_dual(const InBGridDescTuple in_grid_desc_tuple_a, const InBGridDescTuple in_grid_desc_tuple_b, const OutAGridDescTuple out_grid_desc_tuple_a, const OutBGridDescTuple out_grid_desc_tuple_b, const InDataTypePointerTuple p_in_global_tuple_a, const InDataTypePointerTuple p_in_global_tuple_b, const OutDataTypePointerTuple p_out_global_tuple_a, const OutDataTypePointerTuple p_out_global_tuple_b, const Block2TileMapA block_2_tile_map_a, const Block2TileMapB block_2_tile_map_b, const ElementwiseOperation elementwise_op, const index_t a_grid_size)
Definition: gridwise_elementwise_2d.hpp:58
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
int32_t index_t
Definition: ck.hpp:289
typename remove_cv< T >::type remove_cv_t
Definition: type.hpp:298
__global__ void kernel_elementwise(const InGridDescTuple in_grid_desc_tuple, const OutGridDescTuple out_grid_desc_tuple, const InDataTypePointerTuple p_in_global_tuple, const OutDataTypePointerTuple p_out_global_tuple, const Block2TileMap block_2_tile_map, const ElementwiseOperation elementwise_op)
Definition: gridwise_elementwise_2d.hpp:29
Definition: gridwise_elementwise_2d.hpp:162
static constexpr index_t NumInput
Definition: gridwise_elementwise_2d.hpp:163
static constexpr auto I1
Definition: gridwise_elementwise_2d.hpp:172
static __device__ void Run(const InGridDescTuple &in_grid_desc_tuple, const OutGridDescTuple &out_grid_desc_tuple, const InDataTypePointerTuple &p_in_global_tuple, const OutDataTypePointerTuple &p_out_global_tuple, const Block2TileMap &block_2_tile_map, const ElementwiseOperation &elementwise_op, const index_t block_id=get_block_1d_id())
Definition: gridwise_elementwise_2d.hpp:180
static constexpr auto I0
Definition: gridwise_elementwise_2d.hpp:171
static constexpr index_t NumOutput
Definition: gridwise_elementwise_2d.hpp:164
Definition: sequence.hpp:43
Definition: thread_group.hpp:12
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r2.hpp:45
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
Definition: unary_element_wise_operation.hpp:241