17 template <
typename FloatA,
 
   20           typename AThreadDesc_TK0_TM0_TM1_TK1,
 
   21           typename BThreadDesc_TK0_TN0_TN1_TK1,
 
   22           typename CThreadDesc_TM0_TM1_TN0_TN1,
 
   26           typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
 
   27                                  BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
 
   28                                  CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
 
   34         static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
 
   35                           BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
 
   36                           CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
 
   37                       "wrong! Desc should be known at compile-time");
 
   43         static_assert(TKLengths::Size() == 1 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
 
   47     template <
typename ABuffer,
 
   53     __device__ 
static void Run(
const ABuffer& a_buf,
 
   63                       "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
 
   69             "wrong! inconsistent type");
 
   74         constexpr 
auto TK  = TKLengths{}[I0];
 
   75         constexpr 
auto TM0 = TMLengths{}[I0];
 
   76         constexpr 
auto TM1 = TMLengths{}[I1];
 
   77         constexpr 
auto TN0 = TNLengths{}[I0];
 
   78         constexpr 
auto TN1 = TNLengths{}[I1];
 
   90                                 AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
 
   93                                 BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
 
   96                                 CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
 
  116 template <
typename FloatA,
 
  119           typename AThreadDesc_TK0_TM0_TM1_TK1,
 
  120           typename BThreadDesc_TK0_TN0_TN1_TK1,
 
  121           typename CThreadDesc_TM0_TM1_TN0_TN1,
 
  125           typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
 
  126                                  BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
 
  127                                  CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
 
  133         static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
 
  134                           BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
 
  135                           CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
 
  136                       "wrong! Desc should be known at compile-time");
 
  142         static_assert(TKLengths::Size() == 2 && TMLengths::Size() == 2 && TNLengths::Size() == 2,
 
  146     template <
typename ABuffer,
 
  152     __device__ 
static void Run(
const ABuffer& a_buf,
 
  154                                const BBuffer& b_buf,
 
  162                       "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
 
  168             "wrong! inconsistent type");
 
  173         constexpr 
index_t TK0 = TKLengths{}[I0];
 
  174         constexpr 
index_t TK1 = TKLengths{}[I1];
 
  175         constexpr 
index_t TM0 = TMLengths{}[I0];
 
  176         constexpr 
index_t TM1 = TMLengths{}[I1];
 
  177         constexpr 
index_t TN0 = TNLengths{}[I0];
 
  178         constexpr 
index_t TN1 = TNLengths{}[I1];
 
  194                                     AThreadDesc_TK0_TM0_TM1_TK1{}.CalculateOffset(
 
  198                                     BThreadDesc_TK0_TN0_TN1_TK1{}.CalculateOffset(
 
  209                                 CThreadDesc_TM0_TM1_TN0_TN1{}.CalculateOffset(
 
  212                             inner_product<a_vector_t, b_vector_t, FloatC>(
 
  213                                 a_vec.template AsType<a_vector_t>()[I0],
 
  214                                 b_vec.template AsType<b_vector_t>()[I0],
 
__host__ constexpr __device__ auto make_multi_index(Xs &&... xs)
Definition: array_multi_index.hpp:15
 
__host__ constexpr __device__ auto to_multi_index(const T &x)
Definition: array_multi_index.hpp:28
 
std::enable_if< B, T > enable_if
Definition: enable_if.hpp:24
 
remove_cv_t< remove_reference_t< T > > remove_cvref_t
Definition: type.hpp:297
 
int32_t index_t
Definition: ck.hpp:300
 
Definition: threadwise_contraction_dl.hpp:130
 
constexpr __device__ ThreadwiseContractionDl_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
Definition: threadwise_contraction_dl.hpp:131
 
static __device__ void Run(const ABuffer &a_buf, AOriginIdx, const BBuffer &b_buf, BOriginIdx, CBuffer &c_buf, COriginIdx)
Definition: threadwise_contraction_dl.hpp:152
 
Definition: threadwise_contraction_dl.hpp:31
 
static __device__ void Run(const ABuffer &a_buf, AOriginIdx, const BBuffer &b_buf, BOriginIdx, CBuffer &c_buf, COriginIdx)
Definition: threadwise_contraction_dl.hpp:53
 
constexpr __device__ ThreadwiseGemmDl_km0m1_kn0n1_m0m1n0n1()
Definition: threadwise_contraction_dl.hpp:32
 
Definition: integral_constant.hpp:20
 
Definition: is_known_at_compile_time.hpp:14
 
Definition: functional2.hpp:33
 
Definition: dtype_vector.hpp:10