/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp Source File
blockwise_gemm_pipeline_wmmaops_v1.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Naive pipeline with lowest resource request per WGP
11 // GlobalPrefetchStages: 1
12 // LocalPreFillStages: 1
13 // LocalPreFetchStages: 0
14 // LocalSharedMemoryBuffer: 1
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeTypeA,
21  typename ComputeTypeB,
22  typename AccDataType,
23  typename AWmmaTileDesc,
24  typename BWmmaTileDesc,
25  index_t ABlockTransferSrcScalarPerVector,
26  index_t BBlockTransferSrcScalarPerVector,
27  index_t MPerBlock,
28  index_t NPerBlock,
29  index_t KPerBlock,
30  index_t MPerWmma,
31  index_t NPerWmma,
32  index_t MRepeat,
33  index_t NRepeat,
34  index_t KPack,
35  bool TransposeC = false>
37 {
38 };
39 
40 template <index_t BlockSize,
41  typename ADataType,
42  typename BDataType,
43  typename ComputeTypeA,
44  typename ComputeTypeB,
45  typename AccDataType,
46  typename AWmmaTileDesc,
47  typename BWmmaTileDesc,
48  index_t ABlockTransferSrcScalarPerVector,
49  index_t BBlockTransferSrcScalarPerVector,
50  index_t MPerBlock,
51  index_t NPerBlock,
52  index_t KPerBlock,
53  index_t MPerWmma,
54  index_t NPerWmma,
55  index_t MRepeat,
56  index_t NRepeat,
57  index_t KPack,
58  bool TransposeC>
60  BlockSize,
61  ADataType,
62  BDataType,
63  ComputeTypeA,
64  ComputeTypeB,
65  AccDataType,
66  AWmmaTileDesc,
67  BWmmaTileDesc,
68  ABlockTransferSrcScalarPerVector,
69  BBlockTransferSrcScalarPerVector,
70  MPerBlock,
71  NPerBlock,
72  KPerBlock,
73  MPerWmma,
74  NPerWmma,
75  MRepeat,
76  NRepeat,
77  KPack,
78  TransposeC>
80  ADataType,
81  BDataType,
82  ComputeTypeA,
83  ComputeTypeB,
84  AccDataType,
85  AWmmaTileDesc,
86  BWmmaTileDesc,
87  ABlockTransferSrcScalarPerVector,
88  BBlockTransferSrcScalarPerVector,
89  MPerBlock,
90  NPerBlock,
91  KPerBlock,
92  MPerWmma,
93  NPerWmma,
94  MRepeat,
95  NRepeat,
96  KPack,
97  TransposeC>
98 {
100  ADataType,
101  BDataType,
102  ComputeTypeA,
103  ComputeTypeB,
104  AccDataType,
105  AWmmaTileDesc,
106  BWmmaTileDesc,
107  ABlockTransferSrcScalarPerVector,
108  BBlockTransferSrcScalarPerVector,
109  MPerBlock,
110  NPerBlock,
111  KPerBlock,
112  MPerWmma,
113  NPerWmma,
114  MRepeat,
115  NRepeat,
116  KPack,
117  TransposeC>;
118  using Base::I0;
119 
120  using Base::A_K1;
121  using Base::A_KRow;
122  using Base::B_K1;
123  using Base::B_KRow;
124  using Base::KRepeat;
125  using Base::WmmaK;
126 
127  using Base::wmma_gemm;
128 
129  using Base::CalculateCThreadOriginDataIndex;
130  using Base::
131  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
132  using Base::GetCThreadBuffer;
133  using Base::
134  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
135 
136  using Base::a_block_desc_k0_m0_m1_m2_k1;
137  using Base::b_block_desc_k0_n0_n1_n2_k1;
138 
139  using typename Base::Empty;
140 
141  static constexpr index_t PrefetchStages = 1;
142  static constexpr index_t PrefillStages = 1;
143  static constexpr index_t GlobalBufferNum = 1;
144 
145  static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
146 
148  {
149  ignore = num_loop;
150  return TailNumber::Full;
151  }
152 
153  template <bool HasMainLoop,
154  TailNumber TailNum,
155  typename AGridDesc,
156  typename ABlockDesc,
157  typename ABlockTransfer,
158  typename AGridBuffer,
159  typename ABlockBuffer,
160  typename ABlockTransferStep,
161  typename BGridDesc,
162  typename BBlockDesc,
163  typename BBlockTransfer,
164  typename BGridBuffer,
165  typename BBlockBuffer,
166  typename BBlockTransferStep,
167  typename CThreadBuffer,
168  typename BScaleStruct>
169  __device__ void Run(const AGridDesc& a_grid_desc,
170  const ABlockDesc& a_block_desc,
171  ABlockTransfer& a_blockwise_copy,
172  const AGridBuffer& a_grid_buf,
173  ABlockBuffer& a_block_buf,
174  const ABlockTransferStep& a_block_copy_step,
175  const BGridDesc& b_grid_desc,
176  const BBlockDesc& b_block_desc,
177  BBlockTransfer& b_blockwise_copy,
178  const BGridBuffer& b_grid_buf,
179  BBlockBuffer& b_block_buf,
180  const BBlockTransferStep& b_block_copy_step,
181  CThreadBuffer& c_thread_buf,
182  // BScaleThreadCopy
183  BScaleStruct& b_scale_struct,
184  index_t num_loop,
185  index_t num_loop_per_scale) const
186  {
187  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
188  a_thread_desc_.GetElementSpaceSize());
189  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
190  b_thread_desc_.GetElementSpaceSize());
191 
192  // Global prefetch 1
193  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
194  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
195 
196  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
197  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
198 
199  b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
200 
201  // Local prefill 1
202  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
203  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
204 
205  // Initialize C
206  c_thread_buf.Clear();
207 
208  auto blockwise_gemm_func = [&]() {
209  static_for<0, KRepeat, 1>{}([&](auto k0) {
210  static_for<0, MRepeat, 1>{}([&](auto m0) {
211  a_thread_copy_.Run(
212  a_block_desc_k0_m0_m1_m2_k1,
213  make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
214  a_block_buf,
215  a_thread_desc_,
216  make_tuple(I0, m0, k0, I0, I0, I0),
217  a_thread_buf);
218  });
219  if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
220  {
221  static_for<0, NRepeat, 1>{}([&](auto n0) {
222  b_thread_copy_.Run(
223  b_block_desc_k0_n0_n1_n2_k1,
224  make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
225  b_block_buf,
226  b_thread_desc_,
227  make_tuple(I0, n0, k0, I0, I0, I0),
228  b_thread_buf);
229  });
230  }
231  else
232  {
233  static_for<0, NRepeat, 1>{}([&](auto n0) {
234  b_thread_copy_.Run(
235  b_block_desc_k0_n0_n1_n2_k1,
236  make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
237  b_block_buf,
238  b_scale_struct.b_scale_thread_bufs(
239  I0)[Number<n0 * BScaleStruct::num_scale_k_block +
240  k0 / BScaleStruct::num_scale_krepeat>{}],
241  b_thread_desc_,
242  make_tuple(I0, n0, k0, I0, I0, I0),
243  b_thread_buf);
244  });
245  }
246 
247  static_for<0, MRepeat, 1>{}([&](auto m0) {
248  static_for<0, NRepeat, 1>{}([&](auto n0) {
249  vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
250  vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
251 
252  static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
253  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
254  a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
255  Number<ik / A_K1>{}, m0, k0, I0, I0, Number<ik % A_K1>{}))>{}];
256  });
257  static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
258  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
259  b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
260  Number<ik / B_K1>{}, n0, k0, I0, I0, Number<ik % B_K1>{}))>{}];
261  });
262 
263  using wmma_input_type_a =
264  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
265  using wmma_input_type_b =
266  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
267 
268  constexpr index_t c_offset =
269  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
270 
271  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
272  b_thread_vec.template AsType<wmma_input_type_b>(),
273  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
274  });
275  });
276  });
277  };
278 
279  // main body
280  if constexpr(HasMainLoop)
281  {
282  index_t i = 0;
283  do
284  {
285  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
286  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
287 
288  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
289  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
290 
291  block_sync_lds();
292  blockwise_gemm_func();
293 
294  block_sync_lds();
295  b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
296  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
297  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
298 
299  i += 1;
300  } while(i < (num_loop - 1));
301  }
302 
303  // tail
304  if constexpr(TailNum == TailNumber::Full)
305  {
306  block_sync_lds();
307  blockwise_gemm_func();
308  }
309  }
310 
311  protected:
312  using Base::a_thread_copy_;
313  using Base::a_thread_desc_;
314  using Base::b_thread_copy_;
315  using Base::b_thread_desc_;
316  using Base::c_thread_desc_;
317 };
318 
319 template <index_t BlockSize,
320  typename ADataType,
321  typename BDataType,
322  typename ComputeTypeA,
323  typename ComputeTypeB,
324  typename AccDataType,
325  typename AWmmaTileDesc,
326  typename BWmmaTileDesc,
327  index_t ABlockTransferSrcScalarPerVector,
328  index_t BBlockTransferSrcScalarPerVector,
329  index_t MPerBlock,
330  index_t NPerBlock,
331  index_t KPerBlock,
332  index_t MPerWmma,
333  index_t NPerWmma,
334  index_t MRepeat,
335  index_t NRepeat,
336  index_t KPack,
337  bool TransposeC>
339  BlockSize,
340  ADataType,
341  BDataType,
342  ComputeTypeA,
343  ComputeTypeB,
344  AccDataType,
345  AWmmaTileDesc,
346  BWmmaTileDesc,
347  ABlockTransferSrcScalarPerVector,
348  BBlockTransferSrcScalarPerVector,
349  MPerBlock,
350  NPerBlock,
351  KPerBlock,
352  MPerWmma,
353  NPerWmma,
354  MRepeat,
355  NRepeat,
356  KPack,
357  TransposeC>
359  ADataType,
360  BDataType,
361  ComputeTypeA,
362  ComputeTypeB,
363  AccDataType,
364  AWmmaTileDesc,
365  BWmmaTileDesc,
366  ABlockTransferSrcScalarPerVector,
367  BBlockTransferSrcScalarPerVector,
368  MPerBlock,
369  NPerBlock,
370  KPerBlock,
371  MPerWmma,
372  NPerWmma,
373  MRepeat,
374  NRepeat,
375  KPack,
376  TransposeC>
377 {
379  ADataType,
380  BDataType,
381  ComputeTypeA,
382  ComputeTypeB,
383  AccDataType,
384  AWmmaTileDesc,
385  BWmmaTileDesc,
386  ABlockTransferSrcScalarPerVector,
387  BBlockTransferSrcScalarPerVector,
388  MPerBlock,
389  NPerBlock,
390  KPerBlock,
391  MPerWmma,
392  NPerWmma,
393  MRepeat,
394  NRepeat,
395  KPack,
396  TransposeC>;
397  using Base::I0;
398  using Base::I1;
399 
400  using Base::A_K1;
401  using Base::A_KRow;
402  using Base::B_K1;
403  using Base::B_KRow;
404  using Base::KRepeat;
405  using Base::WmmaK;
406 
407  using Base::wmma_gemm;
408 
409  using Base::CalculateCThreadOriginDataIndex;
410  using Base::
411  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
412  using Base::GetCThreadBuffer;
413  using Base::
414  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
415 
416  using Base::a_block_desc_k0_m0_m1_m2_k1;
417  using Base::b_block_desc_k0_n0_n1_n2_k1;
418 
419  using typename Base::Empty;
420 
422  static constexpr index_t KRepeatPerCluster = math::max(KRepeat / NumKClusters, 1);
423 
424  static constexpr index_t PrefetchStages = 1;
425  static constexpr index_t PrefillStages = 1;
426  static constexpr index_t GlobalBufferNum = 1;
427 
428  static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
429 
431  {
432  ignore = num_loop;
433  return TailNumber::Full;
434  }
435 
436  template <bool HasMainLoop,
437  TailNumber TailNum,
438  typename AGridDesc,
439  typename ABlockDesc,
440  typename ABlockTransfer,
441  typename AGridBuffer,
442  typename ABlockBuffer,
443  typename ABlockTransferStep,
444  typename BGridDesc,
445  typename BBlockDesc,
446  typename BBlockTransfer,
447  typename BGridBuffer,
448  typename BBlockBuffer,
449  typename BBlockTransferStep,
450  typename CThreadBuffer,
451  typename BScaleStruct>
452  __device__ void Run(const AGridDesc& a_grid_desc,
453  const ABlockDesc& a_block_desc,
454  ABlockTransfer& a_blockwise_copy,
455  const AGridBuffer& a_grid_buf,
456  ABlockBuffer& a_block_buf,
457  const ABlockTransferStep& a_block_copy_step,
458  const BGridDesc& b_grid_desc,
459  const BBlockDesc& b_block_desc,
460  BBlockTransfer& b_blockwise_copy,
461  const BGridBuffer& b_grid_buf,
462  BBlockBuffer& b_block_buf,
463  const BBlockTransferStep& b_block_copy_step,
464  CThreadBuffer& c_thread_buf,
465  // BScaleThreadCopy
466  BScaleStruct& b_scale_struct,
467  index_t num_loop,
468  index_t num_loop_per_scale) const
469  {
470  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
471  a_thread_desc_.GetElementSpaceSize());
472  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
473  b_thread_desc_.GetElementSpaceSize());
474 
475  // Global prefetch 1
476  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
477  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
478 
479  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
480  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
481 
482  b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
483 
484  // Local prefill 1
485  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
486  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
487 
488  // Initialize C
489  c_thread_buf.Clear();
490 
491  auto blockwise_gemm_func = [&]() {
492  static_for<0, KRepeat, KRepeatPerCluster>{}([&](auto k0_offset) {
493  static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
494  static_for<0, MRepeat, 1>{}([&](auto m0) {
495  a_thread_copy_.Run(
496  a_block_desc_k0_m0_m1_m2_k1,
497  make_tuple(Number<(k0_offset + k0_inner) * KPack / A_K1 / A_KRow>{},
498  m0,
499  I0,
500  I0,
501  I0,
502  I0),
503  a_block_buf,
504  a_thread_desc_,
505  make_tuple(I0, m0, k0_inner, I0, I0, I0),
506  a_thread_buf);
507  });
508  if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
509  {
510  static_for<0, NRepeat, 1>{}([&](auto n0) {
511  b_thread_copy_.Run(
512  b_block_desc_k0_n0_n1_n2_k1,
513  make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{},
514  n0,
515  I0,
516  I0,
517  I0,
518  I0),
519  b_block_buf,
520  b_thread_desc_,
521  make_tuple(I0, n0, k0_inner, I0, I0, I0),
522  b_thread_buf);
523  });
524  }
525  else
526  {
527  static_for<0, NRepeat, 1>{}([&](auto n0) {
528  b_thread_copy_.Run(
529  b_block_desc_k0_n0_n1_n2_k1,
530  make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{},
531  n0,
532  I0,
533  I0,
534  I0,
535  I0),
536  b_block_buf,
537  b_scale_struct.b_scale_thread_bufs(I0)[Number<
538  n0 * BScaleStruct::num_scale_k_block +
539  (k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
540  b_thread_desc_,
541  make_tuple(I0, n0, k0_inner, I0, I0, I0),
542  b_thread_buf);
543  });
544  }
545  });
546 
547  __builtin_amdgcn_sched_barrier(0);
548  // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster,
549  // but except the first, as we can shorten non-MAC cluster a bit and there's no
550  // observable negative impact. The desired effect is waves in a workgroup
551  // executing MAC in sync. This avoids some out-of-sync waves hijacking MAC
552  // resource from other workgroups and reducing the chance of latency hiding by
553  // waiting for the rest of the workgroup at the eventual sync point.
554  if constexpr(k0_offset != 0 || KRepeat == 1)
555  {
556  __builtin_amdgcn_s_barrier();
557  __builtin_amdgcn_sched_barrier(0);
558  }
559  static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
560  static_for<0, MRepeat, 1>{}([&](auto m0) {
561  static_for<0, NRepeat, 1>{}([&](auto n0) {
562  vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
563  vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
564 
565  static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
566  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
567  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
569  m0,
570  k0_inner,
571  I0,
572  I0,
573  Number<ik % A_K1>{}))>{}];
574  });
575  static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
576  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
577  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
579  n0,
580  k0_inner,
581  I0,
582  I0,
583  Number<ik % B_K1>{}))>{}];
584  });
585 
586  using wmma_input_type_a =
587  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
588  using wmma_input_type_b =
589  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
590 
591  constexpr index_t c_offset =
592  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
593 
594  // The block_sync_lds() here performs double duty:
595  // A) safeguard against data hazard.
596  // B) reduce VMEM FIFO congestion by applying small delays to
597  // different wavefronts.
598  // It is performed near the end of MAC cluster to minimize lgkmcnt
599  // penalty
600  if constexpr(k0_offset + k0_inner == KRepeat - 1 && m0 == MRepeat - 1 &&
601  n0 == NRepeat - 1)
602  {
603  __builtin_amdgcn_sched_barrier(0);
604  block_sync_lds();
605  __builtin_amdgcn_sched_barrier(0);
606  }
607  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
608  b_thread_vec.template AsType<wmma_input_type_b>(),
609  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
610  if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
611  {
612  __builtin_amdgcn_sched_barrier(0);
613  __builtin_amdgcn_s_setprio(1);
614  __builtin_amdgcn_sched_barrier(0);
615  }
616  });
617  });
618  });
619  __builtin_amdgcn_sched_barrier(0);
620  __builtin_amdgcn_s_setprio(0);
621  __builtin_amdgcn_sched_barrier(0);
622  });
623  };
624 
625  // main body
626  if constexpr(HasMainLoop)
627  {
628  index_t i = 0;
629  do
630  {
631  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
632  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
633 
634  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
635  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
636 
637  block_sync_lds();
638  blockwise_gemm_func();
639 
640  b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
641  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
642  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
643 
644  i += 1;
645  } while(i < (num_loop - 1));
646  }
647 
648  // tail
649  if constexpr(TailNum == TailNumber::Full)
650  {
651  block_sync_lds();
652  blockwise_gemm_func();
653  }
654  }
655 
656  protected:
657  static constexpr auto a_thread_desc_ =
659  Number<MRepeat>{},
660  Number<KRepeatPerCluster>{},
661  I1,
662  I1,
663  Number<A_K1>{}),
664  make_tuple(Number<A_K1>{},
665  Number<KPack / A_KRow>{},
666  Number<KPack / A_KRow * MRepeat>{},
667  I0,
668  I0,
669  I1));
670 
671  static constexpr auto b_thread_desc_ =
673  Number<NRepeat>{},
674  Number<KRepeatPerCluster>{},
675  I1,
676  I1,
677  Number<B_K1>{}),
678  make_tuple(Number<B_K1>{},
679  Number<KPack / B_KRow>{},
680  Number<KPack / B_KRow * NRepeat>{},
681  I0,
682  I0,
683  I1));
684 
685  using AThreadCopy =
687  ComputeTypeA,
688  decltype(a_block_desc_k0_m0_m1_m2_k1),
689  decltype(a_thread_desc_),
690  Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
692  5,
693  A_K1,
694  A_K1>;
695 
696  using BThreadCopy =
698  ComputeTypeB,
699  decltype(b_block_desc_k0_n0_n1_n2_k1),
700  decltype(b_thread_desc_),
701  Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
703  5,
704  B_K1,
705  B_K1>;
706 
707  AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
708  BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
709  using Base::c_thread_desc_;
710 };
711 
712 } // namespace ck
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:208
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:298
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:35
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:452
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:169
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:37
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10