/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp Source File
blockwise_gemm_pipeline_xdlops_base.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
11 
12 namespace ck {
13 
14 template <index_t BlockSize,
15  typename ADataType,
16  typename BDataType,
17  typename ComputeDataType,
18  typename AccDataType,
19  typename ATileDesc,
20  typename BTileDesc,
21  typename AMmaTileDesc,
22  typename BMmaTileDesc,
23  index_t ABlockTransferSrcScalarPerVector,
24  index_t BBlockTransferSrcScalarPerVector,
25  index_t MPerBlock,
26  index_t NPerBlock,
27  index_t KPerBlock,
28  index_t MPerXDL,
29  index_t NPerXDL,
30  index_t MRepeat,
31  index_t NRepeat,
32  index_t KPack,
33  bool TransposeC = false>
35 {
36  static constexpr auto I0 = Number<0>{};
37  static constexpr auto I1 = Number<1>{};
38  static constexpr auto I2 = Number<2>{};
39  static constexpr auto I3 = Number<3>{};
40 
42 
43  // Hardcode to 64, as HIP-provided "warpSize" would return 32 on RDNA GPUs.
44  static constexpr index_t WaveSize = 64;
45 
46  static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
47  static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
48  static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
49  static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
50 
51  static constexpr auto xdlops_gemm =
53 
54  static constexpr index_t AMmaKStride = KPack;
55  static constexpr index_t BMmaKStride = KPack;
56 
57  static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
58  static constexpr index_t KRepeat = KPerThread / KPack;
59 
60  static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
61  static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
62 
65  MPerBlock,
66  NPerBlock,
67  KPerBlock,
68  ABlockTransferSrcScalarPerVector,
69  BBlockTransferSrcScalarPerVector,
70  A_K1,
71  B_K1,
72  A_K1,
73  B_K1,
74  MRepeat,
75  NRepeat,
76  MPerXDL,
77  NPerXDL,
78  xdlops_gemm.KPerXdlops>;
79 
80  static_assert(KPerThread % KPack == 0,
81  "Wrong KPack setting; try increasing KPerThread or decreasing KPack");
82 
84  AccDataType,
85  MRepeat * NRepeat,
86  xdlops_gemm.GetRegSizePerXdlops(),
87  true>
89 
90  __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
91 
92  __device__ static auto GetWaveIdx()
93  {
94  const index_t thread_id = ThisThreadBlock::GetThreadId();
95 
96  constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
100 
101  return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
102  }
103 
104  __device__ static auto CalculateAThreadOriginDataIndex()
105  {
106  const auto wave_idx = GetWaveIdx();
107 
108  const auto waveId_m = wave_idx[I0];
109 
110  const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
111 
112  return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]);
113  }
114 
115  __device__ static auto CalculateBThreadOriginDataIndex()
116  {
117  const auto wave_idx = GetWaveIdx();
118 
119  const auto waveId_n = wave_idx[I1];
120 
121  const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
122 
123  return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]);
124  }
125 
126  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
127  __device__ static auto
129  {
130  const auto wave_idx = GetWaveIdx();
131 
132  const auto waveId_m = wave_idx[I0];
133  const auto waveId_n = wave_idx[I1];
134 
135  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
136 
137  constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
138  make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
141 
142  constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
143  make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
146 
147  const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
148  make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
149  const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
150  make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
151 
152  return make_tuple(c_thread_m, c_thread_n);
153  }
154 
155  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
156  __device__ static auto
158  {
159  const auto wave_idx = GetWaveIdx();
160 
161  const auto waveId_m = wave_idx[I0];
162  const auto waveId_n = wave_idx[I1];
163 
164  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
165 
166  return make_tuple(
167  m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
168  }
169 
171 
172  __host__ __device__
175  : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
176  {
177  static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
178  "wrong! Desc should be known at compile-time");
179 
180  static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
181  "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
182 
183  static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
184  "wrong!");
185  }
186 
187  // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
188  __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
189  {
190  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
191 
192  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
193  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
194  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
195  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
196 
198  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
199  }
200 
201  // XDL output supporting C_xdl = A_xdl * B_xdl
202  __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
203  {
204  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
205 
206  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
207  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
208  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
209  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
210 
212  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
213  }
214 
215  __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
216  {
217  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
218 
219  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
220  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
221  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
222  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
223 
225  make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
226  }
227 
228  // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
229  __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
230  {
231  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
233  Number<NRepeat>{},
234  Number<MWaves>{},
235  Number<NWaves>{},
236  Number<MPerXDL>{},
237  Number<NPerXDL>{}));
238 
239  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
240  }
241 
242  // XDL output supporting C_xdl = A_xdl * B_xdl
243  __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
244  {
245  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
247  Number<NRepeat>{},
248  Number<MWaves>{},
249  Number<NWaves>{},
250  Number<MPerXDL>{},
251  Number<NPerXDL>{}));
252 
253  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
254  }
255 
256  __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
257  {
258  constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
260  Number<MRepeat>{},
261  Number<NRepeat>{},
262  Number<MWaves>{},
263  Number<NWaves>{},
264  Number<MPerXDL>{},
265  Number<NPerXDL>{}));
266 
267  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
268  c_block_desc_g_m0_n0_m1_n1_m2_n2);
269  }
270 
271  template <typename CGridDesc_M_N>
272  __host__ __device__ static constexpr auto
273  MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
274  {
275  const auto M = c_grid_desc_m_n.GetLength(I0);
276  const auto N = c_grid_desc_m_n.GetLength(I1);
277 
278  const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
279  c_grid_desc_m_n,
280  make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
281  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
284 
285  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
286  }
287 
288  template <typename CGridDesc_G_M_N>
289  __host__ __device__ static constexpr auto
290  MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
291  {
292  const auto G = c_grid_desc_g_m_n.GetLength(I0);
293  const auto M = c_grid_desc_g_m_n.GetLength(I1);
294  const auto N = c_grid_desc_g_m_n.GetLength(I2);
295 
296  const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
297  c_grid_desc_g_m_n,
299  make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
300  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
303 
304  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
305  c_grid_desc_g_m0_n0_m1_n1_m2_n2);
306  }
307 
308  static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
309  static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
310 
311  protected:
312  // M1, N1 as double buffer index
313  // Read buffer + Compute buffer
314  // A[M0, M1, M2, KPack]
316  make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
317  make_tuple(
318  Number<KPack>{}, Number<KRepeat * MRepeat * KPack>{}, Number<MRepeat * KPack>{}, I1));
319 
320  // B[N0, N1, N2, KPack]
322  make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
323  make_tuple(
324  Number<KPack>{}, Number<KRepeat * NRepeat * KPack>{}, Number<NRepeat * KPack>{}, I1));
325 
326  // C[M, N, NumRegXdlops]
328  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
329 
331  ComputeDataType,
332  decltype(a_block_desc_m0_m1_m2_k),
333  decltype(a_thread_desc_),
336  3,
337  A_K1,
338  A_K1>;
339 
341  ComputeDataType,
342  decltype(b_block_desc_n0_n1_n2_k),
343  decltype(b_thread_desc_),
346  3,
347  B_K1,
348  B_K1>;
349 
352 };
353 
354 } // namespace ck
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, AccDataType, MRepeat *NRepeat, xdlops_gemm.GetRegSizePerXdlops(), true > c_thread_buf_
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:81
static constexpr index_t NWaves
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:61
__host__ __device__ BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:173
static constexpr index_t MWaves
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:60
static constexpr index_t A_K0
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:46
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:229
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:243
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:327
static constexpr auto xdlops_gemm
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:51
static __device__ auto CalculateBThreadOriginDataIndex()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:115
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:309
static __device__ auto GetWaveIdx()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:92
static constexpr auto I1
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:37
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:188
static constexpr index_t AMmaKStride
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:54
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:351
static __device__ auto CalculateAThreadOriginDataIndex()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:104
static constexpr index_t WaveSize
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:44
static constexpr index_t B_K1
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:49
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:41
__host__ static constexpr __device__ auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:290
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:273
static constexpr auto I0
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:36
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:202
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:128
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:157
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:315
static constexpr index_t KRepeat
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:58
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:308
__host__ static constexpr __device__ auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:256
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:321
static constexpr auto I2
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:38
static constexpr auto I3
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:39
static constexpr index_t A_K1
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:48
static constexpr index_t BMmaKStride
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:55
decltype(CalculateAThreadOriginDataIndex()) Tuple4
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:170
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:350
static constexpr index_t KPerThread
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:57
__host__ static constexpr __device__ auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:215
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:90
static constexpr index_t B_K0
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:47
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
Definition: sequence.hpp:43
Definition: static_buffer.hpp:75
static __device__ index_t GetThreadId()
Definition: thread_group.hpp:19
static constexpr __device__ index_t GetNumOfThread()
Definition: thread_group.hpp:15
Definition: xdlops_gemm.hpp:1181
Definition: integral_constant.hpp:10