Install JAX on ROCm 7.13.0#

This topic guides you through installing JAX with ROCm support on AMD hardware. It applies to supported AMD GPUs and platforms.

Device family
Operating system
JAX version

Prerequisites#

  • Ensure your system has a supported Python version installed and accessible: 3.11, 3.12, 3.13, or 3.14.

  • Install the ROCm Core SDK – it’s recommended to use pip to install JAX and the ROCm Core SDK in the same Python virtual environment.

    Important

    Unlike PyTorch, the JAX packages do not automatically install rocm[libraries] as a dependency.

Install JAX using pip#

For prerequisite steps and post-installation recommendations, see the ROCm installation instructions.

  1. Set up your Python virtual environment. For example, run the following command to create one with Python 3.13:

    python3.13 -m venv .venv
    
  2. Activate your Python virtual environment. For example:

  3. Install the appropriate ROCm-enabled JAX libraries for your operating system and AMD hardware architecture.

  1. Set the following environment variable before running JAX as a workaround for a known issue.

    export XLA_FLAGS="--xla_gpu_enable_command_buffer="
    
  2. Check your JAX installation.

    python -c "import jax; print(jax.devices())"
    

    This prints something like [RocmDevice(id=0)] if JAX and ROCm are installed properly.

Known issues#

These are known issues related to JAX installation on ROCm 7.13.0 and their workarounds.

Segfaults with JAX 0.9.1#

JAX 0.9.1 might segfault during execution. To work around this, disable XLA command buffers by setting the following flag before running your script:

export XLA_FLAGS="--xla_gpu_enable_command_buffer="