17 template <
typename GridwiseReduction,
21 typename GridDesc_M_K>
23 const GridDesc_M_K out_grid_desc_m_k,
25 index_t num_k_block_tile_iteration,
27 const InDataType*
const __restrict__ p_in_value_global,
29 OutDataType*
const __restrict__ p_out_value_global)
31 GridwiseReduction::Run(in_grid_desc_m_k,
34 num_k_block_tile_iteration,
41 template <
typename InDataType,
44 typename GridDesc_M_K,
56 static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
57 (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
58 (KThreadSliceSize % OutDstVectorSize == 0),
59 "Invalid thread slice sizes and/or vector sizes configuration, please check!");
87 __device__
static void Run(
const GridDesc_M_K& in_grid_desc_m_k,
88 const GridDesc_M_K& out_grid_desc_m_k,
90 index_t num_k_block_tile_iteration,
92 const InDataType*
const __restrict__ p_in_value_global,
94 OutDataType*
const __restrict__ p_out_value_global)
96 if constexpr(SweepOnce)
98 num_k_block_tile_iteration = 1;
102 __shared__ AccDataType p_reduce_work_buffer[BlockSize];
104 auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
105 p_out_value_global, out_grid_desc_m_k.GetElementSpaceSize());
107 auto reduce_work_buf =
108 make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
119 max_value_buf(I) = reduce::Max::template GetIdentityValue<AccDataType>();
125 accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
130 const index_t blkgroup_id = block_global_id / block_group_size;
131 const index_t block_local_id = block_global_id % block_group_size;
133 const auto thread_cluster_idx =
136 const auto thread_m_cluster_id = thread_cluster_idx[
I0];
137 const auto thread_k_cluster_id = thread_cluster_idx[
I1];
162 decltype(thread_buffer_desc),
172 block_local_id * reduceSizePerBlock +
173 thread_k_cluster_id * KThreadSliceSize));
178 decltype(thread_buffer_desc),
187 block_local_id * reduceSizePerBlock +
188 thread_k_cluster_id * KThreadSliceSize));
190 auto threadwise_dst_store =
193 decltype(thread_buffer_desc),
205 blkgroup_id *
M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
206 block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize),
209 constexpr
auto in_thread_copy_fwd_step =
211 constexpr
auto in_thread_copy_bwd_step =
226 using ThreadwiseMaxReduce =
234 const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
235 p_in_value_global, in_grid_desc_m_k.GetElementSpaceSize());
240 threadwise_src_load.Run(in_grid_desc_m_k,
246 ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf);
248 threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
251 }
while(reducedTiles < num_k_block_tile_iteration);
254 BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I));
258 threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
272 using ThreadwiseSumReduce =
283 if constexpr(!SweepOnce)
285 threadwise_src_load.Run(in_grid_desc_m_k,
295 constexpr
auto offset = thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
301 ThreadwiseSumReduce::Reduce(out_thread_buf, accu_value_buf);
303 threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
306 }
while(reducedTiles < num_k_block_tile_iteration);
310 BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I));
314 threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
324 if constexpr(!SweepOnce)
326 threadwise_src_load.Run(in_grid_desc_m_k,
336 constexpr
auto offset =
337 thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
344 threadwise_dst_store.
Run(thread_buffer_desc,
350 threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
351 threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step);
354 }
while(reducedTiles < num_k_block_tile_iteration);
360 MThreadSliceSize * KThreadSliceSize,
365 if constexpr(!SweepOnce)
367 threadwise_src_load.
Run(in_grid_desc_m_k,
373 threadwise_dst_load.
Run(out_grid_desc_m_k,
382 constexpr
auto offset =
383 thread_buffer_desc.CalculateOffset(
make_tuple(iM, iK));
391 threadwise_dst_store.
Run(thread_buffer_desc,
397 threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
398 threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step);
402 }
while(reducedTiles < num_k_block_tile_iteration);
__host__ T exp(T x)
Definition: math_v2.hpp:391
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:16
__global__ void kernel_softmax(const GridDesc_M_K in_grid_desc_m_k, const GridDesc_M_K out_grid_desc_m_k, index_t block_group_size, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global)
Definition: gridwise_softmax.hpp:22
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
Definition: gridwise_softmax.hpp:55
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition: gridwise_softmax.hpp:69
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition: gridwise_softmax.hpp:63
static constexpr auto I0
Definition: gridwise_softmax.hpp:81
static __device__ void Run(const GridDesc_M_K &in_grid_desc_m_k, const GridDesc_M_K &out_grid_desc_m_k, index_t block_group_size, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global)
Definition: gridwise_softmax.hpp:87
static constexpr index_t M_BlockTileSize
Definition: gridwise_softmax.hpp:84
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition: gridwise_softmax.hpp:75
static constexpr bool reorder_thread_cluster
Definition: gridwise_softmax.hpp:61
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_softmax.hpp:79
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition: gridwise_softmax.hpp:77
static constexpr auto I1
Definition: gridwise_softmax.hpp:82
static constexpr auto thread_cluster_desc
Definition: gridwise_softmax.hpp:71
static constexpr index_t K_BlockTileSize
Definition: gridwise_softmax.hpp:85
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition: gridwise_softmax.hpp:66
Definition: reduction_functions_blockwise.hpp:28
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: reduction_functions_threadwise.hpp:23
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: threadwise_tensor_slice_transfer.hpp:214
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:243
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:355
Definition: functional.hpp:100
Definition: reduction_functions_accumulate.hpp:17
Definition: reduction_common.hpp:20
Definition: integral_constant.hpp:10
Definition: reduction_operator.hpp:37
Definition: reduction_operator.hpp:163
Definition: functional2.hpp:31
Definition: unary_element_wise_operation.hpp:241