7 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
23 total_sk_iters_ = sk_tiles_ * iters_per_tile_;
26 if(total_sk_iters_ <
grid_)
40 iters_per_sk_cta_ = sk_ctas_ ? total_sk_iters_ / sk_ctas_ : 0;
41 extra_iters_ = sk_ctas_ ? total_sk_iters_ % sk_ctas_ : 0;
44 total_dp_iters_ =
dp_tiles_ * iters_per_tile_;
47 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
50 index_t acc_element_bytes)
const noexcept
52 return MPerBlock * NPerBlock * acc_element_bytes * sk_ctas_;
55 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
60 return sizeof(
index_t) * sk_ctas_;
63 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
69 iter = total_dp_iters_ + cta_idx * iters_per_sk_cta_ + extra_iters_before_me;
70 iter_end = iter + iters_per_sk_cta_ + (cta_idx < extra_iters_);
73 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
78 return iter / iters_per_tile_;
81 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
86 tile_iter = tile_idx * iters_per_tile_;
87 tile_iter_end = tile_iter + iters_per_tile_;
90 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
95 return iter - tile_iter;
98 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
103 return ck_tile::min(iter_end, tile_iter_end) - tile_iter;
106 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
118 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
121 index_t acc_element_bytes)
const noexcept
126 return get_partials_buffer_size(acc_element_bytes) + get_flags_buffer_size();
134 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
142 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
149 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
156 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
163 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
170 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
175 return total_sk_iters_;
178 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
183 return iters_per_tile_;
186 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
191 return iters_per_sk_cta_;
194 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
202 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
207 return total_dp_iters_;
210 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
217 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
224 int num_wgs_per_tile = 1;
231 num_wgs_per_tile = (iters_per_tile_ / iters_per_sk_cta_non_zero) +
232 ((iters_per_tile_ % iters_per_sk_cta_non_zero) != 0);
235 return std::max(num_wgs_per_tile, 1);
238 template <
typename BlockGemmShapeType,
244 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
252 dp_tiles_per_cta_ = this->dp_tiles_ / this->grid_;
253 extra_dp_tiles_ = this->dp_tiles_ % this->grid_;
256 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
259 const noexcept -> dim3
261 if(extra_dp_tiles_ == 0)
263 return dim3(this->grid_, 1, 1);
267 return dim3(this->num_tiles_, 1, 1);
271 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
276 return dp_tiles_per_cta_;
279 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
284 return extra_dp_tiles_;
288 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
296 dp_ctas_ = this->dp_tiles_;
297 dp_start_block_idx_ = 0;
298 sk_start_block_idx_ = this->dp_tiles_;
301 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
304 const noexcept -> dim3
306 return dim3(dp_ctas_ + this->get_sk_ctas(), 1, 1);
309 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
317 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
322 return dp_start_block_idx_;
325 template <
typename BlockGemmShapeType, StreamKReductionStrategy ReductionStrategyType>
330 return sk_start_block_idx_;
#define CK_TILE_DEVICE
Definition: config.hpp:41
#define CK_TILE_HOST
Definition: config.hpp:40
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:42
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
StreamKReductionStrategy
Definition: streamk_common.hpp:10
int32_t index_t
Definition: integer.hpp:9
ck_tile::index_t estimate_num_wgs_per_tile(index_t sk_ctas, index_t iters_per_sk_cta, index_t iters_per_tile)
Estimates the number of Stream-K workgroups per macro tile in the C tensor.
Definition: streamk_common.hpp:27
constexpr CK_TILE_HOST_DEVICE T min(T x)
Definition: math.hpp:210
constexpr CK_TILE_HOST_DEVICE T max(T x)
Definition: math.hpp:161
@ Atomic
Definition: block_to_ctile_map.hpp:1012
@ Reduction
Definition: block_to_ctile_map.hpp:1013
__device__ uint32_t amd_wave_read_first_lane(uint32_t value)
Definition: amd_wave_read_first_lane.hpp:100
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:299
Template for the Stream-K tile partitioner derived struct.
Definition: streamk_gemm_tile_partitioner.hpp:230
Stream-K tile partitioner base class.
Definition: streamk_gemm_tile_partitioner.hpp:24
CK_TILE_HOST_DEVICE index_t get_flags_buffer_size() const noexcept
Calculates the total space needed for the flags buffer.
Definition: streamk_gemm_tile_partitioner_impl.hpp:57
CK_TILE_HOST_DEVICE index_t get_partials_buffer_size(index_t acc_element_bytes) const noexcept
Calculates the total space needed for the partials buffer.
Definition: streamk_gemm_tile_partitioner_impl.hpp:49
index_t grid_
Definition: streamk_gemm_tile_partitioner.hpp:195
static constexpr index_t KPerBlock
Definition: streamk_gemm_tile_partitioner.hpp:29
static constexpr index_t NPerBlock
Definition: streamk_gemm_tile_partitioner.hpp:28
static constexpr index_t MPerBlock
Definition: streamk_gemm_tile_partitioner.hpp:27
index_t num_tiles_
Definition: streamk_gemm_tile_partitioner.hpp:194
StreamKTilePartitionerBase(index_t m, index_t n, index_t k, index_t grid)
Definition: streamk_gemm_tile_partitioner_impl.hpp:8
index_t dp_tiles_
Definition: streamk_gemm_tile_partitioner.hpp:196
Definition: tuple.hpp:192