Installing JAX for ROCm#
2024-05-21
4 min read time
JAX provides a NumPy-like API, which combines automatic differentiation and the Accelerated Linear Algebra (XLA) compiler to achieve high-performance machine learning at scale.
JAX uses composable transformations of Python+NumPy through just-in-time (JIT) compilation, automatic vectorization, and parallelization.
To learn about JAX, including profiling and optimizations, refer to the JAX documentation.
Compatibility#
You can currently use JAX with the following hardware and software:
GPUs: MI250 and MI300
OS: Ubuntu 20.04
Python: 3.9, 3.10, 3.11
ROCm: 6.0.0, 6.1.0
Installing JAX#
JAX wheels and Docker images are released through the GitHub ROCm JAX fork.
Tip
To build JAX from source files, refer to the JAX developer documentation or use the ROCm build script.
Pull the latest public JAX Docker image.
docker pull rocm/jax:latest
Start Docker container.
docker run -it -w /workspace --device=/dev/kfd --device=/dev/dri --group-add video \ --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 16G rocm/jax:latest
Verify the installation.
python3 -c 'import jax' 2> /dev/null && echo 'Success' || echo 'Failure'
Verify that the GPU is accessible from JAX.
python3 -c 'import jax; print(jax.devices())'
Run a basic example to ensure installation is successful.
git clone https://github.com/google/jax.git jax cd jax export PYTHONPATH=/workspace/jax/examples/:$PYTHONPATH python3 examples/mnist_classifier.py
Your output should look similar to this:
Starting training... Epoch 0 in 10.97 sec Training set accuracy 0.871916651725769 Test set accuracy 0.880299985408783 Epoch 1 in 0.34 sec Training set accuracy 0.8979166746139526 Test set accuracy 0.9030999541282654 Epoch 2 in 0.33 sec Training set accuracy 0.9092333316802979 Test set accuracy 0.9142999649047852 Epoch 3 in 0.33 sec Training set accuracy 0.9170833230018616 Test set accuracy 0.9220999479293823 Epoch 4 in 0.33 sec Training set accuracy 0.9226333498954773 Test set accuracy 0.9279999732971191 Epoch 5 in 0.33 sec Training set accuracy 0.9271667003631592 Test set accuracy 0.9297999739646912 Epoch 6 in 0.34 sec Training set accuracy 0.9323500394821167 Test set accuracy 0.9328999519348145 Epoch 7 in 0.34 sec Training set accuracy 0.935699999332428 Test set accuracy 0.9364999532699585 Epoch 8 in 0.33 sec Training set accuracy 0.938800036907196 Test set accuracy 0.9393999576568604 Epoch 9 in 0.33 sec Training set accuracy 0.9425833225250244 Test set accuracy 0.9418999552726746