hipBLASLtExt Reference

Contents

hipBLASLtExt Reference#

hipBLASLtExt Datatypes Reference#

GemmType#

enum class hipblaslt_ext::GemmType#

It is an enumerated type used to specific the type of the gemm problem in hipblasLtExt APIs.

Values:

enumerator HIPBLASLT_GEMM#
enumerator HIPBLASLT_GROUPED_GEMM#

GemmProblemType#

struct GemmProblemType#

hipblasLt extension ProblemType for gemm problems.

This strusture sets the problem type of a gemm problem.

Public Members

hipblasOperation_t op_a#

The A martix transpose.

hipblasOperation_t op_b#

The B matrix transpose.

hipblasltDatatype_t type_a#

The A matrix datatype.

hipblasltDatatype_t type_b#

The B matrix datatype.

hipblasltDatatype_t type_c#

The C matrix datatype.

hipblasltDatatype_t type_d#

The D matrix datatype.

hipblasLtComputeType_t type_compute#

The compute datatype.

GemmEpilogue#

struct GemmEpilogue#

hipblasLt extension Epilogue for gemm problems.

This strusture sets the epilogue of a gemm problem.

Public Members

hipblasLtEpilogue_t mode = HIPBLASLT_EPILOGUE_DEFAULT#

The mode of epilogue. Default is gemm.

hipblasltDatatype_t bias_data_type = HIPBLASLT_DATATYPE_INVALID#

The bias datatype. Only works if mode is set to bias related epilogues.

int aux_ld = 0#

The aux leading dimension. Only works if mode is set to aux related epilogues.

int aux_stride = 0#

The aux batch stride. Only works if mode is set to aux related epilogues.

GemmInputs#

struct GemmInputs#

hipblasLt extension Inputs for gemm problems.

This strusture sets the input pointers of a gemm problem.

Public Members

void *a = nullptr#

The a matrix input pointer.

void *b = nullptr#

The b matrix input pointer.

void *c = nullptr#

The c matrix input pointer.

void *d = nullptr#

The d matrix input pointer.

void *alpha = nullptr#

The alpha value.

void *beta = nullptr#

The beta value.

void *bias = nullptr#

The bias input pointer.

void *scaleA = nullptr#

The Scale A input pointer.

void *scaleB = nullptr#

The Scale B input pointer.

void *scaleC = nullptr#

The Scale C input pointer.

void *scaleD = nullptr#

The Scale D input pointer.

void *scaleAux = nullptr#

The Scale AUX input pointer.

void *scaleAlphaVec = nullptr#

The scaleAlpha vector input pointer.

void *aux = nullptr#

The aux input pointer.

hipBLASLtExt Class Reference#

GemmPreference#

class GemmPreference#

hipblasLt extension preference for gemm problems.

Currently only supports setting max workspace size.

Public Functions

void setMaxWorkspaceBytes(size_t workspaceBytes)#

This function sets the max workspace size.

Parameters:

workspaceBytes[in] Set the max workspace size in bytes.

const size_t getMaxWorkspaceBytes() const#

This function returns the set max workspace size.

Return values:

size_t – Returns the set max workspace size.

GemmInstance#

class GemmInstance#

hipblasLt extension instance for gemm problems.

Subclassed by hipblaslt_ext::Gemm, hipblaslt_ext::GroupedGemm

Public Functions

hipblasStatus_t algoGetHeuristic(const int requestedAlgoCount, const GemmPreference &pref, std::vector<hipblasLtMatmulHeuristicResult_t> &heuristicResults)#

Retrieve the possible algorithms.

This function retrieves the possible algorithms for the matrix multiply operation hipblasLtMatmul() function with the given data and compute tpye. The output is placed in heuristicResult in the order of increasing estimated compute time.

Parameters:
  • requestedAlgoCount[in] number of requested algorithms.

  • pref[in] hipblasLt extension preference for gemm problems.

  • heuristicResults[out] The algorithm heuristic vector.

Return values:
  • HIPBLAS_STATUS_SUCCESS – If query was successful. Inspect heuristicResults.size > 0, but may heuristicResults.size < requestedAlgoCount state for the status of the results.

  • HIPBLAS_STATUS_NOT_SUPPORTED – If no heuristic function available for current configuration.

  • HIPBLAS_STATUS_INVALID_VALUE – If no solution is found.

hipblasStatus_t isAlgoSupported(hipblasLtMatmulAlgo_t &algo, size_t &workspaceSizeInBytes)#

Check if the algorithm supports the problem. (For hipblaslt extension API)

This function updates the problem saved inside the algorithm if the problem is supported. The required workspaceSizeInBytes is also returned.

Parameters:
  • algo[in] The algorithm heuristic.

  • workspaceSizeInBytes[out] Return the required workspace size.

Return values:
  • HIPBLAS_STATUS_SUCCESS – If query was successful. The problem is supported by the algorithm. results.

  • HIPBLAS_STATUS_INVALID_VALUE – The problem is not supported.

hipblasStatus_t initialize(const hipblasLtMatmulAlgo_t &algo, void *workspace, bool useUserArgs = true, hipStream_t stream = 0)#

Create kernel arguments from a given hipblaslt_ext::GemmInstance.

This function creates kernel arguments from a given hipblaslt_ext::GemmInstance then saves the arguments inside the instance.

Parameters:
  • algo[in] Handle for matrix multiplication algorithm to be used. See hipblaslt.h::hipblasLtMatmulAlgo_t . When NULL, an implicit heuristics query with default search preferences will be performed to determine actual algorithm to use.

  • workspace[in] Pointer to the workspace buffer allocated in the GPU memory. Pointer must be 16B aligned (that is, lowest 4 bits of address must be 0).

  • useUserArgs[in] Use user args, this does not affect vanilla gemm. (May be deprecated in the future)

  • stream[in] The HIP stream where all the GPU work will be submitted. (May be deprecated in the future)

Return values:
  • HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.

  • HIPBLAS_STATUS_INVALID_VALUE – If the gemm_count = 0.

hipblasStatus_t run(hipStream_t stream)#

Execute the kernel arguments stored inside the hipblaslt_ext::GemmInstance.

Parameters:

stream[in] The HIP stream where all the GPU work will be submitted.

Return values:

HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.

Protected Functions

explicit GemmInstance(hipblasLtHandle_t handle, GemmType type)#

Constructor of GemmInstance.

Gemm#

class Gemm : public hipblaslt_ext::GemmInstance#

hipblasLt extension instance for gemm.

The instance can be used to create arguments to compute the matrix multiplication of matrices A and B to produce the output matrix D, according to the following operation: D = alpha*( A *B) + beta*( C ), where A, B, and C are input matrices, and alpha and beta are input scalars.

