/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp Source File
blockwise_gemm_pipeline_xdlops_v4.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Compute optimimal pipeline with highest resource request
11 // GlobalPrefetchStages: 4
12 // LocalPreFillStages: 2
13 // LocalPreFetchStages: 1
14 // LocalSharedMemoryBuffer: 2
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeDataType,
21  typename AccDataType,
22  typename ATileDesc,
23  typename BTileDesc,
24  typename AMmaTileDesc,
25  typename BMmaTileDesc,
26  index_t ABlockTransferSrcScalarPerVector,
27  index_t BBlockTransferSrcScalarPerVector,
28  index_t MPerBlock,
29  index_t NPerBlock,
30  index_t KPerBlock,
31  index_t MPerXDL,
32  index_t NPerXDL,
33  index_t MRepeat,
34  index_t NRepeat,
35  index_t KPacks>
36 struct BlockwiseGemmXdlops_pipeline_v4
37 {
38 };
39 
40 template <index_t BlockSize,
41  typename ADataType,
42  typename BDataType,
43  typename ComputeDataType,
44  typename AccDataType,
45  typename ATileDesc,
46  typename BTileDesc,
47  typename AMmaTileDesc,
48  typename BMmaTileDesc,
49  index_t ABlockTransferSrcScalarPerVector,
50  index_t BBlockTransferSrcScalarPerVector,
51  index_t MPerBlock,
52  index_t NPerBlock,
53  index_t KPerBlock,
54  index_t MPerXDL,
55  index_t NPerXDL,
56  index_t MRepeat,
57  index_t NRepeat,
58  index_t KPack
59  // ,bool TransposeC //disable transposec right now...
60  >
62  BlockSize,
63  ADataType,
64  BDataType,
65  ComputeDataType,
66  AccDataType,
67  ATileDesc,
68  BTileDesc,
69  AMmaTileDesc,
70  BMmaTileDesc,
71  ABlockTransferSrcScalarPerVector,
72  BBlockTransferSrcScalarPerVector,
73  MPerBlock,
74  NPerBlock,
75  KPerBlock,
76  MPerXDL,
77  NPerXDL,
78  MRepeat,
79  NRepeat,
80  KPack>
82  ADataType,
83  BDataType,
84  ComputeDataType,
85  AccDataType,
86  ATileDesc,
87  BTileDesc,
88  AMmaTileDesc,
89  BMmaTileDesc,
90  ABlockTransferSrcScalarPerVector,
91  BBlockTransferSrcScalarPerVector,
92  MPerBlock,
93  NPerBlock,
94  KPerBlock,
95  MPerXDL,
96  NPerXDL,
97  MRepeat,
98  NRepeat,
99  KPack>
100 
101 {
103  ADataType,
104  BDataType,
105  ComputeDataType,
106  AccDataType,
107  ATileDesc,
108  BTileDesc,
109  AMmaTileDesc,
110  BMmaTileDesc,
111  ABlockTransferSrcScalarPerVector,
112  BBlockTransferSrcScalarPerVector,
113  MPerBlock,
114  NPerBlock,
115  KPerBlock,
116  MPerXDL,
117  NPerXDL,
118  MRepeat,
119  NRepeat,
120  KPack>;
121  using Base::I0;
122  using Base::I1;
123  using Base::KRepeat;
124  using Base::xdlops_gemm;
125  using typename Base::HotLoopInstList;
126 
127  using Base::CalculateCThreadOriginDataIndex;
128  using Base::CalculateCThreadOriginDataIndex8D;
129  using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
130  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
131  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
132  using Base::GetCThreadBuffer;
133  using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
134  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
135  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
136  using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
137  using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
138 
139  using Base::a_block_desc_m0_m1_m2_k;
140  using Base::b_block_desc_n0_n1_n2_k;
141 
142  using Base::AMmaKStride;
143  using Base::BMmaKStride;
144 
145  static constexpr index_t PrefetchStages = 4;
146  static constexpr index_t PrefillStages = 2;
147  static constexpr index_t GlobalBufferNum = 2;
148  static constexpr index_t HotloopUnroll = 2;
149 
150  __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
151  {
152  return num_loop > PrefetchStages;
153  }
154 
155  __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
156  {
157  if(num_loop % HotloopUnroll == 1)
158  {
159  return TailNumber::Odd;
160  }
161  else
162  {
163  return TailNumber::Even;
164  }
165  }
166 
167  template <typename ScheduleGroup>
168  __device__ static constexpr void HotLoopScheduler(ScheduleGroup schedule_group)
169  {
170  // TODO: Take data type into consideration as pipe ver 3
171  // A-B splited schedule
172  constexpr auto num_ds_read_inst_a =
173  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
176  constexpr auto num_ds_read_inst_b =
177  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
180 
181  constexpr auto num_issue_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
182  constexpr auto num_dswrite_per_issue_a =
183  (HotLoopInstList::A_LDS_Write_Inst_Num + num_issue_a - 1) / num_issue_a;
184  constexpr auto num_dsread_per_issue_a = num_ds_read_inst_a / num_issue_a;
185 
186  constexpr auto num_issue_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
187  constexpr auto num_dswrite_per_issue_b =
188  (HotLoopInstList::B_LDS_Write_Inst_Num + num_issue_b - 1) / num_issue_b;
189  constexpr auto num_dsread_per_issue_b = num_ds_read_inst_b / num_issue_b;
190 
191  constexpr auto num_mfma_per_issue =
192  HotLoopInstList::C_MFMA_Inst_Num / (num_issue_a + num_issue_b);
193 
194  static_for<0, num_issue_a, 1>{}([&](auto i) {
195  ignore = i;
196  static_for<0, num_dsread_per_issue_a, 1>{}([&](auto idsread) {
197  ignore = idsread;
198  __builtin_amdgcn_sched_group_barrier(0x100, 1, schedule_group); // DS read
199  __builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA
200  });
201 
202  static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
203  ignore = idswrite;
204  __builtin_amdgcn_sched_group_barrier(0x200, 1, schedule_group); // DS write
205  __builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA
206  });
207 
208  __builtin_amdgcn_sched_group_barrier(0x020, 1, schedule_group); // VMEM read
209  __builtin_amdgcn_sched_group_barrier(0x008,
210  num_mfma_per_issue - num_dsread_per_issue_a -
211  num_dswrite_per_issue_a,
212  schedule_group); // MFMA
213  });
214 
215  static_for<0, num_issue_b, 1>{}([&](auto i) {
216  ignore = i;
217  static_for<0, num_dsread_per_issue_b, 1>{}([&](auto idsread) {
218  ignore = idsread;
219  __builtin_amdgcn_sched_group_barrier(0x100, 1, schedule_group); // DS read
220  __builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA
221  });
222 
223  static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
224  ignore = idswrite;
225  __builtin_amdgcn_sched_group_barrier(0x200, 1, schedule_group); // DS write
226  __builtin_amdgcn_sched_group_barrier(0x008, 1, schedule_group); // MFMA
227  });
228 
229  __builtin_amdgcn_sched_group_barrier(0x020, 1, schedule_group); // VMEM read
230  __builtin_amdgcn_sched_group_barrier(0x008,
231  num_mfma_per_issue - num_dsread_per_issue_a -
232  num_dswrite_per_issue_b,
233  schedule_group); // MFMA
234  });
235  __builtin_amdgcn_sched_barrier(0);
236  }
237 
238  template <bool HasMainLoop,
239  TailNumber TailNum,
240  typename AGridDesc,
241  typename ABlockDesc,
242  typename ABlockTransfer,
243  typename AGridBuffer,
244  typename ABlockBuffer,
245  typename ABlockTransferStep,
246  typename BGridDesc,
247  typename BBlockDesc,
248  typename BBlockTransfer,
249  typename BGridBuffer,
250  typename BBlockBuffer,
251  typename BBlockTransferStep,
252  typename CThreadBuffer>
253  __device__ void Run(const AGridDesc& a_grid_desc,
254  const ABlockDesc& a_block_desc,
255  ABlockTransfer& a_blockwise_copy,
256  const AGridBuffer& a_grid_buf,
257  ABlockBuffer& a_block_buf,
258  const ABlockTransferStep& a_block_copy_step,
259  const BGridDesc& b_grid_desc,
260  const BBlockDesc& b_block_desc,
261  BBlockTransfer& b_blockwise_copy,
262  const BGridBuffer& b_grid_buf,
263  BBlockBuffer& b_block_buf,
264  const BBlockTransferStep& b_block_copy_step,
265  CThreadBuffer& c_thread_buf,
266  index_t num_loop) const
267  {
268  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
269  a_thread_desc_.GetElementSpaceSize());
270  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
271  b_thread_desc_.GetElementSpaceSize());
272 
273  StaticallyIndexedArray<decltype(a_thread_buf), Number<2>{}> a_thread_bufs;
274  StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
275 
276  // Global prefetch 1
277  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
278  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
279 
280  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
281  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
282 
283  // Global prefetch 2
284  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
285  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
286 
287  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
288  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
289 
290  // Local prefill 1
291  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0), I0);
292  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0), I0);
293 
294  // Local prefill 2
295  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1), I1);
296  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1), I1);
297 
298  // Local prefetch 1
299  block_sync_lds();
300  static_for<0, KRepeat, 1>{}([&](auto k) {
301  static_for<0, MRepeat, 1>{}([&](auto m0) {
304  a_block_buf.At(I0),
306  make_tuple(m0, I0, k, I0),
307  a_thread_bufs(I0));
308  });
309  static_for<0, NRepeat, 1>{}([&](auto n0) {
312  b_block_buf.At(I0),
314  make_tuple(n0, I0, k, I0),
315  b_thread_bufs(I0));
316  });
317  });
318 
319  // Global prefetch 3
320  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
321  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
322 
323  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
324  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
325 
326  // Global prefetch 4
327  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
328  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
329 
330  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
331  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
332 
333  // Initialize C
334  c_thread_buf.Clear();
335 
336  // main body
337  if constexpr(HasMainLoop)
338  {
339  index_t i = 0;
340  // This hot loop has two legacy loopover, to implement the double local buffer strategy
341  do
342  {
343  auto LoopFunc = [&](auto lds_read_buf,
344  auto lds_read_reg_buf,
345  auto lds_write_buf,
346  auto vmem_buf,
347  auto mfma_reg_buf,
348  auto schedule_group) {
349  block_sync_lds();
350 
351  static_for<0, KRepeat, 1>{}([&](auto k) {
352  static_for<0, MRepeat, 1>{}([&](auto m0) {
355  a_block_buf.At(lds_read_buf),
357  make_tuple(m0, I0, k, I0),
358  a_thread_bufs(lds_read_reg_buf));
359  });
360  static_for<0, NRepeat, 1>{}([&](auto n0) {
363  b_block_buf.At(lds_read_buf),
365  make_tuple(n0, I0, k, I0),
366  b_thread_bufs(lds_read_reg_buf));
367  });
368  });
369 
370  a_blockwise_copy.RunWrite(
371  a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf);
372  b_blockwise_copy.RunWrite(
373  b_block_desc, b_block_buf.At(lds_write_buf), vmem_buf);
374 
375  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, vmem_buf);
376  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, vmem_buf);
377 
378  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
379  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
380 
381  static_for<0, KRepeat, 1>{}([&](auto k0) {
382  static_for<0, MRepeat, 1>{}([&](auto m0) {
383  static_for<0, NRepeat, 1>{}([&](auto n0) {
386 
387  static_for<0, KPack, 1>{}([&](auto ik) {
388  a_thread_vec.template AsType<ComputeDataType>()(ik) =
389  a_thread_bufs[mfma_reg_buf]
390  [Number<a_thread_desc_.CalculateOffset(
391  make_tuple(m0, I0, k0, ik))>{}];
392  b_thread_vec.template AsType<ComputeDataType>()(ik) =
393  b_thread_bufs[mfma_reg_buf]
394  [Number<b_thread_desc_.CalculateOffset(
395  make_tuple(n0, I0, k0, ik))>{}];
396  });
397 
398  using mfma_input_type =
399  typename vector_type<ComputeDataType,
400  xdlops_gemm.K1PerXdlops>::type;
401 
402  constexpr index_t c_offset =
403  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
404 
405  xdlops_gemm.Run(
406  a_thread_vec.template AsType<mfma_input_type>(),
407  b_thread_vec.template AsType<mfma_input_type>(),
408  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
409  });
410  });
411  });
412 
413  HotLoopScheduler(schedule_group);
414  };
415 
416  LoopFunc(I1, I1, I0, I0, I0, I0);
417  LoopFunc(I0, I0, I1, I1, I1, I0);
418 
419  i += HotloopUnroll;
420  } while(i < (num_loop - PrefetchStages));
421  }
422 
423  auto ReadWriteCompFunc = [&](auto lds_read_buf,
424  auto lds_read_reg_buf,
425  auto lds_write_buf,
426  auto vmem_buf,
427  auto mfma_reg_buf,
428  auto schedule_group) {
429  block_sync_lds();
430 
431  static_for<0, KRepeat, 1>{}([&](auto k) {
432  static_for<0, MRepeat, 1>{}([&](auto m0) {
435  a_block_buf.At(lds_read_buf),
437  make_tuple(m0, I0, k, I0),
438  a_thread_bufs(lds_read_reg_buf));
439  });
440  static_for<0, NRepeat, 1>{}([&](auto n0) {
443  b_block_buf.At(lds_read_buf),
445  make_tuple(n0, I0, k, I0),
446  b_thread_bufs(lds_read_reg_buf));
447  });
448  });
449 
450  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(lds_write_buf), vmem_buf);
451  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(lds_write_buf), vmem_buf);
452 
453  static_for<0, KRepeat, 1>{}([&](auto k0) {
454  static_for<0, MRepeat, 1>{}([&](auto m0) {
455  static_for<0, NRepeat, 1>{}([&](auto n0) {
458 
459  static_for<0, KPack, 1>{}([&](auto ik) {
460  a_thread_vec.template AsType<ComputeDataType>()(ik) =
461  a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
462  make_tuple(m0, I0, k0, ik))>{}];
463  b_thread_vec.template AsType<ComputeDataType>()(ik) =
464  b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
465  make_tuple(n0, I0, k0, ik))>{}];
466  });
467 
468  using mfma_input_type =
469  typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
470 
471  constexpr index_t c_offset =
472  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
473 
474  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
475  b_thread_vec.template AsType<mfma_input_type>(),
476  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
477  });
478  });
479  });
480 
481  HotLoopScheduler(schedule_group);
482  };
483 
484  auto ReadCompFunc = [&](auto lds_read_buf,
485  auto lds_read_reg_buf,
486  auto mfma_reg_buf,
487  auto schedule_group) {
488  block_sync_lds();
489 
490  static_for<0, KRepeat, 1>{}([&](auto k) {
491  static_for<0, MRepeat, 1>{}([&](auto m0) {
494  a_block_buf.At(lds_read_buf),
496  make_tuple(m0, I0, k, I0),
497  a_thread_bufs(lds_read_reg_buf));
498  });
499  static_for<0, NRepeat, 1>{}([&](auto n0) {
502  b_block_buf.At(lds_read_buf),
504  make_tuple(n0, I0, k, I0),
505  b_thread_bufs(lds_read_reg_buf));
506  });
507  });
508 
509  static_for<0, KRepeat, 1>{}([&](auto k0) {
510  static_for<0, MRepeat, 1>{}([&](auto m0) {
511  static_for<0, NRepeat, 1>{}([&](auto n0) {
514 
515  static_for<0, KPack, 1>{}([&](auto ik) {
516  a_thread_vec.template AsType<ComputeDataType>()(ik) =
517  a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
518  make_tuple(m0, I0, k0, ik))>{}];
519  b_thread_vec.template AsType<ComputeDataType>()(ik) =
520  b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
521  make_tuple(n0, I0, k0, ik))>{}];
522  });
523 
524  using mfma_input_type =
525  typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
526 
527  constexpr index_t c_offset =
528  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
529 
530  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
531  b_thread_vec.template AsType<mfma_input_type>(),
532  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
533  });
534  });
535  });
536 
537  HotLoopScheduler(schedule_group);
538  };
539 
540  auto CompFunc = [&](auto mfma_reg_buf) {
541  static_for<0, KRepeat, 1>{}([&](auto k0) {
542  static_for<0, MRepeat, 1>{}([&](auto m0) {
543  static_for<0, NRepeat, 1>{}([&](auto n0) {
546 
547  static_for<0, KPack, 1>{}([&](auto ik) {
548  a_thread_vec.template AsType<ComputeDataType>()(ik) =
549  a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
550  make_tuple(m0, I0, k0, ik))>{}];
551  b_thread_vec.template AsType<ComputeDataType>()(ik) =
552  b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
553  make_tuple(n0, I0, k0, ik))>{}];
554  });
555 
556  using mfma_input_type =
557  typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
558 
559  constexpr index_t c_offset =
560  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
561 
562  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
563  b_thread_vec.template AsType<mfma_input_type>(),
564  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
565  });
566  });
567  });
568  };
569  // tail
570  if constexpr(TailNum == TailNumber::Odd)
571  {
572  ReadWriteCompFunc(I1, I1, I0, I0, I0, I1);
573  ReadCompFunc(I0, I0, I1, I1);
574  CompFunc(I0);
575  }
576  else if constexpr(TailNum == TailNumber::Even)
577  {
578  ReadWriteCompFunc(I1, I1, I0, I0, I0, I1);
579  ReadWriteCompFunc(I0, I0, I1, I1, I1, I1);
580  ReadCompFunc(I1, I1, I0, I1);
581  CompFunc(I1);
582  }
583  }
584 
585  protected:
586  using Base::a_thread_copy_;
587  using Base::a_thread_desc_;
588  using Base::b_thread_copy_;
589  using Base::b_thread_desc_;
590  using Base::c_thread_desc_;
591 };
592 
593 } // namespace ck
Definition: ck.hpp:264
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
TailNumber
Definition: blkgemmpipe_scheduler.hpp:18
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
static constexpr index_t B_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:46
static constexpr index_t A_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:49
static constexpr index_t A_LDS_Read_Width
Definition: blkgemmpipe_scheduler.hpp:59
static constexpr index_t B_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:51
static constexpr index_t A_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:44
static constexpr index_t C_MFMA_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:54
static constexpr index_t A_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:39
static constexpr index_t B_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:41
static constexpr index_t B_LDS_Read_Width
Definition: blkgemmpipe_scheduler.hpp:60
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v4.hpp:253
Definition: blockwise_gemm_pipeline_xdlops.hpp:103
static constexpr auto I1
Definition: blockwise_gemm_pipeline_xdlops.hpp:105
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:961
static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops.hpp:373
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:967
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops.hpp:991
static constexpr auto I0
Definition: blockwise_gemm_pipeline_xdlops.hpp:104
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_pipeline_xdlops.hpp:453
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops.hpp:990
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_pipeline_xdlops.hpp:454
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:955
static constexpr auto xdlops_gemm
Definition: blockwise_gemm_pipeline_xdlops.hpp:118
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1036
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
Definition: data_type.hpp:347