JAX on ROCm installation#
2025-10-23
6 min read time
JAX is a library for array-oriented numerical computation (similar to NumPy), with automatic differentiation and just-in-time (JIT) compilation to enable high-performance machine learning research.
This topic covers setup instructions and the necessary files to build, test, and run JAX with ROCm support in a Docker environment. To learn more about JAX on ROCm, including its use cases, recommendations, as well as hardware and software compatibility, see JAX compatibility.
Install JAX on ROCm#
To install JAX on ROCm, you have the following options:
Use a prebuilt Docker image with JAX preinstalled (recommended)
Use a prebuilt Docker image with JAX preinstalled#
The ROCm JAX team provides prebuilt Docker images, which is the simplest way to use JAX on ROCm. These images are available on Docker Hub and come with JAX configured for ROCm.
To pull the latest ROCm JAX Docker image, run:
docker pull rocm/jax:latest
Note
For specific versions of JAX, review the periodically pushed Docker images at ROCm JAX on Docker Hub.
Once the image is downloaded, launch a container using the following command:
docker run -it \ --network=host \ --device=/dev/kfd \ --device=/dev/dri \ --ipc=host \ --shm-size 64G \ --group-add video \ --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ -v $(pwd):/jax_dir \ --name rocm_jax \ rocm/jax:latest /bin/bash
Tip
The
--shm-sizeparameter allocates shared memory for the container. Adjust it based on your system’s resources if needed.Replace
$(pwd)with the absolute path to the directory you want to mount inside the container.
Verify the installation of ROCm JAX. See Test the JAX installation.
Docker image support#
AMD validates and publishes ready-made JAX images with ROCm backends on Docker
Hub. The following Docker image tags and associated inventories are validated
for ROCm 7.0.0.
For jax-community images, see rocm/jax-community on Docker Hub.
Docker pull tag
docker pull rocm/jax:rocm7.0-jax0.6.0-py3.12
See
rocm/jax:rocm7.0-jax0.6.0-py3.12
on Docker Hub.
Docker pull tag
docker pull rocm/jax:rocm7.0-jax0.6.0-py3.10
See
rocm/jax:rocm7.0-jax0.6.0-py3.10
on Docker Hub.
Use a ROCm base Docker image to install JAX#
If you prefer to use the ROCm Ubuntu image or already have a ROCm Ubuntu container, follow these steps to install JAX in the container.
Pull the ROCm Ubuntu Docker image. For example, use the following command to pull the ROCm Ubuntu image:
docker pull rocm/dev-ubuntu-22.04:7.0-complete
Launch the Docker container. After pulling the image, launch a container using this command:
docker run -it \ --network=host \ --device=/dev/kfd \ --device=/dev/dri \ --ipc=host \ --shm-size 64G \ --group-add video \ --cap-add=SYS_PTRACE \ --security-opt seccomp=unconfined \ -v $(pwd):/jax_dir \ --name rocm_jax \ rocm/dev-ubuntu-22.04:7.0-complete /bin/bash
Install the latest version of JAX. Inside the running container, install the required version of JAX with ROCm support using pip:
pip3 install jax[rocm]
Verify the installed JAX version. Check whether the correct version of JAX and its ROCm plugins are installed.
pip3 freeze | grep jax
Expected output:
jax==0.4.35 jax-rocm60-pjrt==0.4.35 jax-rocm60-plugin==0.4.35 jaxlib==0.4.35
Explicitly set the
LLVM_PATHenvironment variable. This helps XLA findld.lldin the PATH at runtime.export LLVM_PATH=/opt/rocm/llvm
Verify the installation of ROCm JAX. See Test the JAX installation.
Install JAX on bare-metal or a custom container#
Follow these steps if you prefer to install ROCm manually on your host system or in a custom container.
Install ROCm. Follow the ROCm installation guide to install ROCm on your system.
Once installed, verify your ROCm installation using:
rocm-smi
========================================== ROCm System Management Interface ========================================== ==================================================== Concise Info ==================================================== Device [Model : Revision] Temp Power Partitions SCLK MCLK Fan Perf PwrCap VRAM% GPU% Name (20 chars) (Junction) (Socket) (Mem, Compute) ====================================================================================================================== 0 [0x74a1 : 0x00] 50.0°C 170.0W NPS1, SPX 131Mhz 900Mhz 0% auto 750.0W 0% 0% AMD Instinct MI300X 1 [0x74a1 : 0x00] 51.0°C 176.0W NPS1, SPX 132Mhz 900Mhz 0% auto 750.0W 0% 0% AMD Instinct MI300X 2 [0x74a1 : 0x00] 50.0°C 177.0W NPS1, SPX 132Mhz 900Mhz 0% auto 750.0W 0% 0% AMD Instinct MI300X 3 [0x74a1 : 0x00] 53.0°C 176.0W NPS1, SPX 132Mhz 900Mhz 0% auto 750.0W 0% 0% AMD Instinct MI300X ====================================================================================================================== ================================================ End of ROCm SMI Log =================================================
Install the required version of JAX with ROCm support using pip:
pip3 install jax[rocm]
Verify the installed JAX version. Check whether the correct version of JAX and its ROCm plugins are installed.
pip3 freeze | grep jax
Explicitly set the
LLVM_PATHenvironment variable.export LLVM_PATH=/opt/rocm/llvm
Apply the namespace patch:
patch -p1 \ -d "$(python3 -c \"import sysconfig; print(sysconfig.get_paths()['purelib'])\")" \ < jax_rocm_plugin/third_party/jax/namespace.patch
Verify the installation of ROCm JAX.
Run the following commands to verify that ROCm JAX is installed correctly:
python3 -c "import jax; print(jax.devices())" python3 -c "import jax.numpy as jnp; x = jnp.arange(5); print(x)"
Expected output:
[RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)]
[0 1 2 3 4]
Build JAX from source#
The ROCm/rocm-jax repository contains sources for the ROCm
plugin for JAX as well as Dockerfiles used to build the AMD rocm/jax images.
For the most up-to-date instructions, refer directly to the instructions in the repository:
See Quick build for concise high-level steps.
See Building for more in-depth build instructions and troubleshooting suggestions.
Test the JAX installation#
After launching the container, test whether JAX detects ROCm devices as expected:
python -c "import jax; print(jax.devices())"
python3 -c "import jax.numpy as jnp; x = jnp.arange(5); print(x)"
If the setup is successful, the output should list all available ROCm devices.
Expected output:
[RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)]
[0 1 2 3 4]