Public Functions

explicit Gemm(hipblasLtHandle_t handle, hipblasOperation_t opA, hipblasOperation_t opB, hipblasltDatatype_t typeA, hipblasltDatatype_t typeB, hipblasltDatatype_t typeC, hipblasltDatatype_t typeD, hipblasLtComputeType_t typeCompute)#

Constructor.

This function set the problem from hipblasLt structures. For more information about the structures, see hipblasLtMatmul for more information.

Parameters:
  • handle[in] The handle from hipBLASLt.

  • opA, opB[in] The transpose type of matrix A, B

  • typeA, typeB, typeC, typeD[in] The data type of matrix A, B, C, D

  • typeCompute[in] The compute type of the gemm problem

explicit Gemm(hipblasLtHandle_t handle, hipblasLtMatmulDesc_t matmul_descr, const void *alpha, const void *A, hipblasLtMatrixLayout_t matA, const void *B, hipblasLtMatrixLayout_t matB, const void *beta, const void *C, hipblasLtMatrixLayout_t matC, void *D, hipblasLtMatrixLayout_t matD)#

Constructor that sets the gemm problem from hipblasLt structures.

This constructor sets the problem from hipblasLt structures. For more information about the structures, see hipblasLtMatmul for more information.

Parameters:
  • handle[in] The handle from hipBLASLt.

  • matmul_descr[in] Handle to a previously created matrix multiplication descriptor of type hipblasLtMatmulDesc_t .

  • alpha, beta[in] Pointers to the scalars used in the multiplication.

  • matA, matB, matC, matD[in] Handles to the previously created matrix layout descriptors of the type hipblasLtMatrixLayout_t .

  • A, B, C[in] Pointers to the GPU memory associated with the corresponding descriptors matA, matB and matC .

  • D[out] Pointer to the GPU memory associated with the descriptor matD .

hipblasStatus_t setProblem(int64_t m, int64_t n, int64_t k, int64_t batch_count, GemmEpilogue &epilogue, GemmInputs &inputs)#

Sets the problem for a gemm problem.

This function sets the problem with m, n, k, batch_count. It uses the problem type sets from the constructor.

Parameters:
  • m, n, k[in] The problem size.

  • batch_count[in] The batch count.

  • epilogue[in] The structure that controls the epilogue.

  • inputs[in] The inputs of the problem.

Return values:
  • HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.

  • HIPBLAS_STATUS_EXECUTION_FAILED – If HIP reported an execution error from the device.

  • HIPBLAS_STATUS_ARCH_MISMATCH – If the configured operation cannot be run using the selected device.

  • HIPBLAS_STATUS_NOT_SUPPORTED – If the current implementation on the selected device doesn’t support the configured operation.

  • HIPBLAS_STATUS_INVALID_VALUE – If the parameters are unexpectedly NULL, in conflict or in an impossible configuration.

  • HIBLAS_STATUS_NOT_INITIALIZED – If hipBLASLt handle has not been initialized.

hipblasStatus_t setProblem(int64_t m, int64_t n, int64_t k, int64_t batch_count, int64_t lda, int64_t ldb, int64_t ldc, int64_t ldd, int64_t strideA, int64_t strideB, int64_t strideC, int64_t strideD, GemmEpilogue &epilogue, GemmInputs &inputs, GemmProblemType &problemtype)#

Sets the problem for a gemm problem.

This function sets the problem with m, n, k, batch_count. It uses the problem type sets from the constructor.

Parameters:
  • m, n, k[in] The problem size.

  • batch_count[in] The batch count.

  • lda, ldb, ldc, ldd[in] The leading dimensions of the matrix.

  • strideA, strideB, strideC, strideD[in] The batch stride of the matrix.

  • epilogue[in] The structure that controls the epilogue.

  • inputs[in] The inputs of the problem.

  • problemtype[in] The structure that sets the problem type of a gemm problem.

Return values:
  • HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.

  • HIPBLAS_STATUS_EXECUTION_FAILED – If HIP reported an execution error from the device.

  • HIPBLAS_STATUS_ARCH_MISMATCH – If the configured operation cannot be run using the selected device.

  • HIPBLAS_STATUS_NOT_SUPPORTED – If the current implementation on the selected device doesn’t support the configured operation.

  • HIPBLAS_STATUS_INVALID_VALUE – If the parameters are unexpectedly NULL, in conflict or in an impossible configuration.

  • HIBLAS_STATUS_NOT_INITIALIZED – If hipBLASLt handle has not been initialized.

hipblasStatus_t setProblem(hipblasLtMatmulDesc_t matmul_descr, const void *alpha, const void *A, hipblasLtMatrixLayout_t matA, const void *B, hipblasLtMatrixLayout_t matB, const void *beta, const void *C, hipblasLtMatrixLayout_t matC, void *D, hipblasLtMatrixLayout_t matD)#

Sets the gemm problem from hipblasLt structures.

This function sets the problem from hipblasLt structures. For more information about the structures, see hipblasLtMatmul for more information.

Parameters:
  • matmul_descr[in] Handle to a previously created matrix multiplication descriptor of type hipblasLtMatmulDesc_t .

  • alpha, beta[in] Pointers to the scalars used in the multiplication.

  • matA, matB, matC, matD[in] Handles to the previously created matrix layout descriptors of the type hipblasLtMatrixLayout_t .

  • A, B, C[in] Pointers to the GPU memory associated with the corresponding descriptors matA, matB and matC .

  • D[out] Pointer to the GPU memory associated with the descriptor matD .

Return values:
  • HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.

  • HIPBLAS_STATUS_EXECUTION_FAILED – If HIP reported an execution error from the device.

  • HIPBLAS_STATUS_ARCH_MISMATCH – If the configured operation cannot be run using the selected device.

  • HIPBLAS_STATUS_NOT_SUPPORTED – If the current implementation on the selected device doesn’t support the configured operation.

  • HIPBLAS_STATUS_INVALID_VALUE – If the parameters are unexpectedly NULL, in conflict or in an impossible configuration.

  • HIBLAS_STATUS_NOT_INITIALIZED – If hipBLASLt handle has not been initialized.

GroupedGemm#

class GroupedGemm : public hipblaslt_ext::GemmInstance#

hipblasLt extension instance for grouped gemm.

The instance can be used to create arguments to compute the matrix multiplication of matrices A and B to produce the output matrix D, according to the following operation: D = alpha*( A *B) + beta*( C ), where A, B, and C are input matrices, and alpha and beta are input scalars.

Public Functions

