16 template <
typename GridwiseMultipleReduction,
19 typename OutDataTypePointerTuple,
21 typename InGridDesc_M_K,
22 typename OutGridDesc_M_Tuple,
23 typename InElementwiseOperationTuple,
24 typename AccElementwiseOperationTuple>
27 const OutGridDesc_M_Tuple out_grid_desc_m_tuple,
28 const InElementwiseOperationTuple in_elementwise_op_tuple,
29 const AccElementwiseOperationTuple acc_elementwise_op_tuple,
31 const InDataType*
const __restrict__ p_in_value_global,
33 OutDataTypePointerTuple p_out_value_global_tuple)
35 GridwiseMultipleReduction::Run(in_grid_desc_m_k,
36 out_grid_desc_m_tuple,
37 in_elementwise_op_tuple,
38 acc_elementwise_op_tuple,
42 p_out_value_global_tuple);
47 typename OutDataTypePointerTuple,
49 typename InGridDesc_M_K,
50 typename OutGridDesc_M_Tuple,
51 typename ReduceOperation,
52 typename InElementwiseOperationTuple,
53 typename AccElementwiseOperationTuple,
61 typename OutDstVectorSizeSeq>
64 static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
65 (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)),
66 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
68 static_assert(NumReduction == OutDataTypePointerTuple::Size() &&
69 NumReduction == OutGridDesc_M_Tuple::Size() &&
70 NumReduction == OutDstVectorSizeSeq::Size() &&
71 NumReduction == InElementwiseOperationTuple::Size() &&
72 NumReduction == AccElementwiseOperationTuple::Size(),
73 "All tuple should have the same size as the number of Reductions!");
97 __device__
static void Run(
const InGridDesc_M_K& in_grid_desc_m_k,
98 const OutGridDesc_M_Tuple& out_grid_desc_m_tuple,
99 const InElementwiseOperationTuple& in_elementwise_op_tuple,
100 const AccElementwiseOperationTuple& acc_elementwise_op_tuple,
102 const InDataType*
const __restrict__ p_in_value_global,
104 OutDataTypePointerTuple p_out_value_global_tuple)
106 const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
108 const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
110 in_grid_desc_m_k.GetElementSpaceSize(),
111 ReduceOperation::template GetIdentityValue<InDataType>());
114 return make_dynamic_buffer<AddressSpaceEnum::Global>(
115 p_out_value_global_tuple[iR], out_grid_desc_m_tuple[iR].GetElementSpaceSize());
127 MThreadSliceSize * KThreadSliceSize,
141 [&](
auto J) { accu_value_buf_tuple(iR)(J) = identityVal; });
146 const auto toReduceLength = in_grid_desc_m_k.GetLength(
Number<1>{});
155 decltype(thread_buffer_desc),
162 in_grid_desc_m_k,
make_multi_index(thread_global_1d_id * MThreadSliceSize, 0));
169 threadwise_src_load.Run(in_grid_desc_m_k,
179 constexpr
auto offset =
180 thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
181 in_elementwise_op_tuple[iR](in_thread_buf_tuple(iR)(
Number<offset>{}),
189 threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
191 reducedLength += KThreadSliceSize;
192 }
while(reducedLength < toReduceLength);
197 using OutDataTypePointer =
remove_cvref_t<decltype(OutDataTypePointerTuple{}[iR])>;
201 acc_elementwise_op_tuple[iR](accu_value_buf_tuple(iR)(I),
202 accu_value_buf_tuple(iR)(I));
204 accu_value_buf_tuple(iR)(I) *= alpha_values[iR];
212 auto threadwise_dst_load =
215 decltype(out_grid_desc_m_tuple[iR]),
216 decltype(reduced_data_desc),
220 OutDstVectorSizeSeq::At(iR),
223 out_grid_desc_m_tuple[iR],
226 threadwise_dst_load.Run(out_grid_desc_m_tuple[iR],
227 out_global_val_buf_tuple(iR),
233 accu_value_buf_tuple(iR)(I) +=
234 type_convert<AccDataType>(priorDstValueBuf[I]) * beta_values[iR];
238 auto threadwise_dst_store =
241 decltype(reduced_data_desc),
242 decltype(out_grid_desc_m_tuple[iR]),
247 OutDstVectorSizeSeq::At(iR),
248 OutMemoryDataOperation,
251 out_grid_desc_m_tuple[iR],
255 threadwise_dst_store.Run(reduced_data_desc,
257 accu_value_buf_tuple[iR],
258 out_grid_desc_m_tuple[iR],
259 out_global_val_buf_tuple(iR));
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:15
InMemoryDataOperationEnum
Definition: ck.hpp:267
__global__ void kernel_multiple_reduce_threadwise(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M_Tuple out_grid_desc_m_tuple, const InElementwiseOperationTuple in_elementwise_op_tuple, const AccElementwiseOperationTuple acc_elementwise_op_tuple, Array< AccDataType, NumReduction > alpha_values, const InDataType *const __restrict__ p_in_value_global, Array< AccDataType, NumReduction > beta_values, OutDataTypePointerTuple p_out_value_global_tuple)
Definition: gridwise_2d_multiple_reduction_threadwise.hpp:26
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__device__ index_t get_thread_global_1d_id()
Definition: get_id.hpp:18
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
int32_t index_t
Definition: ck.hpp:289
Definition: gridwise_2d_multiple_reduction_threadwise.hpp:63
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition: gridwise_2d_multiple_reduction_threadwise.hpp:83
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition: gridwise_2d_multiple_reduction_threadwise.hpp:78
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition: gridwise_2d_multiple_reduction_threadwise.hpp:81
static __device__ void Run(const InGridDesc_M_K &in_grid_desc_m_k, const OutGridDesc_M_Tuple &out_grid_desc_m_tuple, const InElementwiseOperationTuple &in_elementwise_op_tuple, const AccElementwiseOperationTuple &acc_elementwise_op_tuple, Array< AccDataType, NumReduction > alpha_values, const InDataType *const __restrict__ p_in_value_global, Array< AccDataType, NumReduction > beta_values, OutDataTypePointerTuple p_out_value_global_tuple)
Definition: gridwise_2d_multiple_reduction_threadwise.hpp:97
static constexpr bool reorder_thread_cluster
Definition: gridwise_2d_multiple_reduction_threadwise.hpp:75
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_2d_multiple_reduction_threadwise.hpp:91
static constexpr auto I0
Definition: gridwise_2d_multiple_reduction_threadwise.hpp:93
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: reduction_functions_threadwise.hpp:23
static __device__ void Reduce(const SrcBufferType &src_buf, DstBufferType &dst_buf)
Definition: reduction_functions_threadwise.hpp:36
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: threadwise_tensor_slice_transfer.hpp:214
Definition: functional.hpp:100
Definition: reduction_functions_accumulate.hpp:28
Definition: reduction_common.hpp:20
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
Definition: unary_element_wise_operation.hpp:241