/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.2/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.2/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.2/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp Source File
blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Compute optimized pipeline
11 // GlobalPrefetchStages: 2
12 // LocalPreFillStages: 1
13 // LocalPreFetchStages: 1
14 // LocalSharedMemoryBuffer: 1
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeDataType,
21  typename AccDataType,
22  typename ATileDesc,
23  typename BTileDesc,
24  typename AMmaTileDesc,
25  typename BMmaTileDesc,
26  index_t ABlockTransferSrcScalarPerVector,
27  index_t BBlockTransferSrcScalarPerVector,
28  index_t MPerBlock,
29  index_t NPerBlock,
30  index_t KPerBlock,
31  index_t MPerXDL,
32  index_t NPerXDL,
33  index_t MRepeat,
34  index_t NRepeat,
35  index_t KPacks>
37 {
38 };
39 
40 template <index_t BlockSize,
41  typename ADataType,
42  typename BDataType,
43  typename ComputeDataType,
44  typename AccDataType,
45  typename ATileDesc,
46  typename BTileDesc,
47  typename AMmaTileDesc,
48  typename BMmaTileDesc,
49  index_t ABlockTransferSrcScalarPerVector,
50  index_t BBlockTransferSrcScalarPerVector,
51  index_t MPerBlock,
52  index_t NPerBlock,
53  index_t KPerBlock,
54  index_t MPerXDL,
55  index_t NPerXDL,
56  index_t MRepeat,
57  index_t NRepeat,
58  index_t KPack
59  // ,bool TransposeC //disable transposec right now...
60  >
62  BlockSize,
63  ADataType,
64  BDataType,
65  ComputeDataType,
66  AccDataType,
67  ATileDesc,
68  BTileDesc,
69  AMmaTileDesc,
70  BMmaTileDesc,
71  ABlockTransferSrcScalarPerVector,
72  BBlockTransferSrcScalarPerVector,
73  MPerBlock,
74  NPerBlock,
75  KPerBlock,
76  MPerXDL,
77  NPerXDL,
78  MRepeat,
79  NRepeat,
80  KPack>
82  ADataType,
83  BDataType,
84  ComputeDataType,
85  AccDataType,
86  ATileDesc,
87  BTileDesc,
88  AMmaTileDesc,
89  BMmaTileDesc,
90  ABlockTransferSrcScalarPerVector,
91  BBlockTransferSrcScalarPerVector,
92  MPerBlock,
93  NPerBlock,
94  KPerBlock,
95  MPerXDL,
96  NPerXDL,
97  MRepeat,
98  NRepeat,
99  KPack>
100 
101 {
103  ADataType,
104  BDataType,
105  ComputeDataType,
106  AccDataType,
107  ATileDesc,
108  BTileDesc,
109  AMmaTileDesc,
110  BMmaTileDesc,
111  ABlockTransferSrcScalarPerVector,
112  BBlockTransferSrcScalarPerVector,
113  MPerBlock,
114  NPerBlock,
115  KPerBlock,
116  MPerXDL,
117  NPerXDL,
118  MRepeat,
119  NRepeat,
120  KPack>;
121  using Base::A_K1;
122  using Base::B_K1;
123  using Base::I0;
124  using Base::I1;
125  using Base::I2;
126  using Base::KGroup;
127  using Base::KRepeat;
128  using Base::xdlops_gemm;
129  using typename Base::HotLoopInstList;
130 
131  using Base::a_block_desc_m0_m1_m2_k;
132  using Base::CalculateCThreadOriginDataIndex;
133  using Base::CalculateCThreadOriginDataIndex8D;
134  using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
135  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
136  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
137  using Base::GetCThreadBuffer;
138  using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
139  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
140  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
141  using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
142  using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
143 
144  using Base::AMmaKStride;
145  using Base::BMmaKStride;
146 
147  using Base::MWaves;
148 
149  static constexpr index_t PrefetchStages = 2;
150  static constexpr index_t PrefillStages = 1;
151  static constexpr index_t GlobalBufferNum = 1;
152  static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1;
153 
154  template <typename TileDesc_M0_M1_M2_K>
155  __host__ __device__ static constexpr auto MakeAGemmMmaTileDescriptor(const TileDesc_M0_M1_M2_K&)
156  {
157  constexpr index_t M0 = TileDesc_M0_M1_M2_K{}.GetLength(Number<0>{});
158  constexpr index_t M1 = TileDesc_M0_M1_M2_K{}.GetLength(Number<1>{});
159  constexpr index_t M2 = TileDesc_M0_M1_M2_K{}.GetLength(Number<2>{});
160  constexpr index_t K2 = KPack / KGroup;
161  constexpr index_t K1 = 64 / NPerXDL;
162  constexpr index_t K0 = KRepeat * KGroup;
163 
165  TileDesc_M0_M1_M2_K{},
166  make_tuple(
173  }
174 
175  static constexpr auto a_block_desc_m0_m1_m2_k0_k1_k2 =
176  MakeAGemmMmaTileDescriptor(a_block_desc_m0_m1_m2_k);
177 
178  __host__ __device__ static constexpr bool BlockHasHotloop(index_t num_loop)
179  {
180  return num_loop > PrefetchStages;
181  }
182 
183  __host__ __device__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
184  {
185  return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
186  }
187 
188  __device__ static constexpr auto HotLoopScheduler()
189  {
190  // A/B split schedule
191  // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
192  constexpr auto num_ds_read_inst_a =
193  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
194  ? HotLoopInstList::A_LDS_Read_Inst_Num
195  : HotLoopInstList::A_LDS_Read_Inst_Num / 2;
196 
197  constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
198 
199  constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
200  constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * 2;
201 
202  static_assert(num_buffer_load_inst_a == num_ds_write_inst_a);
203 
204  constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num * 2;
205  constexpr auto mfma_cycle = HotLoopInstList::C_MFMA_Inst_Cycle;
206 
207  constexpr auto ds_read_a_issue_cycle =
208  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
209  constexpr auto ds_read_a_mfma_rate =
210  math::integer_divide_ceil(mfma_cycle - 4, 2 * ds_read_a_issue_cycle);
211 
212  // constexpr auto num_dsread_a_mfma =
213  // (num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
214 
215  constexpr auto num_total_stages = MRepeat;
216 
217  // Group num_mfma_perstage num_ds_read_a_perstage
218  // since we want to reuse a local register buffer
219  constexpr auto num_mfma_perstage = num_mfma_inst / num_total_stages;
220  constexpr auto num_ds_read_a_perstage = num_ds_read_inst_a / num_total_stages;
221 
222  constexpr auto num_ds_read_a_mfma_perstage =
223  math::integer_divide_ceil(num_ds_read_a_perstage, ds_read_a_mfma_rate);
224 
225  constexpr auto num_ds_read_a_prefetch_stages = 2;
226 
227  constexpr auto buffer_load_perstage_more = math::integer_divide_ceil(
228  (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2));
229  constexpr auto buffer_load_perstage_less = math::integer_divide_floor(
230  (num_buffer_load_inst_a + num_buffer_load_inst_b), (num_total_stages - 2));
231 
232  constexpr auto buffer_load_stages_more =
233  (num_buffer_load_inst_a + num_buffer_load_inst_b) -
234  math::integer_divide_floor((num_buffer_load_inst_a + num_buffer_load_inst_b),
235  (num_total_stages - 2)) *
236  ((num_total_stages - 2));
237 
238  constexpr auto buffer_load_b_stages =
239  buffer_load_perstage_more * buffer_load_stages_more > num_buffer_load_inst_b
240  ? num_buffer_load_inst_b / buffer_load_perstage_more
241  : (buffer_load_stages_more +
242  (num_buffer_load_inst_b - buffer_load_perstage_more * buffer_load_stages_more) /
243  buffer_load_perstage_less);
244 
245  constexpr auto buffer_load_a_stages =
246  num_total_stages - num_ds_read_a_prefetch_stages - buffer_load_b_stages;
247 
248  constexpr auto buffer_load_issue_point_b = 0;
249  constexpr auto buffer_load_issue_point_interval_more =
250  num_mfma_perstage / buffer_load_perstage_more;
251  constexpr auto buffer_load_issue_point_interval_less =
252  num_mfma_perstage / buffer_load_perstage_less;
253  constexpr auto ds_write_issue_point = 0;
254  constexpr auto buffer_load_issue_point_a = num_mfma_perstage >= 3 ? 1 : 0;
255 
256  // B global read
258  static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
259  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
260 
261  if constexpr(((i < buffer_load_stages_more) &&
262  (imfma % buffer_load_issue_point_interval_more ==
263  buffer_load_issue_point_b)) ||
264  ((i >= buffer_load_stages_more) &&
265  (imfma % buffer_load_issue_point_interval_less ==
266  buffer_load_issue_point_b)))
267  {
268  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
269  }
270 
271  if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
272  {
273  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
274  }
275  });
276  });
277 
278  // A global read + A local write
280  static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
281  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
282  if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
283  (imfma % buffer_load_issue_point_interval_more ==
284  ds_write_issue_point)) ||
285  (((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
286  (imfma % buffer_load_issue_point_interval_less ==
287  ds_write_issue_point)))
288  {
289  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
290  }
291  if constexpr((((i + buffer_load_b_stages) < buffer_load_stages_more) &&
292  (imfma % buffer_load_issue_point_interval_more ==
293  buffer_load_issue_point_a)) ||
294  (((i + buffer_load_b_stages) >= buffer_load_stages_more) &&
295  (imfma % buffer_load_issue_point_interval_less ==
296  buffer_load_issue_point_a)))
297  {
298  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
299  }
300  if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
301  {
302  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
303  }
304  });
305  });
306 
307  // lds synchronization, prefetch next loop local A
309  ignore = i;
310  static_for<0, num_mfma_perstage, 1>{}([&](auto imfma) {
311  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
312  if constexpr(imfma >= (num_mfma_perstage - num_ds_read_a_mfma_perstage))
313  {
314  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
315  }
316  });
317  });
318  }
319 
320  template <typename Stage>
321  __device__ static constexpr auto EpilogueScheduler_1(Stage stage)
322  {
323  constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
324  constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
325  constexpr auto num_buffer_load_inst_b =
326  MWaves * HotLoopInstList::B_Buffer_Load_Inst_Num * 2;
327 
328  constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num * 2;
329 
330  constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat;
331  constexpr auto staged_num_mfma = num_mfma / MRepeat;
332 
333  constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
334 
335  if constexpr(stage.value == 0)
336  {
337  constexpr auto staged_num_buffer_load_b_per_ds_read_a =
338  num_buffer_load_inst_b / staged_num_ds_read_inst_a;
339  constexpr auto staged_num_mfma_per_buffer_load_b =
340  staged_num_mfma / num_buffer_load_inst_b;
341  // B global
343  ignore = i_inst;
344 
346  ignore = ibuf_inst;
347  __builtin_amdgcn_sched_group_barrier(
348  0x008, staged_num_mfma_per_buffer_load_b, 0); // MFMA
349  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
350  });
351 
352  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
353  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
354  __builtin_amdgcn_sched_group_barrier(
355  0x008, staged_num_mfma_per_buffer_load_b - 1, 0); // MFMA
356  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
357  });
358 
359  __builtin_amdgcn_sched_barrier(0);
360  }
361  else if constexpr(stage.value == 1)
362  {
363  constexpr auto staged_num_mfma_per_ds_write_a =
364  math::integer_divide_ceil(staged_num_mfma, num_ds_write_inst_a);
365 
366  constexpr auto stage_more_mfma =
367  staged_num_mfma - (staged_num_mfma_per_ds_write_a - 1) * num_ds_write_inst_a;
368 
369  // A local write
370  static_for<0, num_ds_write_inst_a, 1>{}([&](auto i_inst) {
371  if constexpr(i_inst.value < stage_more_mfma)
372  {
373  if(i_inst.value < staged_num_ds_read_inst_a)
374  {
375  __builtin_amdgcn_sched_group_barrier(
376  0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
377  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
378  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
379  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
380  }
381  else
382  {
383  __builtin_amdgcn_sched_group_barrier(
384  0x008, staged_num_mfma_per_ds_write_a, 0); // MFMA
385  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
386  }
387  }
388  else
389  {
390  if(i_inst.value < staged_num_ds_read_inst_a)
391  {
392  __builtin_amdgcn_sched_group_barrier(
393  0x008, staged_num_mfma_per_ds_write_a - 2, 0); // MFMA
394  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
395  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
396  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
397  }
398  else
399  {
400  __builtin_amdgcn_sched_group_barrier(
401  0x008, staged_num_mfma_per_ds_write_a - 1, 0); // MFMA
402  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS Write
403  }
404  }
405  });
406  __builtin_amdgcn_sched_barrier(0);
407  }
408  else
409  {
410  // A local Read
412  ignore = i_inst;
413  __builtin_amdgcn_sched_group_barrier(
414  0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA
415  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
416  });
417 
418  __builtin_amdgcn_sched_barrier(0);
419  }
420  }
421 
422  __device__ static constexpr auto EpilogueScheduler_2()
423  {
424  constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
425 
426  constexpr auto num_mfma = HotLoopInstList::C_MFMA_Inst_Num * 2;
427 
428  constexpr auto staged_num_ds_read_inst_a = num_ds_read_inst_a / MRepeat;
429  constexpr auto staged_num_mfma = num_mfma / MRepeat;
430 
431  constexpr auto staged_num_mfma_per_ds_read_a = staged_num_mfma / staged_num_ds_read_inst_a;
432 
433  // A local Read
435  ignore = i_inst;
436  __builtin_amdgcn_sched_group_barrier(0x008, staged_num_mfma_per_ds_read_a, 0); // MFMA
437  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
438  });
439 
440  __builtin_amdgcn_sched_barrier(0);
441  }
442 
443  template <bool HasMainLoop,
444  TailNumber TailNum,
445  typename AGridDesc,
446  typename ABlockDesc,
447  typename ABlockTransfer,
448  typename AGridBuffer,
449  typename ABlockBuffer,
450  typename ABlockTransferStep,
451  typename BGridDesc,
452  typename BBlockTransfer,
453  typename BGridBuffer,
454  typename BBlockBuffer,
455  typename BBlockTransferStep,
456  typename CThreadBuffer>
457  __device__ void Run(const AGridDesc& a_grid_desc,
458  const ABlockDesc& a_block_desc,
459  ABlockTransfer& a_blockwise_copy,
460  const AGridBuffer& a_grid_buf,
461  ABlockBuffer& a_block_buf,
462  const ABlockTransferStep& a_block_copy_step,
463  const BGridDesc& b_grid_desc,
464  BBlockTransfer& b_blockwise_copy,
465  BBlockTransfer& b_blockwise_copy_up,
466  const BGridBuffer& b_grid_buf,
467  const BGridBuffer& b_grid_buf_up,
468  BBlockBuffer& b_block_buf,
469  const BBlockTransferStep& b_block_copy_step,
470  CThreadBuffer& c_thread_buf,
471  CThreadBuffer& c_thread_buf_up,
472  index_t num_loop) const
473  {
474  ignore = b_block_buf;
475  __builtin_amdgcn_sched_barrier(0);
476  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
477  a_thread_desc_.GetElementSpaceSize());
478  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
479  b_thread_desc_.GetElementSpaceSize());
480 
481  StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
482  StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs_up;
483  constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0);
484 
485  // Global prefetch A1 B1
486  b_blockwise_copy.Run(b_grid_desc,
487  b_grid_buf,
488  b_block_desc_n0_n1_k0_k1,
489  b_block_origin_idx,
490  b_thread_bufs(I0));
491 
492  b_blockwise_copy_up.Run(b_grid_desc,
493  b_grid_buf_up,
494  b_block_desc_n0_n1_k0_k1,
495  b_block_origin_idx,
496  b_thread_bufs_up(I0));
497  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
498  b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
499 
500  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
501  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
502  __builtin_amdgcn_sched_barrier(0);
503 
504  // // Local prefill A1
505  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
506 
507  // // Global prefetch A2
508  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
509  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
510 
511  // Local prefetch A1
512  block_sync_lds();
513  static_for<0, 2, 1>{}([&](auto m0) {
514  static_for<0, KRepeat, 1>{}([&](auto k0) {
515  static_for<0, KGroup, 1>{}([&](auto kg0) {
516  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k0_k1_k2,
517  make_tuple(m0, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
518  a_block_buf.At(I0),
519  a_thread_desc_,
520  make_tuple(m0, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
521  a_thread_buf);
522  });
523  });
524  });
525 
526  // Initialize C
527  c_thread_buf.Clear();
528  c_thread_buf_up.Clear();
529 
530  __builtin_amdgcn_sched_barrier(0);
531 
532  // main body
533  if constexpr(HasMainLoop)
534  {
535  index_t i = 0;
536  do
537  {
538  auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
539  b_blockwise_copy.Run(b_grid_desc,
540  b_grid_buf,
541  b_block_desc_n0_n1_k0_k1,
542  b_block_origin_idx,
543  b_thread_bufs(local_read_buf));
544  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
545  b_blockwise_copy_up.Run(b_grid_desc,
546  b_grid_buf_up,
547  b_block_desc_n0_n1_k0_k1,
548  b_block_origin_idx,
549  b_thread_bufs_up(local_read_buf));
550  b_blockwise_copy_up.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
551 
552  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf));
553  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
554  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
555  static_for<0, MRepeat, 1>{}([&](auto m0) {
556  static_for<0, KRepeat, 1>{}([&](auto k0) {
557  static_for<0, NRepeat, 1>{}([&](auto n0) {
560  vector_type<ComputeDataType, KPack> b_thread_vec_up;
561 
562  static_for<0, KPack, 1>{}([&](auto ik) {
563  a_thread_vec.template AsType<ComputeDataType>()(ik) =
564  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
565  make_tuple((m0 + HotloopLocalBufSwitch * mfma_reg_buf) %
566  2,
567  I0,
568  I0,
569  k0,
570  I0,
571  ik))>{}];
572  b_thread_vec.template AsType<ComputeDataType>()(ik) =
573  b_thread_bufs[mfma_reg_buf]
574  [Number<b_thread_desc_.CalculateOffset(
575  make_tuple(n0, I0, k0, ik))>{}];
576 
577  b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
578  b_thread_bufs_up[mfma_reg_buf]
579  [Number<b_thread_desc_.CalculateOffset(
580  make_tuple(n0, I0, k0, ik))>{}];
581  });
582 
583  using mfma_input_type =
584  typename vector_type<ComputeDataType,
585  xdlops_gemm.K1PerXdlops>::type;
586 
587  constexpr index_t c_offset =
588  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
589 
590  xdlops_gemm.Run(
591  a_thread_vec.template AsType<mfma_input_type>(),
592  b_thread_vec.template AsType<mfma_input_type>(),
593  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
594 
595  xdlops_gemm.Run(
596  a_thread_vec.template AsType<mfma_input_type>(),
597  b_thread_vec_up.template AsType<mfma_input_type>(),
598  c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
599  });
600  });
601 
602  if constexpr(m0.value == MRepeat - 2)
603  {
604  block_sync_lds();
605 
606  static_for<0, KRepeat, 1>{}([&](auto k0) {
607  static_for<0, KGroup, 1>{}([&](auto kg0) {
608  a_thread_copy_.Run(
609  a_block_desc_m0_m1_m2_k0_k1_k2,
610  make_tuple(Number<(m0 + 2) % MRepeat>{},
611  I0,
612  I0,
614  I0,
615  I0),
616  a_block_buf.At(local_read_buf),
617  a_thread_desc_,
618  make_tuple(
619  Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
620  2>{},
621  I0,
622  I0,
623  k0,
624  I0,
626  a_thread_buf);
627  });
628  });
629  }
630  else if constexpr(m0.value == (MRepeat - 1))
631  {
632  static_for<0, KRepeat, 1>{}([&](auto k0) {
633  static_for<0, KGroup, 1>{}([&](auto kg0) {
634  a_thread_copy_.Run(
635  a_block_desc_m0_m1_m2_k0_k1_k2,
636  make_tuple(Number<(m0 + 2) % MRepeat>{},
637  I0,
638  I0,
640  I0,
641  I0),
642  a_block_buf.At(local_read_buf),
643  a_thread_desc_,
644  make_tuple(
645  Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
646  2>{},
647  I0,
648  I0,
649  k0,
650  I0,
652  a_thread_buf);
653  });
654  });
655  }
656  else
657  {
658  static_for<0, KRepeat, 1>{}([&](auto k0) {
659  static_for<0, KGroup, 1>{}([&](auto kg0) {
660  a_thread_copy_.Run(
661  a_block_desc_m0_m1_m2_k0_k1_k2,
662  make_tuple(Number<(m0 + 2) % MRepeat>{},
663  I0,
664  I0,
666  I0,
667  I0),
668  a_block_buf.At(mfma_reg_buf),
669  a_thread_desc_,
670  make_tuple(
671  Number<(m0 + 2 + HotloopLocalBufSwitch * mfma_reg_buf) %
672  2>{},
673  I0,
674  I0,
675  k0,
676  I0,
678  a_thread_buf);
679  });
680  });
681  }
682  });
683  HotLoopScheduler();
684  };
685 
686  LoopFunc(I0, I1);
687  LoopFunc(I1, I0);
688 
689  i += 2;
690  } while(i < (num_loop - 2));
691  }
692  // tail
693  if constexpr(TailNum == TailNumber::Even)
694  {
695  b_blockwise_copy.Run(b_grid_desc,
696  b_grid_buf,
697  b_block_desc_n0_n1_k0_k1,
698  b_block_origin_idx,
699  b_thread_bufs(I1));
700 
701  b_blockwise_copy_up.Run(b_grid_desc,
702  b_grid_buf_up,
703  b_block_desc_n0_n1_k0_k1,
704  b_block_origin_idx,
705  b_thread_bufs_up(I1));
706  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
707  static_for<0, MRepeat, 1>{}([&](auto m0) {
708  static_for<0, KRepeat, 1>{}([&](auto k0) {
709  static_for<0, NRepeat, 1>{}([&](auto n0) {
712  vector_type<ComputeDataType, KPack> b_thread_vec_up;
713 
714  static_for<0, KPack, 1>{}([&](auto ik) {
715  a_thread_vec.template AsType<ComputeDataType>()(ik) =
716  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
717  make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
718  b_thread_vec.template AsType<ComputeDataType>()(ik) =
719  b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
720  make_tuple(n0, I0, k0, ik))>{}];
721 
722  b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
723  b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
724  make_tuple(n0, I0, k0, ik))>{}];
725  });
726 
727  using mfma_input_type =
728  typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
729 
730  constexpr index_t c_offset =
731  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
732 
733  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
734  b_thread_vec.template AsType<mfma_input_type>(),
735  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
736 
737  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
738  b_thread_vec_up.template AsType<mfma_input_type>(),
739  c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
740  });
741  });
742  if constexpr(m0.value == (MRepeat - 2))
743  {
744  block_sync_lds();
745 
746  static_for<0, KRepeat, 1>{}([&](auto k0) {
747  static_for<0, KGroup, 1>{}([&](auto kg0) {
748  a_thread_copy_.Run(
749  a_block_desc_m0_m1_m2_k0_k1_k2,
750  make_tuple(Number<(m0 + 2) % MRepeat>{},
751  I0,
752  I0,
754  I0,
755  I0),
756  a_block_buf.At(I1),
757  a_thread_desc_,
758  make_tuple(
759  Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
760  a_thread_buf);
761  });
762  });
763  }
764  else if constexpr(m0.value == MRepeat - 1)
765  {
766  static_for<0, KRepeat, 1>{}([&](auto k0) {
767  static_for<0, KGroup, 1>{}([&](auto kg0) {
768  a_thread_copy_.Run(
769  a_block_desc_m0_m1_m2_k0_k1_k2,
770  make_tuple(Number<(m0 + 2) % MRepeat>{},
771  I0,
772  I0,
774  I0,
775  I0),
776  a_block_buf.At(I1),
777  a_thread_desc_,
778  make_tuple(
779  Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
780  a_thread_buf);
781  });
782  });
783  }
784  else
785  {
786  static_for<0, KRepeat, 1>{}([&](auto k0) {
787  static_for<0, KGroup, 1>{}([&](auto kg0) {
788  a_thread_copy_.Run(
789  a_block_desc_m0_m1_m2_k0_k1_k2,
790  make_tuple(Number<(m0 + 2) % MRepeat>{},
791  I0,
792  I0,
794  I0,
795  I0),
796  a_block_buf.At(I0),
797  a_thread_desc_,
798  make_tuple(
799  Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
800  a_thread_buf);
801  });
802  });
803  }
804  });
805 
806  HotLoopScheduler();
807 
808  static_for<0, MRepeat, 1>{}([&](auto m0) {
809  static_for<0, KRepeat, 1>{}([&](auto k0) {
810  static_for<0, NRepeat, 1>{}([&](auto n0) {
813  vector_type<ComputeDataType, KPack> b_thread_vec_up;
814 
815  static_for<0, KPack, 1>{}([&](auto ik) {
816  a_thread_vec.template AsType<ComputeDataType>()(ik) =
817  a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
818  (m0 + HotloopLocalBufSwitch) % 2, I0, I0, k0, I0, ik))>{}];
819  b_thread_vec.template AsType<ComputeDataType>()(ik) =
820  b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
821  make_tuple(n0, I0, k0, ik))>{}];
822  b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
823  b_thread_bufs_up[I1][Number<b_thread_desc_.CalculateOffset(
824  make_tuple(n0, I0, k0, ik))>{}];
825  });
826 
827  using mfma_input_type =
828  typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
829 
830  constexpr index_t c_offset =
831  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
832 
833  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
834  b_thread_vec.template AsType<mfma_input_type>(),
835  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
836 
837  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
838  b_thread_vec_up.template AsType<mfma_input_type>(),
839  c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
840  });
841  });
842 
843  if constexpr(m0.value < (MRepeat - 2))
844  {
845  static_for<0, KRepeat, 1>{}([&](auto k0) {
846  static_for<0, KGroup, 1>{}([&](auto kg0) {
847  a_thread_copy_.Run(
848  a_block_desc_m0_m1_m2_k0_k1_k2,
849  make_tuple(
850  Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
851  a_block_buf.At(I1),
852  a_thread_desc_,
853  make_tuple(Number<(m0 + 2 + HotloopLocalBufSwitch) % 2>{},
854  I0,
855  I0,
856  k0,
857  I0,
859  a_thread_buf);
860  });
861  });
862  }
863  });
864 
865  HotLoopScheduler();
866  // Let's leak last MFMA block to epilogue region, cover the potential lds-shuffle
867  // latency
868  }
869  else if constexpr(TailNum == TailNumber::Odd)
870  {
871  static_for<0, MRepeat, 1>{}([&](auto m0) {
872  static_for<0, KRepeat, 1>{}([&](auto k0) {
873  static_for<0, NRepeat, 1>{}([&](auto n0) {
876  vector_type<ComputeDataType, KPack> b_thread_vec_up;
877 
878  static_for<0, KPack, 1>{}([&](auto ik) {
879  a_thread_vec.template AsType<ComputeDataType>()(ik) =
880  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
881  make_tuple(m0 % 2, I0, I0, k0, I0, ik))>{}];
882  b_thread_vec.template AsType<ComputeDataType>()(ik) =
883  b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
884  make_tuple(n0, I0, k0, ik))>{}];
885  b_thread_vec_up.template AsType<ComputeDataType>()(ik) =
886  b_thread_bufs_up[I0][Number<b_thread_desc_.CalculateOffset(
887  make_tuple(n0, I0, k0, ik))>{}];
888  });
889 
890  using mfma_input_type =
891  typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
892 
893  constexpr index_t c_offset =
894  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
895 
896  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
897  b_thread_vec.template AsType<mfma_input_type>(),
898  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
899  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
900  b_thread_vec_up.template AsType<mfma_input_type>(),
901  c_thread_buf_up.GetVectorTypeReference(Number<c_offset>{}));
902  });
903  });
904 
905  if constexpr(m0.value < (MRepeat - 2))
906  {
907  static_for<0, KRepeat, 1>{}([&](auto k0) {
908  static_for<0, KGroup, 1>{}([&](auto kg0) {
909  a_thread_copy_.Run(
910  a_block_desc_m0_m1_m2_k0_k1_k2,
911  make_tuple(
912  Number<m0 + 2>{}, I0, I0, Number<k0 * KGroup + kg0>{}, I0, I0),
913  a_block_buf.At(I0),
914  a_thread_desc_,
915  make_tuple(
916  Number<(m0 + 2) % 2>{}, I0, I0, k0, I0, Number<kg0 * A_K1>{}),
917  a_thread_buf);
918  });
919  });
920  }
921  });
922  }
923  }
924 
925  protected:
926  // MRepeat MWave MLane KRepeat KLane KPack
927  // KRepeat -> MRepeat-> Mwave->KLane->MLane->KPack
928  // Reduce the vgpr usage here.
929  static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
930  make_tuple(I2, I1, I1, Number<KRepeat>{}, I1, Number<KPack>{}));
931 
933  ComputeDataType,
934  decltype(a_block_desc_m0_m1_m2_k0_k1_k2),
935  decltype(a_thread_desc_),
936  Sequence<1, 1, 1, 1, 1, KPack / KGroup>,
938  5,
939  A_K1,
940  A_K1>;
941 
942  AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex6D()};
943 
944  static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
945  make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}));
946 
947  static constexpr BTileDesc b_block_desc_n0_n1_k0_k1;
948 
949  using Base::c_thread_desc_;
950 };
951 
952 } // namespace ck
__host__ constexpr __device__ auto integer_divide_ceil(X x, Y y)
Definition: math.hpp:72
__host__ constexpr __device__ auto integer_divide_floor(X x, Y y)
Definition: math.hpp:66
Definition: ck.hpp:269
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:300
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, BBlockTransfer &b_blockwise_copy, BBlockTransfer &b_blockwise_copy_up, const BGridBuffer &b_grid_buf, const BGridBuffer &b_grid_buf_up, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, CThreadBuffer &c_thread_buf_up, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp:457
Definition: blockwise_gemm_pipeline_xdlops_b_preshuffle_gufusion_v3.hpp:37
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10