/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp Source File
blockwise_gemm_pipeline_xdlops_v5.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Compute optimized pipeline
11 // GlobalPrefetchStages: 3
12 // LocalPreFillStages: 1
13 // LocalPreFetchStages: 1
14 // LocalSharedMemoryBuffer: 2
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeDataType,
21  typename AccDataType,
22  typename ATileDesc,
23  typename BTileDesc,
24  typename AMmaTileDesc,
25  typename BMmaTileDesc,
26  index_t ABlockTransferSrcScalarPerVector,
27  index_t BBlockTransferSrcScalarPerVector,
28  index_t MPerBlock,
29  index_t NPerBlock,
30  index_t KPerBlock,
31  index_t MPerXDL,
32  index_t NPerXDL,
33  index_t MRepeat,
34  index_t NRepeat,
35  index_t KPacks>
37 {
38 };
39 
40 template <index_t BlockSize,
41  typename ADataType,
42  typename BDataType,
43  typename ComputeDataType,
44  typename AccDataType,
45  typename ATileDesc,
46  typename BTileDesc,
47  typename AMmaTileDesc,
48  typename BMmaTileDesc,
49  index_t ABlockTransferSrcScalarPerVector,
50  index_t BBlockTransferSrcScalarPerVector,
51  index_t MPerBlock,
52  index_t NPerBlock,
53  index_t KPerBlock,
54  index_t MPerXDL,
55  index_t NPerXDL,
56  index_t MRepeat,
57  index_t NRepeat,
58  index_t KPack
59  // ,bool TransposeC //disable transposec right now...
60  >
62  BlockSize,
63  ADataType,
64  BDataType,
65  ComputeDataType,
66  AccDataType,
67  ATileDesc,
68  BTileDesc,
69  AMmaTileDesc,
70  BMmaTileDesc,
71  ABlockTransferSrcScalarPerVector,
72  BBlockTransferSrcScalarPerVector,
73  MPerBlock,
74  NPerBlock,
75  KPerBlock,
76  MPerXDL,
77  NPerXDL,
78  MRepeat,
79  NRepeat,
80  KPack>
82  ADataType,
83  BDataType,
84  ComputeDataType,
85  AccDataType,
86  ATileDesc,
87  BTileDesc,
88  AMmaTileDesc,
89  BMmaTileDesc,
90  ABlockTransferSrcScalarPerVector,
91  BBlockTransferSrcScalarPerVector,
92  MPerBlock,
93  NPerBlock,
94  KPerBlock,
95  MPerXDL,
96  NPerXDL,
97  MRepeat,
98  NRepeat,
99  KPack>
100 
101 {
103  ADataType,
104  BDataType,
105  ComputeDataType,
106  AccDataType,
107  ATileDesc,
108  BTileDesc,
109  AMmaTileDesc,
110  BMmaTileDesc,
111  ABlockTransferSrcScalarPerVector,
112  BBlockTransferSrcScalarPerVector,
113  MPerBlock,
114  NPerBlock,
115  KPerBlock,
116  MPerXDL,
117  NPerXDL,
118  MRepeat,
119  NRepeat,
120  KPack>;
121  using Base::A_K1;
122  using Base::B_K1;
123  using Base::I0;
124  using Base::I1;
125  using Base::KRepeat;
126  using Base::xdlops_gemm;
127  using typename Base::HotLoopInstList;
128 
129  using Base::CalculateCThreadOriginDataIndex;
130  using Base::CalculateCThreadOriginDataIndex8D;
131  using Base::GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
132  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
133  using Base::GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
134  using Base::GetCThreadBuffer;
135  using Base::GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
136  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
137  using Base::GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4;
138  using Base::MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2;
139  using Base::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2;
140 
141  using Base::a_block_desc_m0_m1_m2_k;
142  using Base::b_block_desc_n0_n1_n2_k;
143 
144  using Base::AMmaKStride;
145  using Base::BMmaKStride;
146 
147  static constexpr index_t PrefetchStages = 3;
148  static constexpr index_t PrefillStages = 1;
149  static constexpr index_t GlobalBufferNum = 2;
150  static constexpr index_t HotloopUnroll = 2;
151 
152  __host__ static constexpr bool BlockHasHotloop(index_t num_loop)
153  {
154  return num_loop > PrefetchStages;
155  }
156 
157  __host__ static constexpr TailNumber BlockLoopTailNum(index_t num_loop)
158  {
159  if(num_loop % HotloopUnroll == 1)
160  {
161  return TailNumber::Odd;
162  }
163  else
164  {
165  return TailNumber::Even;
166  }
167  }
168 
169  __device__ static constexpr auto HotLoopScheduler()
170  {
171  // TODO: Take data type into consideration as pipe ver 3
172  // A/B split schedule
173  // compiler is likely to use ds_read2 when instruction width smaller than 16bytes
174  constexpr auto num_ds_read_inst_a =
175  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16
176  ? HotLoopInstList::A_LDS_Read_Inst_Num
177  : HotLoopInstList::A_LDS_Read_Inst_Num / 2;
178  constexpr auto num_ds_read_inst_b =
179  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16
180  ? HotLoopInstList::B_LDS_Read_Inst_Num
181  : HotLoopInstList::B_LDS_Read_Inst_Num / 2;
182 
183  constexpr auto num_ds_write_inst_a = HotLoopInstList::A_LDS_Write_Inst_Num;
184  constexpr auto num_ds_write_inst_b = HotLoopInstList::B_LDS_Write_Inst_Num;
185 
186  constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
187  constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num;
188 
189  constexpr auto num_mfma_inst = HotLoopInstList::C_MFMA_Inst_Num;
190 
191  constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
192  constexpr auto ds_read_a_issue_cycle =
193  HotLoopInstList::A_LDS_Read_Width * sizeof(ADataType) == 16 ? 8 : 4;
194  constexpr auto ds_read_b_issue_cycle =
195  HotLoopInstList::B_LDS_Read_Width * sizeof(BDataType) == 16 ? 8 : 4;
196  constexpr auto ds_read_a_mfma_rate =
197  (mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
198  constexpr auto ds_read_b_mfma_rate =
199  (mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
200 
201  constexpr auto num_dsread_stage1_a = num_ds_read_inst_a / KRepeat * (KRepeat - 1);
202  constexpr auto num_dsread_stage1_b = num_ds_read_inst_b / KRepeat * (KRepeat - 1);
203  constexpr auto num_dsread_stage3_a = num_ds_read_inst_a / KRepeat;
204  constexpr auto num_dsread_stage3_b = num_ds_read_inst_b / KRepeat;
205 
206  constexpr auto num_dsread_stage1_a_mfma =
207  (num_dsread_stage1_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
208  constexpr auto num_dsread_stage1_b_mfma =
209  (num_dsread_stage1_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
210  constexpr auto num_dsread_stage3_a_mfma =
211  (num_dsread_stage3_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
212  constexpr auto num_dsread_stage3_b_mfma =
213  (num_dsread_stage3_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
214 
215  constexpr auto num_mfma_stage2 = num_mfma_inst - num_ds_read_inst_a / ds_read_a_mfma_rate -
216  num_ds_read_inst_b / ds_read_b_mfma_rate;
217  constexpr auto num_mfma_per_issue =
218  num_mfma_stage2 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
219  constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
220  constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
221 
222  // stage 1
224  ignore = i;
225  if constexpr((num_dsread_stage1_a - (i + 1) * ds_read_a_mfma_rate) >=
226  ds_read_a_mfma_rate)
227  {
228  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
229  }
230  else
231  {
232  __builtin_amdgcn_sched_group_barrier(
233  0x100,
234  num_dsread_stage1_a - (num_dsread_stage1_a_mfma - 1) * ds_read_a_mfma_rate,
235  0); // DS read
236  }
237  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
238  });
240  ignore = i;
241  if constexpr((num_dsread_stage1_b - (i + 1) * ds_read_b_mfma_rate) >=
242  ds_read_b_mfma_rate)
243  {
244  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
245  }
246  else
247  {
248  __builtin_amdgcn_sched_group_barrier(
249  0x100,
250  num_dsread_stage1_b - (num_dsread_stage1_b_mfma - 1) * ds_read_b_mfma_rate,
251  0); // DS read
252  }
253  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
254  });
255 
256  // stage 2
258  ignore = i;
259  static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
260  ignore = idswrite;
261  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
262  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
263  });
264  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
265  __builtin_amdgcn_sched_group_barrier(
266  0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
267  });
269  ignore = i;
270  static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
271  ignore = idswrite;
272  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
273  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
274  });
275  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
276  __builtin_amdgcn_sched_group_barrier(
277  0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
278  });
279 
280  // stage 3
282  ignore = i;
283  if constexpr((num_dsread_stage3_a - (i + 1) * ds_read_a_mfma_rate) >=
284  ds_read_a_mfma_rate)
285  {
286  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
287  }
288  else
289  {
290  __builtin_amdgcn_sched_group_barrier(
291  0x100,
292  num_dsread_stage3_a - (num_dsread_stage3_a_mfma - 1) * ds_read_a_mfma_rate,
293  0); // DS read
294  }
295  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
296  });
298  ignore = i;
299  if constexpr((num_dsread_stage3_b - (i + 1) * ds_read_b_mfma_rate) >=
300  ds_read_b_mfma_rate)
301  {
302  __builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
303  }
304  else
305  {
306  __builtin_amdgcn_sched_group_barrier(
307  0x100,
308  num_dsread_stage3_b - (num_dsread_stage3_b_mfma - 1) * ds_read_b_mfma_rate,
309  0); // DS read
310  }
311  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
312  });
313 
314  // IGLP COMPILER BUG:
315  // If comment out following scheduler barrier would cause sanity fail.
316  __builtin_amdgcn_sched_barrier(0);
317  }
318 
319  template <bool HasMainLoop,
320  TailNumber TailNum,
321  typename AGridDesc,
322  typename ABlockDesc,
323  typename ABlockTransfer,
324  typename AGridBuffer,
325  typename ABlockBuffer,
326  typename ABlockTransferStep,
327  typename BGridDesc,
328  typename BBlockDesc,
329  typename BBlockTransfer,
330  typename BGridBuffer,
331  typename BBlockBuffer,
332  typename BBlockTransferStep,
333  typename CThreadBuffer>
334  __device__ void Run(const AGridDesc& a_grid_desc,
335  const ABlockDesc& a_block_desc,
336  ABlockTransfer& a_blockwise_copy,
337  const AGridBuffer& a_grid_buf,
338  ABlockBuffer& a_block_buf,
339  const ABlockTransferStep& a_block_copy_step,
340  const BGridDesc& b_grid_desc,
341  const BBlockDesc& b_block_desc,
342  BBlockTransfer& b_blockwise_copy,
343  const BGridBuffer& b_grid_buf,
344  BBlockBuffer& b_block_buf,
345  const BBlockTransferStep& b_block_copy_step,
346  CThreadBuffer& c_thread_buf,
347  index_t num_loop) const
348  {
349  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
350  a_thread_desc_.GetElementSpaceSize());
351  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
352  b_thread_desc_.GetElementSpaceSize());
353 
354  // Global prefetch 1
355  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
356  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
357 
358  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
359  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
360 
361  // Local prefill 1
362  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, I0);
363  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, I0);
364 
365  // Global prefetch 2
366  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
367  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I0);
368 
369  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
370  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
371 
372  // Global prefetch 3
373  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I1);
374  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, I1);
375 
376  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
377  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
378 
379  // Initialize C
380  c_thread_buf.Clear();
381 
382  // Local prefetch 1
383  block_sync_lds();
384  static_for<0, MRepeat, 1>{}([&](auto m0) {
385  a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
386  make_tuple(m0, I0, I0, I0),
387  a_block_buf,
388  a_thread_desc_,
389  make_tuple(m0, I0, I0, I0),
390  a_thread_buf);
391  });
392  static_for<0, NRepeat, 1>{}([&](auto n0) {
393  b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
394  make_tuple(n0, I0, I0, I0),
395  b_block_buf,
396  b_thread_desc_,
397  make_tuple(n0, I0, I0, I0),
398  b_thread_buf);
399  });
400 
401  // main body
402  if constexpr(HasMainLoop)
403  {
404  index_t i = 0;
405  do
406  {
407  auto LoopFunc = [&](auto vmem_buf) {
410 
411  static_for<0, KRepeat, 1>{}([&](auto k0) {
412  if constexpr(k0 == (KRepeat - 1))
413  {
414  block_sync_lds();
415 
416  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, vmem_buf);
417  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, vmem_buf);
418 
419  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, vmem_buf);
420  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf, vmem_buf);
421 
422  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
423  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
424 
425  block_sync_lds();
426  }
427  static_for<0, MRepeat, 1>{}([&](auto m0) {
428  static_for<0, NRepeat, 1>{}([&](auto n0) {
429  static_for<0, KPack, 1>{}([&](auto ik) {
430  a_thread_vec.template AsType<ComputeDataType>()(ik) =
431  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
432  make_tuple(m0, I0, I0, ik))>{}];
433  });
434  static_for<0, KPack, 1>{}([&](auto ik) {
435  b_thread_vec.template AsType<ComputeDataType>()(ik) =
436  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
437  make_tuple(n0, I0, I0, ik))>{}];
438  });
439 
440  using mfma_input_type =
441  typename vector_type<ComputeDataType,
442  xdlops_gemm.K1PerXdlops>::type;
443 
444  constexpr index_t c_offset =
445  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
446 
447  xdlops_gemm.Run(
448  a_thread_vec.template AsType<mfma_input_type>(),
449  b_thread_vec.template AsType<mfma_input_type>(),
450  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
451  });
452 
453  a_thread_copy_.Run(
454  a_block_desc_m0_m1_m2_k,
455  make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
456  a_block_buf,
457  a_thread_desc_,
458  make_tuple(m0, I0, I0, I0),
459  a_thread_buf);
460  });
461 
462  static_for<0, NRepeat, 1>{}([&](auto n0) {
463  b_thread_copy_.Run(
464  b_block_desc_n0_n1_n2_k,
465  make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
466  b_block_buf,
467  b_thread_desc_,
468  make_tuple(n0, I0, I0, I0),
469  b_thread_buf);
470  });
471  });
472 
473  HotLoopScheduler();
474  };
475 
476  LoopFunc(I0);
477  LoopFunc(I1);
478 
479  i += HotloopUnroll;
480  } while(i < (num_loop - PrefetchStages));
481  }
482  // tail
483  auto ReadWriteCompFunc = [&](auto vmem_buf) {
486 
487  static_for<0, KRepeat, 1>{}([&](auto k0) {
488  if constexpr(k0 == (KRepeat - 1))
489  {
490  block_sync_lds();
491 
492  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, vmem_buf);
493  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf, vmem_buf);
494 
495  block_sync_lds();
496  }
497  static_for<0, MRepeat, 1>{}([&](auto m0) {
498  static_for<0, NRepeat, 1>{}([&](auto n0) {
499  static_for<0, KPack, 1>{}([&](auto ik) {
500  a_thread_vec.template AsType<ComputeDataType>()(ik) =
501  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
502  make_tuple(m0, I0, I0, ik))>{}];
503  });
504  static_for<0, KPack, 1>{}([&](auto ik) {
505  b_thread_vec.template AsType<ComputeDataType>()(ik) =
506  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
507  make_tuple(n0, I0, I0, ik))>{}];
508  });
509 
510  using mfma_input_type =
511  typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
512 
513  constexpr index_t c_offset =
514  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
515 
516  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
517  b_thread_vec.template AsType<mfma_input_type>(),
518  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
519  });
520  a_thread_copy_.Run(
521  a_block_desc_m0_m1_m2_k,
522  make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
523  a_block_buf,
524  a_thread_desc_,
525  make_tuple(m0, I0, I0, I0),
526  a_thread_buf);
527  });
528 
529  static_for<0, NRepeat, 1>{}([&](auto n0) {
530  b_thread_copy_.Run(
531  b_block_desc_n0_n1_n2_k,
532  make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
533  b_block_buf,
534  b_thread_desc_,
535  make_tuple(n0, I0, I0, I0),
536  b_thread_buf);
537  });
538  });
539 
540  HotLoopScheduler();
541  };
542  auto ReadCompFunc = [&]() {
545 
546  static_for<0, KRepeat - 1, 1>{}([&](auto k0) {
547  static_for<0, MRepeat, 1>{}([&](auto m0) {
548  static_for<0, NRepeat, 1>{}([&](auto n0) {
549  static_for<0, KPack, 1>{}([&](auto ik) {
550  a_thread_vec.template AsType<ComputeDataType>()(ik) =
551  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
552  make_tuple(m0, I0, I0, ik))>{}];
553  });
554  static_for<0, KPack, 1>{}([&](auto ik) {
555  b_thread_vec.template AsType<ComputeDataType>()(ik) =
556  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
557  make_tuple(n0, I0, I0, ik))>{}];
558  });
559 
560  using mfma_input_type =
561  typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
562 
563  constexpr index_t c_offset =
564  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
565 
566  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
567  b_thread_vec.template AsType<mfma_input_type>(),
568  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
569  });
570 
571  a_thread_copy_.Run(
572  a_block_desc_m0_m1_m2_k,
573  make_tuple(m0, I0, I0, Number<(k0 + 1) % KRepeat * AMmaKStride>{}),
574  a_block_buf,
575  a_thread_desc_,
576  make_tuple(m0, I0, I0, I0),
577  a_thread_buf);
578  });
579 
580  static_for<0, NRepeat, 1>{}([&](auto n0) {
581  b_thread_copy_.Run(
582  b_block_desc_n0_n1_n2_k,
583  make_tuple(n0, I0, I0, Number<(k0 + 1) % KRepeat * BMmaKStride>{}),
584  b_block_buf,
585  b_thread_desc_,
586  make_tuple(n0, I0, I0, I0),
587  b_thread_buf);
588  });
589  });
590 
591  static_for<0, MRepeat, 1>{}([&](auto m0) {
592  static_for<0, NRepeat, 1>{}([&](auto n0) {
593  static_for<0, KPack, 1>{}([&](auto ik) {
594  a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_buf
595  [Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, I0, ik))>{}];
596  });
597  static_for<0, KPack, 1>{}([&](auto ik) {
598  b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_buf
599  [Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, I0, ik))>{}];
600  });
601 
602  using mfma_input_type =
603  typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
604 
605  constexpr index_t c_offset =
606  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
607 
608  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
609  b_thread_vec.template AsType<mfma_input_type>(),
610  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
611  });
612  });
613 
614  HotLoopScheduler();
615  };
616 
617  if constexpr(TailNum == TailNumber::Odd)
618  {
619  ReadWriteCompFunc(I0);
620  ReadWriteCompFunc(I1);
621  ReadCompFunc();
622  }
623  else if constexpr(TailNum == TailNumber::Even)
624  {
625  ReadWriteCompFunc(I0);
626  ReadCompFunc();
627  }
628  }
629 
630  protected:
631  // A[MRepeat, I1, I1, KPack]
632  static constexpr auto a_thread_desc_ =
634 
635  // B[NRepeat, N1, N2, KPack]
636  static constexpr auto b_thread_desc_ =
638 
640  ComputeDataType,
641  decltype(a_block_desc_m0_m1_m2_k),
642  decltype(a_thread_desc_),
645  3,
646  A_K1,
647  A_K1>;
648 
650  ComputeDataType,
651  decltype(b_block_desc_n0_n1_n2_k),
652  decltype(b_thread_desc_),
655  3,
656  B_K1,
657  B_K1>;
658 
659  AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
660  BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
661  using Base::c_thread_desc_;
662 };
663 
664 } // namespace ck
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
TailNumber
Definition: blkgemmpipe_scheduler.hpp:18
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_xdlops_v5.hpp:334
Definition: blockwise_gemm_pipeline_xdlops_v5.hpp:37
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: functional2.hpp:31
Definition: data_type.hpp:347