This page contains proposed changes for a future release of ROCm. Read the latest Linux release of ROCm documentation for your production environments.

hipBLASLtExt API reference#

hipBLASLt contains extension APIs with the namespace hipblaslt_ext. They are only C++ compatible. The extensions support the following:

  1. GEMM

  2. Grouped GEMM

  3. Get all algorithms

hipBLASLtExt datatypes reference#

GemmType#

Warning

doxygenenum: Cannot find enum “GemmType” in doxygen xml output for project “hipBLASLt 0.15.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipblaslt/checkouts/develop/docs/doxygen/xml

GemmProblemType#

Warning

doxygenstruct: Cannot find class “hipblaslt_ext::GemmProblemType” in doxygen xml output for project “hipBLASLt 0.15.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipblaslt/checkouts/develop/docs/doxygen/xml

GemmProblemTypeV2#

Warning

doxygenclass: Cannot find class “hipblaslt_ext::GemmProblemTypeV2” in doxygen xml output for project “hipBLASLt 0.15.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipblaslt/checkouts/develop/docs/doxygen/xml

GemmEpilogue#

Warning

doxygenstruct: Cannot find class “hipblaslt_ext::GemmEpilogue” in doxygen xml output for project “hipBLASLt 0.15.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipblaslt/checkouts/develop/docs/doxygen/xml

GemmEpilogueV2#

Warning

doxygenclass: Cannot find class “hipblaslt_ext::GemmEpilogueV2” in doxygen xml output for project “hipBLASLt 0.15.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipblaslt/checkouts/develop/docs/doxygen/xml

GemmInputs#

Warning

doxygenstruct: Cannot find class “hipblaslt_ext::GemmInputs” in doxygen xml output for project “hipBLASLt 0.15.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipblaslt/checkouts/develop/docs/doxygen/xml

GemmInputsV2#

Warning

doxygenclass: Cannot find class “hipblaslt_ext::GemmInputsV2” in doxygen xml output for project “hipBLASLt 0.15.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipblaslt/checkouts/develop/docs/doxygen/xml

hipBLASLtExt GEMM class reference#

GemmPreference#

Warning

doxygenclass: Cannot find class “hipblaslt_ext::GemmPreference” in doxygen xml output for project “hipBLASLt 0.15.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipblaslt/checkouts/develop/docs/doxygen/xml

GemmPreferenceV2#

Warning

doxygenclass: Cannot find class “hipblaslt_ext::GemmPreferenceV2” in doxygen xml output for project “hipBLASLt 0.15.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipblaslt/checkouts/develop/docs/doxygen/xml

GemmInstance#

Warning

doxygenclass: Cannot find class “hipblaslt_ext::GemmInstance” in doxygen xml output for project “hipBLASLt 0.15.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipblaslt/checkouts/develop/docs/doxygen/xml

Gemm#

Warning

doxygenclass: Cannot find class “hipblaslt_ext::Gemm” in doxygen xml output for project “hipBLASLt 0.15.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipblaslt/checkouts/develop/docs/doxygen/xml

GroupedGemm#

Warning

doxygenclass: Cannot find class “hipblaslt_ext::GroupedGemm” in doxygen xml output for project “hipBLASLt 0.15.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipblaslt/checkouts/develop/docs/doxygen/xml

hipBLASLtExt API reference#

getAllAlgos()#

Warning

doxygenfunction: Cannot find function “getAllAlgos” in doxygen xml output for project “hipBLASLt 0.15.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipblaslt/checkouts/develop/docs/doxygen/xml

getIndexFromAlgo()#

Warning

doxygenfunction: Cannot find function “getIndexFromAlgo” in doxygen xml output for project “hipBLASLt 0.15.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipblaslt/checkouts/develop/docs/doxygen/xml

getAlgosFromIndex()#

Warning

doxygenfunction: Cannot find function “getAlgosFromIndex” in doxygen xml output for project “hipBLASLt 0.15.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipblaslt/checkouts/develop/docs/doxygen/xml

matmulIsAlgoSupported()#

Warning

