/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp Source File
gemm_tile_partitioner.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
9 #pragma once
10 
11 #include "ck_tile/core.hpp"
12 
13 namespace ck_tile {
14 
19 template <typename BlockGemmShapeType>
21 {
23 
24  static constexpr index_t MPerBlock = BlockGemmShape::kM;
25  static constexpr index_t NPerBlock = BlockGemmShape::kN;
26  static constexpr index_t KPerBlock = BlockGemmShape::kK;
27 
30  [[maybe_unused]] index_t N) noexcept;
31 
39  CK_TILE_HOST static auto
40  GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3
41  {
42  const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
43  const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
44  return dim3(GridDimX, GridDimY, 1);
45  }
46 
53  CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t
54  {
55  return integer_divide_ceil(K, KPerBlock);
56  }
57 
72  CK_TILE_DEVICE static auto GetOutputTileIndex(index_t blockIdx, index_t blockIdy) noexcept
74  {
75  const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx);
76  const index_t iN = __builtin_amdgcn_readfirstlane(blockIdy);
77  return make_tuple(iM, iN);
78  }
79 };
80 
86 template <typename BlockGemmShape_>
88 {
90 
91  static constexpr index_t MPerBlock = BlockGemmShape::kM;
92  static constexpr index_t NPerBlock = BlockGemmShape::kN;
93  static constexpr index_t KPerBlock = BlockGemmShape::kK;
94 
96 
104  {
105  N_ = N;
106  }
107 
115  CK_TILE_HOST static auto
116  GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
117  {
118  const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
119  const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
120  return GridDimX * GridDimY;
121  }
122 
129  CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t
130  {
131  return integer_divide_ceil(K, KPerBlock);
132  }
133 
140  CK_TILE_DEVICE static auto GetOutputTileIndex(index_t blockIdx) noexcept
141  -> const tuple<index_t, index_t>
142  {
143  const index_t NBlocks = integer_divide_ceil(N_, NPerBlock);
144 
145  const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx / NBlocks);
146  const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx - iM * NBlocks);
147  return make_tuple(iM, iN);
148  }
149 
150  private:
151  CK_TILE_DEVICE static index_t N_;
152 };
153 
158 template <typename, typename = void>
160 {
161 };
162 
168 template <typename T>
169 struct HasFnOneArgImpl<T, std::void_t<decltype(std::declval<T>().GetOutputTileIndex(1))>>
171 {
172 };
173 
180 template <typename TilePartitioner,
181  typename = typename std::enable_if_t<HasFnOneArgImpl<TilePartitioner>{}>>
183 {
191  [[nodiscard]] CK_TILE_DEVICE static auto
192  GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept
193  -> const tuple<index_t, index_t>
194  {
195  const auto [iM, iN] = TilePartitioner{M, N}.GetOutputTileIndex(blockIdx.x - block_start);
196  return make_tuple(iM, iN);
197  }
198 
207  [[nodiscard]] CK_TILE_DEVICE static auto
208  GetOffsetedTileIndex(index_t block_start, index_t M, index_t N, index_t block_idx) noexcept
209  -> const tuple<index_t, index_t>
210  {
211  const auto [iM, iN] = TilePartitioner{M, N}.GetOutputTileIndex(block_idx - block_start);
212  return make_tuple(iM, iN);
213  }
214 };
215 
227 template <typename BlockGemmShapeType, index_t GroupNum, index_t M01>
229 {
231 
232  static constexpr index_t MPerBlock = BlockGemmShape::kM;
233  static constexpr index_t NPerBlock = BlockGemmShape::kN;
234  static constexpr index_t KPerBlock = BlockGemmShape::kK;
235 
238  : M(M_), N(N_)
239  {
240  }
241 
249  CK_TILE_HOST_DEVICE static auto
250  GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
251  {
252  const index_t GridDimX = integer_divide_ceil(M, MPerBlock);
253  const index_t GridDimY = integer_divide_ceil(N, NPerBlock);
254  return GridDimX * GridDimY;
255  }
256 
263  CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t
264  {
265  return integer_divide_ceil(K, KPerBlock);
266  }
267 
274  CK_TILE_DEVICE auto GetOutputTileIndex(index_t block_1d_id) noexcept
275  -> const tuple<index_t, index_t>
276  {
277  const auto M0 = integer_divide_ceil(M, MPerBlock);
278  const auto N0 = integer_divide_ceil(N, NPerBlock);
279 
280  if(M0 == 1)
281  {
282  return make_tuple(0, block_1d_id);
283  }
284  else if(N0 == 1)
285  {
286  return make_tuple(block_1d_id, 0);
287  }
288  // block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
289  else
290  {
291  const auto group_size = integer_divide_ceil(M0 * N0, GroupNum);
292  const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0);
293  const auto group_id_y = block_1d_id / GroupNum;
294  const auto group_id_x = block_1d_id - group_id_y * GroupNum;
295  const auto remap_block_1d_id =
296  group_id_x <= big_group_num
297  ? group_id_x * group_size + group_id_y
298  : group_id_x * group_size + big_group_num - group_id_x + group_id_y;
299 
300  const index_t idx_M0 = remap_block_1d_id / N0;
301  const index_t idx_N0 = remap_block_1d_id - idx_M0 * N0;
302 
303  const index_t M0_tmp = M0 / M01;
304  const index_t M0_mod_M01 = M0 - M0_tmp * M01;
305 
306  const auto M01_adapt = (idx_M0 < M0 - M0_mod_M01) ? M01 : M0_mod_M01;
307 
308  const index_t idx_M00 = idx_M0 / M01;
309  const index_t idx_M01 = idx_M0 - idx_M00 * M01;
310  const index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
311 
356  const index_t N_out = idx_N0_M01_local / M01_adapt;
357  const index_t idx_loc_mod_M01 = idx_N0_M01_local - N_out * M01_adapt;
358 
359  return make_tuple(idx_loc_mod_M01 + idx_M00 * M01, N_out);
360  }
361  }
362 
363  private:
364  index_t M;
365  index_t N;
366 };
367 
368 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:337
bool_constant< false > false_type
Definition: integral_constant.hpp:63
bool_constant< true > true_type
Definition: integral_constant.hpp:62
Class mapping 1D block index into 2D output tile space.
Definition: gemm_tile_partitioner.hpp:229
static constexpr index_t MPerBlock
Definition: gemm_tile_partitioner.hpp:232
static CK_TILE_HOST_DEVICE auto GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock !=0 &&NPerBlock !=0)) -> index_t
Calculates GEMM kernel grid size.
Definition: gemm_tile_partitioner.hpp:250
static constexpr index_t KPerBlock
Definition: gemm_tile_partitioner.hpp:234
static CK_TILE_HOST_DEVICE auto GetLoopNum(index_t K) noexcept -> index_t
Calculate number of loop iterations over GEMM's K dimension.
Definition: gemm_tile_partitioner.hpp:263
CK_TILE_HOST_DEVICE GemmSpatiallyLocalTilePartitioner() noexcept=delete
remove_cvref_t< BlockGemmShapeType > BlockGemmShape
Definition: gemm_tile_partitioner.hpp:230
static constexpr index_t NPerBlock
Definition: gemm_tile_partitioner.hpp:233
CK_TILE_DEVICE auto GetOutputTileIndex(index_t block_1d_id) noexcept -> const tuple< index_t, index_t >
Calculate workgroup 1D index mapping into 2D output C-tile space.
Definition: gemm_tile_partitioner.hpp:274
Class providing 1D WGP index mapping into 2D output C-tile space.
Definition: gemm_tile_partitioner.hpp:88
static CK_TILE_HOST auto GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock !=0 &&NPerBlock !=0)) -> index_t
Calculates GEMM kernel grid size.
Definition: gemm_tile_partitioner.hpp:116
CK_TILE_HOST_DEVICE GemmTile1DPartitioner() noexcept=delete
static CK_TILE_HOST_DEVICE auto GetLoopNum(index_t K) noexcept -> index_t
Calculate number of loop iterations over GEMM's K dimension.
Definition: gemm_tile_partitioner.hpp:129
static CK_TILE_DEVICE auto GetOutputTileIndex(index_t blockIdx) noexcept -> const tuple< index_t, index_t >
Calculate workgroup 1D index mapping into 2D output C-tile space.
Definition: gemm_tile_partitioner.hpp:140
remove_cvref_t< BlockGemmShape_ > BlockGemmShape
Definition: gemm_tile_partitioner.hpp:89
static constexpr index_t MPerBlock
Definition: gemm_tile_partitioner.hpp:91
static constexpr index_t NPerBlock
Definition: gemm_tile_partitioner.hpp:92
static constexpr index_t KPerBlock
Definition: gemm_tile_partitioner.hpp:93
Class providing 2D workgroup index mapping into 2D output GEMM C-tile space.
Definition: gemm_tile_partitioner.hpp:21
static CK_TILE_DEVICE auto GetOutputTileIndex(index_t blockIdx, index_t blockIdy) noexcept -> const tuple< index_t, index_t >
The function returns 2D output tile space.
Definition: gemm_tile_partitioner.hpp:72
static CK_TILE_HOST auto GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock !=0 &&NPerBlock !=0)) -> dim3
Calculates GEMM kernel grid size.
Definition: gemm_tile_partitioner.hpp:40
remove_cvref_t< BlockGemmShapeType > BlockGemmShape
Definition: gemm_tile_partitioner.hpp:22
static CK_TILE_HOST_DEVICE auto GetLoopNum(index_t K) noexcept -> index_t
Calculate number of loop iterations over GEMM's K dimension.
Definition: gemm_tile_partitioner.hpp:53
static constexpr index_t NPerBlock
Definition: gemm_tile_partitioner.hpp:25
static constexpr index_t KPerBlock
Definition: gemm_tile_partitioner.hpp:26
static constexpr index_t MPerBlock
Definition: gemm_tile_partitioner.hpp:24
CK_TILE_HOST_DEVICE GemmTile2DPartitioner() noexcept=delete
GemmTile1DPartitioner::GetOutputTileIndex's std::false specialization, checking expression validity i...
Definition: gemm_tile_partitioner.hpp:160
Struct used to calculate offseted tile indexes.
Definition: gemm_tile_partitioner.hpp:183
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from 1D raw-indexes.
Definition: gemm_tile_partitioner.hpp:192
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N, index_t block_idx) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from a given block index.
Definition: gemm_tile_partitioner.hpp:208
Definition: tuple.hpp:192