/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/develop/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp Source File
blockwise_gemm_pipeline_wmmaops_base.hpp
Go to the documentation of this file.
1 // Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
2 // SPDX-License-Identifier: MIT
3 
4 #pragma once
5 
12 
13 namespace ck {
14 
15 template <index_t BlockSize,
16  typename ADataType,
17  typename BDataType,
18  typename ComputeTypeA,
19  typename ComputeTypeB,
20  typename AccDataType,
21  typename AWmmaTileDesc,
22  typename BWmmaTileDesc,
23  index_t ABlockTransferSrcScalarPerVector,
24  index_t BBlockTransferSrcScalarPerVector,
25  index_t MPerBlock,
26  index_t NPerBlock,
27  index_t KPerBlock,
28  index_t MPerWmma,
29  index_t NPerWmma,
30  index_t MRepeat,
31  index_t NRepeat,
32  index_t KPack,
33  index_t KInner,
34  bool TransposeC = false>
36 {
37  static constexpr auto I0 = Number<0>{};
38  static constexpr auto I1 = Number<1>{};
39  static constexpr auto I2 = Number<2>{};
40  static constexpr auto I3 = Number<3>{};
41  static constexpr auto I5 = Number<5>{};
42  static constexpr auto I6 = Number<6>{};
43 
45 
46  static constexpr index_t WaveSize = 32;
47 
48  static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
49  static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
50 
51 #if defined(__gfx12__)
52  static constexpr index_t A_KRow = 2;
53  static constexpr index_t B_KRow = 2;
54 #else
55  static constexpr index_t A_KRow = 1;
56  static constexpr index_t B_KRow = 1;
57 #endif
58 
59  static constexpr auto wmma_gemm = WmmaGemm<ComputeTypeA,
60  ComputeTypeB,
61  AccDataType,
62  MPerWmma,
63  NPerWmma,
64  KPack / KInner,
65  TransposeC>{};
66 
67  static constexpr index_t KPerThread = wmma_gemm.wmma_instr.k_per_blk * KInner;
68  static constexpr index_t A_K1 = ck::math::min(AWmmaTileDesc{}.GetLength(I6), KPerThread);
69  static constexpr index_t B_K1 = ck::math::min(BWmmaTileDesc{}.GetLength(I6), KPerThread);
70 
71  static_assert(KPack % (A_K1 * A_KRow) == 0, "wrong!");
72  static_assert(KPack % (B_K1 * B_KRow) == 0, "wrong!");
73  static constexpr index_t KRepeat = KPerBlock / KPack;
74 
75  static constexpr auto WmmaK = Number<wmma_gemm.wmma_instr.k_per_wmma>{};
76 
79  MPerBlock,
80  NPerBlock,
81  KPerBlock,
82  ABlockTransferSrcScalarPerVector,
83  BBlockTransferSrcScalarPerVector,
84  A_K1,
85  B_K1,
86  A_K1,
87  B_K1,
88  MRepeat,
89  NRepeat,
90  MPerWmma,
91  NPerWmma,
92  wmma_gemm.wmma_instr.k_per_wmma>;
93 
95  AccDataType,
96  MRepeat * NRepeat,
97  wmma_gemm.GetRegSizePerWmma(),
98  true>
100 
101  struct Empty
102  {
103  __device__ Empty() {};
104  template <index_t NBuffer>
105  __device__ void GlobalLoad(bool cond)
106  {
107  ignore = NBuffer;
108  ignore = cond;
109  }
110  };
111 
112  template <index_t ScaleSliceSizeN,
113  index_t ScaleSliceSizeK,
114  index_t NWaves,
115  index_t ScaleBlockK,
116  index_t NumberOfBuffers,
117  typename GridDesc,
118  typename ThreadCopy,
119  typename GridBuffer,
120  typename ThreadStaticBuffer,
121  typename BScaleThreadDesc>
122  struct BScale
123  {
124  __device__ BScale(GridDesc b_scale_grid_desc_,
125  ThreadCopy b_scale_thread_copy_,
126  GridBuffer b_scale_grid_buf_)
127  : b_scale_thread_copy(b_scale_thread_copy_),
128  b_scale_grid_desc(b_scale_grid_desc_),
129  b_scale_grid_buf(b_scale_grid_buf_) {};
130 
131  static constexpr index_t num_scale_k_block = BScaleThreadDesc{}.GetLength(Number<1>{});
133 
134  static constexpr auto b_scale_thread_desc = BScaleThreadDesc{};
135 
136  static constexpr auto b_scale_thread_copy_step =
137  make_tuple(make_multi_index(NWaves * NPerWmma, 0),
138  make_multi_index(-NPerBlock, 0),
139  make_multi_index(-NPerBlock, (KPerBlock + ScaleBlockK - 1) / ScaleBlockK));
140 
141  template <index_t NBuffer>
142  __device__ void GlobalLoad(bool cond)
143  {
144  static_for<0, NRepeat, 1>{}([&](auto n0) {
148  make_tuple(n0, Number<0>{}),
150 
151  b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
153  });
154 
155  if(cond)
156  {
157  b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
159  }
160  else
161  {
162  b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
164  }
165  }
166 
169  GridBuffer b_scale_grid_buf;
171  };
172 
173  __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
174 
175  __device__ static auto GetWaveIdx()
176  {
177  const index_t thread_id = ThisThreadBlock::GetThreadId();
178 
179  constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor(
180  make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))),
183 
184  return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id));
185  }
186 
187  __device__ static auto CalculateAThreadOriginDataIndex()
188  {
189  const auto wave_idx = GetWaveIdx();
190 
191  const auto waveId_m = wave_idx[I0];
192 
193  const auto wmma_a_idx = wmma_gemm.CalculateAThreadOriginDataIndex();
194 
195 #if defined(__gfx12__)
196  const auto wmma_krow = wmma_gemm.GetSubGroupId();
197 #else
198  const auto wmma_krow = 0;
199 #endif
200 
201  return make_tuple(0, 0, 0, waveId_m, wmma_krow, wmma_a_idx, 0);
202  }
203 
204  __device__ static auto CalculateBThreadOriginDataIndex()
205  {
206  const auto wave_idx = GetWaveIdx();
207 
208  const auto waveId_n = wave_idx[I1];
209 
210  const auto wmma_b_idx = wmma_gemm.CalculateBThreadOriginDataIndex();
211 
212 #if defined(__gfx12__)
213  const auto wmma_krow = wmma_gemm.GetSubGroupId();
214 #else
215  const auto wmma_krow = 0;
216 #endif
217 
218  return make_tuple(0, 0, 0, waveId_n, wmma_krow, wmma_b_idx, 0);
219  }
220 
221  template <index_t m0, index_t n0>
223  {
224  const auto wave_idx = GetWaveIdx();
225 
226  const auto waveId_m = wave_idx[I0];
227  const auto waveId_n = wave_idx[I1];
228 
229  const auto blk_idx = wmma_gemm.GetBeginOfThreadBlk();
230 
231  constexpr auto mrepeat_mwave_mperwmma_to_m_adaptor = make_single_stage_tensor_adaptor(
232  make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerWmma))),
235 
236  constexpr auto nrepeat_nwave_nperwmma_to_n_adaptor = make_single_stage_tensor_adaptor(
237  make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerWmma))),
240 
241  const index_t c_thread_m = mrepeat_mwave_mperwmma_to_m_adaptor.CalculateBottomIndex(
242  make_tuple(m0, waveId_m, blk_idx[I0]))[I0];
243  const index_t c_thread_n = nrepeat_nwave_nperwmma_to_n_adaptor.CalculateBottomIndex(
244  make_tuple(n0, waveId_n, blk_idx[I1]))[I0];
245 
246  return make_tuple(c_thread_m, c_thread_n);
247  }
248 
249  using Tuple7 = decltype(CalculateAThreadOriginDataIndex());
250 
268  __host__ __device__
269  BlockwiseGemmWmmaops_pipeline_base(Tuple7 a_origin = CalculateAThreadOriginDataIndex(),
270  Tuple7 b_origin = CalculateBThreadOriginDataIndex())
271  : a_thread_copy_(a_origin), b_thread_copy_(b_origin)
272  {
273  static_assert(AWmmaTileDesc::IsKnownAtCompileTime() &&
274  BWmmaTileDesc::IsKnownAtCompileTime(),
275  "wrong! Desc should be known at compile-time");
276 
277  static_assert(ThisThreadBlock::GetNumOfThread() == MWaves * NWaves * WaveSize,
278  "ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize");
279 
280  static_assert(MPerBlock % (MPerWmma * MRepeat) == 0 &&
281  NPerBlock % (NPerWmma * NRepeat) == 0,
282  "wrong!");
283  }
284 
285  // transposed WMMA output C' = B' * A'
286  __host__ __device__ static constexpr auto
288  {
289  constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
290  wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
291 
292  constexpr auto NAccVgprs = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I2];
293 
295  // |MRepeat |MWave |MSubGroup |NRepeat |NWave
296  // |NThreadPerSubGroup |MAccVgprs
297  make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, NAccVgprs));
298  }
299 
300  static constexpr auto MAccVgprs =
301  wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths()[I2];
302 
303  __host__ __device__ static constexpr auto
305  {
306  constexpr auto c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens =
307  wmma_gemm.GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths();
308 
309  constexpr auto AccStride = c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens[I3];
311  // |MRepeat |MWave |MSubGroup |NRepeat |NWave
312  // |NThreadPerSubGroup |MAccVgprs
313  make_tuple(Number<MRepeat>{}, I1, I1, Number<NRepeat>{}, I1, I1, MAccVgprs),
314  make_tuple(Number<NRepeat>{} * MAccVgprs * AccStride,
315  Number<NRepeat>{} * MAccVgprs * AccStride,
316  Number<NRepeat>{} * MAccVgprs * AccStride,
317  MAccVgprs * AccStride,
318  MAccVgprs * AccStride,
319  MAccVgprs * AccStride,
320  AccStride));
321  }
322 
323  __host__ __device__ static constexpr auto
325  {
326  constexpr auto c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma =
328  Number<MWaves>{},
330  Number<NRepeat>{},
331  Number<NWaves>{},
332  Number<NPerWmma>{}));
333 
334  return wmma_gemm
335  .MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs(
336  c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma);
337  }
338 
339  // Describe how data allocated in thread copy src buffer
340  // M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
341  static constexpr AWmmaTileDesc a_block_desc_k0_m0_m1_m2_k1;
342  static constexpr BWmmaTileDesc b_block_desc_k0_n0_n1_n2_k1;
343 
344  protected:
345  static constexpr auto a_thread_desc_ =
347  Number<MRepeat>{},
348  Number<KRepeat>{},
349  I1,
350  I1,
351  I1,
352  Number<A_K1>{}),
353  make_tuple(Number<A_K1>{},
354  Number<KPack / A_KRow>{},
355  Number<KPack / A_KRow * MRepeat>{},
356  I0,
357  I0,
358  I0,
359  I1));
360 
361  static constexpr auto b_thread_desc_ =
363  Number<NRepeat>{},
364  Number<KRepeat>{},
365  I1,
366  I1,
367  I1,
368  Number<B_K1>{}),
369  make_tuple(Number<B_K1>{},
370  Number<KPack / B_KRow>{},
371  Number<KPack / B_KRow * NRepeat>{},
372  I0,
373  I0,
374  I0,
375  I1));
376 
377  // C[M, N, NumRegWmma]
378  static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed(
379  make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, wmma_gemm.GetRegSizePerWmma()));
380 
381  using AThreadCopy =
383  ComputeTypeA,
384  decltype(a_block_desc_k0_m0_m1_m2_k1),
385  decltype(a_thread_desc_),
386  Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
388  6,
389  A_K1,
390  A_K1>;
391 
392  using BThreadCopy =
394  ComputeTypeB,
395  decltype(b_block_desc_k0_n0_n1_n2_k1),
396  decltype(b_thread_desc_),
397  Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
399  6,
400  B_K1,
401  B_K1>;
402 
405 };
406 
407 } // namespace ck
__host__ constexpr __device__ T min(T x)
Definition: math.hpp:116
Definition: ck.hpp:270
typename detail::StaticallyIndexedArrayImpl< T, N >::type StaticallyIndexedArray
Definition: statically_indexed_array.hpp:45
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_naive_tensor_descriptor(const Tuple< Lengths... > &lengths, const Tuple< Strides... > &strides)
Definition: tensor_descriptor_helper.hpp:49
__host__ constexpr __device__ auto make_naive_tensor_descriptor_packed(const Tuple< Lengths... > &lengths)
Definition: tensor_descriptor_helper.hpp:101
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:425
constexpr detail::ignore_t ignore
Definition: ignore.hpp:20
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
__host__ constexpr __device__ auto make_unmerge_transform(const UpLengths &up_lengths, integral_constant< bool, Use24BitIntegerCalculation >=integral_constant< bool, false >{})
Definition: multi_index_transform_helper.hpp:90
int32_t index_t
Definition: ck.hpp:301
integral_constant< index_t, N > Number
Definition: number.hpp:12
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:123
static constexpr auto b_scale_thread_copy_step
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:136
GridDesc b_scale_grid_desc
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:168
GridBuffer b_scale_grid_buf
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:169
__device__ void GlobalLoad(bool cond)
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:142
StaticallyIndexedArray< ThreadStaticBuffer, Number< NumberOfBuffers >{}> b_scale_thread_bufs
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:170
ThreadCopy b_scale_thread_copy
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:167
__device__ BScale(GridDesc b_scale_grid_desc_, ThreadCopy b_scale_thread_copy_, GridBuffer b_scale_grid_buf_)
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:124
static constexpr auto b_scale_thread_desc
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:134
static constexpr index_t num_scale_krepeat
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:132
static constexpr index_t num_scale_k_block
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:131
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:102
__device__ void GlobalLoad(bool cond)
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:105
__device__ Empty()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:103
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:36
static constexpr auto I2
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:39
__host__ static constexpr __device__ auto GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:304
static constexpr index_t NWaves
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:49
decltype(CalculateAThreadOriginDataIndex()) Tuple7
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:249
static constexpr BWmmaTileDesc b_block_desc_k0_n0_n1_n2_k1
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:342
static constexpr index_t KPerThread
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:67
static constexpr index_t MWaves
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:48
BThreadCopy b_thread_copy_
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:404
static constexpr index_t A_KRow
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:55
ThisThreadBlock< BlockSize > ThisThreadBlock
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:44
static constexpr auto WmmaK
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:75
static __device__ auto CalculateAThreadOriginDataIndex()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:187
__host__ static constexpr __device__ auto GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:324
static __device__ auto CalculateCThreadOriginDataIndex(Number< m0 >, Number< n0 >)
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:222
StaticBufferTupleOfVector< AddressSpaceEnum::Vgpr, AccDataType, MRepeat *NRepeat, wmma_gemm.GetRegSizePerWmma(), true > c_thread_buf_
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:99
static constexpr auto I5
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:41
static constexpr auto I3
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:40
static constexpr index_t WaveSize
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:46
__host__ static constexpr __device__ auto GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:287
static __device__ auto GetWaveIdx()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:175
__host__ __device__ BlockwiseGemmWmmaops_pipeline_base(Tuple7 a_origin=CalculateAThreadOriginDataIndex(), Tuple7 b_origin=CalculateBThreadOriginDataIndex())
Constructor for BlockwiseGemmWmmaops_pipeline_base.
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:269
__host__ constexpr __device__ auto & GetCThreadBuffer()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:173
static __device__ auto CalculateBThreadOriginDataIndex()
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:204
AThreadCopy a_thread_copy_
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:403
static constexpr index_t B_KRow
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:56
static constexpr index_t B_K1
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:69
static constexpr index_t A_K1
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:68
static constexpr auto I6
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:42
static constexpr index_t KRepeat
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:73
static constexpr AWmmaTileDesc a_block_desc_k0_m0_m1_m2_k1
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:341
static constexpr auto I1
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:38
static constexpr auto I0
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:37
static constexpr auto wmma_gemm
Definition: blockwise_gemm_pipeline_wmmaops_base.hpp:59
Definition: blockwise_gemm_pipeline_wmmaops.hpp:26
Definition: sequence.hpp:43
Definition: static_buffer.hpp:75
Definition: wmma_gemm.hpp:675
Definition: integral_constant.hpp:20
Definition: functional2.hpp:33