explicit GroupedGemm(hipblasLtHandle_t handle, hipblasOperation_t opA, hipblasOperation_t opB, hipblasltDatatype_t typeA, hipblasltDatatype_t typeB, hipblasltDatatype_t typeC, hipblasltDatatype_t typeD, hipblasLtComputeType_t typeCompute)#

Constructor.

This function set the problem from hipblasLt structures. For more information about the structures, see hipblasLtMatmul for more information.

Parameters:
  • handle[in] The handle from hipBLASLt.

  • opA, opB[in] The transpose type of matrix A, B

  • typeA, typeB, typeC, typeD[in] The data type of matrix A, B, C, D

  • typeCompute[in] The compute type of the gemm problem

explicit GroupedGemm(hipblasLtHandle_t handle, std::vector<hipblasLtMatmulDesc_t> &matmul_descr, std::vector<void*> &alpha, std::vector<void*> &A, std::vector<hipblasLtMatrixLayout_t> &matA, std::vector<void*> &B, std::vector<hipblasLtMatrixLayout_t> &matB, std::vector<void*> &beta, std::vector<void*> &C, std::vector<hipblasLtMatrixLayout_t> &matC, std::vector<void*> &D, std::vector<hipblasLtMatrixLayout_t> &matD)#

Constructor that sets the grouped gemm problem from hipblasLt structures.

This constructor sets the problem from hipblasLt structures. For more information about the structures, see hipblasLtMatmul for more information.

Parameters:
  • handle[in] The handle from hipBLASLt.

  • matmul_descr[in] Vectors of handle to a previously created matrix multiplication descriptor of type hipblasLtMatmulDesc_t .

  • alpha, beta[in] Vectors of float used in the multiplication.

  • matA, matB, matC, matD[in] Vectors of handle to the previously created matrix layout descriptors of the type hipblasLtMatrixLayout_t .

  • A, B, C[in] Vectors of pointer to the GPU memory associated with the corresponding descriptors matA, matB and matC .

  • D[out] Vector of pointer to the GPU memory associated with the descriptor matD .

hipblasStatus_t setProblem(std::vector<int64_t> &m, std::vector<int64_t> &n, std::vector<int64_t> &k, std::vector<int64_t> &batch_count, std::vector<GemmEpilogue> &epilogue, std::vector<GemmInputs> &inputs)#

Sets the problem for a gemm problem.

This function sets the problem with m, n, k, batch_count. It uses the problem type sets from the constructor.

Parameters:
  • m, n, k[in] The problem size in vector.

  • batch_count[in] The batch count in vector.

  • epilogue[in] The structure in vector that controls the epilogue.

  • inputs[in] The inputs in vector of the problem.

Return values:
  • HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.

  • HIPBLAS_STATUS_EXECUTION_FAILED – If HIP reported an execution error from the device.

  • HIPBLAS_STATUS_ARCH_MISMATCH – If the configured operation cannot be run using the selected device.

  • HIPBLAS_STATUS_NOT_SUPPORTED – If the current implementation on the selected device doesn’t support the configured operation.

  • HIPBLAS_STATUS_INVALID_VALUE – If the parameters are unexpectedly NULL, in conflict or in an impossible configuration.

  • HIBLAS_STATUS_NOT_INITIALIZED – If hipBLASLt handle has not been initialized.

hipblasStatus_t setProblem(std::vector<int64_t> &m, std::vector<int64_t> &n, std::vector<int64_t> &k, std::vector<int64_t> &batch_count, std::vector<int64_t> &lda, std::vector<int64_t> &ldb, std::vector<int64_t> &ldc, std::vector<int64_t> &ldd, std::vector<int64_t> &strideA, std::vector<int64_t> &strideB, std::vector<int64_t> &strideC, std::vector<int64_t> &strideD, std::vector<GemmEpilogue> &epilogue, std::vector<GemmInputs> &inputs, GemmProblemType &problemtype)#

Sets the problem for a gemm problem.

This function sets the problem with m, n, k, batch_count. It uses the problem type sets from the constructor.

Parameters:
  • m, n, k[in] The problem size in vector.

  • batch_count[in] The batch count in vector.

  • lda, ldb, ldc, ldd[in] The leading dimensions in vector of the matrix.

  • strideA, strideB, strideC, strideD[in] The batch stride in vector of the matrix.

  • epilogue[in] The structure in vector that controls the epilogue.

  • inputs[in] The inputs in vector of the problem.

  • problemtype[in] The structure that sets the problem type of a gemm problem.

Return values:
  • HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.

  • HIPBLAS_STATUS_EXECUTION_FAILED – If HIP reported an execution error from the device.

  • HIPBLAS_STATUS_ARCH_MISMATCH – If the configured operation cannot be run using the selected device.

  • HIPBLAS_STATUS_NOT_SUPPORTED – If the current implementation on the selected device doesn’t support the configured operation.

  • HIPBLAS_STATUS_INVALID_VALUE – If the parameters are unexpectedly NULL, in conflict or in an impossible configuration.

  • HIBLAS_STATUS_NOT_INITIALIZED – If hipBLASLt handle has not been initialized.

hipblasStatus_t setProblem(std::vector<hipblasLtMatmulDesc_t> &matmul_descr, std::vector<void*> &alpha, std::vector<void*> &A, std::vector<hipblasLtMatrixLayout_t> &matA, std::vector<void*> &B, std::vector<hipblasLtMatrixLayout_t> &matB, std::vector<void*> &beta, std::vector<void*> &C, std::vector<hipblasLtMatrixLayout_t> &matC, std::vector<void*> &D, std::vector<hipblasLtMatrixLayout_t> &matD)#

Sets the grouped gemm problem from hipblasLt structures.

This function sets the problem from hipblasLt structures. For more information about the structures, see hipblasLtMatmul for more information.

Parameters:
  • matmul_descr[in] Vectors of handle to a previously created matrix multiplication descriptor of type hipblasLtMatmulDesc_t .

  • alpha, beta[in] Vectors of float used in the multiplication.

  • matA, matB, matC, matD[in] Vectors of handle to the previously created matrix layout descriptors of the type hipblasLtMatrixLayout_t .

  • A, B, C[in] Vectors of pointer to the GPU memory associated with the corresponding descriptors matA, matB and matC .

  • D[out] Vector of pointer to the GPU memory associated with the descriptor matD .

Return values:
  • HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.

  • HIPBLAS_STATUS_EXECUTION_FAILED – If HIP reported an execution error from the device.

  • HIPBLAS_STATUS_ARCH_MISMATCH – If the configured operation cannot be run using the selected device.

  • HIPBLAS_STATUS_NOT_SUPPORTED – If the current implementation on the selected device doesn’t support the configured operation.

  • HIPBLAS_STATUS_INVALID_VALUE – If the parameters are unexpectedly NULL, in conflict or in an impossible configuration.

  • HIBLAS_STATUS_NOT_INITIALIZED – If hipBLASLt handle has not been initialized.

