/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp Source File
blockwise_gemm_pipeline_wmmaops_v1.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 
8 namespace ck {
9 
10 // Naive pipeline with lowest resource request per WGP
11 // GlobalPrefetchStages: 1
12 // LocalPreFillStages: 1
13 // LocalPreFetchStages: 0
14 // LocalSharedMemoryBuffer: 1
15 
16 template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
17  index_t BlockSize,
18  typename ADataType,
19  typename BDataType,
20  typename ComputeTypeA,
21  typename ComputeTypeB,
22  typename AccDataType,
23  typename AWmmaTileDesc,
24  typename BWmmaTileDesc,
25  index_t ABlockTransferSrcScalarPerVector,
26  index_t BBlockTransferSrcScalarPerVector,
27  index_t MPerBlock,
28  index_t NPerBlock,
29  index_t KPerBlock,
30  index_t MPerWmma,
31  index_t NPerWmma,
32  index_t MRepeat,
33  index_t NRepeat,
34  index_t KPack>
36 {
37 };
38 
39 template <index_t BlockSize,
40  typename ADataType,
41  typename BDataType,
42  typename ComputeTypeA,
43  typename ComputeTypeB,
44  typename AccDataType,
45  typename AWmmaTileDesc,
46  typename BWmmaTileDesc,
47  index_t ABlockTransferSrcScalarPerVector,
48  index_t BBlockTransferSrcScalarPerVector,
49  index_t MPerBlock,
50  index_t NPerBlock,
51  index_t KPerBlock,
52  index_t MPerWmma,
53  index_t NPerWmma,
54  index_t MRepeat,
55  index_t NRepeat,
56  index_t KPack>
58  BlockSize,
59  ADataType,
60  BDataType,
61  ComputeTypeA,
62  ComputeTypeB,
63  AccDataType,
64  AWmmaTileDesc,
65  BWmmaTileDesc,
66  ABlockTransferSrcScalarPerVector,
67  BBlockTransferSrcScalarPerVector,
68  MPerBlock,
69  NPerBlock,
70  KPerBlock,
71  MPerWmma,
72  NPerWmma,
73  MRepeat,
74  NRepeat,
75  KPack>
77  ADataType,
78  BDataType,
79  ComputeTypeA,
80  ComputeTypeB,
81  AccDataType,
82  AWmmaTileDesc,
83  BWmmaTileDesc,
84  ABlockTransferSrcScalarPerVector,
85  BBlockTransferSrcScalarPerVector,
86  MPerBlock,
87  NPerBlock,
88  KPerBlock,
89  MPerWmma,
90  NPerWmma,
91  MRepeat,
92  NRepeat,
93  KPack>
94 
95 {
97  ADataType,
98  BDataType,
99  ComputeTypeA,
100  ComputeTypeB,
101  AccDataType,
102  AWmmaTileDesc,
103  BWmmaTileDesc,
104  ABlockTransferSrcScalarPerVector,
105  BBlockTransferSrcScalarPerVector,
106  MPerBlock,
107  NPerBlock,
108  KPerBlock,
109  MPerWmma,
110  NPerWmma,
111  MRepeat,
112  NRepeat,
113  KPack>;
114  using Base::I0;
115 
116  using Base::A_K1;
117  using Base::A_KRow;
118  using Base::B_K1;
119  using Base::B_KRow;
120  using Base::KRepeat;
121  using Base::WmmaK;
122 
123  using Base::wmma_gemm;
124 
125  using Base::CalculateCThreadOriginDataIndex;
126  using Base::
127  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
128  using Base::GetCThreadBuffer;
129  using Base::
130  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
131 
132  using Base::a_block_desc_k0_m0_m1_m2_k1;
133  using Base::b_block_desc_k0_n0_n1_n2_k1;
134 
135  static constexpr index_t PrefetchStages = 1;
136  static constexpr index_t PrefillStages = 1;
137  static constexpr index_t GlobalBufferNum = 1;
138 
139  static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
140 
142  {
143  ignore = num_loop;
144  return TailNumber::Full;
145  }
146 
147  template <bool HasMainLoop,
148  TailNumber TailNum,
149  typename AGridDesc,
150  typename ABlockDesc,
151  typename ABlockTransfer,
152  typename AGridBuffer,
153  typename ABlockBuffer,
154  typename ABlockTransferStep,
155  typename BGridDesc,
156  typename BBlockDesc,
157  typename BBlockTransfer,
158  typename BGridBuffer,
159  typename BBlockBuffer,
160  typename BBlockTransferStep,
161  typename CThreadBuffer>
162  __device__ void Run(const AGridDesc& a_grid_desc,
163  const ABlockDesc& a_block_desc,
164  ABlockTransfer& a_blockwise_copy,
165  const AGridBuffer& a_grid_buf,
166  ABlockBuffer& a_block_buf,
167  const ABlockTransferStep& a_block_copy_step,
168  const BGridDesc& b_grid_desc,
169  const BBlockDesc& b_block_desc,
170  BBlockTransfer& b_blockwise_copy,
171  const BGridBuffer& b_grid_buf,
172  BBlockBuffer& b_block_buf,
173  const BBlockTransferStep& b_block_copy_step,
174  CThreadBuffer& c_thread_buf,
175  index_t num_loop) const
176  {
177  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
178  a_thread_desc_.GetElementSpaceSize());
179  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
180  b_thread_desc_.GetElementSpaceSize());
181 
182  // Global prefetch 1
183  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
184  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
185 
186  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
187  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
188 
189  // Local prefill 1
190  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
191  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
192 
193  // Initialize C
194  c_thread_buf.Clear();
195 
196  auto blockwise_gemm_func = [&]() {
197  static_for<0, KRepeat, 1>{}([&](auto k0) {
198  a_thread_copy_.Run(
199  a_block_desc_k0_m0_m1_m2_k1,
200  make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, I0, I0, I0, I0, I0),
201  a_block_buf,
202  a_thread_desc_,
203  make_tuple(I0, I0, k0, I0, I0, I0),
204  a_thread_buf);
205  b_thread_copy_.Run(
206  b_block_desc_k0_n0_n1_n2_k1,
207  make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, I0, I0, I0, I0, I0),
208  b_block_buf,
209  b_thread_desc_,
210  make_tuple(I0, I0, k0, I0, I0, I0),
211  b_thread_buf);
212 
213  static_for<0, MRepeat, 1>{}([&](auto m0) {
214  static_for<0, NRepeat, 1>{}([&](auto n0) {
215  vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
216  vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
217 
218  static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
219  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
220  a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
221  Number<ik / A_K1>{}, m0, k0, I0, I0, Number<ik % A_K1>{}))>{}];
222  });
223  static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
224  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
225  b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
226  Number<ik / B_K1>{}, n0, k0, I0, I0, Number<ik % B_K1>{}))>{}];
227  });
228 
229  using wmma_input_type_a =
230  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
231  using wmma_input_type_b =
232  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
233 
234  constexpr index_t c_offset =
235  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
236 
237  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
238  b_thread_vec.template AsType<wmma_input_type_b>(),
239  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
240  });
241  });
242  });
243  };
244 
245  // main body
246  if constexpr(HasMainLoop)
247  {
248  index_t i = 0;
249  do
250  {
251  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
252  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
253 
254  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
255  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
256 
257  block_sync_lds();
258  blockwise_gemm_func();
259 
260  block_sync_lds();
261  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
262  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
263 
264  i += 1;
265  } while(i < (num_loop - 1));
266  }
267 
268  // tail
269  if constexpr(TailNum == TailNumber::Full)
270  {
271  block_sync_lds();
272  blockwise_gemm_func();
273  }
274  }
275 
276  protected:
277  using Base::a_thread_copy_;
278  using Base::a_thread_desc_;
279  using Base::b_thread_copy_;
280  using Base::b_thread_desc_;
281  using Base::c_thread_desc_;
282 };
283 
284 template <index_t BlockSize,
285  typename ADataType,
286  typename BDataType,
287  typename ComputeTypeA,
288  typename ComputeTypeB,
289  typename AccDataType,
290  typename AWmmaTileDesc,
291  typename BWmmaTileDesc,
292  index_t ABlockTransferSrcScalarPerVector,
293  index_t BBlockTransferSrcScalarPerVector,
294  index_t MPerBlock,
295  index_t NPerBlock,
296  index_t KPerBlock,
297  index_t MPerWmma,
298  index_t NPerWmma,
299  index_t MRepeat,
300  index_t NRepeat,
301  index_t KPack>
303  BlockSize,
304  ADataType,
305  BDataType,
306  ComputeTypeA,
307  ComputeTypeB,
308  AccDataType,
309  AWmmaTileDesc,
310  BWmmaTileDesc,
311  ABlockTransferSrcScalarPerVector,
312  BBlockTransferSrcScalarPerVector,
313  MPerBlock,
314  NPerBlock,
315  KPerBlock,
316  MPerWmma,
317  NPerWmma,
318  MRepeat,
319  NRepeat,
320  KPack>
322  ADataType,
323  BDataType,
324  ComputeTypeA,
325  ComputeTypeB,
326  AccDataType,
327  AWmmaTileDesc,
328  BWmmaTileDesc,
329  ABlockTransferSrcScalarPerVector,
330  BBlockTransferSrcScalarPerVector,
331  MPerBlock,
332  NPerBlock,
333  KPerBlock,
334  MPerWmma,
335  NPerWmma,
336  MRepeat,
337  NRepeat,
338  KPack>
339 
340 {
342  ADataType,
343  BDataType,
344  ComputeTypeA,
345  ComputeTypeB,
346  AccDataType,
347  AWmmaTileDesc,
348  BWmmaTileDesc,
349  ABlockTransferSrcScalarPerVector,
350  BBlockTransferSrcScalarPerVector,
351  MPerBlock,
352  NPerBlock,
353  KPerBlock,
354  MPerWmma,
355  NPerWmma,
356  MRepeat,
357  NRepeat,
358  KPack>;
359  using Base::I0;
360  using Base::I1;
361 
362  using Base::A_K1;
363  using Base::A_KRow;
364  using Base::B_K1;
365  using Base::B_KRow;
366  using Base::KRepeat;
367  using Base::WmmaK;
368 
369  using Base::wmma_gemm;
370 
371  using Base::CalculateCThreadOriginDataIndex;
372  using Base::
373  GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
374  using Base::GetCThreadBuffer;
375  using Base::
376  GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs;
377 
378  using Base::a_block_desc_k0_m0_m1_m2_k1;
379  using Base::b_block_desc_k0_n0_n1_n2_k1;
380 
382  static constexpr index_t KRepeatPerCluster = math::max(KRepeat / NumKClusters, 1);
383 
384  static constexpr index_t PrefetchStages = 1;
385  static constexpr index_t PrefillStages = 1;
386  static constexpr index_t GlobalBufferNum = 1;
387 
388  static bool BlockHasHotloop(index_t num_loop) { return num_loop > PrefetchStages; }
389 
391  {
392  ignore = num_loop;
393  return TailNumber::Full;
394  }
395 
396  template <bool HasMainLoop,
397  TailNumber TailNum,
398  typename AGridDesc,
399  typename ABlockDesc,
400  typename ABlockTransfer,
401  typename AGridBuffer,
402  typename ABlockBuffer,
403  typename ABlockTransferStep,
404  typename BGridDesc,
405  typename BBlockDesc,
406  typename BBlockTransfer,
407  typename BGridBuffer,
408  typename BBlockBuffer,
409  typename BBlockTransferStep,
410  typename CThreadBuffer>
411  __device__ void Run(const AGridDesc& a_grid_desc,
412  const ABlockDesc& a_block_desc,
413  ABlockTransfer& a_blockwise_copy,
414  const AGridBuffer& a_grid_buf,
415  ABlockBuffer& a_block_buf,
416  const ABlockTransferStep& a_block_copy_step,
417  const BGridDesc& b_grid_desc,
418  const BBlockDesc& b_block_desc,
419  BBlockTransfer& b_blockwise_copy,
420  const BGridBuffer& b_grid_buf,
421  BBlockBuffer& b_block_buf,
422  const BBlockTransferStep& b_block_copy_step,
423  CThreadBuffer& c_thread_buf,
424  index_t num_loop) const
425  {
426  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
427  a_thread_desc_.GetElementSpaceSize());
428  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
429  b_thread_desc_.GetElementSpaceSize());
430 
431  // Global prefetch 1
432  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
433  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
434 
435  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
436  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
437 
438  // Local prefill 1
439  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
440  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
441 
442  // Initialize C
443  c_thread_buf.Clear();
444 
445  auto blockwise_gemm_func = [&]() {
446  static_for<0, KRepeat, KRepeatPerCluster>{}([&](auto k0_offset) {
447  static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
448  a_thread_copy_.Run(
449  a_block_desc_k0_m0_m1_m2_k1,
450  make_tuple(Number<(k0_offset + k0_inner) * KPack / A_K1 / A_KRow>{},
451  I0,
452  I0,
453  I0,
454  I0,
455  I0),
456  a_block_buf,
457  a_thread_desc_,
458  make_tuple(I0, I0, k0_inner, I0, I0, I0),
459  a_thread_buf);
460  b_thread_copy_.Run(
461  b_block_desc_k0_n0_n1_n2_k1,
462  make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{},
463  I0,
464  I0,
465  I0,
466  I0,
467  I0),
468  b_block_buf,
469  b_thread_desc_,
470  make_tuple(I0, I0, k0_inner, I0, I0, I0),
471  b_thread_buf);
472  });
473 
474  __builtin_amdgcn_sched_barrier(0);
475  // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster,
476  // but except the first, as we can shorten non-MAC cluster a bit and there's no
477  // observable negative impact. The desired effect is waves in a workgroup
478  // executing MAC in sync. This avoids some out-of-sync waves hijacking MAC
479  // resource from other workgroups and reducing the chance of latency hiding by
480  // waiting for the rest of the workgroup at the eventual sync point.
481  if constexpr(k0_offset != 0 || KRepeat == 1)
482  {
483  __builtin_amdgcn_s_barrier();
484  __builtin_amdgcn_sched_barrier(0);
485  }
486  static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
487  static_for<0, MRepeat, 1>{}([&](auto m0) {
488  static_for<0, NRepeat, 1>{}([&](auto n0) {
489  vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
490  vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
491 
492  static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
493  a_thread_vec.template AsType<ComputeTypeA>()(ik) =
494  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
496  m0,
497  k0_inner,
498  I0,
499  I0,
500  Number<ik % A_K1>{}))>{}];
501  });
502  static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
503  b_thread_vec.template AsType<ComputeTypeB>()(ik) =
504  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
506  n0,
507  k0_inner,
508  I0,
509  I0,
510  Number<ik % B_K1>{}))>{}];
511  });
512 
513  using wmma_input_type_a =
514  typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
515  using wmma_input_type_b =
516  typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
517 
518  constexpr index_t c_offset =
519  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
520 
521  // The block_sync_lds() here performs double duty:
522  // A) safeguard against data hazard.
523  // B) reduce VMEM FIFO congestion by applying small delays to
524  // different wavefronts.
525  // It is performed near the end of MAC cluster to minimize lgkmcnt
526  // penalty
527  if constexpr(k0_offset + k0_inner == KRepeat - 1 && m0 == MRepeat - 1 &&
528  n0 == NRepeat - 1)
529  {
530  __builtin_amdgcn_sched_barrier(0);
531  block_sync_lds();
532  __builtin_amdgcn_sched_barrier(0);
533  }
534  wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
535  b_thread_vec.template AsType<wmma_input_type_b>(),
536  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
537  if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
538  {
539  __builtin_amdgcn_sched_barrier(0);
540  __builtin_amdgcn_s_setprio(1);
541  __builtin_amdgcn_sched_barrier(0);
542  }
543  });
544  });
545  });
546  __builtin_amdgcn_sched_barrier(0);
547  __builtin_amdgcn_s_setprio(0);
548  __builtin_amdgcn_sched_barrier(0);
549  });
550  };
551 
552  // main body
553  if constexpr(HasMainLoop)
554  {
555  index_t i = 0;
556  do
557  {
558  a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
559  b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
560 
561  a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
562  b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
563 
564  block_sync_lds();
565  blockwise_gemm_func();
566 
567  a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
568  b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
569 
570  i += 1;
571  } while(i < (num_loop - 1));
572  }
573 
574  // tail
575  if constexpr(TailNum == TailNumber::Full)
576  {
577  block_sync_lds();
578  blockwise_gemm_func();
579  }
580  }
581 
582  protected:
583  static constexpr auto a_thread_desc_ =
585  Number<MRepeat>{},
586  Number<KRepeatPerCluster>{},
587  I1,
588  I1,
589  Number<A_K1>{}),
590  make_tuple(Number<A_K1>{},
591  Number<KPack / A_KRow>{},
592  Number<KPack / A_KRow * MRepeat>{},
593  I0,
594  I0,
595  I1));
596 
597  static constexpr auto b_thread_desc_ =
599  Number<NRepeat>{},
600  Number<KRepeatPerCluster>{},
601  I1,
602  I1,
603  Number<B_K1>{}),
604  make_tuple(Number<B_K1>{},
605  Number<KPack / B_KRow>{},
606  Number<KPack / B_KRow * NRepeat>{},
607  I0,
608  I0,
609  I1));
610 
611  using AThreadCopy =
613  ComputeTypeA,
614  decltype(a_block_desc_k0_m0_m1_m2_k1),
615  decltype(a_thread_desc_),
616  Sequence<KPack / A_K1 / A_KRow, MRepeat, 1, 1, 1, A_K1>,
618  5,
619  A_K1,
620  A_K1>;
621 
622  using BThreadCopy =
624  ComputeTypeB,
625  decltype(b_block_desc_k0_n0_n1_n2_k1),
626  decltype(b_thread_desc_),
627  Sequence<KPack / B_K1 / B_KRow, NRepeat, 1, 1, 1, B_K1>,
629  5,
630  B_K1,
631  B_K1>;
632 
633  AThreadCopy a_thread_copy_{Base::CalculateAThreadOriginDataIndex()};
634  BThreadCopy b_thread_copy_{Base::CalculateBThreadOriginDataIndex()};
635  using Base::c_thread_desc_;
636 };
637 
638 } // namespace ck
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:207
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:269
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
TailNumber
Definition: blkgemmpipe_scheduler.hpp:31
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
BlockGemmPipelineScheduler
Definition: blkgemmpipe_scheduler.hpp:25
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:300
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:35
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:411
__device__ void Run(const AGridDesc &a_grid_desc, const ABlockDesc &a_block_desc, ABlockTransfer &a_blockwise_copy, const AGridBuffer &a_grid_buf, ABlockBuffer &a_block_buf, const ABlockTransferStep &a_block_copy_step, const BGridDesc &b_grid_desc, const BBlockDesc &b_block_desc, BBlockTransfer &b_blockwise_copy, const BGridBuffer &b_grid_buf, BBlockBuffer &b_block_buf, const BBlockTransferStep &b_block_copy_step, CThreadBuffer &c_thread_buf, index_t num_loop) const
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:162
Definition: blockwise_gemm_pipeline_wmmaops_v1.hpp:36
Definition: sequence.hpp:43
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10