16 template <
typename GridwiseReduction,
 
   22           typename IndexDataType,
 
   23           typename InGridDesc_M_K,
 
   24           typename OutGridDesc_M,
 
   25           typename InElementwiseOperation,
 
   26           typename AccElementwiseOperation>
 
   28                                          const OutGridDesc_M out_grid_desc_m,
 
   29                                          const InElementwiseOperation in_elementwise_op,
 
   30                                          const AccElementwiseOperation acc_elementwise_op,
 
   32                                          index_t num_k_block_tile_iteration,
 
   34                                          const InDataType* 
const __restrict__ p_in_value_global,
 
   35                                          const IndexDataType* 
const __restrict__ p_in_index_global,
 
   37                                          OutDataType* 
const __restrict__ p_out_value_global,
 
   38                                          IndexDataType* 
const __restrict__ p_out_index_global)
 
   40     if constexpr(!OutputIndex)
 
   42         (void)p_in_index_global;
 
   43         (void)p_out_index_global;
 
   45         GridwiseReduction::Run(in_grid_desc_m_k,
 
   50                                num_k_block_tile_iteration,
 
   58         GridwiseReduction::template RunWithIndex<HaveIndexInput>(in_grid_desc_m_k,
 
   62                                                                  num_k_block_tile_iteration,
 
   72 template <
typename InDataType,
 
   75           typename IndexDataType,
 
   76           typename InGridDesc_M_K,
 
   77           typename OutGridDesc_M,
 
   78           typename ReduceOperation,
 
   79           typename InElementwiseOperation,
 
   80           typename AccElementwiseOperation,
 
   93     static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
 
   94                    (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
 
   95                       (MThreadSliceSize % OutDstVectorSize == 0),
 
   96                   "Invalid thread slice sizes and/or vector sizes configuration, please check!");
 
  139     __device__ 
static void Run(
const InGridDesc_M_K& in_grid_desc_m_k,
 
  140                                const OutGridDesc_M& out_grid_desc_m,
 
  141                                const InElementwiseOperation& in_elementwise_op,
 
  142                                const AccElementwiseOperation& acc_elementwise_op,
 
  144                                index_t num_k_block_tile_iteration,
 
  146                                const InDataType* 
const __restrict__ p_in_value_global,
 
  148                                OutDataType* 
const __restrict__ p_out_value_global)
 
  150         const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
 
  153         __shared__ AccDataType p_reduce_work_buffer[BlockSize];
 
  155         const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  157             in_grid_desc_m_k.GetElementSpaceSize(),
 
  158             ReduceOperation::template GetIdentityValue<InDataType>());
 
  159         auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  160             p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
 
  162         auto reduce_work_buf =
 
  163             make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
 
  174         const index_t blkgroup_id     = block_global_id / block_group_size;
 
  175         const index_t block_local_id  = block_global_id % block_group_size;
 
  177         const auto thread_cluster_idx =
 
  180         const auto thread_m_cluster_id = thread_cluster_idx[
I0];
 
  181         const auto thread_k_cluster_id = thread_cluster_idx[
I1];
 
  192                                                                     decltype(thread_buffer_desc),
 
  201                              block_local_id * reduceSizePerBlock +
 
  202                                  thread_k_cluster_id * KThreadSliceSize));
 
  209             threadwise_src_load.Run(in_grid_desc_m_k,
 
  218                     constexpr 
auto offset = thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
 
  226             threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
 
  229         } 
