/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp Source File
gridwise_ab_transfer_thread_tiles.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
9 
10 namespace ck {
11 
12 template <typename ABLayout,
13  typename ABMajorLayout,
14  typename LDSTypeAB,
15  index_t BlockSize,
16  index_t MNPerBlock,
17  index_t KPerBlock,
18  index_t MNPerWmma,
19  index_t ABK1Value,
20  bool UseBlockPaddingAB,
21  bool PermuteAB,
22  typename ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1,
23  typename ABBlockTransferThreadClusterArrangeOrder,
24  typename ABBlockTransferSrcAccessOrder,
25  index_t ABBlockTransferSrcVectorDim,
26  index_t ABBlockTransferSrcScalarPerVector,
27  index_t ABBlockTransferDstScalarPerVector_ABK1,
28  bool ABThreadTransferSrcResetCoordinateAfterRun>
30 {
31  static constexpr auto ABK0Number = Number<KPerBlock / ABK1Value>{};
32  static constexpr auto ABK1Number = Number<ABK1Value>{};
33 
34  static constexpr auto I0 = Number<0>{};
35  static constexpr auto I1 = Number<1>{};
36  static constexpr auto I2 = Number<2>{};
37 
38  static constexpr index_t ABPackedSize = []() {
40  return 2;
41  else
42  return 1;
43  }();
44 
46 
47  template <bool PadMN, bool PadK, typename GridDescriptorBase>
48  __host__ __device__ static auto MakeGridDescriptor(const GridDescriptorBase& ab_grid_desc,
49  index_t MN,
50  index_t MNPad,
51  index_t K,
52  index_t KPad,
53  index_t StrideAB,
54  index_t ABK0)
55  {
56 
57  if constexpr(PadMN && PadK)
58  {
59  // pad both MN and K
60  const auto ab_grid_desc_n_k =
61  transform_tensor_descriptor(ab_grid_desc,
62  make_tuple(make_right_pad_transform(MN, MNPad - MN),
63  make_right_pad_transform(K, KPad - K)),
66 
67  const auto ab_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
68  ab_grid_desc_n_k,
73 
74  return ab_grid_desc_bk0_n_bk1;
75  }
76  else if constexpr(PadMN && !PadK)
77  {
78  // pad MN, but not K
79  const auto ab_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
80  ab_grid_desc,
82  make_right_pad_transform(MN, MNPad - MN)),
85 
86  return ab_grid_desc_bk0_n_bk1;
87  }
88  else if constexpr(!PadMN && PadK)
89  {
90  // pad K, but not MN
91  const auto ab_grid_desc_n_k = transform_tensor_descriptor(
92  ab_grid_desc,
96 
97  const auto ab_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
98  ab_grid_desc_n_k,
103 
104  return ab_grid_desc_bk0_n_bk1;
105  }
106  else
107  {
108  if constexpr(!PermuteAB)
109  {
110  // not pad MN or K
111  const auto ab_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
112  ab_grid_desc,
113  make_tuple(make_unmerge_transform(make_tuple(ABK0, ABK1Value)),
117 
118  return ab_grid_desc_bk0_n_bk1;
119  }
120  else
121  {
122  // Pre-shuffled Weight
123  // BGlobal[K / KPerBlock, MN, KPerBlock / K1, K1] -> BTile[K / K1, MN, K1]
124  constexpr index_t ABK01 = KPerBlock / ABK1Value;
125  const index_t ABK0_ = StrideAB / ABK1Value;
126  const index_t ABK00 = ABK0_ / ABK01;
127 
128  const auto ab_grid_desc_abk00_mn_abk01_abk1_permute =
129  make_naive_tensor_descriptor_packed(make_tuple(ABK00, MN, ABK01, ABK1Value));
130 
131  const auto ab_grid_desc_abk0_mn_abk1_permute = transform_tensor_descriptor(
132  ab_grid_desc_abk00_mn_abk01_abk1_permute,
135  make_pass_through_transform(ABK1Value)),
138 
139  return ab_grid_desc_abk0_mn_abk1_permute;
140  }
141  }
142  }
143 
144  __device__ static constexpr auto GetBlockDescriptor()
145  {
146  // A matrix in LDS memory, dst of blockwise copy
147  if constexpr(UseBlockPaddingAB)
148  {
149  // bank conflict when writting the data into LDS, but don't worry, we have whole entire
150  // loop to hide it in v4. it may give you some benefit from less valu in compute address
154  }
155  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
156  // in some cases.
157  else if constexpr(is_same<ABMajorLayout, ABLayout>::value)
158  {
159  constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(LDSTypeAB) / ABPackedSize;
160  constexpr auto MNLdsLayer = LdsSize < 1 ? 1 : LdsSize;
161  constexpr auto ab_lds_block_desc = make_naive_tensor_descriptor(
163  Number<MNPerBlock / MNLdsLayer>{},
164  ABK1Number),
166 
167  constexpr auto ab_lds_block_desc_permuted = transform_tensor_descriptor(
168  ab_lds_block_desc,
169  make_tuple(
175 
176  constexpr auto ab_lds_block_desc_abk0_mnldslayer_mn_abk1 = transform_tensor_descriptor(
177  ab_lds_block_desc_permuted,
183 
184  constexpr auto ab_lds_block_desc_abk0_mn_abk1 = transform_tensor_descriptor(
185  ab_lds_block_desc_abk0_mnldslayer_mn_abk1,
192 
193  return ab_lds_block_desc_abk0_mn_abk1;
194  }
195  else
196  {
197  // kfold and mpair dimension is not always required.
198  // more dimension in merge_transform increase the difficulty of generating immarg offset
199  // for compiler.
200  constexpr auto MN0 = ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1{}.At(I1);
201  constexpr auto MN1 = MNPerBlock / MN0;
202 
203  constexpr auto KThreadWrite = ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1{}.At(I0);
204  constexpr auto K0PerThreadWrite = ABK0Number / KThreadWrite;
205  constexpr auto KThreadRead = 64 / MNPerWmma;
206  constexpr auto K0PerThreadRead = ABK0Number / KThreadRead;
207 
208  constexpr auto kfold = (ABK1Number * MN0 * sizeof(LDSTypeAB) > 128)
209  ? 1
210  : 128 / (ABK1Number * MN0 * sizeof(LDSTypeAB));
211  constexpr auto KThreadReadPerm =
212  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
213  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
214  : KThreadRead;
215 
216  // 1<=mpair<=n0
217  constexpr auto mpair = (ABK1Number * MNPerWmma * sizeof(LDSTypeAB) > 128)
218  ? 1
219  : ((128 / (ABK1Number * MNPerWmma * sizeof(LDSTypeAB))) > MN0
220  ? MN0
221  : 128 / (ABK1Number * MNPerWmma * sizeof(LDSTypeAB)));
222 
223  constexpr auto ab_lds_block_desc = make_naive_tensor_descriptor_packed(
227  Number<kfold * MN0 / mpair>{},
228  Number<mpair>{},
229  ABK1Number));
230 
231  constexpr auto ab_lds_block_desc_permuted = transform_tensor_descriptor(
232  ab_lds_block_desc,
233  make_tuple(
237  make_tuple(Number<KThreadReadPerm * MN1>{}, Number<kfold * MN0 / mpair>{})),
240  make_tuple(
242  make_tuple(
244 
245  constexpr auto ab_lds_block_desc_unmerged = transform_tensor_descriptor(
246  ab_lds_block_desc_permuted,
247  make_tuple(
255  Sequence<1>{},
256  Sequence<2>{},
257  Sequence<3>{},
258  Sequence<4>{},
259  Sequence<5>{}),
261  Sequence<2>{},
262  Sequence<0, 3>{},
263  Sequence<4, 5>{},
264  Sequence<6>{},
265  Sequence<7>{}));
266 
267  constexpr auto ab_lds_block_desc_abk0_mn_abk1 = transform_tensor_descriptor(
268  ab_lds_block_desc_unmerged,
271  Number<KThreadWrite / kfold / KThreadReadPerm>{},
272  Number<kfold>{},
279 
280  return ab_lds_block_desc_abk0_mn_abk1;
281  }
282  }
283 
284  template <typename GridDescriptor,
285  typename BlockDescriptor,
286  typename ABsDataType,
287  typename ABElementwiseOperation,
288  index_t GlobalBufferNum>
289  __device__ static auto GetBlockTransfer(GridDescriptor& grid_descriptor,
290  BlockDescriptor& block_descriptor,
291  ABElementwiseOperation& ab_element_op,
292  const index_t block_mn_id)
293  {
294  constexpr index_t NumABTensor = ABsDataType::Size();
295  const index_t mn_block_data_idx_on_grid =
296  __builtin_amdgcn_readfirstlane(block_mn_id * MNPerBlock);
297  // workaround because v7r2 is not as general as v4r1
298  if constexpr(NumABTensor > 1)
299  {
300  const auto idx_as_block_begin = generate_tuple(
301  [&](auto) { return make_multi_index(0, mn_block_data_idx_on_grid, 0); },
303 
306  ABsDataType,
308  GridDescriptor,
309  decltype(tie(block_descriptor)),
310  ABElementwiseOperation,
313  ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1,
314  ABBlockTransferThreadClusterArrangeOrder,
315  ABBlockTransferSrcAccessOrder,
317  ABBlockTransferSrcVectorDim,
318  2,
319  ABBlockTransferSrcScalarPerVector,
320  ABBlockTransferDstScalarPerVector_ABK1,
323  GlobalBufferNum>{grid_descriptor,
324  idx_as_block_begin,
325  tie(block_descriptor),
326  make_tuple(make_multi_index(0, 0, 0)),
327  ab_element_op};
328  }
329  else
330  {
333  ABElementwiseOperation,
337  ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1,
338  ABBlockTransferThreadClusterArrangeOrder,
341  decltype(grid_descriptor[I0]),
342  decltype(block_descriptor),
343  ABBlockTransferSrcAccessOrder,
345  ABBlockTransferSrcVectorDim,
346  2,
347  ABBlockTransferSrcScalarPerVector,
348  ABBlockTransferDstScalarPerVector_ABK1,
349  1,
350  1,
351  ABThreadTransferSrcResetCoordinateAfterRun,
352  true,
353  GlobalBufferNum>(grid_descriptor[I0],
354  make_multi_index(0, mn_block_data_idx_on_grid, 0),
355  ab_element_op,
356  block_descriptor,
357  make_multi_index(0, 0, 0),
359  }
360  }
361 
362  template <index_t MNRepeat, index_t MNWaves>
363  __host__ __device__ static constexpr auto MakeWmmaTileDescriptor()
364  {
365  // This is a block descriptor used to read LDS memory into register
366  // It's defined in a way consistent with the existing implementation to
367  // avoid changes in the pipelines
368  using BlockDesc = decltype(GetBlockDescriptor());
369  // ABK0_MN_ABK1 -> ABK0_MNRepeat_MNWaves_KRow_MNPerWmma_ABK1
370  constexpr auto ABK0 = BlockDesc{}.GetLength(I0);
371  constexpr auto ABK1 = BlockDesc{}.GetLength(I2);
372 #ifdef __gfx12__
373  constexpr auto KRow = I2;
374 #else
375  constexpr auto KRow = I1;
376 #endif
378  BlockDesc{},
385  }
386 
387  __device__ static constexpr auto GetBlockStep()
388  {
389  // Grid descriptor step (MoveSrcSliceWindow)
390  return make_multi_index(KPerBlock / ABK1Number, 0, 0);
391  }
392 
393  template <typename GridDescriptor>
394  __device__ static constexpr index_t GetKDimension(const GridDescriptor& grid_desc)
395  {
396  // K dimension size. This should always be called with the A matrix grid descriptor
397  // because it doesn't work for B matrix when packed int4 is used
398  return grid_desc.GetLength(I0) * grid_desc.GetLength(I2);
399  }
400 };
401 
402 } // namespace ck
auto grid_desc(MatrixPadder< GemmSpec, MPerTileType, NPerTileType, KPerTileType > matrix_padder, CDesc_MRaw_NRaw conv_desc)
Definition: matrix_padder.hpp:190
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:47
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:185
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
constexpr bool is_same_v
Definition: type.hpp:283
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:299
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
Definition: gridwise_ab_transfer_thread_tiles.hpp:30
__host__ static constexpr __device__ auto MakeWmmaTileDescriptor()
Definition: gridwise_ab_transfer_thread_tiles.hpp:363
static constexpr auto I1
Definition: gridwise_ab_transfer_thread_tiles.hpp:35
static constexpr auto I0
Definition: gridwise_ab_transfer_thread_tiles.hpp:34
static __device__ auto GetBlockTransfer(GridDescriptor &grid_descriptor, BlockDescriptor &block_descriptor, ABElementwiseOperation &ab_element_op, const index_t block_mn_id)
Definition: gridwise_ab_transfer_thread_tiles.hpp:289
static constexpr auto ABK1Number
Definition: gridwise_ab_transfer_thread_tiles.hpp:32
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_ab_transfer_thread_tiles.hpp:45
static constexpr index_t ABPackedSize
Definition: gridwise_ab_transfer_thread_tiles.hpp:38
static constexpr __device__ auto GetBlockStep()
Definition: gridwise_ab_transfer_thread_tiles.hpp:387
static constexpr __device__ auto GetBlockDescriptor()
Definition: gridwise_ab_transfer_thread_tiles.hpp:144
__host__ static __device__ auto MakeGridDescriptor(const GridDescriptorBase &ab_grid_desc, index_t MN, index_t MNPad, index_t K, index_t KPad, index_t StrideAB, index_t ABK0)
Definition: gridwise_ab_transfer_thread_tiles.hpp:48
static constexpr __device__ index_t GetKDimension(const GridDescriptor &grid_desc)
Definition: gridwise_ab_transfer_thread_tiles.hpp:394
static constexpr auto I2
Definition: gridwise_ab_transfer_thread_tiles.hpp:36
static constexpr auto ABK0Number
Definition: gridwise_ab_transfer_thread_tiles.hpp:31
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v7r2.hpp:47
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:187
Definition: unary_element_wise_operation.hpp:340