/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_base.hpp Source File
blockwise_gemm_pipeline_xdlops_base.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
11 
12 namespace ck {
13 
14 template <index_t BlockSize,
15  typename ADataType,
16  typename BDataType,
17  typename ComputeDataType,
18  typename AccDataType,
19  typename ATileDesc,
20  typename BTileDesc,
21  typename AMmaTileDesc,
22  typename BMmaTileDesc,
23  index_t ABlockTransferSrcScalarPerVector,
24  index_t BBlockTransferSrcScalarPerVector,
25  index_t MPerBlock,
26  index_t NPerBlock,
27  index_t KPerBlock,
28  index_t MPerXDL,
29  index_t NPerXDL,
30  index_t MRepeat,
31  index_t NRepeat,
32  index_t KPack,
33  bool TransposeC = false>
35 {
36  static constexpr auto I0 = Number<0>{};
37  static constexpr auto I1 = Number<1>{};
38  static constexpr auto I2 = Number<2>{};
39  static constexpr auto I3 = Number<3>{};
40 
42 
43  // Hardcode to 64, as HIP-provided "WarpSize" would return 32 on RDNA GPUs.
44  static constexpr index_t WaveSize = 64;
45 
46  static constexpr index_t A_K0 = ATileDesc{}.GetLength(I0);
47  static constexpr index_t B_K0 = BTileDesc{}.GetLength(I0);
48  static constexpr index_t A_K1 = ATileDesc{}.GetLength(I2);
49  static constexpr index_t B_K1 =
50  BTileDesc{}.GetLength(Number < BTileDesc{}.GetNumOfDimension() == 4 ? 3 : 2 > {});
51 
52  static constexpr auto xdlops_gemm =
54 
55  static constexpr index_t AMmaKStride = KPack;
56  static constexpr index_t BMmaKStride = KPack;
57 
58  static constexpr index_t KPerThread = KPerBlock / xdlops_gemm.K0PerXdlops;
59  static constexpr index_t KRepeat = KPerThread / KPack;
60  static constexpr index_t KPerInnerLoop = KPack;
61 
62  static constexpr index_t KGroup = []() {
64  // On gfx950, we have mfma that required 32 f8 elements as input,
65  // splited into 2 groups of 16 f8 elements.
66  // the 2 groups is not contiguous in the B preshuffed layout.
67  // and we do not want it to be contiguous in the B preshuffled layout
68  // because a memory instruction can only read 16 f8 elements at a time.
69  return ((MPerXDL == 16 && MPerXDL == 16 && xdlops_gemm.KPerXdlops == 128) ||
70  (MPerXDL == 32 && MPerXDL == 32 && xdlops_gemm.KPerXdlops == 64))
71  ? 2
72  : 1;
73  else
74  return 1;
75  }();
76 
77  static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL);
78  static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL);
79 
82  MPerBlock,
83  NPerBlock,
84  KPerBlock,
85  ABlockTransferSrcScalarPerVector,
86  BBlockTransferSrcScalarPerVector,
87  A_K1,
88  B_K1,
89  A_K1,
90  B_K1,
91  MRepeat,
92  NRepeat,
93  MPerXDL,
94  NPerXDL,
95  xdlops_gemm.KPerXdlops>;
96 
97  static_assert(KPerThread % KPack == 0,
98  "Wrong KPack setting; try increasing KPerThread or decreasing KPack");
99 
101  AccDataType,
102  MRepeat * NRepeat,
103  xdlops_gemm.GetRegSizePerXdlops(),
104  true>
106 
107  __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
108 
109  __device__ static auto GetWaveIdx()
110  {
111  const index_t thread_id = ThisThreadBlock::GetThreadId();
112 
113  constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
117 
118  return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
119  }
120 
121  __device__ static auto CalculateAThreadOriginDataIndex()
122  {
123  const auto wave_idx = GetWaveIdx();
124 
125  const auto waveId_m = wave_idx[I0];
126 
127  const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
128 
129  return make_tuple(0, waveId_m, xdlops_a_idx[I1], KPerThread * xdlops_a_idx[I0]);
130  }
131 
132  __device__ static auto CalculateAThreadOriginDataIndex6D()
133  {
134  const auto wave_idx = GetWaveIdx();
135 
136  const auto waveId_m = wave_idx[I0];
137 
138  const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex();
139 
140  return make_tuple(0, waveId_m, xdlops_a_idx[I1], 0, xdlops_a_idx[I0], 0);
141  }
142 
143  __device__ static auto CalculateBThreadOriginDataIndex()
144  {
145  const auto wave_idx = GetWaveIdx();
146 
147  const auto waveId_n = wave_idx[I1];
148 
149  const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex();
150 
151  return make_tuple(0, waveId_n, xdlops_b_idx[I1], KPerThread * xdlops_b_idx[I0]);
152  }
153 
154  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
155  __device__ static auto
157  {
158  const auto wave_idx = GetWaveIdx();
159 
160  const auto waveId_m = wave_idx[I0];
161  const auto waveId_n = wave_idx[I1];
162 
163  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i);
164 
165  constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor(
166  make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))),
169 
170  constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor(
171  make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))),
174 
175  const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex(
176  make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
177  const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex(
178  make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
179 
180  return make_tuple(c_thread_m, c_thread_n);
181  }
182 
183  template <index_t m0, index_t n0, index_t xdlops_i, index_t blk_i>
184  __device__ static auto
186  {
187  const auto wave_idx = GetWaveIdx();
188 
189  const auto waveId_m = wave_idx[I0];
190  const auto waveId_n = wave_idx[I1];
191 
192  const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk4D(xdlops_i, blk_i);
193 
194  return make_tuple(
195  m0, n0, waveId_m, waveId_n, blk_idx[I0], blk_idx[I1], blk_idx[I2], blk_idx[I3]);
196  }
197 
199 
217  __host__ __device__
220  : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
221  {
222  static_assert(AMmaTileDesc::IsKnownAtCompileTime() && BMmaTileDesc::IsKnownAtCompileTime(),
223  "wrong! Desc should be known at compile-time");
224 
225  static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
226  "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize\n");
227 
228  static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0,
229  "wrong!");
230  }
231 
232  // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
233  __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
234  {
235  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
236 
237  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
238  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
239  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
240  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
241 
243  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, N, M0, M1, M2));
244  }
245 
246  // XDL output supporting C_xdl = A_xdl * B_xdl
247  __host__ __device__ static constexpr auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
248  {
249  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
250 
251  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
252  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
253  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
254  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
255 
257  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
258  }
259 
260  __host__ __device__ static constexpr auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
261  {
262  constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths();
263 
264  constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0];
265  constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1];
266  constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2];
267  constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3];
268 
270  make_tuple(I1, Number<MRepeat>{}, Number<NRepeat>{}, I1, I1, M0, M1, M2, N));
271  }
272 
273  // transposed XDL output supporting C_xdl' = B_xdl' * A_xdl'
274  __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
275  {
276  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
278  Number<NRepeat>{},
279  Number<MWaves>{},
280  Number<NWaves>{},
281  Number<MPerXDL>{},
282  Number<NPerXDL>{}));
283 
284  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(c_block_desc_m0_n0_m1_n1_m2_n2);
285  }
286 
287  // XDL output supporting C_xdl = A_xdl * B_xdl
288  __host__ __device__ static constexpr auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
289  {
290  constexpr auto c_block_desc_m0_n0_m1_n1_m2_n2 =
292  Number<NRepeat>{},
293  Number<MWaves>{},
294  Number<NWaves>{},
295  Number<MPerXDL>{},
296  Number<NPerXDL>{}));
297 
298  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_block_desc_m0_n0_m1_n1_m2_n2);
299  }
300 
301  __host__ __device__ static constexpr auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
302  {
303  constexpr auto c_block_desc_g_m0_n0_m1_n1_m2_n2 =
305  Number<MRepeat>{},
306  Number<NRepeat>{},
307  Number<MWaves>{},
308  Number<NWaves>{},
309  Number<MPerXDL>{},
310  Number<NPerXDL>{}));
311 
312  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
313  c_block_desc_g_m0_n0_m1_n1_m2_n2);
314  }
315 
316  template <typename CGridDesc_M_N>
317  __host__ __device__ static constexpr auto
318  MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N& c_grid_desc_m_n)
319  {
320  const auto M = c_grid_desc_m_n.GetLength(I0);
321  const auto N = c_grid_desc_m_n.GetLength(I1);
322 
323  const auto c_grid_desc_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
324  c_grid_desc_m_n,
325  make_tuple(make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
326  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
329 
330  return xdlops_gemm.MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m0_n0_m1_n1_m2_n2);
331  }
332 
333  template <typename CGridDesc_G_M_N>
334  __host__ __device__ static constexpr auto
335  MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N& c_grid_desc_g_m_n)
336  {
337  const auto G = c_grid_desc_g_m_n.GetLength(I0);
338  const auto M = c_grid_desc_g_m_n.GetLength(I1);
339  const auto N = c_grid_desc_g_m_n.GetLength(I2);
340 
341  const auto c_grid_desc_g_m0_n0_m1_n1_m2_n2 = transform_tensor_descriptor(
342  c_grid_desc_g_m_n,
344  make_unmerge_transform(make_tuple(M / (MWaves * MPerXDL), MWaves, MPerXDL)),
345  make_unmerge_transform(make_tuple(N / (NWaves * NPerXDL), NWaves, NPerXDL))),
348 
349  return xdlops_gemm.MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
350  c_grid_desc_g_m0_n0_m1_n1_m2_n2);
351  }
352  __host__ __device__ static constexpr auto GetCThreadDesc() { return c_thread_desc_; }
353  static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k;
354  static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k;
355 
356  protected:
357  // M1, N1 as double buffer index
358  // Read buffer + Compute buffer
359  // A[M0, M1, M2, KPack]
361  make_tuple(Number<MRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
362  make_tuple(
363  Number<KPack>{}, Number<KRepeat * MRepeat * KPack>{}, Number<MRepeat * KPack>{}, I1));
364 
365  // B[N0, N1, N2, KPack]
367  make_tuple(Number<NRepeat>{}, I1, Number<KRepeat>{}, Number<KPack>{}),
368  make_tuple(
369  Number<KPack>{}, Number<KRepeat * NRepeat * KPack>{}, Number<NRepeat * KPack>{}, I1));
370 
371  // C[M, N, NumRegXdlops]
373  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
374 
376  ComputeDataType,
377  decltype(a_block_desc_m0_m1_m2_k),
378  decltype(a_thread_desc_),
381  3,
382  A_K1,
383  A_K1>;
384 
386  ComputeDataType,
387  decltype(b_block_desc_n0_n1_n2_k),
388  decltype(b_thread_desc_),
391  3,
392  B_K1,
393  B_K1>;
394 
397 };
398 
399 } // namespace ck
Definition: ck.hpp:269
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
f8_fnuz_t f8_t
Definition: amd_ck_fp8.hpp:1737
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
__host__ constexpr __device__ auto make_pass_through_transform(const LowLength &low_length)
Definition: multi_index_transform_helper.hpp:12
constexpr bool is_same_v
Definition: type.hpp:283
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:300
__host__ constexpr __device__ auto transform_tensor_descriptor(const OldTensorDescriptor &old_tensor_desc, const NewTransforms &new_transforms, NewLowerDimensionOldVisibleIdss, NewUpperDimensionNewVisibleIdss)
Definition: tensor_descriptor.hpp:319
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:35
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, AccDataType, MRepeat *NRepeat, xdlops_gemm.GetRegSizePerXdlops(), true > c_thread_buf_
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:98
static constexpr index_t NWaves
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:78
__host__ __device__ BlockwiseGemmXdlops_pipeline_base(Tuple4 a_origin=CalculateAThreadOriginDataIndex(), Tuple4 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmXdlops_pipeline_base.
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:218
static constexpr index_t MWaves
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:77
static constexpr index_t A_K0
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:46
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:274
__host__ static constexpr __device__ auto GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:288
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:372
static constexpr auto xdlops_gemm
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:52
static __device__ auto CalculateBThreadOriginDataIndex()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:143
static constexpr BMmaTileDesc b_block_desc_n0_n1_n2_k
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:354
static __device__ auto GetWaveIdx()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:109
static constexpr index_t KGroup
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:62
static constexpr auto I1
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:37
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_N2_N3_N4()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:233
static constexpr index_t AMmaKStride
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:55
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:396
static __device__ auto CalculateAThreadOriginDataIndex()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:121
static __device__ auto CalculateAThreadOriginDataIndex6D()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:132
static constexpr index_t WaveSize
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:44
static constexpr index_t B_K1
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:49
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:41
__host__ static constexpr __device__ auto MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_G_M_N &c_grid_desc_g_m_n)
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:335
__host__ static constexpr __device__ auto MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CGridDesc_M_N &c_grid_desc_m_n)
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:318
static constexpr index_t KPerInnerLoop
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:60
static constexpr auto I0
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:36
__host__ static constexpr __device__ auto GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:247
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:156
static __device__ auto CalculateCThreadOriginDataIndex8D(Number< m0 >, Number< n0 >, Number< xdlops_i >, Number< blk_i >)
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:185
__host__ static constexpr __device__ auto GetCThreadDesc()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:352
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:360
static constexpr index_t KRepeat
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:59
static constexpr AMmaTileDesc a_block_desc_m0_m1_m2_k
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:353
__host__ static constexpr __device__ auto GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:301
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:366
static constexpr auto I2
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:38
static constexpr auto I3
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:39
static constexpr index_t A_K1
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:48
static constexpr index_t BMmaKStride
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:56
decltype(CalculateAThreadOriginDataIndex()) Tuple4
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:198
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:395
static constexpr index_t KPerThread
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:58
__host__ static constexpr __device__ auto GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:260
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:107
static constexpr index_t B_K0
Definition: blockwise_gemm_pipeline_xdlops_base.hpp:47
Definition: blockwise_gemm_pipeline_xdlops.hpp:34
Definition: sequence.hpp:43
Definition: static_buffer.hpp:75
static __device__ index_t GetThreadId()
Definition: thread_group.hpp:19
static constexpr __device__ index_t GetNumOfThread()
Definition: thread_group.hpp:15
Definition: xdlops_gemm.hpp:1399
Definition: integral_constant.hpp:20