doxygenfunction: Cannot find function “matmulIsAlgoSupported” in doxygen xml output for project “hipBLASLt 0.15.0 Documentation” from directory: /home/docs/checkouts/readthedocs.org/user_builds/advanced-micro-devices-hipblaslt/checkouts/develop/docs/doxygen/xml

hipblasLtExt usage#

Here are the three use cases supported by the hipBLASLtExt APIs.

GEMM#

hipblasLt has its own instance. You must assign the problem type when constructing or importing the problem from the hipBLAS API.

HIPBLASLT_EXPORT explicit Gemm(hipblasLtHandle_t      handle,
                               hipblasOperation_t     opA,
                               hipblasOperation_t     opB,
                               hipDataType      typeA,
                               hipDataType      typeB,
                               hipDataType      typeC,
                               hipDataType      typeD,
                               hipblasComputeType_t typeCompute);

HIPBLASLT_EXPORT explicit Gemm(hipblasLtHandle_t       handle,
                               hipblasLtMatmulDesc_t   matmul_descr,
                               const void*             alpha,
                               const void*             A,
                               hipblasLtMatrixLayout_t matA,
                               const void*             B,
                               hipblasLtMatrixLayout_t matB,
                               const void*             beta,
                               const void*             C,
                               hipblasLtMatrixLayout_t matC,
                               void*                   D,
                               hipblasLtMatrixLayout_t matD);

After the instance is created, you can set the problem using the API. The API might require the following structures:

  • GemmProblemType: This lets you to change the problem type after the instance is initialized.

    Note

    This structure is deprecated. Use GemmProblemTypeV2 instead.

    struct GemmProblemType
    {
       hipblasOperation_t     op_a;
       hipblasOperation_t     op_b;
       hipDataType      type_a;
       hipDataType      type_b;
       hipDataType      type_c;
       hipDataType      type_d;
       hipblasComputeType_t type_compute;
    };
    
  • GemmEpilogue: This lets you control the epilogue of the problem.

    Note

    This structure is deprecated. Use GemmEpilogueV2 instead.

    struct GemmEpilogue
    {
       hipblasLtEpilogue_t mode = HIPBLASLT_EPILOGUE_DEFAULT;
       hipDataType   bias_data_type;
       int                 aux_ld;
       int                 aux_stride;
    };
    
  • GemmInputs: This specifies the problem inputs.

    Note

    This structure is deprecated. Use GemmInputsV2 instead.

    struct GemmInputs
    {
       void* a = nullptr;
       void* b = nullptr;
       void* c = nullptr;
       void* d = nullptr;
       void* alpha = nullptr;
       void* beta = nullptr;
       // Epilogue inputs
       void* bias = nullptr;
       void* aux = nullptr;
    };
    
  • setProblem APIs:

    HIPBLASLT_EXPORT hipblasStatus_t setProblem(
       int64_t m, int64_t n, int64_t k, int64_t batch_count, GemmEpilogueV2& epilogue, GemmInputsV2& inputs);
    

You can set the leading dimensions and strides and reassign the data type using the following API:

HIPBLASLT_EXPORT hipblasStatus_t setProblem(int64_t            m,
                                            int64_t            n,
                                            int64_t            k,
                                            int64_t            batch_count,
                                            int64_t            lda,
                                            int64_t            ldb,
                                            int64_t            ldc,
                                            int64_t            ldd,
                                            int64_t            strideA,
                                            int64_t            strideB,
                                            int64_t            strideC,
                                            int64_t            strideD,
                                            GemmEpilogueV2&    epilogue,
                                            GemmInputsV2&      inputs,
                                            GemmProblemTypeV2& problemtype);

You can import problems from the hipblasLt APIs after the instance is created.

Note

This can overwrite the problem type of the instance.

HIPBLASLT_EXPORT hipblasStatus_t setProblem(hipblasLtMatmulDesc_t   matmul_descr,
                                            const void*             alpha,
                                            const void*             A,
                                            hipblasLtMatrixLayout_t matA,
                                            const void*             B,
                                            hipblasLtMatrixLayout_t matB,
                                            const void*             beta,
                                            const void*             C,
                                            hipblasLtMatrixLayout_t matC,
                                            void*                   D,
                                            hipblasLtMatrixLayout_t matD);