hipblasStatus_t getDefaultValueForDeviceUserArguments(void *hostDeviceUserArgs)#

A helper function to initialize DeviceUserArguments using the set problem(s) saved in the gemm object.

Parameters:

hostDeviceUserArgs[in] The DeviceUserArguments struture allocated in host. Note that the user must put the correct type of the DeviceUserArguments.

Return values:

HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.

hipblasStatus_t run(void *deviceUserArgs, hipStream_t stream)#

Run the kernel using DeviceUserArguments.

Parameters:
  • deviceUserArgs[in] Pointer to the DeviceUserArguments buffer allocated in the GPU memory. Pointer must be 16B aligned (that is, lowest 4 bits of

  • stream[in] The HIP stream where all the GPU work will be submitted.

Return values:
  • HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.

  • HIPBLAS_STATUS_INVALID_VALUE – If the gemm_count = 0.

hipblasStatus_t run(hipStream_t stream)#

Execute the kernel arguments stored inside the hipblaslt_ext::GemmInstance.

Parameters:

stream[in] The HIP stream where all the GPU work will be submitted.

Return values:

HIPBLAS_STATUS_SUCCESS – If the operation completed successfully.

hipBLASLtExt API Reference#

getAllAlgos()#

hipblasStatus_t hipblaslt_ext::getAllAlgos(hipblasLtHandle_t handle, GemmType typeGemm, hipblasOperation_t opA, hipblasOperation_t opB, hipblasltDatatype_t typeA, hipblasltDatatype_t typeB, hipblasltDatatype_t typeC, hipblasltDatatype_t typeD, hipblasLtComputeType_t typeCompute, std::vector<hipblasLtMatmulHeuristicResult_t> &heuristicResults)#

Retrieve the possible algorithms.

This function retrieves the possible algorithms for the matrix multiply operation hipblasLtMatmul() function with the given data and compute tpye. The output is placed in heuristicResults in the order of increasing estimated compute time.

Parameters:
  • handle[in] Pointer to the allocated hipBLASLt handle for the hipBLASLt context. See hipblasLtHandle_t .

  • typeGemm[in] Gemm type. ex. GEMM, GROUPED_GEMM.

  • opA, opB[in] Transpose settings of A, B.

  • typeA, typeB, typeC, typeD[in] The data type of matrix A, B, C, D.

  • typeCompute[in] The compute type.

  • heuristicResults[out] The algorithm heuristic vector.

Return values:
  • HIPBLAS_STATUS_SUCCESS – If query was successful. Inspect returnedAlgoCount > 0.state for the status of the results.

  • HIPBLAS_STATUS_NOT_SUPPORTED – If no heuristic function available for current configuration.

  • HIPBLAS_STATUS_INVALID_VALUE – If no solution is found.

getIndexFromAlgo()#

int hipblaslt_ext::getIndexFromAlgo(hipblasLtMatmulAlgo_t &algo)#

Retrieve the algorithm index.

Parameters:

algo[in] The algorithm.

Return values:

int – The index of the algorithm, can be used to get hueristic results from getAlgosFromIndex. Returns -1 if the index stored in algo < 0. Note that the index may not be valid if the algo struct is not initialized properly.

getAlgosFromIndex()#

hipblasStatus_t hipblaslt_ext::getAlgosFromIndex(hipblasLtHandle_t handle, std::vector<int> &algoIndex, std::vector<hipblasLtMatmulHeuristicResult_t> &heuristicResults)#

Retrieve the possible algorithms.

This function retrieves the possible algorithms for the matrix multiply operation hipblasLtMatmul() function with the given index. The output is placed in heuristicResult in the order of increasing estimated compute time.

Parameters:
  • handle[in] Pointer to the allocated hipBLASLt handle for the hipBLASLt context. See hipblasLtHandle_t .

  • algoIndex[in] The algorithm index vector.

  • heuristicResults[out] The algorithm heuristic vector.

Return values:
  • HIPBLAS_STATUS_SUCCESS – If query was successful. Inspect heuristicResults.size() > 0.state for the status of the results.

  • HIPBLAS_STATUS_NOT_SUPPORTED – If no heuristic function available for current configuration.

  • HIPBLAS_STATUS_INVALID_VALUE – If no solution is found.

matmulIsAlgoSupported()#

hipblasStatus_t hipblaslt_ext::matmulIsAlgoSupported(hipblasLtHandle_t handle, hipblasLtMatmulDesc_t matmulDesc, const void *alpha, hipblasLtMatrixLayout_t Adesc, hipblasLtMatrixLayout_t Bdesc, const void *beta, hipblasLtMatrixLayout_t Cdesc, hipblasLtMatrixLayout_t Ddesc, hipblasLtMatmulAlgo_t &algo, size_t &workspaceSizeInBytes)#

Check if the algorithm supports the problem. (For hipblasLt API)

This function updates the problem saved inside the algorithm if the problem is supported. The required workspaceSizeInBytes is also returned.

Parameters:
  • handle[in] Pointer to the allocated hipBLASLt handle for the hipBLASLt context. See hipblasLtHandle_t .

  • matmulDesc[in] Handle to a previously created matrix multiplication descriptor of type hipblasLtMatmulDesc_t .

  • alpha, beta[in] Pointers to the scalars used in the multiplication.

  • Adesc, Bdesc, Cdesc, Ddesc[in] Handles to the previously created matrix layout descriptors of the type hipblasLtMatrixLayout_t .

  • algo[in] The algorithm heuristic.

  • workspaceSizeInBytes[out] Return the required workspace size.

Return values:
  • HIPBLAS_STATUS_SUCCESS – If query was successful. The problem is supported by the algorithm. results.

  • HIPBLAS_STATUS_INVALID_VALUE – The problem is not supported.

hipblasLtExt Usage#

Introduction#

hipBLASLt has extension APIs with namespace hipblaslt_ext. It is C++ compatible only. The extensions support:

  1. Gemm

  2. Grouped gemm

  3. Get all algorithms

Gemm#

hipblasLt has its own instance.

The user must assign the problem type when construct or import the problem from hipBLAS API.

HIPBLASLT_EXPORT explicit Gemm(hipblasLtHandle_t      handle,
                               hipblasOperation_t     opA,
                               hipblasOperation_t     opB,
                               hipblasltDatatype_t      typeA,
                               hipblasltDatatype_t      typeB,
                               hipblasltDatatype_t      typeC,
                               hipblasltDatatype_t      typeD,
                               hipblasLtComputeType_t typeCompute);

HIPBLASLT_EXPORT explicit Gemm(hipblasLtHandle_t       handle,
                               hipblasLtMatmulDesc_t   matmul_descr,
                               const void*             alpha,
                               const void*             A,
                               hipblasLtMatrixLayout_t matA,
                               const void*             B,
                               hipblasLtMatrixLayout_t matB,
                               const void*             beta,
                               const void*             C,
                               hipblasLtMatrixLayout_t matC,
                               void*                   D,
                               hipblasLtMatrixLayout_t matD);

After the instance is created, the user can set the problem with the API. The API may requires the following structures:

GemmProblemType lets user able to change the problem type after the instance is initialized.

struct GemmProblemType
{
    hipblasOperation_t     op_a;
    hipblasOperation_t     op_b;
    hipblasltDatatype_t      type_a;
    hipblasltDatatype_t      type_b;
    hipblasltDatatype_t      type_c;
    hipblasltDatatype_t      type_d;
    hipblasLtComputeType_t type_compute;
};

GemmEpilogue lets user to control the epilogue of the problem.

struct GemmEpilogue
{
    hipblasLtEpilogue_t mode = HIPBLASLT_EPILOGUE_DEFAULT;
    hipblasltDatatype_t   bias_data_type;
    int                 aux_ld;
    int                 aux_stride;
};

GemmInputs is the problem inputs.

struct GemmInputs
{
    void* a = nullptr;
    void* b = nullptr;
    void* c = nullptr;
    void* d = nullptr;
    void* alpha = nullptr;
    void* beta = nullptr;
    // Epilogue inputs
    void* bias = nullptr;
    void* aux = nullptr;
};

And the setProblem APIs:

HIPBLASLT_EXPORT hipblasStatus_t setProblem(
    int64_t m, int64_t n, int64_t k, int64_t batch_count, GemmEpilogue& epilogue, GemmInputs& inputs);

The user can also set the leading dimensions, strides, and reassign the data type with the following API.

HIPBLASLT_EXPORT hipblasStatus_t setProblem(int64_t          m,
                                            int64_t          n,
                                            int64_t          k,
                                            int64_t          batch_count,
                                            int64_t          lda,
                                            int64_t          ldb,
                                            int64_t          ldc,
                                            int64_t          ldd,
                                            int64_t          strideA,
                                            int64_t          strideB,
                                            int64_t          strideC,
                                            int64_t          strideD,
                                            GemmEpilogue&    epilogue,
                                            GemmInputs&      inputs,
                                            GemmProblemType& problemtype);

The user can also importing problems from hipblasLt APIs after the instance is created, note that this may overwrite the problem type of the instance.

HIPBLASLT_EXPORT hipblasStatus_t setProblem(hipblasLtMatmulDesc_t   matmul_descr,
                                            const void*             alpha,
                                            const void*             A,
                                            hipblasLtMatrixLayout_t matA,
                                            const void*             B,
                                            hipblasLtMatrixLayout_t matB,
                                            const void*             beta,
                                            const void*             C,
                                            hipblasLtMatrixLayout_t matC,
                                            void*                   D,
                                            hipblasLtMatrixLayout_t matD);

The user can get hueristic and make kernel arguments with the instance. If the properties of the gemm and the inputs don’t change, the user can call the run API to launch the kernel directly.

// Pseudo code
hipblaslt_ext::GemmPreference pref;
pref.setMaxWorkspaceBytes(1000000);
// Default epilogue mode is HIPBLASLT_EPILOGUE_DEFAULT
hipblaslt_ext::GemmEpilogue epilogue;
hipblaslt_ext::GemmInputs inputs;
inputs.a = a;
inputs.b = b;
inputs.c = c;
inputs.d = d;
inputs.alpha = alpha;
inputs.beta = beta;

hipblaslt_ext::Gemm gemm(handle,
                         HIPBLAS_OP_N,
                         HIPBLAS_OP_N,
                         HIPBLASLT_R_16F,
                         HIPBLASLT_R_16F,
                         HIPBLASLT_R_16F,
                         HIPBLASLT_R_16F,
                         HIPBLASLT_COMPUTE_F32);
std::vector<hipblasLtMatmulHeuristicResult_t> hueristic;
gemm.setProblem(1, 1, 1, 1, epilogue, inputs); // m, n, k, batch
gemm.algoGetHeuristic(gemm, pref, hueristic);
gemm.initialize(hueristic[0].algo, d_workspace, stream);
for(int i = 0; i < 10; i++)
{
    gemm.run(stream);
}

Grouped Gemm#

hipblasLtExt supports grouped gemm. It shares the same class with normal gemm.

After the problem is set, the user can check the problem type with function getGemmType().

enum class GemmType
{
    HIPBLASLT_GEMM             = 1,
    HIPBLASLT_GROUPED_GEMM     = 2
};

The grouped gemm class also has the setProblem APIs.

HIPBLASLT_EXPORT hipblasStatus_t setProblem(
    int64_t m, int64_t n, int64_t k, int64_t batch_count, GemmEpilogue& epilogue, GemmInputs& inputs);

HIPBLASLT_EXPORT hipblasStatus_t setProblem(std::vector<int64_t>&      m,
                                            std::vector<int64_t>&      n,
                                            std::vector<int64_t>&      k,
                                            std::vector<int64_t>&      batch_count,
                                            std::vector<GemmEpilogue>& epilogue,
                                            std::vector<GemmInputs>&   inputs);

HIPBLASLT_EXPORT hipblasStatus_t setProblem(std::vector<int64_t>&      m,
                                            std::vector<int64_t>&      n,
                                            std::vector<int64_t>&      k,
                                            std::vector<int64_t>&      batch_count,
                                            std::vector<int64_t>&      lda,
                                            std::vector<int64_t>&      ldb,
                                            std::vector<int64_t>&      ldc,
                                            std::vector<int64_t>&      ldd,
                                            std::vector<int64_t>&      strideA,
                                            std::vector<int64_t>&      strideB,
                                            std::vector<int64_t>&      strideC,
                                            std::vector<int64_t>&      strideD,
                                            std::vector<GemmEpilogue>& epilogue,
                                            std::vector<GemmInputs>&   inputs,
                                            GemmProblemType&           problemtype);

HIPBLASLT_EXPORT hipblasStatus_t setProblem(std::vector<hipblasLtMatmulDesc_t>&   matmul_descr,
                                            std::vector<void*>&                   alpha,
                                            std::vector<void*>&                   A,
                                            std::vector<hipblasLtMatrixLayout_t>& matA,
                                            std::vector<void*>&                   B,
                                            std::vector<hipblasLtMatrixLayout_t>& matB,
                                            std::vector<void*>&                   beta,
                                            std::vector<void*>&                   C,
                                            std::vector<hipblasLtMatrixLayout_t>& matC,
                                            std::vector<void*>&                   D,
                                            std::vector<hipblasLtMatrixLayout_t>& matD);

For the following API, the argument “epilogue” supports broadcasting. They will be broadcasted to the length of the problem size by duplicating the last element.

HIPBLASLT_EXPORT hipblasStatus_t setProblem(std::vector<int64_t>&      m,
                                            std::vector<int64_t>&      n,
                                            std::vector<int64_t>&      k,
                                            std::vector<int64_t>&      batch_count,
                                            std::vector<int64_t>&      lda,
                                            std::vector<int64_t>&      ldb,
                                            std::vector<int64_t>&      ldc,
                                            std::vector<int64_t>&      ldd,
                                            std::vector<int64_t>&      strideA,
                                            std::vector<int64_t>&      strideB,
                                            std::vector<int64_t>&      strideC,
                                            std::vector<int64_t>&      strideD,
                                            std::vector<GemmEpilogue>& epilogue,
                                            std::vector<GemmInputs>&   inputs,
                                            GemmProblemType&           problemtype);

Note that currently only supports problemtype size equals to 1 (Only one GemmProblemType for all problems).

// Pseudo code
std::vector<int64_t> m, n, k;
// ...
for(size_t i = 0; i < problem_size, i++)
{
    // ...
}
std::vector<GemmProblemType> problemtypes;
problemtypes.push_back(problemtype);
groupedgemm.setProblem(m, n, k, batch_count, lda, ldb, ldc, ldd, strideA, strideB, strideC, strideD, epilogue, inputs, problemtypes);

UserArguments#

Grouped gemm supports using external device memory to run the kernel. This will be helpful if some of the arguments are from the output of the pervious kernel. Please refer to section Fixed MK if you want to change the size (m, n, k, batch) related arguments.

struct UserArguments
{
    uint32_t m; //!< size m
    uint32_t n; //!< size n
    uint32_t batch; //!< size batch
    uint32_t k; //!< size k
    void*    d; //!< The d matrix input pointer.
    void*    c; //!< The c matrix input pointer.
    void*    a; //!< The a matrix input pointer.
    void*    b; //!< The b matrix input pointer.
    uint32_t strideD1; //!< The d leading dimension.
    uint32_t strideD2; //!< The d batch stride
    uint32_t strideC1; //!< The c leading dimension.
    uint32_t strideC2; //!< The c batch stride
    uint32_t strideA1; //!< The a leading dimension.
    uint32_t strideA2; //!< The a batch stride
    uint32_t strideB1; //!< The b leading dimension.
    uint32_t strideB2; //!< The b batch stride
    int8_t   alpha[16]; //!< The alpha value.
    int8_t   beta[16]; //!< The beta value.
    // Epilogue inputs
    void*    bias; //!< The bias input pointer.
    int      biasType; //!< The bias datatype. Only works if mode is set to bias related epilogues.
    uint32_t reserved;
    void*    e; //!< The aux input pointer. Only works if mode is set to aux related epilogues.
    uint32_t strideE1; //!< The aux leading dimension. Only works if mode is set to aux related epilogues.
    uint32_t strideE2; //!< The aux batch stride. Only works if mode is set to aux related epilogues.
    float    act0; //!< The activation value 1. Some activations might use it.
    float    act1; //!< The activation value 2.
    int      activationType; //!< The activation type.  Only works if mode is set to activation related epilogues.
} __attribute__((packed));

We add the two functions for UserArguments related API. The first API is a helper function that helps the user to initialize the structure “UserArguments” from the saved problems inside the grouped gemm object. The second API is an overload function with an additional UserArguments device pointer input.

HIPBLASLT_EXPORT hipblasStatus_t getDefaultValueForDeviceUserArguments(void* hostDeviceUserArgs);

HIPBLASLT_EXPORT hipblasStatus_t run(void* deviceUserArgs, hipStream_t stream);

The following is a simple example of how this API works.

// Pseudo code
// Step 1: Get all algorithms
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAllAlgos(handle,
                                                 HIPBLASLT_GEMM,
                                                 HIPBLAS_OP_N,
                                                 HIPBLAS_OP_N,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 HIPBLASLT_COMPUTE_F32,
                                                 heuristicResult));

