/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_wave_tiles.hpp Source File
gridwise_ab_transfer_wave_tiles.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
8 #include "ck/utility/math.hpp"
9 
10 namespace ck {
11 
12 template <typename ABLayout,
13  typename ABMajorLayout,
14  typename LDSTypeAB,
15  index_t BlockSize,
16  index_t MNPerBlock,
17  index_t KPerBlock,
18  index_t MNPerWmma,
19  index_t KPack,
20  index_t ABK1Value,
21  index_t WaveSize>
23 {
24  static_assert(!(is_same_v<remove_cvref_t<LDSTypeAB>, pk_i4_t>),
25  "wave tile transfer method does not support pk_i4_t");
26  static constexpr auto I0 = Number<0>{};
27  static constexpr auto I1 = Number<1>{};
28  static constexpr auto I2 = Number<2>{};
29  static constexpr auto I3 = Number<3>{};
30 
31  static constexpr index_t MNKRow = 2;
32 
34 
35  // Tiles distribution for global memory loading
36  // Notes: support for not power of 2 needs to be reviewed later on
37  // The tiles are distributed along the non-contiguous matrix dimension
38  // Example 4 waves A row-major MPerBlock = 64, KPerBlock = 64
39  // MRepeat = 1, KRepeat = 4
40  // -------------
41  // |W0| | | |
42  // -------------
43  // |W1| | | |
44  // -------------
45  // |W2| | | |
46  // -------------
47  // |W3| | | |
48  // -------------
49  // Example 4 waves A column-major MPerBlock = 64, KPerBlock = 64
50  // MRepeat = 4, KRepeat = 1
51  // -------------
52  // |W0|W1|W2|W3|
53  // -------------
54  // | | | | |
55  // -------------
56  // | | | | |
57  // -------------
58  // | | | | |
59  // -------------
60  static constexpr index_t NumberOfWaves = BlockSize / WaveSize;
61  static constexpr index_t MNMajorWaves_ =
62  MNPerBlock / MNPerWmma % std::min(MNPerBlock / MNPerWmma, NumberOfWaves) == 0
63  ? std::min(MNPerBlock / MNPerWmma, NumberOfWaves)
64  : (MNPerBlock / MNPerWmma % 2 == 0 ? 2 : 1);
65  static constexpr index_t KMajorWaves_ =
66  KPerBlock / KPack % std::min(KPerBlock / KPack, NumberOfWaves) == 0
67  ? std::min(KPerBlock / KPack, NumberOfWaves)
68  : (KPerBlock / KPack % 2 == 0 ? 2 : 1);
69 
70  static constexpr bool ABDoTranspose = !is_same_v<ABLayout, ABMajorLayout>;
71 
72  static constexpr index_t MNWaves_ =
75  static constexpr index_t KRepeat_ = KPerBlock / (KWaves_ * KPack);
76  static constexpr index_t MNRepeat_ = MNPerBlock / (MNWaves_ * MNPerWmma);
77 
78  template <bool PadMN, bool PadK, typename GridDescriptorBase>
79  __host__ __device__ static auto MakeGridDescriptor(GridDescriptorBase& base_desc,
80  index_t sizeMN,
81  index_t,
82  index_t sizeK,
83  index_t,
84  index_t,
85  index_t)
86  {
87  // Notes: padding is currently not supported
88  static_assert(!PadMN && !PadK, "padding is currently not supported");
89 
90  // Divide the base descriptor MN_K into tiles
91  const auto ab_grid_desc_mntiles_ktiles = transform_tensor_descriptor(
92  base_desc,
93  make_tuple(
97  Number<KPack>{}))),
100 
101  // The distinction is needed to get the same global indices for both layouts
102  // Divide each tile in 2 16x8 subtile
103  // MNTiles - KTiles - MNKRow - LaneLocal - VectorSize
104  // MNKRow = 0-1
105  // LaneLocal = 0-15
106  // VectorSize must be 8
107  if constexpr(!ABDoTranspose)
108  {
109  const auto ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1 =
111  ab_grid_desc_mntiles_ktiles,
118  make_tuple(Number<MNKRow>{}, Number<KPack / MNKRow>{}))),
121 
122  // Freeze VectorSize to first element of the loading chunk (for convenience)
123  // Swap MNPerWmma and MNKRow for consistency with transpose descriptor
125  ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1,
126  make_tuple(
133  make_tuple(
135  make_tuple(
137  }
138  else
139  {
140  const auto ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1 =
142  ab_grid_desc_mntiles_ktiles,
148  make_tuple(Number<MNKRow>{}, Number<MNPerWmma / MNKRow>{})),
152 
153  // Freeze VectorSize to first element of the loading chunk (for convenience)
155  ab_grid_desc_mntiles_ktiles_lanegroup_lanelocal_abk1,
156  make_tuple(
163  make_tuple(
165  make_tuple(
167  }
168  }
169 
170  __device__ static constexpr auto GetBlockDescriptor()
171  {
172  // LDS memory layouts:
173  // lanes within tiles stored contiguously in chunks of 8 elements
174  // tiles are then stored first in K dimension
175  // MNTiles - KTiles - MNKRow - LaneLocal - VectorSize
176  const auto a_grid_desc_mraw_kraw = [&]() {
180  Number<MNKRow>{},
187  I1));
188  }();
189 
190  // Freeze VectorSize to first element of the chunk (for convenience)
192  a_grid_desc_mraw_kraw,
200  }
201 
202  __device__ static auto GetWaveIdx()
203  {
204  const index_t thread_id = ThisThreadBlock::GetThreadId();
205 
206  constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
210 
211  return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
212  }
213 
214  __device__ static auto GetBlockLaneIdx()
215  {
216  const index_t lane_id = __lane_id();
217 
218  constexpr index_t LanesPerSubTile = ABDoTranspose ? KPack : MNPerWmma;
219 
220  constexpr auto laneid_to_block_lane_idx_adaptor = make_single_stage_tensor_adaptor(
221  make_tuple(make_merge_transform(make_tuple(MNKRow, LanesPerSubTile))),
224 
225  return laneid_to_block_lane_idx_adaptor.CalculateBottomIndex(make_multi_index(lane_id));
226  }
227 
228  template <typename ABDataType>
229  __device__ static auto GetGridLaneIdx()
230  {
231  const index_t lane_id = __lane_id();
232 
233  constexpr index_t SubTilesRow = MNKRow;
234  constexpr index_t SubTilesCol = 4 / sizeof(ABDataType);
235  constexpr index_t LanesPerSubTile =
236  ABDoTranspose ? KPack / SubTilesCol : MNPerWmma / SubTilesCol;
237  constexpr auto dims_tuple = ABDoTranspose
238  ? make_tuple(SubTilesCol, SubTilesRow, LanesPerSubTile)
239  : make_tuple(SubTilesRow, SubTilesCol, LanesPerSubTile);
240 
241  constexpr auto laneid_to_grid_lane_idx_adaptor =
245 
246  const auto indices =
247  laneid_to_grid_lane_idx_adaptor.CalculateBottomIndex(make_multi_index(lane_id));
248 
249  if constexpr(!ABDoTranspose)
250  {
251  return make_multi_index(indices[I0], indices[I1] * LanesPerSubTile + indices[I2]);
252  }
253  else
254  {
255  return make_multi_index(indices[I1], indices[I0] * LanesPerSubTile + indices[I2]);
256  }
257  }
258 
259  template <typename GridDescriptor,
260  typename BlockDescriptor,
261  typename ABsDataType,
262  typename ABElementwiseOperation,
263  index_t GlobalBufferNum>
264  __device__ static auto GetBlockTransfer(GridDescriptor& grid_descriptor,
265  BlockDescriptor& block_descriptor,
266  ABElementwiseOperation& ab_element_op,
267  const index_t block_mn_id)
268  {
269  // Note: GlobalBufferNum is currently not used but it will be needed
270  // once we add other pipelines. It is currently needed only for
271  // consistency with the thread tiles approach
272  static_assert(GlobalBufferNum == 1, "single global buffer is only supported");
273  constexpr index_t NumABTensor = ABsDataType::Size();
274  static_assert(NumABTensor == 1, "multiAB currently not supported");
275 
277 
278  const auto wave_idx = GetWaveIdx();
279  index_t wave_idK = wave_idx[I1];
280  index_t wave_idMN = wave_idx[I0];
281 
282  const auto grid_lane_id = GetGridLaneIdx<ABDataType>();
283  index_t lane_group_grid = grid_lane_id[I0];
284  index_t lane_local_id_grid = grid_lane_id[I1];
285 
286  const auto block_lane_id = GetBlockLaneIdx();
287  index_t lane_group_block = block_lane_id[I0];
288  index_t lane_local_id_block = block_lane_id[I1];
289 
290  return ThreadGroupTransferGlobal<decltype(grid_descriptor[I0]),
291  BlockDescriptor,
292  ABDataType,
293  ABDataType,
294  ABElementwiseOperation,
298  ABK1Value,
299  ABDoTranspose>(
300  grid_descriptor[I0],
301  block_descriptor,
302  make_multi_index(block_mn_id * (MNRepeat_ * MNWaves_) + wave_idMN,
303  wave_idK,
304  lane_group_grid,
305  lane_local_id_grid),
306  make_multi_index(wave_idMN, wave_idK, lane_group_block, lane_local_id_block),
307  ab_element_op);
308  }
309 
310  template <index_t MNRepeat, index_t MNWaves>
311  __host__ __device__ static constexpr auto MakeWmmaTileDescriptor()
312  {
313  // This is a block descriptor used to read LDS memory into register
314  // It's defined in a way consistent with the existing implementation to
315  // avoid changes in the pipelines
318  Number<MNWaves>{},
319  Number<MNKRow>{},
327  I1));
328  }
329 
330  __device__ static constexpr auto GetBlockStep()
331  {
332  // Grid descriptor step (MoveSrcSliceWindow)
333  return make_multi_index(I0, KWaves_ * KRepeat_, I0, I0);
334  }
335 
336  template <typename GridDescriptor>
337  __device__ static constexpr index_t GetKDimension(const GridDescriptor& grid_desc)
338  {
339  return grid_desc.GetLength(I1) * KPack;
340  }
341 };
342 
343 } // namespace ck
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
auto grid_desc(MatrixPadder< GemmSpec, MPerTileType, NPerTileType, KPerTileType > matrix_padder, CDesc_MRaw_NRaw conv_desc)
Definition: matrix_padder.hpp:190
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_freeze_transform(const LowerIndex &low_idx)
Definition: multi_index_transform_helper.hpp:151
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
constexpr bool is_same_v
Definition: type.hpp:283
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
Definition: gridwise_ab_transfer_wave_tiles.hpp:23
static __device__ auto GetWaveIdx()
Definition: gridwise_ab_transfer_wave_tiles.hpp:202
__host__ static __device__ auto MakeGridDescriptor(GridDescriptorBase &base_desc, index_t sizeMN, index_t, index_t sizeK, index_t, index_t, index_t)
Definition: gridwise_ab_transfer_wave_tiles.hpp:79
static constexpr index_t MNRepeat_
Definition: gridwise_ab_transfer_wave_tiles.hpp:76
static __device__ auto GetGridLaneIdx()
Definition: gridwise_ab_transfer_wave_tiles.hpp:229
static constexpr __device__ auto GetBlockDescriptor()
Definition: gridwise_ab_transfer_wave_tiles.hpp:170
static constexpr auto I2
Definition: gridwise_ab_transfer_wave_tiles.hpp:28
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_ab_transfer_wave_tiles.hpp:33
static constexpr index_t KWaves_
Definition: gridwise_ab_transfer_wave_tiles.hpp:74
static __device__ auto GetBlockTransfer(GridDescriptor &grid_descriptor, BlockDescriptor &block_descriptor, ABElementwiseOperation &ab_element_op, const index_t block_mn_id)
Definition: gridwise_ab_transfer_wave_tiles.hpp:264
__host__ static constexpr __device__ auto MakeWmmaTileDescriptor()
Definition: gridwise_ab_transfer_wave_tiles.hpp:311
static constexpr __device__ index_t GetKDimension(const GridDescriptor &grid_desc)
Definition: gridwise_ab_transfer_wave_tiles.hpp:337
static constexpr index_t KMajorWaves_
Definition: gridwise_ab_transfer_wave_tiles.hpp:65
static constexpr index_t MNMajorWaves_
Definition: gridwise_ab_transfer_wave_tiles.hpp:61
static constexpr auto I1
Definition: gridwise_ab_transfer_wave_tiles.hpp:27
static constexpr auto I3
Definition: gridwise_ab_transfer_wave_tiles.hpp:29
static constexpr index_t MNKRow
Definition: gridwise_ab_transfer_wave_tiles.hpp:31
static constexpr auto I0
Definition: gridwise_ab_transfer_wave_tiles.hpp:26
static constexpr bool ABDoTranspose
Definition: gridwise_ab_transfer_wave_tiles.hpp:70
static constexpr index_t MNWaves_
Definition: gridwise_ab_transfer_wave_tiles.hpp:72
static constexpr __device__ auto GetBlockStep()
Definition: gridwise_ab_transfer_wave_tiles.hpp:330
static constexpr index_t KRepeat_
Definition: gridwise_ab_transfer_wave_tiles.hpp:75
static constexpr index_t NumberOfWaves
Definition: gridwise_ab_transfer_wave_tiles.hpp:60
static __device__ auto GetBlockLaneIdx()
Definition: gridwise_ab_transfer_wave_tiles.hpp:214
Definition: sequence.hpp:43
static __device__ index_t GetThreadId()
Definition: thread_group.hpp:19
Definition: thread_group_tensor_slice_transfer_global.hpp:26
Definition: integral_constant.hpp:20
Definition: data_type.hpp:187