Install JAX for ROCm

Install JAX for ROCm#

This directory provides setup instructions and necessary files to build, test, and run JAX with ROCm support using a PIP install, suitable for both runtime and CI workflows.

Note
These instructions are for JAX installation on Radeon GPUs.
To install ROCm on Instinct GPUs, refer to ROCm Instinct documentation.

Install JAX#

Follow these instructions to install JAX via PIP install.

PIP installation#

Follow these instructions to install JAX via PIP.

Important
The packages must be installed in the following order:

  1. Install pjrt wheel.

  2. Install plugin wheel.

  3. Install jaxlib wheel.

  4. Install jax wheel.

Install JAX for Ubuntu 24.04.

  1. Uninstall previous version

    pip3 uninstall -y jax-rocm7-pjrt jax-rocm7-plugin jaxlib jax
    
  2. Install the pjrt wheel.

    pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1/jax_rocm7_pjrt-0.7.1-py3-none-manylinux_2_28_x86_64.whl
    
  3. Install the plugin wheel.

    pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1/jax_rocm7_plugin-0.7.1-cp312-cp312-manylinux_2_28_x86_64.whl
    
  4. Install the jaxlib wheel.

    pip install jaxlib==0.7.1
    
  5. Install the jax wheel.

    pip install jax==0.7.1
    

Install JAX for Ubuntu 22.04.

Note
Latest JAX no longer supports for Python 3.10. Install Python 3.11 to use JAX with Ubuntu 22.04.

  1. Uninstall previous version

    pip3 uninstall -y jax-rocm60-pjrt jax-rocm60-plugin jaxlib jax
    
  2. Install the pjrt wheel.

    pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1/jax_rocm7_pjrt-0.7.1-py3-none-manylinux_2_28_x86_64.whl
    
  3. Install the plugin wheel.

    pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1/jax_rocm7_plugin-0.7.1-cp311-cp311-manylinux_2_28_x86_64.whl
    
  4. Install the jaxlib wheel.

    pip install jaxlib==0.6.0
    
  5. Install the jax wheel.

    pip install jax==0.6.0
    

Install JAX for RHEL 10.

  1. Uninstall previous version

    pip3 uninstall -y jax-rocm60-pjrt jax-rocm60-plugin jaxlib jax
    
  2. Install the pjrt wheel.

    pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1/jax_rocm7_pjrt-0.7.1-py3-none-manylinux_2_28_x86_64.whl
    
  3. Install the plugin wheel.

    pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1/jax_rocm7_plugin-0.7.1-cp312-cp312-manylinux_2_28_x86_64.whl
    
  4. Install the jaxlib wheel.

    pip install jaxlib==0.6.0
    
  5. Install the jax wheel.

    pip install jax==0.6.0
    
  6. Recommended for RHEL distros: Install the gcc-gfortran package (for libgfortran.so).

    sudo dnf install gcc-gfortran
    

Verify installation#

Refer to Testing your JAX installation with ROCm for verification steps.