17 template <
typename GridwiseReduction,
 
   21           typename GridDesc_M_K>
 
   23                                const GridDesc_M_K out_grid_desc_m_k,
 
   25                                index_t num_k_block_tile_iteration,
 
   27                                const InDataType* 
const __restrict__ p_in_value_global,
 
   29                                OutDataType* 
const __restrict__ p_out_value_global)
 
   31     GridwiseReduction::Run(in_grid_desc_m_k,
 
   34                            num_k_block_tile_iteration,
 
   41 template <
typename InDataType,
 
   44           typename GridDesc_M_K,
 
   56     static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
 
   57                    (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
 
   58                       (KThreadSliceSize % OutDstVectorSize == 0),
 
   59                   "Invalid thread slice sizes and/or vector sizes configuration, please check!");
 
   87     __device__ 
static void Run(
const GridDesc_M_K& in_grid_desc_m_k,
 
   88                                const GridDesc_M_K& out_grid_desc_m_k,
 
   90                                index_t num_k_block_tile_iteration,
 
   92                                const InDataType* 
const __restrict__ p_in_value_global,
 
   94                                OutDataType* 
const __restrict__ p_out_value_global)
 
   96         if constexpr(SweepOnce)
 
   98             num_k_block_tile_iteration = 1;
 
  102         __shared__ AccDataType p_reduce_work_buffer[BlockSize];
 
  104         auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  105             p_out_value_global, out_grid_desc_m_k.GetElementSpaceSize());
 
  107         auto reduce_work_buf =
 
  108             make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
 
  119             max_value_buf(I) = reduce::Max::template GetIdentityValue<AccDataType>();
 
  125             accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
 
  130         const index_t blkgroup_id     = block_global_id / block_group_size;
 
  131         const index_t block_local_id  = block_global_id % block_group_size;
 
  133         const auto thread_cluster_idx =
 
  136         const auto thread_m_cluster_id = thread_cluster_idx[
I0];
 
  137         const auto thread_k_cluster_id = thread_cluster_idx[
I1];
 
  162                                                                     decltype(thread_buffer_desc),
 
  172                              block_local_id * reduceSizePerBlock +
 
  173                                  thread_k_cluster_id * KThreadSliceSize));
 
  178                                                                     decltype(thread_buffer_desc),
 
  187                              block_local_id * reduceSizePerBlock +
 
  188                                  thread_k_cluster_id * KThreadSliceSize));
 
  190         auto threadwise_dst_store =
 
  193                                                decltype(thread_buffer_desc),
 
  205                     blkgroup_id * 
M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
 
  206                     block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize),
 
  209         constexpr 
auto in_thread_copy_fwd_step =
 
  211         constexpr 
auto in_thread_copy_bwd_step =
 
  226         using ThreadwiseMaxReduce =
 
  234         const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
 
  235             p_in_value_global, in_grid_desc_m_k.GetElementSpaceSize());
 
  240             threadwise_src_load.Run(in_grid_desc_m_k,
 
  246             ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf);
 
  248             threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
 
  251         } 
while(reducedTiles < num_k_block_tile_iteration);
 
  254             BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I));
 
  258         threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
 
  272         using ThreadwiseSumReduce =
 
  283             if constexpr(!SweepOnce)
 
  285                 threadwise_src_load.Run(in_grid_desc_m_k,
 
  295                     constexpr 
auto offset = thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
 
  301             ThreadwiseSumReduce::Reduce(out_thread_buf, accu_value_buf);
 
  303             threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
 
  306         } 
while(reducedTiles < num_k_block_tile_iteration);
 
  310             BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I));
 
  314         threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
 
  324                 if constexpr(!SweepOnce)
 
  326                     threadwise_src_load.Run(in_grid_desc_m_k,
 
  336                         constexpr 
auto offset =
 
  337                             thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
 
  344                 threadwise_dst_store.
Run(thread_buffer_desc,
 
  350                 threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
 
  351                 threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step);
 
  354             } 
while(reducedTiles < num_k_block_tile_iteration);
 
  360                          MThreadSliceSize * KThreadSliceSize,
 
  365                 if constexpr(!SweepOnce)
 
  367                     threadwise_src_load.
Run(in_grid_desc_m_k,
 
  373                 threadwise_dst_load.
Run(out_grid_desc_m_k,
 
  382                         constexpr 
auto offset =
 
  383                             thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
 
  391                 threadwise_dst_store.
Run(thread_buffer_desc,
 
  397                 threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
 
  398                 threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step);
 
  402             } 
while(reducedTiles < num_k_block_tile_iteration);
 
__host__ T exp(T x)
Definition: math_v2.hpp:391
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
 
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:25
 
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
 
int32_t index_t
Definition: ck.hpp:300
 
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:19
 
__global__ void kernel_softmax(const GridDesc_M_K in_grid_desc_m_k, const GridDesc_M_K out_grid_desc_m_k, index_t block_group_size, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global)
Definition: gridwise_softmax.hpp:22
 
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
 
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
 
Definition: gridwise_softmax.hpp:55
 
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition: gridwise_softmax.hpp:69
 
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition: gridwise_softmax.hpp:63
 
static constexpr auto I0
Definition: gridwise_softmax.hpp:81
 
static __device__ void Run(const GridDesc_M_K &in_grid_desc_m_k, const GridDesc_M_K &out_grid_desc_m_k, index_t block_group_size, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global)
Definition: gridwise_softmax.hpp:87
 
static constexpr index_t M_BlockTileSize
Definition: gridwise_softmax.hpp:84
 
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition: gridwise_softmax.hpp:75
 
static constexpr bool reorder_thread_cluster
Definition: gridwise_softmax.hpp:61
 
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_softmax.hpp:79
 
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition: gridwise_softmax.hpp:77
 
static constexpr auto I1
Definition: gridwise_softmax.hpp:82
 
static constexpr auto thread_cluster_desc
Definition: gridwise_softmax.hpp:71
 
static constexpr index_t K_BlockTileSize
Definition: gridwise_softmax.hpp:85
 
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition: gridwise_softmax.hpp:66
 
Definition: reduction_functions_blockwise.hpp:28
 
Definition: sequence.hpp:43
 
Definition: static_buffer.hpp:16
 
Definition: reduction_functions_threadwise.hpp:23
 
Definition: threadwise_tensor_slice_transfer.hpp:39
 
Helper structure that facilitates transfer of source (grid) data to destination threads.
Definition: threadwise_tensor_slice_transfer.hpp:234
 
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:276
 
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:389
 
Definition: functional.hpp:100
 
Definition: reduction_functions_accumulate.hpp:17
 
Definition: reduction_common.hpp:20
 
Definition: integral_constant.hpp:20
 
Definition: reduction_operator.hpp:37
 
Definition: reduction_operator.hpp:163
 
Definition: functional2.hpp:33
 
Definition: unary_element_wise_operation.hpp:308