/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp Source File
dpp_gemm.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
8 #include "ck/utility/math.hpp"
9 
10 namespace ck {
11 
12 enum struct DppInstr
13 {
14  dpp8_f16_1x32x2 = 0,
23 };
24 
47 template <DppInstr instr>
48 struct dpp_type;
49 
50 template <>
52 {
53  static constexpr index_t wave_size = 32;
54  static constexpr index_t lanegroup_size = 8;
55  static constexpr index_t m_per_wave = 32;
56  static constexpr index_t n_per_wave = 8;
57  static constexpr index_t m_per_lanegroup = 8;
58  static constexpr index_t n_per_lanegroup = 8;
59  static constexpr index_t m_per_thread = 8;
60  static constexpr index_t n_per_thread = 1;
61  static constexpr index_t k_per_dpp = 2;
62  static constexpr bool share_a = true;
63  using BaseType = half_t;
64 
65  template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
66  __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
67  {
68  dpp8::DppLanegroupGemm<m_per_thread,
69  n_per_thread,
70  k_per_dpp,
71  BaseType,
72  ADataType,
73  BDataType,
74  CDataType,
75  share_a>{}
76  .Run(a, b, reg_c);
77  }
78 };
79 
80 template <>
82 {
83  static constexpr index_t wave_size = 32;
84  static constexpr index_t lanegroup_size = 8;
85  static constexpr index_t m_per_wave = 8;
86  static constexpr index_t n_per_wave = 32;
87  static constexpr index_t m_per_lanegroup = 8;
88  static constexpr index_t n_per_lanegroup = 8;
89  static constexpr index_t m_per_thread = 8;
90  static constexpr index_t n_per_thread = 1;
91  static constexpr index_t k_per_dpp = 2;
92  static constexpr bool share_a = true;
93  using BaseType = half_t;
94 
95  template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
96  __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
97  {
98  dpp8::DppLanegroupGemm<m_per_thread,
99  n_per_thread,
100  k_per_dpp,
101  BaseType,
102  ADataType,
103  BDataType,
104  CDataType,
105  share_a>{}
106  .Run(a, b, reg_c);
107  }
108 };
109 
110 template <>
112 {
113  static constexpr index_t wave_size = 32;
114  static constexpr index_t lanegroup_size = 8;
115  static constexpr index_t m_per_wave = 8;
116  static constexpr index_t n_per_wave = 16;
117  static constexpr index_t m_per_lanegroup = 4;
118  static constexpr index_t n_per_lanegroup = 8;
119  static constexpr index_t m_per_thread = 4;
120  static constexpr index_t n_per_thread = 1;
121  static constexpr index_t k_per_dpp = 2;
122  static constexpr bool share_a = true;
123  using BaseType = half_t;
124 
125  template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
126  __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
127  {
128  dpp8::DppLanegroupGemm<m_per_thread,
129  n_per_thread,
130  k_per_dpp,
131  BaseType,
132  ADataType,
133  BDataType,
134  CDataType,
135  share_a>{}
136  .Run(a, b, reg_c);
137  }
138 };
139 
140 template <>
142 {
143  static constexpr index_t wave_size = 32;
144  static constexpr index_t lanegroup_size = 8;
145  static constexpr index_t m_per_wave = 16;
146  static constexpr index_t n_per_wave = 16;
147  static constexpr index_t m_per_lanegroup = 8;
148  static constexpr index_t n_per_lanegroup = 8;
149  static constexpr index_t m_per_thread = 8;
150  static constexpr index_t n_per_thread = 1;
151  static constexpr index_t k_per_dpp = 2;
152  static constexpr bool share_a = true;
153  using BaseType = half_t;
154 
155  template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
156  __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
157  {
158  dpp8::DppLanegroupGemm<m_per_thread,
159  n_per_thread,
160  k_per_dpp,
161  BaseType,
162  ADataType,
163  BDataType,
164  CDataType,
165  share_a>{}
166  .Run(a, b, reg_c);
167  }
168 };
169 
170 template <>
172 {
173  static constexpr index_t wave_size = 32;
174  static constexpr index_t lanegroup_size = 8;
175  static constexpr index_t m_per_wave = 4;
176  static constexpr index_t n_per_wave = 32;
177  static constexpr index_t m_per_lanegroup = 4;
178  static constexpr index_t n_per_lanegroup = 8;
179  static constexpr index_t m_per_thread = 4;
180  static constexpr index_t n_per_thread = 1;
181  static constexpr index_t k_per_dpp = 2;
182  static constexpr bool share_a = true;
183  using BaseType = half_t;
184 
185  template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
186  __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
187  {
188  dpp8::DppLanegroupGemm<m_per_thread,
189  n_per_thread,
190  k_per_dpp,
191  BaseType,
192  ADataType,
193  BDataType,
194  CDataType,
195  share_a>{}
196  .Run(a, b, reg_c);
197  }
198 };
199 
200 template <>
202 {
203  static constexpr index_t wave_size = 32;
204  static constexpr index_t lanegroup_size = 8;
205  static constexpr index_t m_per_wave = 4;
206  static constexpr index_t n_per_wave = 16;
207  static constexpr index_t m_per_lanegroup = 2;
208  static constexpr index_t n_per_lanegroup = 8;
209  static constexpr index_t m_per_thread = 2;
210  static constexpr index_t n_per_thread = 1;
211  static constexpr index_t k_per_dpp = 2;
212  static constexpr bool share_a = true;
213  using BaseType = half_t;
214 
215  template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
216  __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
217  {
218  dpp8::DppLanegroupGemm<m_per_thread,
219  n_per_thread,
220  k_per_dpp,
221  BaseType,
222  ADataType,
223  BDataType,
224  CDataType,
225  share_a>{}
226  .Run(a, b, reg_c);
227  }
228 };
229 
230 template <>
232 {
233  static constexpr index_t wave_size = 32;
234  static constexpr index_t lanegroup_size = 8;
235  static constexpr index_t m_per_wave = 1;
236  static constexpr index_t n_per_wave = 32;
237  static constexpr index_t m_per_lanegroup = 1;
238  static constexpr index_t n_per_lanegroup = 8;
239  static constexpr index_t m_per_thread = 1;
240  static constexpr index_t n_per_thread = 1;
241  static constexpr index_t k_per_dpp = 2;
242  static constexpr bool share_a = true;
243  using BaseType = half_t;
244 
245  template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
246  __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
247  {
248  dpp8::DppLanegroupGemm<m_per_thread,
249  n_per_thread,
250  k_per_dpp,
251  BaseType,
252  ADataType,
253  BDataType,
254  CDataType,
255  share_a>{}
256  .Run(a, b, reg_c);
257  }
258 };
259 
260 template <>
262 {
263  static constexpr index_t wave_size = 32;
264  static constexpr index_t lanegroup_size = 8;
265  static constexpr index_t m_per_wave = 2;
266  static constexpr index_t n_per_wave = 32;
267  static constexpr index_t m_per_lanegroup = 2;
268  static constexpr index_t n_per_lanegroup = 8;
269  static constexpr index_t m_per_thread = 2;
270  static constexpr index_t n_per_thread = 1;
271  static constexpr index_t k_per_dpp = 2;
272  static constexpr bool share_a = true;
273  using BaseType = half_t;
274 
275  template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
276  __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
277  {
278  dpp8::DppLanegroupGemm<m_per_thread,
279  n_per_thread,
280  k_per_dpp,
281  BaseType,
282  ADataType,
283  BDataType,
284  CDataType,
285  share_a>{}
286  .Run(a, b, reg_c);
287  }
288 };
289 
290 template <>
292 {
293  static constexpr index_t wave_size = 32;
294  static constexpr index_t lanegroup_size = 8;
295  static constexpr index_t m_per_wave = 2;
296  static constexpr index_t n_per_wave = 16;
297  static constexpr index_t m_per_lanegroup = 1;
298  static constexpr index_t n_per_lanegroup = 8;
299  static constexpr index_t m_per_thread = 1;
300  static constexpr index_t n_per_thread = 1;
301  static constexpr index_t k_per_dpp = 2;
302  static constexpr bool share_a = true;
303  using BaseType = half_t;
304 
305  template <index_t MPerDpp, index_t NPerDpp, class ADataType, class BDataType, class CDataType>
306  __device__ void run(const ADataType& a, const BDataType& b, CDataType& reg_c) const
307  {
308  dpp8::DppLanegroupGemm<m_per_thread,
309  n_per_thread,
310  k_per_dpp,
311  BaseType,
312  ADataType,
313  BDataType,
314  CDataType,
315  share_a>{}
316  .Run(a, b, reg_c);
317  }
318 };
319 
320 template <typename BaseType, index_t MPerDpp, index_t NPerDpp>
322 {
323  template <typename BaseType_, index_t MPerDpp_, index_t NPerDpp_>
324  static constexpr auto GetDpp();
325 
326  template <>
327  constexpr auto GetDpp<half_t, 8, 32>()
328  {
330  }
331 
332  template <>
333  constexpr auto GetDpp<half_t, 8, 16>()
334  {
336  }
337 
338  template <>
339  constexpr auto GetDpp<half_t, 16, 16>()
340  {
342  }
343 
344  template <>
345  constexpr auto GetDpp<half_t, 32, 8>()
346  {
348  }
349 
350  template <>
351  constexpr auto GetDpp<half_t, 1, 32>()
352  {
354  }
355 
356  template <>
357  constexpr auto GetDpp<half_t, 2, 32>()
358  {
360  }
361 
362  template <>
363  constexpr auto GetDpp<half_t, 2, 16>()
364  {
366  }
367 
368  template <>
369  constexpr auto GetDpp<half_t, 4, 16>()
370  {
372  }
373 
374  template <>
375  constexpr auto GetDpp<half_t, 4, 32>()
376  {
378  }
379 
381 
382  __host__ __device__ constexpr DppSelector()
383  {
384  static_assert(selected_dpp.m_per_wave % selected_dpp.m_per_lanegroup == 0);
385  static_assert(selected_dpp.n_per_wave % selected_dpp.n_per_lanegroup == 0);
386 
387  static_assert(selected_dpp.k_per_dpp % 2 == 0);
388 
389  static_assert(selected_dpp.wave_size % selected_dpp.lanegroup_size == 0);
390  constexpr index_t num_dpp_per_wave = selected_dpp.wave_size / selected_dpp.lanegroup_size;
391  constexpr index_t num_wave_c_elems = selected_dpp.m_per_wave * selected_dpp.n_per_wave;
392  constexpr index_t num_dpp_c_elems =
393  selected_dpp.m_per_lanegroup * selected_dpp.n_per_lanegroup;
394  static_assert(num_wave_c_elems % num_dpp_c_elems == 0);
395  static_assert(num_dpp_per_wave == num_wave_c_elems / num_dpp_c_elems);
396 
397  if constexpr(selected_dpp.share_a)
398  {
399  static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread);
400  static_assert(selected_dpp.n_per_lanegroup % selected_dpp.n_per_thread == 0);
401  static_assert(selected_dpp.n_per_lanegroup / selected_dpp.n_per_thread ==
402  selected_dpp.lanegroup_size);
403  }
404  else
405  {
406  static_assert(selected_dpp.m_per_lanegroup % selected_dpp.n_per_thread == 0);
407  static_assert(selected_dpp.m_per_lanegroup / selected_dpp.n_per_thread ==
408  selected_dpp.lanegroup_size);
409  static_assert(selected_dpp.n_per_lanegroup == selected_dpp.n_per_thread);
410  }
411 
412  // Below checks come from the restrictions of the current implementation, could be removed
413  // in the future when the implementation is more generalized.
414  static_assert(selected_dpp.share_a);
415  static_assert(selected_dpp.n_per_thread == 1);
416  static_assert(selected_dpp.m_per_lanegroup == selected_dpp.m_per_thread);
417  static_assert(selected_dpp.n_per_lanegroup ==
418  selected_dpp.n_per_thread * selected_dpp.lanegroup_size);
419  }
420 
421  static constexpr index_t GetK1PerDpp() { return selected_dpp.k_per_dpp; }
422 };
423 
424 template <typename BaseType, index_t MPerDpp, index_t NPerDpp, index_t KPack>
425 struct DppGemm
426 {
427  static constexpr auto I0 = Number<0>{};
428  static constexpr auto I1 = Number<1>{};
429  static constexpr auto I2 = Number<2>{};
430  static constexpr auto I3 = Number<3>{};
431  static constexpr auto I4 = Number<4>{};
432  static constexpr auto I5 = Number<5>{};
433 
436 
437  __host__ __device__ constexpr DppGemm()
438  {
439  static_assert(KPack % dpp_instr.k_per_dpp == 0, "KPack must be divisible by k_per_dpp.");
440  }
441 
442  __device__ static constexpr index_t GetRegSizePerDpp()
443  {
444  return MPerDpp * NPerDpp / dpp_instr.wave_size;
445  }
446 
447  template <class ADataType, class BDataType, class CDataType>
448  __device__ void
449  Run(const ADataType& p_a_wave, const BDataType& p_b_wave, CDataType& p_c_thread) const
450  {
454  "base BaseType must be double, float, half, bfloat16, and int8_t!");
455 
456  static_for<0, KPack / dpp_instr.k_per_dpp, 1>{}([&](auto k) {
457  dpp_instr.template run<MPerDpp, NPerDpp>(p_a_wave[k], p_b_wave[k], p_c_thread);
458  });
459  }
460 
461  __device__ static auto GetLaneIdInWave()
462  {
463  return get_thread_local_1d_id() % dpp_instr.wave_size;
464  }
465 
466  __device__ static auto GetWaveId() { return get_thread_local_1d_id() / dpp_instr.wave_size; }
467 
468  __device__ static auto GetLaneIdInLaneGroup()
469  {
470  return get_thread_local_1d_id() % dpp_instr.lanegroup_size;
471  }
472 
473  __device__ static auto GetLaneGroupIdInWave()
474  {
475  return GetLaneIdInWave() / dpp_instr.lanegroup_size;
476  }
477 
478  __device__ static auto GetDppOpIdx()
479  {
480  const auto lanegroupId = GetLaneGroupIdInWave();
481 
482  constexpr auto lanegroup_idx_1d_to_dpp_idx_2d_adaptor = make_single_stage_tensor_adaptor(
483  make_tuple(
484  make_merge_transform(make_tuple(dpp_instr.m_per_wave / dpp_instr.m_per_lanegroup,
485  dpp_instr.n_per_wave / dpp_instr.n_per_lanegroup))),
488 
489  const auto dpp_idx = lanegroup_idx_1d_to_dpp_idx_2d_adaptor.CalculateBottomIndex(
490  make_multi_index(lanegroupId));
491 
492  const auto m_dpp_idx = dpp_idx[I0];
493  const auto n_dpp_idx = dpp_idx[I1];
494 
495  return make_tuple(m_dpp_idx, n_dpp_idx);
496  }
497 
498  __host__ __device__ static auto CalculateAThreadOriginDataIndex_K_M()
499  {
500  const auto laneId = get_thread_local_1d_id();
501  const auto wave_row = laneId / dpp_instr.n_per_wave;
502  auto m_idx = dpp_instr.m_per_thread * wave_row + GetLaneIdInLaneGroup();
503  return make_tuple(0, m_idx % dpp_instr.m_per_wave);
504  }
505 
506  __host__ __device__ static auto CalculateBThreadOriginDataIndex_K_N()
507  {
508  const auto laneId = get_thread_local_1d_id();
509  return make_tuple(0, laneId % dpp_instr.n_per_wave);
510  }
511 
512  __device__ static CIndex GetBeginOfThreadBlk()
513  {
514  const auto dpp_op_idx = GetDppOpIdx();
515 
516  const auto m_dpp_op_idx = dpp_op_idx[I0];
517  const auto n_dpp_op_idx = dpp_op_idx[I1];
518 
519  index_t n_offset = n_dpp_op_idx * dpp_instr.n_per_lanegroup + GetLaneIdInLaneGroup();
520  index_t m_offset = m_dpp_op_idx * dpp_instr.m_per_lanegroup;
521 
522  return CIndex{m_offset, n_offset};
523  }
524 
525  static constexpr auto dpp = DppSelector<BaseType, MPerDpp, NPerDpp>{};
526 
527  static constexpr auto dpp_instr = dpp.selected_dpp;
528 
529  static constexpr auto K0PerDpp = 1;
530  static constexpr auto K1PerDpp = dpp.GetK1PerDpp();
531 
532  __host__ __device__ static constexpr auto GetCMNThreadBlkLengths()
533  {
534  return make_tuple(Number<dpp_instr.m_per_thread>{}, Number<dpp_instr.n_per_thread>{});
535  }
536 };
537 
538 } // namespace ck
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto make_merge_transform(const LowLengths &low_lengths)
Definition: multi_index_transform_helper.hpp:55
_Float16 half_t
Definition: data_type.hpp:25
__host__ constexpr __device__ auto make_single_stage_tensor_adaptor(const Transforms &transforms, LowerDimensionOldTopIdss, UpperDimensionNewTopIdss)
Definition: tensor_adaptor.hpp:429
DppInstr
Definition: dpp_gemm.hpp:13
__host__ constexpr __device__ auto make_tuple(Xs &&... xs)
Definition: tuple.hpp:211
int32_t index_t
Definition: ck.hpp:289
__device__ index_t get_thread_local_1d_id()
Definition: get_id.hpp:16
Definition: array.hpp:14
Definition: dpp_gemm.hpp:426
__host__ static __device__ auto CalculateBThreadOriginDataIndex_K_N()
Definition: dpp_gemm.hpp:506
__device__ void Run(const ADataType &p_a_wave, const BDataType &p_b_wave, CDataType &p_c_thread) const
Definition: dpp_gemm.hpp:449
static constexpr auto dpp_instr
Definition: dpp_gemm.hpp:527
__host__ static constexpr __device__ auto GetCMNThreadBlkLengths()
Definition: dpp_gemm.hpp:532
__host__ constexpr __device__ DppGemm()
Definition: dpp_gemm.hpp:437
static constexpr auto I3
Definition: dpp_gemm.hpp:430
static constexpr auto I1
Definition: dpp_gemm.hpp:428
static __device__ auto GetWaveId()
Definition: dpp_gemm.hpp:466
static constexpr auto I5
Definition: dpp_gemm.hpp:432
static constexpr __device__ index_t GetRegSizePerDpp()
Definition: dpp_gemm.hpp:442
static __device__ auto GetLaneGroupIdInWave()
Definition: dpp_gemm.hpp:473
static __device__ CIndex GetBeginOfThreadBlk()
Definition: dpp_gemm.hpp:512
static constexpr auto I4
Definition: dpp_gemm.hpp:431
static constexpr auto I2
Definition: dpp_gemm.hpp:429
static __device__ auto GetLaneIdInLaneGroup()
Definition: dpp_gemm.hpp:468
static constexpr auto K1PerDpp
Definition: dpp_gemm.hpp:530
static constexpr auto dpp
Definition: dpp_gemm.hpp:525
__host__ static __device__ auto CalculateAThreadOriginDataIndex_K_M()
Definition: dpp_gemm.hpp:498
static constexpr auto I0
Definition: dpp_gemm.hpp:427
static __device__ auto GetDppOpIdx()
Definition: dpp_gemm.hpp:478
static __device__ auto GetLaneIdInWave()
Definition: dpp_gemm.hpp:461
static constexpr auto K0PerDpp
Definition: dpp_gemm.hpp:529
Definition: dpp_gemm.hpp:322
static constexpr index_t GetK1PerDpp()
Definition: dpp_gemm.hpp:421
static constexpr auto selected_dpp
Definition: dpp_gemm.hpp:380
static constexpr auto GetDpp()
__host__ constexpr __device__ DppSelector()
Definition: dpp_gemm.hpp:382
Definition: sequence.hpp:43
Definition: amd_gemm_dpp.hpp:37
__device__ void Run(const AVecDataType &a_vec, const BVecDataType &b_vec, CVecDataType &c_vec)
Definition: amd_gemm_dpp.hpp:43
__device__ void run(const ADataType &a, const BDataType &b, CDataType &reg_c) const
Definition: dpp_gemm.hpp:156
half_t BaseType
Definition: dpp_gemm.hpp:153
__device__ void run(const ADataType &a, const BDataType &b, CDataType &reg_c) const
Definition: dpp_gemm.hpp:246
half_t BaseType
Definition: dpp_gemm.hpp:243
half_t BaseType
Definition: dpp_gemm.hpp:303
__device__ void run(const ADataType &a, const BDataType &b, CDataType &reg_c) const
Definition: dpp_gemm.hpp:306
__device__ void run(const ADataType &a, const BDataType &b, CDataType &reg_c) const
Definition: dpp_gemm.hpp:276
half_t BaseType
Definition: dpp_gemm.hpp:273
half_t BaseType
Definition: dpp_gemm.hpp:63
__device__ void run(const ADataType &a, const BDataType &b, CDataType &reg_c) const
Definition: dpp_gemm.hpp:66
__device__ void run(const ADataType &a, const BDataType &b, CDataType &reg_c) const
Definition: dpp_gemm.hpp:216
half_t BaseType
Definition: dpp_gemm.hpp:213
__device__ void run(const ADataType &a, const BDataType &b, CDataType &reg_c) const
Definition: dpp_gemm.hpp:186
half_t BaseType
Definition: dpp_gemm.hpp:183
__device__ void run(const ADataType &a, const BDataType &b, CDataType &reg_c) const
Definition: dpp_gemm.hpp:126
half_t BaseType
Definition: dpp_gemm.hpp:123
half_t BaseType
Definition: dpp_gemm.hpp:93
__device__ void run(const ADataType &a, const BDataType &b, CDataType &reg_c) const
Definition: dpp_gemm.hpp:96
Definition: dpp_gemm.hpp:48
Definition: integral_constant.hpp:10
Definition: type.hpp:177
Definition: functional2.hpp:31