/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp Source File
blockwise_gemm_xdlops.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
11 
12 namespace ck {
13 
14 template <index_t MNXdlPerWave, index_t MNWaves, index_t MNPerXdl, typename TileDesc_K0_MN_K1>
15 __host__ __device__ static constexpr auto
16 MakeGemmMmaTileDescriptor_MN0_MN1_MN2_K(const TileDesc_K0_MN_K1&)
17 {
18  constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{});
19  constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{});
20 
22  TileDesc_K0_MN_K1{},
23  make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number<K0>{}, Number<K1>{})),
25  make_tuple(Number<MNXdlPerWave>{}, Number<MNWaves>{}, Number<MNPerXdl>{}))),
26  make_tuple(Sequence<0, 2>{}, Sequence<1>{}),
27  make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}));
28 }
29 
30 template <index_t BlockSize,
31  typename FloatA,
32  typename FloatB,
33  typename FloatAcc,
34  typename AK0MK1BlockDesc,
35  typename BK0NK1BlockDesc,
36  index_t MPerXDL,
37  index_t NPerXDL,
38  index_t MRepeat,
39  index_t NRepeat,
40  index_t KPack,
41  typename ComputeTypeA = FloatA,
42  typename ComputeTypeB = FloatB>
43 struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
44 {
45  static constexpr auto I0 = Number<0>{};
46  static constexpr auto I1 = Number<1>{};
47  static constexpr auto I2 = Number<2>{};
48  static constexpr auto I3 = Number<3>{};
49 
51 
52  static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1);
53  static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1);
54  static constexpr index_t KPerBlock =
55  BK0NK1BlockDesc{}.GetLength(I0) * BK0NK1BlockDesc{}.GetLength(I2);
56 
57  static constexpr index_t A_K0 = AK0MK1BlockDesc{}.GetLength(I0);
58  static constexpr index_t B_K0 = BK0NK1BlockDesc{}.GetLength(I0);
59  static constexpr index_t A_K1 = AK0MK1BlockDesc{}.GetLength(I2);
60  static constexpr index_t B_K1 = BK0NK1BlockDesc{}.GetLength(I2);
61 
62  static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
63  static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
64  static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
65 
66  static constexpr auto xdlops_gemm =
67  XdlopsGemm<ComputeTypeA, MPerXDL, NPerXDL, KPack, ComputeTypeB>{};
68 
69  static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
70 
71  StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
72  FloatAcc,
73  MRepeat * NRepeat,
74  xdlops_gemm.GetRegSizePerXdlops(),
75  true>
77 
78  __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
79 
80  __device__ static auto GetWaveIdx()
81  {
82  const index_t thread_id = ThisThreadBlock::GetThreadId();
83 
84  constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
88 
89  return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
90  }
91 
92  __device__ static auto CalculateAThreadOriginDataIndex()
93  {
94  const auto wave_idx = GetWaveIdx();
95 
96  const auto waveId_m = wave_idx[I0];
97 
98  const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
99 
100  return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]);
101  }
102 
103  __device__ static auto CalculateBThreadOriginDataIndex()
104  {
105  const auto wave_idx = GetWaveIdx();
106 
107  const auto waveId_n = wave_idx[I1];
108 
109  const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
110 
111  return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]);
112  }
113 
114  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
115  __device__ static auto
117  {
118  const auto wave_idx = GetWaveIdx();
119 
120  const auto waveId_m = wave_idx[I0];
121  const auto waveId_n = wave_idx[I1];
122 
123  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
124 
125  constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
126  make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
129 
130  constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
131  make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
134 
135  const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
136  make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
137  const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
138  make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
139 
140  return make_tuple(c_thread_m, c_thread_n);
141  }
142 
143  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
144  __device__ static auto
146  {
147  const auto wave_idx = GetWaveIdx();
148 
149  const auto waveId_m = wave_idx[I0];
150  const auto waveId_n = wave_idx[I1];
151 
152  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
153 
154  return make_tuple(Number<m0>{},
155  Number<n0>{},
156  waveId_m,
157  waveId_n,
158  blk_idx[I0],
159  blk_idx[I1],
160  blk_idx[I2],
161  blk_idx[I3]);
162  }
163 
165  {
166  static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
167  BK0NK1BlockDesc::IsKnownAtCompileTime(),
168  "wrong! Desc should be known at compile-time");
169 
170  static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
171  "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
172 
173  static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
174  "wrong!");
175  }
176 
177  __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
178  {
179  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
180 
181  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
182  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
183  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
184  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
185 
187  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
188  }
189 
190  __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
191  {
192  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
193 
194  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
195  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
196  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
197  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
198 
200  make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
201  }
202 
203  __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
204  {
205  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
207  Number<NRepeat>{},
208  Number<MWaves>{},
209  Number<NWaves>{},
210  Number<MPerXDL>{},
211  Number<NPerXDL>{}));
212 
213  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
214  }
215 
216  __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
217  {
218  constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
220  Number<MRepeat>{},
221  Number<NRepeat>{},
222  Number<MWaves>{},
223  Number<NWaves>{},
224  Number<MPerXDL>{},
225  Number<NPerXDL>{}));
226 
227  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
228  c_block_desc_g_m0_n0_m1_n1_m2_n2);
229  }
230 
231  template <typename CGridDesc_M_N>
232  __host__ __device__ static constexpr auto
233  MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
234  {
235  const auto M = c_grid_desc_m_n.GetLength(I0);
236  const auto N = c_grid_desc_m_n.GetLength(I1);
237 
238  const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
239  c_grid_desc_m_n,
240  make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
241  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
244 
245  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
246  }
247 
248  template <typename CGridDesc_G_M_N>
249  __host__ __device__ static constexpr auto
250  MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
251  {
252  const auto G = c_grid_desc_g_m_n.GetLength(I0);
253  const auto M = c_grid_desc_g_m_n.GetLength(I1);
254  const auto N = c_grid_desc_g_m_n.GetLength(I2);
255 
256  const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
257  c_grid_desc_g_m_n,
259  make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
260  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
263 
264  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
265  c_grid_desc_g_m0_n0_m1_n1_m2_n2);
266  }
267 
268  __host__ __device__ static constexpr auto MakeABlockDescriptor_M0_M1_M2_K()
269  {
271  AK0MK1BlockDesc{},
272  make_tuple(
278  }
279 
280  __host__ __device__ static constexpr auto MakeBBlockDescriptor_N0_N1_N2_K()
281  {
283  BK0NK1BlockDesc{},
284  make_tuple(
290  }
291 
294 
295  template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
296  __device__ void Run(const ABlockBuffer& a_block_buf,
297  const BBlockBuffer& b_block_buf,
298  CThreadBuffer& c_thread_buf) const
299  {
300  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
301  a_thread_desc_.GetElementSpaceSize());
302  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
303  b_thread_desc_.GetElementSpaceSize());
304 
305  static_for<0, MRepeat, 1>{}([&](auto m0) {
306  // read A
308  make_tuple(m0, I0, I0, I0),
309  a_block_buf,
311  make_tuple(I0, I0, I0, I0),
312  a_thread_buf);
313 
314  static_for<0, NRepeat, 1>{}([&](auto n0) {
315  // read B
317  make_tuple(n0, I0, I0, I0),
318  b_block_buf,
320  make_tuple(I0, I0, I0, I0),
321  b_thread_buf);
322 
323  static_for<0, KPerThread, KPack>{}([&](auto k) {
326 
327  static_for<0, KPack, 1>{}([&](auto i) {
328  a_thread_vec.template AsType<ComputeTypeA>()(i) = a_thread_buf
329  [Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
330  b_thread_vec.template AsType<ComputeTypeB>()(i) = b_thread_buf
331  [Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, k + i))>{}];
332  });
333 
334  using mfma_input_type_a =
335  typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
336  using mfma_input_type_b =
337  typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
338 
339  constexpr index_t c_offset =
340  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
341 
342  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type_a>(),
343  b_thread_vec.template AsType<mfma_input_type_b>(),
344  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
345  });
346  });
347  });
348  }
349 
350  protected:
351  // A[M0, M1, M2, KPerThread]
352  static constexpr auto a_thread_desc_ =
354 
355  // B[N0, N1, N2, KPerThread]
356  static constexpr auto b_thread_desc_ =
357  make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, Number<KPerThread>{}));
358 
359  // C[M, N, NumRegXdlops]
360  static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
361  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
362 
364  ComputeTypeA,
365  decltype(a_block_desc_m0_m1_m2_k),
366  decltype(a_thread_desc_),
369  3,
370  A_K1,
371  A_K1>;
372 
374  ComputeTypeB,
375  decltype(b_block_desc_n0_n1_n2_k),
376  decltype(b_thread_desc_),
379  3,
380  B_K1,
381  B_K1>;
382 
385 };
386 
387 // Note: To facilitate the inter-wave loop scheduler, we need to explicitly set the macro
388 // CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=1 as a few intrinsics are not yet available in
389 // the latest ROCm release. For unsupported compilers, inter-wave loop scheduler falls back to the
390 // default loop scheduler which is given by the macro CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING=0
391 template <index_t BlockSize,
392  typename FloatA,
393  typename FloatB,
394  typename FloatAcc,
395  typename AK0MK1BlockDesc,
396  typename BK0NK1BlockDesc,
397  index_t MPerXDL,
398  index_t NPerXDL,
399  index_t MRepeat,
400  index_t NRepeat,
401  index_t KPack,
402  typename ComputeTypeA = FloatA,
403  typename ComputeTypeB = FloatB,
407  FloatA,
408  FloatB,
409  FloatAcc,
410  AK0MK1BlockDesc,
411  BK0NK1BlockDesc,
412  MPerXDL,
413  NPerXDL,
414  MRepeat,
415  NRepeat,
416  KPack,
417  ComputeTypeA,
418  ComputeTypeB>
419 {
421  FloatA,
422  FloatB,
423  FloatAcc,
424  AK0MK1BlockDesc,
425  BK0NK1BlockDesc,
426  MPerXDL,
427  NPerXDL,
428  MRepeat,
429  NRepeat,
430  KPack,
431  ComputeTypeA,
432  ComputeTypeB>;
433 
434 #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
436  using Base::A_K1;
438  using Base::B_K1;
439  using Base::c_thread_buf_;
440  using Base::c_thread_desc_;
443  using Base::I0;
444  using Base::I1;
445  using Base::KPerThread;
446  using Base::xdlops_gemm;
447 
448  static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
449 
450  // 2-wave optimized blockwise gemm
451  template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
452  __device__ void Run(const ABlockBuffer& a_block_buf,
453  const BBlockBuffer& b_block_buf,
454  CThreadBuffer& c_thread_buf) const
455  {
456  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
457  a_thread_desc_.GetElementSpaceSize());
458  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
459  b_thread_desc_.GetElementSpaceSize());
460 
462  static_for<0, MRepeat, 1>{}([&](auto m0) {
463  // read A
465  make_tuple(m0, I0, I0, k),
466  a_block_buf,
468  make_tuple(m0, I0, I0, I0),
469  a_thread_buf);
470  });
471  static_for<0, NRepeat, 1>{}([&](auto n0) {
472  // read B
474  make_tuple(n0, I0, I0, k),
475  b_block_buf,
477  make_tuple(n0, I0, I0, I0),
478  b_thread_buf);
479  });
480  __builtin_amdgcn_sched_barrier(0);
481  // NOTE: Synchronize threads in a workgroup at the start of each MAC cluster, but except
482  // the first, as we can shorten non-MAC cluster a bit and there's no observable negative
483  // impact. The desired effect is waves in a workgroup executing MAC in sync. This avoids
484  // some out-of-sync waves hijacking MAC resource from other workgroups and reducing the
485  // chance of latency hiding by waiting for the rest of the workgroup at the eventual
486  // sync point.
487  if constexpr(k.value != 0 || KPerInnerLoop == KPerThread)
488  {
489 #ifdef __gfx12__
490  asm volatile("\
491  s_barrier_signal -1 \n \
492  s_barrier_wait -1 \
493  " ::);
494 #else
495  asm volatile("s_barrier" ::);
496 #endif
497  __builtin_amdgcn_sched_barrier(0);
498  }
499  static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
500  static_for<0, MRepeat, 1>{}([&](auto m0) {
501  static_for<0, NRepeat, 1>{}([&](auto n0) {
502  vector_type<ComputeTypeA, KPack> a_thread_vec;
503  vector_type<ComputeTypeB, KPack> b_thread_vec;
504 
505  static_for<0, KPack, 1>{}([&](auto i) {
506  a_thread_vec.template AsType<ComputeTypeA>()(i) =
507  a_thread_buf[Number<a_thread_desc_.CalculateOffset(
508  make_tuple(m0, 0, 0, k_ + i))>{}];
509  b_thread_vec.template AsType<ComputeTypeB>()(i) =
510  b_thread_buf[Number<b_thread_desc_.CalculateOffset(
511  make_tuple(n0, 0, 0, k_ + i))>{}];
512  });
513 
514  using mfma_input_type_a =
515  typename vector_type<ComputeTypeA, xdlops_gemm.K1PerXdlops>::type;
516  using mfma_input_type_b =
517  typename vector_type<ComputeTypeB, xdlops_gemm.K1PerXdlops>::type;
518 
519  constexpr index_t c_offset =
520  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
521 
522  // The block_sync_lds() here performs double duty:
523  // A) safeguard against data hazard because barrier from blockwise_gemm is
524  // moved here B) reduce VMEM FIFO congestion by applying small delays to
525  // different wavefronts It is performed near the end of MAC cluster to
526  // minimize lgkmcnt penalty
527  if constexpr(k.value == KPerThread - KPerInnerLoop &&
528  k_.value == KPerInnerLoop - KPack && m0.value == MRepeat - 1 &&
529  n0.value == NRepeat - 1)
530  {
531  __builtin_amdgcn_sched_barrier(0);
532  block_sync_lds();
533  __builtin_amdgcn_sched_barrier(0);
534  }
535 
536  // TODO: insert setprio in more precise manner since we
537  // could have more than >1 MFMA instructions in single call
538  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type_a>(),
539  b_thread_vec.template AsType<mfma_input_type_b>(),
540  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
541  if constexpr(k_.value == 0 && m0.value == 0 && n0.value == 0)
542  {
543  __builtin_amdgcn_sched_barrier(0);
544  __builtin_amdgcn_s_setprio(1);
545  __builtin_amdgcn_sched_barrier(0);
546  }
547  });
548  });
549  });
550  __builtin_amdgcn_sched_barrier(0);
551  __builtin_amdgcn_s_setprio(0);
552  __builtin_amdgcn_sched_barrier(0);
553  });
554  }
555 
556  protected:
557  // A[M0, M1, M2, KPerInnerLoop]
558  static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
559  make_tuple(Number<MRepeat>{}, I1, I1, Number<KPerInnerLoop>{}));
560 
561  // B[N0, N1, N2, KPerInnerLoop]
562  static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
563  make_tuple(Number<NRepeat>{}, I1, I1, Number<KPerInnerLoop>{}));
564 
565  using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatA,
566  ComputeTypeA,
567  decltype(a_block_desc_m0_m1_m2_k),
568  decltype(a_thread_desc_),
569  Sequence<1, 1, 1, KPerInnerLoop>,
570  Sequence<0, 1, 2, 3>,
571  3,
572  A_K1,
573  A_K1>;
574 
575  using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<FloatB,
576  ComputeTypeB,
577  decltype(b_block_desc_n0_n1_n2_k),
578  decltype(b_thread_desc_),
579  Sequence<1, 1, 1, KPerInnerLoop>,
580  Sequence<0, 1, 2, 3>,
581  3,
582  B_K1,
583  B_K1>;
584 
587 
588 #endif // #if CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING
589 };
590 
591 template <index_t BlockSize,
592  typename FloatA,
593  typename FloatB,
594  typename FloatAcc,
595  typename AK0MK1BlockDesc,
596  typename BK0NK1BlockDesc,
597  index_t MPerXDL,
598  index_t NPerXDL,
599  index_t MRepeat,
600  index_t NRepeat,
601  index_t KPack,
602  LoopScheduler LoopSched,
603  typename ComputeTypeA = FloatA,
604  typename ComputeTypeB = FloatB>
606 {
607  if constexpr(LoopSched == LoopScheduler::Default)
608  {
610  FloatA,
611  FloatB,
612  FloatAcc,
613  AK0MK1BlockDesc,
614  BK0NK1BlockDesc,
615  MPerXDL,
616  NPerXDL,
617  MRepeat,
618  NRepeat,
619  KPack,
620  ComputeTypeA,
621  ComputeTypeB>{};
622  }
623  else if constexpr(LoopSched == LoopScheduler::Interwave)
624  {
626  FloatA,
627  FloatB,
628  FloatAcc,
629  AK0MK1BlockDesc,
630  BK0NK1BlockDesc,
631  MPerXDL,
632  NPerXDL,
633  MRepeat,
634  NRepeat,
635  KPack,
636  ComputeTypeA,
637  ComputeTypeB>{};
638  }
639 };
640 
651 template <
652  index_t BlockSize,
653  typename FloatAB,
654  typename FloatAcc,
655  typename ATileDesc,
656  typename BTileDesc,
657  typename AMmaTileDesc,
658  typename BMmaTileDesc,
659  index_t MPerBlock,
660  index_t NPerBlock,
661  index_t KPerBlock,
662  index_t MPerXDL,
663  index_t NPerXDL,
664  index_t MRepeat,
665  index_t NRepeat,
666  index_t KPack,
667  bool TransposeC = false,
668  index_t AMmaKStride =
669  KPack * XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops,
670  index_t BMmaKStride =
671  KPack * XdlopsGemm<FloatAB, MPerXDL, NPerXDL, KPack, FloatAB, TransposeC>{}.K0PerXdlops>
673 {
674  static constexpr auto I0 = Number<0>{};
675  static constexpr auto I1 = Number<1>{};
676  static constexpr auto I2 = Number<2>{};
677  static constexpr auto I3 = Number<3>{};
678 
680 
681  static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
682  static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
683  static constexpr index_t WaveSize = BlockSize / MWaves / NWaves;
684 
685  static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
686  static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
687  static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
688  static constexpr index_t B_K1 = BTileDesc{}.GetLength(I2);
689 
690  static constexpr auto xdlops_gemm =
692 
693  static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
694 
695  static_assert(KPerThread % KPack == 0,
696  "Wrong KPack setting; try increasing KPerThread or decreasing KPack");
697 
699  FloatAcc,
700  MRepeat * NRepeat,
701  xdlops_gemm.GetRegSizePerXdlops(),
702  true>
704 
705  __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
706 
707  __device__ static auto GetWaveIdx()
708  {
709  const index_t thread_id = ThisThreadBlock::GetThreadId();
710 
711  constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
715 
716  return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
717  }
718 
719  __device__ static auto CalculateAThreadOriginDataIndex()
720  {
721  const auto wave_idx = GetWaveIdx();
722 
723  const auto waveId_m = wave_idx[I0];
724 
725  const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
726 
727  return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPack * xdlops_a_idx[I0]);
728  }
729 
730  __device__ static auto CalculateBThreadOriginDataIndex()
731  {
732  const auto wave_idx = GetWaveIdx();
733 
734  const auto waveId_n = wave_idx[I1];
735 
736  const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
737 
738  return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPack * xdlops_b_idx[I0]);
739  }
740 
741  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
742  __device__ static auto
744  {
745  const auto wave_idx = GetWaveIdx();
746 
747  const auto waveId_m = wave_idx[I0];
748  const auto waveId_n = wave_idx[I1];
749 
750  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
751 
752  constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
753  make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
756 
757  constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
758  make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
761 
762  const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
763  make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
764  const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
765  make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
766 
767  return make_tuple(c_thread_m, c_thread_n);
768  }
769 
770  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
771  __device__ static auto
773  {
774  const auto wave_idx = GetWaveIdx();
775 
776  const auto waveId_m = wave_idx[I0];
777  const auto waveId_n = wave_idx[I1];
778 
779  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
780 
781  return make_tuple(
782  m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
783  }
784 
786 
789  : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
790  {
791 #if defined(__HIP_DEVICE_COMPILE__)
792  static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
793  "wrong! Desc should be known at compile-time");
794 
795  static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
796  "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
797 
798  static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
799  "wrong!");
800 #endif
801  }
802 
803  // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
804  __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
805  {
806  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
807 
808  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
809  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
810  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
811  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
812 
814  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
815  }
816 
817  // XDL output supporting C_xdl = A_xdl * B_xdl
818  __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
819  {
820  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
821 
822  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
823  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
824  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
825  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
826 
828  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
829  }
830 
831  __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
832  {
833  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
834 
835  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
836  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
837  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
838  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
839 
841  make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
842  }
843 
844  // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
845  __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
846  {
847  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
849  Number<NRepeat>{},
850  Number<MWaves>{},
851  Number<NWaves>{},
852  Number<MPerXDL>{},
853  Number<NPerXDL>{}));
854 
855  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
856  }
857 
858  // XDL output supporting C_xdl = A_xdl * B_xdl
859  __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
860  {
861  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
863  Number<NRepeat>{},
864  Number<MWaves>{},
865  Number<NWaves>{},
866  Number<MPerXDL>{},
867  Number<NPerXDL>{}));
868 
869  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
870  }
871 
872  __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
873  {
874  constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
876  Number<MRepeat>{},
877  Number<NRepeat>{},
878  Number<MWaves>{},
879  Number<NWaves>{},
880  Number<MPerXDL>{},
881  Number<NPerXDL>{}));
882 
883  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
884  c_block_desc_g_m0_n0_m1_n1_m2_n2);
885  }
886 
887  template <typename CGridDesc_M_N>
888  __host__ __device__ static constexpr auto
889  MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
890  {
891  const auto M = c_grid_desc_m_n.GetLength(I0);
892  const auto N = c_grid_desc_m_n.GetLength(I1);
893 
894  const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
895  c_grid_desc_m_n,
896  make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
897  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
900 
901  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
902  }
903 
904  template <typename CGridDesc_G_M_N>
905  __host__ __device__ static constexpr auto
906  MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
907  {
908  const auto G = c_grid_desc_g_m_n.GetLength(I0);
909  const auto M = c_grid_desc_g_m_n.GetLength(I1);
910  const auto N = c_grid_desc_g_m_n.GetLength(I2);
911 
912  const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
913  c_grid_desc_g_m_n,
915  make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
916  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
919 
920  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
921  c_grid_desc_g_m0_n0_m1_n1_m2_n2);
922  }
923 
924  static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
925  static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
926 
927  template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
928  __device__ void Run(const ABlockBuffer& a_block_buf,
929  const BBlockBuffer& b_block_buf,
930  CThreadBuffer& c_thread_buf) const
931  {
932  auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
933  a_thread_desc_.GetElementSpaceSize());
934  auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, FloatAB>(
935  b_thread_desc_.GetElementSpaceSize());
936 
937  static_for<0, KPerThread / KPack, 1>{}([&](auto k) { // k=0,1,2 instead of k=0,kpack*1, ...
938  static_for<0, MRepeat, 1>{}([&](auto m0) {
939  // read A
942  a_block_buf,
944  make_tuple(I0, I0, I0, I0),
945  a_thread_buf);
946 
947  static_for<0, NRepeat, 1>{}([&](auto n0) {
948  // read B
951  b_block_buf,
953  make_tuple(I0, I0, I0, I0),
954  b_thread_buf);
955  vector_type<FloatAB, KPack> a_thread_vec;
956  vector_type<FloatAB, KPack> b_thread_vec;
957 
958  static_for<0, KPack, 1>{}([&](auto i) {
959  a_thread_vec.template AsType<FloatAB>()(i) = a_thread_buf
960  [Number<a_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, i))>{}];
961  b_thread_vec.template AsType<FloatAB>()(i) = b_thread_buf
962  [Number<b_thread_desc_.CalculateOffset(make_tuple(0, 0, 0, i))>{}];
963  });
964 
965  using mfma_input_type =
966  typename vector_type<FloatAB, xdlops_gemm.K1PerXdlops>::type;
967 
968  constexpr index_t c_offset =
969  c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
970 
971  xdlops_gemm.Run(a_thread_vec.template AsType<mfma_input_type>(),
972  b_thread_vec.template AsType<mfma_input_type>(),
973  c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
974  });
975  });
976  });
977  }
978 
979  protected:
980  // A[M0, M1, M2, KPack]
981  static constexpr auto a_thread_desc_ =
983 
984  // B[N0, N1, N2, KPack]
985  static constexpr auto b_thread_desc_ =
987 
988  // C[M, N, NumRegXdlops]
990  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
991 
993  FloatAB,
994  decltype(a_block_desc_m0_m1_m2_k),
995  decltype(a_thread_desc_),
998  3,
999  A_K1,
1000  A_K1>;
1001 
1003  FloatAB,
1004  decltype(b_block_desc_n0_n1_n2_k),
1005  decltype(b_thread_desc_),
1008  3,
1009  B_K1,
1010  B_K1>;
1011 
1014 };
1015 
1016 } // namespace ck
#define CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS
Definition: ck.hpp:208
__host__ constexpr __device__ T max(T x)
Definition: math.hpp:84
Definition: ck.hpp:267
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
Definition: blockwise_gemm_xdlops.hpp:605
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_merge_transform_v3_division_mod(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:84
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:298
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
__device__ void block_sync_lds()
Definition: synchronization.hpp:10
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_smfmac_xdlops.hpp:44
static constexpr index_t KPerBlock
Definition: blockwise_gemm_smfmac_xdlops.hpp:60
static constexpr index_t A_K1
Definition: blockwise_gemm_smfmac_xdlops.hpp:65
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_smfmac_xdlops.hpp:429
__host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1()
Definition: blockwise_gemm_xdlops.hpp:164
static constexpr auto I2
Definition: blockwise_gemm_smfmac_xdlops.hpp:47
static __device__ auto CalculateBThreadOriginDataIndex()
Definition: blockwise_gemm_xdlops.hpp:103
static constexpr index_t WaveSize
Definition: blockwise_gemm_smfmac_xdlops.hpp:56
__host__ static constexpr __device__ auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition: blockwise_gemm_xdlops.hpp:250
static constexpr index_t KPerThread
Definition: blockwise_gemm_smfmac_xdlops.hpp:71
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: blockwise_gemm_xdlops.hpp:233
static constexpr index_t B_K1
Definition: blockwise_gemm_smfmac_xdlops.hpp:66
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:177
static __device__ auto CalculateAThreadOriginDataIndex()
Definition: blockwise_gemm_xdlops.hpp:92
static constexpr index_t MPerBlock
Definition: blockwise_gemm_smfmac_xdlops.hpp:58
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, FloatAcc, MRepeat *NRepeat, xdlops_gemm.GetRegSizePerXdlops(), true > c_thread_buf_
Definition: blockwise_gemm_smfmac_xdlops.hpp:78
static constexpr auto b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_smfmac_xdlops.hpp:297
static constexpr index_t NPerBlock
Definition: blockwise_gemm_smfmac_xdlops.hpp:59
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_xdlops.hpp:145
static constexpr auto I0
Definition: blockwise_gemm_smfmac_xdlops.hpp:45
ThreadwiseTensorSliceTransfer_v4< FloatA, FloatA, decltype(a_block_desc_m0_m1_m2_k), decltype(a_thread_desc_), Sequence< 1, 1, 1, KPerThread >, Sequence< 0, 1, 2, 3 >, 3, A_K1, A_K1 > AThreadCopy
Definition: blockwise_gemm_smfmac_xdlops.hpp:440
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_smfmac_xdlops.hpp:421
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_smfmac_xdlops.hpp:453
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: blockwise_gemm_smfmac_xdlops.hpp:50
__host__ static constexpr __device__ auto MakeBBlockDescriptor_N0_N1_N2_K()
Definition: blockwise_gemm_xdlops.hpp:280
__host__ static constexpr __device__ auto MakeABlockDescriptor_M0_M1_M2_K()
Definition: blockwise_gemm_xdlops.hpp:268
static constexpr auto a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_smfmac_xdlops.hpp:296
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_smfmac_xdlops.hpp:452
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_xdlops.hpp:78
__host__ static constexpr __device__ auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:216
__host__ static constexpr __device__ auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:190
static constexpr index_t NWaves
Definition: blockwise_gemm_smfmac_xdlops.hpp:53
static constexpr auto xdlops_gemm
Definition: blockwise_gemm_smfmac_xdlops.hpp:68
static constexpr index_t B_K0
Definition: blockwise_gemm_smfmac_xdlops.hpp:64
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_smfmac_xdlops.hpp:425
static constexpr index_t A_K0
Definition: blockwise_gemm_smfmac_xdlops.hpp:63
__device__ void Run(const ABlockBuffer &a_block_buf, const BBlockBuffer &b_block_buf, CThreadBuffer &c_thread_buf) const
Definition: blockwise_gemm_xdlops.hpp:296
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:203
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_xdlops.hpp:116
ThreadwiseTensorSliceTransfer_v4< FloatB, ComputeTypeB, decltype(b_block_desc_n0_n1_n2_k), decltype(b_thread_desc_), Sequence< 1, 1, 1, KPerThread >, Sequence< 0, 1, 2, 3 >, 3, B_K1, B_K1 > BThreadCopy
Definition: blockwise_gemm_smfmac_xdlops.hpp:450
static constexpr auto I3
Definition: blockwise_gemm_smfmac_xdlops.hpp:48
static constexpr auto I1
Definition: blockwise_gemm_smfmac_xdlops.hpp:46
static __device__ auto GetWaveIdx()
Definition: blockwise_gemm_xdlops.hpp:80
static constexpr index_t MWaves
Definition: blockwise_gemm_smfmac_xdlops.hpp:52
Blockwise gemm.
Definition: blockwise_gemm_xdlops.hpp:673
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:818
__host__ static constexpr __device__ auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:831
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_xdlops.hpp:705
static constexpr index_t A_K0
Definition: blockwise_gemm_xdlops.hpp:685
static constexpr auto xdlops_gemm
Definition: blockwise_gemm_xdlops.hpp:690
static constexpr index_t A_K1
Definition: blockwise_gemm_xdlops.hpp:687
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_xdlops.hpp:985
static __device__ auto GetWaveIdx()
Definition: blockwise_gemm_xdlops.hpp:707
static constexpr auto I1
Definition: blockwise_gemm_xdlops.hpp:675
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:859
static constexpr index_t NWaves
Definition: blockwise_gemm_xdlops.hpp:682
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_xdlops.hpp:925
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition: blockwise_gemm_xdlops.hpp:845
__device__ void Run(const ABlockBuffer &a_block_buf, const BBlockBuffer &b_block_buf, CThreadBuffer &c_thread_buf) const
Definition: blockwise_gemm_xdlops.hpp:928
__host__ __device__ BlockwiseGemmXdlops_v2(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Definition: blockwise_gemm_xdlops.hpp:787
static constexpr index_t B_K0
Definition: blockwise_gemm_xdlops.hpp:686
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_xdlops.hpp:743
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: blockwise_gemm_xdlops.hpp:889
static constexpr auto I2
Definition: blockwise_gemm_xdlops.hpp:676
__host__ static constexpr __device__ auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_xdlops.hpp:872
decltype(CalculateAThreadOriginDataIndex()) Tuple4
Definition: blockwise_gemm_xdlops.hpp:785
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_xdlops.hpp:981
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_xdlops.hpp:989
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition: blockwise_gemm_xdlops.hpp:804
static constexpr auto I3
Definition: blockwise_gemm_xdlops.hpp:677
static constexpr index_t WaveSize
Definition: blockwise_gemm_xdlops.hpp:683
__host__ static constexpr __device__ auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition: blockwise_gemm_xdlops.hpp:906
static __device__ auto CalculateAThreadOriginDataIndex()
Definition: blockwise_gemm_xdlops.hpp:719
static __device__ auto CalculateBThreadOriginDataIndex()
Definition: blockwise_gemm_xdlops.hpp:730
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_xdlops.hpp:924
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: blockwise_gemm_xdlops.hpp:679
static constexpr index_t MWaves
Definition: blockwise_gemm_xdlops.hpp:681
static constexpr index_t B_K1
Definition: blockwise_gemm_xdlops.hpp:688
static constexpr index_t KPerThread
Definition: blockwise_gemm_xdlops.hpp:693
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_xdlops.hpp:1012
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, FloatAcc, MRepeat *NRepeat, xdlops_gemm.GetRegSizePerXdlops(), true > c_thread_buf_
Definition: blockwise_gemm_xdlops.hpp:696
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_xdlops.hpp:772
static constexpr auto I0
Definition: blockwise_gemm_xdlops.hpp:674
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_xdlops.hpp:1013
Definition: sequence.hpp:43
Definition: static_buffer.hpp:75
static __device__ index_t GetThreadId()
Definition: thread_group.hpp:19
static constexpr __device__ index_t GetNumOfThread()
Definition: thread_group.hpp:15
Definition: threadwise_tensor_slice_transfer.hpp:1260
__device__ void Run(const SrcDesc &, const SrcRefToOriginDisplacement &, const SrcBuffer &src_buf, const DstDesc &, const DstOriginIdx &, DstBuffer &dst_buf) const
Definition: threadwise_tensor_slice_transfer.hpp:1293
Definition: xdlops_gemm.hpp:1711
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33
Definition: dtype_vector.hpp:10