/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp Source File
blockwise_gemm_pipeline_xdlops_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Compute optimized pipeline
11 // GlobalPrefetchStages: 2
12 // LocalPreFillStages: 1
13 // LocalPreFetchStages: 1
14 // LocalSharedMemoryBuffer: 1
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeDataType,
21  typename AccDataType,
22  typename ATileDesc,
23  typename BTileDesc,
24  typename AMmaTileDesc,
25  typename BMmaTileDesc,
26  index_t ABlockTransferSrcScalarPerVector,
27  index_t BBlockTransferSrcScalarPerVector,
28  index_t MPerBlock,
29  index_t NPerBlock,
30  index_t KPerBlock,
31  index_t MPerXDL,
32  index_t NPerXDL,
33  index_t MRepeat,
34  index_t NRepeat,
35  index_t KPacks>
37 {
38 };
39 
40 template <index_t BlockSize,
41  typename ADataType,
42  typename BDataType,
43  typename ComputeDataType,
44  typename AccDataType,
45  typename ATileDesc,
46  typename BTileDesc,
47  typename AMmaTileDesc,
48  typename BMmaTileDesc,
49  index_t ABlockTransferSrcScalarPerVector,
50  index_t BBlockTransferSrcScalarPerVector,
51  index_t MPerBlock,
52  index_t NPerBlock,
53  index_t KPerBlock,
54  index_t MPerXDL,
55  index_t NPerXDL,
56  index_t MRepeat,
57  index_t NRepeat,
58  index_t KPack
59  // ,bool TransposeC //disable transposec right now...
60  >
62  BlockSize,
63  ADataType,
64  BDataType,
65  ComputeDataType,
66  AccDataType,
67  ATileDesc,
68  BTileDesc,
69  AMmaTileDesc,
70  BMmaTileDesc,
71  ABlockTransferSrcScalarPerVector,
72  BBlockTransferSrcScalarPerVector,
73  MPerBlock,
74  NPerBlock,
75  KPerBlock,
76  MPerXDL,
77  NPerXDL,
78  MRepeat,
79  NRepeat,
80  KPack>
82  ADataType,
83  BDataType,
84  ComputeDataType,
85  AccDataType,
86  ATileDesc,
87  BTileDesc,
88  AMmaTileDesc,
89  BMmaTileDesc,
90  ABlockTransferSrcScalarPerVector,
91  BBlockTransferSrcScalarPerVector,
92  MPerBlock,
93  NPerBlock,
94  KPerBlock,
95  MPerXDL,
96  NPerXDL,
97  MRepeat,
98  NRepeat,
99  KPack>
100 
101 {
103  ADataType,
104  BDataType,
105  ComputeDataType,
106  AccDataType,
107  ATileDesc,
108  BTileDesc,
109  AMmaTileDesc,
110  BMmaTileDesc,
111  ABlockTransferSrcScalarPerVector,
112  BBlockTransferSrcScalarPerVector,
113  MPerBlock,
114  NPerBlock,
115  KPerBlock,
116  MPerXDL,
117  NPerXDL,
118  MRepeat,
119  NRepeat,
120  KPack>;
121  using Base::I0;
122  using Base::I1;
123  using Base::KRepeat;
124  using Base::xdlops_gemm;
125  using typename Base::HotLoopInstList;
126 
127  using Base::CalculateCThreadOriginDataIndex;
128  using Base::CalculateCThreadOriginDataIndex8D;
129  using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
130  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
131  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
132  using Base::GetCThreadBuffer;
133  using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
134  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
135  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
136  using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
137  using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
138 
139  using Base::a_block_desc_m0_m1_m2_k;
140  using Base::b_block_desc_n0_n1_n2_k;
141 
142  using Base::AMmaKStride;
143  using Base::BMmaKStride;
144 
145  static constexpr index_t PrefetchStages = 2;
146  static constexpr index_t PrefillStages = 1;
147  static constexpr index_t GlobalBufferNum = 1;
148 
149  __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
150  {
151  return num_loop > PrefetchStages;
152  }
153 
154  __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
155  {
156  ignore = num_loop;
157  return TailNumber::Full;
158  }
159 
160  __device__ static constexpr auto HotLoopScheduler()
161  {
162  // A/B split schedule
163  // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
164  constexpr auto num_ds_read_inst_a =
165  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
166  ? HotLoopInstList::A_LDS_Read_Inst_Num
167  : HotLoopInstList::A_LDS_Read_Inst_Num / 2;
168  constexpr auto num_ds_read_inst_b =
169  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
170  ? HotLoopInstList::B_LDS_Read_Inst_Num
171  : HotLoopInstList::B_LDS_Read_Inst_Num / 2;
172 
173  constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
174  constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
175 
176  constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
177  constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
178 
179  constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
180 
181  constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
182  constexpr auto ds_read_a_issue_cycle =
183  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
184  constexpr auto ds_read_b_issue_cycle =
185  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
186  constexpr auto ds_read_a_mfma_rate =
187  (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
188  constexpr auto ds_read_b_mfma_rate =
189  (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
190 
191  constexpr auto num_dsread_a_mfma =
192  (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
193  constexpr auto num_dsread_b_mfma =
194  (num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
195 
196  // stage 1
197  // Separate this part?
198  // constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
199  // sizeof(ComputeDataType) / sizeof(BDataType)
200  // ? sizeof(ComputeDataType) / sizeof(ADataType)
201  // : sizeof(ComputeDataType) / sizeof(BDataType);
202  constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
203  constexpr auto num_mfma_per_issue =
204  num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
205  constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
206  constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
207 
209  ignore = i;
210  static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
211  ignore = idswrite;
212  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
213  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
214  });
215  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
216  __builtin_amdgcn_sched_group_barrier(
217  0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
218  });
220  ignore = i;
221  static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
222  ignore = idswrite;
223  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
224  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
225  });
226  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
227  __builtin_amdgcn_sched_group_barrier(
228  0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
229  });
230 
231  // stage 2
233  if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
234  ds_read_a_mfma_rate)
235  {
236  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
237  }
238  else
239  {
240  __builtin_amdgcn_sched_group_barrier(0x100,
241  num_ds_read_inst_a - (num_dsread_a_mfma - 1) *
242  ds_read_a_mfma_rate,
243  0); // DS read
244  }
245  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
246  });
247 
249  if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
250  ds_read_b_mfma_rate)
251  {
252  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
253  }
254  else
255  {
256  __builtin_amdgcn_sched_group_barrier(0x100,
257  num_ds_read_inst_b - (num_dsread_b_mfma - 1) *
258  ds_read_b_mfma_rate,
259  0); // DS read
260  }
261  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
262  });
263  }
264 
265  template <bool HasMainLoop,
266  TailNumber TailNum,
267  typename AGridDesc,
268  typename ABlockDesc,
269  typename ABlockTransfer,
270  typename AGridBuffer,
271  typename ABlockBuffer,
272  typename ABlockTransferStep,
273  typename BGridDesc,
274  typename BBlockDesc,
275  typename BBlockTransfer,
276  typename BGridBuffer,
277  typename BBlockBuffer,
278  typename BBlockTransferStep,
279  typename CThreadBuffer>
280  __device__ void Run(const AGridDesc& a_grid_desc,
281  const ABlockDesc& a_block_desc,
282  ABlockTransfer& a_blockwise_copy,
283  const AGridBuffer& a_grid_buf,
284  ABlockBuffer& a_block_buf,
285  const ABlockTransferStep& a_block_copy_step,
286  const BGridDesc& b_grid_desc,
287  const BBlockDesc& b_block_desc,
288  BBlockTransfer& b_blockwise_copy,
289  const BGridBuffer& b_grid_buf,
290  BBlockBuffer& b_block_buf,
291  const BBlockTransferStep& b_block_copy_step,
292  CThreadBuffer& c_thread_buf,
293  index_t num_loop) const
294  {
295  __builtin_amdgcn_sched_barrier(0);
296  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
297  a_thread_desc_.GetElementSpaceSize());
298  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
299  b_thread_desc_.GetElementSpaceSize());
300 
301  // Global prefetch 1
302  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
303  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
304 
305  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
306  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
307 
308  // Local prefill 1
309  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
310  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
311 
312  // Global prefetch 2
313  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
314  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
315 
316  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
317  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
318 
319  // Initialize C
320  c_thread_buf.Clear();
321 
322  // Local prefetch 1
323  block_sync_lds();
324  static_for<0, KRepeat, 1>{}([&](auto k0) {
325  static_for<0, MRepeat, 1>{}([&](auto m0) {
326  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
327  make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
328  a_block_buf,
329  a_thread_desc_,
330  make_tuple(m0, I0, k0, I0),
331  a_thread_buf);
332  });
333  static_for<0, NRepeat, 1>{}([&](auto n0) {
334  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
335  make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
336  b_block_buf,
337  b_thread_desc_,
338  make_tuple(n0, I0, k0, I0),
339  b_thread_buf);
340  });
341  });
342 
343  __builtin_amdgcn_sched_barrier(0);
344 
345  // main body
346  if constexpr(HasMainLoop)
347  {
348  index_t i = 0;
349  do
350  {
351  block_sync_lds();
352 
353  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
354  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
355 
356  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
357  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
358 
359  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
360  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
361 
362  static_for<0, KRepeat, 1>{}([&](auto k0) {
363  static_for<0, MRepeat, 1>{}([&](auto m0) {
364  static_for<0, NRepeat, 1>{}([&](auto n0) {
367 
368  static_for<0, KPack, 1>{}([&](auto ik) {
369  a_thread_vec.template AsType<ComputeDataType>()(ik) =
370  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
371  make_tuple(m0, I0, k0, ik))>{}];
372  b_thread_vec.template AsType<ComputeDataType>()(ik) =
373  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
374  make_tuple(n0, I0, k0, ik))>{}];
375  });
376 
377  using mfma_input_type =
378  typename vector_type<ComputeDataType,
379  xdlops_gemm.K1PerXdlops>::type;
380 
381  constexpr index_t c_offset =
382  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
383 
384  xdlops_gemm.Run(
385  a_thread_vec.template AsType<mfma_input_type>(),
386  b_thread_vec.template AsType<mfma_input_type>(),
387  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
388  });
389  });
390  });
391 
392  block_sync_lds();
393 
394  static_for<0, KRepeat, 1>{}([&](auto k0) {
395  static_for<0, MRepeat, 1>{}([&](auto m0) {
396  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
397  make_tuple(m0, I0, I0, Number<k0 * AMmaKStride>{}),
398  a_block_buf,
399  a_thread_desc_,
400  make_tuple(m0, I0, k0, I0),
401  a_thread_buf);
402  });
403  static_for<0, NRepeat, 1>{}([&](auto n0) {
404  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
405  make_tuple(n0, I0, I0, Number<k0 * BMmaKStride>{}),
406  b_block_buf,
407  b_thread_desc_,
408  make_tuple(n0, I0, k0, I0),
409  b_thread_buf);
410  });
411  });
412 
413  HotLoopScheduler();
414  __builtin_amdgcn_sched_barrier(0);
415 
416  i += 1;
417  } while(i < (num_loop - 1));
418  }
419  // tail
420  if constexpr(TailNum == TailNumber::Full)
421  {
422  static_for<0, KRepeat, 1>{}([&](auto k0) {
423  static_for<0, MRepeat, 1>{}([&](auto m0) {
424  static_for<0, NRepeat, 1>{}([&](auto n0) {
427 
428  static_for<0, KPack, 1>{}([&](auto ik) {
429  a_thread_vec.template AsType<ComputeDataType>()(ik) =
430  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
431  make_tuple(m0, I0, k0, ik))>{}];
432  b_thread_vec.template AsType<ComputeDataType>()(ik) =
433  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
434  make_tuple(n0, I0, k0, ik))>{}];
435  });
436 
437  using mfma_input_type =
438  typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
439 
440  constexpr index_t c_offset =
441  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
442 
443  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
444  b_thread_vec.template AsType<mfma_input_type>(),
445  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
446  });
447  });
448  });
449  // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
450  // latency
451  // __builtin_amdgcn_sched_barrier(0);
452  }
453  }
454 
455  protected:
456  using Base::a_thread_copy_;
457  using Base::a_thread_desc_;
458  using Base::b_thread_copy_;
459  using Base::b_thread_desc_;
460  using Base::c_thread_desc_;
461 };
462 
463 } // namespace ck
Definition: ck.hpp:264
TailNumber
Definition: blkgemmpipe_scheduler.hpp:18
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v3.hpp:280
Definition: blockwise_gemm_pipeline_xdlops_v3.hpp:37
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
Definition: data_type.hpp:347