Install JAX for ROCm#
This directory provides setup instructions and necessary files to build, test, and run JAX with ROCm support using a PIP or Docker install, suitable for both runtime and CI workflows.
Note
These instructions are for JAX installation on Radeon GPUs.
To install ROCm on Instinct GPUs, refer to ROCm Instinct documentation.
Install JAX#
Follow these instructions to install JAX via PIP or Docker install.
PIP installation#
Follow these instructions to install JAX via PIP.
Important
The packages must be installed in the following order:
Install
pjrt
wheel.Install
plugin
wheel.Install
jaxlib
wheel.Install
jax
wheel.
Install JAX for Ubuntu 24.04.
Uninstall previous version
pip3 uninstall -y jax-rocm60-pjrt jax-rocm60-plugin jaxlib jax
Install the
pjrt
wheel.pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/jax_rocm7_pjrt-0.6.0-py3-none-manylinux_2_28_x86_64.whl
Install the
plugin
wheel.pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/jax_rocm7_plugin-0.6.0-cp312-cp312-manylinux_2_28_x86_64.whl
Install the
jaxlib
wheel.pip install jaxlib==0.6.0
Install the
jax
wheel.pip install jax==0.6.0
Install JAX for Ubuntu 22.04.
Uninstall previous version
pip3 uninstall -y jax-rocm60-pjrt jax-rocm60-plugin jaxlib jax
Install the
pjrt
wheel.pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/jax_rocm7_pjrt-0.6.0-py3-none-manylinux_2_28_x86_64.whl
Install the
plugin
wheel.pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/jax_rocm7_plugin-0.6.0-cp310-cp310-manylinux_2_28_x86_64.whl
Install the
jaxlib
wheel.pip install jaxlib==0.6.0
Install the
jax
wheel.pip install jax==0.6.0
Install JAX for RHEL 10.
Uninstall previous version
pip3 uninstall -y jax-rocm60-pjrt jax-rocm60-plugin jaxlib jax
Install the
pjrt
wheel.pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/jax_rocm7_pjrt-0.6.0-py3-none-manylinux_2_28_x86_64.whl
Install the
plugin
wheel.pip install https://repo.radeon.com/rocm/manylinux/rocm-rel-7.0.2/jax_rocm7_plugin-0.6.0-cp312-cp312-manylinux_2_28_x86_64.whl
Install the
jaxlib
wheel.pip install jaxlib==0.6.0
Install the
jax
wheel.pip install jax==0.6.0
Recommended for RHEL distros: Install the
gcc-gfortran
package (forlibgfortran.so
).sudo dnf install gcc-gfortran
Docker installation#
The ROCm JAX team provides prebuilt Docker images, which is the simplest way to use JAX on ROCm. These images are available on Docker Hub and come with JAX configured for ROCm.
Note
If issues occur while installing python packages within the docker, add thepip install parameter --break-system-packages
command.
Install JAX for Ubuntu 24.04.
To pull the latest ROCm JAX Docker image, run:
docker pull rocm/jax:rocm7.0.2-jax0.6.0-py3.12
Note
For specific versions of JAX, review the periodically pushed Docker images at ROCm JAX Community on Docker Hub.Additional Docker images are available at ROCm JAX on Docker Hub. These contain the latest ROCm version but might use an older version of JAX.
Once the image is downloaded, launch a container using the following command:
docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G \ --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v $(pwd):/jax_dir \ --name rocm_jax rocm/jax-community:latest /bin/bash docker attach rocm_jax
Note
The –shm-size parameter allocates shared memory for the container. Adjust it based on your system’s resources if needed. Replace $(pwd) with the absolute path to the directory you want to mount inside the container. If you prefer to use rocm/jax, remember to replace rocm/jax-community with rocm/jax.
Install JAX for Ubuntu 22.04.
To pull the latest ROCm JAX Docker image, run:
docker pull rocm/jax:rocm7.0.2-jax0.6.0-py3.10-ubu22
Note
For specific versions of JAX, review the periodically pushed Docker images at ROCm JAX Community on Docker Hub.Additional Docker images are available at ROCm JAX on Docker Hub. These contain the latest ROCm version but might use an older version of JAX.
Once the image is downloaded, launch a container using the following command:
docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G \ --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v $(pwd):/jax_dir \ --name rocm_jax rocm/jax-community:latest /bin/bash docker attach rocm_jax
Note
The –shm-size parameter allocates shared memory for the container. Adjust it based on your system’s resources if needed. Replace $(pwd) with the absolute path to the directory you want to mount inside the container. If you prefer to use rocm/jax, remember to replace rocm/jax-community with rocm/jax.
Install JAX for RHEL 10.
To pull the latest ROCm JAX Docker image, run:
docker pull rocm/jax:rocm7.0.2-jax0.6.0-py3.12
Note
For specific versions of JAX, review the periodically pushed Docker images at ROCm JAX Community on Docker Hub.Additional Docker images are available at ROCm JAX on Docker Hub. These contain the latest ROCm version but might use an older version of JAX.
Once the image is downloaded, launch a container using the following command:
docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G \ --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v $(pwd):/jax_dir \ --name rocm_jax rocm/jax-community:latest /bin/bash docker attach rocm_jax
Note
The –shm-size parameter allocates shared memory for the container. Adjust it based on your system’s resources if needed. Replace $(pwd) with the absolute path to the directory you want to mount inside the container. If you prefer to use rocm/jax, remember to replace rocm/jax-community with rocm/jax.
Verify installation#
Refer to Testing your JAX installation with ROCm for verification steps.