hipBLASLtExt Reference#
hipBLASLtExt Datatypes Reference#
GemmType#
GemmProblemType#
-
struct GemmProblemType#
hipblasLt extension ProblemType for gemm problems.
This strusture sets the problem type of a gemm problem.
Public Members
-
hipblasOperation_t op_a#
The A martix transpose.
-
hipblasOperation_t op_b#
The B matrix transpose.
-
hipDataType type_a#
The A matrix datatype.
-
hipDataType type_b#
The B matrix datatype.
-
hipDataType type_c#
The C matrix datatype.
-
hipDataType type_d#
The D matrix datatype.
-
hipblasComputeType_t type_compute#
The compute datatype.
-
hipblasOperation_t op_a#
GemmEpilogue#
-
struct GemmEpilogue#
hipblasLt extension Epilogue for gemm problems.
This strusture sets the epilogue of a gemm problem.
Public Members
-
hipblasLtEpilogue_t mode = HIPBLASLT_EPILOGUE_DEFAULT#
The mode of epilogue. Default is gemm.
-
hipDataType bias_data_type = HIPBLASLT_DATATYPE_INVALID#
The bias datatype. Only works if mode is set to bias related epilogues.
-
int aux_ld = 0#
The aux leading dimension. Only works if mode is set to aux related epilogues.
-
int aux_stride = 0#
The aux batch stride. Only works if mode is set to aux related epilogues.
-
hipblasLtEpilogue_t mode = HIPBLASLT_EPILOGUE_DEFAULT#
GemmInputs#
-
struct GemmInputs#
hipblasLt extension Inputs for gemm problems.
This strusture sets the input pointers of a gemm problem.
Public Members
-
void *a = nullptr#
The a matrix input pointer.
-
void *b = nullptr#
The b matrix input pointer.
-
void *c = nullptr#
The c matrix input pointer.
-
void *d = nullptr#
The d matrix input pointer.
-
void *alpha = nullptr#
The alpha value.
-
void *beta = nullptr#
The beta value.
-
void *bias = nullptr#
The bias input pointer.
-
void *scaleA = nullptr#
The Scale A input pointer.
-
void *scaleB = nullptr#
The Scale B input pointer.
-
void *scaleC = nullptr#
The Scale C input pointer.
-
void *scaleD = nullptr#
The Scale D input pointer.
-
void *scaleAux = nullptr#
The Scale AUX input pointer.
-
void *scaleAlphaVec = nullptr#
The scaleAlpha vector input pointer.
-
void *aux = nullptr#
The aux input pointer.
-
void *a = nullptr#
hipBLASLtExt Class Reference#
GemmPreference#
-
class GemmPreference#
hipblasLt extension preference for gemm problems.
Currently only supports setting max workspace size.
Public Functions
-
void setMaxWorkspaceBytes(size_t workspaceBytes)#
This function sets the max workspace size.
- Parameters:
workspaceBytes – [in] Set the max workspace size in bytes.
-
const size_t getMaxWorkspaceBytes() const#
This function returns the set max workspace size.
- Return values:
size_t – Returns the set max workspace size.
-
void setMaxWorkspaceBytes(size_t workspaceBytes)#
GemmInstance#
-
class GemmInstance#
hipblasLt extension instance for gemm problems.
Subclassed by hipblaslt_ext::Gemm, hipblaslt_ext::GroupedGemm
Public Functions
-
hipblasStatus_t algoGetHeuristic(const int requestedAlgoCount, const GemmPreference &pref, std::vector<hipblasLtMatmulHeuristicResult_t> &heuristicResults)#
Retrieve the possible algorithms.
This function retrieves the possible algorithms for the matrix multiply operation hipblasLtMatmul() function with the given data and compute tpye. The output is placed in heuristicResult in the order of increasing estimated compute time.
- Parameters:
requestedAlgoCount – [in] number of requested algorithms.
pref – [in] hipblasLt extension preference for gemm problems.
heuristicResults – [out] The algorithm heuristic vector.
- Return values:
HIPBLAS_STATUS_SUCCESS – If query was successful. Inspect heuristicResults.size > 0, but may heuristicResults.size < requestedAlgoCount state for the status of the results.
HIPBLAS_STATUS_NOT_SUPPORTED – If no heuristic function available for current configuration.
HIPBLAS_STATUS_INVALID_VALUE – If no solution is found.
-
hipblasStatus_t isAlgoSupported(hipblasLtMatmulAlgo_t &algo, size_t &workspaceSizeInBytes)#
Check if the algorithm supports the problem. (For hipblaslt extension API)
This function updates the problem saved inside the algorithm if the problem is supported. The required workspaceSizeInBytes is also returned.
- Parameters:
algo – [in] The algorithm heuristic.
workspaceSizeInBytes – [out] Return the required workspace size.
- Return values:
HIPBLAS_STATUS_SUCCESS – If query was successful. The problem is supported by the algorithm. results.
HIPBLAS_STATUS_INVALID_VALUE – The problem is not supported.
-
hipblasStatus_t initialize(const hipblasLtMatmulAlgo_t &algo, void *workspace, bool useUserArgs = true, hipStream_t stream = 0)#
Create kernel arguments from a given hipblaslt_ext::GemmInstance.
This function creates kernel arguments from a given hipblaslt_ext::GemmInstance then saves the arguments inside the instance.
- Parameters:
algo – [in] Handle for matrix multiplication algorithm to be used. See hipblaslt.h::hipblasLtMatmulAlgo_t . When NULL, an implicit heuristics query with default search preferences will be performed to determine actual algorithm to use.
workspace – [in] Pointer to the workspace buffer allocated in the GPU memory. Pointer must be 16B aligned (that is, lowest 4 bits of address must be 0).
useUserArgs – [in] Use user args, this does not affect vanilla gemm. (May be deprecated in the future)
stream – [in] The HIP stream where all the GPU work will be submitted. (May be deprecated in the future)
- Return values:
HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.
HIPBLAS_STATUS_INVALID_VALUE – If the gemm_count = 0.
-
hipblasStatus_t run(hipStream_t stream)#
Execute the kernel arguments stored inside the hipblaslt_ext::GemmInstance.
- Parameters:
stream – [in] The HIP stream where all the GPU work will be submitted.
- Return values:
HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.
Protected Functions
-
explicit GemmInstance(hipblasLtHandle_t handle, GemmType type)#
Constructor of GemmInstance.
-
hipblasStatus_t algoGetHeuristic(const int requestedAlgoCount, const GemmPreference &pref, std::vector<hipblasLtMatmulHeuristicResult_t> &heuristicResults)#
Gemm#
-
class Gemm : public hipblaslt_ext::GemmInstance#
hipblasLt extension instance for gemm.
The instance can be used to create arguments to compute the matrix multiplication of matrices A and B to produce the output matrix D, according to the following operation:
D
=alpha*
(A
*B
) +beta*
(C
), whereA
,B
, andC
are input matrices, andalpha
andbeta
are input scalars.Public Functions
-
explicit Gemm(hipblasLtHandle_t handle, hipblasOperation_t opA, hipblasOperation_t opB, hipDataType typeA, hipDataType typeB, hipDataType typeC, hipDataType typeD, hipblasComputeType_t typeCompute)#
Constructor.
This function set the problem from hipblasLt structures. For more information about the structures, see hipblasLtMatmul for more information.
- Parameters:
handle – [in] The handle from hipBLASLt.
opA, opB – [in] The transpose type of matrix A, B
typeA, typeB, typeC, typeD – [in] The data type of matrix A, B, C, D
typeCompute – [in] The compute type of the gemm problem
-
explicit Gemm(hipblasLtHandle_t handle, hipblasLtMatmulDesc_t matmul_descr, const void *alpha, const void *A, hipblasLtMatrixLayout_t matA, const void *B, hipblasLtMatrixLayout_t matB, const void *beta, const void *C, hipblasLtMatrixLayout_t matC, void *D, hipblasLtMatrixLayout_t matD)#
Constructor that sets the gemm problem from hipblasLt structures.
This constructor sets the problem from hipblasLt structures. For more information about the structures, see hipblasLtMatmul for more information.
- Parameters:
handle – [in] The handle from hipBLASLt.
matmul_descr – [in] Handle to a previously created matrix multiplication descriptor of type hipblasLtMatmulDesc_t .
alpha, beta – [in] Pointers to the scalars used in the multiplication.
matA, matB, matC, matD – [in] Handles to the previously created matrix layout descriptors of the type hipblasLtMatrixLayout_t .
A, B, C – [in] Pointers to the GPU memory associated with the corresponding descriptors
matA
,matB
andmatC
.D – [out] Pointer to the GPU memory associated with the descriptor
matD
.
-
hipblasStatus_t setProblem(int64_t m, int64_t n, int64_t k, int64_t batch_count, GemmEpilogue &epilogue, GemmInputs &inputs)#
Sets the problem for a gemm problem.
This function sets the problem with m, n, k, batch_count. It uses the problem type sets from the constructor.
- Parameters:
m, n, k – [in] The problem size.
batch_count – [in] The batch count.
epilogue – [in] The structure that controls the epilogue.
inputs – [in] The inputs of the problem.
- Return values:
HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.
HIPBLAS_STATUS_EXECUTION_FAILED – If HIP reported an execution error from the device.
HIPBLAS_STATUS_ARCH_MISMATCH – If the configured operation cannot be run using the selected device.
HIPBLAS_STATUS_NOT_SUPPORTED – If the current implementation on the selected device doesn’t support the configured operation.
HIPBLAS_STATUS_INVALID_VALUE – If the parameters are unexpectedly NULL, in conflict or in an impossible configuration.
HIBLAS_STATUS_NOT_INITIALIZED – If hipBLASLt handle has not been initialized.
-
hipblasStatus_t setProblem(int64_t m, int64_t n, int64_t k, int64_t batch_count, int64_t lda, int64_t ldb, int64_t ldc, int64_t ldd, int64_t strideA, int64_t strideB, int64_t strideC, int64_t strideD, GemmEpilogue &epilogue, GemmInputs &inputs, GemmProblemType &problemtype)#
Sets the problem for a gemm problem.
This function sets the problem with m, n, k, batch_count. It uses the problem type sets from the constructor.
- Parameters:
m, n, k – [in] The problem size.
batch_count – [in] The batch count.
lda, ldb, ldc, ldd – [in] The leading dimensions of the matrix.
strideA, strideB, strideC, strideD – [in] The batch stride of the matrix.
epilogue – [in] The structure that controls the epilogue.
inputs – [in] The inputs of the problem.
problemtype – [in] The structure that sets the problem type of a gemm problem.
- Return values:
HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.
HIPBLAS_STATUS_EXECUTION_FAILED – If HIP reported an execution error from the device.
HIPBLAS_STATUS_ARCH_MISMATCH – If the configured operation cannot be run using the selected device.
HIPBLAS_STATUS_NOT_SUPPORTED – If the current implementation on the selected device doesn’t support the configured operation.
HIPBLAS_STATUS_INVALID_VALUE – If the parameters are unexpectedly NULL, in conflict or in an impossible configuration.
HIBLAS_STATUS_NOT_INITIALIZED – If hipBLASLt handle has not been initialized.
-
hipblasStatus_t setProblem(hipblasLtMatmulDesc_t matmul_descr, const void *alpha, const void *A, hipblasLtMatrixLayout_t matA, const void *B, hipblasLtMatrixLayout_t matB, const void *beta, const void *C, hipblasLtMatrixLayout_t matC, void *D, hipblasLtMatrixLayout_t matD)#
Sets the gemm problem from hipblasLt structures.
This function sets the problem from hipblasLt structures. For more information about the structures, see hipblasLtMatmul for more information.
- Parameters:
matmul_descr – [in] Handle to a previously created matrix multiplication descriptor of type hipblasLtMatmulDesc_t .
alpha, beta – [in] Pointers to the scalars used in the multiplication.
matA, matB, matC, matD – [in] Handles to the previously created matrix layout descriptors of the type hipblasLtMatrixLayout_t .
A, B, C – [in] Pointers to the GPU memory associated with the corresponding descriptors
matA
,matB
andmatC
.D – [out] Pointer to the GPU memory associated with the descriptor
matD
.
- Return values:
HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.
HIPBLAS_STATUS_EXECUTION_FAILED – If HIP reported an execution error from the device.
HIPBLAS_STATUS_ARCH_MISMATCH – If the configured operation cannot be run using the selected device.
HIPBLAS_STATUS_NOT_SUPPORTED – If the current implementation on the selected device doesn’t support the configured operation.
HIPBLAS_STATUS_INVALID_VALUE – If the parameters are unexpectedly NULL, in conflict or in an impossible configuration.
HIBLAS_STATUS_NOT_INITIALIZED – If hipBLASLt handle has not been initialized.
-
explicit Gemm(hipblasLtHandle_t handle, hipblasOperation_t opA, hipblasOperation_t opB, hipDataType typeA, hipDataType typeB, hipDataType typeC, hipDataType typeD, hipblasComputeType_t typeCompute)#
GroupedGemm#
-
class GroupedGemm : public hipblaslt_ext::GemmInstance#
hipblasLt extension instance for grouped gemm.
The instance can be used to create arguments to compute the matrix multiplication of matrices A and B to produce the output matrix D, according to the following operation:
D
=alpha*
(A
*B
) +beta*
(C
), whereA
,B
, andC
are input matrices, andalpha
andbeta
are input scalars.Public Functions
-
explicit GroupedGemm(hipblasLtHandle_t handle, hipblasOperation_t opA, hipblasOperation_t opB, hipDataType typeA, hipDataType typeB, hipDataType typeC, hipDataType typeD, hipblasComputeType_t typeCompute)#
Constructor.
This function set the problem from hipblasLt structures. For more information about the structures, see hipblasLtMatmul for more information.
- Parameters:
handle – [in] The handle from hipBLASLt.
opA, opB – [in] The transpose type of matrix A, B
typeA, typeB, typeC, typeD – [in] The data type of matrix A, B, C, D
typeCompute – [in] The compute type of the gemm problem
-
explicit GroupedGemm(hipblasLtHandle_t handle, std::vector<hipblasLtMatmulDesc_t> &matmul_descr, std::vector<void*> &alpha, std::vector<void*> &A, std::vector<hipblasLtMatrixLayout_t> &matA, std::vector<void*> &B, std::vector<hipblasLtMatrixLayout_t> &matB, std::vector<void*> &beta, std::vector<void*> &C, std::vector<hipblasLtMatrixLayout_t> &matC, std::vector<void*> &D, std::vector<hipblasLtMatrixLayout_t> &matD)#
Constructor that sets the grouped gemm problem from hipblasLt structures.
This constructor sets the problem from hipblasLt structures. For more information about the structures, see hipblasLtMatmul for more information.
- Parameters:
handle – [in] The handle from hipBLASLt.
matmul_descr – [in] Vectors of handle to a previously created matrix multiplication descriptor of type hipblasLtMatmulDesc_t .
alpha, beta – [in] Vectors of float used in the multiplication.
matA, matB, matC, matD – [in] Vectors of handle to the previously created matrix layout descriptors of the type hipblasLtMatrixLayout_t .
A, B, C – [in] Vectors of pointer to the GPU memory associated with the corresponding descriptors
matA
,matB
andmatC
.D – [out] Vector of pointer to the GPU memory associated with the descriptor
matD
.
-
hipblasStatus_t setProblem(std::vector<int64_t> &m, std::vector<int64_t> &n, std::vector<int64_t> &k, std::vector<int64_t> &batch_count, std::vector<GemmEpilogue> &epilogue, std::vector<GemmInputs> &inputs)#
Sets the problem for a gemm problem.
This function sets the problem with m, n, k, batch_count. It uses the problem type sets from the constructor.
- Parameters:
m, n, k – [in] The problem size in vector.
batch_count – [in] The batch count in vector.
epilogue – [in] The structure in vector that controls the epilogue.
inputs – [in] The inputs in vector of the problem.
- Return values:
HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.
HIPBLAS_STATUS_EXECUTION_FAILED – If HIP reported an execution error from the device.
HIPBLAS_STATUS_ARCH_MISMATCH – If the configured operation cannot be run using the selected device.
HIPBLAS_STATUS_NOT_SUPPORTED – If the current implementation on the selected device doesn’t support the configured operation.
HIPBLAS_STATUS_INVALID_VALUE – If the parameters are unexpectedly NULL, in conflict or in an impossible configuration.
HIBLAS_STATUS_NOT_INITIALIZED – If hipBLASLt handle has not been initialized.
-
hipblasStatus_t setProblem(std::vector<int64_t> &m, std::vector<int64_t> &n, std::vector<int64_t> &k, std::vector<int64_t> &batch_count, std::vector<int64_t> &lda, std::vector<int64_t> &ldb, std::vector<int64_t> &ldc, std::vector<int64_t> &ldd, std::vector<int64_t> &strideA, std::vector<int64_t> &strideB, std::vector<int64_t> &strideC, std::vector<int64_t> &strideD, std::vector<GemmEpilogue> &epilogue, std::vector<GemmInputs> &inputs, GemmProblemType &problemtype)#
Sets the problem for a gemm problem.
This function sets the problem with m, n, k, batch_count. It uses the problem type sets from the constructor.
- Parameters:
m, n, k – [in] The problem size in vector.
batch_count – [in] The batch count in vector.
lda, ldb, ldc, ldd – [in] The leading dimensions in vector of the matrix.
strideA, strideB, strideC, strideD – [in] The batch stride in vector of the matrix.
epilogue – [in] The structure in vector that controls the epilogue.
inputs – [in] The inputs in vector of the problem.
problemtype – [in] The structure that sets the problem type of a gemm problem.
- Return values:
HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.
HIPBLAS_STATUS_EXECUTION_FAILED – If HIP reported an execution error from the device.
HIPBLAS_STATUS_ARCH_MISMATCH – If the configured operation cannot be run using the selected device.
HIPBLAS_STATUS_NOT_SUPPORTED – If the current implementation on the selected device doesn’t support the configured operation.
HIPBLAS_STATUS_INVALID_VALUE – If the parameters are unexpectedly NULL, in conflict or in an impossible configuration.
HIBLAS_STATUS_NOT_INITIALIZED – If hipBLASLt handle has not been initialized.
-
hipblasStatus_t setProblem(std::vector<hipblasLtMatmulDesc_t> &matmul_descr, std::vector<void*> &alpha, std::vector<void*> &A, std::vector<hipblasLtMatrixLayout_t> &matA, std::vector<void*> &B, std::vector<hipblasLtMatrixLayout_t> &matB, std::vector<void*> &beta, std::vector<void*> &C, std::vector<hipblasLtMatrixLayout_t> &matC, std::vector<void*> &D, std::vector<hipblasLtMatrixLayout_t> &matD)#
Sets the grouped gemm problem from hipblasLt structures.
This function sets the problem from hipblasLt structures. For more information about the structures, see hipblasLtMatmul for more information.
- Parameters:
matmul_descr – [in] Vectors of handle to a previously created matrix multiplication descriptor of type hipblasLtMatmulDesc_t .
alpha, beta – [in] Vectors of float used in the multiplication.
matA, matB, matC, matD – [in] Vectors of handle to the previously created matrix layout descriptors of the type hipblasLtMatrixLayout_t .
A, B, C – [in] Vectors of pointer to the GPU memory associated with the corresponding descriptors
matA
,matB
andmatC
.D – [out] Vector of pointer to the GPU memory associated with the descriptor
matD
.
- Return values:
HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.
HIPBLAS_STATUS_EXECUTION_FAILED – If HIP reported an execution error from the device.
HIPBLAS_STATUS_ARCH_MISMATCH – If the configured operation cannot be run using the selected device.
HIPBLAS_STATUS_NOT_SUPPORTED – If the current implementation on the selected device doesn’t support the configured operation.
HIPBLAS_STATUS_INVALID_VALUE – If the parameters are unexpectedly NULL, in conflict or in an impossible configuration.
HIBLAS_STATUS_NOT_INITIALIZED – If hipBLASLt handle has not been initialized.
-
hipblasStatus_t getDefaultValueForDeviceUserArguments(void *hostDeviceUserArgs)#
A helper function to initialize DeviceUserArguments using the set problem(s) saved in the gemm object.
- Parameters:
hostDeviceUserArgs – [in] The DeviceUserArguments struture allocated in host. Note that the user must put the correct type of the DeviceUserArguments.
- Return values:
HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.
-
hipblasStatus_t run(void *deviceUserArgs, hipStream_t stream)#
Run the kernel using DeviceUserArguments.
- Parameters:
deviceUserArgs – [in] Pointer to the DeviceUserArguments buffer allocated in the GPU memory. Pointer must be 16B aligned (that is, lowest 4 bits of
stream – [in] The HIP stream where all the GPU work will be submitted.
- Return values:
HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.
HIPBLAS_STATUS_INVALID_VALUE – If the gemm_count = 0.
-
hipblasStatus_t run(hipStream_t stream)#
Execute the kernel arguments stored inside the hipblaslt_ext::GemmInstance.
- Parameters:
stream – [in] The HIP stream where all the GPU work will be submitted.
- Return values:
HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.
-
explicit GroupedGemm(hipblasLtHandle_t handle, hipblasOperation_t opA, hipblasOperation_t opB, hipDataType typeA, hipDataType typeB, hipDataType typeC, hipDataType typeD, hipblasComputeType_t typeCompute)#
hipBLASLtExt API Reference#
getAllAlgos()#
-
hipblasStatus_t hipblaslt_ext::getAllAlgos(hipblasLtHandle_t handle, GemmType typeGemm, hipblasOperation_t opA, hipblasOperation_t opB, hipDataType typeA, hipDataType typeB, hipDataType typeC, hipDataType typeD, hipblasComputeType_t typeCompute, std::vector<hipblasLtMatmulHeuristicResult_t> &heuristicResults)#
Retrieve the possible algorithms.
This function retrieves the possible algorithms for the matrix multiply operation hipblasLtMatmul() function with the given data and compute tpye. The output is placed in heuristicResults in the order of increasing estimated compute time.
- Parameters:
handle – [in] Pointer to the allocated hipBLASLt handle for the hipBLASLt context. See hipblasLtHandle_t .
typeGemm – [in] Gemm type. ex. GEMM, GROUPED_GEMM.
opA, opB – [in] Transpose settings of A, B.
typeA, typeB, typeC, typeD – [in] The data type of matrix A, B, C, D.
typeCompute – [in] The compute type.
heuristicResults – [out] The algorithm heuristic vector.
- Return values:
HIPBLAS_STATUS_SUCCESS – If query was successful. Inspect returnedAlgoCount > 0.state for the status of the results.
HIPBLAS_STATUS_NOT_SUPPORTED – If no heuristic function available for current configuration.
HIPBLAS_STATUS_INVALID_VALUE – If no solution is found.
getIndexFromAlgo()#
-
int hipblaslt_ext::getIndexFromAlgo(hipblasLtMatmulAlgo_t &algo)#
Retrieve the algorithm index.
- Parameters:
algo – [in] The algorithm.
- Return values:
int – The index of the algorithm, can be used to get hueristic results from getAlgosFromIndex. Returns -1 if the index stored in algo < 0. Note that the index may not be valid if the algo struct is not initialized properly.
getAlgosFromIndex()#
-
hipblasStatus_t hipblaslt_ext::getAlgosFromIndex(hipblasLtHandle_t handle, std::vector<int> &algoIndex, std::vector<hipblasLtMatmulHeuristicResult_t> &heuristicResults)#
Retrieve the possible algorithms.
This function retrieves the possible algorithms for the matrix multiply operation hipblasLtMatmul() function with the given index. The output is placed in heuristicResult in the order of increasing estimated compute time.
- Parameters:
handle – [in] Pointer to the allocated hipBLASLt handle for the hipBLASLt context. See hipblasLtHandle_t .
algoIndex – [in] The algorithm index vector.
heuristicResults – [out] The algorithm heuristic vector.
- Return values:
HIPBLAS_STATUS_SUCCESS – If query was successful. Inspect heuristicResults.size() > 0.state for the status of the results.
HIPBLAS_STATUS_NOT_SUPPORTED – If no heuristic function available for current configuration.
HIPBLAS_STATUS_INVALID_VALUE – If no solution is found.
matmulIsAlgoSupported()#
-
hipblasStatus_t hipblaslt_ext::matmulIsAlgoSupported(hipblasLtHandle_t handle, hipblasLtMatmulDesc_t matmulDesc, const void *alpha, hipblasLtMatrixLayout_t Adesc, hipblasLtMatrixLayout_t Bdesc, const void *beta, hipblasLtMatrixLayout_t Cdesc, hipblasLtMatrixLayout_t Ddesc, hipblasLtMatmulAlgo_t &algo, size_t &workspaceSizeInBytes)#
Check if the algorithm supports the problem. (For hipblasLt API)
This function updates the problem saved inside the algorithm if the problem is supported. The required workspaceSizeInBytes is also returned.
- Parameters:
handle – [in] Pointer to the allocated hipBLASLt handle for the hipBLASLt context. See hipblasLtHandle_t .
matmulDesc – [in] Handle to a previously created matrix multiplication descriptor of type hipblasLtMatmulDesc_t .
alpha, beta – [in] Pointers to the scalars used in the multiplication.
Adesc, Bdesc, Cdesc, Ddesc – [in] Handles to the previously created matrix layout descriptors of the type hipblasLtMatrixLayout_t .
algo – [in] The algorithm heuristic.
workspaceSizeInBytes – [out] Return the required workspace size.
- Return values:
HIPBLAS_STATUS_SUCCESS – If query was successful. The problem is supported by the algorithm. results.
HIPBLAS_STATUS_INVALID_VALUE – The problem is not supported.
hipblasLtExt Usage#
Introduction#
hipBLASLt has extension APIs with namespace hipblaslt_ext. It is C++ compatible only. The extensions support:
Gemm
Grouped gemm
Get all algorithms
Gemm#
hipblasLt has its own instance.
The user must assign the problem type when construct or import the problem from hipBLAS API.
HIPBLASLT_EXPORT explicit Gemm(hipblasLtHandle_t handle,
hipblasOperation_t opA,
hipblasOperation_t opB,
hipDataType typeA,
hipDataType typeB,
hipDataType typeC,
hipDataType typeD,
hipblasComputeType_t typeCompute);
HIPBLASLT_EXPORT explicit Gemm(hipblasLtHandle_t handle,
hipblasLtMatmulDesc_t matmul_descr,
const void* alpha,
const void* A,
hipblasLtMatrixLayout_t matA,
const void* B,
hipblasLtMatrixLayout_t matB,
const void* beta,
const void* C,
hipblasLtMatrixLayout_t matC,
void* D,
hipblasLtMatrixLayout_t matD);
After the instance is created, the user can set the problem with the API. The API may requires the following structures:
GemmProblemType lets user able to change the problem type after the instance is initialized.
struct GemmProblemType
{
hipblasOperation_t op_a;
hipblasOperation_t op_b;
hipDataType type_a;
hipDataType type_b;
hipDataType type_c;
hipDataType type_d;
hipblasComputeType_t type_compute;
};
GemmEpilogue lets user to control the epilogue of the problem.
struct GemmEpilogue
{
hipblasLtEpilogue_t mode = HIPBLASLT_EPILOGUE_DEFAULT;
hipDataType bias_data_type;
int aux_ld;
int aux_stride;
};
GemmInputs is the problem inputs.
struct GemmInputs
{
void* a = nullptr;
void* b = nullptr;
void* c = nullptr;
void* d = nullptr;
void* alpha = nullptr;
void* beta = nullptr;
// Epilogue inputs
void* bias = nullptr;
void* aux = nullptr;
};
And the setProblem APIs:
HIPBLASLT_EXPORT hipblasStatus_t setProblem(
int64_t m, int64_t n, int64_t k, int64_t batch_count, GemmEpilogue& epilogue, GemmInputs& inputs);
The user can also set the leading dimensions, strides, and reassign the data type with the following API.
HIPBLASLT_EXPORT hipblasStatus_t setProblem(int64_t m,
int64_t n,
int64_t k,
int64_t batch_count,
int64_t lda,
int64_t ldb,
int64_t ldc,
int64_t ldd,
int64_t strideA,
int64_t strideB,
int64_t strideC,
int64_t strideD,
GemmEpilogue& epilogue,
GemmInputs& inputs,
GemmProblemType& problemtype);
The user can also importing problems from hipblasLt APIs after the instance is created, note that this may overwrite the problem type of the instance.
HIPBLASLT_EXPORT hipblasStatus_t setProblem(hipblasLtMatmulDesc_t matmul_descr,
const void* alpha,
const void* A,
hipblasLtMatrixLayout_t matA,
const void* B,
hipblasLtMatrixLayout_t matB,
const void* beta,
const void* C,
hipblasLtMatrixLayout_t matC,
void* D,
hipblasLtMatrixLayout_t matD);
The user can get hueristic and make kernel arguments with the instance. If the properties of the gemm and the inputs don’t change, the user can call the run API to launch the kernel directly.
// Pseudo code
hipblaslt_ext::GemmPreference pref;
pref.setMaxWorkspaceBytes(1000000);
// Default epilogue mode is HIPBLASLT_EPILOGUE_DEFAULT
hipblaslt_ext::GemmEpilogue epilogue;
hipblaslt_ext::GemmInputs inputs;
inputs.a = a;
inputs.b = b;
inputs.c = c;
inputs.d = d;
inputs.alpha = alpha;
inputs.beta = beta;
hipblaslt_ext::Gemm gemm(handle,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIPBLAS_COMPUTE_32F);
std::vector<hipblasLtMatmulHeuristicResult_t> hueristic;
gemm.setProblem(1, 1, 1, 1, epilogue, inputs); // m, n, k, batch
gemm.algoGetHeuristic(gemm, pref, hueristic);
gemm.initialize(hueristic[0].algo, d_workspace, stream);
for(int i = 0; i < 10; i++)
{
gemm.run(stream);
}
Grouped Gemm#
hipblasLtExt supports grouped gemm. It shares the same class with normal gemm.
After the problem is set, the user can check the problem type with function getGemmType().
enum class GemmType
{
HIPBLASLT_GEMM = 1,
HIPBLASLT_GROUPED_GEMM = 2
};
The grouped gemm class also has the setProblem APIs.
HIPBLASLT_EXPORT hipblasStatus_t setProblem(
int64_t m, int64_t n, int64_t k, int64_t batch_count, GemmEpilogue& epilogue, GemmInputs& inputs);
HIPBLASLT_EXPORT hipblasStatus_t setProblem(std::vector<int64_t>& m,
std::vector<int64_t>& n,
std::vector<int64_t>& k,
std::vector<int64_t>& batch_count,
std::vector<GemmEpilogue>& epilogue,
std::vector<GemmInputs>& inputs);
HIPBLASLT_EXPORT hipblasStatus_t setProblem(std::vector<int64_t>& m,
std::vector<int64_t>& n,
std::vector<int64_t>& k,
std::vector<int64_t>& batch_count,
std::vector<int64_t>& lda,
std::vector<int64_t>& ldb,
std::vector<int64_t>& ldc,
std::vector<int64_t>& ldd,
std::vector<int64_t>& strideA,
std::vector<int64_t>& strideB,
std::vector<int64_t>& strideC,
std::vector<int64_t>& strideD,
std::vector<GemmEpilogue>& epilogue,
std::vector<GemmInputs>& inputs,
GemmProblemType& problemtype);
HIPBLASLT_EXPORT hipblasStatus_t setProblem(std::vector<hipblasLtMatmulDesc_t>& matmul_descr,
std::vector<void*>& alpha,
std::vector<void*>& A,
std::vector<hipblasLtMatrixLayout_t>& matA,
std::vector<void*>& B,
std::vector<hipblasLtMatrixLayout_t>& matB,
std::vector<void*>& beta,
std::vector<void*>& C,
std::vector<hipblasLtMatrixLayout_t>& matC,
std::vector<void*>& D,
std::vector<hipblasLtMatrixLayout_t>& matD);
For the following API, the argument “epilogue” supports broadcasting. They will be broadcasted to the length of the problem size by duplicating the last element.
HIPBLASLT_EXPORT hipblasStatus_t setProblem(std::vector<int64_t>& m,
std::vector<int64_t>& n,
std::vector<int64_t>& k,
std::vector<int64_t>& batch_count,
std::vector<int64_t>& lda,
std::vector<int64_t>& ldb,
std::vector<int64_t>& ldc,
std::vector<int64_t>& ldd,
std::vector<int64_t>& strideA,
std::vector<int64_t>& strideB,
std::vector<int64_t>& strideC,
std::vector<int64_t>& strideD,
std::vector<GemmEpilogue>& epilogue,
std::vector<GemmInputs>& inputs,
GemmProblemType& problemtype);
Note that currently only supports problemtype size equals to 1 (Only one GemmProblemType for all problems).
// Pseudo code
std::vector<int64_t> m, n, k;
// ...
for(size_t i = 0; i < problem_size, i++)
{
// ...
}
std::vector<GemmProblemType> problemtypes;
problemtypes.push_back(problemtype);
groupedgemm.setProblem(m, n, k, batch_count, lda, ldb, ldc, ldd, strideA, strideB, strideC, strideD, epilogue, inputs, problemtypes);
UserArguments#
Grouped gemm supports using external device memory to run the kernel. This will be helpful if some of the arguments are from the output of the pervious kernel. Please refer to section Fixed MK if you want to change the size (m, n, k, batch) related arguments.
struct UserArguments
{
uint32_t m; //!< size m
uint32_t n; //!< size n
uint32_t batch; //!< size batch
uint32_t k; //!< size k
void* d; //!< The d matrix input pointer.
void* c; //!< The c matrix input pointer.
void* a; //!< The a matrix input pointer.
void* b; //!< The b matrix input pointer.
uint32_t strideD1; //!< The d leading dimension.
uint32_t strideD2; //!< The d batch stride
uint32_t strideC1; //!< The c leading dimension.
uint32_t strideC2; //!< The c batch stride
uint32_t strideA1; //!< The a leading dimension.
uint32_t strideA2; //!< The a batch stride
uint32_t strideB1; //!< The b leading dimension.
uint32_t strideB2; //!< The b batch stride
int8_t alpha[16]; //!< The alpha value.
int8_t beta[16]; //!< The beta value.
// Epilogue inputs
void* bias; //!< The bias input pointer.
int biasType; //!< The bias datatype. Only works if mode is set to bias related epilogues.
uint32_t reserved;
void* e; //!< The aux input pointer. Only works if mode is set to aux related epilogues.
uint32_t strideE1; //!< The aux leading dimension. Only works if mode is set to aux related epilogues.
uint32_t strideE2; //!< The aux batch stride. Only works if mode is set to aux related epilogues.
float act0; //!< The activation value 1. Some activations might use it.
float act1; //!< The activation value 2.
int activationType; //!< The activation type. Only works if mode is set to activation related epilogues.
} __attribute__((packed));
We add the two functions for UserArguments related API. The first API is a helper function that helps the user to initialize the structure “UserArguments” from the saved problems inside the grouped gemm object. The second API is an overload function with an additional UserArguments device pointer input.
HIPBLASLT_EXPORT hipblasStatus_t getDefaultValueForDeviceUserArguments(void* hostDeviceUserArgs);
HIPBLASLT_EXPORT hipblasStatus_t run(void* deviceUserArgs, hipStream_t stream);
The following is a simple example of how this API works.
// Pseudo code
// Step 1: Get all algorithms
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAllAlgos(handle,
HIPBLASLT_GEMM,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
in_out_datatype,
in_out_datatype,
in_out_datatype,
in_out_datatype,
HIPBLAS_COMPUTE_32F,
heuristicResult));
hipblaslt_ext::GemmPreference pref;
pref.setMaxWorkspaceBytes(1000000);
// Step 2: Setup problem
std::vector<int64_t> m(gemm_count);
std::vector<int64_t> n(gemm_count);
std::vector<int64_t> k(gemm_count);
std::vector<int64_t> batch_count(gemm_count);
std::vector<hipblaslt_ext::GemmEpilogue> epilogue(gemm_count);
std::vector<hipblaslt_ext::GemmInputs> inputs(gemm_count);
for(int i = 0; i < gemm_count; i++)
{
m[i] = 1;
n[i] = 1;
k[i] = 1;
batch_count[i] = 1;
epilogue[i].mode = HIPBLASLT_EPILOGUE_GELU;
inputs[i].a = a[i];
inputs[i].b = b[i];
inputs[i].c = c[i];
inputs[i].d = d[i];
inputs[i].alpha = alpha[i];
inputs[i].beta = beta[i];
}
// Step 3: Create grouped gemm instance
hipblaslt_ext::GroupedGemm groupedGemm(handle,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIPBLAS_COMPUTE_32F);
// Step 4: Set problem
groupedGemm.setProblem(m, n, k, batch_count, epilogue, inputs); // m, n, k, batch
// Step 5: Get default value from the instance
hipblaslt_ext::UserArguments* dUAFloat = new hipblaslt_ext::UserArguments[gemm_count];
groupedGemm.getDefaultValueForDeviceUserArguments((void*)dUAFloat);
// Once you get the default value here, you can make several copies and change the values
// from the host
// Next Copy them to the device memory
hipblaslt_ext::UserArguments* d_dUAFloat = nullptr;
hipMalloc(&d_dUAFloat, sizeof(hipblaslt_ext::UserArguments) * gemm_count);
hipMemcpy(d_dUAFloat, dUAFloat, sizeof(hipblaslt_ext::UserArguments) * gemm_count, hipMemcpyHostToDevice);
validIdx.clear();
for(int j = 0; j < heuristicResult.size(); j++)
{
size_t workspace_size = 0;
if(groupedGemm.isAlgoSupported(heuristicResult[j].algo, workspace_size)
== HIPBLAS_STATUS_SUCCESS)
{
validIdx.push_back(j);
}
}
// Step 6: Initialize and run
if(validIdx.size() > 1)
{
groupedGemm.initialize(heuristicResult[validIdx[0]].algo, d_workspace, stream);
for(int i = 0; i < 10; i++)
{
groupedGemm.run(userArgs, stream);
}
}
The base class (GemmInstance)#
This is the base class of class Gemm and GroupedGemm.
// Gets huesristic from the instance.
HIPBLASLT_EXPORT hipblasStatus_t algoGetHeuristic(const int requestedAlgoCount,
const GemmPreference& pref,
std::vector<hipblasLtMatmulHeuristicResult_t>& heuristicResults);
// Returns SUCCESS if the algo is supported, also returns the required workspace size in bytes.
HIPBLASLT_EXPORT hipblasStatus_t isAlgoSupported(hipblasLtMatmulAlgo_t& algo, size_t& workspaceSizeInBytes);
// Initializes the instance before calling run. Requires every time the problem is set.
HIPBLASLT_EXPORT hipblasStatus_t initialize(const hipblasLtMatmulAlgo_t& algo, void* workspace, bool useUserArgs = true, hipStream_t stream = 0);
// Run the problem.
HIPBLASLT_EXPORT hipblasStatus_t run(hipStream_t stream);
Get all algorithms#
Get all algorithms lets users to get all the algorithms of a specific problem type. It requires the transpose of A, B, the data type of the inputs, and the compute type.
HIPBLASLT_EXPORT
hipblasStatus_t hipblaslt_ext::getAllAlgos(hipblasLtHandle_t handle,
hipblasLtExtGemmTypeEnum_t typeGemm,
hipblasOperation_t opA,
hipblasOperation_t opB,
hipDataType typeA,
hipDataType typeB,
hipDataType typeC,
hipDataType typeD,
hipblasComputeType_t typeCompute,
std::vector<hipblasLtMatmulHeuristicResult_t>& heuristicResults);
This API does not require any problem size or epilogue as input, but will use another API “isAlgoSupported” to check if the algorithm supports a problem.
hipblaslt_ext::matmulIsAlgoSupported()
gemm.isAlgoSupported()
The API will return the required workspace size in bytes if success.
// Get all algorithms
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAllAlgos(handle,
HIPBLASLT_GEMM,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
in_out_datatype,
in_out_datatype,
in_out_datatype,
in_out_datatype,
HIPBLAS_COMPUTE_32F,
heuristicResult));
validIdx.clear();
for(int j = 0; j < heuristicResult.size(); j++)
{
size_t workspace_size = 0;
if(hipblaslt_ext::matmulIsAlgoSupported(handle,
matmul,
&(alpha),
matA,
matB,
&(beta),
matC,
matD,
heuristicResult[j].algo,
workspace_size)
== HIPBLAS_STATUS_SUCCESS)
{
validIdx.push_back(j);
heuristicResult[j].workspaceSize = workspace_size;
}
else
{
heuristicResult[j].workspaceSize = 0;
}
}
Using extension APIs.
Gemm#
// Pseudo code for gemm problem
// Get all algorithms
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAllAlgos(handle,
HIPBLASLT_GEMM,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
in_out_datatype,
in_out_datatype,
in_out_datatype,
in_out_datatype,
HIPBLAS_COMPUTE_32F,
heuristicResult));
hipblaslt_ext::GemmPreference pref;
pref.setMaxWorkspaceBytes(1000000);
hipblaslt_ext::GemmEpilogue epilogue;
epilogue.mode = HIPBLASLT_EPILOGUE_GELU;
hipblaslt_ext::GemmInputs inputs;
inputs.a = a;
inputs.b = b;
inputs.c = c;
inputs.d = d;
inputs.alpha = alpha;
inputs.beta = beta;
hipblaslt_ext::Gemm gemm(handle,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIPBLAS_COMPUTE_32F);
gemm.setProblem(1, 1, 1, 1, epilogue, inputs); // m, n, k, batch
validIdx.clear();
for(int j = 0; j < heuristicResult.size(); j++)
{
size_t workspace_size = 0;
if(gemm.isAlgoSupported(heuristicResult[j].algo, workspace_size)
== HIPBLAS_STATUS_SUCCESS)
{
validIdx.push_back(j);
heuristicResult[j].workspaceSize = workspace_size;
}
else
{
heuristicResult[j].workspaceSize = 0;
}
}
if(validIdx.size() > 1)
{
gemm.initialize(heuristicResult[validIdx[0]].algo, d_workspace, stream);
for(int i = 0; i < 10; i++)
{
gemm.run(stream);
}
}
Grouped gemm#
// Pseudo code for grouped gemm problem
// Get all algorithms
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAllAlgos(handle,
HIPBLASLT_GEMM,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
in_out_datatype,
in_out_datatype,
in_out_datatype,
in_out_datatype,
HIPBLAS_COMPUTE_32F,
heuristicResult));
hipblaslt_ext::GemmPreference pref;
pref.setMaxWorkspaceBytes(1000000);
std::vector<int64_t> m(gemm_count);
std::vector<int64_t> n(gemm_count);
std::vector<int64_t> k(gemm_count);
std::vector<int64_t> batch_count(gemm_count);
std::vector<hipblaslt_ext::GemmEpilogue> epilogue(gemm_count);
std::vector<hipblaslt_ext::GemmInputs> inputs(gemm_count);
for(int i = 0; i < gemm_count; i++)
{
m[i] = 1;
n[i] = 1;
k[i] = 1;
batch_count[i] = 1;
epilogue[i].mode = HIPBLASLT_EPILOGUE_GELU;
inputs[i].a = a[i];
inputs[i].b = b[i];
inputs[i].c = c[i];
inputs[i].d = d[i];
inputs[i].alpha = alpha[i];
inputs[i].beta = beta[i];
}
hipblaslt_ext::GroupedGemm groupedGemm(handle,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIPBLAS_COMPUTE_32F);
groupedGemm.setProblem(m, n, k, batch_count, epilogue, inputs); // m, n, k, batch
validIdx.clear();
for(int j = 0; j < heuristicResult.size(); j++)
{
size_t workspace_size = 0;
if(groupedGemm.isAlgoSupported(heuristicResult[j].algo, workspace_size)
== HIPBLAS_STATUS_SUCCESS)
{
validIdx.push_back(j);
}
}
if(validIdx.size() > 1)
{
groupedGemm.initialize(heuristicResult[validIdx[0]].algo, d_workspace, stream);
for(int i = 0; i < 10; i++)
{
groupedGemm.run(stream);
}
}
Algorithm Index#
The extension API lets user to get the algorithm index from hipblasLtMatmulAlgo_t.
HIPBLASLT_EXPORT int getIndexFromAlgo(hipblasLtMatmulAlgo_t& algo);
It also supports user to get the heuristic results by giving an index vector.
HIPBLASLT_EXPORT
hipblasStatus_t
getAlgosFromIndex(hipblasLtHandle_t handle,
std::vector<int>& algoIndex,
std::vector<hipblasLtMatmulHeuristicResult_t>& heuristicResults);
Example code#
[Grouped Gemm] Fixed MK#
hipBLASLt extension supports changing the sizes (m, n, k, batch) from the device memory “UserArguments”, but the setup is a bit different from the normal routing.
Sum of n#
A sum of N is required to use as an input for the grouped gemm instance.
For example, we have a grouped gemm with gemm_count = 4. The sum of N must not exceed the “sum of N” set in setProblem API. In this mode, the first element is the “sum of n” in the array of Ns.
// Pseudo code
// Step 1: Get all algorithms
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAllAlgos(handle,
HIPBLASLT_GEMM,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
in_out_datatype,
in_out_datatype,
in_out_datatype,
in_out_datatype,
HIPBLAS_COMPUTE_32F,
heuristicResult));
hipblaslt_ext::GemmPreference pref;
pref.setMaxWorkspaceBytes(1000000);
// Step 2: Setup problem
std::vector<int64_t> m(gemm_count);
std::vector<int64_t> n(gemm_count);
std::vector<int64_t> k(gemm_count);
std::vector<int64_t> batch_count(gemm_count);
std::vector<hipblaslt_ext::GemmEpilogue> epilogue(gemm_count);
std::vector<hipblaslt_ext::GemmInputs> inputs(gemm_count);
// Step 2.1: Calculate sum of n
int64_t sum_of_n = 0;
for(int i = 0; i < gemm_count; i++)
{
sum_of_n += n_arr[i];
}
// {sum_of_n, 1, 1, 1, ...}; // The array of N, the first element is the sum of N
for(int i = 0; i < gemm_count; i++)
{
m[i] = m_arr[i];
if(i == 0)
n[i] = sum_of_n;
else
n[i] = 1;
k[i] = k_arr[i];
batch_count[i] = 1;
inputs[i].a = a[i];
inputs[i].b = b[i];
inputs[i].c = c[i];
inputs[i].d = d[i];
inputs[i].alpha = alpha[i];
inputs[i].beta = beta[i];
}
// Step 3: Create grouped gemm instance
hipblaslt_ext::GroupedGemm groupedGemm(handle,
HIPBLAS_OP_N,
HIPBLAS_OP_N,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIP_R_16F,
HIPBLAS_COMPUTE_32F);
// Step 4: Set problem
groupedGemm.setProblem(m, n, k, batch_count, epilogue, inputs); // m, n, k, batch
// Step 5: Get default value from the instance
hipblaslt_ext::UserArguments* dUAFloat = new hipblaslt_ext::UserArguments[gemm_count];
groupedGemm.getDefaultValueForDeviceUserArguments((void*)dUAFloat);
// Once you get the default value here, you can make several copies and change the values
// from the host
// Next Copy them to the device memory
hipblaslt_ext::UserArguments* d_dUAFloat = nullptr;
hipMalloc(&d_dUAFloat, sizeof(hipblaslt_ext::UserArguments) * gemm_count);
hipMemcpy(d_dUAFloat, dUAFloat, sizeof(hipblaslt_ext::UserArguments) * gemm_count, hipMemcpyHostToDevice);
validIdx.clear();
for(int j = 0; j < heuristicResult.size(); j++)
{
size_t workspace_size = 0;
if(groupedGemm.isAlgoSupported(heuristicResult[j].algo, workspace_size)
== HIPBLAS_STATUS_SUCCESS)
{
validIdx.push_back(j);
}
}
int threads = 256;
int blocks = ceil((double)gemm_count / threads);
// Step 6: Initialize and run
if(validIdx.size() > 1)
{
groupedGemm.initialize(heuristicResult[validIdx[0]].algo, d_workspace);
for(int i = 0; i < 10; i++)
{
hipLaunchKernelGGL(kernelUpdateN,
dim3(blocks),
dim3(threads),
0,
stream,
gemm_count,
d_dUAFloat,
d_n_vec); // d_n_vec is a device pointer with Ns
groupedGemm.run(userArgs, stream);
}
}
// .....
__global__ void kernelUpdateN(uint32_t gemm_count, void* userArgs, int32_t* sizes_n)
{
uint64_t id = hipBlockIdx_x * 256 + hipThreadIdx_x;
if(id >= gemm_count)
return;
hipblaslt_ext::UserArguments* dUAFloat = static_cast<hipblaslt_ext::UserArguments*>(userArgs);
dUAFloat[id].n = sizes_n[id];
}