Install JAX on ROCm 7.13.0#
This topic guides you through installing JAX with ROCm support on AMD hardware. It applies to supported AMD GPUs and platforms.
Prerequisites#
Ensure your system has a supported Python version installed and accessible: 3.11, 3.12, 3.13, or 3.14.
Install the ROCm Core SDK – it’s recommended to use pip to install JAX and the ROCm Core SDK in the same Python virtual environment.
Important
Unlike PyTorch, the JAX packages do not automatically install
rocm[libraries]as a dependency.
Install JAX using pip#
For prerequisite steps and post-installation recommendations, see the ROCm installation instructions.
Set up your Python virtual environment. For example, run the following command to create one with Python 3.13:
python3.13 -m venv .venv
Activate your Python virtual environment. For example:
Install the appropriate ROCm-enabled JAX libraries for your operating system and AMD hardware architecture.
Set the following environment variable before running JAX as a workaround for a known issue.
export XLA_FLAGS="--xla_gpu_enable_command_buffer="
Check your JAX installation.
python -c "import jax; print(jax.devices())"
This prints something like
[RocmDevice(id=0)]if JAX and ROCm are installed properly.
Known issues#
These are known issues related to JAX installation on ROCm 7.13.0 and their workarounds.
Segfaults with JAX 0.9.1#
JAX 0.9.1 might segfault during execution. To work around this, disable XLA command buffers by setting the following flag before running your script:
export XLA_FLAGS="--xla_gpu_enable_command_buffer="