/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp Source File
gridwise_ab_transfer_thread_tiles.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
9 
10 namespace ck {
11 
12 template <typename ABLayout,
13  typename ABMajorLayout,
14  typename LDSTypeAB,
15  index_t BlockSize,
16  index_t MNPerBlock,
17  index_t KPerBlock,
18  index_t MNPerWmma,
19  index_t ABK1Value,
20  index_t KPack,
21  index_t KInner,
22  index_t KPerWmmaBlk,
23  bool UseBlockPaddingAB,
24  bool PermuteAB,
25  typename ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1,
26  typename ABBlockTransferThreadClusterArrangeOrder,
27  typename ABBlockTransferSrcAccessOrder,
28  index_t ABBlockTransferSrcVectorDim,
29  index_t ABBlockTransferSrcScalarPerVector,
30  index_t ABBlockTransferDstScalarPerVector_ABK1,
31  bool ABThreadTransferSrcResetCoordinateAfterRun>
33 {
34  static constexpr auto ABK0Number = Number<KPerBlock / ABK1Value>{};
35  static constexpr auto ABK1Number = Number<ABK1Value>{};
36 
37  static constexpr auto I0 = Number<0>{};
38  static constexpr auto I1 = Number<1>{};
39  static constexpr auto I2 = Number<2>{};
40 
41  static constexpr index_t ABPackedSize = []() {
43  return 2;
44  else
45  return 1;
46  }();
47 
49 
50  template <bool PadMN, bool PadK, typename GridDescriptorBase>
51  __host__ __device__ static auto MakeGridDescriptor(const GridDescriptorBase& ab_grid_desc,
52  index_t MN,
53  index_t MNPad,
54  index_t K,
55  index_t KPad,
56  index_t StrideAB,
57  index_t ABK0)
58  {
59 
60  if constexpr(PadMN && PadK)
61  {
62  // pad both MN and K
63  const auto ab_grid_desc_n_k =
64  transform_tensor_descriptor(ab_grid_desc,
65  make_tuple(make_right_pad_transform(MN, MNPad - MN),
66  make_right_pad_transform(K, KPad - K)),
69 
70  const auto ab_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
71  ab_grid_desc_n_k,
76 
77  return ab_grid_desc_bk0_n_bk1;
78  }
79  else if constexpr(PadMN && !PadK)
80  {
81  // pad MN, but not K
82  const auto ab_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
83  ab_grid_desc,
85  make_right_pad_transform(MN, MNPad - MN)),
88 
89  return ab_grid_desc_bk0_n_bk1;
90  }
91  else if constexpr(!PadMN && PadK)
92  {
93  // pad K, but not MN
94  const auto ab_grid_desc_n_k = transform_tensor_descriptor(
95  ab_grid_desc,
99 
100  const auto ab_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
101  ab_grid_desc_n_k,
102  make_tuple(make_unmerge_transform(make_tuple(ABK0, ABK1Value)),
106 
107  return ab_grid_desc_bk0_n_bk1;
108  }
109  else
110  {
111  if constexpr(!PermuteAB)
112  {
113  // not pad MN or K
114  const auto ab_grid_desc_bk0_n_bk1 = transform_tensor_descriptor(
115  ab_grid_desc,
116  make_tuple(make_unmerge_transform(make_tuple(ABK0, ABK1Value)),
120 
121  return ab_grid_desc_bk0_n_bk1;
122  }
123  else
124  {
125  // Pre-shuffled Weight
126  // BGlobal[K / KPerBlock, MN, KPerBlock / K1, K1] -> BTile[K / K1, MN, K1]
127  constexpr index_t ABK01 = KPerBlock / ABK1Value;
128  const index_t ABK0_ = StrideAB / ABK1Value;
129  const index_t ABK00 = ABK0_ / ABK01;
130 
131  const auto ab_grid_desc_abk00_mn_abk01_abk1_permute =
132  make_naive_tensor_descriptor_packed(make_tuple(ABK00, MN, ABK01, ABK1Value));
133 
134  const auto ab_grid_desc_abk0_mn_abk1_permute = transform_tensor_descriptor(
135  ab_grid_desc_abk00_mn_abk01_abk1_permute,
138  make_pass_through_transform(ABK1Value)),
141 
142  return ab_grid_desc_abk0_mn_abk1_permute;
143  }
144  }
145  }
146 
147  __device__ static constexpr auto GetBlockDescriptor()
148  {
149  // A matrix in LDS memory, dst of blockwise copy
150  if constexpr(UseBlockPaddingAB)
151  {
152  // bank conflict when writting the data into LDS, but don't worry, we have whole entire
153  // loop to hide it in v4. it may give you some benefit from less valu in compute address
157  }
158  // xor tensor transformation request more unnecessary vgpr usage, would cause register spill
159  // in some cases.
160  else if constexpr(is_same<ABMajorLayout, ABLayout>::value)
161  {
162  constexpr index_t LdsSize = 32 * 4 / KPerBlock / sizeof(LDSTypeAB) / ABPackedSize;
163  constexpr auto MNLdsLayer = LdsSize < 1 ? 1 : LdsSize;
164  constexpr auto ab_lds_block_desc = make_naive_tensor_descriptor(
166  Number<MNPerBlock / MNLdsLayer>{},
167  ABK1Number),
169 
170  constexpr auto ab_lds_block_desc_permuted = transform_tensor_descriptor(
171  ab_lds_block_desc,
172  make_tuple(
178 
179  constexpr auto ab_lds_block_desc_abk0_mnldslayer_mn_abk1 = transform_tensor_descriptor(
180  ab_lds_block_desc_permuted,
186 
187  constexpr auto ab_lds_block_desc_abk0_mn_abk1 = transform_tensor_descriptor(
188  ab_lds_block_desc_abk0_mnldslayer_mn_abk1,
195 
196  return ab_lds_block_desc_abk0_mn_abk1;
197  }
198  else
199  {
200  // kfold and mpair dimension is not always required.
201  // more dimension in merge_transform increase the difficulty of generating immarg offset
202  // for compiler.
203  constexpr auto MN0 = ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1{}.At(I1);
204  constexpr auto MN1 = MNPerBlock / MN0;
205 
206  constexpr auto KThreadWrite = ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1{}.At(I0);
207  constexpr auto K0PerThreadWrite = ABK0Number / KThreadWrite;
208  constexpr auto KThreadRead = 64 / MNPerWmma;
209  constexpr auto K0PerThreadRead = ABK0Number / KThreadRead;
210 
211  constexpr auto kfold = (ABK1Number * MN0 * sizeof(LDSTypeAB) > 128)
212  ? 1
213  : 128 / (ABK1Number * MN0 * sizeof(LDSTypeAB));
214  constexpr auto KThreadReadPerm =
215  (kfold * K0PerThreadWrite / K0PerThreadRead) > 1
216  ? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
217  : KThreadRead;
218 
219  // 1<=mpair<=n0
220  constexpr auto mpair = (ABK1Number * MNPerWmma * sizeof(LDSTypeAB) > 128)
221  ? 1
222  : ((128 / (ABK1Number * MNPerWmma * sizeof(LDSTypeAB))) > MN0
223  ? MN0
224  : 128 / (ABK1Number * MNPerWmma * sizeof(LDSTypeAB)));
225 
226  constexpr auto ab_lds_block_desc = make_naive_tensor_descriptor_packed(
230  Number<kfold * MN0 / mpair>{},
231  Number<mpair>{},
232  ABK1Number));
233 
234  constexpr auto ab_lds_block_desc_permuted = transform_tensor_descriptor(
235  ab_lds_block_desc,
236  make_tuple(
240  make_tuple(Number<KThreadReadPerm * MN1>{}, Number<kfold * MN0 / mpair>{})),
243  make_tuple(
245  make_tuple(
247 
248  constexpr auto ab_lds_block_desc_unmerged = transform_tensor_descriptor(
249  ab_lds_block_desc_permuted,
250  make_tuple(
258  Sequence<1>{},
259  Sequence<2>{},
260  Sequence<3>{},
261  Sequence<4>{},
262  Sequence<5>{}),
264  Sequence<2>{},
265  Sequence<0, 3>{},
266  Sequence<4, 5>{},
267  Sequence<6>{},
268  Sequence<7>{}));
269 
270  constexpr auto ab_lds_block_desc_abk0_mn_abk1 = transform_tensor_descriptor(
271  ab_lds_block_desc_unmerged,
274  Number<KThreadWrite / kfold / KThreadReadPerm>{},
275  Number<kfold>{},
282 
283  return ab_lds_block_desc_abk0_mn_abk1;
284  }
285  }
286 
287  template <typename GridDescriptor,
288  typename BlockDescriptor,
289  typename ABsDataType,
290  typename ABElementwiseOperation,
291  index_t GlobalBufferNum>
292  __device__ static auto GetBlockTransfer(GridDescriptor& grid_descriptor,
293  BlockDescriptor& block_descriptor,
294  ABElementwiseOperation& ab_element_op,
295  const index_t block_mn_id)
296  {
297  constexpr index_t NumABTensor = ABsDataType::Size();
298  const index_t mn_block_data_idx_on_grid =
299  __builtin_amdgcn_readfirstlane(block_mn_id * MNPerBlock);
300  // workaround because v7r2 is not as general as v4r1
301  if constexpr(NumABTensor > 1)
302  {
303  const auto idx_as_block_begin = generate_tuple(
304  [&](auto) { return make_multi_index(0, mn_block_data_idx_on_grid, 0); },
306 
309  ABsDataType,
311  GridDescriptor,
312  decltype(tie(block_descriptor)),
313  ABElementwiseOperation,
316  ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1,
317  ABBlockTransferThreadClusterArrangeOrder,
318  ABBlockTransferSrcAccessOrder,
320  ABBlockTransferSrcVectorDim,
321  2,
322  ABBlockTransferSrcScalarPerVector,
323  ABBlockTransferDstScalarPerVector_ABK1,
326  GlobalBufferNum>{grid_descriptor,
327  idx_as_block_begin,
328  tie(block_descriptor),
329  make_tuple(make_multi_index(0, 0, 0)),
330  ab_element_op};
331  }
332  else
333  {
336  ABElementwiseOperation,
340  ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1,
341  ABBlockTransferThreadClusterArrangeOrder,
344  decltype(grid_descriptor[I0]),
345  decltype(block_descriptor),
346  ABBlockTransferSrcAccessOrder,
348  ABBlockTransferSrcVectorDim,
349  2,
350  ABBlockTransferSrcScalarPerVector,
351  ABBlockTransferDstScalarPerVector_ABK1,
352  1,
353  1,
354  ABThreadTransferSrcResetCoordinateAfterRun,
355  true,
356  GlobalBufferNum>(grid_descriptor[I0],
357  make_multi_index(0, mn_block_data_idx_on_grid, 0),
358  ab_element_op,
359  block_descriptor,
360  make_multi_index(0, 0, 0),
362  }
363  }
364 
365  template <index_t MNRepeat, index_t MNWaves>
366  __host__ __device__ static constexpr auto MakeWmmaTileDescriptor()
367  {
368  // This is a block descriptor used to read LDS memory into register
369  // It's defined in a way consistent with the existing implementation to
370  // avoid changes in the pipelines
371  using BlockDesc = decltype(GetBlockDescriptor());
372  // ABK0_MN_ABK1 -> ABK0_MNRepeat_MNWaves_KRow_MNPerWmma_ABK1
373  constexpr auto ABK0 = BlockDesc{}.GetLength(I0);
374  constexpr auto ABK1 = BlockDesc{}.GetLength(I2);
375 #ifdef __gfx12__
376  constexpr auto KRow = I2;
377 #else
378  constexpr auto KRow = I1;
379 #endif
380  if constexpr(KInner > 1)
381  {
382  // KPack = KInner * KPerWmma
383  // K1 = KInner * KPerWmmaBlk
384  // Each thread loads multiple tiles with one instruction
385  // 1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1
387  BlockDesc{},
388  make_tuple(
395  }
396  else
397  {
398  // KPack = KPerWmma (KInner == 1)
399  if constexpr(ABK1 <= KPerWmmaBlk)
400  {
401  // K1 <= single tile (KPerWmmaBlk)
402  // Each thread will load KPerWmmaBlk for the WMMA instruction
403  // Since K1 <= single tile, K0 is unmerged first over KPack / KRow / K1
404  // (rest of the single WMMA tile for single thread) and then over KRow
405  // (rest of the single WMMA tile for single wave)
406  // KPack / KRow / K1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1
408  BlockDesc{},
409  make_tuple(
411  Number<ABK0 / (KPack / ABK1)>{}, KRow, Number<KPack / KRow / ABK1>{})),
417  }
418  else
419  {
420  // K1 > single tile (KPerWmmaBlk)
421  // Each thread will load KPerWmmaBlk for the WMMA instruction
422  // Since K1 > single tile, each thread loads KPerWmmaBlk and the next
423  // KPerWmmaBlk chunk is loaded by a different thread in the same wave (WMMA layout).
424  // This layout is needed to support for example AK1 > single tile and
425  // BK1 <= single tile in the same gemm
426  // KPack / KPerWmmaBlk / KRow - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma -
427  // K1
428  constexpr auto desc1 = transform_tensor_descriptor(
429  BlockDesc{},
430  make_tuple(
435  Number<KPack / KPerWmmaBlk / KRow>{},
436  Number<KRow>{},
437  Number<KPerWmmaBlk>{}))),
440 
442  desc1,
443  make_tuple(
446  make_merge_transform(make_tuple(Number<ABK0>{}, Number<ABK1 / KPack>{})),
452  Sequence<1>{},
453  Sequence<2, 3>{},
454  Sequence<4>{},
455  Sequence<5>{},
456  Sequence<6>{},
457  Sequence<7>{}),
459  Sequence<1>{},
460  Sequence<2>{},
461  Sequence<3>{},
462  Sequence<4>{},
463  Sequence<5>{},
464  Sequence<6>{}));
465  }
466  }
467  }
468 
469  __device__ static constexpr auto GetBlockStep()
470  {
471  // Grid descriptor step (MoveSrcSliceWindow)
472  return make_multi_index(KPerBlock / ABK1Number, 0, 0);
473  }
474 
475  template <typename GridDescriptor>
476  __device__ static constexpr index_t GetKDimension(const GridDescriptor& grid_desc)
477  {
478  // K dimension size. This should always be called with the A matrix grid descriptor
479  // because it doesn't work for B matrix when packed int4 is used
480  return grid_desc.GetLength(I0) * grid_desc.GetLength(I2);
481  }
482 };
483 
484 } // namespace ck
auto grid_desc(MatrixPadder< GemmSpec, MPerTileType, NPerTileType, KPerTileType > matrix_padder, CDesc_MRaw_NRaw conv_desc)
Definition: matrix_padder.hpp:190
ck_tile::element_wise::PassThrough PassThrough
Definition: grouped_convolution_utils.hpp:47
Definition: ck.hpp:270
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
typename uniform_sequence_gen< NSize, I >::type uniform_sequence_gen_t
Definition: sequence.hpp:928
__host__ constexpr __device__ auto generate_tuple(F &&f, Number< N >)
Definition: tuple_helper.hpp:21
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
constexpr Tuple< Args &... > tie(Args &... args) noexcept
Definition: tuple.hpp:218
__host__ constexpr __device__ auto make_xor_with_modulo_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:185
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
constexpr bool is_same_v
Definition: type.hpp:283
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:301
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__host__ constexpr __device__ auto make_right_pad_transform(const LowLength &low_length, const RightPadLength &right_pad, integral_constant< bool, SkipIsValidCheck >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:37
Definition: gridwise_ab_transfer_thread_tiles.hpp:33
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: gridwise_ab_transfer_thread_tiles.hpp:48
static constexpr auto I1
Definition: gridwise_ab_transfer_thread_tiles.hpp:38
static constexpr auto ABK0Number
Definition: gridwise_ab_transfer_thread_tiles.hpp:34
static constexpr __device__ auto GetBlockStep()
Definition: gridwise_ab_transfer_thread_tiles.hpp:469
__host__ static __device__ auto MakeGridDescriptor(const GridDescriptorBase &ab_grid_desc, index_t MN, index_t MNPad, index_t K, index_t KPad, index_t StrideAB, index_t ABK0)
Definition: gridwise_ab_transfer_thread_tiles.hpp:51
static constexpr auto I0
Definition: gridwise_ab_transfer_thread_tiles.hpp:37
static __device__ auto GetBlockTransfer(GridDescriptor &grid_descriptor, BlockDescriptor &block_descriptor, ABElementwiseOperation &ab_element_op, const index_t block_mn_id)
Definition: gridwise_ab_transfer_thread_tiles.hpp:292
static constexpr index_t ABPackedSize
Definition: gridwise_ab_transfer_thread_tiles.hpp:41
static constexpr __device__ auto GetBlockDescriptor()
Definition: gridwise_ab_transfer_thread_tiles.hpp:147
static constexpr auto ABK1Number
Definition: gridwise_ab_transfer_thread_tiles.hpp:35
__host__ static constexpr __device__ auto MakeWmmaTileDescriptor()
Definition: gridwise_ab_transfer_thread_tiles.hpp:366
static constexpr __device__ index_t GetKDimension(const GridDescriptor &grid_desc)
Definition: gridwise_ab_transfer_thread_tiles.hpp:476
static constexpr auto I2
Definition: gridwise_ab_transfer_thread_tiles.hpp:39
Definition: sequence.hpp:43
Blockwise data transfer.
Definition: thread_group_tensor_slice_transfer_v4r1.hpp:46
Definition: thread_group_tensor_slice_transfer_v7r2.hpp:47
Definition: tuple.hpp:117
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: data_type.hpp:187
Definition: unary_element_wise_operation.hpp:340