hipblaslt_ext::GemmPreference pref;
pref.setMaxWorkspaceBytes(1000000);
// Step 2: Setup problem
std::vector<int64_t> m(gemm_count);
std::vector<int64_t> n(gemm_count);
std::vector<int64_t> k(gemm_count);
std::vector<int64_t> batch_count(gemm_count);
std::vector<hipblaslt_ext::GemmEpilogue> epilogue(gemm_count);
std::vector<hipblaslt_ext::GemmInputs> inputs(gemm_count);
for(int i = 0; i < gemm_count; i++)
{
    m[i] = 1;
    n[i] = 1;
    k[i] = 1;
    batch_count[i] = 1;
    epilogue[i].mode = HIPBLASLT_EPILOGUE_GELU;
    inputs[i].a = a[i];
    inputs[i].b = b[i];
    inputs[i].c = c[i];
    inputs[i].d = d[i];
    inputs[i].alpha = alpha[i];
    inputs[i].beta = beta[i];
}

// Step 3: Create grouped gemm instance
hipblaslt_ext::GroupedGemm groupedGemm(handle,
                                       HIPBLAS_OP_N,
                                       HIPBLAS_OP_N,
                                       HIPBLASLT_R_16F,
                                       HIPBLASLT_R_16F,
                                       HIPBLASLT_R_16F,
                                       HIPBLASLT_R_16F,
                                       HIPBLASLT_COMPUTE_F32);

