20 namespace tensor_operation {
23 template <
typename ADataType,
30 typename AElementwiseOperation,
31 typename BElementwiseOperation,
32 typename CElementwiseOperation,
43 typename ABlockTransferThreadClusterLengths_K0_M_K1,
44 typename ABlockTransferThreadClusterArrangeOrder,
45 typename ABlockTransferSrcAccessOrder,
49 bool ABlockLdsAddExtraM,
50 typename BBlockTransferThreadClusterLengths_K0_N_K1,
51 typename BBlockTransferThreadClusterArrangeOrder,
52 typename BBlockTransferSrcAccessOrder,
56 bool BBlockLdsAddExtraN,
68 AElementwiseOperation,
69 BElementwiseOperation,
70 CElementwiseOperation>
88 AElementwiseOperation,
89 BElementwiseOperation,
90 CElementwiseOperation,
100 ABlockTransferThreadClusterLengths_K0_M_K1,
101 ABlockTransferThreadClusterArrangeOrder,
102 ABlockTransferSrcAccessOrder,
103 ABlockTransferSrcVectorDim,
104 ABlockTransferSrcScalarPerVector,
105 ABlockTransferDstScalarPerVector_K1,
108 BBlockTransferThreadClusterLengths_K0_N_K1,
109 BBlockTransferThreadClusterArrangeOrder,
110 BBlockTransferSrcAccessOrder,
111 BBlockTransferSrcVectorDim,
112 BBlockTransferSrcScalarPerVector,
113 BBlockTransferDstScalarPerVector_K1,
117 CThreadTransferSrcDstVectorDim,
118 CThreadTransferDstScalarPerVector,
130 if(stream_config.log_level_ > 0)
137 throw std::runtime_error(
138 "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext has invalid setting");
147 const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, true>;
150 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
154 const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, false>;
157 stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
167 return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
181 if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
182 is_same_v<AccDataType, int32_t>))
189 if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
190 is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
215 const BDataType* p_b,
223 AElementwiseOperation,
224 BElementwiseOperation,
225 CElementwiseOperation)
227 return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
242 AElementwiseOperation,
243 BElementwiseOperation,
244 CElementwiseOperation)
override
246 return std::make_unique<Argument>(
static_cast<const ADataType*
>(p_a),
247 static_cast<const BDataType*
>(p_b),
248 static_cast<CDataType*
>(p_c),
260 return std::make_unique<Invoker>(
Invoker{});
266 auto str = std::stringstream();
268 std::map<LoopScheduler, std::string> LoopSchedToString{
271 std::map<PipelineVersion, std::string> PipelineVersionToString{{
PipelineVersion::v1,
"v1"},
275 str <<
"DeviceGemmXdl"
280 << K0PerBlock <<
", "
284 << MXdlPerWave <<
", "
285 << NXdlPerWave <<
", "
286 << ABlockTransferSrcScalarPerVector <<
", "
287 << ABlockTransferDstScalarPerVector_K1 <<
", "
288 << BBlockTransferSrcScalarPerVector <<
", "
289 << BBlockTransferDstScalarPerVector_K1
292 << NumPrefetch <<
", "
294 << LoopSchedToString[LoopSched] <<
", "
295 <<
"PipelineVersion: "
296 << PipelineVersionToString[PipelineVer];
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
GemmSpecialization
Definition: gemm_specialization.hpp:11
bool is_lds_direct_load_supported()
Definition: device_prop.hpp:61
std::string get_device_name()
Definition: device_prop.hpp:12
LoopScheduler
Definition: loop_scheduler.hpp:15
int32_t index_t
Definition: ck.hpp:289
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:17
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
Definition: stream_config.hpp:10
Definition: gridwise_gemm_xdlops_v2r3.hpp:781
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdlops_v2r3.hpp:968
ck::GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, FloatAB, FloatAcc, FloatC, CGlobalMemoryDataOperation, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1Value, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, AThreadTransferSrcResetCoordinateAfterRun, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, BThreadTransferSrcResetCoordinateAfterRun, BBlockLdsExtraN, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, 1, make_default_loop_scheduler(), PipelineVersion::v1 >::CalculateHasMainKBlockLoop static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdlops_v2r3.hpp:382
ck::GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, FloatAB, FloatAcc, FloatC, CGlobalMemoryDataOperation, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1Value, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, AThreadTransferSrcResetCoordinateAfterRun, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, BThreadTransferSrcResetCoordinateAfterRun, BBlockLdsExtraN, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, 1, make_default_loop_scheduler(), PipelineVersion::v1 >::CalculateGridSize static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_xdlops_v2r3.hpp:151
Definition: sequence.hpp:43
Definition: integral_constant.hpp:10
Definition: device_base.hpp:50
Definition: device_base.hpp:61
Definition: device_gemm.hpp:22
Definition: device_gemm_xdl.hpp:127
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl.hpp:164
float Run(const Argument &karg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl.hpp:128
Definition: device_gemm_xdl.hpp:71
static bool IsSupportedArgument(const Argument &karg)
Definition: device_gemm_xdl.hpp:177
static constexpr auto K1Number
Definition: device_gemm_xdl.hpp:76
static constexpr auto I0
Definition: device_gemm_xdl.hpp:72
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl.hpp:209
static auto MakeInvoker()
Definition: device_gemm_xdl.hpp:230
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl.hpp:258
std::string GetTypeString() const override
Definition: device_gemm_xdl.hpp:264
static constexpr auto I2
Definition: device_gemm_xdl.hpp:74
static constexpr auto I1
Definition: device_gemm_xdl.hpp:73
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_gemm_xdl.hpp:214
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_xdl.hpp:171
typename GridwiseGemm::Argument Argument
Definition: device_gemm_xdl.hpp:123
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_gemm_xdl.hpp:233