/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-7.0.0/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp Source File
blockwise_gemm_pipeline_wmmaops_base.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
12 
13 namespace ck {
14 
15 template <index_t BlockSize,
16  typename ADataType,
17  typename BDataType,
18  typename ComputeTypeA,
19  typename ComputeTypeB,
20  typename AccDataType,
21  typename AWmmaTileDesc,
22  typename BWmmaTileDesc,
23  index_t ABlockTransferSrcScalarPerVector,
24  index_t BBlockTransferSrcScalarPerVector,
25  index_t MPerBlock,
26  index_t NPerBlock,
27  index_t KPerBlock,
28  index_t MPerWmma,
29  index_t NPerWmma,
30  index_t MRepeat,
31  index_t NRepeat,
32  index_t KPack,
33  bool TransposeC = false>
35 {
36  static constexpr auto I0 = Number<0>{};
37  static constexpr auto I1 = Number<1>{};
38  static constexpr auto I2 = Number<2>{};
39  static constexpr auto I3 = Number<3>{};
40  static constexpr auto I5 = Number<5>{};
41 
43 
44  static constexpr index_t WaveSize = 32;
45 
46  static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
47  static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
48 
49 #if defined(__gfx12__)
50  static constexpr index_t A_KRow = 2;
51  static constexpr index_t B_KRow = 2;
52 #else
53  static constexpr index_t A_KRow = 1;
54  static constexpr index_t B_KRow = 1;
55 #endif
56 
57  static constexpr index_t A_K1 = AWmmaTileDesc{}.GetLength(I5);
58  static constexpr index_t B_K1 = BWmmaTileDesc{}.GetLength(I5);
59 
60  static_assert(KPack % (A_K1 * A_KRow) == 0, "wrong!");
61  static_assert(KPack % (B_K1 * B_KRow) == 0, "wrong!");
62 
63  static constexpr auto wmma_gemm =
65 
66  static constexpr index_t KRepeat = KPerBlock / KPack;
67 
68  static constexpr auto WmmaK = Number<wmma_gemm.wmma_instr.k_per_wmma>{};
69 
72  MPerBlock,
73  NPerBlock,
74  KPerBlock,
75  ABlockTransferSrcScalarPerVector,
76  BBlockTransferSrcScalarPerVector,
77  A_K1,
78  B_K1,
79  A_K1,
80  B_K1,
81  MRepeat,
82  NRepeat,
83  MPerWmma,
84  NPerWmma,
85  wmma_gemm.wmma_instr.k_per_wmma>;
86 
88  AccDataType,
89  MRepeat * NRepeat,
90  wmma_gemm.GetRegSizePerWmma(),
91  true>
93 
94  __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
95 
96  __device__ static auto GetWaveIdx()
97  {
98  const index_t thread_id = ThisThreadBlock::GetThreadId();
99 
100  constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
104 
105  return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
106  }
107 
108  __device__ static auto CalculateAThreadOriginDataIndex()
109  {
110  const auto wave_idx = GetWaveIdx();
111 
112  const auto waveId_m = wave_idx[I0];
113 
114  const auto wmma_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
115 
116 #if defined(__gfx12__)
117  const auto wmma_krow = wmma_gemm.GetSubGroupId();
118 #else
119  const auto wmma_krow = 0;
120 #endif
121 
122  // |KRepeat |MRepeat|MWave |KRow |MLane |KPack
123  return make_tuple(0, 0, waveId_m, wmma_krow, wmma_a_idx, 0);
124  }
125 
126  __device__ static auto CalculateBThreadOriginDataIndex()
127  {
128  const auto wave_idx = GetWaveIdx();
129 
130  const auto waveId_n = wave_idx[I1];
131 
132  const auto wmma_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
133 
134 #if defined(__gfx12__)
135  const auto wmma_krow = wmma_gemm.GetSubGroupId();
136 #else
137  const auto wmma_krow = 0;
138 #endif
139 
140  // |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
141  return make_tuple(0, 0, waveId_n, wmma_krow, wmma_b_idx, 0);
142  }
143 
144  template <index_t m0, index_t n0>
146  {
147  const auto wave_idx = GetWaveIdx();
148 
149  const auto waveId_m = wave_idx[I0];
150  const auto waveId_n = wave_idx[I1];
151 
152  const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
153 
154  constexpr auto mrepeat_mwave_mperwmma_to_m_adaptor = make_single_stage_tensor_adaptor(
155  make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWmma))),
158 
159  constexpr auto nrepeat_nwave_nperwmma_to_n_adaptor = make_single_stage_tensor_adaptor(
160  make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWmma))),
163 
164  const index_t c_thread_m = mrepeat_mwave_mperwmma_to_m_adaptor.CalculateBottomIndex(
165  make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
166  const index_t c_thread_n = nrepeat_nwave_nperwmma_to_n_adaptor.CalculateBottomIndex(
167  make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
168 
169  return make_tuple(c_thread_m, c_thread_n);
170  }
171 
173 
191  __host__ __device__
194  : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
195  {
196  static_assert(AWmmaTileDesc::IsKnownAtCompileTime() &&
197  BWmmaTileDesc::IsKnownAtCompileTime(),
198  "wrong! Desc should be known at compile-time");
199 
200  static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
201  "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize");
202 
203  static_assert(MPerBlock % (MPerWmma * MRepeat) == 0 &&
204  NPerBlock % (NPerWmma * NRepeat) == 0,
205  "wrong!");
206  }
207 
208  __host__ __device__ static constexpr auto
210  {
211  constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
212  wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
213 
214  constexpr auto MAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
215  constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
217  // |MRepeat |MWave |MSubGroup |NRepeat |NWave
218  // |NThreadPerSubGroup |MAccVgprs
219  make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, MAccVgprs),
220  make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
221  Number<NRepeat>{} * MAccVgprs * AccStride,
222  Number<NRepeat>{} * MAccVgprs * AccStride,
223  MAccVgprs * AccStride,
224  MAccVgprs * AccStride,
225  MAccVgprs * AccStride,
226  AccStride));
227  }
228 
229  __host__ __device__ static constexpr auto
231  {
232  constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
234  Number<MWaves>{},
236  Number<NRepeat>{},
237  Number<NWaves>{},
238  Number<NPerWmma>{}));
239 
240  return wmma_gemm
241  .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
242  c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
243  }
244 
245  // Describe how data allocated in thread copy src buffer
246  // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
247  static constexpr AWmmaTileDesc a_block_desc_k0_m0_m1_m2_k1;
248  static constexpr BWmmaTileDesc b_block_desc_k0_n0_n1_n2_k1;
249 
250  protected:
251  static constexpr auto a_thread_desc_ =
253  Number<MRepeat>{},
254  Number<KRepeat>{},
255  I1,
256  I1,
257  Number<A_K1>{}),
258  make_tuple(Number<A_K1>{},
259  Number<KPack / A_KRow>{},
260  Number<KPack / A_KRow * MRepeat>{},
261  I0,
262  I0,
263  I1));
264 
265  static constexpr auto b_thread_desc_ =
267  Number<NRepeat>{},
268  Number<KRepeat>{},
269  I1,
270  I1,
271  Number<B_K1>{}),
272  make_tuple(Number<B_K1>{},
273  Number<KPack / B_KRow>{},
274  Number<KPack / B_KRow * NRepeat>{},
275  I0,
276  I0,
277  I1));
278 
279  // C[M, N, NumRegWmma]
281  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
282 
283  using AThreadCopy =
285  ComputeTypeA,
286  decltype(a_block_desc_k0_m0_m1_m2_k1),
287  decltype(a_thread_desc_),
288  Sequence<KPack / A_K1 / A_KRow, MRepeat, 1, 1, 1, A_K1>,
290  5,
291  A_K1,
292  A_K1>;
293 
294  using BThreadCopy =
296  ComputeTypeB,
297  decltype(b_block_desc_k0_n0_n1_n2_k1),
298  decltype(b_thread_desc_),
299  Sequence<KPack / B_K1 / B_KRow, NRepeat, 1, 1, 1, B_K1>,
301  5,
302  B_K1,
303  B_K1>;
304 
307 };
308 
309 } // namespace ck
Definition: ck.hpp:269
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:300
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:35
static constexpr auto a_thread_desc_
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:251
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:305
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:94
static constexpr BWmmaTileDesc b_block_desc_k0_n0_n1_n2_k1
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:248
static constexpr auto I1
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:37
static constexpr index_t A_K1
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:57
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >)
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:145
static __device__ auto CalculateBThreadOriginDataIndex()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:126
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, AccDataType, MRepeat *NRepeat, wmma_gemm.GetRegSizePerWmma(), true > c_thread_buf_
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:92
__host__ __device__ BlockwiseGemmWmmaops_pipeline_base(Tuple6 a_origin=CalculateAThreadOriginDataIndex(), Tuple6 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmWmmaops_pipeline_base.
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:192
static constexpr auto b_thread_desc_
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:265
static constexpr AWmmaTileDesc a_block_desc_k0_m0_m1_m2_k1
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:247
static constexpr index_t MWaves
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:46
static constexpr auto wmma_gemm
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:63
static constexpr index_t B_KRow
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:54
static __device__ auto GetWaveIdx()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:96
static constexpr auto I3
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:39
static constexpr auto I0
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:36
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:42
static constexpr index_t B_K1
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:58
__host__ static constexpr __device__ auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:209
static constexpr index_t WaveSize
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:44
__host__ static constexpr __device__ auto GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:230
static constexpr auto WmmaK
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:68
static constexpr auto I5
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:40
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:306
decltype(CalculateAThreadOriginDataIndex()) Tuple6
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:172
static constexpr index_t KRepeat
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:66
static constexpr index_t NWaves
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:47
static __device__ auto CalculateAThreadOriginDataIndex()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:108
static constexpr index_t A_KRow
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:53
static constexpr auto c_thread_desc_
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:280
static constexpr auto I2
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:38
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
Definition: sequence.hpp:43
Definition: static_buffer.hpp:75
static __device__ index_t GetThreadId()
Definition: thread_group.hpp:19
static constexpr __device__ index_t GetNumOfThread()
Definition: thread_group.hpp:15
Definition: wmma_gemm.hpp:663
Definition: integral_constant.hpp:20