You can retrieve heuristics and set kernel arguments with the instance. If the properties of the GEMM and the inputs don’t change, you can call the run API to launch the kernel directly.

// Pseudo code
hipblaslt_ext::GemmPreferenceV2 pref;
pref.setMaxWorkspaceBytes(1000000);
// Default epilogue mode is HIPBLASLT_EPILOGUE_DEFAULT
hipblaslt_ext::GemmEpilogueV2 epilogue;
hipblaslt_ext::GemmInputsV2 inputs;
inputs.setA(d_a);
inputs.setB(d_b);
inputs.setC(d_c);
inputs.setD(d_d);
inputs.setAlpha(&alpha);
inputs.setBeta(&beta);

hipblaslt_ext::Gemm gemm(handle,
                         HIPBLAS_OP_N,
                         HIPBLAS_OP_N,
                         HIP_R_16F,
                         HIP_R_16F,
                         HIP_R_16F,
                         HIP_R_16F,
                         HIPBLAS_COMPUTE_32F);
std::vector<hipblasLtMatmulHeuristicResult_t> heuristic;
gemm.setProblem(1, 1, 1, 1, epilogue, inputs); // m, n, k, batch
gemm.algoGetHeuristic(gemm, pref, heuristic);
gemm.initialize(heuristic[0].algo, d_workspace, stream);
for(int i = 0; i < 10; i++)
{
    gemm.run(stream);
}

Grouped GEMM#

hipblasLtExt supports grouped GEMM. It shares the same class with normal GEMM.

After the problem is set, you can check the problem type using the function getGemmType().

enum class GemmType
{
    HIPBLASLT_GEMM             = 1,
    HIPBLASLT_GROUPED_GEMM     = 2
};

The grouped GEMM class also includes the setProblem APIs.

HIPBLASLT_EXPORT hipblasStatus_t setProblem(
    int64_t m, int64_t n, int64_t k, int64_t batch_count, GemmEpilogueV2& epilogue, GemmInputsV2& inputs);

HIPBLASLT_EXPORT hipblasStatus_t setProblem(std::vector<int64_t>&        m,
                                            std::vector<int64_t>&        n,
                                            std::vector<int64_t>&        k,
                                            std::vector<int64_t>&        batch_count,
                                            std::vector<GemmEpilogueV2>& epilogue,
                                            std::vector<GemmInputsV2>&   inputs);

HIPBLASLT_EXPORT hipblasStatus_t setProblem(std::vector<int64_t>&        m,
                                            std::vector<int64_t>&        n,
                                            std::vector<int64_t>&        k,
                                            std::vector<int64_t>&        batch_count,
                                            std::vector<int64_t>&        lda,
                                            std::vector<int64_t>&        ldb,
                                            std::vector<int64_t>&        ldc,
                                            std::vector<int64_t>&        ldd,
                                            std::vector<int64_t>&        strideA,
                                            std::vector<int64_t>&        strideB,
                                            std::vector<int64_t>&        strideC,
                                            std::vector<int64_t>&        strideD,
                                            std::vector<GemmEpilogueV2>& epilogue,
                                            std::vector<GemmInputsV2>&   inputs,
                                            GemmProblemTypeV2&           problemtype);

HIPBLASLT_EXPORT hipblasStatus_t setProblem(std::vector<hipblasLtMatmulDesc_t>&   matmul_descr,
                                            std::vector<void*>&                   alpha,
                                            std::vector<void*>&                   A,
                                            std::vector<hipblasLtMatrixLayout_t>& matA,
                                            std::vector<void*>&                   B,
                                            std::vector<hipblasLtMatrixLayout_t>& matB,
                                            std::vector<void*>&                   beta,
                                            std::vector<void*>&                   C,
                                            std::vector<hipblasLtMatrixLayout_t>& matC,
                                            std::vector<void*>&                   D,
                                            std::vector<hipblasLtMatrixLayout_t>& matD);

For the following API, the epilogue argument supports broadcasting to the length of the problem size by duplicating the last element.

