/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck_tile/ops/gemm/kernel/streamk_gemm_tile_partitioner_impl.hpp Source File
streamk_gemm_tile_partitioner_impl.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
6 namespace ck_tile {
7 
8 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
10  index_t m, index_t n, index_t k, index_t grid)
11  : grid_{grid}, n_{n}
12 {
13  iters_per_tile_ = integer_divide_ceil(k, KPerBlock);
15 
16  bool big_enough = num_tiles_ > grid_;
17  index_t remainder_tiles = num_tiles_ % grid_;
18 
19  if(remainder_tiles)
20  {
21  sk_tiles_ = big_enough ? full_tiles_ * grid_ + (num_tiles_ % grid_) : num_tiles_;
22  sk_tiles_ = min(num_tiles_, sk_tiles_);
23  sk_ctas_ = grid_;
24  total_sk_iters_ = sk_tiles_ * iters_per_tile_;
25 
26  // If there still isn't enough work to saturate all CUs, then just revert to DP only.
27  if(total_sk_iters_ < grid_)
28  {
29  sk_tiles_ = 0;
30  sk_ctas_ = 0;
31  total_sk_iters_ = 0;
32  }
33  }
34  else // Full DP (i.e., no Stream-K)
35  {
36  sk_tiles_ = 0;
37  sk_ctas_ = 0;
38  total_sk_iters_ = 0;
39  }
40 
41  iters_per_sk_cta_ = sk_ctas_ ? total_sk_iters_ / sk_ctas_ : 0;
42  extra_iters_ = sk_ctas_ ? total_sk_iters_ % sk_ctas_ : 0;
43 
44  dp_tiles_ = num_tiles_ - sk_tiles_;
45  total_dp_iters_ = dp_tiles_ * iters_per_tile_;
46 }
47 
48 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
51  index_t acc_element_bytes) const noexcept
52 {
53  return MPerBlock * NPerBlock * acc_element_bytes * sk_ctas_;
54 }
55 
56 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
59  const noexcept
60 {
61  return sizeof(index_t) * sk_ctas_;
62 }
63 
64 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
65 CK_TILE_DEVICE void
67  index_t& iter, index_t& iter_end, index_t cta_idx) const noexcept
68 {
69  index_t extra_iters_before_me = ck_tile::min(cta_idx, extra_iters_);
70  iter = total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
71  iter_end = iter + iters_per_sk_cta_ + (cta_idx < extra_iters_);
72 }
73 
74 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
77  index_t iter) const noexcept
78 {
79  return iter / iters_per_tile_;
80 }
81 
82 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
83 CK_TILE_DEVICE void
85  index_t& tile_iter, index_t& tile_iter_end, index_t tile_idx) const noexcept
86 {
87  tile_iter = tile_idx * iters_per_tile_;
88  tile_iter_end = tile_iter + iters_per_tile_;
89 }
90 
91 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
92 CK_TILE_DEVICE /* static */ index_t
94  index_t iter, index_t tile_iter) noexcept
95 {
96  return iter - tile_iter;
97 }
98 
99 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
100 CK_TILE_DEVICE /* static */ index_t
102  index_t tile_iter, index_t iter_end, index_t tile_iter_end) noexcept
103 {
104  return ck_tile::min(iter_end, tile_iter_end) - tile_iter;
105 }
106 
107 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
108 CK_TILE_DEVICE auto
110  index_t tile_idx) const noexcept -> tuple<index_t, index_t>
111 {
112  const index_t n_macro_tiles = integer_divide_ceil(n_, NPerBlock);
113 
114  const index_t im = amd_wave_read_first_lane(tile_idx / n_macro_tiles);
115  const index_t in = amd_wave_read_first_lane(tile_idx - im * n_macro_tiles);
116  return make_tuple(im, in);
117 }
118 
119 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
122  index_t acc_element_bytes) const noexcept
123 {
124  if constexpr(ReductionStrategy == StreamKReductionStrategy::Reduction)
125  {
126 
127  return get_partials_buffer_size(acc_element_bytes) + get_flags_buffer_size();
128  }
129  else // ReductionStrategy is Atomics
130  {
131  return 0;
132  }
133 }
134 
135 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
138  const noexcept
139 {
140  return num_tiles_;
141 }
142 
143 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
146 {
147  return grid_;
148 }
149 
150 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
153 {
154  return dp_tiles_;
155 }
156 
157 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
160 {
161  return sk_tiles_;
162 }
163 
164 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
167 {
168  return sk_ctas_;
169 }
170 
171 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
174  const noexcept
175 {
176  return total_sk_iters_;
177 }
178 
179 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
182  const noexcept
183 {
184  return iters_per_tile_;
185 }
186 
187 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
190  const noexcept
191 {
192  return iters_per_sk_cta_;
193 }
194 
195 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
198  const noexcept
199 {
200  return extra_iters_;
201 }
202 
203 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
206  const noexcept
207 {
208  return total_dp_iters_;
209 }
210 
211 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
214 {
215  return n_;
216 }
217 
218 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
221  const noexcept
222 {
223  // In the case of non-atomic reduction or data-parallel (DP) only, there will always be 1
224  // workgroup writing final results to a given macro tile in C.
225  int num_wgs_per_tile = 1;
226 
227  // Otherwise, for atomics, multiple workgroups may be writing to the same macro tile in C.
228  if(sk_ctas_ > 0 && ReductionStrategy == ck_tile::StreamKReductionStrategy::Atomic)
229  {
230  // If we have DP and SK tiles, this is DP+2TSK which guarantees at most 2 workgroups per
231  // tile. We only need to check that dp_tiles is greater than zero since we know we have SK
232  // workgroups.
233  if(dp_tiles_ > 0)
234  {
235  num_wgs_per_tile = 2;
236  }
237  else
238  {
239  ck_tile::index_t iters_per_sk_cta_non_zero = ck_tile::max(iters_per_sk_cta_, 1);
240  // Estimate the number of workgroups per macro tile.
241  num_wgs_per_tile = (iters_per_tile_ / iters_per_sk_cta_non_zero) +
242  ((iters_per_tile_ % iters_per_sk_cta_non_zero) != 0);
243  }
244  }
245 
246  return std::max(num_wgs_per_tile, 1);
247 }
248 
249 template <typename BlockGemmShapeType,
250  StreamKReductionStrategy ReductionStrategyType,
251  bool Persistent>
253 
254 // child class for Persistent Tile Partitioner
255 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
258  : StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
259 { // inherit from base constructor
260  dp_tiles_per_cta_ = this->dp_tiles_ / this->grid_;
261  extra_dp_tiles_ = this->dp_tiles_ % this->grid_;
262 }
263 
264 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
265 CK_TILE_HOST auto
267  -> dim3
268 {
269  if(extra_dp_tiles_ == 0)
270  {
271  return dim3(this->grid_, 1, 1);
272  }
273  else
274  {
275  return dim3(this->num_tiles_, 1, 1);
276  }
277 }
278 
279 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
282  const noexcept
283 {
284  return dp_tiles_per_cta_;
285 }
286 
287 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
290  const noexcept
291 {
292  return extra_dp_tiles_;
293 }
294 
295 // child class for Non-Persistent Tile Partitioner
296 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
299  : StreamKTilePartitionerBase<BlockGemmShapeType, ReductionStrategyType>(m, n, k, grid)
300 { // inherit from base constructor
301  dp_ctas_ = this->dp_tiles_;
302  dp_start_block_idx_ = 0;
303  sk_start_block_idx_ = this->dp_tiles_;
304 }
305 
306 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
307 CK_TILE_HOST auto
309  -> dim3
310 {
311  return dim3(dp_ctas_ + this->get_sk_ctas(), 1, 1);
312 }
313 
314 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
317  const noexcept
318 {
319  return dp_ctas_;
320 }
321 
322 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
325  const noexcept
326 {
327  return dp_start_block_idx_;
328 }
329 
330 template <typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
333  const noexcept
334 {
335  return sk_start_block_idx_;
336 }
337 
338 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:45
#define CK_TILE_HOST
Definition: config.hpp:44
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:46
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
StreamKReductionStrategy
Definition: streamk_common.hpp:10
int32_t index_t
Definition: integer.hpp:9
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
@ Atomic
Definition: block_to_ctile_map.hpp:1012
@ Reduction
Definition: block_to_ctile_map.hpp:1013
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition: amd_wave_read_first_lane.hpp:100
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:301
Stream-K tile partitioner base class.
Definition: streamk_gemm_tile_partitioner.hpp:24
CK_TILE_HOST_DEVICE index_t get_flags_buffer_size() const noexcept
Calculates the total space needed for the flags buffer.
Definition: streamk_gemm_tile_partitioner_impl.hpp:58
CK_TILE_HOST_DEVICE index_t get_partials_buffer_size(index_t acc_element_bytes) const noexcept
Calculates the total space needed for the partials buffer.
Definition: streamk_gemm_tile_partitioner_impl.hpp:50
index_t grid_
Definition: streamk_gemm_tile_partitioner.hpp:194
static constexpr index_t KPerBlock
Definition: streamk_gemm_tile_partitioner.hpp:28
static constexpr index_t NPerBlock
Definition: streamk_gemm_tile_partitioner.hpp:27
static constexpr index_t MPerBlock
Definition: streamk_gemm_tile_partitioner.hpp:26
index_t num_tiles_
Definition: streamk_gemm_tile_partitioner.hpp:193
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid)
Definition: streamk_gemm_tile_partitioner_impl.hpp:9
index_t dp_tiles_
Definition: streamk_gemm_tile_partitioner.hpp:195
Template for the Stream-K tile partitioner derived struct.
Definition: streamk_gemm_tile_partitioner.hpp:229
Definition: tuple.hpp:192