/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp Source File

/home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp Source File#

Composable Kernel: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-composable-kernel/checkouts/docs-6.4.3/include/ck/tensor_operation/gpu/thread/threadwise_contraction_dl.hpp Source File
threadwise_contraction_dl.hpp
Go to the documentation of this file.
1 // SPDX-License-Identifier: MIT
2 // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
3 
4 #pragma once
5 
7 #include "ck/utility/math.hpp"
8 
9 namespace ck {
10 
11 // C[TM0, TM1, TN0, TN1] += A[TK, TM0, TM1] * B[TK, TN0, TN1]
12 // Tensor element can be vectorized data
13 // Assume:
14 // 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are
15 // known at compile-time
16 // 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
17 template <typename FloatA,
18  typename FloatB,
19  typename FloatC,
20  typename AThreadDesc_TK0_TM0_TM1_TK1,
21  typename BThreadDesc_TK0_TN0_TN1_TK1,
22  typename CThreadDesc_TM0_TM1_TN0_TN1,
23  typename TKLengths,
24  typename TMLengths,
25  typename TNLengths,
26  typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
27  BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
28  CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
29  bool>::type = false>
31 {
33  {
34  static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
35  BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
36  CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
37  "wrong! Desc should be known at compile-time");
38 
39  // TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1,
40  // CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths
41 
42  // TODO remove this restriction
43  static_assert(TKLengths::Size() == 1 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
44  "wrong!");
45  }
46 
47  template <typename ABuffer,
48  typename AOriginIdx,
49  typename BBuffer,
50  typename BOriginIdx,
51  typename CBuffer,
52  typename COriginIdx>
53  __device__ static void Run(const ABuffer& a_buf,
54  AOriginIdx,
55  const BBuffer& b_buf,
56  BOriginIdx,
57  CBuffer& c_buf,
58  COriginIdx)
59  {
63  "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
64 
65  static_assert(
69  "wrong! inconsistent type");
70 
71  constexpr auto I0 = Number<0>{};
72  constexpr auto I1 = Number<1>{};
73 
74  constexpr auto TK = TKLengths{}[I0];
75  constexpr auto TM0 = TMLengths{}[I0];
76  constexpr auto TM1 = TMLengths{}[I1];
77  constexpr auto TN0 = TNLengths{}[I0];
78  constexpr auto TN1 = TNLengths{}[I1];
79 
80  constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
81  constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
82  constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
83 
84  static_for<0, TK, 1>{}([&](auto tk) {
85  static_for<0, TM0, 1>{}([&](auto tm0) {
86  static_for<0, TM1, 1>{}([&](auto tm1) {
87  static_for<0, TN0, 1>{}([&](auto tn0) {
88  static_for<0, TN1, 1>{}([&](auto tn1) {
89  constexpr index_t a_offset =
90  AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
91  a_origin_idx + make_multi_index(tk, tm0, tm1));
92  constexpr index_t b_offset =
93  BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
94  b_origin_idx + make_multi_index(tk, tn0, tn1));
95  constexpr index_t c_offset =
96  CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
97  c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
98 
99  inner_product<FloatA, FloatB, FloatC>(a_buf[Number<a_offset>{}],
100  b_buf[Number<b_offset>{}],
101  c_buf(Number<c_offset>{}));
102  });
103  });
104  });
105  });
106  });
107  }
108 };
109 
110 // C[TM0, TM1, TN0, TN1] += A[TK0, TM0, TM1, TK1] * B[TK0, TN0, TN1, TK1]
111 // Tensor element can be vectorized data
112 // Assume:
113 // 1. AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1, CThreadDesc_TM0_TM1_TN0_TN1 are
114 // known at compile-time
115 // 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
116 template <typename FloatA,
117  typename FloatB,
118  typename FloatC,
119  typename AThreadDesc_TK0_TM0_TM1_TK1,
120  typename BThreadDesc_TK0_TN0_TN1_TK1,
121  typename CThreadDesc_TM0_TM1_TN0_TN1,
122  typename TKLengths,
123  typename TMLengths,
124  typename TNLengths,
125  typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
126  BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
127  CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
128  bool>::type = false>
130 {
132  {
133  static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
134  BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
135  CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
136  "wrong! Desc should be known at compile-time");
137 
138  // TODO: sanity-check: compare AThreadDesc_TK0_TM0_TM1_TK1, BThreadDesc_TK0_TN0_TN1_TK1,
139  // CThreadDesc_TM0_TM1_TN0_TN1 Size with KLenghts, TMLengths and TNLengths
140 
141  // TODO remove this restriction
142  static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
143  "wrong!");
144  }
145 
146  template <typename ABuffer,
147  typename AOriginIdx,
148  typename BBuffer,
149  typename BOriginIdx,
150  typename CBuffer,
151  typename COriginIdx>
152  __device__ static void Run(const ABuffer& a_buf,
153  AOriginIdx,
154  const BBuffer& b_buf,
155  BOriginIdx,
156  CBuffer& c_buf,
157  COriginIdx)
158  {
162  "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
163 
164  static_assert(
168  "wrong! inconsistent type");
169 
170  constexpr auto I0 = Number<0>{};
171  constexpr auto I1 = Number<1>{};
172 
173  constexpr index_t TK0 = TKLengths{}[I0];
174  constexpr index_t TK1 = TKLengths{}[I1];
175  constexpr index_t TM0 = TMLengths{}[I0];
176  constexpr index_t TM1 = TMLengths{}[I1];
177  constexpr index_t TN0 = TNLengths{}[I0];
178  constexpr index_t TN1 = TNLengths{}[I1];
179 
180  constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
181  constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
182  constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
183 
184  static_for<0, TK0, 1>{}([&](auto tk0) {
185  static_for<0, TM0, 1>{}([&](auto tm0) {
186  static_for<0, TM1, 1>{}([&](auto tm1) {
187  static_for<0, TN0, 1>{}([&](auto tn0) {
188  static_for<0, TN1, 1>{}([&](auto tn1) {
191 
192  static_for<0, TK1, 1>{}([&](auto tk1) {
193  constexpr index_t a_offset =
194  AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
195  a_origin_idx + make_multi_index(tk0, tm0, tm1, tk1));
196 
197  constexpr index_t b_offset =
198  BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
199  b_origin_idx + make_multi_index(tk0, tn0, tn1, tk1));
200 
201  a_vec.template AsType<FloatA>()(tk1) = a_buf[Number<a_offset>{}];
202  b_vec.template AsType<FloatB>()(tk1) = b_buf[Number<b_offset>{}];
203  });
204 
205  using a_vector_t = typename vector_type<FloatA, TK1>::type;
206  using b_vector_t = typename vector_type<FloatB, TK1>::type;
207 
208  constexpr index_t c_offset =
209  CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
210  c_origin_idx + make_multi_index(tm0, tm1, tn0, tn1));
211 
212  inner_product<a_vector_t, b_vector_t, FloatC>(
213  a_vec.template AsType<a_vector_t>()[I0],
214  b_vec.template AsType<b_vector_t>()[I0],
215  c_buf(Number<c_offset>{}));
216  });
217  });
218  });
219  });
220  });
221  }
222 };
223 
224 } // namespace ck
Definition: ck.hpp:264
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto to_multi_index(const T &x)
Definition: array_multi_index.hpp:28
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:10
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
int32_t index_t
Definition: ck.hpp:289
constexpr __device__ ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
Definition: threadwise_contraction_dl.hpp:131
static __device__ void Run(const ABuffer &a_buf, AOriginIdx, const BBuffer &b_buf, BOriginIdx, CBuffer &c_buf, COriginIdx)
Definition: threadwise_contraction_dl.hpp:152
Definition: threadwise_contraction_dl.hpp:31
static __device__ void Run(const ABuffer &a_buf, AOriginIdx, const BBuffer &b_buf, BOriginIdx, CBuffer &c_buf, COriginIdx)
Definition: threadwise_contraction_dl.hpp:53
constexpr __device__ ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1()
Definition: threadwise_contraction_dl.hpp:32
Definition: integral_constant.hpp:10
Definition: is_known_at_compile_time.hpp:14
Definition: type.hpp:177
Definition: functional2.hpp:31
Definition: data_type.hpp:347