HIPBLASLT_EXPORT hipblasStatus_t setProblem(std::vector<int64_t>&        m,
                                            std::vector<int64_t>&        n,
                                            std::vector<int64_t>&        k,
                                            std::vector<int64_t>&        batch_count,
                                            std::vector<int64_t>&        lda,
                                            std::vector<int64_t>&        ldb,
                                            std::vector<int64_t>&        ldc,
                                            std::vector<int64_t>&        ldd,
                                            std::vector<int64_t>&        strideA,
                                            std::vector<int64_t>&        strideB,
                                            std::vector<int64_t>&        strideC,
                                            std::vector<int64_t>&        strideD,
                                            std::vector<GemmEpilogueV2>& epilogue,
                                            std::vector<GemmInputsV2>&   inputs,
                                            GemmProblemTypeV2&           problemtype);

Note

Only a problemtype size equal to 1 is currently supported. (This means only one GemmProblemTypeV2 for all problems.)

// Pseudo code
std::vector<int64_t> m, n, k;
// ...
for(size_t i = 0; i < problem_size, i++)
{
    // ...
}
std::vector<GemmProblemTypeV2> problemtypes;
problemtypes.push_back(problemtype);
groupedgemm.setProblem(m, n, k, batch_count, lda, ldb, ldc, ldd, strideA, strideB, strideC, strideD, epilogue, inputs, problemtypes);

The UserArguments structure#

Grouped GEMM supports the use of external device memory to run the kernel. This is helpful if some of the arguments are from the output of the pervious kernel. To change the size-related arguments m, n, k, and batch, see Fixed MK.

struct UserArguments
{
    uint32_t m; //!< size m
    uint32_t n; //!< size n
    uint32_t batch; //!< size batch
    uint32_t k; //!< size k
    void*    d; //!< The d matrix input pointer.
    void*    c; //!< The c matrix input pointer.
    void*    a; //!< The a matrix input pointer.
    void*    b; //!< The b matrix input pointer.
    uint32_t strideD1; //!< The d leading dimension.
    uint32_t strideD2; //!< The d batch stride
    uint32_t strideC1; //!< The c leading dimension.
    uint32_t strideC2; //!< The c batch stride
    uint32_t strideA1; //!< The a leading dimension.
    uint32_t strideA2; //!< The a batch stride
    uint32_t strideB1; //!< The b leading dimension.
    uint32_t strideB2; //!< The b batch stride
    int8_t   alpha[16]; //!< The alpha value.
    int8_t   beta[16]; //!< The beta value.
    // Epilogue inputs
    void*    bias; //!< The bias input pointer.
    int      biasType; //!< The bias datatype. Only works if mode is set to bias related epilogues.
    uint32_t reserved;
    void*    e; //!< The aux input pointer. Only works if mode is set to aux related epilogues.
    uint32_t strideE1; //!< The aux leading dimension. Only works if mode is set to aux related epilogues.
    uint32_t strideE2; //!< The aux batch stride. Only works if mode is set to aux related epilogues.
    float    act0; //!< The activation value 1. Some activations might use it.
    float    act1; //!< The activation value 2.
    int      activationType; //!< The activation type.  Only works if mode is set to activation related epilogues.
} __attribute__((packed));

hipBLASLt adds two functions to the UserArguments-related API. The first API is a helper function that helps you initialize the UserArguments structure from the saved problems inside the grouped GEMM object. The second API is an overload function with an additional UserArguments device pointer input.

HIPBLASLT_EXPORT hipblasStatus_t getDefaultValueForDeviceUserArguments(void* hostDeviceUserArgs);

HIPBLASLT_EXPORT hipblasStatus_t run(void* deviceUserArgs, hipStream_t stream);

Here is a simple example that shows how this API works.

// Pseudo code
// Step 1: Get all algorithms
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAllAlgos(handle,
                                                 HIPBLASLT_GEMM,
                                                 HIPBLAS_OP_N,
                                                 HIPBLAS_OP_N,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 HIPBLAS_COMPUTE_32F,
                                                 heuristicResult));

