/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp Source File
blockwise_gemm_pipeline_xdlops_v2.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Maximum Global Memory throughput pipeline with >=32KB data in fly
11 // GlobalPrefetchStages: >=2
12 // LocalPreFillStages: 1
13 // LocalPreFetchStages: 0
14 // LocalSharedMemoryBuffer: 1
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeDataType,
21  typename AccDataType,
22  typename ATileDesc,
23  typename BTileDesc,
24  typename AMmaTileDesc,
25  typename BMmaTileDesc,
26  index_t ABlockTransferSrcScalarPerVector,
27  index_t BBlockTransferSrcScalarPerVector,
28  index_t MPerBlock,
29  index_t NPerBlock,
30  index_t KPerBlock,
31  index_t MPerXDL,
32  index_t NPerXDL,
33  index_t MRepeat,
34  index_t NRepeat,
35  index_t KPacks>
37 {
38 };
39 
40 template <index_t BlockSize,
41  typename ADataType,
42  typename BDataType,
43  typename ComputeDataType,
44  typename AccDataType,
45  typename ATileDesc,
46  typename BTileDesc,
47  typename AMmaTileDesc,
48  typename BMmaTileDesc,
49  index_t ABlockTransferSrcScalarPerVector,
50  index_t BBlockTransferSrcScalarPerVector,
51  index_t MPerBlock,
52  index_t NPerBlock,
53  index_t KPerBlock,
54  index_t MPerXDL,
55  index_t NPerXDL,
56  index_t MRepeat,
57  index_t NRepeat,
58  index_t KPack
59  // ,bool TransposeC //disable transposec right now...
60  >
62  BlockSize,
63  ADataType,
64  BDataType,
65  ComputeDataType,
66  AccDataType,
67  ATileDesc,
68  BTileDesc,
69  AMmaTileDesc,
70  BMmaTileDesc,
71  ABlockTransferSrcScalarPerVector,
72  BBlockTransferSrcScalarPerVector,
73  MPerBlock,
74  NPerBlock,
75  KPerBlock,
76  MPerXDL,
77  NPerXDL,
78  MRepeat,
79  NRepeat,
80  KPack>
82  ADataType,
83  BDataType,
84  ComputeDataType,
85  AccDataType,
86  ATileDesc,
87  BTileDesc,
88  AMmaTileDesc,
89  BMmaTileDesc,
90  ABlockTransferSrcScalarPerVector,
91  BBlockTransferSrcScalarPerVector,
92  MPerBlock,
93  NPerBlock,
94  KPerBlock,
95  MPerXDL,
96  NPerXDL,
97  MRepeat,
98  NRepeat,
99  KPack>
100 
101 {
103  ADataType,
104  BDataType,
105  ComputeDataType,
106  AccDataType,
107  ATileDesc,
108  BTileDesc,
109  AMmaTileDesc,
110  BMmaTileDesc,
111  ABlockTransferSrcScalarPerVector,
112  BBlockTransferSrcScalarPerVector,
113  MPerBlock,
114  NPerBlock,
115  KPerBlock,
116  MPerXDL,
117  NPerXDL,
118  MRepeat,
119  NRepeat,
120  KPack>;
121  using Base::I0;
122  using Base::KRepeat;
123  using Base::xdlops_gemm;
124 
125  using Base::CalculateCThreadOriginDataIndex;
126  using Base::CalculateCThreadOriginDataIndex8D;
127  using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
128  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
129  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
130  using Base::GetCThreadBuffer;
131  using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
132  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
133  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
134  using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
135  using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
136 
137  using Base::a_block_desc_m0_m1_m2_k;
138  using Base::b_block_desc_n0_n1_n2_k;
139 
140  using Base::AMmaKStride;
141  using Base::BMmaKStride;
142 
143  static constexpr index_t WgpPerCU =
144  (4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
145  static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
146  32768 / WgpPerCU,
147  (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
148  static constexpr index_t PrefetchStages =
149  FullMemBandPrefetchStages >= 2
150  ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8
151  : 2;
152 
153  static constexpr index_t PrefillStages = 1;
154  static constexpr index_t GlobalBufferNum = PrefetchStages;
155 
156  __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
157  {
158  return num_loop > PrefetchStages;
159  }
160 
161  __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
162  {
163  if(num_loop % PrefetchStages == 1)
164  {
165  return TailNumber::One;
166  }
167  else if(num_loop % PrefetchStages == 2)
168  {
169  return TailNumber::Two;
170  }
171  else if(num_loop % PrefetchStages == 3)
172  {
173  return TailNumber::Three;
174  }
175  else if(num_loop % PrefetchStages == 4)
176  {
177  return TailNumber::Four;
178  }
179  else if(num_loop % PrefetchStages == 5)
180  {
181  return TailNumber::Five;
182  }
183  else if(num_loop % PrefetchStages == 6)
184  {
185  return TailNumber::Six;
186  }
187  else if(num_loop % PrefetchStages == 7)
188  {
189  return TailNumber::Seven;
190  }
191  else
192  {
193  return TailNumber::Full;
194  }
195  }
196 
197  template <bool HasMainLoop,
198  TailNumber TailNum,
199  typename AGridDesc,
200  typename ABlockDesc,
201  typename ABlockTransfer,
202  typename AGridBuffer,
203  typename ABlockBuffer,
204  typename ABlockTransferStep,
205  typename BGridDesc,
206  typename BBlockDesc,
207  typename BBlockTransfer,
208  typename BGridBuffer,
209  typename BBlockBuffer,
210  typename BBlockTransferStep,
211  typename CThreadBuffer>
212  __device__ void Run(const AGridDesc& a_grid_desc,
213  const ABlockDesc& a_block_desc,
214  ABlockTransfer& a_blockwise_copy,
215  const AGridBuffer& a_grid_buf,
216  ABlockBuffer& a_block_buf,
217  const ABlockTransferStep& a_block_copy_step,
218  const BGridDesc& b_grid_desc,
219  const BBlockDesc& b_block_desc,
220  BBlockTransfer& b_blockwise_copy,
221  const BGridBuffer& b_grid_buf,
222  BBlockBuffer& b_block_buf,
223  const BBlockTransferStep& b_block_copy_step,
224  CThreadBuffer& c_thread_buf,
225  index_t num_loop) const
226  {
227  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
228  a_thread_desc_.GetElementSpaceSize());
229  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
230  b_thread_desc_.GetElementSpaceSize());
231 
232  // Global prefetch 1
233  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
234  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
235 
236  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
237  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
238 
239  // Initialize C
240  c_thread_buf.Clear();
241 
242  // Local prefill 1
243  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
244  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
245 
246  // Global prefetch [2, PrefetchStages]
247  static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) {
248  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
249  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
250 
251  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
252  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
253  });
254 
255  // main body
256  if constexpr(HasMainLoop)
257  {
258  index_t i = 0;
259  do
260  {
261  static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) {
262  // -------------------------------------------------------------------------------------------
263  block_sync_lds();
264  static_for<0, KRepeat, 1>{}([&](auto k) {
265  static_for<0, MRepeat, 1>{}([&](auto m0) {
266  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
267  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
268  a_block_buf,
269  a_thread_desc_,
270  make_tuple(m0, I0, k, I0),
271  a_thread_buf);
272  });
273  static_for<0, NRepeat, 1>{}([&](auto n0) {
274  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
275  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
276  b_block_buf,
277  b_thread_desc_,
278  make_tuple(n0, I0, k, I0),
279  b_thread_buf);
280  });
281  });
282 
283  static_for<0, KRepeat, 1>{}([&](auto k0) {
284  static_for<0, MRepeat, 1>{}([&](auto m0) {
285  static_for<0, NRepeat, 1>{}([&](auto n0) {
288 
289  static_for<0, KPack, 1>{}([&](auto ik) {
290  a_thread_vec.template AsType<ComputeDataType>()(ik) =
291  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
292  make_tuple(m0, I0, k0, ik))>{}];
293  b_thread_vec.template AsType<ComputeDataType>()(ik) =
294  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
295  make_tuple(n0, I0, k0, ik))>{}];
296  });
297 
298  using mfma_input_type =
299  typename vector_type<ComputeDataType,
300  xdlops_gemm.K1PerXdlops>::type;
301 
302  constexpr index_t c_offset =
303  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
304 
305  xdlops_gemm.Run(
306  a_thread_vec.template AsType<mfma_input_type>(),
307  b_thread_vec.template AsType<mfma_input_type>(),
308  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
309  });
310  });
311  });
312 
313  block_sync_lds();
314  a_blockwise_copy.RunWrite(
315  a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{});
316  b_blockwise_copy.RunWrite(
317  b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{});
318 
319  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
320  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
321 
322  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
323  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
324  });
325 
326  i += PrefetchStages;
327  } while(i < (num_loop - PrefetchStages));
328  }
329 
330  // tail
331 
332  auto LoopTailFunc = [&](auto tail_num) {
333  static_for<1, tail_num, 1>{}([&](auto iprefetch) {
334  block_sync_lds();
335  static_for<0, KRepeat, 1>{}([&](auto k) {
336  static_for<0, MRepeat, 1>{}([&](auto m0) {
337  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
338  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
339  a_block_buf,
340  a_thread_desc_,
341  make_tuple(m0, I0, k, I0),
342  a_thread_buf);
343  });
344  static_for<0, NRepeat, 1>{}([&](auto n0) {
345  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
346  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
347  b_block_buf,
348  b_thread_desc_,
349  make_tuple(n0, I0, k, I0),
350  b_thread_buf);
351  });
352  });
353 
354  static_for<0, KRepeat, 1>{}([&](auto k0) {
355  static_for<0, MRepeat, 1>{}([&](auto m0) {
356  static_for<0, NRepeat, 1>{}([&](auto n0) {
359 
360  static_for<0, KPack, 1>{}([&](auto ik) {
361  a_thread_vec.template AsType<ComputeDataType>()(ik) =
362  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
363  make_tuple(m0, I0, k0, ik))>{}];
364  b_thread_vec.template AsType<ComputeDataType>()(ik) =
365  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
366  make_tuple(n0, I0, k0, ik))>{}];
367  });
368 
369  using mfma_input_type =
370  typename vector_type<ComputeDataType,
371  xdlops_gemm.K1PerXdlops>::type;
372 
373  constexpr index_t c_offset =
374  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
375 
376  xdlops_gemm.Run(
377  a_thread_vec.template AsType<mfma_input_type>(),
378  b_thread_vec.template AsType<mfma_input_type>(),
379  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
380  });
381  });
382  });
383 
384  block_sync_lds();
385  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch);
386  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch);
387  });
388 
389  block_sync_lds();
390  static_for<0, KRepeat, 1>{}([&](auto k) {
391  static_for<0, MRepeat, 1>{}([&](auto m0) {
392  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
393  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
394  a_block_buf,
395  a_thread_desc_,
396  make_tuple(m0, I0, k, I0),
397  a_thread_buf);
398  });
399  static_for<0, NRepeat, 1>{}([&](auto n0) {
400  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
401  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
402  b_block_buf,
403  b_thread_desc_,
404  make_tuple(n0, I0, k, I0),
405  b_thread_buf);
406  });
407  });
408 
409  static_for<0, KRepeat, 1>{}([&](auto k0) {
410  static_for<0, MRepeat, 1>{}([&](auto m0) {
411  static_for<0, NRepeat, 1>{}([&](auto n0) {
414 
415  static_for<0, KPack, 1>{}([&](auto ik) {
416  a_thread_vec.template AsType<ComputeDataType>()(ik) =
417  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
418  make_tuple(m0, I0, k0, ik))>{}];
419  b_thread_vec.template AsType<ComputeDataType>()(ik) =
420  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
421  make_tuple(n0, I0, k0, ik))>{}];
422  });
423 
424  using mfma_input_type =
425  typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
426 
427  constexpr index_t c_offset =
428  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
429 
430  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
431  b_thread_vec.template AsType<mfma_input_type>(),
432  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
433  });
434  });
435  });
436  };
437 
438  if constexpr(TailNum == TailNumber::One)
439  {
440  block_sync_lds();
441  static_for<0, KRepeat, 1>{}([&](auto k) {
442  static_for<0, MRepeat, 1>{}([&](auto m0) {
443  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
444  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
445  a_block_buf,
446  a_thread_desc_,
447  make_tuple(m0, I0, k, I0),
448  a_thread_buf);
449  });
450  static_for<0, NRepeat, 1>{}([&](auto n0) {
451  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
452  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
453  b_block_buf,
454  b_thread_desc_,
455  make_tuple(n0, I0, k, I0),
456  b_thread_buf);
457  });
458  });
459 
460  static_for<0, KRepeat, 1>{}([&](auto k0) {
461  static_for<0, MRepeat, 1>{}([&](auto m0) {
462  static_for<0, NRepeat, 1>{}([&](auto n0) {
465 
466  static_for<0, KPack, 1>{}([&](auto ik) {
467  a_thread_vec.template AsType<ComputeDataType>()(ik) =
468  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
469  make_tuple(m0, I0, k0, ik))>{}];
470  b_thread_vec.template AsType<ComputeDataType>()(ik) =
471  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
472  make_tuple(n0, I0, k0, ik))>{}];
473  });
474 
475  using mfma_input_type =
476  typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
477 
478  constexpr index_t c_offset =
479  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
480 
481  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
482  b_thread_vec.template AsType<mfma_input_type>(),
483  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
484  });
485  });
486  });
487  }
488  else if constexpr(TailNum == TailNumber::Two)
489  {
490  LoopTailFunc(Number<2>{});
491  }
492  else if constexpr(TailNum == TailNumber::Three)
493  {
494  LoopTailFunc(Number<3>{});
495  }
496  else if constexpr(TailNum == TailNumber::Four)
497  {
498  LoopTailFunc(Number<4>{});
499  }
500  else if constexpr(TailNum == TailNumber::Five)
501  {
502  LoopTailFunc(Number<5>{});
503  }
504  else if constexpr(TailNum == TailNumber::Six)
505  {
506  LoopTailFunc(Number<6>{});
507  }
508  else if constexpr(TailNum == TailNumber::Seven)
509  {
510  LoopTailFunc(Number<7>{});
511  }
512  else if constexpr(TailNum == TailNumber::Full)
513  {
514  LoopTailFunc(Number<PrefetchStages>{});
515  }
516  }
517 
518  protected:
519  using Base::a_thread_copy_;
520  using Base::a_thread_desc_;
521  using Base::b_thread_copy_;
522  using Base::b_thread_desc_;
523  using Base::c_thread_desc_;
524 };
525 
526 template <index_t BlockSize,
527  typename ADataType,
528  typename BDataType,
529  typename ComputeDataType,
530  typename AccDataType,
531  typename ATileDesc,
532  typename BTileDesc,
533  typename AMmaTileDesc,
534  typename BMmaTileDesc,
535  index_t ABlockTransferSrcScalarPerVector,
536  index_t BBlockTransferSrcScalarPerVector,
537  index_t MPerBlock,
538  index_t NPerBlock,
539  index_t KPerBlock,
540  index_t MPerXDL,
541  index_t NPerXDL,
542  index_t MRepeat,
543  index_t NRepeat,
544  index_t KPack
545  // ,bool TransposeC //disable transposec right now...
546  >
548  BlockSize,
549  ADataType,
550  BDataType,
551  ComputeDataType,
552  AccDataType,
553  ATileDesc,
554  BTileDesc,
555  AMmaTileDesc,
556  BMmaTileDesc,
557  ABlockTransferSrcScalarPerVector,
558  BBlockTransferSrcScalarPerVector,
559  MPerBlock,
560  NPerBlock,
561  KPerBlock,
562  MPerXDL,
563  NPerXDL,
564  MRepeat,
565  NRepeat,
566  KPack>
568  ADataType,
569  BDataType,
570  ComputeDataType,
571  AccDataType,
572  ATileDesc,
573  BTileDesc,
574  AMmaTileDesc,
575  BMmaTileDesc,
576  ABlockTransferSrcScalarPerVector,
577  BBlockTransferSrcScalarPerVector,
578  MPerBlock,
579  NPerBlock,
580  KPerBlock,
581  MPerXDL,
582  NPerXDL,
583  MRepeat,
584  NRepeat,
585  KPack>
586 
587 {
589  ADataType,
590  BDataType,
591  ComputeDataType,
592  AccDataType,
593  ATileDesc,
594  BTileDesc,
595  AMmaTileDesc,
596  BMmaTileDesc,
597  ABlockTransferSrcScalarPerVector,
598  BBlockTransferSrcScalarPerVector,
599  MPerBlock,
600  NPerBlock,
601  KPerBlock,
602  MPerXDL,
603  NPerXDL,
604  MRepeat,
605  NRepeat,
606  KPack>;
607  using Base::A_K1;
608  using Base::B_K1;
609  using Base::I0;
610  using Base::I1;
611  using Base::KPerThread;
612  using Base::xdlops_gemm;
613 
614  using Base::CalculateCThreadOriginDataIndex;
615  using Base::CalculateCThreadOriginDataIndex8D;
616  using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
617  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
618  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
619  using Base::GetCThreadBuffer;
620  using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
621  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
622  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
623  using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
624  using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
625 
626  using Base::a_block_desc_m0_m1_m2_k;
627  using Base::b_block_desc_n0_n1_n2_k;
628 
630  static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
631  static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
632 
633  static constexpr index_t WgpPerCU =
634  (4 * warpSize / BlockSize) >= 1 ? 4 * warpSize / BlockSize : 1;
635  static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
636  32768 / WgpPerCU,
637  (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
638  static constexpr index_t PrefetchStages =
639  FullMemBandPrefetchStages >= 2
640  ? FullMemBandPrefetchStages <= 8 ? FullMemBandPrefetchStages : 8
641  : 2;
642 
643  static constexpr index_t PrefillStages = 1;
644  static constexpr index_t GlobalBufferNum = PrefetchStages;
645 
646  __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
647  {
648  return num_loop > PrefetchStages;
649  }
650 
651  __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
652  {
653  if(num_loop % PrefetchStages == 1)
654  {
655  return TailNumber::One;
656  }
657  else if(num_loop % PrefetchStages == 2)
658  {
659  return TailNumber::Two;
660  }
661  else if(num_loop % PrefetchStages == 3)
662  {
663  return TailNumber::Three;
664  }
665  else if(num_loop % PrefetchStages == 4)
666  {
667  return TailNumber::Four;
668  }
669  else if(num_loop % PrefetchStages == 5)
670  {
671  return TailNumber::Five;
672  }
673  else if(num_loop % PrefetchStages == 6)
674  {
675  return TailNumber::Six;
676  }
677  else if(num_loop % PrefetchStages == 7)
678  {
679  return TailNumber::Seven;
680  }
681  else
682  {
683  return TailNumber::Full;
684  }
685  }
686 
687  template <bool HasMainLoop,
688  TailNumber TailNum,
689  typename AGridDesc,
690  typename ABlockDesc,
691  typename ABlockTransfer,
692  typename AGridBuffer,
693  typename ABlockBuffer,
694  typename ABlockTransferStep,
695  typename BGridDesc,
696  typename BBlockDesc,
697  typename BBlockTransfer,
698  typename BGridBuffer,
699  typename BBlockBuffer,
700  typename BBlockTransferStep,
701  typename CThreadBuffer>
702  __device__ void Run(const AGridDesc& a_grid_desc,
703  const ABlockDesc& a_block_desc,
704  ABlockTransfer& a_blockwise_copy,
705  const AGridBuffer& a_grid_buf,
706  ABlockBuffer& a_block_buf,
707  const ABlockTransferStep& a_block_copy_step,
708  const BGridDesc& b_grid_desc,
709  const BBlockDesc& b_block_desc,
710  BBlockTransfer& b_blockwise_copy,
711  const BGridBuffer& b_grid_buf,
712  BBlockBuffer& b_block_buf,
713  const BBlockTransferStep& b_block_copy_step,
714  CThreadBuffer& c_thread_buf,
715  index_t num_loop) const
716  {
717  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
718  a_thread_desc_.GetElementSpaceSize());
719  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
720  b_thread_desc_.GetElementSpaceSize());
721 
722  // Global prefetch 1
723  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
724  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
725 
726  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
727  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
728 
729  // Initialize C
730  c_thread_buf.Clear();
731 
732  // Local prefill 1
733  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
734  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
735 
736  // Global prefetch [2, PrefetchStages]
737  static_for<1, PrefetchStages, 1>{}([&](auto iprefetch) {
738  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
739  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
740 
741  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
742  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
743  });
744 
745  // main body
746  if constexpr(HasMainLoop)
747  {
748  index_t i = 0;
749  do
750  {
751  static_for<0, PrefetchStages, 1>{}([&](auto iprefetch) {
752  // -------------------------------------------------------------------------------------------
753  block_sync_lds();
754  static_for<0, KRepeat, 1>{}([&](auto k0) {
755  static_for<0, MRepeat, 1>{}([&](auto m0) {
756  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
757  make_tuple(m0, I0, I0, Number<k0 * KPerInnerLoop>{}),
758  a_block_buf,
759  a_thread_desc_,
760  make_tuple(m0, I0, k0, I0),
761  a_thread_buf);
762  });
763  static_for<0, NRepeat, 1>{}([&](auto n0) {
764  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
765  make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
766  b_block_buf,
767  b_thread_desc_,
768  make_tuple(n0, I0, k0, I0),
769  b_thread_buf);
770  });
771  __builtin_amdgcn_sched_barrier(0);
772  // NOTE: Synchronize threads in a workgroup at the start of each MAC
773  // cluster, but except the first, as we can shorten non-MAC cluster a bit
774  // and there's no observable negative impact. The desired effect is waves in
775  // a workgroup executing MAC in sync. This avoids some out-of-sync waves
776  // hijacking MAC resource from other workgroups and reducing the chance of
777  // latency hiding by waiting for the rest of the workgroup at the eventual
778  // sync point.
779  if constexpr(k0.value != 0 || KRepeat == 1)
780  {
781  __builtin_amdgcn_s_barrier();
782  __builtin_amdgcn_sched_barrier(0);
783  }
784  static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
785  static_for<0, MRepeat, 1>{}([&](auto m0) {
786  static_for<0, NRepeat, 1>{}([&](auto n0) {
789 
790  static_for<0, KPack, 1>{}([&](auto ik) {
791  a_thread_vec.template AsType<ComputeDataType>()(ik) =
792  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
793  make_tuple(m0, I0, k0, k_ + ik))>{}];
794  b_thread_vec.template AsType<ComputeDataType>()(ik) =
795  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
796  make_tuple(n0, I0, k0, k_ + ik))>{}];
797  });
798 
799  using mfma_input_type =
800  typename vector_type<ComputeDataType,
801  xdlops_gemm.K1PerXdlops>::type;
802 
803  constexpr index_t c_offset =
804  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
805 
806  // The block_sync_lds() here performs double duty:
807  // A) safeguard against data hazard because barrier from
808  // blockwise_gemm is moved here B) reduce VMEM FIFO congestion
809  // by applying small delays to different wavefronts It is
810  // performed near the end of MAC cluster to minimize lgkmcnt
811  // penalty
812  if constexpr(k0.value == KRepeat - 1 &&
813  k_.value == KPerInnerLoop - KPack &&
814  m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
815  {
816  __builtin_amdgcn_sched_barrier(0);
817  block_sync_lds();
818  __builtin_amdgcn_sched_barrier(0);
819  }
820  xdlops_gemm.Run(
821  a_thread_vec.template AsType<mfma_input_type>(),
822  b_thread_vec.template AsType<mfma_input_type>(),
823  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
824  if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
825  {
826  __builtin_amdgcn_sched_barrier(0);
827  __builtin_amdgcn_s_setprio(1);
828  __builtin_amdgcn_sched_barrier(0);
829  }
830  });
831  });
832  });
833  __builtin_amdgcn_sched_barrier(0);
834  __builtin_amdgcn_s_setprio(0);
835  __builtin_amdgcn_sched_barrier(0);
836  });
837 
838  // block_sync_lds();
839  a_blockwise_copy.RunWrite(
840  a_block_desc, a_block_buf, Number<(iprefetch + 1) % PrefetchStages>{});
841  b_blockwise_copy.RunWrite(
842  b_block_desc, b_block_buf, Number<(iprefetch + 1) % PrefetchStages>{});
843 
844  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, iprefetch);
845  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, iprefetch);
846 
847  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
848  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
849  });
850  i += PrefetchStages;
851  } while(i < (num_loop - PrefetchStages));
852  }
853 
854  // tail
855 
856  auto LoopTailFunc = [&](auto tail_num) {
857  static_for<1, tail_num, 1>{}([&](auto iprefetch) {
858  block_sync_lds();
859  static_for<0, KRepeat, 1>{}([&](auto k0) {
860  static_for<0, MRepeat, 1>{}([&](auto m0) {
861  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
862  make_tuple(m0, I0, I0, Number<k0 * KPerInnerLoop>{}),
863  a_block_buf,
864  a_thread_desc_,
865  make_tuple(m0, I0, k0, I0),
866  a_thread_buf);
867  });
868  static_for<0, NRepeat, 1>{}([&](auto n0) {
869  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
870  make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
871  b_block_buf,
872  b_thread_desc_,
873  make_tuple(n0, I0, k0, I0),
874  b_thread_buf);
875  });
876 
877  __builtin_amdgcn_sched_barrier(0);
878  if constexpr(k0.value != 0 || KRepeat == 1)
879  {
880  __builtin_amdgcn_s_barrier();
881  __builtin_amdgcn_sched_barrier(0);
882  }
883  static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
884  static_for<0, MRepeat, 1>{}([&](auto m0) {
885  static_for<0, NRepeat, 1>{}([&](auto n0) {
888 
889  static_for<0, KPack, 1>{}([&](auto ik) {
890  a_thread_vec.template AsType<ComputeDataType>()(ik) =
891  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
892  make_tuple(m0, I0, k0, k_ + ik))>{}];
893  b_thread_vec.template AsType<ComputeDataType>()(ik) =
894  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
895  make_tuple(n0, I0, k0, k_ + ik))>{}];
896  });
897 
898  using mfma_input_type =
899  typename vector_type<ComputeDataType,
900  xdlops_gemm.K1PerXdlops>::type;
901 
902  constexpr index_t c_offset =
903  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
904 
905  if constexpr(k0.value == KRepeat - 1 &&
906  k_.value == KPerInnerLoop - KPack &&
907  m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
908  {
909  __builtin_amdgcn_sched_barrier(0);
910  block_sync_lds();
911  __builtin_amdgcn_sched_barrier(0);
912  }
913  xdlops_gemm.Run(
914  a_thread_vec.template AsType<mfma_input_type>(),
915  b_thread_vec.template AsType<mfma_input_type>(),
916  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
917  if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
918  {
919  __builtin_amdgcn_sched_barrier(0);
920  __builtin_amdgcn_s_setprio(1);
921  __builtin_amdgcn_sched_barrier(0);
922  }
923  });
924  });
925  });
926  __builtin_amdgcn_sched_barrier(0);
927  __builtin_amdgcn_s_setprio(0);
928  __builtin_amdgcn_sched_barrier(0);
929  });
930 
931  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, iprefetch);
932  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, iprefetch);
933  });
934  block_sync_lds();
935  static_for<0, KRepeat, 1>{}([&](auto k0) {
936  static_for<0, MRepeat, 1>{}([&](auto m0) {
937  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
938  make_tuple(m0, I0, I0, Number<k0 * KPerInnerLoop>{}),
939  a_block_buf,
940  a_thread_desc_,
941  make_tuple(m0, I0, k0, I0),
942  a_thread_buf);
943  });
944  static_for<0, NRepeat, 1>{}([&](auto n0) {
945  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
946  make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
947  b_block_buf,
948  b_thread_desc_,
949  make_tuple(n0, I0, k0, I0),
950  b_thread_buf);
951  });
952 
953  __builtin_amdgcn_sched_barrier(0);
954  if constexpr(k0.value != 0 || KRepeat == 1)
955  {
956  __builtin_amdgcn_s_barrier();
957  __builtin_amdgcn_sched_barrier(0);
958  }
959  static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
960  static_for<0, MRepeat, 1>{}([&](auto m0) {
961  static_for<0, NRepeat, 1>{}([&](auto n0) {
964 
965  static_for<0, KPack, 1>{}([&](auto ik) {
966  a_thread_vec.template AsType<ComputeDataType>()(ik) =
967  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
968  make_tuple(m0, I0, k0, k_ + ik))>{}];
969  b_thread_vec.template AsType<ComputeDataType>()(ik) =
970  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
971  make_tuple(n0, I0, k0, k_ + ik))>{}];
972  });
973 
974  using mfma_input_type =
975  typename vector_type<ComputeDataType,
976  xdlops_gemm.K1PerXdlops>::type;
977 
978  constexpr index_t c_offset =
979  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
980 
981  if constexpr(k0.value == KRepeat - 1 &&
982  k_.value == KPerInnerLoop - KPack &&
983  m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
984  {
985  __builtin_amdgcn_sched_barrier(0);
986  block_sync_lds();
987  __builtin_amdgcn_sched_barrier(0);
988  }
989  xdlops_gemm.Run(
990  a_thread_vec.template AsType<mfma_input_type>(),
991  b_thread_vec.template AsType<mfma_input_type>(),
992  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
993  if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
994  {
995  __builtin_amdgcn_sched_barrier(0);
996  __builtin_amdgcn_s_setprio(1);
997  __builtin_amdgcn_sched_barrier(0);
998  }
999  });
1000  });
1001  });
1002  __builtin_amdgcn_sched_barrier(0);
1003  __builtin_amdgcn_s_setprio(0);
1004  __builtin_amdgcn_sched_barrier(0);
1005  });
1006  };
1007 
1008  if constexpr(TailNum == TailNumber::One)
1009  {
1010  block_sync_lds();
1011  static_for<0, KRepeat, 1>{}([&](auto k0) {
1012  static_for<0, MRepeat, 1>{}([&](auto m0) {
1013  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
1014  make_tuple(m0, I0, I0, Number<k0 * KPerInnerLoop>{}),
1015  a_block_buf,
1016  a_thread_desc_,
1017  make_tuple(m0, I0, k0, I0),
1018  a_thread_buf);
1019  });
1020  static_for<0, NRepeat, 1>{}([&](auto n0) {
1021  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
1022  make_tuple(n0, I0, I0, Number<k0 * KPerInnerLoop>{}),
1023  b_block_buf,
1024  b_thread_desc_,
1025  make_tuple(n0, I0, k0, I0),
1026  b_thread_buf);
1027  });
1028 
1029  __builtin_amdgcn_sched_barrier(0);
1030  if constexpr(k0.value != 0 || KRepeat == 1)
1031  {
1032  __builtin_amdgcn_s_barrier();
1033  __builtin_amdgcn_sched_barrier(0);
1034  }
1035  static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
1036  static_for<0, MRepeat, 1>{}([&](auto m0) {
1037  static_for<0, NRepeat, 1>{}([&](auto n0) {
1040 
1041  static_for<0, KPack, 1>{}([&](auto ik) {
1042  a_thread_vec.template AsType<ComputeDataType>()(ik) =
1043  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1044  make_tuple(m0, I0, k0, k_ + ik))>{}];
1045  b_thread_vec.template AsType<ComputeDataType>()(ik) =
1046  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
1047  make_tuple(n0, I0, k0, k_ + ik))>{}];
1048  });
1049 
1050  using mfma_input_type =
1051  typename vector_type<ComputeDataType,
1052  xdlops_gemm.K1PerXdlops>::type;
1053 
1054  constexpr index_t c_offset =
1055  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
1056 
1057  if constexpr(k0.value == KRepeat - 1 &&
1058  k_.value == KPerInnerLoop - KPack &&
1059  m0.value == MRepeat - 1 && n0.value == NRepeat - 1)
1060  {
1061  __builtin_amdgcn_sched_barrier(0);
1062  block_sync_lds();
1063  __builtin_amdgcn_sched_barrier(0);
1064  }
1065  xdlops_gemm.Run(
1066  a_thread_vec.template AsType<mfma_input_type>(),
1067  b_thread_vec.template AsType<mfma_input_type>(),
1068  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1069  if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
1070  {
1071  __builtin_amdgcn_sched_barrier(0);
1072  __builtin_amdgcn_s_setprio(1);
1073  __builtin_amdgcn_sched_barrier(0);
1074  }
1075  });
1076  });
1077  });
1078  __builtin_amdgcn_sched_barrier(0);
1079  __builtin_amdgcn_s_setprio(0);
1080  __builtin_amdgcn_sched_barrier(0);
1081  });
1082  }
1083  else if constexpr(TailNum == TailNumber::Two)
1084  {
1085  LoopTailFunc(Number<2>{});
1086  }
1087  else if constexpr(TailNum == TailNumber::Three)
1088  {
1089  LoopTailFunc(Number<3>{});
1090  }
1091  else if constexpr(TailNum == TailNumber::Four)
1092  {
1093  LoopTailFunc(Number<4>{});
1094  }
1095  else if constexpr(TailNum == TailNumber::Five)
1096  {
1097  LoopTailFunc(Number<5>{});
1098  }
1099  else if constexpr(TailNum == TailNumber::Six)
1100  {
1101  LoopTailFunc(Number<6>{});
1102  }
1103  else if constexpr(TailNum == TailNumber::Seven)
1104  {
1105  LoopTailFunc(Number<7>{});
1106  }
1107  else if constexpr(TailNum == TailNumber::Full)
1108  {
1109  LoopTailFunc(Number<PrefetchStages>{});
1110  }
1111  }
1112 
1113  protected:
1114  // K->M loopover
1115  static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
1116  make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPerInnerLoop>{}),
1117  make_tuple(Number<KPerInnerLoop>{},
1118  Number<KRepeat * MRepeat * KPerInnerLoop>{},
1119  Number<MRepeat * KPerInnerLoop>{},
1120  I1));
1121 
1122  static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
1123  make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPerInnerLoop>{}),
1124  make_tuple(Number<KPerInnerLoop>{},
1125  Number<KRepeat * NRepeat * KPerInnerLoop>{},
1126  Number<NRepeat * KPerInnerLoop>{},
1127  I1));
1128 
1130  ComputeDataType,
1131  decltype(a_block_desc_m0_m1_m2_k),
1132  decltype(a_thread_desc_),
1135  3,
1136  A_K1,
1137  A_K1>;
1138 
1140  ComputeDataType,
1141  decltype(b_block_desc_n0_n1_n2_k),
1142  decltype(b_thread_desc_),
1145  3,
1146  B_K1,
1147  B_K1>;
1148 
1149  AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
1150  BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
1151  using Base::c_thread_desc_;
1152 };
1153 
1154 } // namespace ck
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:211
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
TailNumber
Definition: blkgemmpipe_scheduler.hpp:18
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:702
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:212
Definition: blockwise_gemm_pipeline_xdlops_v2.hpp:37
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
Definition: data_type.hpp:347