hipBLASLtExt API reference#
hipBLASLt has extension APIs with namespace hipblaslt_ext
. It is only C++ compatible. The extensions support:
hipBLASLtExt datatypes reference#
GemmType#
GemmProblemType#
-
struct GemmProblemType#
hipblasLt extension ProblemType for gemm problems.
This structure sets the problem type of a gemm problem.
Public Members
-
hipblasOperation_t op_a#
The A martix transpose.
-
hipblasOperation_t op_b#
The B matrix transpose.
-
hipDataType type_a#
The A matrix datatype.
-
hipDataType type_b#
The B matrix datatype.
-
hipDataType type_c#
The C matrix datatype.
-
hipDataType type_d#
The D matrix datatype.
-
hipblasComputeType_t type_compute#
The compute datatype.
-
hipblasOperation_t op_a#
GemmEpilogue#
-
struct GemmEpilogue#
hipblasLt extension Epilogue for gemm problems.
This structure sets the epilogue of a gemm problem.
Public Members
-
hipblasLtEpilogue_t mode = HIPBLASLT_EPILOGUE_DEFAULT#
The mode of epilogue. Default is gemm.
-
hipDataType bias_data_type = HIPBLASLT_DATATYPE_INVALID#
The bias datatype. Only works if mode is set to bias related epilogues.
-
int aux_ld = 0#
The aux leading dimension. Only works if mode is set to aux related epilogues.
-
int aux_stride = 0#
The aux batch stride. Only works if mode is set to aux related epilogues.
-
hipblasLtEpilogue_t mode = HIPBLASLT_EPILOGUE_DEFAULT#
GemmInputs#
-
struct GemmInputs#
hipblasLt extension Inputs for gemm problems.
This structure sets the input pointers of a gemm problem.
Public Members
-
void *a = nullptr#
The a matrix input pointer.
-
void *b = nullptr#
The b matrix input pointer.
-
void *c = nullptr#
The c matrix input pointer.
-
void *d = nullptr#
The d matrix input pointer.
-
void *alpha = nullptr#
The alpha value.
-
void *beta = nullptr#
The beta value.
-
void *bias = nullptr#
The bias input pointer.
-
void *scaleA = nullptr#
The Scale A input pointer.
-
void *scaleB = nullptr#
The Scale B input pointer.
-
void *scaleC = nullptr#
The Scale C input pointer.
-
void *scaleD = nullptr#
The Scale D input pointer.
-
void *scaleAux = nullptr#
The Scale AUX input pointer.
-
void *scaleAlphaVec = nullptr#
The scaleAlpha vector input pointer.
-
void *aux = nullptr#
The aux input pointer.
-
void *a = nullptr#
hipBLASLtExt Class Reference#
GemmPreference#
-
class GemmPreference#
hipblasLt extension preference for gemm problems.
Currently only supports setting max workspace size.
Public Functions
-
void setMaxWorkspaceBytes(size_t workspaceBytes)#
This function sets the max workspace size.
- Parameters:
workspaceBytes – [in] Set the max workspace size in bytes.
-
const size_t getMaxWorkspaceBytes() const#
This function returns the set max workspace size.
- Return values:
size_t – Returns the set max workspace size.
-
void setMaxWorkspaceBytes(size_t workspaceBytes)#
GemmInstance#
-
class GemmInstance#
hipblasLt extension instance for gemm problems.
Subclassed by hipblaslt_ext::Gemm, hipblaslt_ext::GroupedGemm
Public Functions
-
hipblasStatus_t algoGetHeuristic(const int requestedAlgoCount, const GemmPreference &pref, std::vector<hipblasLtMatmulHeuristicResult_t> &heuristicResults)#
Retrieve the possible algorithms.
This function retrieves the possible algorithms for the matrix multiply operation hipblasLtMatmul() function with the given data and compute tpye. The output is placed in heuristicResult in the order of increasing estimated compute time.
- Parameters:
requestedAlgoCount – [in] number of requested algorithms.
pref – [in] hipblasLt extension preference for gemm problems.
heuristicResults – [out] The algorithm heuristic vector.
- Return values:
HIPBLAS_STATUS_SUCCESS – If query was successful. Inspect heuristicResults.size > 0, but may heuristicResults.size < requestedAlgoCount state for the status of the results.
HIPBLAS_STATUS_NOT_SUPPORTED – If no heuristic function available for current configuration.
HIPBLAS_STATUS_INVALID_VALUE – If no solution is found.
-
hipblasStatus_t isAlgoSupported(hipblasLtMatmulAlgo_t &algo, size_t &workspaceSizeInBytes)#
Check if the algorithm supports the problem. (For hipblaslt extension API)
This function updates the problem saved inside the algorithm if the problem is supported. The required workspaceSizeInBytes is also returned.
- Parameters:
algo – [in] The algorithm heuristic.
workspaceSizeInBytes – [out] Return the required workspace size.
- Return values:
HIPBLAS_STATUS_SUCCESS – If query was successful. The problem is supported by the algorithm. results.
HIPBLAS_STATUS_INVALID_VALUE – The problem is not supported.
-
hipblasStatus_t isAlgoSupported(hipblasLtMatmulAlgo_t &algo, GemmTuning &tuning, size_t &workspaceSizeInBytes)#
Check if the algorithm supports the problem. (For hipblaslt extension API)
This function updates the problem saved inside the algorithm if the problem is supported. The required workspaceSizeInBytes is also returned.
- Parameters:
algo – [in] The algorithm heuristic.
tuning – [in] The tuning parameters.
workspaceSizeInBytes – [out] Return the required workspace size.
- Return values:
HIPBLAS_STATUS_SUCCESS – If query was successful. The problem is supported by the algorithm. results.
HIPBLAS_STATUS_INVALID_VALUE – The problem is not supported.
-
hipblasStatus_t initialize(const hipblasLtMatmulAlgo_t &algo, void *workspace, bool useUserArgs = true, hipStream_t stream = 0)#
Create kernel arguments from a given hipblaslt_ext::GemmInstance.
This function creates kernel arguments from a given hipblaslt_ext::GemmInstance then saves the arguments inside the instance.
- Parameters:
algo – [in] Handle for matrix multiplication algorithm to be used. See hipblaslt.h::hipblasLtMatmulAlgo_t . When NULL, an implicit heuristics query with default search preferences will be performed to determine actual algorithm to use.
workspace – [in] Pointer to the workspace buffer allocated in the GPU memory. Pointer must be 16B aligned (that is, lowest 4 bits of address must be 0).
useUserArgs – [in] Use user args, this does not affect vanilla gemm. (May be deprecated in the future)
stream – [in] The HIP stream where all the GPU work will be submitted. (May be deprecated in the future)
- Return values:
HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.
HIPBLAS_STATUS_INVALID_VALUE – If the gemm_count = 0.
-
hipblasStatus_t initialize(const hipblasLtMatmulAlgo_t &algo, GemmTuning &tuning, void *workspace, bool useUserArgs = true, hipStream_t stream = 0)#
Create kernel arguments from a given hipblaslt_ext::GemmInstance.
This function creates kernel arguments from a given hipblaslt_ext::GemmInstance then saves the arguments inside the instance.
- Parameters:
algo – [in] Handle for matrix multiplication algorithm to be used. See hipblaslt.h::hipblasLtMatmulAlgo_t . When NULL, an implicit heuristics query with default search preferences will be performed to determine actual algorithm to use.
tuning – [in] Structure with user tuning parameters. Note that not every algo supports user tuning parameters. Will return HIPBLAS_STATUS_INVALID_VALUE if not supported. be 0).
workspace – [in] Pointer to the workspace buffer allocated in the GPU memory. Pointer must be 16B aligned (that is, lowest 4 bits of address must be 0).
useUserArgs – [in] Use user args, this does not affect vanilla gemm. (May be deprecated in the future)
stream – [in] The HIP stream where all the GPU work will be submitted. (May be deprecated in the future)
- Return values:
HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.
HIPBLAS_STATUS_INVALID_VALUE – If the gemm_count = 0.
-
hipblasStatus_t run(hipStream_t stream)#
Execute the kernel arguments stored inside the hipblaslt_ext::GemmInstance.
- Parameters:
stream – [in] The HIP stream where all the GPU work will be submitted.
- Return values:
HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.
Protected Functions
-
explicit GemmInstance(hipblasLtHandle_t handle, GemmType type)#
Constructor of GemmInstance.
-
hipblasStatus_t algoGetHeuristic(const int requestedAlgoCount, const GemmPreference &pref, std::vector<hipblasLtMatmulHeuristicResult_t> &heuristicResults)#
Gemm#
-
class Gemm : public hipblaslt_ext::GemmInstance#
hipblasLt extension instance for gemm.
The instance can be used to create arguments to compute the matrix multiplication of matrices A and B to produce the output matrix D, according to the following operation:
D
=alpha*
(A
*B
) +beta*
(C
), whereA
,B
, andC
are input matrices, andalpha
andbeta
are input scalars.Public Functions
-
explicit Gemm(hipblasLtHandle_t handle, hipblasOperation_t opA, hipblasOperation_t opB, hipDataType typeA, hipDataType typeB, hipDataType typeC, hipDataType typeD, hipblasComputeType_t typeCompute)#
Constructor.
This function set the problem from hipblasLt structures. For more information about the structures, see hipblasLtMatmul for more information.
- Parameters:
handle – [in] The handle from hipBLASLt.
opA, opB – [in] The transpose type of matrix A, B
typeA, typeB, typeC, typeD – [in] The data type of matrix A, B, C, D
typeCompute – [in] The compute type of the gemm problem
-
explicit Gemm(hipblasLtHandle_t handle, hipblasLtMatmulDesc_t matmul_descr, const void *alpha, const void *A, hipblasLtMatrixLayout_t matA, const void *B, hipblasLtMatrixLayout_t matB, const void *beta, const void *C, hipblasLtMatrixLayout_t matC, void *D, hipblasLtMatrixLayout_t matD)#
Constructor that sets the gemm problem from hipblasLt structures.
This constructor sets the problem from hipblasLt structures. For more information about the structures, see hipblasLtMatmul for more information.
- Parameters:
handle – [in] The handle from hipBLASLt.
matmul_descr – [in] Handle to a previously created matrix multiplication descriptor of type hipblasLtMatmulDesc_t .
alpha, beta – [in] Pointers to the scalars used in the multiplication.
matA, matB, matC, matD – [in] Handles to the previously created matrix layout descriptors of the type hipblasLtMatrixLayout_t .
A, B, C – [in] Pointers to the GPU memory associated with the corresponding descriptors
matA
,matB
andmatC
.D – [out] Pointer to the GPU memory associated with the descriptor
matD
.
-
hipblasStatus_t setProblem(int64_t m, int64_t n, int64_t k, int64_t batch_count, GemmEpilogue &epilogue, GemmInputs &inputs)#
Sets the problem for a gemm problem.
This function sets the problem with m, n, k, batch_count. It uses the problem type sets from the constructor.
- Parameters:
m, n, k – [in] The problem size.
batch_count – [in] The batch count.
epilogue – [in] The structure that controls the epilogue.
inputs – [in] The inputs of the problem.
- Return values:
HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.
HIPBLAS_STATUS_EXECUTION_FAILED – If HIP reported an execution error from the device.
HIPBLAS_STATUS_ARCH_MISMATCH – If the configured operation cannot be run using the selected device.
HIPBLAS_STATUS_NOT_SUPPORTED – If the current implementation on the selected device doesn’t support the configured operation.
HIPBLAS_STATUS_INVALID_VALUE – If the parameters are unexpectedly NULL, in conflict or in an impossible configuration.
HIBLAS_STATUS_NOT_INITIALIZED – If hipBLASLt handle has not been initialized.
-
hipblasStatus_t setProblem(int64_t m, int64_t n, int64_t k, int64_t batch_count, int64_t lda, int64_t ldb, int64_t ldc, int64_t ldd, int64_t strideA, int64_t strideB, int64_t strideC, int64_t strideD, GemmEpilogue &epilogue, GemmInputs &inputs, GemmProblemType &problemtype)#
Sets the problem for a gemm problem.
This function sets the problem with m, n, k, batch_count. It uses the problem type sets from the constructor.
- Parameters:
m, n, k – [in] The problem size.
batch_count – [in] The batch count.
lda, ldb, ldc, ldd – [in] The leading dimensions of the matrix.
strideA, strideB, strideC, strideD – [in] The batch stride of the matrix.
epilogue – [in] The structure that controls the epilogue.
inputs – [in] The inputs of the problem.
problemtype – [in] The structure that sets the problem type of a gemm problem.
- Return values:
HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.
HIPBLAS_STATUS_EXECUTION_FAILED – If HIP reported an execution error from the device.
HIPBLAS_STATUS_ARCH_MISMATCH – If the configured operation cannot be run using the selected device.
HIPBLAS_STATUS_NOT_SUPPORTED – If the current implementation on the selected device doesn’t support the configured operation.
HIPBLAS_STATUS_INVALID_VALUE – If the parameters are unexpectedly NULL, in conflict or in an impossible configuration.
HIBLAS_STATUS_NOT_INITIALIZED – If hipBLASLt handle has not been initialized.
-
hipblasStatus_t setProblem(hipblasLtMatmulDesc_t matmul_descr, const void *alpha, const void *A, hipblasLtMatrixLayout_t matA, const void *B, hipblasLtMatrixLayout_t matB, const void *beta, const void *C, hipblasLtMatrixLayout_t matC, void *D, hipblasLtMatrixLayout_t matD)#
Sets the gemm problem from hipblasLt structures.
This function sets the problem from hipblasLt structures. For more information about the structures, see hipblasLtMatmul for more information.
- Parameters:
matmul_descr – [in] Handle to a previously created matrix multiplication descriptor of type hipblasLtMatmulDesc_t .
alpha, beta – [in] Pointers to the scalars used in the multiplication.
matA, matB, matC, matD – [in] Handles to the previously created matrix layout descriptors of the type hipblasLtMatrixLayout_t .
A, B, C – [in] Pointers to the GPU memory associated with the corresponding descriptors
matA
,matB
andmatC
.D – [out] Pointer to the GPU memory associated with the descriptor
matD
.
- Return values:
HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.
HIPBLAS_STATUS_EXECUTION_FAILED – If HIP reported an execution error from the device.
HIPBLAS_STATUS_ARCH_MISMATCH – If the configured operation cannot be run using the selected device.
HIPBLAS_STATUS_NOT_SUPPORTED – If the current implementation on the selected device doesn’t support the configured operation.
HIPBLAS_STATUS_INVALID_VALUE – If the parameters are unexpectedly NULL, in conflict or in an impossible configuration.
HIBLAS_STATUS_NOT_INITIALIZED – If hipBLASLt handle has not been initialized.
-
explicit Gemm(hipblasLtHandle_t handle, hipblasOperation_t opA, hipblasOperation_t opB, hipDataType typeA, hipDataType typeB, hipDataType typeC, hipDataType typeD, hipblasComputeType_t typeCompute)#
GroupedGemm#
-
class GroupedGemm : public hipblaslt_ext::GemmInstance#
hipblasLt extension instance for grouped gemm.
The instance can be used to create arguments to compute the matrix multiplication of matrices A and B to produce the output matrix D, according to the following operation:
D
=alpha*
(A
*B
) +beta*
(C
), whereA
,B
, andC
are input matrices, andalpha
andbeta
are input scalars.Public Functions
-
explicit GroupedGemm(hipblasLtHandle_t handle, hipblasOperation_t opA, hipblasOperation_t opB, hipDataType typeA, hipDataType typeB, hipDataType typeC, hipDataType typeD, hipblasComputeType_t typeCompute)#
Constructor.
This function set the problem from hipblasLt structures. For more information about the structures, see hipblasLtMatmul for more information.
- Parameters:
handle – [in] The handle from hipBLASLt.
opA, opB – [in] The transpose type of matrix A, B
typeA, typeB, typeC, typeD – [in] The data type of matrix A, B, C, D
typeCompute – [in] The compute type of the gemm problem
-
explicit GroupedGemm(hipblasLtHandle_t handle, std::vector<hipblasLtMatmulDesc_t> &matmul_descr, std::vector<void*> &alpha, std::vector<void*> &A, std::vector<hipblasLtMatrixLayout_t> &matA, std::vector<void*> &B, std::vector<hipblasLtMatrixLayout_t> &matB, std::vector<void*> &beta, std::vector<void*> &C, std::vector<hipblasLtMatrixLayout_t> &matC, std::vector<void*> &D, std::vector<hipblasLtMatrixLayout_t> &matD)#
Constructor that sets the grouped gemm problem from hipblasLt structures.
This constructor sets the problem from hipblasLt structures. For more information about the structures, see hipblasLtMatmul for more information.
- Parameters:
handle – [in] The handle from hipBLASLt.
matmul_descr – [in] Vectors of handle to a previously created matrix multiplication descriptor of type hipblasLtMatmulDesc_t .
alpha, beta – [in] Vectors of float used in the multiplication.
matA, matB, matC, matD – [in] Vectors of handle to the previously created matrix layout descriptors of the type hipblasLtMatrixLayout_t .
A, B, C – [in] Vectors of pointer to the GPU memory associated with the corresponding descriptors
matA
,matB
andmatC
.D – [out] Vector of pointer to the GPU memory associated with the descriptor
matD
.
-
hipblasStatus_t setProblem(std::vector<int64_t> &m, std::vector<int64_t> &n, std::vector<int64_t> &k, std::vector<int64_t> &batch_count, std::vector<GemmEpilogue> &epilogue, std::vector<GemmInputs> &inputs)#
Sets the problem for a gemm problem.
This function sets the problem with m, n, k, batch_count. It uses the problem type sets from the constructor.
- Parameters:
m, n, k – [in] The problem size in vector.
batch_count – [in] The batch count in vector.
epilogue – [in] The structure in vector that controls the epilogue.
inputs – [in] The inputs in vector of the problem.
- Return values:
HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.
HIPBLAS_STATUS_EXECUTION_FAILED – If HIP reported an execution error from the device.
HIPBLAS_STATUS_ARCH_MISMATCH – If the configured operation cannot be run using the selected device.
HIPBLAS_STATUS_NOT_SUPPORTED – If the current implementation on the selected device doesn’t support the configured operation.
HIPBLAS_STATUS_INVALID_VALUE – If the parameters are unexpectedly NULL, in conflict or in an impossible configuration.
HIBLAS_STATUS_NOT_INITIALIZED – If hipBLASLt handle has not been initialized.
-
hipblasStatus_t setProblem(std::vector<int64_t> &m, std::vector<int64_t> &n, std::vector<int64_t> &k, std::vector<int64_t> &batch_count, std::vector<int64_t> &lda, std::vector<int64_t> &ldb, std::vector<int64_t> &ldc, std::vector<int64_t> &ldd, std::vector<int64_t> &strideA, std::vector<int64_t> &strideB, std::vector<int64_t> &strideC, std::vector<int64_t> &strideD, std::vector<GemmEpilogue> &epilogue, std::vector<GemmInputs> &inputs, GemmProblemType &problemtype)#
Sets the problem for a gemm problem.
This function sets the problem with m, n, k, batch_count. It uses the problem type sets from the constructor.
- Parameters:
m, n, k – [in] The problem size in vector.
batch_count – [in] The batch count in vector.
lda, ldb, ldc, ldd – [in] The leading dimensions in vector of the matrix.
strideA, strideB, strideC, strideD – [in] The batch stride in vector of the matrix.
epilogue – [in] The structure in vector that controls the epilogue.
inputs – [in] The inputs in vector of the problem.
problemtype – [in] The structure that sets the problem type of a gemm problem.
- Return values:
HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.
HIPBLAS_STATUS_EXECUTION_FAILED – If HIP reported an execution error from the device.
HIPBLAS_STATUS_ARCH_MISMATCH – If the configured operation cannot be run using the selected device.
HIPBLAS_STATUS_NOT_SUPPORTED – If the current implementation on the selected device doesn’t support the configured operation.
HIPBLAS_STATUS_INVALID_VALUE – If the parameters are unexpectedly NULL, in conflict or in an impossible configuration.
HIBLAS_STATUS_NOT_INITIALIZED – If hipBLASLt handle has not been initialized.
-
hipblasStatus_t setProblem(std::vector<hipblasLtMatmulDesc_t> &matmul_descr, std::vector<void*> &alpha, std::vector<void*> &A, std::vector<hipblasLtMatrixLayout_t> &matA, std::vector<void*> &B, std::vector<hipblasLtMatrixLayout_t> &matB, std::vector<void*> &beta, std::vector<void*> &C, std::vector<hipblasLtMatrixLayout_t> &matC, std::vector<void*> &D, std::vector<hipblasLtMatrixLayout_t> &matD)#
Sets the grouped gemm problem from hipblasLt structures.
This function sets the problem from hipblasLt structures. For more information about the structures, see hipblasLtMatmul for more information.
- Parameters:
matmul_descr – [in] Vectors of handle to a previously created matrix multiplication descriptor of type hipblasLtMatmulDesc_t .
alpha, beta – [in] Vectors of float used in the multiplication.
matA, matB, matC, matD – [in] Vectors of handle to the previously created matrix layout descriptors of the type hipblasLtMatrixLayout_t .
A, B, C – [in] Vectors of pointer to the GPU memory associated with the corresponding descriptors
matA
,matB
andmatC
.D – [out] Vector of pointer to the GPU memory associated with the descriptor
matD
.
- Return values:
HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.
HIPBLAS_STATUS_EXECUTION_FAILED – If HIP reported an execution error from the device.
HIPBLAS_STATUS_ARCH_MISMATCH – If the configured operation cannot be run using the selected device.
HIPBLAS_STATUS_NOT_SUPPORTED – If the current implementation on the selected device doesn’t support the configured operation.
HIPBLAS_STATUS_INVALID_VALUE – If the parameters are unexpectedly NULL, in conflict or in an impossible configuration.
HIBLAS_STATUS_NOT_INITIALIZED – If hipBLASLt handle has not been initialized.
-
hipblasStatus_t getDefaultValueForDeviceUserArguments(void *hostDeviceUserArgs)#
A helper function to initialize DeviceUserArguments using the set problem(s) saved in the gemm object.
- Parameters:
hostDeviceUserArgs – [in] The DeviceUserArguments struture allocated in host. Note that the user must put the correct type of the DeviceUserArguments.
- Return values:
HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.
-
hipblasStatus_t run(void *deviceUserArgs, hipStream_t stream)#
Run the kernel using DeviceUserArguments.
- Parameters:
deviceUserArgs – [in] Pointer to the DeviceUserArguments buffer allocated in the GPU memory. Pointer must be 16B aligned (that is, lowest 4 bits of
stream – [in] The HIP stream where all the GPU work will be submitted.
- Return values:
HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.
HIPBLAS_STATUS_INVALID_VALUE – If the gemm_count = 0.
-
hipblasStatus_t run(hipStream_t stream)#
Execute the kernel arguments stored inside the hipblaslt_ext::GemmInstance.
- Parameters:
stream – [in] The HIP stream where all the GPU work will be submitted.
- Return values:
HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.
-
explicit GroupedGemm(hipblasLtHandle_t handle, hipblasOperation_t opA, hipblasOperation_t opB, hipDataType typeA, hipDataType typeB, hipDataType typeC, hipDataType typeD, hipblasComputeType_t typeCompute)#
hipBLASLtExt API Reference#
getAllAlgos()#
-
hipblasStatus_t hipblaslt_ext::getAllAlgos(hipblasLtHandle_t handle, GemmType typeGemm, hipblasOperation_t opA, hipblasOperation_t opB, hipDataType typeA, hipDataType typeB, hipDataType typeC, hipDataType typeD, hipblasComputeType_t typeCompute, std::vector<hipblasLtMatmulHeuristicResult_t> &heuristicResults)#
Retrieve the possible algorithms.
This function retrieves the possible algorithms for the matrix multiply operation hipblasLtMatmul() function with the given data and compute tpye. The output is placed in heuristicResults in the order of increasing estimated compute time.
- Parameters:
handle – [in] Pointer to the allocated hipBLASLt handle for the hipBLASLt context. See hipblasLtHandle_t .
typeGemm – [in] Gemm type. ex. GEMM, GROUPED_GEMM.
opA, opB – [in] Transpose settings of A, B.
typeA, typeB, typeC, typeD – [in] The data type of matrix A, B, C, D.
typeCompute – [in] The compute type.
heuristicResults – [out] The algorithm heuristic vector.
- Return values:
HIPBLAS_STATUS_SUCCESS – If query was successful. Inspect returnedAlgoCount > 0.state for the status of the results.
HIPBLAS_STATUS_NOT_SUPPORTED – If no heuristic function available for current configuration.
HIPBLAS_STATUS_INVALID_VALUE – If no solution is found.
getIndexFromAlgo()#
-
int hipblaslt_ext::getIndexFromAlgo(hipblasLtMatmulAlgo_t &algo)#
Retrieve the algorithm index.
- Parameters:
algo – [in] The algorithm.
- Return values:
int – The index of the algorithm, can be used to get heuristic results from getAlgosFromIndex. Returns -1 if the index stored in algo < 0. Note that the index may not be valid if the algo struct is not initialized properly.
getAlgosFromIndex()#
-
hipblasStatus_t hipblaslt_ext::getAlgosFromIndex(hipblasLtHandle_t handle, std::vector<int> &algoIndex, std::vector<hipblasLtMatmulHeuristicResult_t> &heuristicResults)#
Retrieve the possible algorithms.
This function retrieves the possible algorithms for the matrix multiply operation hipblasLtMatmul() function with the given index. The output is placed in heuristicResult in the order of increasing estimated compute time.
- Parameters:
handle – [in] Pointer to the allocated hipBLASLt handle for the hipBLASLt context. See hipblasLtHandle_t .
algoIndex – [in] The algorithm index vector.
heuristicResults – [out] The algorithm heuristic vector.
- Return values:
HIPBLAS_STATUS_SUCCESS – If query was successful. Inspect heuristicResults.size() > 0.state for the status of the results.
HIPBLAS_STATUS_NOT_SUPPORTED – If no heuristic function available for current configuration.
HIPBLAS_STATUS_INVALID_VALUE – If no solution is found.
matmulIsAlgoSupported()#
-
hipblasStatus_t hipblaslt_ext::matmulIsAlgoSupported(hipblasLtHandle_t handle, hipblasLtMatmulDesc_t matmulDesc, const void *alpha, hipblasLtMatrixLayout_t Adesc, hipblasLtMatrixLayout_t Bdesc, const void *beta, hipblasLtMatrixLayout_t Cdesc, hipblasLtMatrixLayout_t Ddesc, hipblasLtMatmulAlgo_t &algo, size_t &workspaceSizeInBytes)#
Check if the algorithm supports the problem. (For hipblasLt API)
This function updates the problem saved inside the algorithm if the problem is supported. The required workspaceSizeInBytes is also returned.
- Parameters:
handle – [in] Pointer to the allocated hipBLASLt handle for the hipBLASLt context. See hipblasLtHandle_t .
matmulDesc – [in] Handle to a previously created matrix multiplication descriptor of type hipblasLtMatmulDesc_t .
alpha, beta – [in] Pointers to the scalars used in the multiplication.
Adesc, Bdesc, Cdesc, Ddesc – [in] Handles to the previously created matrix layout descriptors of the type hipblasLtMatrixLayout_t .
algo – [in] The algorithm heuristic.
workspaceSizeInBytes – [out] Return the required workspace size.
- Return values:
HIPBLAS_STATUS_SUCCESS – If query was successful. The problem is supported by the algorithm. results.
HIPBLAS_STATUS_INVALID_VALUE – The problem is not supported.
hipblasLtExt usage#
Here are the three use-cases supported by the hipBLASLtExt APIs.
Gemm#
hipblasLt has its own instance.
You must assign the problem type when constructing or importing the problem from hipBLAS API.
HIPBLASLT_EXPORT explicit Gemm(hipblasLtHandle_t handle,
hipblasOperation_t opA,
hipblasOperation_t opB,
hipDataType typeA,
hipDataType typeB,
hipDataType typeC,
hipDataType typeD,
hipblasComputeType_t typeCompute);
HIPBLASLT_EXPORT explicit Gemm(hipblasLtHandle_t handle,
hipblasLtMatmulDesc_t matmul_descr,
const void* alpha,
const void* A,
hipblasLtMatrixLayout_t matA,
const void* B,
hipblasLtMatrixLayout_t matB,
const void* beta,
const void* C,
hipblasLtMatrixLayout_t matC,
void* D,
hipblasLtMatrixLayout_t matD);
After the instance is created, you can set the problem using the API. The API may require the following structures:
GemmProblemType
allows you to change the problem type after the instance is initialized.
struct GemmProblemType
{
hipblasOperation_t op_a;
hipblasOperation_t op_b;
hipDataType type_a;
hipDataType type_b;
hipDataType type_c;
hipDataType type_d;
hipblasComputeType_t type_compute;
};
GemmEpilogue
allows the user to control the epilogue of the problem.
struct GemmEpilogue
{
hipblasLtEpilogue_t mode = HIPBLASLT_EPILOGUE_DEFAULT;
hipDataType bias_data_type;
int aux_ld;
int aux_stride;
};
GemmInputs
specifies the problem inputs.
struct GemmInputs
{
void* a = nullptr;
void* b = nullptr;
void* c = nullptr;
void* d = nullptr;
void* alpha = nullptr;
void* beta = nullptr;
// Epilogue inputs
void* bias = nullptr;
void* aux = nullptr;
};
setProblem
APIs:
HIPBLASLT_EXPORT hipblasStatus_t setProblem(
int64_t m, int64_t n, int64_t k, int64_t batch_count, GemmEpilogue& epilogue, GemmInputs& inputs);
You can also set the leading dimensions and strides, and reassign the data type with the following API:
HIPBLASLT_EXPORT hipblasStatus_t setProblem(int64_t m,
int64_t n,
int64_t k,
int64_t batch_count,
int64_t lda,
int64_t ldb,
int64_t ldc,
int64_t ldd,
int64_t strideA,
int64_t strideB,
int64_t strideC,
int64_t strideD,
GemmEpilogue& epilogue,
GemmInputs& inputs,
GemmProblemType& problemtype);
You can also import problems from hipblasLt APIs after the instance is created. Note that this may overwrite the problem type of the instance.
HIPBLASLT_EXPORT hipblasStatus_t setProblem(hipblasLtMatmulDesc_t matmul_descr,
const void* alpha,
const void* A,
hipblasLtMatrixLayout_t matA,
const void* B,
hipblasLtMatrixLayout_t matB,
const void* beta,
const void* C,
hipblasLtMatrixLayout_t matC,
void* D,
hipblasLtMatrixLayout_t matD);
You can get heuristics and make kernel arguments with the instance. If the properties of the gemm and the inputs don’t change, you can call the run API to launch the kernel directly.
// Pseudo code
hipblaslt_ext::GemmPreference pref;
pref.setMaxWorkspaceBytes(1000000);
// Default epilogue mode is HIPBLASLT_EPILOGUE_DEFAULT
hipblaslt_ext::GemmEpilogue epilogue;
hipblaslt_ext::GemmInputs inputs;
inputs.a = a;
inputs.b = b;
inputs.c = c;
inputs.d = d;
inputs.alpha = alpha;
inputs.beta = beta;
hipblaslt_ext::Gemm gemm(handle,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIPBLAS_COMPUTE_32F);
std::vector<hipblasLtMatmulHeuristicResult_t> heuristic;
gemm.setProblem(1, 1, 1, 1, epilogue, inputs); // m, n, k, batch
gemm.algoGetHeuristic(gemm, pref, heuristic);
gemm.initialize(heuristic[0].algo, d_workspace, stream);
for(int i = 0; i < 10; i++)
{
gemm.run(stream);
}
Grouped gemm#
hipblasLtExt
supports grouped gemm. It shares the same class with normal gemm.
After the problem is set, you can check the problem type with function getGemmType()
.
enum class GemmType
{
HIPBLASLT_GEMM = 1,
HIPBLASLT_GROUPED_GEMM = 2
};
The grouped gemm class also has the setProblem
APIs.
HIPBLASLT_EXPORT hipblasStatus_t setProblem(
int64_t m, int64_t n, int64_t k, int64_t batch_count, GemmEpilogue& epilogue, GemmInputs& inputs);
HIPBLASLT_EXPORT hipblasStatus_t setProblem(std::vector<int64_t>& m,
std::vector<int64_t>& n,
std::vector<int64_t>& k,
std::vector<int64_t>& batch_count,
std::vector<GemmEpilogue>& epilogue,
std::vector<GemmInputs>& inputs);
HIPBLASLT_EXPORT hipblasStatus_t setProblem(std::vector<int64_t>& m,
std::vector<int64_t>& n,
std::vector<int64_t>& k,
std::vector<int64_t>& batch_count,
std::vector<int64_t>& lda,
std::vector<int64_t>& ldb,
std::vector<int64_t>& ldc,
std::vector<int64_t>& ldd,
std::vector<int64_t>& strideA,
std::vector<int64_t>& strideB,
std::vector<int64_t>& strideC,
std::vector<int64_t>& strideD,
std::vector<GemmEpilogue>& epilogue,
std::vector<GemmInputs>& inputs,
GemmProblemType& problemtype);
HIPBLASLT_EXPORT hipblasStatus_t setProblem(std::vector<hipblasLtMatmulDesc_t>& matmul_descr,
std::vector<void*>& alpha,
std::vector<void*>& A,
std::vector<hipblasLtMatrixLayout_t>& matA,
std::vector<void*>& B,
std::vector<hipblasLtMatrixLayout_t>& matB,
std::vector<void*>& beta,
std::vector<void*>& C,
std::vector<hipblasLtMatrixLayout_t>& matC,
std::vector<void*>& D,
std::vector<hipblasLtMatrixLayout_t>& matD);
For the following API, the argument “epilogue” supports broadcasting. They are broadcasted to the length of the problem size by duplicating the last element.
HIPBLASLT_EXPORT hipblasStatus_t setProblem(std::vector<int64_t>& m,
std::vector<int64_t>& n,
std::vector<int64_t>& k,
std::vector<int64_t>& batch_count,
std::vector<int64_t>& lda,
std::vector<int64_t>& ldb,
std::vector<int64_t>& ldc,
std::vector<int64_t>& ldd,
std::vector<int64_t>& strideA,
std::vector<int64_t>& strideB,
std::vector<int64_t>& strideC,
std::vector<int64_t>& strideD,
std::vector<GemmEpilogue>& epilogue,
std::vector<GemmInputs>& inputs,
GemmProblemType& problemtype);
Note that currently only problemtype
size equal to 1 (Only one GemmProblemType
for all problems) is supported.
// Pseudo code
std::vector<int64_t> m, n, k;
// ...
for(size_t i = 0; i < problem_size, i++)
{
// ...
}
std::vector<GemmProblemType> problemtypes;
problemtypes.push_back(problemtype);
groupedgemm.setProblem(m, n, k, batch_count, lda, ldb, ldc, ldd, strideA, strideB, strideC, strideD, epilogue, inputs, problemtypes);
UserArguments#
Grouped gemm supports the use of external device memory to run the kernel. This is helpful if some of the arguments are from the output of the pervious kernel. To change the size (m, n, k, batch) related arguments, refer to Fixed MK.
struct UserArguments
{
uint32_t m; //!< size m
uint32_t n; //!< size n
uint32_t batch; //!< size batch
uint32_t k; //!< size k
void* d; //!< The d matrix input pointer.
void* c; //!< The c matrix input pointer.
void* a; //!< The a matrix input pointer.
void* b; //!< The b matrix input pointer.
uint32_t strideD1; //!< The d leading dimension.
uint32_t strideD2; //!< The d batch stride
uint32_t strideC1; //!< The c leading dimension.
uint32_t strideC2; //!< The c batch stride
uint32_t strideA1; //!< The a leading dimension.
uint32_t strideA2; //!< The a batch stride
uint32_t strideB1; //!< The b leading dimension.
uint32_t strideB2; //!< The b batch stride
int8_t alpha[16]; //!< The alpha value.
int8_t beta[16]; //!< The beta value.
// Epilogue inputs
void* bias; //!< The bias input pointer.
int biasType; //!< The bias datatype. Only works if mode is set to bias related epilogues.
uint32_t reserved;
void* e; //!< The aux input pointer. Only works if mode is set to aux related epilogues.
uint32_t strideE1; //!< The aux leading dimension. Only works if mode is set to aux related epilogues.
uint32_t strideE2; //!< The aux batch stride. Only works if mode is set to aux related epilogues.
float act0; //!< The activation value 1. Some activations might use it.
float act1; //!< The activation value 2.
int activationType; //!< The activation type. Only works if mode is set to activation related epilogues.
} __attribute__((packed));
We add the two functions for UserArguments related API. The first API is a helper function that helps the user to initialize the structure “UserArguments” from the saved problems inside the grouped gemm object. The second API is an overload function with an additional UserArguments device pointer input.
HIPBLASLT_EXPORT hipblasStatus_t getDefaultValueForDeviceUserArguments(void* hostDeviceUserArgs);
HIPBLASLT_EXPORT hipblasStatus_t run(void* deviceUserArgs, hipStream_t stream);
The following is a simple example of how this API works.
// Pseudo code
// Step 1: Get all algorithms
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAllAlgos(handle,
HIPBLASLT_GEMM,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
in_out_datatype,
in_out_datatype,
in_out_datatype,
in_out_datatype,
HIPBLAS_COMPUTE_32F,
heuristicResult));
hipblaslt_ext::GemmPreference pref;
pref.setMaxWorkspaceBytes(1000000);
// Step 2: Setup problem
std::vector<int64_t> m(gemm_count);
std::vector<int64_t> n(gemm_count);
std::vector<int64_t> k(gemm_count);
std::vector<int64_t> batch_count(gemm_count);
std::vector<hipblaslt_ext::GemmEpilogue> epilogue(gemm_count);
std::vector<hipblaslt_ext::GemmInputs> inputs(gemm_count);
for(int i = 0; i < gemm_count; i++)
{
m[i] = 1;
n[i] = 1;
k[i] = 1;
batch_count[i] = 1;
epilogue[i].mode = HIPBLASLT_EPILOGUE_GELU;
inputs[i].a = a[i];
inputs[i].b = b[i];
inputs[i].c = c[i];
inputs[i].d = d[i];
inputs[i].alpha = alpha[i];
inputs[i].beta = beta[i];
}
// Step 3: Create grouped gemm instance
hipblaslt_ext::GroupedGemm groupedGemm(handle,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIPBLAS_COMPUTE_32F);
// Step 4: Set problem
groupedGemm.setProblem(m, n, k, batch_count, epilogue, inputs); // m, n, k, batch
// Step 5: Get default value from the instance
hipblaslt_ext::UserArguments* dUAFloat = new hipblaslt_ext::UserArguments[gemm_count];
groupedGemm.getDefaultValueForDeviceUserArguments((void*)dUAFloat);
// Once you get the default value here, you can make several copies and change the values
// from the host
// Next Copy them to the device memory
hipblaslt_ext::UserArguments* d_dUAFloat = nullptr;
hipMalloc(&d_dUAFloat, sizeof(hipblaslt_ext::UserArguments) * gemm_count);
hipMemcpy(d_dUAFloat, dUAFloat, sizeof(hipblaslt_ext::UserArguments) * gemm_count, hipMemcpyHostToDevice);
validIdx.clear();
for(int j = 0; j < heuristicResult.size(); j++)
{
size_t workspace_size = 0;
if(groupedGemm.isAlgoSupported(heuristicResult[j].algo, workspace_size)
== HIPBLAS_STATUS_SUCCESS)
{
validIdx.push_back(j);
}
}
// Step 6: Initialize and run
if(validIdx.size() > 1)
{
groupedGemm.initialize(heuristicResult[validIdx[0]].algo, d_workspace, stream);
for(int i = 0; i < 10; i++)
{
groupedGemm.run(userArgs, stream);
}
}
The base class (GemmInstance)#
This is the base class for Gemm
and GroupedGemm
.
// Gets huesristic from the instance.
HIPBLASLT_EXPORT hipblasStatus_t algoGetHeuristic(const int requestedAlgoCount,
const GemmPreference& pref,
std::vector<hipblasLtMatmulHeuristicResult_t>& heuristicResults);
// Returns SUCCESS if the algo is supported, also returns the required workspace size in bytes.
HIPBLASLT_EXPORT hipblasStatus_t isAlgoSupported(hipblasLtMatmulAlgo_t& algo, size_t& workspaceSizeInBytes);
// Initializes the instance before calling run. Requires every time the problem is set.
HIPBLASLT_EXPORT hipblasStatus_t initialize(const hipblasLtMatmulAlgo_t& algo, void* workspace, bool useUserArgs = true, hipStream_t stream = 0);
// Run the problem.
HIPBLASLT_EXPORT hipblasStatus_t run(hipStream_t stream);
Get all algorithms#
Get all algorithms allows you to get all the algorithms for a specific problem type. It requires the transpose of A, B, data type of the inputs, and the compute type.
HIPBLASLT_EXPORT
hipblasStatus_t hipblaslt_ext::getAllAlgos(hipblasLtHandle_t handle,
hipblasLtExtGemmTypeEnum_t typeGemm,
hipblasOperation_t opA,
hipblasOperation_t opB,
hipDataType typeA,
hipDataType typeB,
hipDataType typeC,
hipDataType typeD,
hipblasComputeType_t typeCompute,
std::vector<hipblasLtMatmulHeuristicResult_t>& heuristicResults);
This API doesn’t require any problem size or epilogue as input, but uses another API isAlgoSupported
to check if the algorithm supports a problem.
hipblaslt_ext::matmulIsAlgoSupported()
gemm.isAlgoSupported()
The API returns the required workspace size in bytes on successful completion.
// Get all algorithms
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAllAlgos(handle,
HIPBLASLT_GEMM,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
in_out_datatype,
in_out_datatype,
in_out_datatype,
in_out_datatype,
HIPBLAS_COMPUTE_32F,
heuristicResult));
validIdx.clear();
for(int j = 0; j < heuristicResult.size(); j++)
{
size_t workspace_size = 0;
if(hipblaslt_ext::matmulIsAlgoSupported(handle,
matmul,
&(alpha),
matA,
matB,
&(beta),
matC,
matD,
heuristicResult[j].algo,
workspace_size)
== HIPBLAS_STATUS_SUCCESS)
{
validIdx.push_back(j);
heuristicResult[j].workspaceSize = workspace_size;
}
else
{
heuristicResult[j].workspaceSize = 0;
}
}
Algorithm index#
This extension API allows you to get the algorithm index using hipblasLtMatmulAlgo_t
.
HIPBLASLT_EXPORT int getIndexFromAlgo(hipblasLtMatmulAlgo_t& algo);
It also supports you to get the heuristic results by giving an index vector.
HIPBLASLT_EXPORT
hipblasStatus_t
getAlgosFromIndex(hipblasLtHandle_t handle,
std::vector<int>& algoIndex,
std::vector<hipblasLtMatmulHeuristicResult_t>& heuristicResults);
Sample codes#
Here are the sample codes demonstrating use cases of the extension APIs.
Gemm#
// Pseudo code for gemm problem
// Get all algorithms
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAllAlgos(handle,
HIPBLASLT_GEMM,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
in_out_datatype,
in_out_datatype,
in_out_datatype,
in_out_datatype,
HIPBLAS_COMPUTE_32F,
heuristicResult));
hipblaslt_ext::GemmPreference pref;
pref.setMaxWorkspaceBytes(1000000);
hipblaslt_ext::GemmEpilogue epilogue;
epilogue.mode = HIPBLASLT_EPILOGUE_GELU;
hipblaslt_ext::GemmInputs inputs;
inputs.a = a;
inputs.b = b;
inputs.c = c;
inputs.d = d;
inputs.alpha = alpha;
inputs.beta = beta;
hipblaslt_ext::Gemm gemm(handle,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIPBLAS_COMPUTE_32F);
gemm.setProblem(1, 1, 1, 1, epilogue, inputs); // m, n, k, batch
validIdx.clear();
for(int j = 0; j < heuristicResult.size(); j++)
{
size_t workspace_size = 0;
if(gemm.isAlgoSupported(heuristicResult[j].algo, workspace_size)
== HIPBLAS_STATUS_SUCCESS)
{
validIdx.push_back(j);
heuristicResult[j].workspaceSize = workspace_size;
}
else
{
heuristicResult[j].workspaceSize = 0;
}
}
if(validIdx.size() > 1)
{
gemm.initialize(heuristicResult[validIdx[0]].algo, d_workspace, stream);
for(int i = 0; i < 10; i++)
{
gemm.run(stream);
}
}
Grouped gemm#
// Pseudo code for grouped gemm problem
// Get all algorithms
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAllAlgos(handle,
HIPBLASLT_GEMM,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
in_out_datatype,
in_out_datatype,
in_out_datatype,
in_out_datatype,
HIPBLAS_COMPUTE_32F,
heuristicResult));
hipblaslt_ext::GemmPreference pref;
pref.setMaxWorkspaceBytes(1000000);
std::vector<int64_t> m(gemm_count);
std::vector<int64_t> n(gemm_count);
std::vector<int64_t> k(gemm_count);
std::vector<int64_t> batch_count(gemm_count);
std::vector<hipblaslt_ext::GemmEpilogue> epilogue(gemm_count);
std::vector<hipblaslt_ext::GemmInputs> inputs(gemm_count);
for(int i = 0; i < gemm_count; i++)
{
m[i] = 1;
n[i] = 1;
k[i] = 1;
batch_count[i] = 1;
epilogue[i].mode = HIPBLASLT_EPILOGUE_GELU;
inputs[i].a = a[i];
inputs[i].b = b[i];
inputs[i].c = c[i];
inputs[i].d = d[i];
inputs[i].alpha = alpha[i];
inputs[i].beta = beta[i];
}
hipblaslt_ext::GroupedGemm groupedGemm(handle,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIPBLAS_COMPUTE_32F);
groupedGemm.setProblem(m, n, k, batch_count, epilogue, inputs); // m, n, k, batch
validIdx.clear();
for(int j = 0; j < heuristicResult.size(); j++)
{
size_t workspace_size = 0;
if(groupedGemm.isAlgoSupported(heuristicResult[j].algo, workspace_size)
== HIPBLAS_STATUS_SUCCESS)
{
validIdx.push_back(j);
}
}
if(validIdx.size() > 1)
{
groupedGemm.initialize(heuristicResult[validIdx[0]].algo, d_workspace, stream);
for(int i = 0; i < 10; i++)
{
groupedGemm.run(stream);
}
}
Algorithm index#
int index = hipblaslt_ext::getIndexFromAlgo(testResults[i].algo);
// Save the index to disk or somewhere else for later use.
// Get the index from previous state.
std::vector<int> algoIndex(index);
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResults;
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, heuristicResults));
[Grouped Gemm] Fixed MK#
hipBLASLt extension supports changing the sizes (m, n, k, batch) from the device memory UserArguments
, but the setup is a bit different from the normal routing.
Sum of n:
A sum of N needs to be used as an input for the grouped gemm instance.
{1000, 1, 1, 1}; // The array of N, the first element is the sum of N
// Below is the values stored in "UserArguments"
{256, 256, 1, 1}; // This is a valid configuration cause 256 + 256 + 1 + 1 < 1000
{512, 512, 1, 1}; // This is NOT a valid configuration cause 512 + 512 + 1 + 1 > 1000
For example, we have a grouped gemm with gemm_count = 4. The sum of N must not exceed the “sum of N” set in setProblem API. In this mode, the first element is the “sum of n” in the array of Ns.
// Pseudo code
// Step 1: Get all algorithms
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAllAlgos(handle,
HIPBLASLT_GEMM,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
in_out_datatype,
in_out_datatype,
in_out_datatype,
in_out_datatype,
HIPBLAS_COMPUTE_32F,
heuristicResult));
hipblaslt_ext::GemmPreference pref;
pref.setMaxWorkspaceBytes(1000000);
// Step 2: Setup problem
std::vector<int64_t> m(gemm_count);
std::vector<int64_t> n(gemm_count);
std::vector<int64_t> k(gemm_count);
std::vector<int64_t> batch_count(gemm_count);
std::vector<hipblaslt_ext::GemmEpilogue> epilogue(gemm_count);
std::vector<hipblaslt_ext::GemmInputs> inputs(gemm_count);
// Step 2.1: Calculate sum of n
int64_t sum_of_n = 0;
for(int i = 0; i < gemm_count; i++)
{
sum_of_n += n_arr[i];
}
// {sum_of_n, 1, 1, 1, ...}; // The array of N, the first element is the sum of N
for(int i = 0; i < gemm_count; i++)
{
m[i] = m_arr[i];
if(i == 0)
n[i] = sum_of_n;
else
n[i] = 1;
k[i] = k_arr[i];
batch_count[i] = 1;
inputs[i].a = a[i];
inputs[i].b = b[i];
inputs[i].c = c[i];
inputs[i].d = d[i];
inputs[i].alpha = alpha[i];
inputs[i].beta = beta[i];
}
// Step 3: Create grouped gemm instance
hipblaslt_ext::GroupedGemm groupedGemm(handle,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIPBLAS_COMPUTE_32F);
// Step 4: Set problem
groupedGemm.setProblem(m, n, k, batch_count, epilogue, inputs); // m, n, k, batch
// Step 5: Get default value from the instance
hipblaslt_ext::UserArguments* dUAFloat = new hipblaslt_ext::UserArguments[gemm_count];
groupedGemm.getDefaultValueForDeviceUserArguments((void*)dUAFloat);
// Once you get the default value here, you can make several copies and change the values
// from the host
// Next Copy them to the device memory
hipblaslt_ext::UserArguments* d_dUAFloat = nullptr;
hipMalloc(&d_dUAFloat, sizeof(hipblaslt_ext::UserArguments) * gemm_count);
hipMemcpy(d_dUAFloat, dUAFloat, sizeof(hipblaslt_ext::UserArguments) * gemm_count, hipMemcpyHostToDevice);
validIdx.clear();
for(int j = 0; j < heuristicResult.size(); j++)
{
size_t workspace_size = 0;
if(groupedGemm.isAlgoSupported(heuristicResult[j].algo, workspace_size)
== HIPBLAS_STATUS_SUCCESS)
{
validIdx.push_back(j);
}
}
int threads = 256;
int blocks = ceil((double)gemm_count / threads);
// Step 6: Initialize and run
if(validIdx.size() > 1)
{
groupedGemm.initialize(heuristicResult[validIdx[0]].algo, d_workspace);
for(int i = 0; i < 10; i++)
{
hipLaunchKernelGGL(kernelUpdateN,
dim3(blocks),
dim3(threads),
0,
stream,
gemm_count,
d_dUAFloat,
d_n_vec); // d_n_vec is a device pointer with Ns
groupedGemm.run(userArgs, stream);
}
}
// .....
__global__ void kernelUpdateN(uint32_t gemm_count, void* userArgs, int32_t* sizes_n)
{
uint64_t id = hipBlockIdx_x * 256 + hipThreadIdx_x;
if(id >= gemm_count)
return;
hipblaslt_ext::UserArguments* dUAFloat = static_cast<hipblaslt_ext::UserArguments*>(userArgs);
dUAFloat[id].n = sizes_n[id];
}