/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/grid/gridwise_softmax.hpp Source File
gridwise_softmax.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
14 
15 namespace ck {
16 
17 template <typename GridwiseReduction,
18  typename InDataType,
19  typename OutDataType,
20  typename AccDataType,
21  typename GridDesc_M_K>
22 __global__ void kernel_softmax(const GridDesc_M_K in_grid_desc_m_k,
23  const GridDesc_M_K out_grid_desc_m_k,
24  index_t block_group_size,
25  index_t num_k_block_tile_iteration,
26  AccDataType alpha,
27  const InDataType* const __restrict__ p_in_value_global,
28  AccDataType beta,
29  OutDataType* const __restrict__ p_out_value_global)
30 {
31  GridwiseReduction::Run(in_grid_desc_m_k,
32  out_grid_desc_m_k,
33  block_group_size,
34  num_k_block_tile_iteration,
35  alpha,
36  p_in_value_global,
37  beta,
38  p_out_value_global);
39 };
40 
41 template <typename InDataType,
42  typename OutDataType,
43  typename AccDataType,
44  typename GridDesc_M_K,
45  index_t BlockSize,
46  index_t MThreadClusterSize,
47  index_t KThreadClusterSize,
48  index_t MThreadSliceSize,
49  index_t KThreadSliceSize,
50  index_t InSrcVectorDim,
51  index_t InSrcVectorSize,
52  index_t OutDstVectorSize,
53  bool SweepOnce>
55 {
56  static_assert(((InSrcVectorDim == 0 && MThreadSliceSize % InSrcVectorSize == 0) ||
57  (InSrcVectorDim == 1 && KThreadSliceSize % InSrcVectorSize == 0)) &&
58  (KThreadSliceSize % OutDstVectorSize == 0),
59  "Invalid thread slice sizes and/or vector sizes configuration, please check!");
60 
61  static constexpr bool reorder_thread_cluster = (InSrcVectorDim == 0);
62 
64 
67 
70 
71  static constexpr auto thread_cluster_desc =
73 
78 
80 
81  static constexpr auto I0 = Number<0>{};
82  static constexpr auto I1 = Number<1>{};
83 
84  static constexpr index_t M_BlockTileSize = MThreadClusterSize * MThreadSliceSize;
85  static constexpr index_t K_BlockTileSize = KThreadClusterSize * KThreadSliceSize;
86 
87  __device__ static void Run(const GridDesc_M_K& in_grid_desc_m_k,
88  const GridDesc_M_K& out_grid_desc_m_k,
89  index_t block_group_size,
90  index_t num_k_block_tile_iteration,
91  AccDataType alpha,
92  const InDataType* const __restrict__ p_in_value_global,
93  AccDataType beta,
94  OutDataType* const __restrict__ p_out_value_global)
95  {
96  if constexpr(SweepOnce)
97  {
98  num_k_block_tile_iteration = 1;
99  }
100 
101  // LDS
102  __shared__ AccDataType p_reduce_work_buffer[BlockSize];
103 
104  auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
105  p_out_value_global, out_grid_desc_m_k.GetElementSpaceSize());
106 
107  auto reduce_work_buf =
108  make_dynamic_buffer<AddressSpaceEnum::Lds>(p_reduce_work_buffer, BlockSize);
109 
111  in_thread_buf;
112 
114  out_thread_buf;
115 
117 
118  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
119  max_value_buf(I) = reduce::Max::template GetIdentityValue<AccDataType>();
120  });
121 
123 
124  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
125  accu_value_buf(I) = reduce::Add::template GetIdentityValue<AccDataType>();
126  });
127 
128  const index_t thread_local_id = get_thread_local_1d_id();
129  const index_t block_global_id = get_block_1d_id();
130  const index_t blkgroup_id = block_global_id / block_group_size;
131  const index_t block_local_id = block_global_id % block_group_size;
132 
133  const auto thread_cluster_idx =
134  thread_cluster_desc.CalculateBottomIndex(make_multi_index(thread_local_id));
135 
136  const auto thread_m_cluster_id = thread_cluster_idx[I0];
137  const auto thread_k_cluster_id = thread_cluster_idx[I1];
138 
139  const index_t reduceSizePerBlock = K_BlockTileSize * num_k_block_tile_iteration;
140 
141  using ThreadBufferLengths = Sequence<MThreadSliceSize, KThreadSliceSize>;
142  constexpr auto thread_buffer_desc = make_naive_tensor_descriptor_packed(
144 
145  // Normally, 0 as invalid element value is adequate since 0 makes no contribution to
146  // accumulated result. However, in stable softmax, all values 0s or not are subtracted by
147  // another value_max. As numbers become non-zero, effectively it allows invalid values to
148  // slip through and contribute to the accumulated result.
149  //
150  // The trick here is leveraging the fact that many math functions (add, sub, exp, ...)
151  // propagate NaNs when operands have NaNs involved. By initialiing invalid element value
152  // with NaN, an invalid value doing math manipulations is still NaN, which in turn can still
153  // be identified as an invalid value. We can then discard the invalid values which
154  // originally failed the bound check during accumulation. This allows to ignore values that
155  // failed bound check even after multiple math manipulations.
156  //
157  // NOTE: reset coordinate after every step because the same threadwise copy will sweep
158  // through global memory 3 times back and forth
159  auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2<InDataType,
160  AccDataType,
161  GridDesc_M_K,
162  decltype(thread_buffer_desc),
163  ThreadBufferLengths,
165  InSrcVectorDim,
166  InSrcVectorSize,
167  1,
168  true /* ResetCoordAfterRun */,
169  true /* InvalidElementAsNaN */>(
170  in_grid_desc_m_k,
171  make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
172  block_local_id * reduceSizePerBlock +
173  thread_k_cluster_id * KThreadSliceSize));
174 
175  auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<OutDataType,
176  AccDataType,
177  GridDesc_M_K,
178  decltype(thread_buffer_desc),
179  ThreadBufferLengths,
181  InSrcVectorDim,
182  InSrcVectorSize,
183  1,
184  false>(
185  out_grid_desc_m_k,
186  make_multi_index(blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
187  block_local_id * reduceSizePerBlock +
188  thread_k_cluster_id * KThreadSliceSize));
189 
190  auto threadwise_dst_store =
192  OutDataType,
193  decltype(thread_buffer_desc),
194  GridDesc_M_K,
196  ThreadBufferLengths,
198  InSrcVectorDim,
199  OutDstVectorSize,
201  1,
202  true>(
203  out_grid_desc_m_k,
205  blkgroup_id * M_BlockTileSize + thread_m_cluster_id * MThreadSliceSize,
206  block_local_id * reduceSizePerBlock + thread_k_cluster_id * KThreadSliceSize),
207  PassThroughOp{});
208 
209  constexpr auto in_thread_copy_fwd_step =
210  make_multi_index(0, SweepOnce ? 0 : K_BlockTileSize);
211  constexpr auto in_thread_copy_bwd_step =
212  make_multi_index(0, SweepOnce ? 0 : -K_BlockTileSize);
213 
217  using BlockwiseMaxReduce = PartitionedBlockwiseReduction<
218  AccDataType,
219  BlockSize,
222  reduce::Max,
223  false, // param ignored
225 
226  using ThreadwiseMaxReduce =
227  ThreadwiseReduction<AccDataType,
230  reduce::Max,
231  false, // param ignored
233 
234  const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
235  p_in_value_global, in_grid_desc_m_k.GetElementSpaceSize());
236 
237  index_t reducedTiles = 0;
238  do
239  {
240  threadwise_src_load.Run(in_grid_desc_m_k,
241  in_global_val_buf,
242  thread_buffer_desc,
243  make_tuple(I0, I0),
244  in_thread_buf);
245 
246  ThreadwiseMaxReduce::Reduce(in_thread_buf, max_value_buf);
247 
248  threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
249 
250  reducedTiles++;
251  } while(reducedTiles < num_k_block_tile_iteration);
252 
253  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
254  BlockwiseMaxReduce::Reduce(reduce_work_buf, max_value_buf(I));
255  block_sync_lds();
256  });
257 
258  threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
259 
263  using BlockwiseSumReduce = PartitionedBlockwiseReduction<
264  AccDataType,
265  BlockSize,
268  reduce::Add,
269  false, // ignored
271 
272  using ThreadwiseSumReduce =
273  ThreadwiseReduction<AccDataType,
276  reduce::Add,
277  false, // ignored
279 
280  reducedTiles = 0;
281  do
282  {
283  if constexpr(!SweepOnce)
284  {
285  threadwise_src_load.Run(in_grid_desc_m_k,
286  in_global_val_buf,
287  thread_buffer_desc,
288  make_tuple(I0, I0),
289  in_thread_buf);
290  }
291 
292  // do element-wise pre-reduction operation
293  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
294  static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
295  constexpr auto offset = thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
296  out_thread_buf(Number<offset>{}) =
297  math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM));
298  });
299  });
300 
301  ThreadwiseSumReduce::Reduce(out_thread_buf, accu_value_buf);
302 
303  threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_bwd_step);
304 
305  reducedTiles++;
306  } while(reducedTiles < num_k_block_tile_iteration);
307 
308  block_sync_lds(); // wait for reading being complete before writing to LDS
309  static_for<0, MThreadSliceSize, 1>{}([&](auto I) {
310  BlockwiseSumReduce::Reduce(reduce_work_buf, accu_value_buf(I));
311  block_sync_lds();
312  });
313 
314  threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
315 
319  reducedTiles = 0;
320  if(float_equal_zero{}(beta))
321  {
322  do
323  {
324  if constexpr(!SweepOnce)
325  {
326  threadwise_src_load.Run(in_grid_desc_m_k,
327  in_global_val_buf,
328  thread_buffer_desc,
329  make_tuple(I0, I0),
330  in_thread_buf);
331  }
332 
333  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
334  // out = alpha * exp(x - max(x)) / sum(exp(x - max(x)))
335  static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
336  constexpr auto offset =
337  thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
338  out_thread_buf(Number<offset>{}) =
339  alpha * math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)) /
340  accu_value_buf(iM);
341  });
342  });
343 
344  threadwise_dst_store.Run(thread_buffer_desc,
345  make_tuple(I0, I0),
346  out_thread_buf,
347  out_grid_desc_m_k,
348  out_global_val_buf);
349 
350  threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
351  threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step);
352 
353  reducedTiles++;
354  } while(reducedTiles < num_k_block_tile_iteration);
355  }
356  else
357  {
359  AccDataType,
360  MThreadSliceSize * KThreadSliceSize,
361  true>
362  in_prior_dst_buf;
363  do
364  {
365  if constexpr(!SweepOnce)
366  {
367  threadwise_src_load.Run(in_grid_desc_m_k,
368  in_global_val_buf,
369  thread_buffer_desc,
370  make_tuple(I0, I0),
371  in_thread_buf);
372  }
373  threadwise_dst_load.Run(out_grid_desc_m_k,
374  out_global_val_buf,
375  thread_buffer_desc,
376  make_tuple(I0, I0),
377  in_prior_dst_buf);
378 
379  static_for<0, MThreadSliceSize, 1>{}([&](auto iM) {
380  // out = alpha * exp(x - max(x)) / sum(exp(x - max(x))) + beta * prior_out
381  static_for<0, KThreadSliceSize, 1>{}([&](auto iK) {
382  constexpr auto offset =
383  thread_buffer_desc.CalculateOffset(make_tuple(iM, iK));
384  out_thread_buf(Number<offset>{}) =
385  alpha * math::exp(in_thread_buf(Number<offset>{}) - max_value_buf(iM)) /
386  accu_value_buf(iM) +
387  beta * in_prior_dst_buf(Number<offset>{});
388  });
389  });
390 
391  threadwise_dst_store.Run(thread_buffer_desc,
392  make_tuple(I0, I0),
393  out_thread_buf,
394  out_grid_desc_m_k,
395  out_global_val_buf);
396 
397  threadwise_src_load.MoveSrcSliceWindow(in_grid_desc_m_k, in_thread_copy_fwd_step);
398  threadwise_dst_store.MoveDstSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step);
399  threadwise_dst_load.MoveSrcSliceWindow(out_grid_desc_m_k, in_thread_copy_fwd_step);
400 
401  reducedTiles++;
402  } while(reducedTiles < num_k_block_tile_iteration);
403  }
404  }
405 };
406 
407 } // namespace ck
__host__ T exp(T x)
Definition: math_v2.hpp:391
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__device__ index_t get_block_1d_id()
Definition: get_id.hpp:22
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:16
__global__ void kernel_softmax(const GridDesc_M_K in_grid_desc_m_k, const GridDesc_M_K out_grid_desc_m_k, index_t block_group_size, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global)
Definition: gridwise_softmax.hpp:22
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
__host__ constexpr __device__ auto make_cluster_descriptor(const Lengths &lengths, ArrangeOrder order=typename arithmetic_sequence_gen< 0, Lengths::Size(), 1 >::type{})
Definition: cluster_descriptor.hpp:13
Definition: gridwise_softmax.hpp:55
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadClusterArrangeOrder
Definition: gridwise_softmax.hpp:69
Sequence< MThreadClusterSize, KThreadClusterSize > ThreadClusterLengths_M_K
Definition: gridwise_softmax.hpp:63
static constexpr auto I0
Definition: gridwise_softmax.hpp:81
static __device__ void Run(const GridDesc_M_K &in_grid_desc_m_k, const GridDesc_M_K &out_grid_desc_m_k, index_t block_group_size, index_t num_k_block_tile_iteration, AccDataType alpha, const InDataType *const __restrict__ p_in_value_global, AccDataType beta, OutDataType *const __restrict__ p_out_value_global)
Definition: gridwise_softmax.hpp:87
static constexpr index_t M_BlockTileSize
Definition: gridwise_softmax.hpp:84
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}, Number< KThreadSliceSize >{}))) ThreadReduceSrcDesc_M_K
Definition: gridwise_softmax.hpp:75
static constexpr bool reorder_thread_cluster
Definition: gridwise_softmax.hpp:61
tensor_operation::element_wise::PassThrough PassThroughOp
Definition: gridwise_softmax.hpp:79
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number< MThreadSliceSize >{}))) ThreadReduceDstDesc_M
Definition: gridwise_softmax.hpp:77
static constexpr auto I1
Definition: gridwise_softmax.hpp:82
static constexpr auto thread_cluster_desc
Definition: gridwise_softmax.hpp:71
static constexpr index_t K_BlockTileSize
Definition: gridwise_softmax.hpp:85
typename conditional< reorder_thread_cluster, Sequence< 1, 0 >, Sequence< 0, 1 > >::type ThreadBufferDimAccessOrder
Definition: gridwise_softmax.hpp:66
Definition: reduction_functions_blockwise.hpp:28
Definition: sequence.hpp:43
Definition: static_buffer.hpp:16
Definition: reduction_functions_threadwise.hpp:23
Definition: threadwise_tensor_slice_transfer.hpp:39
Definition: threadwise_tensor_slice_transfer.hpp:214
__device__ void Run(const SrcDesc &src_desc, const SrcBuffer &src_buf, const DstDesc &, const DstSliceOriginIdx &, DstBuffer &dst_buf)
Definition: threadwise_tensor_slice_transfer.hpp:243
__device__ void MoveSrcSliceWindow(const SrcDesc &src_desc, const Index &src_slice_origin_step_idx)
Definition: threadwise_tensor_slice_transfer.hpp:355
Definition: functional.hpp:100
Definition: reduction_functions_accumulate.hpp:17
Definition: reduction_common.hpp:20
Definition: integral_constant.hpp:10
Definition: reduction_operator.hpp:37
Definition: reduction_operator.hpp:163
Definition: functional2.hpp:31
Definition: unary_element_wise_operation.hpp:241