Install JAX for ROCm#
This directory provides setup instructions and necessary files to build, test, and run JAX with ROCm support using a PIP install, suitable for both runtime and CI workflows.
Note
These instructions are for JAX installation on Radeon GPUs.
To install ROCm on Instinct GPUs, refer to ROCm Instinct documentation.
Install JAX#
Follow these instructions to install JAX via PIP install.
PIP installation#
Follow these instructions to install JAX via PIP.
Important
The packages must be installed in the following order:
Install
pjrtwheel.Install
pluginwheel.Install
jaxlibwheel.Install
jaxwheel.
Install JAX for Ubuntu 24.04.
Uninstall previous version
pip3 uninstall -y jax-rocm7-pjrt jax-rocm7-plugin jaxlib jax
Install the
pjrtwheel.pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1/jax_rocm7_pjrt-0.7.1-py3-none-manylinux_2_28_x86_64.whl
Install the
pluginwheel.pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1/jax_rocm7_plugin-0.7.1-cp312-cp312-manylinux_2_28_x86_64.whl
Install the
jaxlibwheel.pip install jaxlib==0.7.1
Install the
jaxwheel.pip install jax==0.7.1
Install JAX for Ubuntu 22.04.
Note
Latest JAX no longer supports for Python 3.10. Install Python 3.11 to use JAX with Ubuntu 22.04.
Uninstall previous version
pip3 uninstall -y jax-rocm60-pjrt jax-rocm60-plugin jaxlib jax
Install the
pjrtwheel.pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1/jax_rocm7_pjrt-0.7.1-py3-none-manylinux_2_28_x86_64.whl
Install the
pluginwheel.pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1/jax_rocm7_plugin-0.7.1-cp311-cp311-manylinux_2_28_x86_64.whl
Install the
jaxlibwheel.pip install jaxlib==0.6.0
Install the
jaxwheel.pip install jax==0.6.0
Install JAX for RHEL 10.
Uninstall previous version
pip3 uninstall -y jax-rocm60-pjrt jax-rocm60-plugin jaxlib jax
Install the
pjrtwheel.pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1/jax_rocm7_pjrt-0.7.1-py3-none-manylinux_2_28_x86_64.whl
Install the
pluginwheel.pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1/jax_rocm7_plugin-0.7.1-cp312-cp312-manylinux_2_28_x86_64.whl
Install the
jaxlibwheel.pip install jaxlib==0.6.0
Install the
jaxwheel.pip install jax==0.6.0
Recommended for RHEL distros: Install the
gcc-gfortranpackage (forlibgfortran.so).sudo dnf install gcc-gfortran
Verify installation#
Refer to Testing your JAX installation with ROCm for verification steps.