Install JAX for ROCm

Install JAX for ROCm#

Refer to this section for the recommended JAX via PIP installation method, as well as Docker-based installation.

Important!
These instructions are for Radeon GPUs only. For Instinct users, refer to Jax installation for Instinct.

AMD recommends the PIP install method to create a JAX environment when working with ROCm™ for machine learning development.

Install JAX via PIP

  1. Enter the following command to unpack and begin set up.

    sudo apt install
    
  2. Enter this command to update the pip wheel.

    pip3 install --upgrade pip wheel
    
  3. Select the applicable Ubuntu version and enter the commands to install JAX for ROCm AMD GPU support.

    This may take several minutes.

    Important! AMD recommends proceeding with ROCm WHLs available at repo.radeon.com. The ROCm WHLs available at PyTorch.org are not tested extensively by AMD as the WHLs change regularly when the nightly builds are updated.

    Important! When manually downloading WHLs from repo.radeon, ensure to select the compatible WHLs for specific Python versions.

    See Compatibility matrices for support information.

    Ubuntu 24.04

    wget https://repo.radeon.com/rocm/manylinux/rocm-rel-6.3.4/torch-2.4.0%2Brocm6.3.4.git7cecbf6d-cp312-cp312-linux_x86_64.whl
    wget https://repo.radeon.com/rocm/manylinux/rocm-rel-6.3.4/torchvision-0.19.0%2Brocm6.3.4.gitfab84886-cp312-cp312-linux_x86_64.whl
    wget https://repo.radeon.com/rocm/manylinux/rocm-rel-6.3.4/pytorch_triton_rocm-3.0.0%2Brocm6.3.4.git75cc27c2-cp312-cp312-linux_x86_64.whl
    wget https://repo.radeon.com/rocm/manylinux/rocm-rel-6.3.4/torchaudio-2.4.0%2Brocm6.3.4.git69d40773-cp312-cp312-linux_x86_64.whl
    pip3 uninstall torch torchvision pytorch-triton-rocm
    pip3 install torch-2.4.0+rocm6.3.4.git7cecbf6d-cp312-cp312-linux_x86_64.whl torchvision-0.19.0+rocm6.3.4.gitfab84886-cp312-cp312-linux_x86_64.whl torchaudio-2.4.0+rocm6.3.4.git69d40773-cp312-cp312-linux_x86_64.whl pytorch_triton_rocm-3.0.0+rocm6.3.4.git75cc27c2-cp312-cp312-linux_x86_64.whl
    

    Note

    The --break-system-packages flag must be added when installing wheels for Python 3.12 in a non-virtual environment.

    Ubuntu 22.04

    wget https://repo.radeon.com/rocm/manylinux/rocm-rel-6.3.4/torch-2.4.0%2Brocm6.3.4.git7cecbf6d-cp310-cp310-linux_x86_64.whl
    wget https://repo.radeon.com/rocm/manylinux/rocm-rel-6.3.4/torchvision-0.19.0%2Brocm6.3.4.gitfab84886-cp310-cp310-linux_x86_64.whl
    wget https://repo.radeon.com/rocm/manylinux/rocm-rel-6.3.4/pytorch_triton_rocm-3.0.0%2Brocm6.3.4.git75cc27c2-cp310-cp310-linux_x86_64.whl
    wget https://repo.radeon.com/rocm/manylinux/rocm-rel-6.3.4/torchaudio-2.4.0%2Brocm6.3.4.git69d40773-cp310-cp310-linux_x86_64.whl
    pip3 uninstall torch torchvision pytorch-triton-rocm
    pip3 install torch-2.4.0+rocm6.3.4.git7cecbf6d-cp310-cp310-linux_x86_64.whl torchvision-0.19.0+rocm6.3.4.gitfab84886-cp310-cp310-linux_x86_64.whl torchaudio-2.4.0+rocm6.3.4.git69d40773-cp310-cp310-linux_x86_64.whl pytorch_triton_rocm-3.0.0+rocm6.3.4.git75cc27c2-cp310-cp310-linux_x86_64.whl
    

    RHEL 10

    wget
      --
    

Next, verify your JAX installation.

Using Docker provides portability, and access to a prebuilt Docker container that has been rigorously tested within AMD. Docker also cuts down compilation time, and should perform as expected without installation issues.

Prerequisites to install JAX using Docker

Docker for Ubuntu® must be installed.

To install Docker for Ubuntu, enter the following command:

sudo apt install docker.io

Use Docker image with pre-installed JAX

Follow these steps to install using a Docker image.

  1. Select the applicable Ubuntu version and enter the following command to pull the public JAX Docker image.

    Optional: You can also download a specific and supported configuration with different user-space ROCm versions, PyTorch versions, and supported operating systems.

    Refer to hub.docker.com/r/rocm/pytorch to download the PyTorch Docker image.

    Ubuntu 22.04

    sudo docker pull
    

    Ubuntu 24.04

    sudo docker pull
    

    RHEL 10

    sudo docker pull
      --
    
  2. Select the applicable Ubuntu version and start a Docker container using the downloaded image.

    Ubuntu 22.04

    sudo docker run -it \
      --cap-add=SYS_PTRACE \
      --security-opt seccomp=unconfined \
      --device=/dev/kfd \
      --device=/dev/dri \
      --group-add video \
      --ipc=host \
      --shm-size 8G \
      rocm/pytorch:rocm6.3.4_ubuntu22.04_py3.10_pytorch_release_2.4.0
    

    Ubuntu 24.04

    sudo docker run -it \
      --cap-add=SYS_PTRACE \
      --security-opt seccomp=unconfined \
      --device=/dev/kfd \
      --device=/dev/dri \
      --group-add video \
      --ipc=host \
      --shm-size 8G \
      rocm/pytorch:rocm6.3.4_ubuntu24.04_py3.12_pytorch_release_2.4.0
    

    RHEL 10

    sudo docker run -it \
      --
    

    This will automatically download the image if it does not exist on the host. You can also pass the -v argument to mount any data directories from the host onto the container.

Next, verify the JAX installation.

Verify JAX installation#

Confirm if JAX is correctly installed.

  1. Verify if JAX is installed and detecting the GPU compute device.

    python3 -c 'import torch' 2> /dev/null && echo 'Success' || echo 'Failure'
    

    Expected result:

    Success
    
  2. Enter command to test if the GPU is available.

    python3 -c 'import torch; print(torch.cuda.is_available())'
    

    Expected result:

    True
    
  3. Enter command to display installed GPU device name.

    python3 -c "import torch; print(f'device name [0]:', torch.cuda.get_device_name(0))"
    

    Expected result: Example: device name [0]: Radeon RX 7900 XTX

    device name [0]: <Supported AMD GPU>
    
  4. Enter command to display component information within the current PyTorch environment.

    python3 -m torch.utils.collect_env
    

    Expected result:

    PyTorch version
     
    ROCM used to build PyTorch
     
    OS
     
    Is CUDA available
     
    GPU model and configuration
     
    HIP runtime version
     
    MIOpen runtime version
    

Environment set-up is complete, and the system is ready for use with JAX to work with machine learning models, and algorithms.