// Step 4: Set problem
groupedGemm.setProblem(m, n, k, batch_count, epilogue, inputs); // m, n, k, batch

// Step 5: Get default value from the instance
hipblaslt_ext::UserArguments* dUAFloat = new hipblaslt_ext::UserArguments[gemm_count];
groupedGemm.getDefaultValueForDeviceUserArguments((void*)dUAFloat);
// Once you get the default value here, you can make several copies and change the values
// from the host

// Next Copy them to the device memory
hipblaslt_ext::UserArguments* d_dUAFloat = nullptr;
hipMalloc(&d_dUAFloat, sizeof(hipblaslt_ext::UserArguments) * gemm_count);
hipMemcpy(d_dUAFloat, dUAFloat, sizeof(hipblaslt_ext::UserArguments) * gemm_count, hipMemcpyHostToDevice);

validIdx.clear();
for(int j = 0; j < heuristicResult.size(); j++)
{
    size_t workspace_size = 0;
    if(groupedGemm.isAlgoSupported(heuristicResult[j].algo, workspace_size)
       == HIPBLAS_STATUS_SUCCESS)
    {
        validIdx.push_back(j);
    }
}

// Step 6: Initialize and run
if(validIdx.size() > 1)
{
    groupedGemm.initialize(heuristicResult[validIdx[0]].algo, d_workspace, stream);
    for(int i = 0; i < 10; i++)
    {
        groupedGemm.run(userArgs, stream);
    }
}

