/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp Source File
blockwise_gemm_pipeline_wmmaops_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Compute optimized pipeline
11 // GlobalPrefetchStages: 2
12 // LocalPreFillStages: 1
13 // LocalPreFetchStages: 1
14 // LocalSharedMemoryBuffer: 1
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeTypeA,
21  typename ComputeTypeB,
22  typename AccDataType,
23  typename AWmmaTileDesc,
24  typename BWmmaTileDesc,
25  index_t ABlockTransferSrcScalarPerVector,
26  index_t BBlockTransferSrcScalarPerVector,
27  index_t MPerBlock,
28  index_t NPerBlock,
29  index_t KPerBlock,
30  index_t MPerWmma,
31  index_t NPerWmma,
32  index_t MRepeat,
33  index_t NRepeat,
34  index_t KPack>
36 {
37 };
38 
39 template <index_t BlockSize,
40  typename ADataType,
41  typename BDataType,
42  typename ComputeTypeA,
43  typename ComputeTypeB,
44  typename AccDataType,
45  typename AWmmaTileDesc,
46  typename BWmmaTileDesc,
47  index_t ABlockTransferSrcScalarPerVector,
48  index_t BBlockTransferSrcScalarPerVector,
49  index_t MPerBlock,
50  index_t NPerBlock,
51  index_t KPerBlock,
52  index_t MPerWmma,
53  index_t NPerWmma,
54  index_t MRepeat,
55  index_t NRepeat,
56  index_t KPack>
58  BlockSize,
59  ADataType,
60  BDataType,
61  ComputeTypeA,
62  ComputeTypeB,
63  AccDataType,
64  AWmmaTileDesc,
65  BWmmaTileDesc,
66  ABlockTransferSrcScalarPerVector,
67  BBlockTransferSrcScalarPerVector,
68  MPerBlock,
69  NPerBlock,
70  KPerBlock,
71  MPerWmma,
72  NPerWmma,
73  MRepeat,
74  NRepeat,
75  KPack>
77  ADataType,
78  BDataType,
79  ComputeTypeA,
80  ComputeTypeB,
81  AccDataType,
82  AWmmaTileDesc,
83  BWmmaTileDesc,
84  ABlockTransferSrcScalarPerVector,
85  BBlockTransferSrcScalarPerVector,
86  MPerBlock,
87  NPerBlock,
88  KPerBlock,
89  MPerWmma,
90  NPerWmma,
91  MRepeat,
92  NRepeat,
93  KPack>
94 {
96  ADataType,
97  BDataType,
98  ComputeTypeA,
99  ComputeTypeB,
100  AccDataType,
101  AWmmaTileDesc,
102  BWmmaTileDesc,
103  ABlockTransferSrcScalarPerVector,
104  BBlockTransferSrcScalarPerVector,
105  MPerBlock,
106  NPerBlock,
107  KPerBlock,
108  MPerWmma,
109  NPerWmma,
110  MRepeat,
111  NRepeat,
112  KPack>;
113  using Base::I0;
114 
115  using Base::A_K1;
116  using Base::A_KRow;
117  using Base::B_K1;
118  using Base::B_KRow;
119  using Base::KRepeat;
120  using Base::WmmaK;
121 
122  using Base::wmma_gemm;
123  using typename Base::HotLoopInstList;
124 
125  using Base::CalculateCThreadOriginDataIndex;
126  using Base::
127  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
128  using Base::GetCThreadBuffer;
129  using Base::
130  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
131 
132  using Base::a_block_desc_k0_m0_m1_m2_k1;
133  using Base::b_block_desc_k0_n0_n1_n2_k1;
134 
135  static constexpr index_t PrefetchStages = 2;
136  static constexpr index_t PrefillStages = 1;
137  static constexpr index_t GlobalBufferNum = 1;
138 
139  __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
140  {
141  return num_loop > PrefetchStages;
142  }
143 
144  __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
145  {
146  ignore = num_loop;
147  return TailNumber::Full;
148  }
149 
150  __device__ static constexpr auto HotLoopScheduler()
151  {
152  // TODO: Calculation of the number of instructions may require changes for WMMA
153  /*
154  // A/B split schedule
155  // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
156  constexpr auto num_ds_read_inst_a =
157  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
158  ? HotLoopInstList::A_LDS_Read_Inst_Num
159  : HotLoopInstList::A_LDS_Read_Inst_Num / 2;
160  constexpr auto num_ds_read_inst_b =
161  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
162  ? HotLoopInstList::B_LDS_Read_Inst_Num
163  : HotLoopInstList::B_LDS_Read_Inst_Num / 2;
164 
165  constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
166  constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
167 
168  constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
169  constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
170 
171  constexpr auto num_wmma_inst = HotLoopInstList::C_WMMA_Inst_Num;
172 
173  constexpr auto wmma_cycle = NPerWmma == 16 ? 16 : 32;
174  constexpr auto ds_read_a_issue_cycle =
175  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
176  constexpr auto ds_read_b_issue_cycle =
177  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
178  constexpr auto ds_read_a_wmma_rate =
179  (wmma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
180  constexpr auto ds_read_b_wmma_rate =
181  (wmma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
182 
183  constexpr auto num_dsread_a_wmma =
184  (num_ds_read_inst_a + ds_read_a_wmma_rate - 1) / ds_read_a_wmma_rate;
185  constexpr auto num_dsread_b_wmma =
186  (num_ds_read_inst_b + ds_read_b_wmma_rate - 1) / ds_read_b_wmma_rate;
187 
188  // stage 1
189  // Separate this part?
190  // constexpr auto num_wmma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
191  // sizeof(ComputeDataType) / sizeof(BDataType)
192  // ? sizeof(ComputeDataType) / sizeof(ADataType)
193  // : sizeof(ComputeDataType) / sizeof(BDataType);
194  constexpr auto num_wmma_stage1 = num_wmma_inst - (num_dsread_a_wmma + num_dsread_b_wmma);
195  constexpr auto num_wmma_per_issue =
196  num_wmma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
197  constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
198  constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
199 
200  static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
201  ignore = i;
202  static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
203  ignore = idswrite;
204  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
205  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
206  });
207  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
208  __builtin_amdgcn_sched_group_barrier(
209  0x008, num_wmma_per_issue - num_dswrite_per_issue_a, 0); // WMMA
210  });
211  static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
212  ignore = i;
213  static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
214  ignore = idswrite;
215  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
216  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
217  });
218  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
219  __builtin_amdgcn_sched_group_barrier(
220  0x008, num_wmma_per_issue - num_dswrite_per_issue_b, 0); // WMMA
221  });
222 
223  // stage 2
224  static_for<0, num_dsread_a_wmma, 1>{}([&](auto i) {
225  if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_wmma_rate) >=
226  ds_read_a_wmma_rate)
227  {
228  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_wmma_rate, 0); // DS read
229  }
230  else
231  {
232  __builtin_amdgcn_sched_group_barrier(0x100,
233  num_ds_read_inst_a - (num_dsread_a_wmma - 1) *
234  ds_read_a_wmma_rate,
235  0); // DS read
236  }
237  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
238  });
239 
240  static_for<0, num_dsread_b_wmma, 1>{}([&](auto i) {
241  if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_wmma_rate) >=
242  ds_read_b_wmma_rate)
243  {
244  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_wmma_rate, 0); // DS read
245  }
246  else
247  {
248  __builtin_amdgcn_sched_group_barrier(0x100,
249  num_ds_read_inst_b - (num_dsread_b_wmma - 1) *
250  ds_read_b_wmma_rate,
251  0); // DS read
252  }
253  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
254  });
255  */
256  }
257 
258  template <bool HasMainLoop,
259  TailNumber TailNum,
260  typename AGridDesc,
261  typename ABlockDesc,
262  typename ABlockTransfer,
263  typename AGridBuffer,
264  typename ABlockBuffer,
265  typename ABlockTransferStep,
266  typename BGridDesc,
267  typename BBlockDesc,
268  typename BBlockTransfer,
269  typename BGridBuffer,
270  typename BBlockBuffer,
271  typename BBlockTransferStep,
272  typename CThreadBuffer>
273  __device__ void Run(const AGridDesc& a_grid_desc,
274  const ABlockDesc& a_block_desc,
275  ABlockTransfer& a_blockwise_copy,
276  const AGridBuffer& a_grid_buf,
277  ABlockBuffer& a_block_buf,
278  const ABlockTransferStep& a_block_copy_step,
279  const BGridDesc& b_grid_desc,
280  const BBlockDesc& b_block_desc,
281  BBlockTransfer& b_blockwise_copy,
282  const BGridBuffer& b_grid_buf,
283  BBlockBuffer& b_block_buf,
284  const BBlockTransferStep& b_block_copy_step,
285  CThreadBuffer& c_thread_buf,
286  index_t num_loop) const
287  {
288  __builtin_amdgcn_sched_barrier(0);
289  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
290  a_thread_desc_.GetElementSpaceSize());
291  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
292  b_thread_desc_.GetElementSpaceSize());
293 
294  // Global prefetch 1
295  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
296  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
297 
298  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
299  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
300 
301  // Local prefill 1
302  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
303  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
304 
305  // Global prefetch 2
306  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
307  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
308 
309  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
310  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
311 
312  // Initialize C
313  c_thread_buf.Clear();
314 
315  // Local prefetch 1
316  block_sync_lds();
317  static_for<0, KRepeat, 1>{}([&](auto k0) {
318  a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
319  make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, I0, I0, I0, I0, I0),
320  a_block_buf,
321  a_thread_desc_,
322  make_tuple(I0, I0, k0, I0, I0, I0),
323  a_thread_buf);
324  b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
325  make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, I0, I0, I0, I0, I0),
326  b_block_buf,
327  b_thread_desc_,
328  make_tuple(I0, I0, k0, I0, I0, I0),
329  b_thread_buf);
330  });
331 
332  __builtin_amdgcn_sched_barrier(0);
333 
334  // main body
335  if constexpr(HasMainLoop)
336  {
337  index_t i = 0;
338  do
339  {
340  block_sync_lds();
341 
342  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
343  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
344 
345  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
346  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
347 
348  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
349  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
350 
351  static_for<0, KRepeat, 1>{}([&](auto k0) {
352  static_for<0, MRepeat, 1>{}([&](auto m0) {
353  static_for<0, NRepeat, 1>{}([&](auto n0) {
354  vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
355  vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
356 
357  static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
358  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
359  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
361  m0,
362  k0,
363  I0,
364  I0,
365  Number<ik % A_K1>{}))>{}];
366  });
367  static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
368  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
369  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
371  n0,
372  k0,
373  I0,
374  I0,
375  Number<ik % B_K1>{}))>{}];
376  });
377 
378  using wmma_input_type_a =
379  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
380  using wmma_input_type_b =
381  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
382 
383  constexpr index_t c_offset =
384  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
385 
386  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
387  b_thread_vec.template AsType<wmma_input_type_b>(),
388  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
389  });
390  });
391  });
392 
393  block_sync_lds();
394 
395  static_for<0, KRepeat, 1>{}([&](auto k0) {
396  a_thread_copy_.Run(
397  a_block_desc_k0_m0_m1_m2_k1,
398  make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, I0, I0, I0, I0, I0),
399  a_block_buf,
400  a_thread_desc_,
401  make_tuple(I0, I0, k0, I0, I0, I0),
402  a_thread_buf);
403  b_thread_copy_.Run(
404  b_block_desc_k0_n0_n1_n2_k1,
405  make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, I0, I0, I0, I0, I0),
406  b_block_buf,
407  b_thread_desc_,
408  make_tuple(I0, I0, k0, I0, I0, I0),
409  b_thread_buf);
410  });
411 
412  HotLoopScheduler();
413  __builtin_amdgcn_sched_barrier(0);
414 
415  i += 1;
416  } while(i < (num_loop - 1));
417  }
418  // tail
419  if constexpr(TailNum == TailNumber::Full)
420  {
421  static_for<0, KRepeat, 1>{}([&](auto k0) {
422  static_for<0, MRepeat, 1>{}([&](auto m0) {
423  static_for<0, NRepeat, 1>{}([&](auto n0) {
424  vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
425  vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
426 
427  static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
428  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
429  a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
430  Number<ik / A_K1>{}, m0, k0, I0, I0, Number<ik % A_K1>{}))>{}];
431  });
432  static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
433  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
434  b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
435  Number<ik / B_K1>{}, n0, k0, I0, I0, Number<ik % B_K1>{}))>{}];
436  });
437 
438  using wmma_input_type_a =
439  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
440  using wmma_input_type_b =
441  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
442 
443  constexpr index_t c_offset =
444  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
445 
446  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
447  b_thread_vec.template AsType<wmma_input_type_b>(),
448  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
449  });
450  });
451  });
452  // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
453  // latency
454  // __builtin_amdgcn_sched_barrier(0);
455  }
456  }
457 
458  protected:
459  using Base::a_thread_copy_;
460  using Base::a_thread_desc_;
461  using Base::b_thread_copy_;
462  using Base::b_thread_desc_;
463  using Base::c_thread_desc_;
464 };
465 
466 } // namespace ck
Definition: ck.hpp:269
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:300
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:35
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:273
Definition: blockwise_gemm_pipeline_wmmaops_v3.hpp:36
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10