16 template <
typename GridwiseReduction,
22 typename IndexDataType,
23 typename InGridDesc_M_K,
24 typename OutGridDesc_M,
25 typename InElementwiseOperation,
26 typename AccElementwiseOperation>
28 const OutGridDesc_M out_grid_desc_m,
29 const InElementwiseOperation in_elementwise_op,
30 const AccElementwiseOperation acc_elementwise_op,
32 index_t num_k_block_tile_iteration,
34 const InDataType*
const __restrict__ p_in_value_global,
35 const IndexDataType*
const __restrict__ p_in_index_global,
37 OutDataType*
const __restrict__ p_out_value_global,
38 IndexDataType*
const __restrict__ p_out_index_global)
40 if constexpr(!OutputIndex)
42 (void)p_in_index_global;
43 (void)p_out_index_global;
45 GridwiseReduction::Run(in_grid_desc_m_k,
50 num_k_block_tile_iteration,
58 GridwiseReduction::template RunWithIndex<HaveIndexInput>(in_grid_desc_m_k,
62 num_k_block_tile_iteration,
72 template <
typename InDataType,
75 typename IndexDataType,
76 typename InGridDesc_M_K,
77 typename OutGridDesc_M,
78 typename ReduceOperation,
79 typename InElementwiseOperation,
80 typename AccElementwiseOperation,
93 static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
94 (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
95 (MThreadSliceSize % OutDstVectorSize == 0),
96 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
139 __device__
static void Run(
const InGridDesc_M_K& in_grid_desc_m_k,
140 const OutGridDesc_M& out_grid_desc_m,
141 const InElementwiseOperation& in_elementwise_op,
142 const AccElementwiseOperation& acc_elementwise_op,
144 index_t num_k_block_tile_iteration,
146 const InDataType*
const __restrict__ p_in_value_global,
148 OutDataType*
const __restrict__ p_out_value_global)
150 const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
153 __shared__ AccDataType p_reduce_work_buffer[BlockSize];
155 const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
157 in_grid_desc_m_k.GetElementSpaceSize(),
158 ReduceOperation::template GetIdentityValue<InDataType>());
159 auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
160 p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
162 auto reduce_work_buf =
163 make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
174 const index_t blkgroup_id = block_global_id / block_group_size;
175 const index_t block_local_id = block_global_id % block_group_size;
177 const auto thread_cluster_idx =
180 const auto thread_m_cluster_id = thread_cluster_idx[
I0];
181 const auto thread_k_cluster_id = thread_cluster_idx[
I1];
192 decltype(thread_buffer_desc),
201 block_local_id * reduceSizePerBlock +
202 thread_k_cluster_id * KThreadSliceSize));
209 threadwise_src_load.Run(in_grid_desc_m_k,
218 constexpr
auto offset = thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
226 threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
229 }
while(reducedTiles < num_k_block_tile_iteration);
237 if(thread_k_cluster_id == 0)
239 acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
241 accu_value_buf(I) *= alpha;
245 if(thread_k_cluster_id == 0)
252 auto threadwise_dst_load =
256 decltype(reduced_data_desc),
265 thread_m_cluster_id * MThreadSliceSize));
267 threadwise_dst_load.Run(out_grid_desc_m,
274 accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
278 auto threadwise_dst_store =
281 decltype(reduced_data_desc),
288 OutMemoryDataOperation,
293 thread_m_cluster_id * MThreadSliceSize),
296 threadwise_dst_store.Run(reduced_data_desc,
304 template <
bool HaveIndexInput>
305 __device__
static void RunWithIndex(
const InGridDesc_M_K& in_grid_desc_m_k,
306 const OutGridDesc_M& out_grid_desc_m,
307 const InElementwiseOperation in_elementwise_op,
308 const AccElementwiseOperation acc_elementwise_op,
309 index_t num_k_block_tile_iteration,
311 const InDataType*
const __restrict__ p_in_value_global,
312 const IndexDataType*
const __restrict__ p_in_index_global,
314 OutDataType*
const __restrict__ p_out_value_global,
315 IndexDataType*
const __restrict__ p_out_index_global)
317 using BlockwiseReduceWithIndex =
331 (void)in_elementwise_op;
334 __shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
335 __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
337 const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
339 const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
341 in_grid_desc_m_k.GetElementSpaceSize(),
342 ReduceOperation::template GetIdentityValue<InDataType>());
343 const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
344 p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
345 auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
346 p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
347 auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
348 p_out_index_global, out_grid_desc_m.GetElementSpaceSize());
350 auto reduce_work_val_buf =
351 make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
352 auto reduce_work_idx_buf =
353 make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
360 MThreadSliceSize * KThreadSliceSize,
370 const auto thread_cluster_idx =
373 const auto thread_m_cluster_id = thread_cluster_idx[
I0];
374 const auto thread_k_cluster_id = thread_cluster_idx[
I1];
380 auto threadwise_src_val_load =
384 decltype(thread_buffer_desc),
393 thread_m_cluster_id * MThreadSliceSize,
394 thread_k_cluster_id * KThreadSliceSize));
397 accu_value_buf(I) = identityVal;
398 accu_index_buf(I) = 0;
405 if constexpr(HaveIndexInput)
407 auto threadwise_src_idx_load =
411 decltype(thread_buffer_desc),
420 thread_m_cluster_id * MThreadSliceSize,
421 thread_k_cluster_id * KThreadSliceSize));
426 threadwise_src_val_load.
Run(in_grid_desc_m_k,
431 threadwise_src_idx_load.Run(in_grid_desc_m_k,
438 AccDataType tmpValue = identityVal;
439 IndexDataType tmpIndex = 0;
442 constexpr
auto offset =
443 thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
445 AccumulationWithIndex::Calculate(tmpValue,
451 BlockwiseReduceWithIndex::Reduce(
452 reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
454 AccumulationWithIndex::Calculate(
455 accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
458 threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
459 threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
462 }
while(reducedTiles < num_k_block_tile_iteration);
471 threadwise_src_val_load.Run(in_grid_desc_m_k,
479 constexpr
auto offset =
480 thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
484 indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
491 AccDataType tmpValue = identityVal;
492 IndexDataType tmpIndex = 0;
495 constexpr
auto offset =
496 thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
498 AccumulationWithIndex::Calculate(tmpValue,
504 BlockwiseReduceWithIndex::Reduce(
505 reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
507 AccumulationWithIndex::Calculate(
508 accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
511 threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
515 }
while(reducedTiles < num_k_block_tile_iteration);
521 if(thread_k_cluster_id == 0)
524 acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
526 accu_value_buf(I) *= alpha;
530 if(thread_k_cluster_id == 0)
537 auto threadwise_dst_load =
541 decltype(reduced_data_desc),
550 thread_m_cluster_id * MThreadSliceSize));
552 threadwise_dst_load.Run(out_grid_desc_m,
559 accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
563 auto threadwise_dst_val_store =
566 decltype(reduced_data_desc),
578 thread_m_cluster_id * MThreadSliceSize),
581 auto threadwise_dst_idx_store =
584 decltype(reduced_data_desc),
596 thread_m_cluster_id * MThreadSliceSize),
599 threadwise_dst_val_store.Run(reduced_data_desc,
604 threadwise_dst_idx_store.
Run(reduced_data_desc,
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
InMemoryDataOperationEnum
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__global__ void kernel_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op, index_t block_group_size, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition: gridwise_2d_reduction_multiblock.hpp:27
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:16
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
Definition: gridwise_2d_reduction_multiblock.hpp:92
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition: gridwise_2d_reduction_multiblock.hpp:114
static constexpr bool reorder_thread_cluster
Definition: gridwise_2d_reduction_multiblock.hpp:98
static __device__ void Run(const InGridDesc_M_K &in_grid_desc_m_k, const OutGridDesc_M &out_grid_desc_m, const InElementwiseOperation &in_elementwise_op, const AccElementwiseOperation &acc_elementwise_op, index_t block_group_size, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global)
Definition: gridwise_2d_reduction_multiblock.hpp:139
static constexpr index_t M_BlockTileSize
Definition: gridwise_2d_reduction_multiblock.hpp:134
static constexpr auto I0
Definition: gridwise_2d_reduction_multiblock.hpp:131
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition: gridwise_2d_reduction_multiblock.hpp:103
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition: gridwise_2d_reduction_multiblock.hpp:100
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition: gridwise_2d_reduction_multiblock.hpp:106
static constexpr auto thread_cluster_desc
Definition: gridwise_2d_reduction_multiblock.hpp:108
static constexpr auto I1
Definition: gridwise_2d_reduction_multiblock.hpp:132
static __device__ void RunWithIndex(const InGridDesc_M_K &in_grid_desc_m_k, const OutGridDesc_M &out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition: gridwise_2d_reduction_multiblock.hpp:305
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_2d_reduction_multiblock.hpp:129
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition: gridwise_2d_reduction_multiblock.hpp:112
static constexpr index_t K_BlockTileSize
Definition: gridwise_2d_reduction_multiblock.hpp:135
Definition: reduction_functions_blockwise.hpp:28
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition: reduction_functions_blockwise.hpp:44
Definition: reduction_functions_blockwise.hpp:175
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: reduction_functions_threadwise.hpp:23
static __device__ void Reduce(const SrcBufferType &src_buf, DstBufferType &dst_buf)
Definition: reduction_functions_threadwise.hpp:36
Definition: threadwise_tensor_slice_transfer.hpp:39
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:66
Definition: threadwise_tensor_slice_transfer.hpp:214
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:243
Definition: functional.hpp:100
Definition: reduction_functions_accumulate.hpp:65
Definition: reduction_functions_accumulate.hpp:28
Definition: reduction_common.hpp:20
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
Definition: unary_element_wise_operation.hpp:241