/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.2.0/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.2.0/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.2.0/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp Source File
blockwise_gemm_pipeline_wmmaops_v1.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Naive pipeline with lowest resource request per WGP
11 // GlobalPrefetchStages: 1
12 // LocalPreFillStages: 1
13 // LocalPreFetchStages: 0
14 // LocalSharedMemoryBuffer: 1
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeTypeA,
21  typename ComputeTypeB,
22  typename AccDataType,
23  typename AWmmaTileDesc,
24  typename BWmmaTileDesc,
25  index_t ABlockTransferSrcScalarPerVector,
26  index_t BBlockTransferSrcScalarPerVector,
27  index_t MPerBlock,
28  index_t NPerBlock,
29  index_t KPerBlock,
30  index_t MPerWmma,
31  index_t NPerWmma,
32  index_t MRepeat,
33  index_t NRepeat,
34  index_t KPack,
35  index_t KInner,
36  bool TransposeC = false>
38 {
39 };
40 
41 template <index_t BlockSize,
42  typename ADataType,
43  typename BDataType,
44  typename ComputeTypeA,
45  typename ComputeTypeB,
46  typename AccDataType,
47  typename AWmmaTileDesc,
48  typename BWmmaTileDesc,
49  index_t ABlockTransferSrcScalarPerVector,
50  index_t BBlockTransferSrcScalarPerVector,
51  index_t MPerBlock,
52  index_t NPerBlock,
53  index_t KPerBlock,
54  index_t MPerWmma,
55  index_t NPerWmma,
56  index_t MRepeat,
57  index_t NRepeat,
58  index_t KPack,
59  index_t KInner,
60  bool TransposeC>
62  BlockSize,
63  ADataType,
64  BDataType,
65  ComputeTypeA,
66  ComputeTypeB,
67  AccDataType,
68  AWmmaTileDesc,
69  BWmmaTileDesc,
70  ABlockTransferSrcScalarPerVector,
71  BBlockTransferSrcScalarPerVector,
72  MPerBlock,
73  NPerBlock,
74  KPerBlock,
75  MPerWmma,
76  NPerWmma,
77  MRepeat,
78  NRepeat,
79  KPack,
80  KInner,
81  TransposeC>
83  ADataType,
84  BDataType,
85  ComputeTypeA,
86  ComputeTypeB,
87  AccDataType,
88  AWmmaTileDesc,
89  BWmmaTileDesc,
90  ABlockTransferSrcScalarPerVector,
91  BBlockTransferSrcScalarPerVector,
92  MPerBlock,
93  NPerBlock,
94  KPerBlock,
95  MPerWmma,
96  NPerWmma,
97  MRepeat,
98  NRepeat,
99  KPack,
100  KInner,
101  TransposeC>
102 {
104  ADataType,
105  BDataType,
106  ComputeTypeA,
107  ComputeTypeB,
108  AccDataType,
109  AWmmaTileDesc,
110  BWmmaTileDesc,
111  ABlockTransferSrcScalarPerVector,
112  BBlockTransferSrcScalarPerVector,
113  MPerBlock,
114  NPerBlock,
115  KPerBlock,
116  MPerWmma,
117  NPerWmma,
118  MRepeat,
119  NRepeat,
120  KPack,
121  KInner,
122  TransposeC>;
123  using Base::I0;
124  using Base::I1;
125  using typename Base::HotLoopInstList;
126 
127  using Base::A_K1;
128  using Base::A_KRow;
129  using Base::B_K1;
130  using Base::B_KRow;
131  using Base::KRepeat;
132  using Base::WmmaK;
133 
134  using Base::wmma_gemm;
135 
136  using Base::CalculateCThreadOriginDataIndex;
137  using Base::
138  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
139  using Base::GetCThreadBuffer;
140  using Base::
141  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
142 
143  using Base::a_block_desc_k0_m0_m1_m2_k1;
144  using Base::b_block_desc_k0_n0_n1_n2_k1;
145 
146  using typename Base::Empty;
147 
148  static constexpr index_t PrefetchStages = 1;
149  static constexpr index_t PrefillStages = 1;
150  static constexpr index_t GlobalBufferNum = 1;
151 
152  static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
153 
155  {
156  ignore = num_loop;
157  return TailNumber::Full;
158  }
159 
160  template <bool HasMainLoop,
161  TailNumber TailNum,
162  typename AGridDesc,
163  typename ABlockDesc,
164  typename ABlockTransfer,
165  typename AGridBuffer,
166  typename ABlockBuffer,
167  typename ABlockTransferStep,
168  typename BGridDesc,
169  typename BBlockDesc,
170  typename BBlockTransfer,
171  typename BGridBuffer,
172  typename BBlockBuffer,
173  typename BBlockTransferStep,
174  typename CThreadBuffer,
175  typename BScaleStruct>
176  __device__ void Run(const AGridDesc& a_grid_desc,
177  const ABlockDesc& a_block_desc,
178  ABlockTransfer& a_blockwise_copy,
179  const AGridBuffer& a_grid_buf,
180  ABlockBuffer& a_block_buf,
181  const ABlockTransferStep& a_block_copy_step,
182  const BGridDesc& b_grid_desc,
183  const BBlockDesc& b_block_desc,
184  BBlockTransfer& b_blockwise_copy,
185  const BGridBuffer& b_grid_buf,
186  BBlockBuffer& b_block_buf,
187  const BBlockTransferStep& b_block_copy_step,
188  CThreadBuffer& c_thread_buf,
189  // BScaleThreadCopy
190  BScaleStruct& b_scale_struct,
191  index_t num_loop,
192  index_t num_loop_per_scale) const
193  {
194  constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
195 
196  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
197  a_thread_desc_.GetElementSpaceSize());
198  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
199  b_thread_desc_.GetElementSpaceSize());
200 
201  // Global prefetch 1
202  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
203  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
204 
205  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
206  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
207 
208  b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
209 
210  // Local prefill 1
211  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
212  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
213 
214  // Initialize C
215  c_thread_buf.Clear();
216 
217  auto blockwise_gemm_func = [&]() {
218  static_for<0, KRepeat, 1>{}([&](auto k0) {
219  static_for<0, MRepeat, 1>{}([&](auto m0) {
220  a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
221  make_tuple(I0, m0, k0, I0, I0, I0, I0),
222  a_block_buf,
223  a_thread_desc_,
224  make_tuple(I0, I0, I0, I0, I0, I0, I0),
225  a_thread_buf);
226  if constexpr(m0 == I0)
227  {
228  if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
229  {
230  static_for<0, NRepeat, 1>{}([&](auto n0) {
231  b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
232  make_tuple(I0, n0, k0, I0, I0, I0, I0),
233  b_block_buf,
234  b_thread_desc_,
235  make_tuple(I0, n0, I0, I0, I0, I0, I0),
236  b_thread_buf);
237  });
238  }
239  else
240  {
241  static_for<0, NRepeat, 1>{}([&](auto n0) {
242  b_thread_copy_.Run(
243  b_block_desc_k0_n0_n1_n2_k1,
244  make_tuple(I0, n0, k0, I0, I0, I0, I0),
245  b_block_buf,
246  b_scale_struct.b_scale_thread_bufs(
247  I0)[Number<n0 * BScaleStruct::num_scale_k_block +
248  k0 / BScaleStruct::num_scale_krepeat>{}],
249  b_thread_desc_,
250  make_tuple(I0, n0, I0, I0, I0, I0, I0),
251  b_thread_buf);
252  });
253  }
254  }
255 
256  static_for<0, KInner, 1>{}([&](auto k_inner) {
257  static_for<0, NRepeat, 1>{}([&](auto n0) {
258  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
259  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
260 
261  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
262  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
263  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
264  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
266  I0,
267  I0,
268  I0,
269  I0,
270  I0,
271  Number<kk % A_K1>{}))>{}];
272  });
273  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
274  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
275  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
276  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
278  n0,
279  I0,
280  I0,
281  I0,
282  I0,
283  Number<kk % B_K1>{}))>{}];
284  });
285 
286  using wmma_input_type_a =
287  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
288  using wmma_input_type_b =
289  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
290 
291  constexpr index_t c_offset =
292  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
293 
294  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
295  b_thread_vec.template AsType<wmma_input_type_b>(),
296  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
297  });
298  });
299  });
300  });
301  };
302 
303  // main body
304  if constexpr(HasMainLoop)
305  {
306  index_t i = 0;
307  do
308  {
309  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
310  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
311 
312  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
313  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
314 
315  block_sync_lds();
316  blockwise_gemm_func();
317 
318  block_sync_lds();
319  b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
320  if constexpr(ck::is_same<BScaleStruct, Empty>::value == false)
321  {
322  block_sync_lds();
323  }
324  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
325  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
326 
327  constexpr index_t num_ds_write_inst =
328  HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num;
329 
330  constexpr index_t num_buffer_load_inst = HotLoopInstList::A_Buffer_Load_Inst_Num +
331  HotLoopInstList::B_Buffer_Load_Inst_Num;
333  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
334  });
335  static_for<0, KRepeat, 1>{}([&](auto) {
336  static_for<0, MRepeat, 1>{}([&](auto m0) {
337  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
338  if constexpr(m0 == I0)
339  {
340  static_for<0, NRepeat, 1>{}([&](auto) {
341  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
342  });
343  }
344  static_for<0, KInner, 1>{}([&](auto) {
345  static_for<0, NRepeat, 1>{}([&](auto) {
346  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
347  });
348  });
349  });
350  });
352  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
353  });
354 
355  i += 1;
356  } while(i < (num_loop - 1));
357  }
358 
359  // tail
360  if constexpr(TailNum == TailNumber::Full)
361  {
362  block_sync_lds();
363  blockwise_gemm_func();
364  }
365  }
366 
367  protected:
368  // A[MRepeat, I1, I1, KPack]
369  static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
370  make_tuple(Number<KPack / A_K1 / A_KRow>{}, I1, I1, I1, I1, I1, Number<A_K1>{}));
371 
372  // B[NRepeat, N1, N2, KPack]
373  static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple(
374  Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, I1, Number<B_K1>{}));
375 
376  using AThreadCopy =
378  ComputeTypeA,
379  decltype(a_block_desc_k0_m0_m1_m2_k1),
380  decltype(a_thread_desc_),
381  Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
383  6,
384  A_K1,
385  A_K1>;
386 
387  using BThreadCopy =
389  ComputeTypeB,
390  decltype(b_block_desc_k0_n0_n1_n2_k1),
391  decltype(b_thread_desc_),
392  Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
394  6,
395  B_K1,
396  B_K1>;
397 
398  AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
399  BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
400  using Base::c_thread_desc_;
401 };
402 
403 template <index_t BlockSize,
404  typename ADataType,
405  typename BDataType,
406  typename ComputeTypeA,
407  typename ComputeTypeB,
408  typename AccDataType,
409  typename AWmmaTileDesc,
410  typename BWmmaTileDesc,
411  index_t ABlockTransferSrcScalarPerVector,
412  index_t BBlockTransferSrcScalarPerVector,
413  index_t MPerBlock,
414  index_t NPerBlock,
415  index_t KPerBlock,
416  index_t MPerWmma,
417  index_t NPerWmma,
418  index_t MRepeat,
419  index_t NRepeat,
420  index_t KPack,
421  index_t KInner,
422  bool TransposeC>
424  BlockSize,
425  ADataType,
426  BDataType,
427  ComputeTypeA,
428  ComputeTypeB,
429  AccDataType,
430  AWmmaTileDesc,
431  BWmmaTileDesc,
432  ABlockTransferSrcScalarPerVector,
433  BBlockTransferSrcScalarPerVector,
434  MPerBlock,
435  NPerBlock,
436  KPerBlock,
437  MPerWmma,
438  NPerWmma,
439  MRepeat,
440  NRepeat,
441  KPack,
442  KInner,
443  TransposeC>
445  ADataType,
446  BDataType,
447  ComputeTypeA,
448  ComputeTypeB,
449  AccDataType,
450  AWmmaTileDesc,
451  BWmmaTileDesc,
452  ABlockTransferSrcScalarPerVector,
453  BBlockTransferSrcScalarPerVector,
454  MPerBlock,
455  NPerBlock,
456  KPerBlock,
457  MPerWmma,
458  NPerWmma,
459  MRepeat,
460  NRepeat,
461  KPack,
462  KInner,
463  TransposeC>
464 {
466  ADataType,
467  BDataType,
468  ComputeTypeA,
469  ComputeTypeB,
470  AccDataType,
471  AWmmaTileDesc,
472  BWmmaTileDesc,
473  ABlockTransferSrcScalarPerVector,
474  BBlockTransferSrcScalarPerVector,
475  MPerBlock,
476  NPerBlock,
477  KPerBlock,
478  MPerWmma,
479  NPerWmma,
480  MRepeat,
481  NRepeat,
482  KPack,
483  KInner,
484  TransposeC>;
485  using Base::I0;
486  using Base::I1;
487 
488  using Base::A_K1;
489  using Base::A_KRow;
490  using Base::B_K1;
491  using Base::B_KRow;
492  using Base::KRepeat;
493  using Base::WmmaK;
494 
495  using Base::wmma_gemm;
496 
497  using Base::CalculateCThreadOriginDataIndex;
498  using Base::
499  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
500  using Base::GetCThreadBuffer;
501  using Base::
502  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
503 
504  using Base::a_block_desc_k0_m0_m1_m2_k1;
505  using Base::b_block_desc_k0_n0_n1_n2_k1;
506 
507  using typename Base::Empty;
508 
510  static constexpr index_t KRepeatPerCluster = math::max(KRepeat / NumKClusters, 1);
511 
512  static constexpr index_t PrefetchStages = 1;
513  static constexpr index_t PrefillStages = 1;
514  static constexpr index_t GlobalBufferNum = 1;
515 
516  static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
517 
519  {
520  ignore = num_loop;
521  return TailNumber::Full;
522  }
523 
524  template <bool HasMainLoop,
525  TailNumber TailNum,
526  typename AGridDesc,
527  typename ABlockDesc,
528  typename ABlockTransfer,
529  typename AGridBuffer,
530  typename ABlockBuffer,
531  typename ABlockTransferStep,
532  typename BGridDesc,
533  typename BBlockDesc,
534  typename BBlockTransfer,
535  typename BGridBuffer,
536  typename BBlockBuffer,
537  typename BBlockTransferStep,
538  typename CThreadBuffer,
539  typename BScaleStruct>
540  __device__ void Run(const AGridDesc& a_grid_desc,
541  const ABlockDesc& a_block_desc,
542  ABlockTransfer& a_blockwise_copy,
543  const AGridBuffer& a_grid_buf,
544  ABlockBuffer& a_block_buf,
545  const ABlockTransferStep& a_block_copy_step,
546  const BGridDesc& b_grid_desc,
547  const BBlockDesc& b_block_desc,
548  BBlockTransfer& b_blockwise_copy,
549  const BGridBuffer& b_grid_buf,
550  BBlockBuffer& b_block_buf,
551  const BBlockTransferStep& b_block_copy_step,
552  CThreadBuffer& c_thread_buf,
553  // BScaleThreadCopy
554  BScaleStruct& b_scale_struct,
555  index_t num_loop,
556  index_t num_loop_per_scale) const
557  {
558  constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
559 
560  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
561  a_thread_desc_.GetElementSpaceSize());
562  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
563  b_thread_desc_.GetElementSpaceSize());
564 
565  // Global prefetch 1
566  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
567  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
568 
569  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
570  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
571 
572  b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
573 
574  // Local prefill 1
575  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
576  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
577 
578  // Initialize C
579  c_thread_buf.Clear();
580 
581  auto blockwise_gemm_func = [&]() {
582  static_for<0, KRepeat, KRepeatPerCluster>{}([&](auto k0_offset) {
583  static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
584  static_for<0, MRepeat, 1>{}([&](auto m0) {
585  a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
586  make_tuple(I0, m0, k0_offset + k0_inner, I0, I0, I0, I0),
587  a_block_buf,
588  a_thread_desc_,
589  make_tuple(I0, m0, k0_inner, I0, I0, I0, I0),
590  a_thread_buf);
591  });
592  if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
593  {
594  static_for<0, NRepeat, 1>{}([&](auto n0) {
595  b_thread_copy_.Run(
596  b_block_desc_k0_n0_n1_n2_k1,
597  make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
598  b_block_buf,
599  b_thread_desc_,
600  make_tuple(I0, n0, k0_inner, I0, I0, I0, I0),
601  b_thread_buf);
602  });
603  }
604  else
605  {
606  static_for<0, NRepeat, 1>{}([&](auto n0) {
607  b_thread_copy_.Run(
608  b_block_desc_k0_n0_n1_n2_k1,
609  make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
610  b_block_buf,
611  b_scale_struct.b_scale_thread_bufs(I0)[Number<
612  n0 * BScaleStruct::num_scale_k_block +
613  (k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
614  b_thread_desc_,
615  make_tuple(I0, n0, k0_inner, I0, I0, I0, I0),
616  b_thread_buf);
617  });
618  }
619  });
620 
621  __builtin_amdgcn_sched_barrier(0);
622  // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster,
623  // but except the first, as we can shorten non-MAC cluster a bit and there's no
624  // observable negative impact. The desired effect is waves in a workgroup
625  // executing MAC in sync. This avoids some out-of-sync waves hijacking MAC
626  // resource from other workgroups and reducing the chance of latency hiding by
627  // waiting for the rest of the workgroup at the eventual sync point.
628  if constexpr(k0_offset != 0 || KRepeat == 1)
629  {
630  __builtin_amdgcn_s_barrier();
631  __builtin_amdgcn_sched_barrier(0);
632  }
633  static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
634  static_for<0, KInner, 1>{}([&](auto k_inner) {
635  static_for<0, MRepeat, 1>{}([&](auto m0) {
636  static_for<0, NRepeat, 1>{}([&](auto n0) {
637  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
638  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
639 
640  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
641  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
642  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
643  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
645  m0,
646  k0_inner,
647  I0,
648  I0,
649  I0,
650  Number<kk % A_K1>{}))>{}];
651  });
652  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
653  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
654  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
655  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
657  n0,
658  k0_inner,
659  I0,
660  I0,
661  I0,
662  Number<kk % B_K1>{}))>{}];
663  });
664 
665  using wmma_input_type_a =
666  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
667  using wmma_input_type_b =
668  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
669 
670  constexpr index_t c_offset =
671  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
672 
673  // The block_sync_lds() here performs double duty:
674  // A) safeguard against data hazard.
675  // B) reduce VMEM FIFO congestion by applying small delays to
676  // different wavefronts.
677  // It is performed near the end of MAC cluster to minimize lgkmcnt
678  // penalty
679  if constexpr(k0_offset + k0_inner == KRepeat - 1 &&
680  m0 == MRepeat - 1 && n0 == NRepeat - 1)
681  {
682  __builtin_amdgcn_sched_barrier(0);
683  block_sync_lds();
684  __builtin_amdgcn_sched_barrier(0);
685  }
686  wmma_gemm.Run(
687  a_thread_vec.template AsType<wmma_input_type_a>(),
688  b_thread_vec.template AsType<wmma_input_type_b>(),
689  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
690  if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
691  {
692  __builtin_amdgcn_sched_barrier(0);
693  __builtin_amdgcn_s_setprio(1);
694  __builtin_amdgcn_sched_barrier(0);
695  }
696  });
697  });
698  });
699  });
700  __builtin_amdgcn_sched_barrier(0);
701  __builtin_amdgcn_s_setprio(0);
702  __builtin_amdgcn_sched_barrier(0);
703  });
704  };
705 
706  // main body
707  if constexpr(HasMainLoop)
708  {
709  index_t i = 0;
710  do
711  {
712  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
713  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
714 
715  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
716  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
717 
718  block_sync_lds();
719  blockwise_gemm_func();
720 
721  b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
722  if constexpr(ck::is_same<BScaleStruct, Empty>::value == false)
723  {
724  block_sync_lds();
725  }
726  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
727  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
728 
729  i += 1;
730  } while(i < (num_loop - 1));
731  }
732 
733  // tail
734  if constexpr(TailNum == TailNumber::Full)
735  {
736  block_sync_lds();
737  blockwise_gemm_func();
738  }
739  }
740 
741  protected:
742  static constexpr auto a_thread_desc_ =
744  Number<MRepeat>{},
745  Number<KRepeatPerCluster>{},
746  I1,
747  I1,
748  I1,
749  Number<A_K1>{}),
750  make_tuple(Number<A_K1>{},
751  Number<KPack / A_KRow>{},
752  Number<KPack / A_KRow * MRepeat>{},
753  I0,
754  I0,
755  I0,
756  I1));
757 
758  static constexpr auto b_thread_desc_ =
760  Number<NRepeat>{},
761  Number<KRepeatPerCluster>{},
762  I1,
763  I1,
764  I1,
765  Number<B_K1>{}),
766  make_tuple(Number<B_K1>{},
767  Number<KPack / B_KRow>{},
768  Number<KPack / B_KRow * NRepeat>{},
769  I0,
770  I0,
771  I0,
772  I1));
773 
774  using AThreadCopy =
776  ComputeTypeA,
777  decltype(a_block_desc_k0_m0_m1_m2_k1),
778  decltype(a_thread_desc_),
779  Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
781  6,
782  A_K1,
783  A_K1>;
784 
785  using BThreadCopy =
787  ComputeTypeB,
788  decltype(b_block_desc_k0_n0_n1_n2_k1),
789  decltype(b_thread_desc_),
790  Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
792  6,
793  B_K1,
794  B_K1>;
795 
796  AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
797  BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
798  using Base::c_thread_desc_;
799 };
800 
801 } // namespace ck
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:209
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:299
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:36
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:540
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:176
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:38
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10