/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp Source File
blockwise_gemm_pipeline_wmmaops_v1.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Naive pipeline with lowest resource request per WGP
11 
12 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
13  index_t BlockSize,
14  typename ADataType,
15  typename BDataType,
16  typename ComputeTypeA,
17  typename ComputeTypeB,
18  typename AccDataType,
19  typename AWmmaTileDesc,
20  typename BWmmaTileDesc,
21  index_t ABlockTransferSrcScalarPerVector,
22  index_t BBlockTransferSrcScalarPerVector,
23  index_t MPerBlock,
24  index_t NPerBlock,
25  index_t KPerBlock,
26  index_t MPerWmma,
27  index_t NPerWmma,
28  index_t MRepeat,
29  index_t NRepeat,
30  index_t KPack,
31  index_t KInner,
32  bool TransposeC = false,
33  bool BSkipLDS = false>
35 {
36 };
37 
38 template <index_t BlockSize,
39  typename ADataType,
40  typename BDataType,
41  typename ComputeTypeA,
42  typename ComputeTypeB,
43  typename AccDataType,
44  typename AWmmaTileDesc,
45  typename BWmmaTileDesc,
46  index_t ABlockTransferSrcScalarPerVector,
47  index_t BBlockTransferSrcScalarPerVector,
48  index_t MPerBlock,
49  index_t NPerBlock,
50  index_t KPerBlock,
51  index_t MPerWmma,
52  index_t NPerWmma,
53  index_t MRepeat,
54  index_t NRepeat,
55  index_t KPack,
56  index_t KInner,
57  bool TransposeC>
59  BlockSize,
60  ADataType,
61  BDataType,
62  ComputeTypeA,
63  ComputeTypeB,
64  AccDataType,
65  AWmmaTileDesc,
66  BWmmaTileDesc,
67  ABlockTransferSrcScalarPerVector,
68  BBlockTransferSrcScalarPerVector,
69  MPerBlock,
70  NPerBlock,
71  KPerBlock,
72  MPerWmma,
73  NPerWmma,
74  MRepeat,
75  NRepeat,
76  KPack,
77  KInner,
78  TransposeC,
79  false>
81  ADataType,
82  BDataType,
83  ComputeTypeA,
84  ComputeTypeB,
85  AccDataType,
86  AWmmaTileDesc,
87  BWmmaTileDesc,
88  ABlockTransferSrcScalarPerVector,
89  BBlockTransferSrcScalarPerVector,
90  MPerBlock,
91  NPerBlock,
92  KPerBlock,
93  MPerWmma,
94  NPerWmma,
95  MRepeat,
96  NRepeat,
97  KPack,
98  KInner,
99  TransposeC>
100 {
101  // GlobalPrefetchStages: 1
102  // LocalPreFillStages: 1
103  // LocalPreFetchStages: 0
104  // LocalSharedMemoryBuffer: 1
106  ADataType,
107  BDataType,
108  ComputeTypeA,
109  ComputeTypeB,
110  AccDataType,
111  AWmmaTileDesc,
112  BWmmaTileDesc,
113  ABlockTransferSrcScalarPerVector,
114  BBlockTransferSrcScalarPerVector,
115  MPerBlock,
116  NPerBlock,
117  KPerBlock,
118  MPerWmma,
119  NPerWmma,
120  MRepeat,
121  NRepeat,
122  KPack,
123  KInner,
124  TransposeC>;
125  using Base::I0;
126  using Base::I1;
127  using typename Base::HotLoopInstList;
128 
129  using Base::A_K1;
130  using Base::A_KRow;
131  using Base::B_K1;
132  using Base::B_KRow;
133  using Base::KRepeat;
134  using Base::WmmaK;
135 
136  using Base::wmma_gemm;
137 
138  using Base::CalculateCThreadOriginDataIndex;
139  using Base::
140  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
141  using Base::GetCThreadBuffer;
142  using Base::
143  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
144 
145  using Base::a_block_desc_k0_m0_m1_m2_k1;
146  using Base::b_block_desc_k0_n0_n1_n2_k1;
147 
148  using typename Base::Empty;
149 
150  static constexpr index_t PrefetchStages = 1;
151  static constexpr index_t PrefillStages = 1;
152  static constexpr index_t GlobalBufferNum = 1;
153 
154  static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
155 
157  {
158  ignore = num_loop;
159  return TailNumber::Full;
160  }
161 
162  template <bool HasMainLoop,
163  TailNumber TailNum,
164  typename AGridDesc,
165  typename ABlockDesc,
166  typename ABlockTransfer,
167  typename AGridBuffer,
168  typename ABlockBuffer,
169  typename ABlockTransferStep,
170  typename BGridDesc,
171  typename BBlockDesc,
172  typename BBlockTransfer,
173  typename BGridBuffer,
174  typename BBlockBuffer,
175  typename BBlockTransferStep,
176  typename CThreadBuffer,
177  typename BScaleStruct>
178  __device__ void Run(const AGridDesc& a_grid_desc,
179  const ABlockDesc& a_block_desc,
180  ABlockTransfer& a_blockwise_copy,
181  const AGridBuffer& a_grid_buf,
182  ABlockBuffer& a_block_buf,
183  const ABlockTransferStep& a_block_copy_step,
184  const BGridDesc& b_grid_desc,
185  const BBlockDesc& b_block_desc,
186  BBlockTransfer& b_blockwise_copy,
187  const BGridBuffer& b_grid_buf,
188  BBlockBuffer& b_block_buf,
189  const BBlockTransferStep& b_block_copy_step,
190  CThreadBuffer& c_thread_buf,
191  // BScaleThreadCopy
192  BScaleStruct& b_scale_struct,
193  index_t num_loop,
194  index_t num_loop_per_scale) const
195  {
196  constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
197 
198  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
199  a_thread_desc_.GetElementSpaceSize());
200  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
201  b_thread_desc_.GetElementSpaceSize());
202 
203  // Global prefetch 1
204  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
205  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
206 
207  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
208  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
209 
210  b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
211 
212  // Local prefill 1
213  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
214  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
215 
216  // Initialize C
217  c_thread_buf.Clear();
218 
219  auto blockwise_gemm_func = [&]() {
220  static_for<0, KRepeat, 1>{}([&](auto k0) {
221  static_for<0, MRepeat, 1>{}([&](auto m0) {
222  a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
223  make_tuple(I0, m0, k0, I0, I0, I0, I0),
224  a_block_buf,
225  a_thread_desc_,
226  make_tuple(I0, I0, I0, I0, I0, I0, I0),
227  a_thread_buf);
228  if constexpr(m0 == I0)
229  {
230  if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
231  {
232  static_for<0, NRepeat, 1>{}([&](auto n0) {
233  b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
234  make_tuple(I0, n0, k0, I0, I0, I0, I0),
235  b_block_buf,
236  b_thread_desc_,
237  make_tuple(I0, n0, I0, I0, I0, I0, I0),
238  b_thread_buf);
239  });
240  }
241  else
242  {
243  static_for<0, NRepeat, 1>{}([&](auto n0) {
244  b_thread_copy_.Run(
245  b_block_desc_k0_n0_n1_n2_k1,
246  make_tuple(I0, n0, k0, I0, I0, I0, I0),
247  b_block_buf,
248  b_scale_struct.b_scale_thread_bufs(
249  I0)[Number<n0 * BScaleStruct::num_scale_k_block +
250  k0 / BScaleStruct::num_scale_krepeat>{}],
251  b_thread_desc_,
252  make_tuple(I0, n0, I0, I0, I0, I0, I0),
253  b_thread_buf);
254  });
255  }
256  }
257 
258  static_for<0, KInner, 1>{}([&](auto k_inner) {
259  static_for<0, NRepeat, 1>{}([&](auto n0) {
260  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
261  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
262 
263  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
264  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
265  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
266  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
268  I0,
269  I0,
270  I0,
271  I0,
272  I0,
273  Number<kk % A_K1>{}))>{}];
274  });
275  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
276  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
277  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
278  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
280  n0,
281  I0,
282  I0,
283  I0,
284  I0,
285  Number<kk % B_K1>{}))>{}];
286  });
287 
288  using wmma_input_type_a =
289  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
290  using wmma_input_type_b =
291  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
292 
293  constexpr index_t c_offset =
294  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
295 
296  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
297  b_thread_vec.template AsType<wmma_input_type_b>(),
298  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
299  });
300  });
301  });
302  });
303  };
304 
305  // main body
306  if constexpr(HasMainLoop)
307  {
308  index_t i = 0;
309  do
310  {
311  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
312  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
313 
314  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
315  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
316 
317  block_sync_lds();
318  blockwise_gemm_func();
319 
320  block_sync_lds();
321  b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
322  if constexpr(ck::is_same<BScaleStruct, Empty>::value == false)
323  {
324  block_sync_lds();
325  }
326  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
327  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
328 
329  constexpr index_t num_ds_write_inst =
330  HotLoopInstList::A_LDS_Write_Inst_Num + HotLoopInstList::B_LDS_Write_Inst_Num;
331 
332  constexpr index_t num_buffer_load_inst = HotLoopInstList::A_Buffer_Load_Inst_Num +
333  HotLoopInstList::B_Buffer_Load_Inst_Num;
335  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
336  });
337  static_for<0, KRepeat, 1>{}([&](auto) {
338  static_for<0, MRepeat, 1>{}([&](auto m0) {
339  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
340  if constexpr(m0 == I0)
341  {
342  static_for<0, NRepeat, 1>{}([&](auto) {
343  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
344  });
345  }
346  static_for<0, KInner, 1>{}([&](auto) {
347  static_for<0, NRepeat, 1>{}([&](auto) {
348  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
349  });
350  });
351  });
352  });
354  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
355  });
356 
357  i += 1;
358  } while(i < (num_loop - 1));
359  }
360 
361  // tail
362  if constexpr(TailNum == TailNumber::Full)
363  {
364  block_sync_lds();
365  blockwise_gemm_func();
366  }
367  }
368 
369  protected:
370  // A[MRepeat, I1, I1, KPack]
371  static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
372  make_tuple(Number<KPack / A_K1 / A_KRow>{}, I1, I1, I1, I1, I1, Number<A_K1>{}));
373 
374  // B[NRepeat, N1, N2, KPack]
375  static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple(
376  Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, I1, Number<B_K1>{}));
377 
378  using AThreadCopy =
380  ComputeTypeA,
381  decltype(a_block_desc_k0_m0_m1_m2_k1),
382  decltype(a_thread_desc_),
383  Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
385  6,
386  A_K1,
387  A_K1>;
388 
389  using BThreadCopy =
391  ComputeTypeB,
392  decltype(b_block_desc_k0_n0_n1_n2_k1),
393  decltype(b_thread_desc_),
394  Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
396  6,
397  B_K1,
398  B_K1>;
399 
400  AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
401  BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
402  using Base::c_thread_desc_;
403 };
404 
405 template <index_t BlockSize,
406  typename ADataType,
407  typename BDataType,
408  typename ComputeTypeA,
409  typename ComputeTypeB,
410  typename AccDataType,
411  typename AWmmaTileDesc,
412  typename BWmmaTileDesc,
413  index_t ABlockTransferSrcScalarPerVector,
414  index_t BBlockTransferSrcScalarPerVector,
415  index_t MPerBlock,
416  index_t NPerBlock,
417  index_t KPerBlock,
418  index_t MPerWmma,
419  index_t NPerWmma,
420  index_t MRepeat,
421  index_t NRepeat,
422  index_t KPack,
423  index_t KInner,
424  bool TransposeC>
426  BlockSize,
427  ADataType,
428  BDataType,
429  ComputeTypeA,
430  ComputeTypeB,
431  AccDataType,
432  AWmmaTileDesc,
433  BWmmaTileDesc,
434  ABlockTransferSrcScalarPerVector,
435  BBlockTransferSrcScalarPerVector,
436  MPerBlock,
437  NPerBlock,
438  KPerBlock,
439  MPerWmma,
440  NPerWmma,
441  MRepeat,
442  NRepeat,
443  KPack,
444  KInner,
445  TransposeC,
446  false>
448  ADataType,
449  BDataType,
450  ComputeTypeA,
451  ComputeTypeB,
452  AccDataType,
453  AWmmaTileDesc,
454  BWmmaTileDesc,
455  ABlockTransferSrcScalarPerVector,
456  BBlockTransferSrcScalarPerVector,
457  MPerBlock,
458  NPerBlock,
459  KPerBlock,
460  MPerWmma,
461  NPerWmma,
462  MRepeat,
463  NRepeat,
464  KPack,
465  KInner,
466  TransposeC>
467 {
468  // GlobalPrefetchStages: 1
469  // LocalPreFillStages: 1
470  // LocalPreFetchStages: 0
471  // LocalSharedMemoryBuffer: 1
473  ADataType,
474  BDataType,
475  ComputeTypeA,
476  ComputeTypeB,
477  AccDataType,
478  AWmmaTileDesc,
479  BWmmaTileDesc,
480  ABlockTransferSrcScalarPerVector,
481  BBlockTransferSrcScalarPerVector,
482  MPerBlock,
483  NPerBlock,
484  KPerBlock,
485  MPerWmma,
486  NPerWmma,
487  MRepeat,
488  NRepeat,
489  KPack,
490  KInner,
491  TransposeC>;
492  using Base::I0;
493  using Base::I1;
494 
495  using Base::A_K1;
496  using Base::A_KRow;
497  using Base::B_K1;
498  using Base::B_KRow;
499  using Base::KRepeat;
500  using Base::WmmaK;
501 
502  using Base::wmma_gemm;
503 
504  using Base::CalculateCThreadOriginDataIndex;
505  using Base::
506  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
507  using Base::GetCThreadBuffer;
508  using Base::
509  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
510 
511  using Base::a_block_desc_k0_m0_m1_m2_k1;
512  using Base::b_block_desc_k0_n0_n1_n2_k1;
513 
514  using typename Base::Empty;
515 
517  static constexpr index_t KRepeatPerCluster = math::max(KRepeat / NumKClusters, 1);
518 
519  static constexpr index_t PrefetchStages = 1;
520  static constexpr index_t PrefillStages = 1;
521  static constexpr index_t GlobalBufferNum = 1;
522 
523  static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
524 
526  {
527  ignore = num_loop;
528  return TailNumber::Full;
529  }
530 
531  template <bool HasMainLoop,
532  TailNumber TailNum,
533  typename AGridDesc,
534  typename ABlockDesc,
535  typename ABlockTransfer,
536  typename AGridBuffer,
537  typename ABlockBuffer,
538  typename ABlockTransferStep,
539  typename BGridDesc,
540  typename BBlockDesc,
541  typename BBlockTransfer,
542  typename BGridBuffer,
543  typename BBlockBuffer,
544  typename BBlockTransferStep,
545  typename CThreadBuffer,
546  typename BScaleStruct>
547  __device__ void Run(const AGridDesc& a_grid_desc,
548  const ABlockDesc& a_block_desc,
549  ABlockTransfer& a_blockwise_copy,
550  const AGridBuffer& a_grid_buf,
551  ABlockBuffer& a_block_buf,
552  const ABlockTransferStep& a_block_copy_step,
553  const BGridDesc& b_grid_desc,
554  const BBlockDesc& b_block_desc,
555  BBlockTransfer& b_blockwise_copy,
556  const BGridBuffer& b_grid_buf,
557  BBlockBuffer& b_block_buf,
558  const BBlockTransferStep& b_block_copy_step,
559  CThreadBuffer& c_thread_buf,
560  // BScaleThreadCopy
561  BScaleStruct& b_scale_struct,
562  index_t num_loop,
563  index_t num_loop_per_scale) const
564  {
565  constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
566 
567  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
568  a_thread_desc_.GetElementSpaceSize());
569  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
570  b_thread_desc_.GetElementSpaceSize());
571 
572  // Global prefetch 1
573  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
574  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
575 
576  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
577  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
578 
579  b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
580 
581  // Local prefill 1
582  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
583  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
584 
585  // Initialize C
586  c_thread_buf.Clear();
587 
588  auto blockwise_gemm_func = [&]() {
589  static_for<0, KRepeat, KRepeatPerCluster>{}([&](auto k0_offset) {
590  static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
591  static_for<0, MRepeat, 1>{}([&](auto m0) {
592  a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
593  make_tuple(I0, m0, k0_offset + k0_inner, I0, I0, I0, I0),
594  a_block_buf,
595  a_thread_desc_,
596  make_tuple(I0, m0, k0_inner, I0, I0, I0, I0),
597  a_thread_buf);
598  });
599  if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
600  {
601  static_for<0, NRepeat, 1>{}([&](auto n0) {
602  b_thread_copy_.Run(
603  b_block_desc_k0_n0_n1_n2_k1,
604  make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
605  b_block_buf,
606  b_thread_desc_,
607  make_tuple(I0, n0, k0_inner, I0, I0, I0, I0),
608  b_thread_buf);
609  });
610  }
611  else
612  {
613  static_for<0, NRepeat, 1>{}([&](auto n0) {
614  b_thread_copy_.Run(
615  b_block_desc_k0_n0_n1_n2_k1,
616  make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
617  b_block_buf,
618  b_scale_struct.b_scale_thread_bufs(I0)[Number<
619  n0 * BScaleStruct::num_scale_k_block +
620  (k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
621  b_thread_desc_,
622  make_tuple(I0, n0, k0_inner, I0, I0, I0, I0),
623  b_thread_buf);
624  });
625  }
626  });
627 
628  __builtin_amdgcn_sched_barrier(0);
629  // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster,
630  // but except the first, as we can shorten non-MAC cluster a bit and there's no
631  // observable negative impact. The desired effect is waves in a workgroup
632  // executing MAC in sync. This avoids some out-of-sync waves hijacking MAC
633  // resource from other workgroups and reducing the chance of latency hiding by
634  // waiting for the rest of the workgroup at the eventual sync point.
635  if constexpr(k0_offset != 0 || KRepeat == 1)
636  {
637  __builtin_amdgcn_s_barrier();
638  __builtin_amdgcn_sched_barrier(0);
639  }
640  static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
641  static_for<0, KInner, 1>{}([&](auto k_inner) {
642  static_for<0, MRepeat, 1>{}([&](auto m0) {
643  static_for<0, NRepeat, 1>{}([&](auto n0) {
644  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
645  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
646 
647  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
648  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
649  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
650  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
652  m0,
653  k0_inner,
654  I0,
655  I0,
656  I0,
657  Number<kk % A_K1>{}))>{}];
658  });
659  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
660  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
661  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
662  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
664  n0,
665  k0_inner,
666  I0,
667  I0,
668  I0,
669  Number<kk % B_K1>{}))>{}];
670  });
671 
672  using wmma_input_type_a =
673  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
674  using wmma_input_type_b =
675  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
676 
677  constexpr index_t c_offset =
678  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
679 
680  // The block_sync_lds() here performs double duty:
681  // A) safeguard against data hazard.
682  // B) reduce VMEM FIFO congestion by applying small delays to
683  // different wavefronts.
684  // It is performed near the end of MAC cluster to minimize lgkmcnt
685  // penalty
686  if constexpr(k0_offset + k0_inner == KRepeat - 1 &&
687  m0 == MRepeat - 1 && n0 == NRepeat - 1)
688  {
689  __builtin_amdgcn_sched_barrier(0);
690  block_sync_lds();
691  __builtin_amdgcn_sched_barrier(0);
692  }
693  wmma_gemm.Run(
694  a_thread_vec.template AsType<wmma_input_type_a>(),
695  b_thread_vec.template AsType<wmma_input_type_b>(),
696  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
697  if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
698  {
699  __builtin_amdgcn_sched_barrier(0);
700  __builtin_amdgcn_s_setprio(1);
701  __builtin_amdgcn_sched_barrier(0);
702  }
703  });
704  });
705  });
706  });
707  __builtin_amdgcn_sched_barrier(0);
708  __builtin_amdgcn_s_setprio(0);
709  __builtin_amdgcn_sched_barrier(0);
710  });
711  };
712 
713  // main body
714  if constexpr(HasMainLoop)
715  {
716  index_t i = 0;
717  do
718  {
719  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
720  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
721 
722  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
723  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
724 
725  block_sync_lds();
726  blockwise_gemm_func();
727 
728  b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
729  if constexpr(ck::is_same<BScaleStruct, Empty>::value == false)
730  {
731  block_sync_lds();
732  }
733  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
734  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
735 
736  i += 1;
737  } while(i < (num_loop - 1));
738  }
739 
740  // tail
741  if constexpr(TailNum == TailNumber::Full)
742  {
743  block_sync_lds();
744  blockwise_gemm_func();
745  }
746  }
747 
748  protected:
749  static constexpr auto a_thread_desc_ =
751  Number<MRepeat>{},
752  Number<KRepeatPerCluster>{},
753  I1,
754  I1,
755  I1,
756  Number<A_K1>{}),
757  make_tuple(Number<A_K1>{},
758  Number<KPack / A_KRow>{},
759  Number<KPack / A_KRow * MRepeat>{},
760  I0,
761  I0,
762  I0,
763  I1));
764 
765  static constexpr auto b_thread_desc_ =
767  Number<NRepeat>{},
768  Number<KRepeatPerCluster>{},
769  I1,
770  I1,
771  I1,
772  Number<B_K1>{}),
773  make_tuple(Number<B_K1>{},
774  Number<KPack / B_KRow>{},
775  Number<KPack / B_KRow * NRepeat>{},
776  I0,
777  I0,
778  I0,
779  I1));
780 
781  using AThreadCopy =
783  ComputeTypeA,
784  decltype(a_block_desc_k0_m0_m1_m2_k1),
785  decltype(a_thread_desc_),
786  Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
788  6,
789  A_K1,
790  A_K1>;
791 
792  using BThreadCopy =
794  ComputeTypeB,
795  decltype(b_block_desc_k0_n0_n1_n2_k1),
796  decltype(b_thread_desc_),
797  Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
799  6,
800  B_K1,
801  B_K1>;
802 
803  AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
804  BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
805  using Base::c_thread_desc_;
806 };
807 
808 template <index_t BlockSize,
809  typename ADataType,
810  typename BDataType,
811  typename ComputeTypeA,
812  typename ComputeTypeB,
813  typename AccDataType,
814  typename AWmmaTileDesc,
815  typename BWmmaTileDesc,
816  index_t ABlockTransferSrcScalarPerVector,
817  index_t BBlockTransferSrcScalarPerVector,
818  index_t MPerBlock,
819  index_t NPerBlock,
820  index_t KPerBlock,
821  index_t MPerWmma,
822  index_t NPerWmma,
823  index_t MRepeat,
824  index_t NRepeat,
825  index_t KPack,
826  index_t KInner,
827  bool TransposeC>
829  BlockSize,
830  ADataType,
831  BDataType,
832  ComputeTypeA,
833  ComputeTypeB,
834  AccDataType,
835  AWmmaTileDesc,
836  BWmmaTileDesc,
837  ABlockTransferSrcScalarPerVector,
838  BBlockTransferSrcScalarPerVector,
839  MPerBlock,
840  NPerBlock,
841  KPerBlock,
842  MPerWmma,
843  NPerWmma,
844  MRepeat,
845  NRepeat,
846  KPack,
847  KInner,
848  TransposeC,
849  true>
851  ADataType,
852  BDataType,
853  ComputeTypeA,
854  ComputeTypeB,
855  AccDataType,
856  AWmmaTileDesc,
857  BWmmaTileDesc,
858  ABlockTransferSrcScalarPerVector,
859  BBlockTransferSrcScalarPerVector,
860  MPerBlock,
861  NPerBlock,
862  KPerBlock,
863  MPerWmma,
864  NPerWmma,
865  MRepeat,
866  NRepeat,
867  KPack,
868  KInner,
869  TransposeC>
870 {
871  // GlobalPrefetchStages: 2
872  // LocalPreFillStages: 1
873  // LocalPreFetchStages: 1
874  // LocalSharedMemoryBuffer: 1
876  ADataType,
877  BDataType,
878  ComputeTypeA,
879  ComputeTypeB,
880  AccDataType,
881  AWmmaTileDesc,
882  BWmmaTileDesc,
883  ABlockTransferSrcScalarPerVector,
884  BBlockTransferSrcScalarPerVector,
885  MPerBlock,
886  NPerBlock,
887  KPerBlock,
888  MPerWmma,
889  NPerWmma,
890  MRepeat,
891  NRepeat,
892  KPack,
893  KInner,
894  TransposeC>;
895  using Base::I0;
896  using Base::I1;
897  using Base::MWaves;
898  using Base::WaveSize;
899  using typename Base::HotLoopInstList;
900 
901  using Base::A_K1;
902  using Base::A_KRow;
903  using Base::B_K1;
904  using Base::B_KRow;
905  using Base::KRepeat;
906  using Base::WmmaK;
907 
908  using Base::wmma_gemm;
909 
910  using Base::CalculateCThreadOriginDataIndex;
911  using Base::
912  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
913  using Base::GetCThreadBuffer;
914  using Base::
915  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
916 
917  using Base::a_block_desc_k0_m0_m1_m2_k1;
918  using Base::b_block_desc_k0_n0_n1_n2_k1;
919 
920  using typename Base::Empty;
921 
922  static constexpr index_t PrefetchStages = 2;
923  static constexpr index_t PrefillStages = 1;
924  static constexpr index_t GlobalBufferNum = 2;
925 
926  static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
927 
929  {
930  return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
931  }
932 
933  __device__ static constexpr auto HotLoopScheduler()
934  {
935  constexpr auto num_ds_read_inst_a = HotLoopInstList::A_LDS_Read_Inst_Num;
936  constexpr auto num_buffer_load_inst_a = HotLoopInstList::A_Buffer_Load_Inst_Num;
937  constexpr auto num_buffer_load_inst_b = HotLoopInstList::B_Buffer_Load_Inst_Num * MWaves;
938  constexpr auto wmma_interleave = 2;
939  // B global
941  ignore = i;
942  if constexpr(MPerBlock >= 128 && NPerBlock >= 128)
943  {
944  __builtin_amdgcn_sched_group_barrier(0x008, 2 * wmma_interleave, 0);
945  }
946  else
947  {
948  __builtin_amdgcn_sched_group_barrier(0x008, wmma_interleave, 0);
949  }
950  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
951  });
952 
953  // A global
955  ignore = i;
956  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
957  __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
958  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
959  __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
960  });
961 
962  // A local
964  ignore = i;
965  __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
966  __builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
967  });
968  }
969 
970  template <bool HasMainLoop,
971  TailNumber TailNum,
972  typename AGridDesc,
973  typename ABlockDesc,
974  typename ABlockTransfer,
975  typename AGridBuffer,
976  typename ABlockBuffer,
977  typename ABlockTransferStep,
978  typename BGridDesc,
979  typename BBlockDesc,
980  typename BBlockTransfer,
981  typename BGridBuffer,
982  typename BBlockBuffer,
983  typename BBlockTransferStep,
984  typename CThreadBuffer,
985  typename BScaleStruct>
986  __device__ void Run(const AGridDesc& a_grid_desc,
987  const ABlockDesc& a_block_desc,
988  ABlockTransfer& a_blockwise_copy,
989  const AGridBuffer& a_grid_buf,
990  ABlockBuffer& a_block_buf,
991  const ABlockTransferStep& a_block_copy_step,
992  const BGridDesc& b_grid_desc,
993  const BBlockDesc&,
994  BBlockTransfer& b_blockwise_copy,
995  const BGridBuffer& b_grid_buf,
996  BBlockBuffer&,
997  const BBlockTransferStep& b_block_copy_step,
998  CThreadBuffer& c_thread_buf,
999  // BScaleThreadCopy
1000  BScaleStruct&,
1001  index_t num_loop,
1002  index_t) const
1003  {
1004  __builtin_amdgcn_sched_barrier(0);
1005  constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
1006 
1007  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
1008  a_thread_desc_.GetElementSpaceSize());
1009  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
1010  b_thread_desc_.GetElementSpaceSize());
1011 
1012  StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
1013  constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
1014 
1015  // Global prefetch A1 B1
1016  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
1017  b_blockwise_copy.Run(b_grid_desc,
1018  b_grid_buf,
1019  b_block_desc_k0_n0_n1_n2_k1,
1020  b_block_origin_idx,
1021  b_thread_bufs(I0));
1022 
1023  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1024  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
1025  __builtin_amdgcn_sched_barrier(0);
1026 
1027  // Local prefill A1
1028  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
1029 
1030  // Global prefetch A2
1031  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
1032  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1033 
1034  // Local prefetch A1
1035  block_sync_lds();
1036  static_for<0, MRepeat, 1>{}([&](auto m0) {
1037  static_for<0, KRepeat, 1>{}([&](auto k0) {
1038  a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
1039  make_tuple(I0, m0, k0, I0, I0, I0, I0),
1040  a_block_buf,
1041  a_thread_desc_,
1042  make_tuple(I0, m0, k0, I0, I0, I0, I0),
1043  a_thread_buf);
1044  });
1045  });
1046 
1047  // Initialize C
1048  c_thread_buf.Clear();
1049 
1050  __builtin_amdgcn_sched_barrier(0);
1051 
1052  // main body
1053  if constexpr(HasMainLoop)
1054  {
1055  index_t i = 0;
1056  do
1057  {
1058  auto LoopFunc = [&](auto wmma_reg_buf, auto local_read_buf) {
1059  b_blockwise_copy.Run(b_grid_desc,
1060  b_grid_buf,
1061  b_block_desc_k0_n0_n1_n2_k1,
1062  b_block_origin_idx,
1063  b_thread_bufs(local_read_buf));
1064 
1065  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
1066 
1067  block_sync_lds();
1068 
1069  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, wmma_reg_buf);
1070 
1071  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
1072  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
1073 
1074  static_for<0, MRepeat, 1>{}([&](auto m0) {
1075  static_for<0, NRepeat, 1>{}([&](auto n0) {
1076  static_for<0, KRepeat, 1>{}([&](auto k0) {
1077  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1078  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1079  static_for<0, KInner, 1>{}([&](auto k_inner) {
1080  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
1081  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1082  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1083  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1085  m0,
1086  k0,
1087  I0,
1088  I0,
1089  I0,
1090  Number<kk % A_K1>{}))>{}];
1091  });
1092  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
1093  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1094  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1095  b_thread_bufs[wmma_reg_buf]
1096  [Number<b_thread_desc_.CalculateOffset(
1098  I0,
1099  I0,
1100  n0,
1101  I0,
1102  k0,
1103  Number<kk % B_K1>{}))>{}];
1104  });
1105  using wmma_input_type_a =
1106  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1107  using wmma_input_type_b =
1108  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1109 
1110  constexpr index_t c_offset =
1111  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
1112 
1113  wmma_gemm.Run(
1114  a_thread_vec.template AsType<wmma_input_type_a>(),
1115  b_thread_vec.template AsType<wmma_input_type_b>(),
1116  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1117  });
1118  });
1119  });
1120  });
1121 
1122  block_sync_lds();
1123 
1124  // loop prefetch copy
1125  static_for<0, MRepeat, 1>{}([&](auto m0) {
1126  static_for<0, KRepeat, 1>{}([&](auto k0) {
1127  a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
1128  make_tuple(I0, m0, k0, I0, I0, I0, I0),
1129  a_block_buf,
1130  a_thread_desc_,
1131  make_tuple(I0, m0, k0, I0, I0, I0, I0),
1132  a_thread_buf);
1133  });
1134  });
1135 
1136  HotLoopScheduler();
1137  __builtin_amdgcn_sched_barrier(0);
1138  };
1139 
1140  LoopFunc(I0, I1);
1141  LoopFunc(I1, I0);
1142 
1143  i += 2;
1144  } while(i < (num_loop - 2));
1145  }
1146 
1147  // tail
1148  if constexpr(TailNum == TailNumber::Even)
1149  {
1150  b_blockwise_copy.Run(b_grid_desc,
1151  b_grid_buf,
1152  b_block_desc_k0_n0_n1_n2_k1,
1153  b_block_origin_idx,
1154  b_thread_bufs(I1));
1155 
1156  block_sync_lds();
1157 
1158  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
1159 
1160  static_for<0, MRepeat, 1>{}([&](auto m0) {
1161  static_for<0, NRepeat, 1>{}([&](auto n0) {
1162  static_for<0, KRepeat, 1>{}([&](auto k0) {
1163  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1164  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1165  static_for<0, KInner, 1>{}([&](auto k_inner) {
1166  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
1167  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1168  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1169  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1171  m0,
1172  k0,
1173  I0,
1174  I0,
1175  I0,
1176  Number<kk % A_K1>{}))>{}];
1177  });
1178  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
1179  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1180  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1181  b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
1183  I0,
1184  I0,
1185  n0,
1186  I0,
1187  k0,
1188  Number<kk % B_K1>{}))>{}];
1189  });
1190 
1191  using wmma_input_type_a =
1192  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1193  using wmma_input_type_b =
1194  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1195 
1196  constexpr index_t c_offset =
1197  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
1198 
1199  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
1200  b_thread_vec.template AsType<wmma_input_type_b>(),
1201  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1202  });
1203  });
1204  });
1205  });
1206 
1207  block_sync_lds();
1208 
1209  // tail Local Prefetch A1
1210  static_for<0, MRepeat, 1>{}([&](auto m0) {
1211  static_for<0, KRepeat, 1>{}([&](auto k0) {
1212  a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
1213  make_tuple(I0, m0, k0, I0, I0, I0, I0),
1214  a_block_buf,
1215  a_thread_desc_,
1216  make_tuple(I0, m0, k0, I0, I0, I0, I0),
1217  a_thread_buf);
1218  });
1219  });
1220 
1221  __builtin_amdgcn_sched_barrier(0);
1222 
1223  static_for<0, MRepeat, 1>{}([&](auto m0) {
1224  static_for<0, NRepeat, 1>{}([&](auto n0) {
1225  static_for<0, KRepeat, 1>{}([&](auto k0) {
1226  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1227  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1228  static_for<0, KInner, 1>{}([&](auto k_inner) {
1229  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
1230  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1231  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1232  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1234  m0,
1235  k0,
1236  I0,
1237  I0,
1238  I0,
1239  Number<kk % A_K1>{}))>{}];
1240  });
1241  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
1242  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1243  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1244  b_thread_bufs[I1][Number<b_thread_desc_.CalculateOffset(
1246  I0,
1247  I0,
1248  n0,
1249  I0,
1250  k0,
1251  Number<kk % B_K1>{}))>{}];
1252  });
1253  using wmma_input_type_a =
1254  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1255  using wmma_input_type_b =
1256  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1257 
1258  constexpr index_t c_offset =
1259  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
1260 
1261  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
1262  b_thread_vec.template AsType<wmma_input_type_b>(),
1263  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1264  });
1265  });
1266  });
1267  });
1268  // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
1269  // latency
1270  // __builtin_amdgcn_sched_barrier(0);
1271  }
1272  else if constexpr(TailNum == TailNumber::Odd)
1273  {
1274  static_for<0, MRepeat, 1>{}([&](auto m0) {
1275  static_for<0, NRepeat, 1>{}([&](auto n0) {
1276  static_for<0, KRepeat, 1>{}([&](auto k0) {
1277  vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
1278  vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
1279  static_for<0, KInner, 1>{}([&](auto k_inner) {
1280  static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
1281  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1282  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
1283  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
1285  m0,
1286  k0,
1287  I0,
1288  I0,
1289  I0,
1290  Number<kk % A_K1>{}))>{}];
1291  });
1292  static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
1293  constexpr index_t kk = ik + k_inner * KPerWaveBlock;
1294  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
1295  b_thread_bufs[I0][Number<b_thread_desc_.CalculateOffset(
1297  I0,
1298  I0,
1299  n0,
1300  I0,
1301  k0,
1302  Number<kk % B_K1>{}))>{}];
1303  });
1304  using wmma_input_type_a =
1305  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
1306  using wmma_input_type_b =
1307  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
1308 
1309  constexpr index_t c_offset =
1310  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
1311 
1312  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
1313  b_thread_vec.template AsType<wmma_input_type_b>(),
1314  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
1315  });
1316  });
1317  });
1318  });
1319  }
1320  }
1321 
1322  protected:
1323  static constexpr auto b_thread_desc_ =
1325  I1,
1326  I1,
1327  Number<NRepeat>{},
1328  I1,
1329  Number<KRepeat>{},
1330  Number<B_K1>{}));
1331 
1332  using Base::a_thread_copy_;
1333  using Base::a_thread_desc_;
1334  using Base::c_thread_desc_;
1335 };
1336 
1337 } // namespace ck
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:211
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:270
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:301
__device__ void block_sync_lds()
Definition: synchronization.hpp:16
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:36
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:547
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &b_scale_struct, index_t num_loop, index_t num_loop_per_scale) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:178
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, BScaleStruct &, index_t num_loop, index_t) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:986
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:35
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20
Definition: type.hpp:177
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:11