54 template <
typename ADataType,
61 typename CompilerTarget =
62 decltype(get_compiler_target()),
65 typename MmaDefaultSelector<ADataType,
72 CompilerTarget>::SelectedOp,
73 typename MmaTransforms =
74 typename MmaTransformsDefaultSelector<MmaOp, CompilerTarget>::SelectedTransforms>
109 static_assert(FragM >=
BlockM,
"FragM must be larger than BlockM");
110 static_assert(FragN >=
BlockN,
"FragN must be larger than BlockN");
111 static_assert(FragK >=
BlockK,
"FragK must be larger than BlockK");
112 static_assert(FragM %
BlockM == 0u,
"FragM must be a multiple of BlockM");
113 static_assert(FragN %
BlockN == 0u,
"FragN must be a multiple of BlockN");
114 static_assert(FragK %
BlockK == 0u,
"FragK must be a multiple of BlockK");
117 template <
typename DstT,
typename SrcT>
123 static_assert(
sizeof(DstT) ==
sizeof(SrcT),
"Size mismatch in formatBuffer");
124 return reinterpret_cast<DstT const&
>(inputBuffer);
127 template <
typename DstT,
typename SrcT>
133 static_assert(
sizeof(DstT) ==
sizeof(SrcT),
"Size mismatch in formatBuffer");
134 return reinterpret_cast<DstT&
>(inputBuffer);
142 template <
typename VecTA,
typename VecTB,
typename VecTC>
143 CK_TILE_DEVICE static decltype(
auto) exec_col_major(VecTA&&
a, VecTB&& b, VecTC&& accum)
150 auto a_frag = formatBuffer<ABufferType>(ATransform::exec(
a));
151 auto b_frag = formatBuffer<BBufferType>(BTransform::exec(b));
152 auto c_frag = formatBuffer<CBufferType>(CTransform::exec(accum));
163 BlockWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
170 return DTransform::exec(formatBuffer<std::decay_t<VecTC>>(c_frag));
178 template <
typename VecTA,
typename VecTB,
typename VecTC>
179 CK_TILE_DEVICE static decltype(
auto) exec_row_major(VecTA&&
a, VecTB&& b, VecTC&& accum)
186 auto a_frag = formatBuffer<ABufferType>(ATransform::exec(
a));
187 auto b_frag = formatBuffer<BBufferType>(BTransform::exec(b));
188 auto c_frag = formatBuffer<CBufferType>(CTransform::exec(accum));
201 BlockWiseMmaOp::exec(a_frag[bm][bk], b_frag[bn][bk], c_frag[bm][bn]);
208 return DTransform::exec(formatBuffer<std::decay_t<VecTC>>(c_frag));
217 template <
typename VecTA,
typename VecTB,
typename VecTC>
222 return exec_row_major(
223 std::forward<VecTA>(
a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
227 return exec_col_major(
228 std::forward<VecTA>(
a), std::forward<VecTB>(b), std::forward<VecTC>(accum));
#define CK_TILE_DEVICE
Definition: config.hpp:45
Definition: amdgcn_mma.hpp:10
MmaAccumPolicy
Accumulation order for Mma decomposition.
Definition: mma.hpp:21
const GenericPointer< typename T::ValueType > T2 T::AllocatorType & a
Definition: pointer.h:1517
unsigned int uint32_t
Definition: stdint.h:126
Reflects the template parameters and static members of a given MmaOp.
Definition: mma_traits.hpp:125
typename MmaOp::CVecType CVecType
Definition: mma_traits.hpp:130
typename MmaOp::BVecType BVecType
Definition: mma_traits.hpp:129
typename MmaOp::AVecType AVecType
Definition: mma_traits.hpp:128
typename MmaTransforms::CTransform CTransform
Definition: mma.hpp:105
constexpr static uint32_t BlocksM
Definition: mma.hpp:87
typename MmaTransforms::ATransform ATransform
Definition: mma.hpp:103
constexpr static uint32_t BlockK
Definition: mma.hpp:84
constexpr static uint32_t BlocksC
Definition: mma.hpp:90
constexpr static uint32_t BlockM
Definition: mma.hpp:82
typename BlockWiseMmaOpTraits::BVecType BVecType
Definition: mma.hpp:94
constexpr static uint32_t BlockN
Definition: mma.hpp:83
static decltype(auto) CK_TILE_DEVICE exec(VecTA &&a, VecTB &&b, VecTC &&accum)
Forward to Mma operation with specified accumulation order.
Definition: mma.hpp:218
constexpr static uint32_t BlocksK
Definition: mma.hpp:89
typename MmaTransforms::BTransform BTransform
Definition: mma.hpp:104
AVecType[BlocksM][BlocksK] ABufferType
Definition: mma.hpp:98
MmaOp BlockWiseMmaOp
Definition: mma.hpp:78
typename BlockWiseMmaOpTraits::AVecType AVecType
Definition: mma.hpp:93
typename MmaTransforms::DTransform DTransform
Definition: mma.hpp:106
BVecType[BlocksN][BlocksK] BBufferType
Definition: mma.hpp:99
CVecType[BlocksM][BlocksN] CBufferType
Definition: mma.hpp:100
constexpr static uint32_t BlocksN
Definition: mma.hpp:88
typename BlockWiseMmaOpTraits::CVecType CVecType
Definition: mma.hpp:95