17 template <
typename FloatA,
20 typename AThreadDesc_TK0_TM0_TM1_TK1,
21 typename BThreadDesc_TK0_TN0_TN1_TK1,
22 typename CThreadDesc_TM0_TM1_TN0_TN1,
26 typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
27 BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
28 CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
34 static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
35 BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
36 CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
37 "wrong! Desc should be known at compile-time");
43 static_assert(TKLengths::Size() == 1 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
47 template <
typename ABuffer,
53 __device__
static void Run(
const ABuffer& a_buf,
63 "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
69 "wrong! inconsistent type");
74 constexpr
auto TK = TKLengths{}[I0];
75 constexpr
auto TM0 = TMLengths{}[I0];
76 constexpr
auto TM1 = TMLengths{}[I1];
77 constexpr
auto TN0 = TNLengths{}[I0];
78 constexpr
auto TN1 = TNLengths{}[I1];
90 AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
93 BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
96 CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
116 template <
typename FloatA,
119 typename AThreadDesc_TK0_TM0_TM1_TK1,
120 typename BThreadDesc_TK0_TN0_TN1_TK1,
121 typename CThreadDesc_TM0_TM1_TN0_TN1,
125 typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
126 BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
127 CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
133 static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
134 BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
135 CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
136 "wrong! Desc should be known at compile-time");
142 static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
146 template <
typename ABuffer,
152 __device__
static void Run(
const ABuffer& a_buf,
154 const BBuffer& b_buf,
162 "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
168 "wrong! inconsistent type");
173 constexpr
index_t TK0 = TKLengths{}[I0];
174 constexpr
index_t TK1 = TKLengths{}[I1];
175 constexpr
index_t TM0 = TMLengths{}[I0];
176 constexpr
index_t TM1 = TMLengths{}[I1];
177 constexpr
index_t TN0 = TNLengths{}[I0];
178 constexpr
index_t TN1 = TNLengths{}[I1];
194 AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
198 BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
209 CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
212 inner_product<a_vector_t, b_vector_t, FloatC>(
213 a_vec.template AsType<a_vector_t>()[I0],
214 b_vec.template AsType<b_vector_t>()[I0],
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
__host__ constexpr __device__ auto to_multi_index(const T &x)
Definition: array_multi_index.hpp:28
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:10
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:300
int32_t index_t
Definition: ck.hpp:289
Definition: threadwise_contraction_dl.hpp:130
constexpr __device__ ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
Definition: threadwise_contraction_dl.hpp:131
static __device__ void Run(const ABuffer &a_buf, AOriginIdx, const BBuffer &b_buf, BOriginIdx, CBuffer &c_buf, COriginIdx)
Definition: threadwise_contraction_dl.hpp:152
Definition: threadwise_contraction_dl.hpp:31
static __device__ void Run(const ABuffer &a_buf, AOriginIdx, const BBuffer &b_buf, BOriginIdx, CBuffer &c_buf, COriginIdx)
Definition: threadwise_contraction_dl.hpp:53
constexpr __device__ ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1()
Definition: threadwise_contraction_dl.hpp:32
Definition: integral_constant.hpp:10
Definition: is_known_at_compile_time.hpp:14
Definition: functional2.hpp:31
Definition: data_type.hpp:347