while(reducedTiles < num_k_block_tile_iteration);
 
  237             if(thread_k_cluster_id == 0)
 
  239                 acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
 
  241                 accu_value_buf(I) *= alpha;
 
  245         if(thread_k_cluster_id == 0)
 
  252                 auto threadwise_dst_load =
 
  256                                                      decltype(reduced_data_desc),
 
  265                                          thread_m_cluster_id * MThreadSliceSize));
 
  267                 threadwise_dst_load.Run(out_grid_desc_m,
 
  274                     accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
 
  278             auto threadwise_dst_store =
 
  281                                                    decltype(reduced_data_desc),
 
  288                                                    OutMemoryDataOperation,
 
  293                                      thread_m_cluster_id * MThreadSliceSize),
 
  296             threadwise_dst_store.Run(reduced_data_desc,
 
  304     template <
bool HaveIndexInput>
 
  305     __device__ 
static void RunWithIndex(
const InGridDesc_M_K& in_grid_desc_m_k,
 
  306                                         const OutGridDesc_M& out_grid_desc_m,
 
  307                                         const InElementwiseOperation in_elementwise_op,
 
  308                                         const AccElementwiseOperation acc_elementwise_op,
 
  309                                         index_t num_k_block_tile_iteration,
 
  311                                         const InDataType* 
const __restrict__ p_in_value_global,
 
  312                                         const IndexDataType* 
const __restrict__ p_in_index_global,
 
  314                                         OutDataType* 
const __restrict__ p_out_value_global,
 
  315                                         IndexDataType* 
const __restrict__ p_out_index_global)
 
  317         using BlockwiseReduceWithIndex =
 
  331         (void)in_elementwise_op;
 
  334         __shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
 
  335         __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
 
  337         const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
 
  339         const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  341             in_grid_desc_m_k.GetElementSpaceSize(),
 
  342             ReduceOperation::template GetIdentityValue<InDataType>());
 
  343         const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  344             p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
 
  345         auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  346             p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
 
  347         auto out_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  348             p_out_index_global, out_grid_desc_m.GetElementSpaceSize());
 
  350         auto reduce_work_val_buf =
 
  351             make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_val_buffer, BlockSize);
 
  352         auto reduce_work_idx_buf =
 
  353             make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_idx_buffer, BlockSize);
 
  360                      MThreadSliceSize * KThreadSliceSize,
 
  370         const auto thread_cluster_idx =
 
  373         const auto thread_m_cluster_id = thread_cluster_idx[
I0];
 
  374         const auto thread_k_cluster_id = thread_cluster_idx[
I1];
 
  380         auto threadwise_src_val_load =
 
  384                                              decltype(thread_buffer_desc),
 
  393                                      thread_m_cluster_id * MThreadSliceSize,
 
  394                                  thread_k_cluster_id * KThreadSliceSize));
 
  397             accu_value_buf(I) = identityVal;
 
  398             accu_index_buf(I) = 0;
 
  405         if constexpr(HaveIndexInput)
 
  407             auto threadwise_src_idx_load =
 
  411                                                  decltype(thread_buffer_desc),
 
  420                                          thread_m_cluster_id * MThreadSliceSize,
 
  421                                      thread_k_cluster_id * KThreadSliceSize));
 
  426                 threadwise_src_val_load.
