/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp Source File
blockwise_gemm_pipeline_xdlops.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
11 
12 // Double LDS buffer
13 // Prefetech 2 stage
14 // Local prefetch 1 stage
15 
16 namespace ck {
17 
18 template <index_t BlockSize,
19  index_t MPerBlock,
20  index_t NPerBlock,
21  index_t KPerBlock,
22  index_t ABufferLoadWidth,
23  index_t BBufferLoadWidth,
24  index_t ALDSWriteWidth,
25  index_t BLDSWriteWidth,
26  index_t ALDSReadWidth,
27  index_t BLDSReadWidth,
28  index_t MRepeat,
29  index_t NRepeat,
30  index_t MPerXDL,
31  index_t NPerXDL,
32  index_t KPerXDL>
34 {
35  static constexpr index_t WaveSize = 64;
36  static constexpr index_t WaveNumM = MPerBlock / (MRepeat * MPerXDL);
37  static constexpr index_t WaveNumN = NPerBlock / (NRepeat * NPerXDL);
38 
39  static constexpr index_t A_Buffer_Load_Inst_Num =
40  MPerBlock * KPerBlock / (BlockSize * ABufferLoadWidth);
41  static constexpr index_t B_Buffer_Load_Inst_Num =
42  NPerBlock * KPerBlock / (BlockSize * BBufferLoadWidth);
43 
44  static constexpr index_t A_LDS_Write_Inst_Num =
45  MPerBlock * KPerBlock / (BlockSize * ALDSWriteWidth);
46  static constexpr index_t B_LDS_Write_Inst_Num =
47  NPerBlock * KPerBlock / (BlockSize * BLDSWriteWidth);
48 
49  static constexpr index_t A_LDS_Read_Inst_Num =
50  WaveNumN * MPerBlock * KPerBlock / (BlockSize * ALDSReadWidth);
51  static constexpr index_t B_LDS_Read_Inst_Num =
52  WaveNumM * MPerBlock * KPerBlock / (BlockSize * BLDSReadWidth);
53 
54  static constexpr index_t C_MFMA_Inst_Num =
55  MPerBlock * NPerBlock * KPerBlock / (BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
56 
57  static constexpr auto Print()
58  {
59  printf(" Blk/Wave Size: %d, %d, M/N/K PerBlk: %d, %d, %d, M/N/K PerXdl: %d, %d, %d\n",
60  BlockSize,
61  WaveSize,
62  MPerBlock,
63  NPerBlock,
64  KPerBlock,
65  MPerXDL,
66  NPerXDL,
67  KPerXDL);
68 
69  printf(" A/B buffer load inst: %d, %d\n A/B LDS write inst: %d, %d\n A/B LDS read inst: "
70  "%d, %d\n C MFMA inst: %d\n",
78  }
79 };
80 
81 template <
82  index_t BlockSize,
83  typename FloatAB,
84  typename FloatAcc,
85  typename ATileDesc,
86  typename BTileDesc,
87  typename AMmaTileDesc,
88  typename BMmaTileDesc,
89  index_t MPerBlock,
90  index_t NPerBlock,
91  index_t KPerBlock,
92  index_t MPerXDL,
93  index_t NPerXDL,
94  index_t MRepeat,
95  index_t NRepeat,
96  index_t KPack,
97  bool TransposeC = false,
98  index_t AMmaKStride =
99  KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops,
100  index_t BMmaKStride =
101  KPack* XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops>
103 {
104  static constexpr auto I0 = Number<0>{};
105  static constexpr auto I1 = Number<1>{};
106  static constexpr auto I2 = Number<2>{};
107  static constexpr auto I3 = Number<3>{};
108 
110 
111  static constexpr index_t WaveSize = get_warp_size();
112 
113  static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
114  static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
115  static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
116  static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
117 
118  static constexpr auto xdlops_gemm =
120 
121  static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
122  static constexpr index_t KRepeat = KPerThread / KPack;
123 
124  static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
125  static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
126 
128  MPerBlock,
129  NPerBlock,
130  KPerBlock,
131  A_K1,
132  B_K1,
133  A_K1,
134  B_K1,
135  KPack,
136  KPack,
137  MRepeat,
138  NRepeat,
139  MPerXDL,
140  NPerXDL,
141  xdlops_gemm.KPerXdlops>;
142 
143  static_assert(KPerThread % KPack == 0,
144  "Wrong KPack setting; try increasing KPerThread or decreasing KPack");
145 
147  FloatAcc,
148  MRepeat * NRepeat,
149  xdlops_gemm.GetRegSizePerXdlops(),
150  true>
152 
153  __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
154 
155  __device__ static auto GetWaveIdx()
156  {
157  const index_t thread_id = ThisThreadBlock::GetThreadId();
158 
159  constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
163 
164  return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
165  }
166 
167  __device__ static auto CalculateAThreadOriginDataIndex()
168  {
169  const auto wave_idx = GetWaveIdx();
170 
171  const auto waveId_m = wave_idx[I0];
172 
173  const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
174 
175  return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]);
176  }
177 
178  __device__ static auto CalculateBThreadOriginDataIndex()
179  {
180  const auto wave_idx = GetWaveIdx();
181 
182  const auto waveId_n = wave_idx[I1];
183 
184  const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
185 
186  return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]);
187  }
188 
189  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
190  __device__ static auto
192  {
193  const auto wave_idx = GetWaveIdx();
194 
195  const auto waveId_m = wave_idx[I0];
196  const auto waveId_n = wave_idx[I1];
197 
198  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
199 
200  constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
201  make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
204 
205  constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
206  make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
209 
210  const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
211  make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
212  const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
213  make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
214 
215  return make_tuple(c_thread_m, c_thread_n);
216  }
217 
218  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
219  __device__ static auto
221  {
222  const auto wave_idx = GetWaveIdx();
223 
224  const auto waveId_m = wave_idx[I0];
225  const auto waveId_n = wave_idx[I1];
226 
227  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
228 
229  return make_tuple(
230  m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
231  }
232 
234 
235  __host__ __device__
238  : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
239  {
240  static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
241  "wrong! Desc should be known at compile-time");
242 
243  static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
244  "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
245 
246  static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
247  "wrong!");
248 
249  // HotLoopInstList::Print();
250  }
251 
252  // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
253  __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
254  {
255  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
256 
257  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
258  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
259  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
260  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
261 
263  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
264  }
265 
266  // XDL output supporting C_xdl = A_xdl * B_xdl
267  __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
268  {
269  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
270 
271  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
272  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
273  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
274  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
275 
277  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
278  }
279 
280  __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
281  {
282  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
283 
284  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
285  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
286  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
287  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
288 
290  make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
291  }
292 
293  // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
294  __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
295  {
296  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
298  Number<NRepeat>{},
299  Number<MWaves>{},
300  Number<NWaves>{},
301  Number<MPerXDL>{},
302  Number<NPerXDL>{}));
303 
304  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
305  }
306 
307  // XDL output supporting C_xdl = A_xdl * B_xdl
308  __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
309  {
310  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
312  Number<NRepeat>{},
313  Number<MWaves>{},
314  Number<NWaves>{},
315  Number<MPerXDL>{},
316  Number<NPerXDL>{}));
317 
318  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
319  }
320 
321  __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
322  {
323  constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
325  Number<MRepeat>{},
326  Number<NRepeat>{},
327  Number<MWaves>{},
328  Number<NWaves>{},
329  Number<MPerXDL>{},
330  Number<NPerXDL>{}));
331 
332  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
333  c_block_desc_g_m0_n0_m1_n1_m2_n2);
334  }
335 
336  template <typename CGridDesc_M_N>
337  __host__ __device__ static constexpr auto
338  MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
339  {
340  const auto M = c_grid_desc_m_n.GetLength(I0);
341  const auto N = c_grid_desc_m_n.GetLength(I1);
342 
343  const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
344  c_grid_desc_m_n,
345  make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
346  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
349 
350  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
351  }
352 
353  template <typename CGridDesc_G_M_N>
354  __host__ __device__ static constexpr auto
355  MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
356  {
357  const auto G = c_grid_desc_g_m_n.GetLength(I0);
358  const auto M = c_grid_desc_g_m_n.GetLength(I1);
359  const auto N = c_grid_desc_g_m_n.GetLength(I2);
360 
361  const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
362  c_grid_desc_g_m_n,
364  make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
365  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
368 
369  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
370  c_grid_desc_g_m0_n0_m1_n1_m2_n2);
371  }
372 
373  __device__ static constexpr auto HotLoopScheduler()
374  {
375  // schedule
376  constexpr auto num_ds_read_inst =
378  constexpr auto num_ds_write_inst =
380  ;
381  constexpr auto num_buffer_load_inst =
383  ;
384  constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
385 
386  constexpr auto num_issue = num_buffer_load_inst;
387 
388  static_for<0, num_issue, 1>{}([&](auto i) {
389  ignore = i;
390  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
391  __builtin_amdgcn_sched_group_barrier(
392  0x100, num_ds_read_inst / num_buffer_load_inst, 0); // DS read
393  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
394  __builtin_amdgcn_sched_group_barrier(
395  0x200, num_ds_write_inst / num_buffer_load_inst, 0); // DS write
396  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
397  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
398  __builtin_amdgcn_sched_group_barrier(
399  0x008, num_mfma_inst / num_buffer_load_inst - 3, 0); // MFMA
400  });
401  }
402 
403  template <index_t stage>
404  __device__ static constexpr auto TailScheduler()
405  {
406  }
407 
408  template <>
409  __device__ constexpr auto TailScheduler<1>()
410  {
411  // schedule
412  constexpr auto num_ds_read_inst =
414  constexpr auto num_ds_write_inst =
416  ;
417  constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
418 
419  constexpr auto num_issue = num_ds_write_inst;
420 
421  static_for<0, num_issue, 1>{}([&](auto i) {
422  ignore = i;
423  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
424  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
425  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
426  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
427  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
428  __builtin_amdgcn_sched_group_barrier(
429  0x100, num_ds_read_inst / num_ds_write_inst - 1, 0); // DS read
430  __builtin_amdgcn_sched_group_barrier(
431  0x008, num_mfma_inst / num_ds_write_inst - 3, 0); // MFMA
432  });
433  }
434 
435  template <>
436  __device__ constexpr auto TailScheduler<2>()
437  {
438  // schedule
439  constexpr auto num_ds_read_inst =
441  constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
442 
443  constexpr auto num_issue = num_ds_read_inst;
444 
445  static_for<0, num_issue, 1>{}([&](auto i) {
446  ignore = i;
447  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
448  __builtin_amdgcn_sched_group_barrier(
449  0x008, num_mfma_inst / num_ds_read_inst, 0); // MFMA
450  });
451  }
452 
453  static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
454  static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
455 
456  template <bool HasMainLoop,
457  index_t TailNum,
458  typename AGridDesc,
459  typename ABlockDesc,
460  typename ABlockTransfer,
461  typename AGridBuffer,
462  typename ABlockBuffer,
463  typename ABlockTransferStep,
464  typename BGridDesc,
465  typename BBlockDesc,
466  typename BBlockTransfer,
467  typename BGridBuffer,
468  typename BBlockBuffer,
469  typename BBlockTransferStep,
470  typename CThreadBuffer>
471  __device__ void Run(const AGridDesc& a_grid_desc,
472  const ABlockDesc& a_block_desc,
473  ABlockTransfer& a_blockwise_copy,
474  const AGridBuffer& a_grid_buf,
475  ABlockBuffer& a_block_buf,
476  const ABlockTransferStep& a_block_copy_step,
477  const BGridDesc& b_grid_desc,
478  const BBlockDesc& b_block_desc,
479  BBlockTransfer& b_blockwise_copy,
480  const BGridBuffer& b_grid_buf,
481  BBlockBuffer& b_block_buf,
482  const BBlockTransferStep& b_block_copy_step,
483  CThreadBuffer& c_thread_buf,
484  index_t num_loop) const
485  {
486  __builtin_amdgcn_sched_barrier(0);
487  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
488  a_thread_desc_.GetElementSpaceSize());
489  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
490  b_thread_desc_.GetElementSpaceSize());
491 
492  StaticallyIndexedArray<decltype(a_thread_buf), Number<2>{}> a_thread_bufs;
493  StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
494  // Inst List:
495  // ds_read_b128: 16
496  // ds_write_b128: 8
497  // buffer_load_dwordx4: 16
498  // v_mfma: 0
499  // -------------------------------------------------------------------------------------------
500 
501  // Global prefetch 1th, Fill Ping LDS
502  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
503  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
504 
505  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
506  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
507 
508  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
509  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0));
510 
511  // Local prefetch 1th, Fill Ping Reg
512  block_sync_lds();
513  static_for<0, KRepeat, 1>{}([&](auto k) {
514  static_for<0, MRepeat, 1>{}([&](auto m0) {
517  a_block_buf.At(I0),
519  make_tuple(m0, I0, k, I0),
520  a_thread_bufs(I0));
521  static_for<0, NRepeat, 1>{}([&](auto n0) {
524  b_block_buf.At(I0),
526  make_tuple(n0, I0, k, I0),
527  b_thread_bufs(I0));
528  });
529  });
530  });
531 
532  // Global prefetch 2th, Fill Pong LDS
533  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
534  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
535 
536  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
537  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
538 
539  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I1));
540  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I1));
541 
542  // Global prefetch 3rd
543  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
544  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
545 
546  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
547  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
548 
549  // Initialize C
550  c_thread_buf.Clear();
551 
552  // main body
553  if constexpr(HasMainLoop)
554  {
555  index_t i = 0;
556  // This hot loop has two legacy loopover, to implement the double local buffer strategy
557  do
558  {
559  // -------------------------------------------------------------------------------------------
560  using PingP1 = Number<0>;
561  using PongP1 = Number<1>;
562  // MFMA: Ping Reg
563  // DS_WRITE: To Ping LDS
564  // DS_READ: Pong LDS to Pong Reg
565  block_sync_lds();
566 
567  static_for<0, KRepeat, 1>{}([&](auto k) {
568  static_for<0, MRepeat, 1>{}([&](auto m0) {
571  a_block_buf.At(PongP1{}),
573  make_tuple(m0, I0, k, I0),
574  a_thread_bufs(PongP1{}));
575  static_for<0, NRepeat, 1>{}([&](auto n0) {
578  b_block_buf.At(PongP1{}),
580  make_tuple(n0, I0, k, I0),
581  b_thread_bufs(PongP1{}));
582  });
583  });
584  });
585 
586  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{}));
587  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{}));
588 
589  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
590  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
591 
592  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
593  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
594 
595  static_for<0, KRepeat, 1>{}([&](auto k0) {
596  static_for<0, MRepeat, 1>{}([&](auto m0) {
597  static_for<0, NRepeat, 1>{}([&](auto n0) {
598  vector_type<FloatAB, KPack> a_thread_vec;
599  vector_type<FloatAB, KPack> b_thread_vec;
600 
601  static_for<0, KPack, 1>{}([&](auto ik) {
602  a_thread_vec.template AsType<FloatAB>()(ik) =
603  a_thread_bufs[PingP1{}][Number<a_thread_desc_.CalculateOffset(
604  make_tuple(m0, I0, k0, ik))>{}];
605  b_thread_vec.template AsType<FloatAB>()(ik) =
606  b_thread_bufs[PingP1{}][Number<b_thread_desc_.CalculateOffset(
607  make_tuple(n0, I0, k0, ik))>{}];
608  });
609 
610  using mfma_input_type =
611  typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
612 
613  constexpr index_t c_offset =
614  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
615 
616  xdlops_gemm.Run(
617  a_thread_vec.template AsType<mfma_input_type>(),
618  b_thread_vec.template AsType<mfma_input_type>(),
619  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
620  });
621  });
622  });
623 
624  HotLoopScheduler();
625  __builtin_amdgcn_sched_barrier(0);
626 
627  // -------------------------------------------------------------------------------------------
628  using PingP2 = Number<1>;
629  using PongP2 = Number<0>;
630  // MFMA: Pong Reg
631  // DS_WRITE: To Pong LDS
632  // DS_READ: Ping LDS to Ping Reg
633  block_sync_lds();
634 
635  static_for<0, KRepeat, 1>{}([&](auto k) {
636  static_for<0, MRepeat, 1>{}([&](auto m0) {
637  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
638  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
639  a_block_buf.At(PongP2{}),
640  a_thread_desc_,
641  make_tuple(m0, I0, k, I0),
642  a_thread_bufs(PongP2{}));
643  static_for<0, NRepeat, 1>{}([&](auto n0) {
644  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
645  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
646  b_block_buf.At(PongP2{}),
647  b_thread_desc_,
648  make_tuple(n0, I0, k, I0),
649  b_thread_bufs(PongP2{}));
650  });
651  });
652  });
653 
654  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP2{}));
655  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP2{}));
656 
657  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
658  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
659 
660  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
661  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
662 
663  static_for<0, KRepeat, 1>{}([&](auto k0) {
664  static_for<0, MRepeat, 1>{}([&](auto m0) {
665  static_for<0, NRepeat, 1>{}([&](auto n0) {
666  vector_type<FloatAB, KPack> a_thread_vec;
667  vector_type<FloatAB, KPack> b_thread_vec;
668 
669  static_for<0, KPack, 1>{}([&](auto ik) {
670  a_thread_vec.template AsType<FloatAB>()(ik) =
671  a_thread_bufs[PingP2{}][Number<a_thread_desc_.CalculateOffset(
672  make_tuple(m0, I0, k0, ik))>{}];
673  b_thread_vec.template AsType<FloatAB>()(ik) =
674  b_thread_bufs[PingP2{}][Number<b_thread_desc_.CalculateOffset(
675  make_tuple(n0, I0, k0, ik))>{}];
676  });
677 
678  using mfma_input_type =
679  typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
680 
681  constexpr index_t c_offset =
682  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
683 
684  xdlops_gemm.Run(
685  a_thread_vec.template AsType<mfma_input_type>(),
686  b_thread_vec.template AsType<mfma_input_type>(),
687  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
688  });
689  });
690  });
691 
692  HotLoopScheduler();
693  __builtin_amdgcn_sched_barrier(0);
694 
695  i += 2;
696  } while(i < (num_loop - 3));
697  }
698 
699  // tail
700  if constexpr(TailNum == 3)
701  {
702  using PingP1 = Number<0>;
703  using PongP1 = Number<1>;
704  // MFMA: Ping Reg
705  // DS_WRITE: To Ping LDS
706  // DS_READ: Pong LDS to Pong Reg
707  block_sync_lds();
708 
709  static_for<0, KRepeat, 1>{}([&](auto k) {
710  static_for<0, MRepeat, 1>{}([&](auto m0) {
711  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
712  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
713  a_block_buf.At(PongP1{}),
714  a_thread_desc_,
715  make_tuple(m0, I0, k, I0),
716  a_thread_bufs(PongP1{}));
717  static_for<0, NRepeat, 1>{}([&](auto n0) {
718  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
719  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
720  b_block_buf.At(PongP1{}),
721  b_thread_desc_,
722  make_tuple(n0, I0, k, I0),
723  b_thread_bufs(PongP1{}));
724  });
725  });
726  });
727 
728  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(PingP1{}));
729  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(PingP1{}));
730 
731  static_for<0, KRepeat, 1>{}([&](auto k0) {
732  static_for<0, MRepeat, 1>{}([&](auto m0) {
733  static_for<0, NRepeat, 1>{}([&](auto n0) {
734  vector_type<FloatAB, KPack> a_thread_vec;
735  vector_type<FloatAB, KPack> b_thread_vec;
736 
737  static_for<0, KPack, 1>{}([&](auto ik) {
738  a_thread_vec.template AsType<FloatAB>()(ik) =
739  a_thread_bufs[PingP1{}][Number<a_thread_desc_.CalculateOffset(
740  make_tuple(m0, I0, k0, ik))>{}];
741  b_thread_vec.template AsType<FloatAB>()(ik) =
742  b_thread_bufs[PingP1{}][Number<b_thread_desc_.CalculateOffset(
743  make_tuple(n0, I0, k0, ik))>{}];
744  });
745 
746  using mfma_input_type =
747  typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
748 
749  constexpr index_t c_offset =
750  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
751 
752  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
753  b_thread_vec.template AsType<mfma_input_type>(),
754  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
755  });
756  });
757  });
758 
759  TailScheduler<1>();
760  __builtin_amdgcn_sched_barrier(0);
761 
762  // -------------------------------------------------------------------------------------------
763  using PingP2 = Number<1>;
764  using PongP2 = Number<0>;
765  // MFMA: Pong Reg
766  // DS_WRITE: To Pong LDS
767  // DS_READ: Ping LDS to Ping Reg
768  block_sync_lds();
769 
770  static_for<0, KRepeat, 1>{}([&](auto k) {
771  static_for<0, MRepeat, 1>{}([&](auto m0) {
772  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
773  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
774  a_block_buf.At(PongP2{}),
775  a_thread_desc_,
776  make_tuple(m0, I0, k, I0),
777  a_thread_bufs(PongP2{}));
778  static_for<0, NRepeat, 1>{}([&](auto n0) {
779  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
780  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
781  b_block_buf.At(PongP2{}),
782  b_thread_desc_,
783  make_tuple(n0, I0, k, I0),
784  b_thread_bufs(PongP2{}));
785  });
786  });
787  });
788 
789  static_for<0, KRepeat, 1>{}([&](auto k0) {
790  static_for<0, MRepeat, 1>{}([&](auto m0) {
791  static_for<0, NRepeat, 1>{}([&](auto n0) {
792  vector_type<FloatAB, KPack> a_thread_vec;
793  vector_type<FloatAB, KPack> b_thread_vec;
794 
795  static_for<0, KPack, 1>{}([&](auto ik) {
796  a_thread_vec.template AsType<FloatAB>()(ik) =
797  a_thread_bufs[PingP2{}][Number<a_thread_desc_.CalculateOffset(
798  make_tuple(m0, I0, k0, ik))>{}];
799  b_thread_vec.template AsType<FloatAB>()(ik) =
800  b_thread_bufs[PingP2{}][Number<b_thread_desc_.CalculateOffset(
801  make_tuple(n0, I0, k0, ik))>{}];
802  });
803 
804  using mfma_input_type =
805  typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
806 
807  constexpr index_t c_offset =
808  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
809 
810  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
811  b_thread_vec.template AsType<mfma_input_type>(),
812  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
813  });
814  });
815  });
816 
817  TailScheduler<2>();
818  __builtin_amdgcn_sched_barrier(0);
819 
820  static_for<0, KRepeat, 1>{}([&](auto k) {
821  static_for<0, MRepeat, 1>{}([&](auto m0) {
822  static_for<0, NRepeat, 1>{}([&](auto n0) {
823  vector_type<FloatAB, KPack> a_thread_vec;
824  vector_type<FloatAB, KPack> b_thread_vec;
825 
826  static_for<0, KPack, 1>{}([&](auto ik) {
827  a_thread_vec.template AsType<FloatAB>()(ik) =
828  a_thread_bufs[PongP2{}][Number<a_thread_desc_.CalculateOffset(
829  make_tuple(m0, I0, k, ik))>{}];
830  b_thread_vec.template AsType<FloatAB>()(ik) =
831  b_thread_bufs[PongP2{}][Number<b_thread_desc_.CalculateOffset(
832  make_tuple(n0, I0, k, ik))>{}];
833  });
834 
835  using mfma_input_type =
836  typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
837 
838  constexpr index_t c_offset =
839  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
840 
841  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
842  b_thread_vec.template AsType<mfma_input_type>(),
843  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
844  });
845  });
846  });
847 
848  // 64 v_mfma
849  __builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA
850  __builtin_amdgcn_sched_barrier(0);
851  }
852  else if constexpr(TailNum == 2)
853  {
854  using PingP1 = Number<0>;
855  using PongP1 = Number<1>;
856  // MFMA: Ping Reg
857  // DS_WRITE: To Ping LDS
858  // DS_READ: Pong LDS to Pong Reg
859  block_sync_lds();
860 
861  static_for<0, KRepeat, 1>{}([&](auto k) {
862  static_for<0, MRepeat, 1>{}([&](auto m0) {
863  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
864  make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
865  a_block_buf.At(PongP1{}),
866  a_thread_desc_,
867  make_tuple(m0, I0, k, I0),
868  a_thread_bufs(PongP1{}));
869  static_for<0, NRepeat, 1>{}([&](auto n0) {
870  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
871  make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
872  b_block_buf.At(PongP1{}),
873  b_thread_desc_,
874  make_tuple(n0, I0, k, I0),
875  b_thread_bufs(PongP1{}));
876  });
877  });
878  });
879 
880  static_for<0, KRepeat, 1>{}([&](auto k0) {
881  static_for<0, MRepeat, 1>{}([&](auto m0) {
882  static_for<0, NRepeat, 1>{}([&](auto n0) {
883  vector_type<FloatAB, KPack> a_thread_vec;
884  vector_type<FloatAB, KPack> b_thread_vec;
885 
886  static_for<0, KPack, 1>{}([&](auto ik) {
887  a_thread_vec.template AsType<FloatAB>()(ik) =
888  a_thread_bufs[PingP1{}][Number<a_thread_desc_.CalculateOffset(
889  make_tuple(m0, I0, k0, ik))>{}];
890  b_thread_vec.template AsType<FloatAB>()(ik) =
891  b_thread_bufs[PingP1{}][Number<b_thread_desc_.CalculateOffset(
892  make_tuple(n0, I0, k0, ik))>{}];
893  });
894 
895  using mfma_input_type =
896  typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
897 
898  constexpr index_t c_offset =
899  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
900 
901  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
902  b_thread_vec.template AsType<mfma_input_type>(),
903  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
904  });
905  });
906  });
907 
908  TailScheduler<2>();
909  __builtin_amdgcn_sched_barrier(0);
910 
911  // -------------------------------------------------------------------------------------------
912  using PingP2 = Number<1>;
913  // MFMA: Pong Reg
914  // DS_WRITE: To Pong LDS
915  // DS_READ: Ping LDS to Ping Reg
916 
917  static_for<0, KRepeat, 1>{}([&](auto k0) {
918  static_for<0, MRepeat, 1>{}([&](auto m0) {
919  static_for<0, NRepeat, 1>{}([&](auto n0) {
920  vector_type<FloatAB, KPack> a_thread_vec;
921  vector_type<FloatAB, KPack> b_thread_vec;
922 
923  static_for<0, KPack, 1>{}([&](auto ik) {
924  a_thread_vec.template AsType<FloatAB>()(ik) =
925  a_thread_bufs[PingP2{}][Number<a_thread_desc_.CalculateOffset(
926  make_tuple(m0, I0, k0, ik))>{}];
927  b_thread_vec.template AsType<FloatAB>()(ik) =
928  b_thread_bufs[PingP2{}][Number<b_thread_desc_.CalculateOffset(
929  make_tuple(n0, I0, k0, ik))>{}];
930  });
931 
932  using mfma_input_type =
933  typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
934 
935  constexpr index_t c_offset =
936  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
937 
938  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
939  b_thread_vec.template AsType<mfma_input_type>(),
940  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
941  });
942  });
943  });
944 
945  // 64 v_mfma
946  __builtin_amdgcn_sched_group_barrier(0x008, 64, 0); // MFMA
947  __builtin_amdgcn_sched_barrier(0);
948  }
949  }
950 
951  protected:
952  // M1, N1 as double buffer index
953  // Read buffer + Compute buffer
954  // A[M0, M1, M2, KPack]
955  static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor(
956  make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
957  make_tuple(
958  Number<KPack>{}, Number<KRepeat * MRepeat * KPack>{}, Number<MRepeat * KPack>{}, I1));
959 
960  // B[N0, N1, N2, KPack]
961  static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor(
962  make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
963  make_tuple(
964  Number<KPack>{}, Number<KRepeat * NRepeat * KPack>{}, Number<NRepeat * KPack>{}, I1));
965 
966  // C[M, N, NumRegXdlops]
967  static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
968  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
969 
971  FloatAB,
972  decltype(a_block_desc_m0_m1_m2_k),
973  decltype(a_thread_desc_),
976  3,
977  A_K1,
978  A_K1>;
979 
981  FloatAB,
982  decltype(b_block_desc_n0_n1_n2_k),
983  decltype(b_thread_desc_),
986  3,
987  B_K1,
988  B_K1>;
989 
992 };
993 
994 } // namespace ck
Definition: ck.hpp:264
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ index_t get_warp_size()
Definition: get_id.hpp:10
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:289
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
static constexpr index_t B_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:46
static constexpr index_t A_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:49
static constexpr index_t B_LDS_Read_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:51
static constexpr index_t A_LDS_Write_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:44
static constexpr index_t C_MFMA_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:54
static constexpr index_t A_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:39
static constexpr index_t WaveSize
Definition: blockwise_gemm_pipeline_xdlops.hpp:35
static constexpr index_t B_Buffer_Load_Inst_Num
Definition: blockwise_gemm_pipeline_xdlops.hpp:41
static constexpr auto Print()
Definition: blockwise_gemm_pipeline_xdlops.hpp:57
static constexpr index_t WaveNumN
Definition: blockwise_gemm_pipeline_xdlops.hpp:37
static constexpr index_t WaveNumM
Definition: blockwise_gemm_pipeline_xdlops.hpp:36
Definition: blockwise_gemm_pipeline_xdlops.hpp:103
static constexpr auto I1
Definition: blockwise_gemm_pipeline_xdlops.hpp:105
__host__ static constexpr __device__ auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition: blockwise_gemm_pipeline_xdlops.hpp:355
static constexpr index_t MWaves
Definition: blockwise_gemm_pipeline_xdlops.hpp:124
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition: blockwise_gemm_pipeline_xdlops.hpp:253
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: blockwise_gemm_pipeline_xdlops.hpp:338
static constexpr index_t A_K1
Definition: blockwise_gemm_pipeline_xdlops.hpp:115
static constexpr index_t A_K0
Definition: blockwise_gemm_pipeline_xdlops.hpp:113
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:961
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops.hpp:308
static constexpr __device__ auto HotLoopScheduler()
Definition: blockwise_gemm_pipeline_xdlops.hpp:373
static constexpr index_t WaveSize
Definition: blockwise_gemm_pipeline_xdlops.hpp:111
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops.hpp:991
decltype(CalculateAThreadOriginDataIndex()) Tuple4
Definition: blockwise_gemm_pipeline_xdlops.hpp:233
static constexpr auto I0
Definition: blockwise_gemm_pipeline_xdlops.hpp:104
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_pipeline_xdlops.hpp:453
static __device__ auto CalculateBThreadOriginDataIndex()
Definition: blockwise_gemm_pipeline_xdlops.hpp:178
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops.hpp:990
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops.hpp:267
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: blockwise_gemm_pipeline_xdlops.hpp:109
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_pipeline_xdlops.hpp:454
static constexpr index_t KRepeat
Definition: blockwise_gemm_pipeline_xdlops.hpp:122
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_pipeline_xdlops.hpp:153
static constexpr auto I3
Definition: blockwise_gemm_pipeline_xdlops.hpp:107
static constexpr index_t B_K1
Definition: blockwise_gemm_pipeline_xdlops.hpp:116
static constexpr auto I2
Definition: blockwise_gemm_pipeline_xdlops.hpp:106
static constexpr index_t B_K0
Definition: blockwise_gemm_pipeline_xdlops.hpp:114
__host__ static constexpr __device__ auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops.hpp:321
static __device__ auto CalculateAThreadOriginDataIndex()
Definition: blockwise_gemm_pipeline_xdlops.hpp:167
static constexpr index_t KPerThread
Definition: blockwise_gemm_pipeline_xdlops.hpp:121
static __device__ auto GetWaveIdx()
Definition: blockwise_gemm_pipeline_xdlops.hpp:155
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops.hpp:955
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition: blockwise_gemm_pipeline_xdlops.hpp:294
static constexpr index_t NWaves
Definition: blockwise_gemm_pipeline_xdlops.hpp:125
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, FloatAcc, MRepeat *NRepeat, xdlops_gemm.GetRegSizePerXdlops(), true > c_thread_buf_
Definition: blockwise_gemm_pipeline_xdlops.hpp:144
__host__ static constexpr __device__ auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops.hpp:280
static constexpr auto xdlops_gemm
Definition: blockwise_gemm_pipeline_xdlops.hpp:118
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_pipeline_xdlops.hpp:220
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops.hpp:471
__host__ __device__ BlockwiseGemmXdlops_pipeline_v4(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Definition: blockwise_gemm_pipeline_xdlops.hpp:236
static constexpr __device__ auto TailScheduler()
Definition: blockwise_gemm_pipeline_xdlops.hpp:404
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_pipeline_xdlops.hpp:191
Definition: sequence.hpp:43
Definition: static_buffer.hpp:75
static __device__ index_t GetThreadId()
Definition: thread_group.hpp:19
static constexpr __device__ index_t GetNumOfThread()
Definition: thread_group.hpp:15
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1036
Definition: xdlops_gemm.hpp:1181
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31