JAX on ROCm installation#

2025-10-23

6 min read time

Applies to Linux

JAX is a library for array-oriented numerical computation (similar to NumPy), with automatic differentiation and just-in-time (JIT) compilation to enable high-performance machine learning research.

This topic covers setup instructions and the necessary files to build, test, and run JAX with ROCm support in a Docker environment. To learn more about JAX on ROCm, including its use cases, recommendations, as well as hardware and software compatibility, see JAX compatibility.

Install JAX on ROCm#

To install JAX on ROCm, you have the following options:

Use a prebuilt Docker image with JAX preinstalled#

The ROCm JAX team provides prebuilt Docker images, which is the simplest way to use JAX on ROCm. These images are available on Docker Hub and come with JAX configured for ROCm.

  1. To pull the latest ROCm JAX Docker image, run:

    docker pull rocm/jax:latest
    

    Note

    For specific versions of JAX, review the periodically pushed Docker images at ROCm JAX on Docker Hub.

  2. Once the image is downloaded, launch a container using the following command:

    docker run -it \
        --network=host \
        --device=/dev/kfd \
        --device=/dev/dri \
        --ipc=host \
        --shm-size 64G \
        --group-add video \
        --cap-add=SYS_PTRACE \
        --security-opt seccomp=unconfined \
        -v $(pwd):/jax_dir \
        --name rocm_jax \
        rocm/jax:latest /bin/bash
    

    Tip

    • The --shm-size parameter allocates shared memory for the container. Adjust it based on your system’s resources if needed.

    • Replace $(pwd) with the absolute path to the directory you want to mount inside the container.

  3. Verify the installation of ROCm JAX. See Test the JAX installation.

Docker image support#

AMD validates and publishes ready-made JAX images with ROCm backends on Docker Hub. The following Docker image tags and associated inventories are validated for ROCm 7.0.0. For jax-community images, see rocm/jax-community on Docker Hub.

Docker pull tag

docker pull rocm/jax:rocm7.0-jax0.6.0-py3.12

See rocm/jax:rocm7.0-jax0.6.0-py3.12 on Docker Hub.

Docker pull tag

docker pull rocm/jax:rocm7.0-jax0.6.0-py3.10

See rocm/jax:rocm7.0-jax0.6.0-py3.10 on Docker Hub.

Use a ROCm base Docker image to install JAX#

If you prefer to use the ROCm Ubuntu image or already have a ROCm Ubuntu container, follow these steps to install JAX in the container.

  1. Pull the ROCm Ubuntu Docker image. For example, use the following command to pull the ROCm Ubuntu image:

    docker pull rocm/dev-ubuntu-22.04:7.0-complete
    
  2. Launch the Docker container. After pulling the image, launch a container using this command:

    docker run -it \
        --network=host \
        --device=/dev/kfd \
        --device=/dev/dri \
        --ipc=host \
        --shm-size 64G \
        --group-add video \
        --cap-add=SYS_PTRACE \
        --security-opt seccomp=unconfined \
        -v $(pwd):/jax_dir \
        --name rocm_jax \
        rocm/dev-ubuntu-22.04:7.0-complete /bin/bash
    
  3. Install the latest version of JAX. Inside the running container, install the required version of JAX with ROCm support using pip:

    pip3 install jax[rocm]
    
  4. Verify the installed JAX version. Check whether the correct version of JAX and its ROCm plugins are installed.

    pip3 freeze | grep jax
    

    Expected output:

    jax==0.4.35
    jax-rocm60-pjrt==0.4.35
    jax-rocm60-plugin==0.4.35
    jaxlib==0.4.35
    
  5. Explicitly set the LLVM_PATH environment variable. This helps XLA find ld.lld in the PATH at runtime.

    export LLVM_PATH=/opt/rocm/llvm
    
  6. Verify the installation of ROCm JAX. See Test the JAX installation.

Install JAX on bare-metal or a custom container#

Follow these steps if you prefer to install ROCm manually on your host system or in a custom container.

  1. Install ROCm. Follow the ROCm installation guide to install ROCm on your system.

    Once installed, verify your ROCm installation using:

    rocm-smi
    
     ========================================== ROCm System Management Interface ==========================================
     ==================================================== Concise Info ====================================================
    Device  [Model : Revision]    Temp        Power     Partitions      SCLK     MCLK     Fan  Perf  PwrCap  VRAM%  GPU%
              Name (20 chars)       (Junction)  (Socket)  (Mem, Compute)
      ======================================================================================================================
      0       [0x74a1 : 0x00]       50.0°C      170.0W    NPS1, SPX       131Mhz   900Mhz   0%   auto  750.0W    0%   0%
              AMD Instinct MI300X
      1       [0x74a1 : 0x00]       51.0°C      176.0W    NPS1, SPX       132Mhz   900Mhz   0%   auto  750.0W    0%   0%
              AMD Instinct MI300X
      2       [0x74a1 : 0x00]       50.0°C      177.0W    NPS1, SPX       132Mhz   900Mhz   0%   auto  750.0W    0%   0%
              AMD Instinct MI300X
      3       [0x74a1 : 0x00]       53.0°C      176.0W    NPS1, SPX       132Mhz   900Mhz   0%   auto  750.0W    0%   0%
              AMD Instinct MI300X
      ======================================================================================================================
      ================================================ End of ROCm SMI Log =================================================
    
  2. Install the required version of JAX with ROCm support using pip:

    pip3 install jax[rocm]
    
  3. Verify the installed JAX version. Check whether the correct version of JAX and its ROCm plugins are installed.

    pip3 freeze | grep jax
    
  4. Explicitly set the LLVM_PATH environment variable.

    export LLVM_PATH=/opt/rocm/llvm
    
  5. Apply the namespace patch:

    patch -p1 \
        -d "$(python3 -c \"import sysconfig; print(sysconfig.get_paths()['purelib'])\")" \
        < jax_rocm_plugin/third_party/jax/namespace.patch
    
  6. Verify the installation of ROCm JAX.

    Run the following commands to verify that ROCm JAX is installed correctly:

    python3 -c "import jax; print(jax.devices())"
    python3 -c "import jax.numpy as jnp; x = jnp.arange(5); print(x)"
    

    Expected output:

    [RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)]
    
    [0 1 2 3 4]
    

Build JAX from source#

The ROCm/rocm-jax repository contains sources for the ROCm plugin for JAX as well as Dockerfiles used to build the AMD rocm/jax images. For the most up-to-date instructions, refer directly to the instructions in the repository:

  • See Quick build for concise high-level steps.

  • See Building for more in-depth build instructions and troubleshooting suggestions.

Test the JAX installation#

After launching the container, test whether JAX detects ROCm devices as expected:

python -c "import jax; print(jax.devices())"
python3 -c "import jax.numpy as jnp; x = jnp.arange(5); print(x)"

If the setup is successful, the output should list all available ROCm devices.

Expected output:

[RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)]
[0 1 2 3 4]