Install JAX on ROCm 7.12.0

Install JAX on ROCm 7.12.0#

This topic guides you through installing JAX with ROCm support on AMD hardware. It applies to supported AMD GPUs and platforms.

AMD device family
Instinct GPU
Radeon PRO GPU
Radeon GPU
Ryzen APU
Operating system

Prerequisites#

Ensure your system has a supported Python version installed and accessible: 3.11, 3.12, 3.13, or 3.14.

Review the ROCm 7.12.0 compatibility matrix for more details.

Important

Unlike PyTorch, the JAX wheels do not automatically install rocm[libraries] as a dependency. You must have ROCm installed separately via a tarball installation.

Install JAX#

For prerequisite steps and post-installation recommendations, see the ROCm installation instructions.

  1. Set up your Python virtual environment. For example, run the following command to create a virtual environment:

    python3.12 -m venv .venv
    
  2. Activate your Python virtual environment. For example:

  3. Install the appropriate ROCm-enabled JAX libraries for your operating system and AMD hardware architecture.

    Note

    The jax package itself is not published to the AMD package repository. After installing GFX architecture-based jaxlib, jax_rocm7_plugin, and jax_rocm7_pjrt packages from the AMD repository, install a supported JAX version from PyPI.

  4. Check your JAX installation.

    Important

    Set the environment variable AMD_COMGR_NAMESPACE=1. See the known issue JAX GPU initialization might fail without AMD_COMGR_NAMESPACE set.

    If AMD_COMGR_NAMESPACE=1 is not set:

    • JAX might fail to initialize the GPU

    • JAX workloads might unexpectedly run on the CPU instead of the GPU

    • Processes might crash during initialization

    python -c "import jax; print(jax.devices())"
    

    This prints something like [RocmDevice(id=0)] if JAX and ROCm are installed properly.