20 namespace tensor_operation {
 
   23 template <
typename ADataType,
 
   30           typename AElementwiseOperation,
 
   31           typename BElementwiseOperation,
 
   32           typename CElementwiseOperation,
 
   43           typename ABlockTransferThreadClusterLengths_K0_M_K1,
 
   44           typename ABlockTransferThreadClusterArrangeOrder,
 
   45           typename ABlockTransferSrcAccessOrder,
 
   49           bool ABlockLdsAddExtraM,
 
   50           typename BBlockTransferThreadClusterLengths_K0_N_K1,
 
   51           typename BBlockTransferThreadClusterArrangeOrder,
 
   52           typename BBlockTransferSrcAccessOrder,
 
   56           bool BBlockLdsAddExtraN,
 
   68                                          AElementwiseOperation,
 
   69                                          BElementwiseOperation,
 
   70                                          CElementwiseOperation>
 
   88         AElementwiseOperation,
 
   89         BElementwiseOperation,
 
   90         CElementwiseOperation,
 
  100         ABlockTransferThreadClusterLengths_K0_M_K1,
 
  101         ABlockTransferThreadClusterArrangeOrder,
 
  102         ABlockTransferSrcAccessOrder,
 
  103         ABlockTransferSrcVectorDim,
 
  104         ABlockTransferSrcScalarPerVector,
 
  105         ABlockTransferDstScalarPerVector_K1,
 
  108         BBlockTransferThreadClusterLengths_K0_N_K1,
 
  109         BBlockTransferThreadClusterArrangeOrder,
 
  110         BBlockTransferSrcAccessOrder,
 
  111         BBlockTransferSrcVectorDim,
 
  112         BBlockTransferSrcScalarPerVector,
 
  113         BBlockTransferDstScalarPerVector_K1,
 
  117         CThreadTransferSrcDstVectorDim,
 
  118         CThreadTransferDstScalarPerVector,
 
  130             if(stream_config.log_level_ > 0)
 
  137                 throw std::runtime_error(
 
  138                     "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3_ext has invalid setting");
 
  147                 const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, true>;
 
  150                     stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
 
  154                 const auto kernel = kernel_gemm_xdlops_v2r3<GridwiseGemm, false>;
 
  157                     stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, karg);
 
  167             return Run(*
dynamic_cast<const Argument*
>(p_arg), stream_config);
 
  181             if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
 
  182                            is_same_v<AccDataType, int32_t>))
 
  189             if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, float> ||
 
  190                            is_same_v<AccDataType, int32_t> || is_same_v<AccDataType, double>))
 
  215                              const BDataType* p_b,
 
  223                              AElementwiseOperation,
 
  224                              BElementwiseOperation,
 
  225                              CElementwiseOperation)
 
  227         return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC};
 
  242                                                       AElementwiseOperation,
 
  243                                                       BElementwiseOperation,
 
  244                                                       CElementwiseOperation)
 override 
  246         return std::make_unique<Argument>(
static_cast<const ADataType*
>(p_a),
 
  247                                           static_cast<const BDataType*
>(p_b),
 
  248                                           static_cast<CDataType*
>(p_c),
 
  260         return std::make_unique<Invoker>(
Invoker{});
 
  266         auto str = std::stringstream();
 
  268         std::map<LoopScheduler, std::string> LoopSchedToString{
 
  271         std::map<PipelineVersion, std::string> PipelineVersionToString{{
PipelineVersion::v1, 
"v1"},
 
  275         str << 
"DeviceGemmXdl" 
  280             << K0PerBlock << 
", " 
  284             << MXdlPerWave << 
", " 
  285             << NXdlPerWave << 
", " 
  286             << ABlockTransferSrcScalarPerVector << 
", " 
  287             << ABlockTransferDstScalarPerVector_K1 << 
", " 
  288             << BBlockTransferSrcScalarPerVector << 
", " 
  289             << BBlockTransferDstScalarPerVector_K1
 
  292             << NumPrefetch << 
", " 
  294             << LoopSchedToString[LoopSched] << 
", " 
  295             << 
"PipelineVersion: " 
  296             << PipelineVersionToString[PipelineVer];
 
float launch_and_time_kernel(const StreamConfig &stream_config, F kernel, dim3 grid_dim, dim3 block_dim, std::size_t lds_byte, Args... args)
Definition: kernel_launch.hpp:13
 
GemmSpecialization
Definition: gemm_specialization.hpp:11
 
bool is_lds_direct_load_supported()
Definition: device_prop.hpp:61
 
std::string get_device_name()
Definition: device_prop.hpp:12
 
LoopScheduler
Definition: loop_scheduler.hpp:15
 
int32_t index_t
Definition: ck.hpp:289
 
PipelineVersion
Definition: gridwise_gemm_pipeline_selector.hpp:17
 
constexpr LoopScheduler make_default_loop_scheduler()
Definition: loop_scheduler.hpp:20
 
Definition: stream_config.hpp:10
 
Definition: gridwise_gemm_xdlops_v2r3.hpp:781
 
static constexpr __host__ bool CheckValidity(const Problem &problem)
Definition: gridwise_gemm_xdlops_v2r3.hpp:968
 
ck::GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, FloatAB, FloatAcc, FloatC, CGlobalMemoryDataOperation, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1Value, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, AThreadTransferSrcResetCoordinateAfterRun, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, BThreadTransferSrcResetCoordinateAfterRun, BBlockLdsExtraN, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, 1, make_default_loop_scheduler(), PipelineVersion::v1 >::CalculateHasMainKBlockLoop static constexpr __host__ bool CalculateHasMainKBlockLoop(index_t K)
Definition: gridwise_gemm_xdlops_v2r3.hpp:382
 
ck::GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3< BlockSize, FloatAB, FloatAcc, FloatC, CGlobalMemoryDataOperation, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1Value, MXdlPerWave, NXdlPerWave, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, AThreadTransferSrcResetCoordinateAfterRun, ABlockLdsExtraM, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, BThreadTransferSrcResetCoordinateAfterRun, BBlockLdsExtraN, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, 1, make_default_loop_scheduler(), PipelineVersion::v1 >::CalculateGridSize static __host__ auto CalculateGridSize(index_t M, index_t N)
Definition: gridwise_gemm_xdlops_v2r3.hpp:151
 
Definition: sequence.hpp:43
 
Definition: integral_constant.hpp:10
 
Definition: device_base.hpp:50
 
Definition: device_base.hpp:61
 
Definition: device_gemm.hpp:22
 
Definition: device_gemm_xdl.hpp:127
 
float Run(const BaseArgument *p_arg, const StreamConfig &stream_config=StreamConfig{}) override
Definition: device_gemm_xdl.hpp:164
 
float Run(const Argument &karg, const StreamConfig &stream_config=StreamConfig{})
Definition: device_gemm_xdl.hpp:128
 
Definition: device_gemm_xdl.hpp:71
 
static bool IsSupportedArgument(const Argument &karg)
Definition: device_gemm_xdl.hpp:177
 
static constexpr auto K1Number
Definition: device_gemm_xdl.hpp:76
 
static constexpr auto I0
Definition: device_gemm_xdl.hpp:72
 
bool IsSupportedArgument(const BaseArgument *p_arg) override
Definition: device_gemm_xdl.hpp:209
 
static auto MakeInvoker()
Definition: device_gemm_xdl.hpp:230
 
std::unique_ptr< BaseInvoker > MakeInvokerPointer() override
Definition: device_gemm_xdl.hpp:258
 
std::string GetTypeString() const override
Definition: device_gemm_xdl.hpp:264
 
static constexpr auto I2
Definition: device_gemm_xdl.hpp:74
 
static constexpr auto I1
Definition: device_gemm_xdl.hpp:73
 
static auto MakeArgument(const ADataType *p_a, const BDataType *p_b, CDataType *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation)
Definition: device_gemm_xdl.hpp:214
 
static constexpr bool IsValidCompilationParameter()
Definition: device_gemm_xdl.hpp:171
 
typename GridwiseGemm::Argument Argument
Definition: device_gemm_xdl.hpp:123
 
std::unique_ptr< BaseArgument > MakeArgumentPointer(const void *p_a, const void *p_b, void *p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, AElementwiseOperation, BElementwiseOperation, CElementwiseOperation) override
Definition: device_gemm_xdl.hpp:233