/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.1.0/include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp Source File
gemm_tile_partitioner.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
9 #pragma once
10 
11 #include "ck_tile/core.hpp"
12 
13 namespace ck_tile {
14 
19 template <typename BlockGemmShapeType>
21 {
23 
24  static constexpr index_t MPerBlock = BlockGemmShape::kM;
25  static constexpr index_t NPerBlock = BlockGemmShape::kN;
26  static constexpr index_t KPerBlock = BlockGemmShape::kK;
27 
30  [[maybe_unused]] index_t N) noexcept;
31 
39  CK_TILE_HOST static auto
40  GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> dim3
41  {
42  const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
43  const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
44  return dim3(GridDimX, GridDimY, 1);
45  }
46 
53  CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t
54  {
55  return integer_divide_ceil(K, KPerBlock);
56  }
57 
72  CK_TILE_DEVICE static auto
73  GetOutputTileIndex(index_t blockIdx, index_t blockIdy) noexcept -> const tuple<index_t, index_t>
74  {
75  const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx);
76  const index_t iN = __builtin_amdgcn_readfirstlane(blockIdy);
77  return make_tuple(iM, iN);
78  }
79 };
80 
86 template <typename BlockGemmShape_>
88 {
90 
91  static constexpr index_t MPerBlock = BlockGemmShape::kM;
92  static constexpr index_t NPerBlock = BlockGemmShape::kN;
93  static constexpr index_t KPerBlock = BlockGemmShape::kK;
94 
96 
104  {
105  N_ = N;
106  }
107 
115  CK_TILE_HOST_DEVICE static auto
116  GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
117  {
118  const index_t GridDimX = (M + MPerBlock - 1) / MPerBlock;
119  const index_t GridDimY = (N + NPerBlock - 1) / NPerBlock;
120  return GridDimX * GridDimY;
121  }
122 
129  CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t
130  {
131  return integer_divide_ceil(K, KPerBlock);
132  }
133 
140  CK_TILE_DEVICE static auto
142  {
143  const index_t NBlocks = integer_divide_ceil(N_, NPerBlock);
144 
145  const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx / NBlocks);
146  const index_t iN = __builtin_amdgcn_readfirstlane(blockIdx - iM * NBlocks);
147  return make_tuple(iM, iN);
148  }
149 
150  private:
151  CK_TILE_DEVICE static index_t N_;
152 };
153 
158 template <typename, typename = void>
160 {
161 };
162 
168 template <typename T>
169 struct HasFnOneArgImpl<T, std::void_t<decltype(std::declval<T>().GetOutputTileIndex(1))>>
171 {
172 };
173 
180 template <typename TilePartitioner,
181  typename = typename std::enable_if_t<HasFnOneArgImpl<TilePartitioner>{}>>
183 {
191  [[nodiscard]] CK_TILE_DEVICE static auto GetOffsetedTileIndex(
192  index_t block_start, index_t M, index_t N) noexcept -> const tuple<index_t, index_t>
193  {
194  const auto [iM, iN] = TilePartitioner{M, N}.GetOutputTileIndex(blockIdx.x - block_start);
195  return make_tuple(iM, iN);
196  }
197 
206  [[nodiscard]] CK_TILE_DEVICE static auto
207  GetOffsetedTileIndex(index_t block_start, index_t M, index_t N, index_t block_idx) noexcept
208  -> const tuple<index_t, index_t>
209  {
210  const auto [iM, iN] = TilePartitioner{M, N}.GetOutputTileIndex(block_idx - block_start);
211  return make_tuple(iM, iN);
212  }
213 };
214 
226 template <typename BlockGemmShapeType, index_t GroupNum, index_t M01>
228 {
230 
231  static constexpr index_t MPerBlock = BlockGemmShape::kM;
232  static constexpr index_t NPerBlock = BlockGemmShape::kN;
233  static constexpr index_t KPerBlock = BlockGemmShape::kK;
234 
237  : M(M_), N(N_)
238  {
239  }
240 
248  CK_TILE_HOST_DEVICE static auto
249  GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock != 0 && NPerBlock != 0)) -> index_t
250  {
251  const index_t GridDimX = integer_divide_ceil(M, MPerBlock);
252  const index_t GridDimY = integer_divide_ceil(N, NPerBlock);
253  return GridDimX * GridDimY;
254  }
255 
262  CK_TILE_HOST_DEVICE static auto GetLoopNum(index_t K) noexcept -> index_t
263  {
264  return integer_divide_ceil(K, KPerBlock);
265  }
266 
273  CK_TILE_DEVICE auto
274  GetOutputTileIndex(index_t block_1d_id) noexcept -> const tuple<index_t, index_t>
275  {
276  const auto M0 = integer_divide_ceil(M, MPerBlock);
277  const auto N0 = integer_divide_ceil(N, NPerBlock);
278 
279  if(M0 == 1)
280  {
281  return make_tuple(0, block_1d_id);
282  }
283  else if(N0 == 1)
284  {
285  return make_tuple(block_1d_id, 0);
286  }
287  // block_1d_id = block_1d_id % (M0 * N0); // swallow batch index
288  else
289  {
290  const auto group_size = integer_divide_ceil(M0 * N0, GroupNum);
291  const auto big_group_num = GroupNum - (group_size * GroupNum - M0 * N0);
292  const auto group_id_y = block_1d_id / GroupNum;
293  const auto group_id_x = block_1d_id - group_id_y * GroupNum;
294  const auto remap_block_1d_id =
295  group_id_x <= big_group_num
296  ? group_id_x * group_size + group_id_y
297  : group_id_x * group_size + big_group_num - group_id_x + group_id_y;
298 
299  const index_t idx_M0 = remap_block_1d_id / N0;
300  const index_t idx_N0 = remap_block_1d_id - idx_M0 * N0;
301 
302  const index_t M0_tmp = M0 / M01;
303  const index_t M0_mod_M01 = M0 - M0_tmp * M01;
304 
305  const auto M01_adapt = (idx_M0 < M0 - M0_mod_M01) ? M01 : M0_mod_M01;
306 
307  const index_t idx_M00 = idx_M0 / M01;
308  const index_t idx_M01 = idx_M0 - idx_M00 * M01;
309  const index_t idx_N0_M01_local = idx_N0 + idx_M01 * N0;
310 
355  const index_t N_out = idx_N0_M01_local / M01_adapt;
356  const index_t idx_loc_mod_M01 = idx_N0_M01_local - N_out * M01_adapt;
357 
358  return make_tuple(idx_loc_mod_M01 + idx_M00 * M01, N_out);
359  }
360  }
361 
362  private:
363  index_t M;
364  index_t N;
365 };
366 
367 } // namespace ck_tile
#define CK_TILE_DEVICE
Definition: config.hpp:40
#define CK_TILE_HOST
Definition: config.hpp:39
#define CK_TILE_HOST_DEVICE
Definition: config.hpp:41
Definition: cluster_descriptor.hpp:13
constexpr CK_TILE_HOST_DEVICE auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:149
int32_t index_t
Definition: integer.hpp:9
remove_cv_t< std::remove_reference_t< T > > remove_cvref_t
Definition: type_traits.hpp:21
constexpr CK_TILE_HOST_DEVICE auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:343
bool_constant< false > false_type
Definition: integral_constant.hpp:63
bool_constant< true > true_type
Definition: integral_constant.hpp:62
Class mapping 1D block index into 2D output tile space.
Definition: gemm_tile_partitioner.hpp:228
static constexpr index_t MPerBlock
Definition: gemm_tile_partitioner.hpp:231
static CK_TILE_HOST_DEVICE auto GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock !=0 &&NPerBlock !=0)) -> index_t
Calculates GEMM kernel grid size.
Definition: gemm_tile_partitioner.hpp:249
static constexpr index_t KPerBlock
Definition: gemm_tile_partitioner.hpp:233
static CK_TILE_HOST_DEVICE auto GetLoopNum(index_t K) noexcept -> index_t
Calculate number of loop iterations over GEMM's K dimension.
Definition: gemm_tile_partitioner.hpp:262
CK_TILE_HOST_DEVICE GemmSpatiallyLocalTilePartitioner() noexcept=delete
remove_cvref_t< BlockGemmShapeType > BlockGemmShape
Definition: gemm_tile_partitioner.hpp:229
static constexpr index_t NPerBlock
Definition: gemm_tile_partitioner.hpp:232
CK_TILE_DEVICE auto GetOutputTileIndex(index_t block_1d_id) noexcept -> const tuple< index_t, index_t >
Calculate workgroup 1D index mapping into 2D output C-tile space.
Definition: gemm_tile_partitioner.hpp:274
Class providing 1D WGP index mapping into 2D output C-tile space.
Definition: gemm_tile_partitioner.hpp:88
CK_TILE_HOST_DEVICE GemmTile1DPartitioner() noexcept=delete
static CK_TILE_HOST_DEVICE auto GetLoopNum(index_t K) noexcept -> index_t
Calculate number of loop iterations over GEMM's K dimension.
Definition: gemm_tile_partitioner.hpp:129
static CK_TILE_DEVICE auto GetOutputTileIndex(index_t blockIdx) noexcept -> const tuple< index_t, index_t >
Calculate workgroup 1D index mapping into 2D output C-tile space.
Definition: gemm_tile_partitioner.hpp:141
remove_cvref_t< BlockGemmShape_ > BlockGemmShape
Definition: gemm_tile_partitioner.hpp:89
static constexpr index_t MPerBlock
Definition: gemm_tile_partitioner.hpp:91
static constexpr index_t NPerBlock
Definition: gemm_tile_partitioner.hpp:92
static CK_TILE_HOST_DEVICE auto GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock !=0 &&NPerBlock !=0)) -> index_t
Calculates GEMM kernel grid size.
Definition: gemm_tile_partitioner.hpp:116
static constexpr index_t KPerBlock
Definition: gemm_tile_partitioner.hpp:93
Class providing 2D workgroup index mapping into 2D output GEMM C-tile space.
Definition: gemm_tile_partitioner.hpp:21
static CK_TILE_DEVICE auto GetOutputTileIndex(index_t blockIdx, index_t blockIdy) noexcept -> const tuple< index_t, index_t >
The function returns 2D output tile space.
Definition: gemm_tile_partitioner.hpp:73
static CK_TILE_HOST auto GridSize(index_t M, index_t N) noexcept(noexcept(MPerBlock !=0 &&NPerBlock !=0)) -> dim3
Calculates GEMM kernel grid size.
Definition: gemm_tile_partitioner.hpp:40
remove_cvref_t< BlockGemmShapeType > BlockGemmShape
Definition: gemm_tile_partitioner.hpp:22
static CK_TILE_HOST_DEVICE auto GetLoopNum(index_t K) noexcept -> index_t
Calculate number of loop iterations over GEMM's K dimension.
Definition: gemm_tile_partitioner.hpp:53
static constexpr index_t NPerBlock
Definition: gemm_tile_partitioner.hpp:25
static constexpr index_t KPerBlock
Definition: gemm_tile_partitioner.hpp:26
static constexpr index_t MPerBlock
Definition: gemm_tile_partitioner.hpp:24
CK_TILE_HOST_DEVICE GemmTile2DPartitioner() noexcept=delete
GemmTile1DPartitioner::GetOutputTileIndex's std::false specialization, checking expression validity i...
Definition: gemm_tile_partitioner.hpp:160
Struct used to calculate offseted tile indexes.
Definition: gemm_tile_partitioner.hpp:183
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from 1D raw-indexes.
Definition: gemm_tile_partitioner.hpp:191
static CK_TILE_DEVICE auto GetOffsetedTileIndex(index_t block_start, index_t M, index_t N, index_t block_idx) noexcept -> const tuple< index_t, index_t >
The function subtracts the block's start (offset) from a given block index.
Definition: gemm_tile_partitioner.hpp:207
Definition: tuple.hpp:192