/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp Source File
blockwise_gemm_pipeline_wmmaops_v1.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Naive pipeline with lowest resource request per WGP
11 // GlobalPrefetchStages: 1
12 // LocalPreFillStages: 1
13 // LocalPreFetchStages: 0
14 // LocalSharedMemoryBuffer: 1
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeTypeA,
21  typename ComputeTypeB,
22  typename AccDataType,
23  typename AWmmaTileDesc,
24  typename BWmmaTileDesc,
25  index_t ABlockTransferSrcScalarPerVector,
26  index_t BBlockTransferSrcScalarPerVector,
27  index_t MPerBlock,
28  index_t NPerBlock,
29  index_t KPerBlock,
30  index_t MPerWmma,
31  index_t NPerWmma,
32  index_t MRepeat,
33  index_t NRepeat,
34  index_t KPack,
35  bool TransposeC = false>
37 {
38 };
39 
40 template <index_t BlockSize,
41  typename ADataType,
42  typename BDataType,
43  typename ComputeTypeA,
44  typename ComputeTypeB,
45  typename AccDataType,
46  typename AWmmaTileDesc,
47  typename BWmmaTileDesc,
48  index_t ABlockTransferSrcScalarPerVector,
49  index_t BBlockTransferSrcScalarPerVector,
50  index_t MPerBlock,
51  index_t NPerBlock,
52  index_t KPerBlock,
53  index_t MPerWmma,
54  index_t NPerWmma,
55  index_t MRepeat,
56  index_t NRepeat,
57  index_t KPack,
58  bool TransposeC>
60  BlockSize,
61  ADataType,
62  BDataType,
63  ComputeTypeA,
64  ComputeTypeB,
65  AccDataType,
66  AWmmaTileDesc,
67  BWmmaTileDesc,
68  ABlockTransferSrcScalarPerVector,
69  BBlockTransferSrcScalarPerVector,
70  MPerBlock,
71  NPerBlock,
72  KPerBlock,
73  MPerWmma,
74  NPerWmma,
75  MRepeat,
76  NRepeat,
77  KPack,
78  TransposeC>
80  ADataType,
81  BDataType,
82  ComputeTypeA,
83  ComputeTypeB,
84  AccDataType,
85  AWmmaTileDesc,
86  BWmmaTileDesc,
87  ABlockTransferSrcScalarPerVector,
88  BBlockTransferSrcScalarPerVector,
89  MPerBlock,
90  NPerBlock,
91  KPerBlock,
92  MPerWmma,
93  NPerWmma,
94  MRepeat,
95  NRepeat,
96  KPack,
97  TransposeC>
98 {
100  ADataType,
101  BDataType,
102  ComputeTypeA,
103  ComputeTypeB,
104  AccDataType,
105  AWmmaTileDesc,
106  BWmmaTileDesc,
107  ABlockTransferSrcScalarPerVector,
108  BBlockTransferSrcScalarPerVector,
109  MPerBlock,
110  NPerBlock,
111  KPerBlock,
112  MPerWmma,
113  NPerWmma,
114  MRepeat,
115  NRepeat,
116  KPack,
117  TransposeC>;
118  using Base::I0;
119  using Base::I1;
120  using Base::WaveSize;
121  using typename Base::HotLoopInstList;
122 
123  using Base::A_K1;
124  using Base::A_KRow;
125  using Base::B_K1;
126  using Base::B_KRow;
127  using Base::KRepeat;
128  using Base::WmmaK;
129 
130  using Base::wmma_gemm;
131 
132  using Base::CalculateCThreadOriginDataIndex;
133  using Base::
134  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
135  using Base::GetCThreadBuffer;
136  using Base::
137  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
138 
139  using Base::a_block_desc_k0_m0_m1_m2_k1;
140  using Base::b_block_desc_k0_n0_n1_n2_k1;
141 
142  using typename Base::Empty;
143 
144  static constexpr index_t PrefetchStages = 1;
145  static constexpr index_t PrefillStages = 1;
146  static constexpr index_t GlobalBufferNum = 1;
147 
148  static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
149 
151  {
152  ignore = num_loop;
153  return TailNumber::Full;
154  }
155 
156  template <bool HasMainLoop,
157  TailNumber TailNum,
158  typename AGridDesc,
159  typename ABlockDesc,
160  typename ABlockTransfer,
161  typename AGridBuffer,
162  typename ABlockBuffer,
163  typename ABlockTransferStep,
164  typename BGridDesc,
165  typename BBlockDesc,
166  typename BBlockTransfer,
167  typename BGridBuffer,
168  typename BBlockBuffer,
169  typename BBlockTransferStep,
170  typename CThreadBuffer,
171  typename BScaleStruct>
172  __device__ void Run(const AGridDesc& a_grid_desc,
173  const ABlockDesc& a_block_desc,
174  ABlockTransfer& a_blockwise_copy,
175  const AGridBuffer& a_grid_buf,
176  ABlockBuffer& a_block_buf,
177  const ABlockTransferStep& a_block_copy_step,
178  const BGridDesc& b_grid_desc,
179  const BBlockDesc& b_block_desc,
180  BBlockTransfer& b_blockwise_copy,
181  const BGridBuffer& b_grid_buf,
182  BBlockBuffer& b_block_buf,
183  const BBlockTransferStep& b_block_copy_step,
184  CThreadBuffer& c_thread_buf,
185  // BScaleThreadCopy
186  BScaleStruct& b_scale_struct,
187  index_t num_loop,
188  index_t num_loop_per_scale) const
189  {
190  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
191  a_thread_desc_.GetElementSpaceSize());
192  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
193  b_thread_desc_.GetElementSpaceSize());
194 
195  // Global prefetch 1
196  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
197  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
198 
199  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
200  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
201 
202  b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
203 
204  // Local prefill 1
205  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
206  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
207 
208  // Initialize C
209  c_thread_buf.Clear();
210 
211  auto blockwise_gemm_func = [&]() {
212  static_for<0, KRepeat, 1>{}([&](auto k0) {
213  static_for<0, MRepeat, 1>{}([&](auto m0) {
214  a_thread_copy_.Run(
215  a_block_desc_k0_m0_m1_m2_k1,
216  make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
217  a_block_buf,
218  a_thread_desc_,
219  make_tuple(I0, I0, I0, I0, I0, I0),
220  a_thread_buf);
221 
222  if constexpr(m0 == I0)
223  {
224  if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
225  {
226  static_for<0, NRepeat, 1>{}([&](auto n0) {
227  b_thread_copy_.Run(
228  b_block_desc_k0_n0_n1_n2_k1,
229  make_tuple(
230  Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
231  b_block_buf,
232  b_thread_desc_,
233  make_tuple(I0, n0, I0, I0, I0, I0),
234  b_thread_buf);
235  });
236  }
237  else
238  {
239  static_for<0, NRepeat, 1>{}([&](auto n0) {
240  b_thread_copy_.Run(
241  b_block_desc_k0_n0_n1_n2_k1,
242  make_tuple(
243  Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
244  b_block_buf,
245  b_scale_struct.b_scale_thread_bufs(
246  I0)[Number<n0 * BScaleStruct::num_scale_k_block +
247  k0 / BScaleStruct::num_scale_krepeat>{}],
248  b_thread_desc_,
249  make_tuple(I0, n0, I0, I0, I0, I0),
250  b_thread_buf);
251  });
252  }
253  }
254 
255  static_for<0, NRepeat, 1>{}([&](auto n0) {
256  vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
257  vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
258 
259  static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
260  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
261  a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
262  Number<ik / A_K1>{}, I0, I0, I0, I0, Number<ik % A_K1>{}))>{}];
263  });
264  static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
265  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
266  b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
267  Number<ik / B_K1>{}, n0, I0, I0, I0, Number<ik % B_K1>{}))>{}];
268  });
269 
270  using wmma_input_type_a =
271  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
272  using wmma_input_type_b =
273  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
274 
275  constexpr index_t c_offset =
276  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
277 
278  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
279  b_thread_vec.template AsType<wmma_input_type_b>(),
280  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
281  });
282  });
283  });
284  };
285 
286  // main body
287  if constexpr(HasMainLoop)
288  {
289  index_t i = 0;
290  do
291  {
292  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
293  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
294 
295  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
296  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
297 
298  block_sync_lds();
299  blockwise_gemm_func();
300 
301  block_sync_lds();
302  b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
303  if constexpr(ck::is_same<BScaleStruct, Empty>::value == false)
304  {
305  block_sync_lds();
306  }
307  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
308  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
309 
310  constexpr index_t num_ds_write_inst =
311  HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num;
312 
313  constexpr index_t num_buffer_load_inst = HotLoopInstList::A_Buffer_Load_Inst_Num +
314  HotLoopInstList::B_Buffer_Load_Inst_Num;
316  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
317  });
318  static_for<0, KRepeat, 1>{}([&](auto) {
319  static_for<0, MRepeat, 1>{}([&](auto m0) {
320  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
321  if constexpr(m0 == I0)
322  {
323  static_for<0, NRepeat, 1>{}([&](auto) {
324  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
325  });
326  }
327  static_for<0, NRepeat, 1>{}([&](auto) {
328  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
329  });
330  });
331  });
333  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
334  });
335 
336  i += 1;
337  } while(i < (num_loop - 1));
338  }
339 
340  // tail
341  if constexpr(TailNum == TailNumber::Full)
342  {
343  block_sync_lds();
344  blockwise_gemm_func();
345  }
346  }
347 
348  protected:
349  // A[MRepeat, I1, I1, KPack]
350  static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
352 
353  // B[NRepeat, N1, N2, KPack]
354  static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
355  make_tuple(Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, Number<B_K1>{}));
356 
357  using AThreadCopy =
359  ComputeTypeA,
360  decltype(a_block_desc_k0_m0_m1_m2_k1),
361  decltype(a_thread_desc_),
362  Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
364  5,
365  A_K1,
366  A_K1>;
367 
368  using BThreadCopy =
370  ComputeTypeB,
371  decltype(b_block_desc_k0_n0_n1_n2_k1),
372  decltype(b_thread_desc_),
373  Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
375  5,
376  B_K1,
377  B_K1>;
378 
379  AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
380  BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
381  using Base::c_thread_desc_;
382 };
383 
384 template <index_t BlockSize,
385  typename ADataType,
386  typename BDataType,
387  typename ComputeTypeA,
388  typename ComputeTypeB,
389  typename AccDataType,
390  typename AWmmaTileDesc,
391  typename BWmmaTileDesc,
392  index_t ABlockTransferSrcScalarPerVector,
393  index_t BBlockTransferSrcScalarPerVector,
394  index_t MPerBlock,
395  index_t NPerBlock,
396  index_t KPerBlock,
397  index_t MPerWmma,
398  index_t NPerWmma,
399  index_t MRepeat,
400  index_t NRepeat,
401  index_t KPack,
402  bool TransposeC>
404  BlockSize,
405  ADataType,
406  BDataType,
407  ComputeTypeA,
408  ComputeTypeB,
409  AccDataType,
410  AWmmaTileDesc,
411  BWmmaTileDesc,
412  ABlockTransferSrcScalarPerVector,
413  BBlockTransferSrcScalarPerVector,
414  MPerBlock,
415  NPerBlock,
416  KPerBlock,
417  MPerWmma,
418  NPerWmma,
419  MRepeat,
420  NRepeat,
421  KPack,
422  TransposeC>
424  ADataType,
425  BDataType,
426  ComputeTypeA,
427  ComputeTypeB,
428  AccDataType,
429  AWmmaTileDesc,
430  BWmmaTileDesc,
431  ABlockTransferSrcScalarPerVector,
432  BBlockTransferSrcScalarPerVector,
433  MPerBlock,
434  NPerBlock,
435  KPerBlock,
436  MPerWmma,
437  NPerWmma,
438  MRepeat,
439  NRepeat,
440  KPack,
441  TransposeC>
442 {
444  ADataType,
445  BDataType,
446  ComputeTypeA,
447  ComputeTypeB,
448  AccDataType,
449  AWmmaTileDesc,
450  BWmmaTileDesc,
451  ABlockTransferSrcScalarPerVector,
452  BBlockTransferSrcScalarPerVector,
453  MPerBlock,
454  NPerBlock,
455  KPerBlock,
456  MPerWmma,
457  NPerWmma,
458  MRepeat,
459  NRepeat,
460  KPack,
461  TransposeC>;
462  using Base::I0;
463  using Base::I1;
464 
465  using Base::A_K1;
466  using Base::A_KRow;
467  using Base::B_K1;
468  using Base::B_KRow;
469  using Base::KRepeat;
470  using Base::WmmaK;
471 
472  using Base::wmma_gemm;
473 
474  using Base::CalculateCThreadOriginDataIndex;
475  using Base::
476  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
477  using Base::GetCThreadBuffer;
478  using Base::
479  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
480 
481  using Base::a_block_desc_k0_m0_m1_m2_k1;
482  using Base::b_block_desc_k0_n0_n1_n2_k1;
483 
484  using typename Base::Empty;
485 
487  static constexpr index_t KRepeatPerCluster = math::max(KRepeat / NumKClusters, 1);
488 
489  static constexpr index_t PrefetchStages = 1;
490  static constexpr index_t PrefillStages = 1;
491  static constexpr index_t GlobalBufferNum = 1;
492 
493  static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
494 
496  {
497  ignore = num_loop;
498  return TailNumber::Full;
499  }
500 
501  template <bool HasMainLoop,
502  TailNumber TailNum,
503  typename AGridDesc,
504  typename ABlockDesc,
505  typename ABlockTransfer,
506  typename AGridBuffer,
507  typename ABlockBuffer,
508  typename ABlockTransferStep,
509  typename BGridDesc,
510  typename BBlockDesc,
511  typename BBlockTransfer,
512  typename BGridBuffer,
513  typename BBlockBuffer,
514  typename BBlockTransferStep,
515  typename CThreadBuffer,
516  typename BScaleStruct>
517  __device__ void Run(const AGridDesc& a_grid_desc,
518  const ABlockDesc& a_block_desc,
519  ABlockTransfer& a_blockwise_copy,
520  const AGridBuffer& a_grid_buf,
521  ABlockBuffer& a_block_buf,
522  const ABlockTransferStep& a_block_copy_step,
523  const BGridDesc& b_grid_desc,
524  const BBlockDesc& b_block_desc,
525  BBlockTransfer& b_blockwise_copy,
526  const BGridBuffer& b_grid_buf,
527  BBlockBuffer& b_block_buf,
528  const BBlockTransferStep& b_block_copy_step,
529  CThreadBuffer& c_thread_buf,
530  // BScaleThreadCopy
531  BScaleStruct& b_scale_struct,
532  index_t num_loop,
533  index_t num_loop_per_scale) const
534  {
535  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
536  a_thread_desc_.GetElementSpaceSize());
537  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
538  b_thread_desc_.GetElementSpaceSize());
539 
540  // Global prefetch 1
541  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
542  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
543 
544  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
545  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
546 
547  b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
548 
549  // Local prefill 1
550  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
551  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
552 
553  // Initialize C
554  c_thread_buf.Clear();
555 
556  auto blockwise_gemm_func = [&]() {
557  static_for<0, KRepeat, KRepeatPerCluster>{}([&](auto k0_offset) {
558  static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
559  static_for<0, MRepeat, 1>{}([&](auto m0) {
560  a_thread_copy_.Run(
561  a_block_desc_k0_m0_m1_m2_k1,
562  make_tuple(Number<(k0_offset + k0_inner) * KPack / A_K1 / A_KRow>{},
563  m0,
564  I0,
565  I0,
566  I0,
567  I0),
568  a_block_buf,
569  a_thread_desc_,
570  make_tuple(I0, m0, k0_inner, I0, I0, I0),
571  a_thread_buf);
572  });
573  if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
574  {
575  static_for<0, NRepeat, 1>{}([&](auto n0) {
576  b_thread_copy_.Run(
577  b_block_desc_k0_n0_n1_n2_k1,
578  make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{},
579  n0,
580  I0,
581  I0,
582  I0,
583  I0),
584  b_block_buf,
585  b_thread_desc_,
586  make_tuple(I0, n0, k0_inner, I0, I0, I0),
587  b_thread_buf);
588  });
589  }
590  else
591  {
592  static_for<0, NRepeat, 1>{}([&](auto n0) {
593  b_thread_copy_.Run(
594  b_block_desc_k0_n0_n1_n2_k1,
595  make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{},
596  n0,
597  I0,
598  I0,
599  I0,
600  I0),
601  b_block_buf,
602  b_scale_struct.b_scale_thread_bufs(I0)[Number<
603  n0 * BScaleStruct::num_scale_k_block +
604  (k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
605  b_thread_desc_,
606  make_tuple(I0, n0, k0_inner, I0, I0, I0),
607  b_thread_buf);
608  });
609  }
610  });
611 
612  __builtin_amdgcn_sched_barrier(0);
613  // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster,
614  // but except the first, as we can shorten non-MAC cluster a bit and there's no
615  // observable negative impact. The desired effect is waves in a workgroup
616  // executing MAC in sync. This avoids some out-of-sync waves hijacking MAC
617  // resource from other workgroups and reducing the chance of latency hiding by
618  // waiting for the rest of the workgroup at the eventual sync point.
619  if constexpr(k0_offset != 0 || KRepeat == 1)
620  {
621  __builtin_amdgcn_s_barrier();
622  __builtin_amdgcn_sched_barrier(0);
623  }
624  static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
625  static_for<0, MRepeat, 1>{}([&](auto m0) {
626  static_for<0, NRepeat, 1>{}([&](auto n0) {
627  vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
628  vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
629 
630  static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
631  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
632  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
634  m0,
635  k0_inner,
636  I0,
637  I0,
638  Number<ik % A_K1>{}))>{}];
639  });
640  static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
641  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
642  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
644  n0,
645  k0_inner,
646  I0,
647  I0,
648  Number<ik % B_K1>{}))>{}];
649  });
650 
651  using wmma_input_type_a =
652  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
653  using wmma_input_type_b =
654  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
655 
656  constexpr index_t c_offset =
657  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
658 
659  // The block_sync_lds() here performs double duty:
660  // A) safeguard against data hazard.
661  // B) reduce VMEM FIFO congestion by applying small delays to
662  // different wavefronts.
663  // It is performed near the end of MAC cluster to minimize lgkmcnt
664  // penalty
665  if constexpr(k0_offset + k0_inner == KRepeat - 1 && m0 == MRepeat - 1 &&
666  n0 == NRepeat - 1)
667  {
668  __builtin_amdgcn_sched_barrier(0);
669  block_sync_lds();
670  __builtin_amdgcn_sched_barrier(0);
671  }
672  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
673  b_thread_vec.template AsType<wmma_input_type_b>(),
674  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
675  if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
676  {
677  __builtin_amdgcn_sched_barrier(0);
678  __builtin_amdgcn_s_setprio(1);
679  __builtin_amdgcn_sched_barrier(0);
680  }
681  });
682  });
683  });
684  __builtin_amdgcn_sched_barrier(0);
685  __builtin_amdgcn_s_setprio(0);
686  __builtin_amdgcn_sched_barrier(0);
687  });
688  };
689 
690  // main body
691  if constexpr(HasMainLoop)
692  {
693  index_t i = 0;
694  do
695  {
696  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
697  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
698 
699  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
700  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
701 
702  block_sync_lds();
703  blockwise_gemm_func();
704 
705  b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
706  if constexpr(ck::is_same<BScaleStruct, Empty>::value == false)
707  {
708  block_sync_lds();
709  }
710  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
711  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
712 
713  i += 1;
714  } while(i < (num_loop - 1));
715  }
716 
717  // tail
718  if constexpr(TailNum == TailNumber::Full)
719  {
720  block_sync_lds();
721  blockwise_gemm_func();
722  }
723  }
724 
725  protected:
726  static constexpr auto a_thread_desc_ =
728  Number<MRepeat>{},
729  Number<KRepeatPerCluster>{},
730  I1,
731  I1,
732  Number<A_K1>{}),
733  make_tuple(Number<A_K1>{},
734  Number<KPack / A_KRow>{},
735  Number<KPack / A_KRow * MRepeat>{},
736  I0,
737  I0,
738  I1));
739 
740  static constexpr auto b_thread_desc_ =
742  Number<NRepeat>{},
743  Number<KRepeatPerCluster>{},
744  I1,
745  I1,
746  Number<B_K1>{}),
747  make_tuple(Number<B_K1>{},
748  Number<KPack / B_KRow>{},
749  Number<KPack / B_KRow * NRepeat>{},
750  I0,
751  I0,
752  I1));
753 
754  using AThreadCopy =
756  ComputeTypeA,
757  decltype(a_block_desc_k0_m0_m1_m2_k1),
758  decltype(a_thread_desc_),
759  Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
761  5,
762  A_K1,
763  A_K1>;
764 
765  using BThreadCopy =
767  ComputeTypeB,
768  decltype(b_block_desc_k0_n0_n1_n2_k1),
769  decltype(b_thread_desc_),
770  Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
772  5,
773  B_K1,
774  B_K1>;
775 
776  AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
777  BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
778  using Base::c_thread_desc_;
779 };
780 
781 } // namespace ck
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:209
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:268
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:299
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:35
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:517
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:172
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:37
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10