Run(in_grid_desc_m_k,
 
  431                 threadwise_src_idx_load.Run(in_grid_desc_m_k,
 
  438                     AccDataType tmpValue   = identityVal;
 
  439                     IndexDataType tmpIndex = 0;
 
  442                         constexpr 
auto offset =
 
  443                             thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
 
  445                         AccumulationWithIndex::Calculate(tmpValue,
 
  451                     BlockwiseReduceWithIndex::Reduce(
 
  452                         reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
 
  454                     AccumulationWithIndex::Calculate(
 
  455                         accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
 
  458                 threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
 
  459                 threadwise_src_idx_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
 
  462             } 
while(reducedTiles < num_k_block_tile_iteration);
 
  471                 threadwise_src_val_load.Run(in_grid_desc_m_k,
 
  479                         constexpr 
auto offset =
 
  480                             thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
 
  484                             indexOffset + thread_k_cluster_id * KThreadSliceSize + iK();
 
  491                     AccDataType tmpValue   = identityVal;
 
  492                     IndexDataType tmpIndex = 0;
 
  495                         constexpr 
auto offset =
 
  496                             thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
 
  498                         AccumulationWithIndex::Calculate(tmpValue,
 
  504                     BlockwiseReduceWithIndex::Reduce(
 
  505                         reduce_work_val_buf, reduce_work_idx_buf, tmpValue, tmpIndex);
 
  507                     AccumulationWithIndex::Calculate(
 
  508                         accu_value_buf(iM), tmpValue, accu_index_buf(iM), tmpIndex);
 
  511                 threadwise_src_val_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_step);
 
  515             } 
while(reducedTiles < num_k_block_tile_iteration);
 
  521             if(thread_k_cluster_id == 0)
 
  524                 acc_elementwise_op(accu_value_buf(I), accu_value_buf(I));
 
  526                 accu_value_buf(I) *= alpha;
 
  530         if(thread_k_cluster_id == 0)
 
  537                 auto threadwise_dst_load =
 
  541                                                      decltype(reduced_data_desc),
 
  550                                          thread_m_cluster_id * MThreadSliceSize));
 
  552                 threadwise_dst_load.Run(out_grid_desc_m,
 
  559                     accu_value_buf(I) += type_convert<AccDataType>(priorDstValueBuf[I]) * beta;
 
  563             auto threadwise_dst_val_store =
 
  566                                                    decltype(reduced_data_desc),
 
  578                                      thread_m_cluster_id * MThreadSliceSize),
 
  581             auto threadwise_dst_idx_store =
 
  584                                                    decltype(reduced_data_desc),
 
  596                                      thread_m_cluster_id * MThreadSliceSize),
 
  599             threadwise_dst_val_store.Run(reduced_data_desc,
 
  604             threadwise_dst_idx_store.
Run(reduced_data_desc,
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
InMemoryDataOperationEnum
Definition: ck.hpp:278
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
__global__ void kernel_reduce_multiblock(const InGridDesc_M_K in_grid_desc_m_k, const OutGridDesc_M out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op, index_t block_group_size, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition: gridwise_2d_reduction_multiblock.hpp:27
 
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:300
 
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
 
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
 
Definition: gridwise_2d_reduction_multiblock.hpp:92
 
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition: gridwise_2d_reduction_multiblock.hpp:114
 
static constexpr bool reorder_thread_cluster
Definition: gridwise_2d_reduction_multiblock.hpp:98
 
static __device__ void Run(const InGridDesc_M_K &in_grid_desc_m_k, const OutGridDesc_M &out_grid_desc_m, const InElementwiseOperation &in_elementwise_op, const AccElementwiseOperation &acc_elementwise_op, index_t block_group_size, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global)
Definition: gridwise_2d_reduction_multiblock.hpp:139
 
static constexpr index_t M_BlockTileSize
Definition: gridwise_2d_reduction_multiblock.hpp:134
 
static constexpr auto I0
Definition: gridwise_2d_reduction_multiblock.hpp:131
 
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition: gridwise_2d_reduction_multiblock.hpp:103
 
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition: gridwise_2d_reduction_multiblock.hpp:100
 
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition: gridwise_2d_reduction_multiblock.hpp:106
 
static constexpr auto thread_cluster_desc
Definition: gridwise_2d_reduction_multiblock.hpp:108
 
static constexpr auto I1
Definition: gridwise_2d_reduction_multiblock.hpp:132
 
static __device__ void RunWithIndex(const InGridDesc_M_K &in_grid_desc_m_k, const OutGridDesc_M &out_grid_desc_m, const InElementwiseOperation in_elementwise_op, const AccElementwiseOperation acc_elementwise_op, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, const IndexDataType *const __restrict__ p_in_index_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global, IndexDataType *const __restrict__ p_out_index_global)
Definition: gridwise_2d_reduction_multiblock.hpp:305
 
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_2d_reduction_multiblock.hpp:129
 
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition: gridwise_2d_reduction_multiblock.hpp:112
 
static constexpr index_t K_BlockTileSize
Definition: gridwise_2d_reduction_multiblock.hpp:135
 
Definition: reduction_functions_blockwise.hpp:28
 
static __device__ void Reduce(BufferType &work_buffer, AccDataType &in_out_value)
Definition: reduction_functions_blockwise.hpp:44
 
Definition: reduction_functions_blockwise.hpp:175
 
Definition: sequence.hpp:43
 
Definition: static_buffer.hpp:16
 
Definition: reduction_functions_threadwise.hpp:23
 
static __device__ void Reduce(const SrcBufferType &src_buf, DstBufferType &dst_buf)
Definition: reduction_functions_threadwise.hpp:36
 
Definition: threadwise_tensor_slice_transfer.hpp:39
 
__device__ void Run(const SrcDesc &, const SrcSliceOriginIdx &, const SrcBuffer &src_buf, const DstDesc &dst_desc, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:66
 
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
 
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:276
 
Definition: functional.hpp:100
 
Definition: reduction_functions_accumulate.hpp:65
 
Definition: reduction_functions_accumulate.hpp:28
 
Definition: reduction_common.hpp:20
 
Definition: integral_constant.hpp:20
 
Definition: functional2.hpp:33
 
Definition: unary_element_wise_operation.hpp:308