Install JAX on ROCm 7.12.0
This topic guides you through installing JAX with ROCm support on AMD
hardware. It applies to supported AMD GPUs and platforms.
AMD device family
Instinct
Radeon PRO
Radeon
Ryzen
MI355X
MI350X
MI325X
MI300X
MI300A
MI250X
MI250
MI210
MI100
AI PRO R9700
AI PRO R9600D
W7900 Dual Slot
W7900
W7800 48GB
W7800
W7700
V710
RX 9070 XT
RX 9070 GRE
RX 9070
RX 9060 XT LP
RX 9060 XT
RX 9060
RX 7900 XTX
RX 7900 XT
RX 7900 GRE
RX 7800 XT
RX 7700 XT
RX 7700 XE
RX 7700
RX 7600
AI Max+ PRO 395
AI Max PRO 390
AI Max PRO 385
AI Max PRO 380
AI Max+ 395
AI Max 390
AI Max 385
AI 9 HX PRO 475
AI 9 HX PRO 470
AI 9 PRO 465
AI 7 PRO 450
AI 5 PRO 440
AI 5 PRO 435
AI 9 HX 375
AI 9 HX 370
AI 9 365
9 270
7 260
7 250
5 240
5 230
5 220
3 210
Install JAX
For prerequisite steps and post-installation recommendations, see the
ROCm installation instructions.
Set up your Python virtual environment. For example, run the following
command to create a virtual environment:
Activate your Python virtual environment. For example:
source .venv/bin/activate
Install the appropriate ROCm-enabled JAX libraries for your operating system
and AMD hardware architecture.
Note
The jax package itself is not published to the AMD package
repository. After installing GFX architecture-based jaxlib,
jax_rocm7_plugin, and jax_rocm7_pjrt packages from the AMD
repository, install a supported JAX version from PyPI.
python -m pip install \
--extra-index-url https://repo.amd.com/rocm/whl/gfx950-dcgpu/ \
"jaxlib==0.8.2+rocm7.12.0" \
"jax_rocm7_plugin==0.8.2+rocm7.12.0" \
"jax_rocm7_pjrt==0.8.2+rocm7.12.0"
# Install jax from PyPI
python -m pip install "jax==0.8.2"
python -m pip install \
--extra-index-url https://repo.amd.com/rocm/whl/gfx94X-dcgpu/ \
"jaxlib==0.8.2+rocm7.12.0" \
"jax_rocm7_plugin==0.8.2+rocm7.12.0" \
"jax_rocm7_pjrt==0.8.2+rocm7.12.0"
# Install jax from PyPI
python -m pip install "jax==0.8.2"
python -m pip install \
--extra-index-url https://repo.amd.com/rocm/whl/gfx90a/ \
"jaxlib==0.8.2+rocm7.12.0" \
"jax_rocm7_plugin==0.8.2+rocm7.12.0" \
"jax_rocm7_pjrt==0.8.2+rocm7.12.0"
# Install jax from PyPI
python -m pip install "jax==0.8.2"
python -m pip install \
--extra-index-url https://repo.amd.com/rocm/whl/gfx908/ \
"jaxlib==0.8.2+rocm7.12.0" \
"jax_rocm7_plugin==0.8.2+rocm7.12.0" \
"jax_rocm7_pjrt==0.8.2+rocm7.12.0"
# Install jax from PyPI
python -m pip install "jax==0.8.2"
python -m pip install \
--extra-index-url https://repo.amd.com/rocm/whl/gfx120X-all/ \
"jaxlib==0.8.2+rocm7.12.0" \
"jax_rocm7_plugin==0.8.2+rocm7.12.0" \
"jax_rocm7_pjrt==0.8.2+rocm7.12.0"
# Install jax from PyPI
python -m pip install "jax==0.8.2"
python -m pip install \
--extra-index-url https://repo.amd.com/rocm/whl/gfx110X-dgpu/ \
"jaxlib==0.8.2+rocm7.12.0" \
"jax_rocm7_plugin==0.8.2+rocm7.12.0" \
"jax_rocm7_pjrt==0.8.2+rocm7.12.0"
# Install jax from PyPI
python -m pip install "jax==0.8.2"
python -m pip install \
--extra-index-url https://repo.amd.com/rocm/whl/gfx1151/ \
"jaxlib==0.8.2+rocm7.12.0" \
"jax_rocm7_plugin==0.8.2+rocm7.12.0" \
"jax_rocm7_pjrt==0.8.2+rocm7.12.0"
# Install jax from PyPI
python -m pip install "jax==0.8.2"
python -m pip install \
--extra-index-url https://repo.amd.com/rocm/whl/gfx1150/ \
"jaxlib==0.8.2+rocm7.12.0" \
"jax_rocm7_plugin==0.8.2+rocm7.12.0" \
"jax_rocm7_pjrt==0.8.2+rocm7.12.0"
# Install jax from PyPI
python -m pip install "jax==0.8.2"
Check your JAX installation.
Important
Set the environment variable AMD_COMGR_NAMESPACE=1. See the known issue
JAX GPU initialization might fail without AMD_COMGR_NAMESPACE set.
If AMD_COMGR_NAMESPACE=1 is not set:
JAX might fail to initialize the GPU
JAX workloads might unexpectedly run on the CPU instead of the GPU
Processes might crash during initialization
python -c "import jax; print(jax.devices())"
This prints something like [RocmDevice(id=0)] if JAX and ROCm are installed properly.