Install JAX on ROCm 7.12.0#
This topic guides you through installing JAX with ROCm support on AMD hardware. It applies to supported AMD GPUs and platforms.
Prerequisites#
Ensure your system has a supported Python version installed and accessible: 3.11, 3.12, 3.13, or 3.14.
Review the ROCm 7.12.0 compatibility matrix for more details.
Important
Unlike PyTorch, the JAX wheels do not automatically install
rocm[libraries] as a dependency. You must have ROCm installed separately
via a tarball installation.
Install JAX#
For prerequisite steps and post-installation recommendations, see the ROCm installation instructions.
Set up your Python virtual environment. For example, run the following command to create a virtual environment:
python3.12 -m venv .venv
Activate your Python virtual environment. For example:
Install the appropriate ROCm-enabled JAX libraries for your operating system and AMD hardware architecture.
Note
The
jaxpackage itself is not published to the AMD package repository. After installing GFX architecture-basedjaxlib,jax_rocm7_plugin, andjax_rocm7_pjrtpackages from the AMD repository, install a supported JAX version from PyPI.Check your JAX installation.
Important
Set the environment variable
AMD_COMGR_NAMESPACE=1. See the known issue JAX GPU initialization might fail without AMD_COMGR_NAMESPACE set.Set
LD_LIBRARY_PATHto include the ROCm SDK core library path before running JAX. See the known issue JAX fails to initialize due to missing ROCm shared libraries. Replacepython3.12with your actual Python version (3.14, 3.13, 3.12, or 3.11):export LD_LIBRARY_PATH=/opt/python/lib/python3.12/site-packages/_rocm_sdk_core/lib:$LD_LIBRARY_PATH
export AMD_COMGR_NAMESPACE=1 export LD_LIBRARY_PATH=/opt/python/lib/python3.12/site-packages/_rocm_sdk_core/lib:$LD_LIBRARY_PATH python -c "import jax; print(jax.devices())"
This prints something like
[RocmDevice(id=0)]if JAX and ROCm are installed properly.