hipblaslt_ext::GemmPreferenceV2 pref;
pref.setMaxWorkspaceBytes(1000000);
// Step 2: Setup problem
std::vector<int64_t> m(gemm_count);
std::vector<int64_t> n(gemm_count);
std::vector<int64_t> k(gemm_count);
std::vector<int64_t> batch_count(gemm_count);
std::vector<hipblaslt_ext::GemmEpilogueV2> epilogue(gemm_count);
std::vector<hipblaslt_ext::GemmInputsV2> inputs(gemm_count);
for(int i = 0; i < gemm_count; i++)
{
    m[i] = 1;
    n[i] = 1;
    k[i] = 1;
    batch_count[i] = 1;
    epilogue[i].setMode(HIPBLASLT_EPILOGUE_GELU);
    inputs[i].setA(d_a[i]);
    inputs[i].setB(d_b[i]);
    inputs[i].setC(d_c[i]);
    inputs[i].setD(d_d[i]);
    inputs[i].setAlpha(&alpha[i]);
    inputs[i].setBeta(&beta[i]);
}

// Step 3: Create grouped gemm instance
hipblaslt_ext::GroupedGemm groupedGemm(handle,
                                       HIPBLAS_OP_N,
                                       HIPBLAS_OP_N,
                                       HIP_R_16F,
                                       HIP_R_16F,
                                       HIP_R_16F,
                                       HIP_R_16F,
                                       HIPBLAS_COMPUTE_32F);

// Step 4: Set problem
groupedGemm.setProblem(m, n, k, batch_count, epilogue, inputs); // m, n, k, batch

// Step 5: Get default value from the instance
hipblaslt_ext::UserArguments* dUAFloat = new hipblaslt_ext::UserArguments[gemm_count];
groupedGemm.getDefaultValueForDeviceUserArguments((void*)dUAFloat);
// Once you get the default value here, you can make several copies and change the values
// from the host

// Next copy them to the device memory
hipblaslt_ext::UserArguments* d_dUAFloat = nullptr;
hipMalloc(&d_dUAFloat, sizeof(hipblaslt_ext::UserArguments) * gemm_count);
hipMemcpy(d_dUAFloat, dUAFloat, sizeof(hipblaslt_ext::UserArguments) * gemm_count, hipMemcpyHostToDevice);

validIdx.clear();
for(int j = 0; j < heuristicResult.size(); j++)
{
    size_t workspace_size = 0;
    if(groupedGemm.isAlgoSupported(heuristicResult[j].algo, workspace_size)
       == HIPBLAS_STATUS_SUCCESS)
    {
        validIdx.push_back(j);
    }
}

// Step 6: Initialize and run
if(validIdx.size() > 1)
{
    groupedGemm.initialize(heuristicResult[validIdx[0]].algo, d_workspace, stream);
    for(int i = 0; i < 10; i++)
    {
        groupedGemm.run(userArgs, stream);
    }
}

The base class (GemmInstance)#

This is the base class for Gemm and GroupedGemm.

// Gets heuristic from the instance.
HIPBLASLT_EXPORT hipblasStatus_t algoGetHeuristic(const int                                      requestedAlgoCount,
                                                  const GemmPreferenceV2&                        pref,
                                                  std::vector<hipblasLtMatmulHeuristicResult_t>& heuristicResults);

// Returns SUCCESS if the algo is supported, also returns the required workspace size in bytes.
HIPBLASLT_EXPORT hipblasStatus_t isAlgoSupported(hipblasLtMatmulAlgo_t& algo, size_t& workspaceSizeInBytes);

// Initializes the instance before calling run. Requires every time the problem is set.
HIPBLASLT_EXPORT hipblasStatus_t initialize(const hipblasLtMatmulAlgo_t& algo, void* workspace, bool useUserArgs = true, hipStream_t stream = 0);

// Run the problem.
HIPBLASLT_EXPORT hipblasStatus_t run(hipStream_t stream);

Get all algorithms#

Get all algorithms allows you to get all the algorithms for a specific problem type. It requires the transpose of A, B, the data type of the inputs, and the compute type.