The base class (GemmInstance)#

This is the base class of class Gemm and GroupedGemm.

// Gets huesristic from the instance.
HIPBLASLT_EXPORT hipblasStatus_t algoGetHeuristic(const int                                      requestedAlgoCount,
                                                  const GemmPreference&                          pref,
                                                  std::vector<hipblasLtMatmulHeuristicResult_t>& heuristicResults);

// Returns SUCCESS if the algo is supported, also returns the required workspace size in bytes.
HIPBLASLT_EXPORT hipblasStatus_t isAlgoSupported(hipblasLtMatmulAlgo_t& algo, size_t& workspaceSizeInBytes);

// Initializes the instance before calling run. Requires every time the problem is set.
HIPBLASLT_EXPORT hipblasStatus_t initialize(const hipblasLtMatmulAlgo_t& algo, void* workspace, bool useUserArgs = true, hipStream_t stream = 0);

// Run the problem.
HIPBLASLT_EXPORT hipblasStatus_t run(hipStream_t stream);

Get all algorithms#

Get all algorithms lets users to get all the algorithms of a specific problem type. It requires the transpose of A, B, the data type of the inputs, and the compute type.

HIPBLASLT_EXPORT
hipblasStatus_t hipblaslt_ext::getAllAlgos(hipblasLtHandle_t                              handle,
                                           hipblasLtExtGemmTypeEnum_t                     typeGemm,
                                           hipblasOperation_t                             opA,
                                           hipblasOperation_t                             opB,
                                           hipblasltDatatype_t                              typeA,
                                           hipblasltDatatype_t                              typeB,
                                           hipblasltDatatype_t                              typeC,
                                           hipblasltDatatype_t                              typeD,
                                           hipblasLtComputeType_t                         typeCompute,
                                           std::vector<hipblasLtMatmulHeuristicResult_t>& heuristicResults);

This API does not require any problem size or epilogue as input, but will use another API “isAlgoSupported” to check if the algorithm supports a problem.

hipblaslt_ext::matmulIsAlgoSupported()
gemm.isAlgoSupported()

The API will return the required workspace size in bytes if success.

// Get all algorithms
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAllAlgos(handle,
                                                 HIPBLASLT_GEMM,
                                                 HIPBLAS_OP_N,
                                                 HIPBLAS_OP_N,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 HIPBLASLT_COMPUTE_F32,
                                                 heuristicResult));

validIdx.clear();
for(int j = 0; j < heuristicResult.size(); j++)
{
    size_t workspace_size = 0;
    if(hipblaslt_ext::matmulIsAlgoSupported(handle,
                                            matmul,
                                            &(alpha),
                                            matA,
                                            matB,
                                            &(beta),
                                            matC,
                                            matD,
                                            heuristicResult[j].algo,
                                            workspace_size)
       == HIPBLAS_STATUS_SUCCESS)
    {
        validIdx.push_back(j);
        heuristicResult[j].workspaceSize = workspace_size;
    }
    else
    {
        heuristicResult[j].workspaceSize = 0;
    }
}

Using extension APIs.

Gemm#

// Pseudo code for gemm problem
// Get all algorithms
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAllAlgos(handle,
                                                 HIPBLASLT_GEMM,
                                                 HIPBLAS_OP_N,
                                                 HIPBLAS_OP_N,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 HIPBLASLT_COMPUTE_F32,
                                                 heuristicResult));

hipblaslt_ext::GemmPreference pref;
pref.setMaxWorkspaceBytes(1000000);
hipblaslt_ext::GemmEpilogue epilogue;
epilogue.mode = HIPBLASLT_EPILOGUE_GELU;
hipblaslt_ext::GemmInputs inputs;
inputs.a = a;
inputs.b = b;
inputs.c = c;
inputs.d = d;
inputs.alpha = alpha;
inputs.beta = beta;

hipblaslt_ext::Gemm gemm(handle,
                         HIPBLAS_OP_N,
                         HIPBLAS_OP_N,
                         HIPBLASLT_R_16F,
                         HIPBLASLT_R_16F,
                         HIPBLASLT_R_16F,
                         HIPBLASLT_R_16F,
                         HIPBLASLT_COMPUTE_F32);

gemm.setProblem(1, 1, 1, 1, epilogue, inputs); // m, n, k, batch

validIdx.clear();
for(int j = 0; j < heuristicResult.size(); j++)
{
    size_t workspace_size = 0;
    if(gemm.isAlgoSupported(heuristicResult[j].algo, workspace_size)
       == HIPBLAS_STATUS_SUCCESS)
    {
        validIdx.push_back(j);
        heuristicResult[j].workspaceSize = workspace_size;
    }
    else
    {
        heuristicResult[j].workspaceSize = 0;
    }
}

if(validIdx.size() > 1)
{
    gemm.initialize(heuristicResult[validIdx[0]].algo, d_workspace, stream);
    for(int i = 0; i < 10; i++)
    {
        gemm.run(stream);
    }
}

Grouped gemm#

// Pseudo code for grouped gemm problem
// Get all algorithms
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAllAlgos(handle,
                                                 HIPBLASLT_GEMM,
                                                 HIPBLAS_OP_N,
                                                 HIPBLAS_OP_N,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 HIPBLASLT_COMPUTE_F32,
                                                 heuristicResult));

hipblaslt_ext::GemmPreference pref;
pref.setMaxWorkspaceBytes(1000000);

std::vector<int64_t> m(gemm_count);
std::vector<int64_t> n(gemm_count);
std::vector<int64_t> k(gemm_count);
std::vector<int64_t> batch_count(gemm_count);
std::vector<hipblaslt_ext::GemmEpilogue> epilogue(gemm_count);
std::vector<hipblaslt_ext::GemmInputs> inputs(gemm_count);
for(int i = 0; i < gemm_count; i++)
{
    m[i] = 1;
    n[i] = 1;
    k[i] = 1;
    batch_count[i] = 1;
    epilogue[i].mode = HIPBLASLT_EPILOGUE_GELU;
    inputs[i].a = a[i];
    inputs[i].b = b[i];
    inputs[i].c = c[i];
    inputs[i].d = d[i];
    inputs[i].alpha = alpha[i];
    inputs[i].beta = beta[i];
}


