/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/core/algorithm/static_encoding_pattern.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/core/algorithm/static_encoding_pattern.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/core/algorithm/static_encoding_pattern.hpp Source File
static_encoding_pattern.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
71 #pragma once
72 
74 #include "ck_tile/core/config.hpp"
80 
81 namespace ck_tile {
82 
88 {
98  warp_raked,
103  block_raked,
104 };
105 
107 {
108 };
109 
122 template <index_t BlockSize,
123  index_t YPerTile,
124  index_t XPerTile,
125  index_t VecSize,
126  tile_distribution_pattern DistributionPattern,
127  index_t NumWaveGroups = 1>
129 {
130 };
131 
132 // Thread raked
133 template <index_t BlockSize,
134  index_t YPerTile,
135  index_t XPerTile,
136  index_t VecSize,
137  index_t NumWaveGroups>
139  YPerTile,
140  XPerTile,
141  VecSize,
143  NumWaveGroups> : public TileDistributionEncodingPattern
144 {
145 
146  // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
147  static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
148  static constexpr index_t warp_size = get_warp_size();
149  static constexpr index_t num_warps = BlockSize / get_warp_size();
150  static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
151  static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
152  static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
153 
154  // # of rows in Y dim accessed by single wavefront in one iteration
155  static constexpr index_t Y1 = warp_size / X0;
156  static_assert(X0 * Y1 == warp_size, "X0 * Y1 must cover whole wavefront!");
157 
158  static constexpr index_t Y0 = num_warps / NumWaveGroups;
159  // YPerWarp = YPerTile / Y0;
160  // Y2 = YPerWarp / Y1;
161  static constexpr index_t Y2 = YPerTile / (Y1 * Y0); // # of iters within wavefront
162 
163  static_assert(X0 * Y1 * Y0 * NumWaveGroups == BlockSize,
164  "X0 * warp_ys * Y0 must cover whole workgroup!");
165  static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
166 
168  {
169  if constexpr(NumWaveGroups != 1)
170  {
175  tuple<sequence<0>, sequence<0, 0>>, // -> <Y0>, <Y1, X0>
177  sequence<1, 1>>{}); // -> <Y2, X1>
178  }
179  else
180  {
185  tuple<sequence<0>, sequence<1, 0>>, // -> <Y0>, <Y1, X0>
187  sequence<2, 1>>{}); // -> <Y2, X1>
188  }
189  }
190 
192  {
193  if constexpr(NumWaveGroups != 1)
194  {
199  tuple<sequence<0>, sequence<0, 0>>, // -> <Y0>, <Y1, X0>
201  sequence<1, 1>>{}); // -> <X1, Y2>
202  }
203  else
204  {
209  tuple<sequence<0>, sequence<1, 0>>, // -> <Y0>, <Y1, X0>
211  sequence<1, 2>>{}); // -> <X1, Y2>
212  }
213  }
214 };
215 
216 // Warp raked
217 template <index_t BlockSize,
218  index_t YPerTile,
219  index_t XPerTile,
220  index_t VecSize,
221  index_t NumWaveGroups>
223  YPerTile,
224  XPerTile,
225  VecSize,
227  NumWaveGroups> : public TileDistributionEncodingPattern
228 {
229 
230  static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
231  static constexpr index_t warp_size = get_warp_size();
232  static constexpr index_t num_warps = BlockSize / get_warp_size();
233  static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
234  static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
235  static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
236 
237  static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
238  static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
239 
240  static constexpr index_t Y0 = num_warps;
241  static_assert(X0 * Y2 * Y0 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!");
242 
243  static constexpr index_t Y1 = YPerTile / (Y2 * Y0); // # of iters within wavefront
244  static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
245 
247  {
252  tuple<sequence<0>, sequence<2, 0>>, // -> <Y0>, <Y2, X0>
254  sequence<1, 1>>{}); // -> <Y1, X1>
255  }
256 
258  {
263  tuple<sequence<0>, sequence<2, 0>>, // -> <Y0>, <Y2, X0>
265  sequence<1, 1>>{}); // -> <X1, Y1>
266  }
267 };
268 
269 // Block raked
270 template <index_t BlockSize,
271  index_t YPerTile,
272  index_t XPerTile,
273  index_t VecSize,
274  index_t NumWaveGroups>
276  YPerTile,
277  XPerTile,
278  VecSize,
280  NumWaveGroups> : public TileDistributionEncodingPattern
281 {
282 
283  // TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
284  static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
285  static constexpr index_t warp_size = get_warp_size();
286  static constexpr index_t num_warps = BlockSize / get_warp_size();
287  static constexpr index_t LargestVec = (XPerTile * YPerTile) / (num_warps * warp_size);
288  static constexpr index_t X1 = VecSize > LargestVec ? LargestVec : VecSize;
289  static constexpr index_t X0 = XPerTile / X1; // # of threads in X dim
290  static constexpr index_t Y2 = warp_size / X0; // # of rows in Y dim to cover whole wavefront
291  static_assert(X0 * Y2 == warp_size, "X0 * Y2 must cover whole wavefront!");
292  static constexpr index_t Y1 = num_warps;
293  static_assert(X0 * Y2 * Y1 == BlockSize, "X0 * Y2 * Y1 must cover whole workgroup!");
294  static constexpr index_t Y0 = YPerTile / (Y2 * Y1); // # of iters
295  static_assert(Y0 * Y1 * Y2 == YPerTile, "Y0, Y1, Y2 must cover whole YPerTile");
296 
298  {
303  tuple<sequence<1>, sequence<2, 0>>, // -> <Y1>, <Y2, X0>
305  sequence<0, 1>>{}); // -> <Y0, X1>
306  }
307 
309  {
314  tuple<sequence<1>, sequence<2, 0>>, // -> <Y1>, <Y2, X0>
316  sequence<1, 0>>{}); // -> <X1, Y0>
317  }
318 };
319 
320 } // namespace ck_tile
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE index_t get_warp_size()
Definition: arch.hpp:51
int32_t index_t
Definition: integer.hpp:9
tile_distribution_pattern
Enumeration describing static tile distribution patterns.
Definition: static_encoding_pattern.hpp:88
@ block_raked
Block raked pattern - aka linear.
@ thread_raked
Thread raked pattern.
@ warp_raked
Warp raked pattern.
constexpr CK_TILE_HOST_DEVICE auto make_static_tile_distribution(StaticTileDistributionEncoding_)
Definition: tile_distribution.hpp:498
static constexpr CK_TILE_HOST_DEVICE auto Make2DStaticTileDistribution()
Definition: static_encoding_pattern.hpp:246
static constexpr CK_TILE_HOST_DEVICE auto MakeShuffled2DStaticTileDistribution()
Definition: static_encoding_pattern.hpp:257
static constexpr CK_TILE_HOST_DEVICE auto Make2DStaticTileDistribution()
Definition: static_encoding_pattern.hpp:167
static constexpr CK_TILE_HOST_DEVICE auto MakeShuffled2DStaticTileDistribution()
Definition: static_encoding_pattern.hpp:191
static constexpr CK_TILE_HOST_DEVICE auto MakeShuffled2DStaticTileDistribution()
Definition: static_encoding_pattern.hpp:308
static constexpr CK_TILE_HOST_DEVICE auto Make2DStaticTileDistribution()
Definition: static_encoding_pattern.hpp:297
Class creating 2D static tile distribution with different load/store patterns.
Definition: static_encoding_pattern.hpp:129
Definition: static_encoding_pattern.hpp:107
Definition: sequence.hpp:52
Definition: tile_distribution_encoding.hpp:26
Definition: tuple.hpp:192