HIPBLASLT_EXPORT
hipblasStatus_t hipblaslt_ext::getAllAlgos(hipblasLtHandle_t                              handle,
                                           hipblasLtExtGemmTypeEnum_t                     typeGemm,
                                           hipblasOperation_t                             opA,
                                           hipblasOperation_t                             opB,
                                           hipDataType                              typeA,
                                           hipDataType                              typeB,
                                           hipDataType                              typeC,
                                           hipDataType                              typeD,
                                           hipblasComputeType_t                         typeCompute,
                                           std::vector<hipblasLtMatmulHeuristicResult_t>& heuristicResults);

This API doesn’t require a problem size or epilogue as input. It uses another API named isAlgoSupported to check if the algorithm supports a problem.

hipblaslt_ext::matmulIsAlgoSupported()
gemm.isAlgoSupported()

The API returns the required workspace size in bytes upon successful completion.

// Get all algorithms
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAllAlgos(handle,
                                                 HIPBLASLT_GEMM,
                                                 HIPBLAS_OP_N,
                                                 HIPBLAS_OP_N,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 HIPBLAS_COMPUTE_32F,
                                                 heuristicResult));

validIdx.clear();
for(int j = 0; j < heuristicResult.size(); j++)
{
    size_t workspace_size = 0;
    if(hipblaslt_ext::matmulIsAlgoSupported(handle,
                                            matmul,
                                            &(alpha),
                                            matA,
                                            matB,
                                            &(beta),
                                            matC,
                                            matD,
                                            heuristicResult[j].algo,
                                            workspace_size)
       == HIPBLAS_STATUS_SUCCESS)
    {
        validIdx.push_back(j);
        heuristicResult[j].workspaceSize = workspace_size;
    }
    else
    {
        heuristicResult[j].workspaceSize = 0;
    }
}

Algorithm index#

This extension API lets you to get the algorithm index using hipblasLtMatmulAlgo_t.

HIPBLASLT_EXPORT int getIndexFromAlgo(hipblasLtMatmulAlgo_t& algo);

You can use an index vector to retrieve the heuristic results.

HIPBLASLT_EXPORT
hipblasStatus_t
    getAlgosFromIndex(hipblasLtHandle_t                              handle,
                      std::vector<int>&                              algoIndex,
                      std::vector<hipblasLtMatmulHeuristicResult_t>& heuristicResults);

Sample code#

This section contains some code samples that demonstrate the use cases of the extension APIs.

GEMM#

// Pseudo code for gemm problem
// Get all algorithms
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAllAlgos(handle,
                                                 HIPBLASLT_GEMM,
                                                 HIPBLAS_OP_N,
                                                 HIPBLAS_OP_N,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 HIPBLAS_COMPUTE_32F,
                                                 heuristicResult));

hipblaslt_ext::GemmPreferenceV2 pref;
pref.setMaxWorkspaceBytes(1000000);
hipblaslt_ext::GemmEpilogueV2 epilogue;
epilogue.setMode(HIPBLASLT_EPILOGUE_GELU);
hipblaslt_ext::GemmInputsV2 inputs;
inputs.setA(d_a);
inputs.setB(d_b);
inputs.setC(d_c);
inputs.setD(d_d);
inputs.setAlpha(&alpha);
inputs.setBeta(&beta);

hipblaslt_ext::Gemm gemm(handle,
                         HIPBLAS_OP_N,
                         HIPBLAS_OP_N,
                         HIP_R_16F,
                         HIP_R_16F,
                         HIP_R_16F,
                         HIP_R_16F,
                         HIPBLAS_COMPUTE_32F);

gemm.setProblem(1, 1, 1, 1, epilogue, inputs); // m, n, k, batch

validIdx.clear();
for(int j = 0; j < heuristicResult.size(); j++)
{
    size_t workspace_size = 0;
    if(gemm.isAlgoSupported(heuristicResult[j].algo, workspace_size)
       == HIPBLAS_STATUS_SUCCESS)
    {
        validIdx.push_back(j);
        heuristicResult[j].workspaceSize = workspace_size;
    }
    else
    {
        heuristicResult[j].workspaceSize = 0;
    }
}

if(validIdx.size() > 1)
{
    gemm.initialize(heuristicResult[validIdx[0]].algo, d_workspace, stream);
    for(int i = 0; i < 10; i++)
    {
        gemm.run(stream);
    }
}