hipblaslt_ext::GroupedGemm groupedGemm(handle,
                                       HIPBLAS_OP_N,
                                       HIPBLAS_OP_N,
                                       HIPBLASLT_R_16F,
                                       HIPBLASLT_R_16F,
                                       HIPBLASLT_R_16F,
                                       HIPBLASLT_R_16F,
                                       HIPBLASLT_COMPUTE_F32);

groupedGemm.setProblem(m, n, k, batch_count, epilogue, inputs); // m, n, k, batch

validIdx.clear();
for(int j = 0; j < heuristicResult.size(); j++)
{
    size_t workspace_size = 0;
    if(groupedGemm.isAlgoSupported(heuristicResult[j].algo, workspace_size)
       == HIPBLAS_STATUS_SUCCESS)
    {
        validIdx.push_back(j);
    }
}

if(validIdx.size() > 1)
{
    groupedGemm.initialize(heuristicResult[validIdx[0]].algo, d_workspace, stream);
    for(int i = 0; i < 10; i++)
    {
        groupedGemm.run(stream);
    }
}

Algorithm Index#

The extension API lets user to get the algorithm index from hipblasLtMatmulAlgo_t.

HIPBLASLT_EXPORT int getIndexFromAlgo(hipblasLtMatmulAlgo_t& algo);

It also supports user to get the heuristic results by giving an index vector.

HIPBLASLT_EXPORT
hipblasStatus_t
    getAlgosFromIndex(hipblasLtHandle_t                              handle,
                      std::vector<int>&                              algoIndex,
                      std::vector<hipblasLtMatmulHeuristicResult_t>& heuristicResults);

Example code#

[Grouped Gemm] Fixed MK#

hipBLASLt extension supports changing the sizes (m, n, k, batch) from the device memory “UserArguments”, but the setup is a bit different from the normal routing.

Sum of n#

A sum of N is required to use as an input for the grouped gemm instance.

For example, we have a grouped gemm with gemm_count = 4. The sum of N must not exceed the “sum of N” set in setProblem API. In this mode, the first element is the “sum of n” in the array of Ns.

// Pseudo code
// Step 1: Get all algorithms
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAllAlgos(handle,
                                                 HIPBLASLT_GEMM,
                                                 HIPBLAS_OP_N,
                                                 HIPBLAS_OP_N,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 HIPBLASLT_COMPUTE_F32,
                                                 heuristicResult));

hipblaslt_ext::GemmPreference pref;
pref.setMaxWorkspaceBytes(1000000);
// Step 2: Setup problem
std::vector<int64_t> m(gemm_count);
std::vector<int64_t> n(gemm_count);
std::vector<int64_t> k(gemm_count);
std::vector<int64_t> batch_count(gemm_count);
std::vector<hipblaslt_ext::GemmEpilogue> epilogue(gemm_count);
std::vector<hipblaslt_ext::GemmInputs> inputs(gemm_count);

// Step 2.1: Calculate sum of n
int64_t sum_of_n = 0;
for(int i = 0; i < gemm_count; i++)
{
    sum_of_n += n_arr[i];
}

// {sum_of_n, 1, 1, 1, ...}; // The array of N, the first element is the sum of N
for(int i = 0; i < gemm_count; i++)
{
    m[i] = m_arr[i];
    if(i == 0)
        n[i] = sum_of_n;
    else
        n[i] = 1;
    k[i] = k_arr[i];
    batch_count[i] = 1;
    inputs[i].a = a[i];
    inputs[i].b = b[i];
    inputs[i].c = c[i];
    inputs[i].d = d[i];
    inputs[i].alpha = alpha[i];
    inputs[i].beta = beta[i];
}

// Step 3: Create grouped gemm instance
hipblaslt_ext::GroupedGemm groupedGemm(handle,
                                       HIPBLAS_OP_N,
                                       HIPBLAS_OP_N,
                                       HIPBLASLT_R_16F,
                                       HIPBLASLT_R_16F,
                                       HIPBLASLT_R_16F,
                                       HIPBLASLT_R_16F,
                                       HIPBLASLT_COMPUTE_F32);

// Step 4: Set problem
groupedGemm.setProblem(m, n, k, batch_count, epilogue, inputs); // m, n, k, batch

// Step 5: Get default value from the instance
hipblaslt_ext::UserArguments* dUAFloat = new hipblaslt_ext::UserArguments[gemm_count];
groupedGemm.getDefaultValueForDeviceUserArguments((void*)dUAFloat);
// Once you get the default value here, you can make several copies and change the values
// from the host

// Next Copy them to the device memory
hipblaslt_ext::UserArguments* d_dUAFloat = nullptr;
hipMalloc(&d_dUAFloat, sizeof(hipblaslt_ext::UserArguments) * gemm_count);
hipMemcpy(d_dUAFloat, dUAFloat, sizeof(hipblaslt_ext::UserArguments) * gemm_count, hipMemcpyHostToDevice);

validIdx.clear();
for(int j = 0; j < heuristicResult.size(); j++)
{
    size_t workspace_size = 0;
    if(groupedGemm.isAlgoSupported(heuristicResult[j].algo, workspace_size)
       == HIPBLAS_STATUS_SUCCESS)
    {
        validIdx.push_back(j);
    }
}

int threads = 256;
int blocks  = ceil((double)gemm_count / threads);

// Step 6: Initialize and run
if(validIdx.size() > 1)
{
    groupedGemm.initialize(heuristicResult[validIdx[0]].algo, d_workspace);
    for(int i = 0; i < 10; i++)
    {
        hipLaunchKernelGGL(kernelUpdateN,
                            dim3(blocks),
                            dim3(threads),
                            0,
                            stream,
                            gemm_count,
                            d_dUAFloat,
                            d_n_vec);  // d_n_vec is a device pointer with Ns
        groupedGemm.run(userArgs, stream);
    }
}

// .....

__global__ void kernelUpdateN(uint32_t gemm_count, void* userArgs, int32_t* sizes_n)
{
uint64_t id = hipBlockIdx_x * 256 + hipThreadIdx_x;

if(id >= gemm_count)
    return;

hipblaslt_ext::UserArguments* dUAFloat = static_cast<hipblaslt_ext::UserArguments*>(userArgs);
dUAFloat[id].n                         = sizes_n[id];
}