Grouped GEMM#

// Pseudo code for grouped gemm problem
// Get all algorithms
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAllAlgos(handle,
                                                 HIPBLASLT_GEMM,
                                                 HIPBLAS_OP_N,
                                                 HIPBLAS_OP_N,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 HIPBLAS_COMPUTE_32F,
                                                 heuristicResult));

hipblaslt_ext::GemmPreferenceV2 pref;
pref.setMaxWorkspaceBytes(1000000);

std::vector<int64_t> m(gemm_count);
std::vector<int64_t> n(gemm_count);
std::vector<int64_t> k(gemm_count);
std::vector<int64_t> batch_count(gemm_count);
std::vector<hipblaslt_ext::GemmEpilogueV2> epilogue(gemm_count);
std::vector<hipblaslt_ext::GemmInputsV2> inputs(gemm_count);
for(int i = 0; i < gemm_count; i++)
{
    m[i] = 1;
    n[i] = 1;
    k[i] = 1;
    batch_count[i] = 1;
    epilogue[i].setMode(HIPBLASLT_EPILOGUE_GELU);
    inputs[i].setA(d_a[i]);
    inputs[i].setB(d_b[i]);
    inputs[i].setC(d_c[i]);
    inputs[i].setD(d_d[i]);
    inputs[i].setAlpha(&alpha[i]);
    inputs[i].setBeta(&beta[i]);
}


hipblaslt_ext::GroupedGemm groupedGemm(handle,
                                       HIPBLAS_OP_N,
                                       HIPBLAS_OP_N,
                                       HIP_R_16F,
                                       HIP_R_16F,
                                       HIP_R_16F,
                                       HIP_R_16F,
                                       HIPBLAS_COMPUTE_32F);

groupedGemm.setProblem(m, n, k, batch_count, epilogue, inputs); // m, n, k, batch

validIdx.clear();
for(int j = 0; j < heuristicResult.size(); j++)
{
    size_t workspace_size = 0;
    if(groupedGemm.isAlgoSupported(heuristicResult[j].algo, workspace_size)
       == HIPBLAS_STATUS_SUCCESS)
    {
        validIdx.push_back(j);
    }
}

if(validIdx.size() > 1)
{
    groupedGemm.initialize(heuristicResult[validIdx[0]].algo, d_workspace, stream);
    for(int i = 0; i < 10; i++)
    {
        groupedGemm.run(stream);
    }
}

Algorithm index#

int index = hipblaslt_ext::getIndexFromAlgo(testResults[i].algo);
// Save the index to disk or somewhere else for later use.

// Get the index from previous state.
std::vector<int> algoIndex{index};
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResults;
// If the index is out of the bound of solutions, getAlgosFromIndex will return HIPBLAS_STATUS_INVALID_VALUE
if(HIPBLAS_STATUS_INVALID_VALUE
    == hipblaslt_ext::getAlgosFromIndex(handle, algoIndex, heuristicResults))
{
    std::cout << "Indexes are all out of bound." << std::endl;
    break;
}

[Grouped Gemm] Fixed MK#

The hipBLASLt extension supports changing the sizes (m, n, k, and batch) from the device memory UserArguments. However, the setup is a bit different from the normal routing.

Sum of N#

A sum of N needs to be used as an input for the grouped GEMM instance.

{1000, 1, 1, 1}; // The array of N, the first element is the sum of N

// Below is the values stored in "UserArguments"
{256, 256, 1, 1}; // This is a valid configuration cause 256 + 256 + 1 + 1 < 1000
{512, 512, 1, 1}; // This is NOT a valid configuration cause 512 + 512 + 1 + 1 > 1000

For example, consider a grouped GEMM with gemm_count = 4. The sum of N must not exceed the “sum of N” set in the setProblem API. In this mode, the first element is the “sum of N” in the array of Ns.

// Pseudo code
// Step 1: Get all algorithms
std::vector<hipblasLtMatmulHeuristicResult_t> heuristicResult;
CHECK_HIPBLASLT_ERROR(hipblaslt_ext::getAllAlgos(handle,
                                                 HIPBLASLT_GEMM,
                                                 HIPBLAS_OP_N,
                                                 HIPBLAS_OP_N,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 in_out_datatype,
                                                 HIPBLAS_COMPUTE_32F,
                                                 heuristicResult));

hipblaslt_ext::GemmPreferenceV2 pref;
pref.setMaxWorkspaceBytes(1000000);
// Step 2: Setup problem
std::vector<int64_t> m(gemm_count);
std::vector<int64_t> n(gemm_count);
std::vector<int64_t> k(gemm_count);
std::vector<int64_t> batch_count(gemm_count);
std::vector<hipblaslt_ext::GemmEpilogueV2> epilogue(gemm_count);
std::vector<hipblaslt_ext::GemmInputsV2> inputs(gemm_count);

// Step 2.1: Calculate sum of n
int64_t sum_of_n = 0;
for(int i = 0; i < gemm_count; i++)
{
    sum_of_n += n_arr[i];
}

// {sum_of_n, 1, 1, 1, ...}; // The array of N, the first element is the sum of N
for(int i = 0; i < gemm_count; i++)
{
    m[i] = m_arr[i];
    if(i == 0)
        n[i] = sum_of_n;
    else
        n[i] = 1;
    k[i] = k_arr[i];
    batch_count[i] = 1;
    inputs[i].setA(d_a[i]);
    inputs[i].setB(d_b[i]);
    inputs[i].setC(d_c[i]);
    inputs[i].setD(d_d[i]);
    inputs[i].setAlpha(&alpha[i]);
    inputs[i].setBeta(&beta[i]);
}

// Step 3: Create grouped gemm instance
hipblaslt_ext::GroupedGemm groupedGemm(handle,
                                       HIPBLAS_OP_N,
                                       HIPBLAS_OP_N,
                                       HIP_R_16F,
                                       HIP_R_16F,
                                       HIP_R_16F,
                                       HIP_R_16F,
                                       HIPBLAS_COMPUTE_32F);

// Step 4: Set problem
groupedGemm.setProblem(m, n, k, batch_count, epilogue, inputs); // m, n, k, batch

// Step 5: Get default value from the instance
hipblaslt_ext::UserArguments* dUAFloat = new hipblaslt_ext::UserArguments[gemm_count];
groupedGemm.getDefaultValueForDeviceUserArguments((void*)dUAFloat);
// Once you get the default value here, you can make several copies and change the values
// from the host

// Next Copy them to the device memory
hipblaslt_ext::UserArguments* d_dUAFloat = nullptr;
hipMalloc(&d_dUAFloat, sizeof(hipblaslt_ext::UserArguments) * gemm_count);
hipMemcpy(d_dUAFloat, dUAFloat, sizeof(hipblaslt_ext::UserArguments) * gemm_count, hipMemcpyHostToDevice);

validIdx.clear();
for(int j = 0; j < heuristicResult.size(); j++)
{
    size_t workspace_size = 0;
    if(groupedGemm.isAlgoSupported(heuristicResult[j].algo, workspace_size)
       == HIPBLAS_STATUS_SUCCESS)
    {
        validIdx.push_back(j);
    }
}

int threads = 256;
int blocks  = ceil((double)gemm_count / threads);

// Step 6: Initialize and run
if(validIdx.size() > 1)
{
    groupedGemm.initialize(heuristicResult[validIdx[0]].algo, d_workspace);
    for(int i = 0; i < 10; i++)
    {
        hipLaunchKernelGGL(kernelUpdateN,
                            dim3(blocks),
                            dim3(threads),
                            0,
                            stream,
                            gemm_count,
                            d_dUAFloat,
                            d_n_vec);  // d_n_vec is a device pointer with Ns
        groupedGemm.run(userArgs, stream);
    }
}

// .....

__global__ void kernelUpdateN(uint32_t gemm_count, void* userArgs, int32_t* sizes_n)
{
uint64_t id = hipBlockIdx_x * 256 + hipThreadIdx_x;

if(id >= gemm_count)
    return;

hipblaslt_ext::UserArguments* dUAFloat = static_cast<hipblaslt_ext::UserArguments*>(userArgs);
dUAFloat[id].n                         = sizes_n[id];
}